Spaces:
Running
on
Zero
Running
on
Zero
import asyncio | |
import functools | |
import re | |
import gradio as gr | |
import spaces | |
from transformers import AsyncTextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer | |
JS = """ | |
() => { | |
// auto scroll .auto-scroll elements when text has changed | |
const observer = new MutationObserver((mutations) => { | |
mutations.forEach((mutation) => { | |
// find the parent element with .auto-scroll class and having the "overflow" | |
// style attribute to "auto" | |
let element = mutation.target; | |
while(element.parentElement !== null && element.parentElement.style.overflow !== "auto") { | |
element = element.parentElement; | |
} | |
if (element.parentElement === null) { | |
return; | |
} | |
element = element.parentElement; | |
element.scrollTop = element.scrollHeight; | |
}); | |
}) | |
document.querySelectorAll('.auto-scroll > *').forEach((elem) => { | |
console.log("observing", elem) | |
observer.observe(elem, { | |
childList: true, | |
characterData: true, | |
}) | |
}); | |
} | |
""" | |
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype="auto", | |
device_map="auto", | |
) | |
print(dir(model)) | |
print(model.config) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def reformat_math(text): | |
"""Fix MathJax delimiters to use the Gradio syntax. | |
This is a workaround to display math formulas in Gradio. For now, I havn't found a way to | |
make it work as expected using others latex_delimites... | |
""" | |
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL) | |
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL) | |
return text | |
def _generate(history): | |
text = tokenizer.apply_chat_template( | |
history, | |
tokenize=False, | |
add_generation_prompt=True, | |
) | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
streamer = AsyncTextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
task = asyncio.get_running_loop().run_in_executor( | |
None, | |
functools.partial( | |
model.generate, | |
max_new_tokens=1024 * 128, | |
streamer=streamer, | |
**model_inputs, | |
), | |
) | |
return task, streamer | |
async def chat(prompt, history): | |
"""Respond to a chat prompt.""" | |
message = { | |
"role": "user", | |
"content": prompt, | |
} | |
# build the messages list | |
history = [] if history is None else history | |
message_list = history + [message] | |
# get the task and the streamer | |
task, streamer = _generate(message_list) | |
buffer = "" | |
reasoning = "" | |
thinking = False | |
try: | |
async for new_text in streamer: | |
if task.done() or task.cancelled(): | |
print("Cancelled") | |
break # Stop le streaming si la tâche est annulée | |
if not thinking and "<think>" in new_text: | |
thinking = True | |
continue | |
if thinking and "</think>" in new_text: | |
thinking = False | |
continue | |
if thinking: | |
reasoning += new_text | |
heading = "# Reasoning\n\n" | |
yield "I'm thinking, please wait a moment...", heading + reasoning | |
continue | |
buffer += new_text | |
yield reformat_math(buffer), reasoning | |
except asyncio.CancelledError: | |
# this doesn't work, I don't find a way to stop generation thread | |
print("Cancelled") | |
streamer.on_finalized_text("cancelled", True) | |
print("Signal sent") | |
raise | |
chat_bot = gr.Chatbot( | |
latex_delimiters=[ | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "$", "right": "$", "display": False}, | |
], | |
scale=1, | |
type="messages", | |
) | |
with gr.Blocks(js=JS) as demo: | |
reasoning = gr.Markdown( | |
"# Reasoning\n\nWhen the model will reasoning, its thoughts will be displayed here.", | |
label="Reasoning", | |
show_label=True, | |
container=True, | |
elem_classes="auto-scroll", | |
max_height="90vh", | |
render=False, | |
) | |
with gr.Row(equal_height=True, height="90vh"): | |
with gr.Column(scale=3): | |
gr.ChatInterface( | |
chat, | |
type="messages", | |
chatbot=chat_bot, | |
title=str(model_name), | |
description=( | |
f"*{model_name}* is a large language model " | |
"trained on a mixture of instruction and " | |
"conversational data." | |
), | |
additional_outputs=[reasoning], | |
) | |
with gr.Column(): | |
reasoning.render() | |
if __name__ == "__main__": | |
demo.queue().launch() | |