File size: 10,461 Bytes
1bf7168
15d32ba
 
 
 
 
5889e5d
15d32ba
5889e5d
865b696
15d32ba
 
865b696
 
15d32ba
 
 
 
 
 
08d1d51
5889e5d
 
08d1d51
 
5889e5d
a878ad7
15d32ba
 
 
 
865b696
 
15d32ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e844849
15d32ba
e844849
15d32ba
 
 
 
 
 
5889e5d
15d32ba
 
 
 
 
 
 
 
 
51c367c
15d32ba
 
 
 
 
 
 
 
 
 
 
b10c582
 
15d32ba
 
23dc432
 
 
 
15d32ba
 
feb219b
179464f
fb1bbbc
 
 
 
5a26bce
e3371ea
 
fb1bbbc
 
 
 
15d32ba
5889e5d
 
 
179464f
5889e5d
08d1d51
15d32ba
 
 
 
179464f
 
 
865b696
 
179464f
865b696
179464f
7faefe7
 
179464f
 
 
15d32ba
08d1d51
6123382
08d1d51
 
3f8bcd2
37ecec5
066a2fc
fb4b038
7faefe7
066a2fc
fb4b038
37ecec5
3f8bcd2
179464f
870cc0f
5889e5d
d336226
5889e5d
 
 
15d32ba
 
23dc432
 
 
 
02e4fca
5889e5d
08d1d51
5889e5d
 
08d1d51
 
 
5889e5d
08d1d51
 
 
 
 
5889e5d
 
08d1d51
23dc432
 
08d1d51
 
6b29f7a
 
5889e5d
08d1d51
 
5889e5d
15d32ba
5889e5d
 
08d1d51
15d32ba
 
 
 
 
 
 
 
 
 
 
 
5889e5d
 
 
 
 
 
 
 
 
 
 
 
23dc432
 
 
 
15d32ba
e844849
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
import re
import uuid

# Load model and tokenizer
our_model_path = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
device = "cuda:0" if torch.cuda.is_available() else "cpu"

our_model = AutoModelForCausalLM.from_pretrained(our_model_path, device_map="auto", torch_dtype="auto")
our_tokenizer = AutoTokenizer.from_pretrained(our_model_path)

def format_math(text):
    text = re.sub(r"\[(.*?)\]", r"$$\1$$", text, flags=re.DOTALL)
    text = text.replace(r"\(", "$").replace(r"\)", "$")
    return text

# Global dictionary to store all conversations: {id: {"title": str, "messages": list}}
conversations = {}

def generate_conversation_id():
    return str(uuid.uuid4())[:8]

@spaces.GPU(duration=60)
def generate_response(user_message, max_tokens, temperature, top_p, history_state):
    if not user_message.strip():
        return history_state, history_state

    model = our_model
    tokenizer = our_tokenizer
    start_tag = "<|im_start|>"
    sep_tag = "<|im_sep|>"
    end_tag = "<|im_end|>"

    system_message = "Your role as an assistant..."
    prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
    for message in history_state:
        if message["role"] == "user":
            prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}"
        elif message["role"] == "assistant" and message["content"]:
            prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}"
    prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    generation_kwargs = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "max_new_tokens": int(max_tokens),
        "do_sample": True,
        "temperature": temperature,
        "top_k": 50,
        "top_p": top_p,
        "repetition_penalty": 1.0,
        "pad_token_id": tokenizer.eos_token_id,
        "streamer": streamer,
    }

    try:
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
    except Exception:
        yield history_state + [{"role": "user", "content": user_message}, {"role": "assistant", "content": "⚠️ Generation failed."}], history_state
        return

    assistant_response = ""
    new_history = history_state + [
        {"role": "user", "content": user_message},
        {"role": "assistant", "content": ""}
    ]

    try:
        for new_token in streamer:
            if "<|end" in new_token:
                continue
            cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "")
            assistant_response += cleaned_token
            new_history[-1]["content"] = assistant_response.strip()
            yield new_history, new_history
    except Exception:
        pass

    yield new_history, new_history


