File size: 4,863 Bytes
e0c9e04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633edd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0c9e04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import asyncio
import functools
import re

import gradio as gr
import spaces
from transformers import AsyncTextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer

JS = """
() => {
    // auto scroll .auto-scroll elements when text has changed
    const observer = new MutationObserver((mutations) => {
        mutations.forEach((mutation) => {
            // find the parent element with .auto-scroll class and having the "overflow"
            // style attribute to "auto"
            let element = mutation.target;
            while(element.parentElement !== null && element.parentElement.style.overflow !== "auto") {
                element = element.parentElement;
            }
            if (element.parentElement === null) {
                return;
            }
            element = element.parentElement;
            element.scrollTop = element.scrollHeight;
        });
    })
    document.querySelectorAll('.auto-scroll > *').forEach((elem) => {
        console.log("observing", elem)
        observer.observe(elem, {
            childList: true,
            characterData: true,
        })
    });
}
"""

model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto",
)
print(dir(model))
print(model.config)
tokenizer = AutoTokenizer.from_pretrained(model_name)


def reformat_math(text):
    """Fix MathJax delimiters to use the Gradio syntax.

    This is a workaround to display math formulas in Gradio. For now, I havn't found a way to
    make it work as expected using others latex_delimites...
    """
    text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
    text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
    return text


async def chat(prompt, history):
    """Respond to a chat prompt."""
    message = {
        "role": "user",
        "content": prompt,
    }

    history = [] if history is None else history

    @spaces.GPU
    def _generate():
        text = tokenizer.apply_chat_template(
            history + [message],
            tokenize=False,
            add_generation_prompt=True,
        )

        model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
        streamer = AsyncTextIteratorStreamer(tokenizer, skip_special_tokens=True)

        task = asyncio.get_running_loop().run_in_executor(
            None,
            functools.partial(
                model.generate,
                max_new_tokens=1024 * 128,
                streamer=streamer,
                **model_inputs,
            ),
        )
        return task, streamer

    task, streamer = _generate()

    buffer = ""
    reasoning = ""
    thinking = False

    try:
        async for new_text in streamer:
            if task.done() or task.cancelled():
                print("Cancelled")
                break  # Stop le streaming si la tâche est annulée

            if not thinking and "<think>" in new_text:
                thinking = True
                continue
            if thinking and "</think>" in new_text:
                thinking = False
                continue

            if thinking:
                reasoning += new_text
                heading = "# Reasoning\n\n"
                yield "I'm thinking, please wait a moment...", heading + reasoning
                continue

            buffer += new_text
            yield reformat_math(buffer), reasoning

    except asyncio.CancelledError:
        # this doesn't work, I don't find a way to stop generation thread
        print("Cancelled")
        streamer.on_finalized_text("cancelled", True)
        print("Signal sent")
        raise


chat_bot = gr.Chatbot(
    latex_delimiters=[
        {"left": "$$", "right": "$$", "display": True},
        {"left": "$", "right": "$", "display": False},
    ],
    scale=1,
    type="messages",
)

with gr.Blocks(js=JS) as demo:
    reasoning = gr.Markdown(
        "# Reasoning\n\nWhen the model will reasoning, its thoughts will be displayed here.",
        label="Reasoning",
        show_label=True,
        container=True,
        elem_classes="auto-scroll",
        max_height="90vh",
        render=False,
    )
    with gr.Row(equal_height=True, height="90vh"):
        with gr.Column(scale=3):
            gr.ChatInterface(
                chat,
                type="messages",
                chatbot=chat_bot,
                title=str(model_name),
                description=(
                    f"*{model_name}* is a large language model "
                    "trained on a mixture of instruction and "
                    "conversational data."
                ),
                additional_outputs=[reasoning],
            )

        with gr.Column():
            reasoning.render()


if __name__ == "__main__":
    demo.queue().launch()