Update app.py
Browse files
app.py
CHANGED
@@ -1,64 +1,175 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
"""
|
44 |
-
|
45 |
-
""
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
4 |
+
import torch
|
5 |
+
from threading import Thread
|
6 |
+
import re
|
7 |
+
|
8 |
+
phi4_model_path = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
|
9 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
10 |
+
|
11 |
+
phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto")
|
12 |
+
phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
|
13 |
+
|
14 |
+
def format_math(text):
|
15 |
+
text = re.sub(r"\[(.*?)\]", r"$$\1$$", text, flags=re.DOTALL)
|
16 |
+
text = text.replace(r"\(", "$").replace(r"\)", "$")
|
17 |
+
return text
|
18 |
+
|
19 |
+
@spaces.GPU(duration=60)
|
20 |
+
def generate_response(user_message, max_tokens, temperature, top_p, history_state):
|
21 |
+
if not user_message.strip():
|
22 |
+
return history_state, history_state
|
23 |
+
|
24 |
+
model = phi4_model
|
25 |
+
tokenizer = phi4_tokenizer
|
26 |
+
start_tag = "<|im_start|>"
|
27 |
+
sep_tag = "<|im_sep|>"
|
28 |
+
end_tag = "<|im_end|>"
|
29 |
+
|
30 |
+
system_message = "Your role as an assistant..."
|
31 |
+
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
|
32 |
+
for message in history_state:
|
33 |
+
if message["role"] == "user":
|
34 |
+
prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}"
|
35 |
+
elif message["role"] == "assistant" and message["content"]:
|
36 |
+
prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}"
|
37 |
+
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"
|
38 |
+
|
39 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
40 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
41 |
+
|
42 |
+
generation_kwargs = {
|
43 |
+
"input_ids": inputs["input_ids"],
|
44 |
+
"attention_mask": inputs["attention_mask"],
|
45 |
+
"max_new_tokens": int(max_tokens),
|
46 |
+
"do_sample": True,
|
47 |
+
"temperature": temperature,
|
48 |
+
"top_k": 50, # Fixed default
|
49 |
+
"top_p": top_p,
|
50 |
+
"repetition_penalty": 1.0, # Fixed default
|
51 |
+
"pad_token_id": tokenizer.eos_token_id,
|
52 |
+
"streamer": streamer,
|
53 |
+
}
|
54 |
+
|
55 |
+
try:
|
56 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
57 |
+
thread.start()
|
58 |
+
except Exception:
|
59 |
+
yield history_state + [{"role": "user", "content": user_message}, {"role": "assistant", "content": "⚠️ Generation failed."}], history_state
|
60 |
+
return
|
61 |
+
|
62 |
+
assistant_response = ""
|
63 |
+
new_history = history_state + [
|
64 |
+
{"role": "user", "content": user_message},
|
65 |
+
{"role": "assistant", "content": ""}
|
66 |
+
]
|
67 |
+
|
68 |
+
try:
|
69 |
+
for new_token in streamer:
|
70 |
+
if "<|end" in new_token:
|
71 |
+
continue
|
72 |
+
cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "")
|
73 |
+
assistant_response += cleaned_token
|
74 |
+
new_history[-1]["content"] = assistant_response.strip()
|
75 |
+
yield new_history, new_history
|
76 |
+
except Exception:
|
77 |
+
pass
|
78 |
+
|
79 |
+
yield new_history, new_history
|
80 |
+
|
81 |
+
example_messages = {
|
82 |
+
"Combinatorics": "From all the English alphabets, five letters are chosen and are arranged in alphabetical order. The total number of ways, in which the middle letter is 'M', is?",
|
83 |
+
"Co-ordinate Geometry": "A circle \\(C\\) of radius 2 lies in the second quadrant and touches both the coordinate axes. Let \\(r\\) be the radius of a circle that has centre at the point \\((2, 5)\\) and intersects the circle \\(C\\) at exactly two points. If the set of all possible values of \\(r\\) is the interval \\((\\alpha, \\beta)\\), then \\(3\\beta - 2\\alpha\\) is?",
|
84 |
+
"Prob-Stats": "A coin is tossed three times. Let \(X\) denote the number of times a tail follows a head. If \\(\\mu\\) and \\(\\sigma^2\\) denote the mean and variance of \\(X\\), then the value of \\(64(\\mu + \\sigma^2)\\) is?"
|
85 |
+
}
|
86 |
+
|
87 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
88 |
+
gr.Markdown(
|
89 |
+
"""
|
90 |
+
# Ramanujan Ganit R1 14B V1 Chatbot
|
91 |
+
|
92 |
+
Welcome to the Ramanujan Ganit R1 14B V1 Chatbot, developed by Fractal AI Research!
|
93 |
+
|
94 |
+
Our model excels at reasoning tasks in mathematics and science.
|
95 |
+
|
96 |
+
Try the example problems below from JEE Main 2025 or type in your own problems to see how our model breaks down complex reasoning problems.
|
97 |
+
"""
|
98 |
+
)
|
99 |
+
|
100 |
+
history_state = gr.State([])
|
101 |
+
|
102 |
+
with gr.Row():
|
103 |
+
with gr.Column(scale=1):
|
104 |
+
gr.Markdown("### Settings")
|
105 |
+
max_tokens_slider = gr.Slider(
|
106 |
+
minimum=6144,
|
107 |
+
maximum=32768,
|
108 |
+
step=1024,
|
109 |
+
value=16384,
|
110 |
+
label="Max Tokens"
|
111 |
+
)
|
112 |
+
with gr.Accordion("Advanced Settings", open=False):
|
113 |
+
temperature_slider = gr.Slider(
|
114 |
+
minimum=0.1,
|
115 |
+
maximum=2.0,
|
116 |
+
value=0.6,
|
117 |
+
label="Temperature"
|
118 |
+
)
|
119 |
+
top_p_slider = gr.Slider(
|
120 |
+
minimum=0.1,
|
121 |
+
maximum=1.0,
|
122 |
+
value=0.95,
|
123 |
+
label="Top-p"
|
124 |
+
)
|
125 |
+
|
126 |
+
with gr.Column(scale=4):
|
127 |
+
chatbot = gr.Chatbot(label="Chat", type="messages")
|
128 |
+
with gr.Row():
|
129 |
+
user_input = gr.Textbox(
|
130 |
+
label="User Input",
|
131 |
+
placeholder="Type your question here...",
|
132 |
+
scale=3
|
133 |
+
)
|
134 |
+
submit_button = gr.Button("Send", variant="primary", scale=1)
|
135 |
+
#stop_button = gr.Button("Stop", variant="stop", scale=1, interactive=True)
|
136 |
+
clear_button = gr.Button("Clear", scale=1)
|
137 |
+
gr.Markdown("**Try these examples:**")
|
138 |
+
with gr.Row():
|
139 |
+
example1_button = gr.Button("Combinatorics")
|
140 |
+
example2_button = gr.Button("Co-ordinate Geometry")
|
141 |
+
example3_button = gr.Button("Prob-Stats")
|
142 |
+
|
143 |
+
submit_button.click(
|
144 |
+
fn=generate_response,
|
145 |
+
inputs=[user_input, max_tokens_slider, temperature_slider, top_p_slider, history_state],
|
146 |
+
outputs=[chatbot, history_state]
|
147 |
+
).then(
|
148 |
+
fn=lambda: gr.update(value=""),
|
149 |
+
inputs=None,
|
150 |
+
outputs=user_input
|
151 |
+
)
|
152 |
+
|
153 |
+
clear_button.click(
|
154 |
+
fn=lambda: ([], []),
|
155 |
+
inputs=None,
|
156 |
+
outputs=[chatbot, history_state]
|
157 |
+
)
|
158 |
+
|
159 |
+
example1_button.click(
|
160 |
+
fn=lambda: gr.update(value=example_messages["Combinatorics"]),
|
161 |
+
inputs=None,
|
162 |
+
outputs=user_input
|
163 |
+
)
|
164 |
+
example2_button.click(
|
165 |
+
fn=lambda: gr.update(value=example_messages["Co-ordinate Geometry"]),
|
166 |
+
inputs=None,
|
167 |
+
outputs=user_input
|
168 |
+
)
|
169 |
+
example3_button.click(
|
170 |
+
fn=lambda: gr.update(value=example_messages["Prob-Stats"]),
|
171 |
+
inputs=None,
|
172 |
+
outputs=user_input
|
173 |
+
)
|
174 |
+
|
175 |
+
demo.launch(ssr_mode=False)
|