example_messages = {
    "JEE Main 2025 Combinatorics": "From all the English alphabets, five letters are chosen and are arranged in alphabetical order. The total number of ways, in which the middle letter is 'M', is?",
    "JEE Main 2025 Coordinate Geometry": "A circle \\(C\\) of radius 2 lies in the second quadrant and touches both the coordinate axes. Let \\(r\\) be the radius of a circle that has centre at the point \\((2, 5)\\) and intersects the circle \\(C\\) at exactly two points. If the set of all possible values of \\(r\\) is the interval \\((\\alpha, \\beta)\\), then \\(3\\beta - 2\\alpha\\) is?",
    "JEE Main 2025 Probability & Statistics": "A coin is tossed three times. Let \(X\) denote the number of times a tail follows a head. If \\(\\mu\\) and \\(\\sigma^2\\) denote the mean and variance of \\(X\\), then the value of \\(64(\\mu + \\sigma^2)\\) is?",
    "JEE Main 2025 Laws of Motion": "A massless spring gets elongated by amount x_1 under a tension of 5 N . Its elongation is x_2 under the tension of 7 N . For the elongation of 5x_1 - 2x_2 , the tension in the spring will be?"
}

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    # Global heading stays at top
    #gr.Markdown("# Ramanujan Ganit R1 14B V1 Chatbot")
    gr.HTML(
    """
    <div style="display: flex; align-items: center; gap: 16px; margin-bottom: 1em;">
        <div style="background-color: black; padding: 6px; border-radius: 8px;">
            <img src="https://framerusercontent.com/images/j0KjQQyrUfkFw4NwSaxQOLAoBU.png" alt="Fractal AI Logo" style="height: 48px;">
        </div>
        <h1 style="margin: 0;">Ramanujan Ganit R1 14B V1 Chatbot</h1>
    </div>
    """
)

    with gr.Sidebar():
        gr.Markdown("## Conversations")
        conversation_selector = gr.Radio(choices=[], label="Select Conversation", interactive=True)
        new_convo_button = gr.Button("New Conversation ➕")

    current_convo_id = gr.State(generate_conversation_id())
    history_state = gr.State([])

    with gr.Row():
        with gr.Column(scale=1):
            # INTRO TEXT MOVED HERE
            gr.Markdown(
                """
                Welcome to the Ramanujan Ganit R1 14B V1 Chatbot, developed by Fractal AI Research! 
                
                Our model excels at reasoning tasks in mathematics and science.  
                
                Try the example problems below from JEE Main 2025 or type in your own problems to see how our model breaks down complex reasoning problems.

                Please note that once you close this demo window, all currently saved conversations will be lost.
                """
            )

            gr.Markdown("### Settings")
            max_tokens_slider = gr.Slider(minimum=6144, maximum=32768, step=1024, value=16384, label="Max Tokens")
            with gr.Accordion("Advanced Settings", open=True):
                temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.6, label="Temperature")
                top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p")

            # New acknowledgment line at bottom
            gr.Markdown("""
            
                        We sincerely acknowledge [VIDraft](https://huggingface.co/VIDraft) for their Phi 4 Reasoning Plus [space](https://huggingface.co/spaces/VIDraft/phi-4-reasoning-plus), which served as the starting point for this demo.
                        """
                        )

        with gr.Column(scale=4):
            #chatbot = gr.Chatbot(label="Chat", type="messages")
            chatbot = gr.Chatbot(label="Chat", type="messages", height=520)
            with gr.Row():
                user_input = gr.Textbox(label="User Input", placeholder="Type your question here...", lines=3, scale=8)
                with gr.Column():
                    submit_button = gr.Button("Send", variant="primary", scale=1)
                    clear_button = gr.Button("Clear", scale=1)
            gr.Markdown("**Try these examples:**")
            with gr.Row():
                example1_button = gr.Button("JEE Main 2025\nCombinatorics")
                example2_button = gr.Button("JEE Main 2025\nCoordinate Geometry")
                example3_button = gr.Button("JEE Main 2025\nProbability & Statistics")
                example4_button = gr.Button("JEE Main 2025\nLaws of Motion")

    def update_conversation_list():
        return [conversations[cid]["title"] for cid in conversations]

    def start_new_conversation():
        new_id = generate_conversation_id()
        conversations[new_id] = {"title": f"New Conversation {new_id}", "messages": []}
        return new_id, [], gr.update(choices=update_conversation_list(), value=conversations[new_id]["title"])

    def load_conversation(selected_title):
        for cid, convo in conversations.items():
            if convo["title"] == selected_title:
                return cid, convo["messages"], convo["messages"]
        return current_convo_id.value, history_state.value, history_state.value

    def send_message(user_message, max_tokens, temperature, top_p, convo_id, history):
        if convo_id not in conversations:
            #title = user_message.strip().split("\n")[0][:40]
            title = " ".join(user_message.strip().split()[:5])
            conversations[convo_id] = {"title": title, "messages": history}
        if conversations[convo_id]["title"].startswith("New Conversation"):
            #conversations[convo_id]["title"] = user_message.strip().split("\n")[0][:40]
            conversations[convo_id]["title"] = " ".join(user_message.strip().split()[:5])
        for updated_history, new_history in generate_response(user_message, max_tokens, temperature, top_p, history):
            conversations[convo_id]["messages"] = new_history
            yield updated_history, new_history, gr.update(choices=update_conversation_list(), value=conversations[convo_id]["title"])

    submit_button.click(
        fn=send_message,
        inputs=[user_input, max_tokens_slider, temperature_slider, top_p_slider, current_convo_id, history_state],
        outputs=[chatbot, history_state, conversation_selector]
    ).then(
        fn=lambda: gr.update(value=""),
        inputs=None,
        outputs=user_input
    )

    clear_button.click(
        fn=lambda: ([], []),
        inputs=None,
        outputs=[chatbot, history_state]
    )

    new_convo_button.click(
        fn=start_new_conversation,
        inputs=None,
        outputs=[current_convo_id, history_state, conversation_selector]
    )

    conversation_selector.change(
        fn=load_conversation,
        inputs=conversation_selector,
        outputs=[current_convo_id, history_state, chatbot]
    )

    example1_button.click(fn=lambda: gr.update(value=example_messages["JEE Main 2025 Combinatorics"]), inputs=None, outputs=user_input)
    example2_button.click(fn=lambda: gr.update(value=example_messages["JEE Main 2025 Coordinate Geometry"]), inputs=None, outputs=user_input)
    example3_button.click(fn=lambda: gr.update(value=example_messages["JEE Main 2025 Probability & Statistics"]), inputs=None, outputs=user_input)
    example4_button.click(fn=lambda: gr.update(value=example_messages["JEE Main 2025 Laws of Motion"]), inputs=None, outputs=user_input)

demo.launch(share=True, ssr_mode=False)