Metal3d's picture
Moving spaces.GPU
48a12f0 unverified
raw
history blame
4.9 kB
import asyncio
import functools
import re
import gradio as gr
import spaces
from transformers import AsyncTextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer
JS = """
() => {
// auto scroll .auto-scroll elements when text has changed
const observer = new MutationObserver((mutations) => {
mutations.forEach((mutation) => {
// find the parent element with .auto-scroll class and having the "overflow"
// style attribute to "auto"
let element = mutation.target;
while(element.parentElement !== null && element.parentElement.style.overflow !== "auto") {
element = element.parentElement;
}
if (element.parentElement === null) {
return;
}
element = element.parentElement;
element.scrollTop = element.scrollHeight;
});
})
document.querySelectorAll('.auto-scroll > *').forEach((elem) => {
console.log("observing", elem)
observer.observe(elem, {
childList: true,
characterData: true,
})
});
}
"""
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
)
print(dir(model))
print(model.config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def reformat_math(text):
"""Fix MathJax delimiters to use the Gradio syntax.
This is a workaround to display math formulas in Gradio. For now, I havn't found a way to
make it work as expected using others latex_delimites...
"""
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
return text
@spaces.GPU
def _generate(history):
text = tokenizer.apply_chat_template(
history,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
streamer = AsyncTextIteratorStreamer(tokenizer, skip_special_tokens=True)
task = asyncio.get_running_loop().run_in_executor(
None,
functools.partial(
model.generate,
max_new_tokens=1024 * 128,
streamer=streamer,
**model_inputs,
),
)
return task, streamer
async def chat(prompt, history):
"""Respond to a chat prompt."""
message = {
"role": "user",
"content": prompt,
}
# build the messages list
history = [] if history is None else history
message_list = history + [message]
# get the task and the streamer
task, streamer = _generate(message_list)
buffer = ""
reasoning = ""
thinking = False
try:
async for new_text in streamer:
if task.done() or task.cancelled():
print("Cancelled")
break # Stop le streaming si la tâche est annulée
if not thinking and "<think>" in new_text:
thinking = True
continue
if thinking and "</think>" in new_text:
thinking = False
continue
if thinking:
reasoning += new_text
heading = "# Reasoning\n\n"
yield "I'm thinking, please wait a moment...", heading + reasoning
continue
buffer += new_text
yield reformat_math(buffer), reasoning
except asyncio.CancelledError:
# this doesn't work, I don't find a way to stop generation thread
print("Cancelled")
streamer.on_finalized_text("cancelled", True)
print("Signal sent")
raise
chat_bot = gr.Chatbot(
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
],
scale=1,
type="messages",
)
with gr.Blocks(js=JS) as demo:
reasoning = gr.Markdown(
"# Reasoning\n\nWhen the model will reasoning, its thoughts will be displayed here.",
label="Reasoning",
show_label=True,
container=True,
elem_classes="auto-scroll",
max_height="90vh",
render=False,
)
with gr.Row(equal_height=True, height="90vh"):
with gr.Column(scale=3):
gr.ChatInterface(
chat,
type="messages",
chatbot=chat_bot,
title=str(model_name),
description=(
f"*{model_name}* is a large language model "
"trained on a mixture of instruction and "
"conversational data."
),
additional_outputs=[reasoning],
)
with gr.Column():
reasoning.render()
if __name__ == "__main__":
demo.queue().launch()