Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,863 Bytes
e0c9e04 633edd7 e0c9e04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import asyncio
import functools
import re
import gradio as gr
import spaces
from transformers import AsyncTextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer
JS = """
() => {
// auto scroll .auto-scroll elements when text has changed
const observer = new MutationObserver((mutations) => {
mutations.forEach((mutation) => {
// find the parent element with .auto-scroll class and having the "overflow"
// style attribute to "auto"
let element = mutation.target;
while(element.parentElement !== null && element.parentElement.style.overflow !== "auto") {
element = element.parentElement;
}
if (element.parentElement === null) {
return;
}
element = element.parentElement;
element.scrollTop = element.scrollHeight;
});
})
document.querySelectorAll('.auto-scroll > *').forEach((elem) => {
console.log("observing", elem)
observer.observe(elem, {
childList: true,
characterData: true,
})
});
}
"""
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
)
print(dir(model))
print(model.config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def reformat_math(text):
"""Fix MathJax delimiters to use the Gradio syntax.
This is a workaround to display math formulas in Gradio. For now, I havn't found a way to
make it work as expected using others latex_delimites...
"""
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
return text
async def chat(prompt, history):
"""Respond to a chat prompt."""
message = {
"role": "user",
"content": prompt,
}
history = [] if history is None else history
@spaces.GPU
def _generate():
text = tokenizer.apply_chat_template(
history + [message],
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
streamer = AsyncTextIteratorStreamer(tokenizer, skip_special_tokens=True)
task = asyncio.get_running_loop().run_in_executor(
None,
functools.partial(
model.generate,
max_new_tokens=1024 * 128,
streamer=streamer,
**model_inputs,
),
)
return task, streamer
task, streamer = _generate()
buffer = ""
reasoning = ""
thinking = False
try:
async for new_text in streamer:
if task.done() or task.cancelled():
print("Cancelled")
break # Stop le streaming si la tâche est annulée
if not thinking and "<think>" in new_text:
thinking = True
continue
if thinking and "</think>" in new_text:
thinking = False
continue
if thinking:
reasoning += new_text
heading = "# Reasoning\n\n"
yield "I'm thinking, please wait a moment...", heading + reasoning
continue
buffer += new_text
yield reformat_math(buffer), reasoning
except asyncio.CancelledError:
# this doesn't work, I don't find a way to stop generation thread
print("Cancelled")
streamer.on_finalized_text("cancelled", True)
print("Signal sent")
raise
chat_bot = gr.Chatbot(
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
],
scale=1,
type="messages",
)
with gr.Blocks(js=JS) as demo:
reasoning = gr.Markdown(
"# Reasoning\n\nWhen the model will reasoning, its thoughts will be displayed here.",
label="Reasoning",
show_label=True,
container=True,
elem_classes="auto-scroll",
max_height="90vh",
render=False,
)
with gr.Row(equal_height=True, height="90vh"):
with gr.Column(scale=3):
gr.ChatInterface(
chat,
type="messages",
chatbot=chat_bot,
title=str(model_name),
description=(
f"*{model_name}* is a large language model "
"trained on a mixture of instruction and "
"conversational data."
),
additional_outputs=[reasoning],
)
with gr.Column():
reasoning.render()
if __name__ == "__main__":
demo.queue().launch()
|