Spaces:
Sleeping
Sleeping
File size: 5,952 Bytes
01041bb d8ef477 01041bb 4e38b02 d8ef477 cbfd371 d8ef477 01041bb cbfd371 01041bb d8ef477 cbfd371 d8ef477 cca068b a54d338 cbfd371 d8ef477 cbfd371 01041bb d8ef477 cbfd371 d8ef477 cbfd371 d8ef477 3f93f94 cbfd371 d8ef477 01041bb d8ef477 cbfd371 d8ef477 cbfd371 d8ef477 cbfd371 cca068b cbfd371 01041bb d8ef477 ef83945 d8ef477 01041bb d8ef477 cbfd371 d8ef477 13f4ac6 d8ef477 cca068b 01041bb d8ef477 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import threading
import torch
# Load base model directly and then add the adapter
model = AutoModelForCausalLM.from_pretrained("unsloth/gemma-3-1b-it")
# Apply adapter from the fine-tuned version
model.load_adapter("Oysiyl/gemma-3-1B-GRPO")
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it")
if torch.cuda.is_available():
model.to("cuda")
if torch.backends.mps.is_available():
model.to("mps")
def respond(
user_message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
# Format messages according to Gemma's expected chat format
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
# # Process the conversation history
# No need for history functionality for this reasoning model
# if history:
# messages.extend(process_history(history))
# Add the new user message
messages.append({"role": "user", "content": user_message})
# Apply chat template
prompt = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
inputs = tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = inputs.to("cuda")
elif torch.backends.mps.is_available():
inputs = inputs.to("mps")
# Set up the streamer
streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=False)
# Run generation in a separate thread
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=64, # Recommended Gemma-3 setting
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
output = ""
for token in streamer:
output += token
# Tags
start_tag = "<start_working_out>"
sol_start = "<SOLUTION>"
thinking = ""
final_answer = ""
# Extract "Thinking" section: everything after <start_working_out>
if start_tag in output:
start_idx = output.find(start_tag) + len(start_tag)
# If <SOLUTION> is also present, stop "Thinking" at <SOLUTION>
if sol_start in output:
end_idx = output.find(sol_start)
else:
end_idx = len(output)
thinking = output[start_idx:end_idx].strip()
# Extract "Final answer" section: everything after <SOLUTION>
if sol_start in output:
sol_start_idx = output.find(sol_start) + len(sol_start)
final_answer = output[sol_start_idx:].strip()
# Build formatted output
formatted_output = ""
if thinking:
formatted_output += "### Thinking:\n" + thinking + "\n"
if final_answer:
formatted_output += "\n### Final answer:\n**" + final_answer + "**"
# If nothing found yet, just show the raw output (for streaming effect)
if not thinking and not final_answer:
formatted_output = output
yield formatted_output
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are given a problem.\nThink about the problem and provide your working out.\nPlace it between <start_working_out> and <end_working_out>.\nThen, provide your solution between <SOLUTION></SOLUTION>",
label="System message"
),
gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
examples=[
["Toulouse has twice as many sheep as Charleston. Charleston has 4 times as many sheep as Seattle. How many sheep do Toulouse, Charleston, and Seattle have together if Seattle has 20 sheep?"],
["A football team played 22 games. They won 8 more than they lost. How many did they win?"],
["Jim spends 2 hours watching TV and then decides to go to bed and reads for half as long. He does this 3 times a week. How many hours does he spend on TV and reading in 4 weeks?"],
["Darrell and Allen's ages are in the ratio of 7:11. If their total age now is 162, calculate Allen's age 10 years from now."],
["In a neighborhood, the number of rabbits pets is twelve less than the combined number of pet dogs and cats. If there are two cats for every dog, and the number of dogs is 60, how many pets in total are in the neighborhood?"],
],
cache_examples=False,
chatbot=gr.Chatbot(
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False}
],
),
title="Demo: Finetuned Gemma 3 1B (Oysiyl/gemma-3-1B-GRPO) on GSM8K with GRPO",
description=(
"This is a demo of a finetuned Gemma 3 1B model ([Oysiyl/gemma-3-1B-GRPO](https://huggingface.co/Oysiyl/gemma-3-1B-GRPO)) "
"on the [openai/gsm8k](https://huggingface.co/datasets/openai/gsm8k) dataset using the GRPO technique. "
"Finetuning and reasoning approach inspired by [Unsloth's notebook](https://docs.unsloth.ai/basics/reasoning-grpo-and-rl/tutorial-train-your-own-reasoning-model-with-grpo). "
"This demo does not support conversation history, as the GSM8K dataset consists of single-turn questions."
),
)
if __name__ == "__main__":
demo.launch() |