GRPO-Any-Model / app.py
wassemgtk's picture
Update app.py
ae78897 verified
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from typing import List, Dict, Tuple
import json
import os
from datetime import datetime
class GRPOTrainer:
def __init__(self):
self.model = None
self.ref_model = None
self.tokenizer = None
self.optimizer = None
self.training_history = []
def load_model(self, model_name: str) -> str:
"""Load the model and tokenizer"""
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
self.ref_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
# Set padding token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Freeze reference model
for param in self.ref_model.parameters():
param.requires_grad = False
return f"βœ… Successfully loaded model: {model_name}"
except Exception as e:
return f"❌ Error loading model: {str(e)}"
def compute_rewards(self, prompts: List[str], responses: List[str]) -> torch.Tensor:
"""Compute rewards for responses (simplified reward function)"""
rewards = []
for response in responses:
# Simple reward based on response length and diversity
length_reward = min(len(response.split()) / 50, 1.0)
unique_words = len(set(response.lower().split()))
diversity_reward = min(unique_words / 20, 1.0)
reward = (length_reward + diversity_reward) / 2
rewards.append(reward)
return torch.tensor(rewards)
def compute_kl_penalty(self, logits: torch.Tensor, ref_logits: torch.Tensor) -> torch.Tensor:
"""Compute KL divergence penalty"""
probs = F.softmax(logits, dim=-1)
ref_probs = F.softmax(ref_logits, dim=-1)
kl = (probs * (probs / ref_probs).log()).sum(-1)
return kl.mean()
def grpo_step(self, prompts: List[str], beta: float = 0.1) -> Dict:
"""Perform one GRPO training step"""
if not self.model or not self.tokenizer:
return {"error": "Model not loaded"}
# Tokenize prompts
inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
# Generate responses
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
max_length=inputs.input_ids.shape[1] + 50,
do_sample=True,
temperature=0.8,
pad_token_id=self.tokenizer.pad_token_id
)
# Decode responses
responses = []
for output in outputs:
response = self.tokenizer.decode(output[inputs.input_ids.shape[1]:], skip_special_tokens=True)
responses.append(response)
# Compute rewards
rewards = self.compute_rewards(prompts, responses)
# Forward pass through both models
self.model.train()
model_outputs = self.model(inputs.input_ids)
ref_outputs = self.ref_model(inputs.input_ids)
# Compute KL penalty
kl_penalty = self.compute_kl_penalty(model_outputs.logits, ref_outputs.logits)
# Compute loss (simplified GRPO loss)
loss = -rewards.mean() + beta * kl_penalty
# Backward pass
if self.optimizer:
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {
"loss": loss.item(),
"reward": rewards.mean().item(),
"kl_penalty": kl_penalty.item(),
"responses": responses
}
def train(self, prompts: List[str], num_steps: int, lr: float, beta: float) -> str:
"""Run GRPO training"""
if not self.model:
return "❌ Please load a model first"
# Initialize optimizer
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
results = []
for step in range(num_steps):
step_result = self.grpo_step(prompts, beta)
if "error" in step_result:
return f"❌ Error: {step_result['error']}"
result_str = f"Step {step + 1}/{num_steps} - Loss: {step_result['loss']:.4f}, Reward: {step_result['reward']:.4f}, KL: {step_result['kl_penalty']:.4f}"
results.append(result_str)
# Store training history
self.training_history.append({
"step": step + 1,
"loss": step_result['loss'],
"reward": step_result['reward'],
"kl_penalty": step_result['kl_penalty']
})
return "\n".join(results)
def generate_response(self, prompt: str, max_length: int = 100, temperature: float = 0.8) -> str:
"""Generate a response using the trained model"""
if not self.model or not self.tokenizer:
return "❌ Please load a model first"
inputs = self.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
max_length=inputs.input_ids.shape[1] + max_length,
temperature=temperature,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id
)
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response
def save_model(self, save_path: str) -> str:
"""Save the trained model"""
if not self.model:
return "❌ No model to save"
try:
self.model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path)
# Save training history
with open(os.path.join(save_path, "training_history.json"), "w") as f:
json.dump(self.training_history, f)
return f"βœ… Model saved to {save_path}"
except Exception as e:
return f"❌ Error saving model: {str(e)}"
# Initialize trainer
trainer = GRPOTrainer()
# Gradio interface
def load_model_interface(model_name):
return trainer.load_model(model_name)
def train_interface(prompts_text, num_steps, learning_rate, beta):
prompts = [p.strip() for p in prompts_text.split("\n") if p.strip()]
if not prompts:
return "❌ Please provide at least one prompt"
return trainer.train(prompts, int(num_steps), float(learning_rate), float(beta))
def generate_interface(prompt, max_length, temperature):
return trainer.generate_response(prompt, int(max_length), float(temperature))
def save_model_interface(save_path):
return trainer.save_model(save_path)
def get_training_history():
if not trainer.training_history:
return "No training history available"
history_str = "Training History:\n"
history_str += "-" * 50 + "\n"
for entry in trainer.training_history[-10:]: # Show last 10 entries
history_str += f"Step {entry['step']}: Loss={entry['loss']:.4f}, Reward={entry['reward']:.4f}, KL={entry['kl_penalty']:.4f}\n"
return history_str
# Create Gradio interface
with gr.Blocks(title="GRPO Model Training") as app:
gr.Markdown("# πŸš€ GRPO (Group Relative Policy Optimization) Training App")
gr.Markdown("Train language models using GRPO technique with this simple interface")
with gr.Tab("πŸ”§ Model Setup"):
with gr.Row():
model_input = gr.Textbox(
label="Model Name",
value="Writer/Palmyra-56B-Instruct",
placeholder="Enter HuggingFace model name (e.g., Palmyra, Qwen, Llama)"
)
load_btn = gr.Button("Load Model", variant="primary")
model_status = gr.Textbox(label="Status", lines=2)
load_btn.click(load_model_interface, inputs=model_input, outputs=model_status)
with gr.Tab("🎯 Training"):
with gr.Row():
with gr.Column():
prompts_input = gr.Textbox(
label="Training Prompts (one per line)",
lines=5,
value="Tell me about artificial intelligence\nExplain quantum computing\nWhat is machine learning?",
placeholder="Enter your prompts here..."
)
with gr.Column():
num_steps_input = gr.Slider(
label="Number of Training Steps",
minimum=1,
maximum=100,
value=10,
step=1
)
lr_input = gr.Number(
label="Learning Rate",
value=1e-5,
step=1e-6
)
beta_input = gr.Number(
label="KL Penalty Weight (Ξ²)",
value=0.1,
step=0.01
)
train_btn = gr.Button("Start Training", variant="primary")
training_output = gr.Textbox(label="Training Progress", lines=10)
train_btn.click(
train_interface,
inputs=[prompts_input, num_steps_input, lr_input, beta_input],
outputs=training_output
)
with gr.Tab("πŸ’¬ Generation"):
with gr.Row():
with gr.Column():
gen_prompt = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
value="Tell me about"
)
max_length = gr.Slider(
label="Max Length",
minimum=10,
maximum=500,
value=100,
step=10
)
temp_slider = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1
)
with gr.Column():
gen_btn = gr.Button("Generate", variant="primary")
gen_output = gr.Textbox(label="Generated Response", lines=10)
gen_btn.click(
generate_interface,
inputs=[gen_prompt, max_length, temp_slider],
outputs=gen_output
)
with gr.Tab("πŸ’Ύ Save Model"):
save_path_input = gr.Textbox(
label="Save Path",
value="./grpo_trained_model",
placeholder="Enter path to save the model"
)
save_btn = gr.Button("Save Model", variant="primary")
save_status = gr.Textbox(label="Save Status")
save_btn.click(save_model_interface, inputs=save_path_input, outputs=save_status)
with gr.Tab("πŸ“Š Training History"):
history_btn = gr.Button("Refresh History", variant="secondary")
history_output = gr.Textbox(label="Training History", lines=15)
history_btn.click(get_training_history, outputs=history_output)
gr.Markdown("""
## πŸ“ Instructions:
1. **Load Model**: Start by loading a pre-trained model from HuggingFace
2. **Training**: Add your prompts and configure training parameters
3. **Generation**: Test your trained model with custom prompts
4. **Save**: Save your fine-tuned model for later use
## ⚠️ Note:
- This is a simplified GRPO implementation for demonstration
- For production use, consider more sophisticated reward functions
- GPU recommended for larger models
""")
# Launch the app
if __name__ == "__main__":
app.launch(share=True)