File size: 4,272 Bytes
e0c9e04
bc76ed6
e0c9e04
 
 
2c7beb2
e0c9e04
3eec860
 
 
 
 
 
 
 
 
 
 
 
e0c9e04
 
 
3eec860
e0c9e04
3eec860
e0c9e04
3eec860
e0c9e04
 
3eec860
e0c9e04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48a12f0
bc76ed6
 
 
 
 
 
 
 
 
 
 
48a12f0
bc76ed6
48a12f0
 
 
 
 
2c7beb2
ca9eb6e
bc76ed6
 
 
ca9eb6e
 
48a12f0
ca9eb6e
bc76ed6
e0c9e04
 
 
 
73839bd
e0c9e04
bc76ed6
 
 
 
 
 
 
 
 
 
73839bd
 
 
 
bc76ed6
 
 
73839bd
ca9eb6e
e0c9e04
 
 
 
 
 
 
 
 
 
ca9eb6e
3eec860
 
 
 
 
 
 
e0c9e04
 
 
 
 
3eec860
e0c9e04
 
01c18cb
3eec860
e0c9e04
 
 
 
103b4b8
e0c9e04
103b4b8
e0c9e04
103b4b8
 
 
 
e0c9e04
 
 
 
01c18cb
e0c9e04
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import re
import threading

import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

CSS = """
.m3d-auto-scroll > * {
    overflow: auto;
}

#reasoning {
    overflow: auto;
    height: calc(100vh - 128px);
    scroll-behavior: smooth;
}
"""

JS = """
() => {
    // auto scroll .auto-scroll elements when text has changed
    const block = document.querySelector('#reasoning');
    const observer = new MutationObserver((mutations) => {
        block.scrollTop = block.scrollHeight;
    })
    observer.observe(block, {
            childList: true,
            characterData: true,
            subtree: true,
    });
}
"""

model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto",
)
print(dir(model))
print(model.config)
tokenizer = AutoTokenizer.from_pretrained(model_name)


def reformat_math(text):
    """Fix MathJax delimiters to use the Gradio syntax.

    This is a workaround to display math formulas in Gradio. For now, I havn't found a way to
    make it work as expected using others latex_delimites...
    """
    text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
    text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
    return text


@spaces.GPU
def chat(prompt, history):
    """Respond to a chat prompt."""
    message = {
        "role": "user",
        "content": prompt,
    }

    # build the messages list
    history = [] if history is None else history
    message_list = history + [message]

    text = tokenizer.apply_chat_template(
        message_list,
        tokenize=False,
        add_generation_prompt=True,
    )

    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

    threading.Thread(
        target=model.generate,
        kwargs=dict(
            max_new_tokens=1024 * 128,
            streamer=streamer,
            **model_inputs,
        ),
    ).start()

    buffer = ""
    reasoning = ""
    thinking = False
    reasoning_heading = "# Reasoning\n\n"

    for new_text in streamer:
        if not thinking and "<think>" in new_text:
            thinking = True
            continue
        if thinking and "</think>" in new_text:
            thinking = False
            continue

        if thinking:
            reasoning += new_text
            yield (
                "I'm thinking, please wait a moment...",
                reasoning_heading + reasoning,
            )
            continue

        buffer += new_text
        yield reformat_math(buffer), reasoning_heading + reasoning


chat_bot = gr.Chatbot(
    latex_delimiters=[
        {"left": "$$", "right": "$$", "display": True},
        {"left": "$", "right": "$", "display": False},
    ],
    scale=1,
    type="messages",
)


with gr.Blocks(
    theme="davehornik/Tealy",
    js=JS,
    css=CSS,
    fill_height=True,
    title="Reasoning model example",
) as demo:
    reasoning = gr.Markdown(
        "# Reasoning\n\nWhen the model will reasoning, its thoughts will be displayed here.",
        label="Reasoning",
        show_label=True,
        container=True,
        elem_classes="m3d-auto-scroll",
        render=False,
    )
    with gr.Row(equal_height=True):
        with gr.Column(scale=3, variant="compact"):
            gr.ChatInterface(
                chat,
                type="messages",
                chatbot=chat_bot,
                title="Simple conversational AI with reasoning",
                description=(
                    f"We're using the **{model_name}**. It is a large language model "
                    "trained on a mixture of instruction and "
                    "conversational data. It has the capability to reason about the "
                    "prompt (the user question). "
                    "When you ask a question, you can see its thoughts "
                    "on the left block."
                ),
                additional_outputs=[reasoning],
            )

        with gr.Column(elem_id="reasoning"):
            reasoning.render()


if __name__ == "__main__":
    demo.queue().launch()