|
import gradio as gr |
|
import spaces |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
import torch |
|
from threading import Thread |
|
import re |
|
import uuid |
|
|
|
|
|
our_model_path = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B" |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
our_model = AutoModelForCausalLM.from_pretrained(our_model_path, device_map="auto", torch_dtype="auto") |
|
our_tokenizer = AutoTokenizer.from_pretrained(our_model_path) |
|
|
|
def format_math(text): |
|
text = re.sub(r"\[(.*?)\]", r"$$\1$$", text, flags=re.DOTALL) |
|
text = text.replace(r"\(", "$").replace(r"\)", "$") |
|
return text |
|
|
|
|
|
conversations = {} |
|
|
|
def generate_conversation_id(): |
|
return str(uuid.uuid4())[:8] |
|
|
|
@spaces.GPU(duration=60) |
|
def generate_response(user_message, max_tokens, temperature, top_p, history_state): |
|
if not user_message.strip(): |
|
return history_state, history_state |
|
|
|
model = our_model |
|
tokenizer = our_tokenizer |
|
start_tag = "<|im_start|>" |
|
sep_tag = "<|im_sep|>" |
|
end_tag = "<|im_end|>" |
|
|
|
system_message = "Your role as an assistant..." |
|
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}" |
|
for message in history_state: |
|
if message["role"] == "user": |
|
prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}" |
|
elif message["role"] == "assistant" and message["content"]: |
|
prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}" |
|
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}" |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
generation_kwargs = { |
|
"input_ids": inputs["input_ids"], |
|
"attention_mask": inputs["attention_mask"], |
|
"max_new_tokens": int(max_tokens), |
|
"do_sample": True, |
|
"temperature": temperature, |
|
"top_k": 50, |
|
"top_p": top_p, |
|
"repetition_penalty": 1.0, |
|
"pad_token_id": tokenizer.eos_token_id, |
|
"streamer": streamer, |
|
} |
|
|
|
try: |
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
except Exception: |
|
yield history_state + [{"role": "user", "content": user_message}, {"role": "assistant", "content": "⚠️ Generation failed."}], history_state |
|
return |
|
|
|
assistant_response = "" |
|
new_history = history_state + [ |
|
{"role": "user", "content": user_message}, |
|
{"role": "assistant", "content": ""} |
|
] |
|
|
|
try: |
|
for new_token in streamer: |
|
if "<|end" in new_token: |
|
continue |
|
cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "") |
|
assistant_response += cleaned_token |
|
new_history[-1]["content"] = assistant_response.strip() |
|
yield new_history, new_history |
|
except Exception: |
|
pass |
|
|
|
yield new_history, new_history |
|
|
|
|
|
example_messages = { |
|
"JEE Main 2025 Combinatorics": "From all the English alphabets, five letters are chosen and are arranged in alphabetical order. The total number of ways, in which the middle letter is 'M', is?", |
|
"JEE Main 2025 Coordinate Geometry": "A circle \\(C\\) of radius 2 lies in the second quadrant and touches both the coordinate axes. Let \\(r\\) be the radius of a circle that has centre at the point \\((2, 5)\\) and intersects the circle \\(C\\) at exactly two points. If the set of all possible values of \\(r\\) is the interval \\((\\alpha, \\beta)\\), then \\(3\\beta - 2\\alpha\\) is?", |
|
"JEE Main 2025 Probability & Statistics": "A coin is tossed three times. Let \(X\) denote the number of times a tail follows a head. If \\(\\mu\\) and \\(\\sigma^2\\) denote the mean and variance of \\(X\\), then the value of \\(64(\\mu + \\sigma^2)\\) is?", |
|
"JEE Main 2025 Laws of Motion": "A massless spring gets elongated by amount x_1 under a tension of 5 N . Its elongation is x_2 under the tension of 7 N . For the elongation of 5x_1 - 2x_2 , the tension in the spring will be?" |
|
} |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
|
|
|
gr.HTML( |
|
""" |
|
<div style="display: flex; align-items: center; gap: 16px; margin-bottom: 1em;"> |
|
<div style="background-color: black; padding: 6px; border-radius: 8px;"> |
|
<img src="https://framerusercontent.com/images/j0KjQQyrUfkFw4NwSaxQOLAoBU.png" alt="Fractal AI Logo" style="height: 48px;"> |
|
</div> |
|
<h1 style="margin: 0;">Ramanujan Ganit R1 14B V1 Chatbot</h1> |
|
</div> |
|
""" |
|
) |
|
|
|
with gr.Sidebar(): |
|
gr.Markdown("## Conversations") |
|
conversation_selector = gr.Radio(choices=[], label="Select Conversation", interactive=True) |
|
new_convo_button = gr.Button("New Conversation ➕") |
|
|
|
current_convo_id = gr.State(generate_conversation_id()) |
|
history_state = gr.State([]) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
gr.Markdown( |
|
""" |
|
Welcome to the Ramanujan Ganit R1 14B V1 Chatbot, developed by Fractal AI Research! |
|
|
|
Our model excels at reasoning tasks in mathematics and science. |
|
|
|
Try the example problems below from JEE Main 2025 or type in your own problems to see how our model breaks down complex reasoning problems. |
|
|
|
Please note that once you close this demo window, all currently saved conversations will be lost. |
|
""" |
|
) |
|
|
|
gr.Markdown("### Settings") |
|
max_tokens_slider = gr.Slider(minimum=6144, maximum=32768, step=1024, value=16384, label="Max Tokens") |
|
with gr.Accordion("Advanced Settings", open=True): |
|
temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.6, label="Temperature") |
|
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p") |
|
|
|
|
|
gr.Markdown(""" |
|
|
|
We sincerely acknowledge [VIDraft](https://huggingface.co/VIDraft) for their Phi 4 Reasoning Plus [space](https://huggingface.co/spaces/VIDraft/phi-4-reasoning-plus), which served as the starting point for this demo. |
|
""" |
|
) |
|
|
|
with gr.Column(scale=4): |
|
|
|
chatbot = gr.Chatbot(label="Chat", type="messages", height=520) |
|
with gr.Row(): |
|
user_input = gr.Textbox(label="User Input", placeholder="Type your question here...", lines=3, scale=8) |
|
with gr.Column(): |
|
submit_button = gr.Button("Send", variant="primary", scale=1) |
|
clear_button = gr.Button("Clear", scale=1) |
|
gr.Markdown("**Try these examples:**") |
|
with gr.Row(): |
|
example1_button = gr.Button("JEE Main 2025\nCombinatorics") |
|
example2_button = gr.Button("JEE Main 2025\nCoordinate Geometry") |
|
example3_button = gr.Button("JEE Main 2025\nProbability & Statistics") |
|
example4_button = gr.Button("JEE Main 2025\nLaws of Motion") |
|
|
|
def update_conversation_list(): |
|
return [conversations[cid]["title"] for cid in conversations] |
|
|
|
def start_new_conversation(): |
|
new_id = generate_conversation_id() |
|
conversations[new_id] = {"title": f"New Conversation {new_id}", "messages": []} |
|
return new_id, [], gr.update(choices=update_conversation_list(), value=conversations[new_id]["title"]) |
|
|
|
def load_conversation(selected_title): |
|
for cid, convo in conversations.items(): |
|
if convo["title"] == selected_title: |
|
return cid, convo["messages"], convo["messages"] |
|
return current_convo_id.value, history_state.value, history_state.value |
|
|
|
def send_message(user_message, max_tokens, temperature, top_p, convo_id, history): |
|
if convo_id not in conversations: |
|
|
|
title = " ".join(user_message.strip().split()[:5]) |
|
conversations[convo_id] = {"title": title, "messages": history} |
|
if conversations[convo_id]["title"].startswith("New Conversation"): |
|
|
|
conversations[convo_id]["title"] = " ".join(user_message.strip().split()[:5]) |
|
for updated_history, new_history in generate_response(user_message, max_tokens, temperature, top_p, history): |
|
conversations[convo_id]["messages"] = new_history |
|
yield updated_history, new_history, gr.update(choices=update_conversation_list(), value=conversations[convo_id]["title"]) |
|
|
|
submit_button.click( |
|
fn=send_message, |
|
inputs=[user_input, max_tokens_slider, temperature_slider, top_p_slider, current_convo_id, history_state], |
|
outputs=[chatbot, history_state, conversation_selector] |
|
).then( |
|
fn=lambda: gr.update(value=""), |
|
inputs=None, |
|
outputs=user_input |
|
) |
|
|
|
clear_button.click( |
|
fn=lambda: ([], []), |
|
inputs=None, |
|
outputs=[chatbot, history_state] |
|
) |
|
|
|
new_convo_button.click( |
|
fn=start_new_conversation, |
|
inputs=None, |
|
outputs=[current_convo_id, history_state, conversation_selector] |
|
) |
|
|
|
conversation_selector.change( |
|
fn=load_conversation, |
|
inputs=conversation_selector, |
|
outputs=[current_convo_id, history_state, chatbot] |
|
) |
|
|
|
example1_button.click(fn=lambda: gr.update(value=example_messages["JEE Main 2025 Combinatorics"]), inputs=None, outputs=user_input) |
|
example2_button.click(fn=lambda: gr.update(value=example_messages["JEE Main 2025 Coordinate Geometry"]), inputs=None, outputs=user_input) |
|
example3_button.click(fn=lambda: gr.update(value=example_messages["JEE Main 2025 Probability & Statistics"]), inputs=None, outputs=user_input) |
|
example4_button.click(fn=lambda: gr.update(value=example_messages["JEE Main 2025 Laws of Motion"]), inputs=None, outputs=user_input) |
|
|
|
demo.launch(share=True, ssr_mode=False) |
|
|