File size: 6,317 Bytes
01041bb
d8ef477
 
 
01041bb
4e38b02
 
 
 
 
d8ef477
cbfd371
 
 
 
 
 
d8ef477
 
 
 
 
a54d338
 
 
 
d8ef477
 
 
 
 
 
01041bb
 
 
cbfd371
01041bb
 
 
 
 
 
d8ef477
 
 
cbfd371
d8ef477
a54d338
 
 
 
 
cbfd371
d8ef477
 
cbfd371
01041bb
d8ef477
cbfd371
d8ef477
cbfd371
d8ef477
3f93f94
cbfd371
 
d8ef477
 
 
 
 
 
 
 
 
01041bb
 
d8ef477
cbfd371
d8ef477
 
 
 
 
cbfd371
d8ef477
 
cbfd371
 
 
 
add68a2
cbfd371
 
 
 
 
 
 
 
 
 
d8ef477
cbfd371
 
 
 
 
 
add68a2
 
 
 
 
 
cbfd371
 
 
 
 
 
 
 
 
 
 
 
 
01041bb
add68a2
 
 
 
01041bb
 
 
 
 
 
 
d8ef477
 
 
 
ef83945
d8ef477
01041bb
 
 
 
 
 
 
 
d8ef477
cbfd371
 
 
 
 
d8ef477
13f4ac6
d8ef477
 
 
 
 
 
01041bb
 
 
 
d8ef477
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import threading
import torch

# Load base model directly and then add the adapter
model = AutoModelForCausalLM.from_pretrained("unsloth/gemma-3-1b-it")
# Apply adapter from the fine-tuned version
model.load_adapter("Oysiyl/gemma-3-1B-GRPO")
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it")

if torch.cuda.is_available():
    model.to("cuda")

if torch.backends.mps.is_available():
    model.to("mps")


def process_history(history):
    """Process chat history into the format expected by the model."""
    processed_history = []
    for user_msg, assistant_msg in history:
        # Always add user message first, even if empty
        processed_history.append({"role": "user", "content": [{"type": "text", "text": user_msg or ""}]})
        # Always add assistant message, even if empty
        processed_history.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg or ""}]})
    return processed_history


def process_new_user_message(message):
    """Process a new user message into the format expected by the model."""
    return [{"type": "text", "text": message}]


def respond(
    user_message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # Format messages according to Gemma's expected chat format
    messages = []
    if system_message:
        messages.append({"role": "system", "content": system_message})
    
    # Process the conversation history
    if history:
        messages.extend(process_history(history))
    
    # Add the new user message
    messages.append({"role": "user", "content": user_message})
    
    # Apply chat template
    prompt = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )
    inputs = tokenizer(prompt, return_tensors="pt")
    if torch.cuda.is_available():
        inputs = inputs.to("cuda")
    elif torch.backends.mps.is_available():
        inputs = inputs.to("mps")
    
    # Set up the streamer
    streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=False)
    
    # Run generation in a separate thread
    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=64,  # Recommended Gemma-3 setting
        # do_sample=True,
    )
    
    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()
    
    output = ""
    for token in streamer:
        output += token

        # Tags
        start_tag = "<start_working_out>"
        sol_start = "<SOLUTION>"
        sol_end = "</SOLUTION>"

        thinking = ""
        final_answer = ""

        # Extract "Thinking" section: everything after <start_working_out>
        if start_tag in output:
            start_idx = output.find(start_tag) + len(start_tag)
            # If <SOLUTION> is also present, stop "Thinking" at <SOLUTION>
            if sol_start in output:
                end_idx = output.find(sol_start)
            else:
                end_idx = len(output)
            thinking = output[start_idx:end_idx].strip()

        # Extract "Final answer" section: everything after <SOLUTION>
        if sol_start in output:
            sol_start_idx = output.find(sol_start) + len(sol_start)
            # If </SOLUTION> is present, stop at it
            if sol_end in output:
                sol_end_idx = output.find(sol_end)
                final_answer = output[sol_start_idx:sol_end_idx].strip()
            else:
                final_answer = output[sol_start_idx:].strip()

        # Build formatted output
        formatted_output = ""
        if thinking:
            formatted_output += "### Thinking:\n" + thinking + "\n"
        if final_answer:
            formatted_output += "\n### Final answer:\n**" + final_answer + "**"

        # If nothing found yet, just show the raw output (for streaming effect)
        if not thinking and not final_answer:
            formatted_output = output

        yield formatted_output

        # If </SOLUTION> is found, end the response
        if sol_end in output:
            break


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are given a problem.\nThink about the problem and provide your working out.\nPlace it between <start_working_out> and <end_working_out>.\nThen, provide your solution between <SOLUTION></SOLUTION>",
            label="System message"
        ),
        gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    examples=[
        ["Toulouse has twice as many sheep as Charleston. Charleston has 4 times as many sheep as Seattle. How many sheep do Toulouse, Charleston, and Seattle have together if Seattle has 20 sheep?"],
        ["A football team played 22 games. They won 8 more than they lost. How many did they win?"],
        ["Jim spends 2 hours watching TV and then decides to go to bed and reads for half as long. He does this 3 times a week. How many hours does he spend on TV and reading in 4 weeks?"],
        ["Darrell and Allen's ages are in the ratio of 7:11. If their total age now is 162, calculate Allen's age 10 years from now."],
        ["In a neighborhood, the number of rabbits pets is twelve less than the combined number of pet dogs and cats. If there are two cats for every dog, and the number of dogs is 60, how many pets in total are in the neighborhood?"],   
    ],
    cache_examples=False,
    chatbot=gr.Chatbot(
        latex_delimiters=[
            {"left": "$$", "right": "$$", "display": True},
            {"left": "$", "right": "$", "display": False}
        ],
    ),
)


if __name__ == "__main__":
    demo.launch()