Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,27 @@
|
|
1 |
import gradio as gr
|
2 |
-
from
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
""
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
def respond(
|
@@ -15,29 +32,100 @@ def respond(
|
|
15 |
temperature,
|
16 |
top_p,
|
17 |
):
|
18 |
-
messages
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
response = ""
|
29 |
-
|
30 |
-
for message in client.chat_completion(
|
31 |
messages,
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
temperature=temperature,
|
35 |
top_p=top_p,
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
"""
|
@@ -46,9 +134,12 @@ For information on how to customize the ChatInterface, peruse the gradio docs: h
|
|
46 |
demo = gr.ChatInterface(
|
47 |
respond,
|
48 |
additional_inputs=[
|
49 |
-
gr.Textbox(
|
|
|
|
|
|
|
50 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
51 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0
|
52 |
gr.Slider(
|
53 |
minimum=0.1,
|
54 |
maximum=1.0,
|
@@ -57,8 +148,20 @@ demo = gr.ChatInterface(
|
|
57 |
label="Top-p (nucleus sampling)",
|
58 |
),
|
59 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
)
|
61 |
|
62 |
|
63 |
if __name__ == "__main__":
|
64 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
3 |
+
import threading
|
4 |
+
import torch
|
5 |
|
6 |
+
# Load model directly
|
7 |
+
model = AutoModelForCausalLM.from_pretrained("Oysiyl/gemma-3-1B-GRPO")
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained("Oysiyl/gemma-3-1B-GRPO")
|
9 |
+
|
10 |
+
|
11 |
+
def process_history(history):
|
12 |
+
"""Process chat history into the format expected by the model."""
|
13 |
+
processed_history = []
|
14 |
+
for user_msg, assistant_msg in history:
|
15 |
+
if user_msg:
|
16 |
+
processed_history.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
|
17 |
+
if assistant_msg:
|
18 |
+
processed_history.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]})
|
19 |
+
return processed_history
|
20 |
+
|
21 |
+
|
22 |
+
def process_new_user_message(message):
|
23 |
+
"""Process a new user message into the format expected by the model."""
|
24 |
+
return [{"type": "text", "text": message}]
|
25 |
|
26 |
|
27 |
def respond(
|
|
|
32 |
temperature,
|
33 |
top_p,
|
34 |
):
|
35 |
+
# Format messages according to Gemma's expected chat format
|
36 |
+
messages = []
|
37 |
+
if system_message:
|
38 |
+
messages.append({"role": "system", "content": [{"type": "text", "text": system_message}]})
|
39 |
+
|
40 |
+
messages.extend(process_history(history))
|
41 |
+
messages.append({"role": "user", "content": process_new_user_message(message)})
|
42 |
+
|
43 |
+
# Apply chat template
|
44 |
+
inputs = tokenizer.apply_chat_template(
|
|
|
|
|
|
|
45 |
messages,
|
46 |
+
add_generation_prompt=True,
|
47 |
+
tokenize=True,
|
48 |
+
return_dict=True,
|
49 |
+
return_tensors="pt"
|
50 |
+
)
|
51 |
+
|
52 |
+
if torch.cuda.is_available():
|
53 |
+
inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
|
54 |
+
model.to("cuda")
|
55 |
+
|
56 |
+
# Set up the streamer
|
57 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=False)
|
58 |
+
|
59 |
+
# Run generation in a separate thread
|
60 |
+
generate_kwargs = dict(
|
61 |
+
**inputs,
|
62 |
+
streamer=streamer,
|
63 |
+
max_new_tokens=max_tokens,
|
64 |
temperature=temperature,
|
65 |
top_p=top_p,
|
66 |
+
top_k=64, # Recommended Gemma-3 setting
|
67 |
+
do_sample=True,
|
68 |
+
)
|
69 |
+
|
70 |
+
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
|
71 |
+
thread.start()
|
72 |
+
|
73 |
+
# Stream the output, add "Thinking" at the beginning
|
74 |
+
output = "Thinking: \n"
|
75 |
+
for token in streamer:
|
76 |
+
output += token
|
77 |
+
# Check for various solution patterns
|
78 |
+
if "<SOLUTION>" in output:
|
79 |
+
# Original solution pattern
|
80 |
+
solution_start = output.find("<SOLUTION>") + len("<SOLUTION>")
|
81 |
+
solution_end = output.find("</SOLUTION>")
|
82 |
+
if solution_end > solution_start:
|
83 |
+
formatted_output = (
|
84 |
+
output[:solution_start] +
|
85 |
+
"Final answer: **" + output[solution_start:solution_end] + "**" +
|
86 |
+
output[solution_end:]
|
87 |
+
)
|
88 |
+
yield formatted_output
|
89 |
+
else:
|
90 |
+
# Handle case where closing tag is missing
|
91 |
+
formatted_output = (
|
92 |
+
output[:solution_start] +
|
93 |
+
"Final answer: **" + output[solution_start:] + "**"
|
94 |
+
)
|
95 |
+
yield formatted_output
|
96 |
+
# Check if end_working_out tag is present
|
97 |
+
elif "</end_working_out>" in output:
|
98 |
+
solution_start = output.find("</end_working_out>") + len("</end_working_out>")
|
99 |
+
formatted_output = (
|
100 |
+
output[:solution_start] +
|
101 |
+
"\nFinal answer: **" + output[solution_start:] + "**"
|
102 |
+
)
|
103 |
+
yield formatted_output
|
104 |
+
# Check if start_working_out is present but end_working_out is missing
|
105 |
+
elif "<start_working_out>" in output:
|
106 |
+
# Check if there's a SOLUTION tag after start_working_out
|
107 |
+
working_start = output.find("<start_working_out>")
|
108 |
+
if "<SOLUTION>" in output[working_start:]:
|
109 |
+
solution_start = output.find("<SOLUTION>", working_start) + len("<SOLUTION>")
|
110 |
+
solution_end = output.find("</SOLUTION>", solution_start)
|
111 |
+
if solution_end > solution_start:
|
112 |
+
formatted_output = (
|
113 |
+
output[:solution_start] +
|
114 |
+
"Final answer: **" + output[solution_start:solution_end] + "**" +
|
115 |
+
output[solution_end:]
|
116 |
+
)
|
117 |
+
yield formatted_output
|
118 |
+
else:
|
119 |
+
formatted_output = (
|
120 |
+
output[:solution_start] +
|
121 |
+
"Final answer: **" + output[solution_start:] + "**"
|
122 |
+
)
|
123 |
+
yield formatted_output
|
124 |
+
else:
|
125 |
+
# No clear solution identified
|
126 |
+
yield output
|
127 |
+
else:
|
128 |
+
yield output
|
129 |
|
130 |
|
131 |
"""
|
|
|
134 |
demo = gr.ChatInterface(
|
135 |
respond,
|
136 |
additional_inputs=[
|
137 |
+
gr.Textbox(
|
138 |
+
value="You are given a problem.\nThink about the problem and provide your working out.\nPlace it between <start_working_out> and <end_working_out>.\nThen, provide your solution between <SOLUTION></SOLUTION>",
|
139 |
+
label="System message"
|
140 |
+
),
|
141 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
142 |
+
gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
|
143 |
gr.Slider(
|
144 |
minimum=0.1,
|
145 |
maximum=1.0,
|
|
|
148 |
label="Top-p (nucleus sampling)",
|
149 |
),
|
150 |
],
|
151 |
+
examples=[
|
152 |
+
["Apple sold 100 iPhones at their New York store today for an average cost of $1000. They also sold 20 iPads for an average cost of $900 and 80 Apple TVs for an average cost of $200. What was the average cost across all products sold today?"],
|
153 |
+
["Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?"],
|
154 |
+
["Mel is three years younger than Katherine. When Katherine is two dozen years old, how old will Mel be in years?"],
|
155 |
+
["What is the sqrt of 101?"],
|
156 |
+
],
|
157 |
+
chatbot=gr.Chatbot(
|
158 |
+
latex_delimiters=[
|
159 |
+
{"left": "$$", "right": "$$", "display": True},
|
160 |
+
{"left": "$", "right": "$", "display": False}
|
161 |
+
],
|
162 |
+
),
|
163 |
)
|
164 |
|
165 |
|
166 |
if __name__ == "__main__":
|
167 |
+
demo.launch()
|