fyzanshaik's picture
Update app.py
03f2222 verified
# app.py (Revised for Unsloth LoRA Gemma Model)
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM # We still use AutoModelForCausalLM for the base model
import torch
# Import unsloth for loading the adapters
from unsloth import FastLanguageModel
# --- Configuration ---
BASE_MODEL_NAME = "unsloth/gemma-3-4b-it"
ADAPTER_MODEL_NAME = "neuralnets/cf_codebot" # Your friend's fine-tuned adapters
# --- Model Loading ---
# This block will run once when the Space starts up.
try:
# Load the base model and tokenizer using unsloth's optimized method
# This automatically handles loading the tokenizer too.
# We specify "bf16" for faster inference if GPU is available, else it will default.
# max_seq_length is important for context window. 2048 is a common default for Gemma.
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = BASE_MODEL_NAME,
max_seq_length = 2048, # Max context length the model can handle
dtype = torch.bfloat16, # Optimized dtype for performance
load_in_4bit = True, # Load in 4-bit to save memory (even on CPU, though less impact than GPU)
)
# Load the LoRA adapters from your friend's model onto the base model
model = FastLanguageModel.get_peft_model(
model,
# Default LoRA configuration for inference (should match training if possible)
# If your friend shared their training config, use those ranks.
r = 16, # Rank of the LoRA adapters
target_modules = FastLanguageModel.get_model_peft_target_modules(model),
lora_alpha = 16, # Alpha value for LoRA
lora_dropout = 0, # Dropout for inference is usually 0
bias = "none",
use_gradient_checkpointing = False,
random_state = 3407,
max_seq_length = 2048,
# `use_te_vllm` for inference if you have specific hardware, but usually not needed for basic deployment
)
# Load the trained adapters
model.load_lora_weights(ADAPTER_MODEL_NAME)
# Set model to evaluation mode
model.eval()
# Move model to device (unsloth often handles this, but explicit is good)
# Note: Unsloth's 4-bit loading often uses `accelerate` which handles device placement.
# Keeping `device` print for debugging.
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device) # No need to explicitly move model if load_in_4bit is True, handled by bitsandbytes/accelerate
print(f"Base model '{BASE_MODEL_NAME}' and adapters '{ADAPTER_MODEL_NAME}' loaded successfully.")
# You can infer the actual device from the model object's parameters later if needed.
except Exception as e:
print(f"Error loading model '{BASE_MODEL_NAME}' or adapters '{ADAPTER_MODEL_NAME}': {e}")
print("Using a dummy function for demonstration purposes.")
tokenizer, model = None, None # Indicate model not loaded
# --- Inference Function ---
def generate_editorial(problem_statement: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
if model is None or tokenizer is None: # If model failed to load, use dummy
print("Model not loaded, using dummy generation.")
if "watermelon" in problem_statement.lower():
return "To be able to split the watermelon such that each part is even..."
return "This is a placeholder editorial based on your problem statement.\n(Model failed to load, check logs)"
try:
# Construct the prompt in an instruction-tuned format
# This is CRUCIAL for instruction-tuned models like Gemma-IT
# You need to ensure the format matches what the model was trained on.
# Common format for instruction models:
# prompt = f"### Instruction:\n{problem_statement}\n\n### Response:\n"
# Unsloth's `FastLanguageModel.chat_template` or `apply_chat_template` is ideal here.
# This function generates the correct chat format for the model.
messages = [
{"role": "user", "content": problem_statement}
]
# Apply the chat template. add_generation_prompt=True ensures it's ready for generation.
# This adds special tokens like <bos><start_of_turn>user ... <end_of_turn><start_of_turn>model
input_text = tokenizer.apply_chat_template(
messages,
tokenize=False, # We want the string, not token IDs
add_generation_prompt=True
)
# Tokenize the input string
inputs = tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=tokenizer.model_max_length # Use model's max_length
).to(model.device) # Ensure inputs are on the same device as the model
# Generate text
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
do_sample=True,
top_k=50,
top_p=top_p,
temperature=temperature,
pad_token_id=tokenizer.eos_token_id, # Ensure pad_token_id is set
# Stopping criteria: for instruction-tuned models, often <eos_token> or specific strings.
# If your friend's model generates "<end_of_turn>" specifically, keep that.
# Otherwise, the default generation stopping (tokenizer.eos_token_id) usually suffices.
# `stop_sequences=["<end_of_turn>"]`
)
# Decode the generated text
# We need to skip the input prompt from the generated text
# `skip_special_tokens=True` for clean text, but check if it affects your specific `<end_of_turn>`
generated_sequence = tokenizer.decode(outputs[0], skip_special_tokens=False)
# Extract only the model's response.
# The `apply_chat_template` typically produces something like:
# "<bos><start_of_turn>user\n{problem_statement}<end_of_turn>\n<start_of_turn>model\n"
# We want to find the start of the model's response and take everything after it.
response_start_marker = "<start_of_turn>model\n" # or similar based on template
if response_start_marker in generated_sequence:
editorial_content = generated_sequence.split(response_start_marker)[-1].strip()
else:
# Fallback if marker not found, or if generated_sequence starts with input
editorial_content = generated_sequence.strip()
if editorial_content.startswith(input_text):
editorial_content = editorial_content[len(input_text):].strip()
# Remove any lingering special tokens like <end_of_turn> or <eos_token>
# (tokenizer.decode with skip_special_tokens=True might handle this, but manual clean is safer)
editorial_content = editorial_content.replace("<end_of_turn>", "").replace(tokenizer.eos_token, "").strip()
return editorial_content
except Exception as e:
print(f"Error during inference: {e}")
return f"An error occurred during editorial generation: {e}"
# --- Gradio Interface Setup ---
iface = gr.Interface(
fn=generate_editorial,
inputs=[
gr.Textbox(lines=10, label="Problem Statement", placeholder="Paste your problem statement here...", autofocus=True),
gr.Slider(minimum=1, maximum=1024, value=400, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
outputs=gr.Textbox(label="Generated Editorial"),
title="Codeforces Editorial Assistant (Gemma LoRA)",
description="Paste a Codeforces problem statement and get a generated editorial from neuralnets/cf_codebot (Gemma-3-4b-it LoRA).",
flagging_mode="auto", # Updated from allow_flagging
examples=[
[
"A. Watermelon\ntime limit per test\n1 second\nmemory limit per test\n64 megabytes\n\nOne hot summer day Pete and his friend Billy decided to buy a watermelon. They chose the biggest and the ripest one, in their opinion. After that the watermelon was weighed, and the scales showed w kilos. They rushed home, dying of thirst, and decided to divide the berry, however they faced a hard problem.\n\nPete and Billy are great fans of even numbers, that's why they want to divide the watermelon in such a way that each of the two parts weighs even number of kilos, at the same time it is not obligatory that the parts are equal. The boys are extremely tired and want to start their meal as soon as possible, that's why you should help them and find out, if they can divide the watermelon in the way they want. For sure, each of them should get a part of positive weight.\nInput\n\nThe first (and the only) input line contains integer number w (1 ≀ w ≀ 100) β€” the weight of the watermelon bought by the boys.\nOutput\n\nPrint YES, if the boys can divide the watermelon into two parts, each of them weighing even number of kilos; and NO in the opposite case.",
400,
0.7,
0.95
]
]
)
if __name__ == "__main__":
iface.launch()