Oysiyl commited on
Commit
cbfd371
·
verified ·
1 Parent(s): e2df7c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -30
app.py CHANGED
@@ -9,6 +9,12 @@ model = AutoModelForCausalLM.from_pretrained("unsloth/gemma-3-1b-it")
9
  model.load_adapter("Oysiyl/gemma-3-1B-GRPO")
10
  tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it")
11
 
 
 
 
 
 
 
12
 
13
  def process_history(history):
14
  """Process chat history into the format expected by the model."""
@@ -27,7 +33,7 @@ def process_new_user_message(message):
27
 
28
 
29
  def respond(
30
- message,
31
  history: list[tuple[str, str]],
32
  system_message,
33
  max_tokens,
@@ -37,27 +43,26 @@ def respond(
37
  # Format messages according to Gemma's expected chat format
38
  messages = []
39
  if system_message:
40
- messages.append({"role": "system", "content": [{"type": "text", "text": system_message}]})
41
 
42
  # Process the conversation history
43
  if history:
44
  messages.extend(process_history(history))
45
 
46
  # Add the new user message
47
- messages.append({"role": "user", "content": process_new_user_message(message)})
48
 
49
  # Apply chat template
50
- inputs = tokenizer.apply_chat_template(
51
  messages,
52
  add_generation_prompt=True,
53
- tokenize=True,
54
- return_dict=True,
55
- return_tensors="pt"
56
  )
57
-
58
  if torch.cuda.is_available():
59
  inputs = inputs.to("cuda")
60
- model.to("cuda")
 
61
 
62
  # Set up the streamer
63
  streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=False)
@@ -70,31 +75,50 @@ def respond(
70
  temperature=temperature,
71
  top_p=top_p,
72
  top_k=64, # Recommended Gemma-3 setting
73
- do_sample=True,
74
  )
75
 
76
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
77
  thread.start()
78
 
79
- # Stream the output, add "Thinking" at the beginning
80
- output = "Thinking: \n"
81
  for token in streamer:
82
  output += token
83
- # Check if "<SOLUTION>" token is in the output and format everything after it as bold
84
- if "<SOLUTION>" in output:
85
- solution_start = output.find("<SOLUTION>") + len("<SOLUTION>")
86
- solution_end = output.find("</SOLUTION>")
87
- if solution_end > solution_start:
88
- formatted_output = (
89
- output[:solution_start] +
90
- "Final answer: **" + output[solution_start:solution_end] + "**" +
91
- output[solution_end:]
92
- )
93
- yield formatted_output
 
 
 
94
  else:
95
- yield output
96
- else:
97
- yield output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
 
100
  """
@@ -118,10 +142,11 @@ demo = gr.ChatInterface(
118
  ),
119
  ],
120
  examples=[
121
- ["Apple sold 100 iPhones at their New York store today for an average cost of $1000. They also sold 20 iPads for an average cost of $900 and 80 Apple TVs for an average cost of $200. What was the average cost across all products sold today?"],
122
- ["Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?"],
123
- ["Mel is three years younger than Katherine. When Katherine is two dozen years old, how old will Mel be in years?"],
124
- ["What is the sqrt of 101?"],
 
125
  ],
126
  cache_examples=False,
127
  chatbot=gr.Chatbot(
 
9
  model.load_adapter("Oysiyl/gemma-3-1B-GRPO")
10
  tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it")
11
 
12
+ if torch.cuda.is_available():
13
+ model.to("cuda")
14
+
15
+ if torch.backends.mps.is_available():
16
+ model.to("mps")
17
+
18
 
19
  def process_history(history):
20
  """Process chat history into the format expected by the model."""
 
33
 
34
 
35
  def respond(
36
+ user_message,
37
  history: list[tuple[str, str]],
38
  system_message,
39
  max_tokens,
 
43
  # Format messages according to Gemma's expected chat format
44
  messages = []
45
  if system_message:
46
+ messages.append({"role": "system", "content": system_message})
47
 
48
  # Process the conversation history
49
  if history:
50
  messages.extend(process_history(history))
51
 
52
  # Add the new user message
53
+ messages.append({"role": "user", "content": user_message})
54
 
55
  # Apply chat template
56
+ prompt = tokenizer.apply_chat_template(
57
  messages,
58
  add_generation_prompt=True,
59
+ tokenize=False,
 
 
60
  )
61
+ inputs = tokenizer(prompt, return_tensors="pt")
62
  if torch.cuda.is_available():
63
  inputs = inputs.to("cuda")
64
+ elif torch.backends.mps.is_available():
65
+ inputs = inputs.to("mps")
66
 
67
  # Set up the streamer
68
  streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=False)
 
75
  temperature=temperature,
76
  top_p=top_p,
77
  top_k=64, # Recommended Gemma-3 setting
78
+ # do_sample=True,
79
  )
80
 
81
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
82
  thread.start()
83
 
84
+ output = ""
 
85
  for token in streamer:
86
  output += token
87
+
88
+ # Tags
89
+ start_tag = "<start_working_out>"
90
+ sol_start = "<SOLUTION>"
91
+
92
+ thinking = ""
93
+ final_answer = ""
94
+
95
+ # Extract "Thinking" section: everything after <start_working_out>
96
+ if start_tag in output:
97
+ start_idx = output.find(start_tag) + len(start_tag)
98
+ # If <SOLUTION> is also present, stop "Thinking" at <SOLUTION>
99
+ if sol_start in output:
100
+ end_idx = output.find(sol_start)
101
  else:
102
+ end_idx = len(output)
103
+ thinking = output[start_idx:end_idx].strip()
104
+
105
+ # Extract "Final answer" section: everything after <SOLUTION>
106
+ if sol_start in output:
107
+ sol_start_idx = output.find(sol_start) + len(sol_start)
108
+ final_answer = output[sol_start_idx:].strip()
109
+
110
+ # Build formatted output
111
+ formatted_output = ""
112
+ if thinking:
113
+ formatted_output += "### Thinking:\n" + thinking + "\n"
114
+ if final_answer:
115
+ formatted_output += "\n### Final answer:\n**" + final_answer + "**"
116
+
117
+ # If nothing found yet, just show the raw output (for streaming effect)
118
+ if not thinking and not final_answer:
119
+ formatted_output = output
120
+
121
+ yield formatted_output
122
 
123
 
124
  """
 
142
  ),
143
  ],
144
  examples=[
145
+ ["Toulouse has twice as many sheep as Charleston. Charleston has 4 times as many sheep as Seattle. How many sheep do Toulouse, Charleston, and Seattle have together if Seattle has 20 sheep?"],
146
+ ["A football team played 22 games. They won 8 more than they lost. How many did they win?"],
147
+ ["Jim spends 2 hours watching TV and then decides to go to bed and reads for half as long. He does this 3 times a week. How many hours does he spend on TV and reading in 4 weeks?"],
148
+ ["Darrell and Allen's ages are in the ratio of 7:11. If their total age now is 162, calculate Allen's age 10 years from now."],
149
+ ["In a neighborhood, the number of rabbits pets is twelve less than the combined number of pet dogs and cats. If there are two cats for every dog, and the number of dogs is 60, how many pets in total are in the neighborhood?"],
150
  ],
151
  cache_examples=False,
152
  chatbot=gr.Chatbot(