FractalAIR commited on
Commit
5889e5d
·
verified ·
1 Parent(s): feb219b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -13
app.py CHANGED
@@ -4,7 +4,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
4
  import torch
5
  from threading import Thread
6
  import re
 
7
 
 
8
  phi4_model_path = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
9
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
 
@@ -16,7 +18,15 @@ def format_math(text):
16
  text = text.replace(r"\(", "$").replace(r"\)", "$")
17
  return text
18
 
 
 
 
 
 
 
 
19
  @spaces.GPU(duration=60)
 
20
  def generate_response(user_message, max_tokens, temperature, top_p, history_state):
21
  if not user_message.strip():
22
  return history_state, history_state
@@ -54,7 +64,7 @@ def generate_response(user_message, max_tokens, temperature, top_p, history_stat
54
 
55
  try:
56
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
57
- thread.start()
58
  except Exception:
59
  yield history_state + [{"role": "user", "content": user_message}, {"role": "assistant", "content": "⚠️ Generation failed."}], history_state
60
  return
@@ -78,25 +88,35 @@ def generate_response(user_message, max_tokens, temperature, top_p, history_stat
78
 
79
  yield new_history, new_history
80
 
 
81
  example_messages = {
82
  "JEE Main 2025 Combinatorics": "From all the English alphabets, five letters are chosen and are arranged in alphabetical order. The total number of ways, in which the middle letter is 'M', is?",
83
  "JEE Main 2025 Co-ordinate Geometry": "A circle \\(C\\) of radius 2 lies in the second quadrant and touches both the coordinate axes. Let \\(r\\) be the radius of a circle that has centre at the point \\((2, 5)\\) and intersects the circle \\(C\\) at exactly two points. If the set of all possible values of \\(r\\) is the interval \\((\\alpha, \\beta)\\), then \\(3\\beta - 2\\alpha\\) is?",
84
- "JEE Main 2025 Prob-Stats": "A coin is tossed three times. Let \(X\) denote the number of times a tail follows a head. If \\(\\mu\\) and \\(\\sigma^2\\) denote the mean and variance of \\(X\\), then the value of \\(64(\\mu + \\sigma^2)\\) is?"
 
85
  }
86
 
87
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
88
  gr.Markdown(
89
  """
90
  # Ramanujan Ganit R1 14B V1 Chatbot
91
-
92
  Welcome to the Ramanujan Ganit R1 14B V1 Chatbot, developed by Fractal AI Research!
93
-
94
  Our model excels at reasoning tasks in mathematics and science.
95
-
96
  Try the example problems below from JEE Main 2025 or type in your own problems to see how our model breaks down complex reasoning problems.
97
  """
98
  )
99
 
 
 
 
 
 
 
 
 
100
  history_state = gr.State([])
101
 
102
  with gr.Row():
@@ -125,24 +145,51 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
125
 
126
  with gr.Column(scale=4):
127
  chatbot = gr.Chatbot(label="Chat", type="messages")
128
- with gr.Column():
129
  user_input = gr.Textbox(
130
  label="User Input",
131
  placeholder="Type your question here...",
132
- lines=2
133
  )
134
- with gr.Row():
135
- submit_button = gr.Button("Send", variant="primary")
136
- clear_button = gr.Button("Clear")
137
  gr.Markdown("**Try these examples:**")
138
  with gr.Row():
139
  example1_button = gr.Button("JEE Main 2025 Combinatorics")
140
  example2_button = gr.Button("JEE Main 2025 Co-ordinate Geometry")
141
  example3_button = gr.Button("JEE Main 2025 Prob-Stats")
142
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  submit_button.click(
144
- fn=generate_response,
145
- inputs=[user_input, max_tokens_slider, temperature_slider, top_p_slider, history_state],
146
  outputs=[chatbot, history_state]
147
  ).then(
148
  fn=lambda: gr.update(value=""),
@@ -156,6 +203,18 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
156
  outputs=[chatbot, history_state]
157
  )
158
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  example1_button.click(
160
  fn=lambda: gr.update(value=example_messages["JEE Main 2025 Combinatorics"]),
161
  inputs=None,
@@ -171,5 +230,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
171
  inputs=None,
172
  outputs=user_input
173
  )
 
 
 
 
 
174
 
175
  demo.launch(share=True, ssr_mode=False)
 
4
  import torch
5
  from threading import Thread
6
  import re
7
+ import uuid
8
 
9
+ # Load model and tokenizer
10
  phi4_model_path = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
11
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
 
 
18
  text = text.replace(r"\(", "$").replace(r"\)", "$")
19
  return text
20
 
21
+ # Global dictionary to store all conversations
22
+ conversations = {}
23
+
24
+ # Function to generate a unique conversation ID
25
+ def generate_conversation_id():
26
+ return str(uuid.uuid4())[:8]
27
+
28
  @spaces.GPU(duration=60)
29
+ # Function to generate response
30
  def generate_response(user_message, max_tokens, temperature, top_p, history_state):
31
  if not user_message.strip():
32
  return history_state, history_state
 
64
 
65
  try:
66
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
67
+ thread.start()
68
  except Exception:
69
  yield history_state + [{"role": "user", "content": user_message}, {"role": "assistant", "content": "⚠️ Generation failed."}], history_state
70
  return
 
88
 
89
  yield new_history, new_history
90
 
91
+ # Example messages
92
  example_messages = {
93
  "JEE Main 2025 Combinatorics": "From all the English alphabets, five letters are chosen and are arranged in alphabetical order. The total number of ways, in which the middle letter is 'M', is?",
94
  "JEE Main 2025 Co-ordinate Geometry": "A circle \\(C\\) of radius 2 lies in the second quadrant and touches both the coordinate axes. Let \\(r\\) be the radius of a circle that has centre at the point \\((2, 5)\\) and intersects the circle \\(C\\) at exactly two points. If the set of all possible values of \\(r\\) is the interval \\((\\alpha, \\beta)\\), then \\(3\\beta - 2\\alpha\\) is?",
95
+ "JEE Main 2025 Prob-Stats": "A coin is tossed three times. Let \(X\) denote the number of times a tail follows a head. If \\(\\mu\\) and \\(\\sigma^2\\) denote the mean and variance of \\(X\\), then the value of \\(64(\\mu + \\sigma^2)\\) is?",
96
+ "JEE Main 2025 Physics": "A massless spring gets elongated by amount x_1 under a tension of 5 N . Its elongation is x_2 under the tension of 7 N . For the elongation of 5x_1 - 2x_2 , the tension in the spring will be?"
97
  }
98
 
99
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
100
  gr.Markdown(
101
  """
102
  # Ramanujan Ganit R1 14B V1 Chatbot
103
+
104
  Welcome to the Ramanujan Ganit R1 14B V1 Chatbot, developed by Fractal AI Research!
105
+
106
  Our model excels at reasoning tasks in mathematics and science.
107
+
108
  Try the example problems below from JEE Main 2025 or type in your own problems to see how our model breaks down complex reasoning problems.
109
  """
110
  )
111
 
112
+ # Sidebar for conversation history
113
+ with gr.Sidebar():
114
+ gr.Markdown("## Conversations")
115
+ conversation_selector = gr.Radio(choices=[], label="Select Conversation", interactive=True)
116
+ new_convo_button = gr.Button("New Conversation")
117
+
118
+ # State to store current conversation ID and history
119
+ current_convo_id = gr.State(generate_conversation_id())
120
  history_state = gr.State([])
121
 
122
  with gr.Row():
 
145
 
146
  with gr.Column(scale=4):
147
  chatbot = gr.Chatbot(label="Chat", type="messages")
148
+ with gr.Row():
149
  user_input = gr.Textbox(
150
  label="User Input",
151
  placeholder="Type your question here...",
152
+ scale=8 # This makes the textbox take up the entire width
153
  )
154
+ with gr.Column():
155
+ submit_button = gr.Button("Send", variant="primary", scale=1)
156
+ clear_button = gr.Button("Clear", scale=1)
157
  gr.Markdown("**Try these examples:**")
158
  with gr.Row():
159
  example1_button = gr.Button("JEE Main 2025 Combinatorics")
160
  example2_button = gr.Button("JEE Main 2025 Co-ordinate Geometry")
161
  example3_button = gr.Button("JEE Main 2025 Prob-Stats")
162
+ example4_button = gr.Button("JEE Main 2025 Physics")
163
+
164
+ # Function to update conversation list
165
+ def update_conversation_list():
166
+ return list(conversations.keys())
167
+
168
+ # Function to start a new conversation
169
+ def start_new_conversation():
170
+ new_id = generate_conversation_id()
171
+ conversations[new_id] = []
172
+ return new_id, [], gr.update(choices=update_conversation_list(), value=new_id)
173
+
174
+ # Function to load selected conversation
175
+ def load_conversation(selected_id):
176
+ if selected_id in conversations:
177
+ return selected_id, conversations[selected_id], conversations[selected_id]
178
+ else:
179
+ return current_convo_id.value, history_state.value, history_state.value
180
+
181
+ # Send message
182
+ def send_message(user_message, max_tokens, temperature, top_p, convo_id, history):
183
+ if convo_id not in conversations:
184
+ conversations[convo_id] = history
185
+ for updated_history, new_history in generate_response(user_message, max_tokens, temperature, top_p, history):
186
+ conversations[convo_id] = new_history
187
+ yield updated_history, new_history
188
+
189
+ # Button and event handlers
190
  submit_button.click(
191
+ fn=send_message,
192
+ inputs=[user_input, max_tokens_slider, temperature_slider, top_p_slider, current_convo_id, history_state],
193
  outputs=[chatbot, history_state]
194
  ).then(
195
  fn=lambda: gr.update(value=""),
 
203
  outputs=[chatbot, history_state]
204
  )
205
 
206
+ new_convo_button.click(
207
+ fn=start_new_conversation,
208
+ inputs=None,
209
+ outputs=[current_convo_id, history_state, conversation_selector]
210
+ )
211
+
212
+ conversation_selector.change(
213
+ fn=load_conversation,
214
+ inputs=conversation_selector,
215
+ outputs=[current_convo_id, history_state, chatbot]
216
+ )
217
+
218
  example1_button.click(
219
  fn=lambda: gr.update(value=example_messages["JEE Main 2025 Combinatorics"]),
220
  inputs=None,
 
230
  inputs=None,
231
  outputs=user_input
232
  )
233
+ example4_button.click(
234
+ fn=lambda: gr.update(value=example_messages["JEE Main 2025 Physics"]),
235
+ inputs=None,
236
+ outputs=user_input
237
+ )
238
 
239
  demo.launch(share=True, ssr_mode=False)