Oysiyl commited on
Commit
d8ef477
·
verified ·
1 Parent(s): 7173226

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -28
app.py CHANGED
@@ -1,10 +1,27 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def respond(
@@ -15,29 +32,100 @@ def respond(
15
  temperature,
16
  top_p,
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
  messages,
32
- max_tokens=max_tokens,
33
- stream=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  temperature=temperature,
35
  top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  """
@@ -46,9 +134,12 @@ For information on how to customize the ChatInterface, peruse the gradio docs: h
46
  demo = gr.ChatInterface(
47
  respond,
48
  additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
 
 
 
50
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
  gr.Slider(
53
  minimum=0.1,
54
  maximum=1.0,
@@ -57,8 +148,20 @@ demo = gr.ChatInterface(
57
  label="Top-p (nucleus sampling)",
58
  ),
59
  ],
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
+ import threading
4
+ import torch
5
 
6
+ # Load model directly
7
+ model = AutoModelForCausalLM.from_pretrained("Oysiyl/gemma-3-1B-GRPO")
8
+ tokenizer = AutoTokenizer.from_pretrained("Oysiyl/gemma-3-1B-GRPO")
9
+
10
+
11
+ def process_history(history):
12
+ """Process chat history into the format expected by the model."""
13
+ processed_history = []
14
+ for user_msg, assistant_msg in history:
15
+ if user_msg:
16
+ processed_history.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
17
+ if assistant_msg:
18
+ processed_history.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]})
19
+ return processed_history
20
+
21
+
22
+ def process_new_user_message(message):
23
+ """Process a new user message into the format expected by the model."""
24
+ return [{"type": "text", "text": message}]
25
 
26
 
27
  def respond(
 
32
  temperature,
33
  top_p,
34
  ):
35
+ # Format messages according to Gemma's expected chat format
36
+ messages = []
37
+ if system_message:
38
+ messages.append({"role": "system", "content": [{"type": "text", "text": system_message}]})
39
+
40
+ messages.extend(process_history(history))
41
+ messages.append({"role": "user", "content": process_new_user_message(message)})
42
+
43
+ # Apply chat template
44
+ inputs = tokenizer.apply_chat_template(
 
 
 
45
  messages,
46
+ add_generation_prompt=True,
47
+ tokenize=True,
48
+ return_dict=True,
49
+ return_tensors="pt"
50
+ )
51
+
52
+ if torch.cuda.is_available():
53
+ inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
54
+ model.to("cuda")
55
+
56
+ # Set up the streamer
57
+ streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=False)
58
+
59
+ # Run generation in a separate thread
60
+ generate_kwargs = dict(
61
+ **inputs,
62
+ streamer=streamer,
63
+ max_new_tokens=max_tokens,
64
  temperature=temperature,
65
  top_p=top_p,
66
+ top_k=64, # Recommended Gemma-3 setting
67
+ do_sample=True,
68
+ )
69
+
70
+ thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
71
+ thread.start()
72
+
73
+ # Stream the output, add "Thinking" at the beginning
74
+ output = "Thinking: \n"
75
+ for token in streamer:
76
+ output += token
77
+ # Check for various solution patterns
78
+ if "<SOLUTION>" in output:
79
+ # Original solution pattern
80
+ solution_start = output.find("<SOLUTION>") + len("<SOLUTION>")
81
+ solution_end = output.find("</SOLUTION>")
82
+ if solution_end > solution_start:
83
+ formatted_output = (
84
+ output[:solution_start] +
85
+ "Final answer: **" + output[solution_start:solution_end] + "**" +
86
+ output[solution_end:]
87
+ )
88
+ yield formatted_output
89
+ else:
90
+ # Handle case where closing tag is missing
91
+ formatted_output = (
92
+ output[:solution_start] +
93
+ "Final answer: **" + output[solution_start:] + "**"
94
+ )
95
+ yield formatted_output
96
+ # Check if end_working_out tag is present
97
+ elif "</end_working_out>" in output:
98
+ solution_start = output.find("</end_working_out>") + len("</end_working_out>")
99
+ formatted_output = (
100
+ output[:solution_start] +
101
+ "\nFinal answer: **" + output[solution_start:] + "**"
102
+ )
103
+ yield formatted_output
104
+ # Check if start_working_out is present but end_working_out is missing
105
+ elif "<start_working_out>" in output:
106
+ # Check if there's a SOLUTION tag after start_working_out
107
+ working_start = output.find("<start_working_out>")
108
+ if "<SOLUTION>" in output[working_start:]:
109
+ solution_start = output.find("<SOLUTION>", working_start) + len("<SOLUTION>")
110
+ solution_end = output.find("</SOLUTION>", solution_start)
111
+ if solution_end > solution_start:
112
+ formatted_output = (
113
+ output[:solution_start] +
114
+ "Final answer: **" + output[solution_start:solution_end] + "**" +
115
+ output[solution_end:]
116
+ )
117
+ yield formatted_output
118
+ else:
119
+ formatted_output = (
120
+ output[:solution_start] +
121
+ "Final answer: **" + output[solution_start:] + "**"
122
+ )
123
+ yield formatted_output
124
+ else:
125
+ # No clear solution identified
126
+ yield output
127
+ else:
128
+ yield output
129
 
130
 
131
  """
 
134
  demo = gr.ChatInterface(
135
  respond,
136
  additional_inputs=[
137
+ gr.Textbox(
138
+ value="You are given a problem.\nThink about the problem and provide your working out.\nPlace it between <start_working_out> and <end_working_out>.\nThen, provide your solution between <SOLUTION></SOLUTION>",
139
+ label="System message"
140
+ ),
141
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
142
+ gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
143
  gr.Slider(
144
  minimum=0.1,
145
  maximum=1.0,
 
148
  label="Top-p (nucleus sampling)",
149
  ),
150
  ],
151
+ examples=[
152
+ ["Apple sold 100 iPhones at their New York store today for an average cost of $1000. They also sold 20 iPads for an average cost of $900 and 80 Apple TVs for an average cost of $200. What was the average cost across all products sold today?"],
153
+ ["Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?"],
154
+ ["Mel is three years younger than Katherine. When Katherine is two dozen years old, how old will Mel be in years?"],
155
+ ["What is the sqrt of 101?"],
156
+ ],
157
+ chatbot=gr.Chatbot(
158
+ latex_delimiters=[
159
+ {"left": "$$", "right": "$$", "display": True},
160
+ {"left": "$", "right": "$", "display": False}
161
+ ],
162
+ ),
163
  )
164
 
165
 
166
  if __name__ == "__main__":
167
+ demo.launch()