|
|
|
|
|
|
|
import gradio as gr |
|
import spaces |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
import torch, re, uuid |
|
from threading import Thread |
|
from openai import OpenAI |
|
import tiktoken |
|
|
|
|
|
client = OpenAI( |
|
base_url="https://a7g1ajqixo23revq.us-east-1.aws.endpoints.huggingface.cloud/v1/", |
|
api_key="hf_XXXXX" |
|
) |
|
|
|
|
|
def format_math(text): |
|
text = re.sub(r"\[(.*?)\]", r"$$\1$$", text, flags=re.DOTALL) |
|
text = text.replace(r"\(", "$").replace(r"\)", "$") |
|
return text |
|
|
|
def generate_conversation_id() -> str: |
|
return str(uuid.uuid4())[:8] |
|
|
|
enc = tiktoken.encoding_for_model("gpt-3.5-turbo") |
|
|
|
|
|
def generate_response(user_message, |
|
max_tokens, |
|
temperature, |
|
top_p, |
|
history_state): |
|
if not user_message.strip(): |
|
return history_state, history_state |
|
|
|
system_message = "Your role as an assistant..." |
|
messages = [{"role": "system", "content": system_message}] |
|
for m in history_state: |
|
messages.append({"role": m["role"], "content": m["content"]}) |
|
messages.append({"role": "user", "content": user_message}) |
|
|
|
try: |
|
response = client.chat.completions.create( |
|
model="tgi", |
|
messages=messages, |
|
max_tokens=int(max_tokens), |
|
temperature=temperature, |
|
top_p=top_p, |
|
stream=True |
|
) |
|
except Exception as e: |
|
print(f"[ERROR] OpenAI API call failed: {e}") |
|
yield history_state + [ |
|
{"role": "user", "content": user_message}, |
|
{"role": "assistant", "content": "⚠️ Generation failed."} |
|
], history_state |
|
return |
|
|
|
assistant_response = "" |
|
new_history = history_state + [ |
|
{"role": "user", "content": user_message}, |
|
{"role": "assistant", "content": ""} |
|
] |
|
|
|
token_budget = int(max_tokens) |
|
tokens_seen = 0 |
|
|
|
try: |
|
for chunk in response: |
|
if (not chunk.choices |
|
or not chunk.choices[0].delta |
|
or not chunk.choices[0].delta.content): |
|
continue |
|
|
|
token_text = chunk.choices[0].delta.content |
|
assistant_response += token_text |
|
tokens_seen += len(enc.encode(token_text)) |
|
|
|
new_history[-1]["content"] = assistant_response.strip() |
|
yield new_history, new_history |
|
|
|
if tokens_seen >= token_budget: |
|
break |
|
except Exception: |
|
pass |
|
|
|
yield new_history, new_history |
|
|
|
|
|
example_messages = { |
|
"IIT-JEE 2024 Mathematics": "...", |
|
"IIT-JEE 2025 Physics": "...", |
|
"Goldman Sachs Interview Puzzle": "...", |
|
"IIT-JEE 2025 Mathematics": "..." |
|
} |
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
|
|
|
conversations_state = gr.State({}) |
|
current_convo_id = gr.State(generate_conversation_id()) |
|
history_state = gr.State([]) |
|
|
|
|
|
gr.HTML(""" |
|
<div style="display:flex;align-items:center;gap:16px;margin-bottom:1em;"> |
|
<div style="background-color:black;padding:6px;border-radius:8px;"> |
|
<img src="https://framerusercontent.com/images/j0KjQQyrUfkFw4NwSaxQOLAoBU.png" |
|
style="height:48px;"> |
|
</div> |
|
<h1 style="margin:0;">Fathom R1 14B Chatbot</h1> |
|
</div> |
|
""") |
|
|
|
with gr.Sidebar(): |
|
gr.Markdown("## Conversations") |
|
conversation_selector = gr.Radio(choices=[], label="Select Conversation", interactive=True) |
|
new_convo_button = gr.Button("New Conversation ➕") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("""Welcome to the Fathom R1 14B Chatbot, developed by Fractal AI Research! ...""") |
|
gr.Markdown("### Settings") |
|
max_tokens_slider = gr.Slider(6144, 32768, step=1024, value=16384, label="Max Tokens") |
|
with gr.Accordion("Advanced Settings", open=True): |
|
temperature_slider = gr.Slider(0.1, 2.0, value=0.6, label="Temperature") |
|
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p") |
|
gr.Markdown("""We sincerely acknowledge [VIDraft]...""") |
|
with gr.Column(scale=4): |
|
chatbot = gr.Chatbot(label="Chat", type="messages", height=520) |
|
with gr.Row(): |
|
user_input = gr.Textbox(label="User Input", placeholder="Type your question here...", lines=3, scale=8) |
|
with gr.Column(): |
|
submit_button = gr.Button("Send", variant="primary", scale=1) |
|
clear_button = gr.Button("Clear", scale=1) |
|
gr.Markdown("**Try these examples:**") |
|
with gr.Row(): |
|
example1_button = gr.Button("IIT-JEE 2025 Mathematics") |
|
example2_button = gr.Button("IIT-JEE 2025 Physics") |
|
example3_button = gr.Button("Goldman Sachs Interview Puzzle") |
|
example4_button = gr.Button("IIT-JEE 2024 Mathematics") |
|
|
|
|
|
def update_conversation_list(conversations): |
|
return [conversations[cid]["title"] for cid in conversations] |
|
|
|
def start_new_conversation(conversations): |
|
new_id = generate_conversation_id() |
|
conversations[new_id] = {"title": f"New Conversation {new_id}", "messages": []} |
|
return (new_id, [], |
|
gr.update(choices=update_conversation_list(conversations), |
|
value=conversations[new_id]["title"]), |
|
conversations) |
|
|
|
def load_conversation(selected_title, conversations): |
|
for cid, convo in conversations.items(): |
|
if convo["title"] == selected_title: |
|
return cid, convo["messages"], convo["messages"] |
|
return current_convo_id.value, history_state.value, history_state.value |
|
|
|
def send_message(user_message, max_tokens, temperature, top_p, |
|
convo_id, history, conversations): |
|
|
|
if convo_id not in conversations: |
|
title = " ".join(user_message.strip().split()[:5]) |
|
conversations[convo_id] = {"title": title, "messages": history} |
|
|
|
if conversations[convo_id]["title"].startswith("New Conversation"): |
|
conversations[convo_id]["title"] = " ".join(user_message.strip().split()[:5]) |
|
|
|
for updated_history, new_history in generate_response( |
|
user_message, max_tokens, temperature, top_p, history): |
|
conversations[convo_id]["messages"] = new_history |
|
yield (updated_history, |
|
new_history, |
|
gr.update(choices=update_conversation_list(conversations), |
|
value=conversations[convo_id]["title"]), |
|
conversations) |
|
|
|
|
|
submit_button.click( |
|
fn=send_message, |
|
inputs=[user_input, max_tokens_slider, temperature_slider, top_p_slider, |
|
current_convo_id, history_state, conversations_state], |
|
outputs=[chatbot, history_state, conversation_selector, conversations_state], |
|
concurrency_limit=16 |
|
).then( |
|
fn=lambda: gr.update(value=""), |
|
inputs=None, |
|
outputs=user_input |
|
) |
|
|
|
clear_button.click( |
|
fn=lambda: ([], []), |
|
inputs=None, |
|
outputs=[chatbot, history_state] |
|
) |
|
|
|
new_convo_button.click( |
|
fn=start_new_conversation, |
|
inputs=[conversations_state], |
|
outputs=[current_convo_id, history_state, conversation_selector, conversations_state] |
|
) |
|
|
|
conversation_selector.change( |
|
fn=load_conversation, |
|
inputs=[conversation_selector, conversations_state], |
|
outputs=[current_convo_id, history_state, chatbot] |
|
) |
|
|
|
|
|
example1_button.click(lambda: gr.update(value=example_messages["IIT-JEE 2025 Mathematics"]), |
|
None, user_input) |
|
example2_button.click(lambda: gr.update(value=example_messages["IIT-JEE 2025 Physics"]), |
|
None, user_input) |
|
example3_button.click(lambda: gr.update(value=example_messages["Goldman Sachs Interview Puzzle"]), |
|
None, user_input) |
|
example4_button.click(lambda: gr.update(value=example_messages["IIT-JEE 2024 Mathematics"]), |
|
None, user_input) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch(share=True, ssr_mode=False) |