import asyncio import functools import re import gradio as gr import spaces from transformers import AsyncTextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer JS = """ () => { // auto scroll .auto-scroll elements when text has changed const observer = new MutationObserver((mutations) => { mutations.forEach((mutation) => { // find the parent element with .auto-scroll class and having the "overflow" // style attribute to "auto" let element = mutation.target; while(element.parentElement !== null && element.parentElement.style.overflow !== "auto") { element = element.parentElement; } if (element.parentElement === null) { return; } element = element.parentElement; element.scrollTop = element.scrollHeight; }); }) document.querySelectorAll('.auto-scroll > *').forEach((elem) => { console.log("observing", elem) observer.observe(elem, { childList: true, characterData: true, }) }); } """ model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto", ) print(dir(model)) print(model.config) tokenizer = AutoTokenizer.from_pretrained(model_name) def reformat_math(text): """Fix MathJax delimiters to use the Gradio syntax. This is a workaround to display math formulas in Gradio. For now, I havn't found a way to make it work as expected using others latex_delimites... """ text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL) text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL) return text async def chat(prompt, history): """Respond to a chat prompt.""" message = { "role": "user", "content": prompt, } history = [] if history is None else history @spaces.GPU def _generate(): text = tokenizer.apply_chat_template( history + [message], tokenize=False, add_generation_prompt=True, ) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) streamer = AsyncTextIteratorStreamer(tokenizer, skip_special_tokens=True) task = asyncio.get_running_loop().run_in_executor( None, functools.partial( model.generate, max_new_tokens=1024 * 128, streamer=streamer, **model_inputs, ), ) return task, streamer task, streamer = _generate() buffer = "" reasoning = "" thinking = False try: async for new_text in streamer: if task.done() or task.cancelled(): print("Cancelled") break # Stop le streaming si la tâche est annulée if not thinking and "" in new_text: thinking = True continue if thinking and "" in new_text: thinking = False continue if thinking: reasoning += new_text heading = "# Reasoning\n\n" yield "I'm thinking, please wait a moment...", heading + reasoning continue buffer += new_text yield reformat_math(buffer), reasoning except asyncio.CancelledError: # this doesn't work, I don't find a way to stop generation thread print("Cancelled") streamer.on_finalized_text("cancelled", True) print("Signal sent") raise chat_bot = gr.Chatbot( latex_delimiters=[ {"left": "$$", "right": "$$", "display": True}, {"left": "$", "right": "$", "display": False}, ], scale=1, type="messages", ) with gr.Blocks(js=JS) as demo: reasoning = gr.Markdown( "# Reasoning\n\nWhen the model will reasoning, its thoughts will be displayed here.", label="Reasoning", show_label=True, container=True, elem_classes="auto-scroll", max_height="90vh", render=False, ) with gr.Row(equal_height=True, height="90vh"): with gr.Column(scale=3): gr.ChatInterface( chat, type="messages", chatbot=chat_bot, title=str(model_name), description=( f"*{model_name}* is a large language model " "trained on a mixture of instruction and " "conversational data." ), additional_outputs=[reasoning], ) with gr.Column(): reasoning.render() if __name__ == "__main__": demo.queue().launch()