Oysiyl commited on
Commit
cca068b
·
verified ·
1 Parent(s): add68a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -31
app.py CHANGED
@@ -16,22 +16,6 @@ if torch.backends.mps.is_available():
16
  model.to("mps")
17
 
18
 
19
- def process_history(history):
20
- """Process chat history into the format expected by the model."""
21
- processed_history = []
22
- for user_msg, assistant_msg in history:
23
- # Always add user message first, even if empty
24
- processed_history.append({"role": "user", "content": [{"type": "text", "text": user_msg or ""}]})
25
- # Always add assistant message, even if empty
26
- processed_history.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg or ""}]})
27
- return processed_history
28
-
29
-
30
- def process_new_user_message(message):
31
- """Process a new user message into the format expected by the model."""
32
- return [{"type": "text", "text": message}]
33
-
34
-
35
  def respond(
36
  user_message,
37
  history: list[tuple[str, str]],
@@ -45,9 +29,10 @@ def respond(
45
  if system_message:
46
  messages.append({"role": "system", "content": system_message})
47
 
48
- # Process the conversation history
49
- if history:
50
- messages.extend(process_history(history))
 
51
 
52
  # Add the new user message
53
  messages.append({"role": "user", "content": user_message})
@@ -75,7 +60,6 @@ def respond(
75
  temperature=temperature,
76
  top_p=top_p,
77
  top_k=64, # Recommended Gemma-3 setting
78
- # do_sample=True,
79
  )
80
 
81
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
@@ -88,7 +72,6 @@ def respond(
88
  # Tags
89
  start_tag = "<start_working_out>"
90
  sol_start = "<SOLUTION>"
91
- sol_end = "</SOLUTION>"
92
 
93
  thinking = ""
94
  final_answer = ""
@@ -106,12 +89,7 @@ def respond(
106
  # Extract "Final answer" section: everything after <SOLUTION>
107
  if sol_start in output:
108
  sol_start_idx = output.find(sol_start) + len(sol_start)
109
- # If </SOLUTION> is present, stop at it
110
- if sol_end in output:
111
- sol_end_idx = output.find(sol_end)
112
- final_answer = output[sol_start_idx:sol_end_idx].strip()
113
- else:
114
- final_answer = output[sol_start_idx:].strip()
115
 
116
  # Build formatted output
117
  formatted_output = ""
@@ -126,10 +104,6 @@ def respond(
126
 
127
  yield formatted_output
128
 
129
- # If </SOLUTION> is found, end the response
130
- if sol_end in output:
131
- break
132
-
133
 
134
  """
135
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
@@ -165,6 +139,13 @@ demo = gr.ChatInterface(
165
  {"left": "$", "right": "$", "display": False}
166
  ],
167
  ),
 
 
 
 
 
 
 
168
  )
169
 
170
 
 
16
  model.to("mps")
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def respond(
20
  user_message,
21
  history: list[tuple[str, str]],
 
29
  if system_message:
30
  messages.append({"role": "system", "content": system_message})
31
 
32
+ # # Process the conversation history
33
+ # No need for history functionality for this reasoning model
34
+ # if history:
35
+ # messages.extend(process_history(history))
36
 
37
  # Add the new user message
38
  messages.append({"role": "user", "content": user_message})
 
60
  temperature=temperature,
61
  top_p=top_p,
62
  top_k=64, # Recommended Gemma-3 setting
 
63
  )
64
 
65
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
 
72
  # Tags
73
  start_tag = "<start_working_out>"
74
  sol_start = "<SOLUTION>"
 
75
 
76
  thinking = ""
77
  final_answer = ""
 
89
  # Extract "Final answer" section: everything after <SOLUTION>
90
  if sol_start in output:
91
  sol_start_idx = output.find(sol_start) + len(sol_start)
92
+ final_answer = output[sol_start_idx:].strip()
 
 
 
 
 
93
 
94
  # Build formatted output
95
  formatted_output = ""
 
104
 
105
  yield formatted_output
106
 
 
 
 
 
107
 
108
  """
109
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
139
  {"left": "$", "right": "$", "display": False}
140
  ],
141
  ),
142
+ title="Demo: Finetuned Gemma 3 1B (Oysiyl/gemma-3-1B-GRPO) on GSM8K with GRPO",
143
+ description=(
144
+ "This is a demo of a finetuned Gemma 3 1B model ([Oysiyl/gemma-3-1B-GRPO](https://huggingface.co/Oysiyl/gemma-3-1B-GRPO)) "
145
+ "on the [openai/gsm8k](https://huggingface.co/datasets/openai/gsm8k) dataset using the GRPO technique. "
146
+ "Finetuning and reasoning approach inspired by [Unsloth's notebook](https://docs.unsloth.ai/basics/reasoning-grpo-and-rl/tutorial-train-your-own-reasoning-model-with-grpo). "
147
+ "This demo does not support conversation history, as the GSM8K dataset consists of single-turn questions."
148
+ ),
149
  )
150
 
151