mcamargo00's picture
Upload app.py
cecea85 verified
raw
history blame
7.99 kB
# app.py - Gradio version (much simpler for HF Spaces)
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables for model and tokenizer
model = None
tokenizer = None
label_mapping = {0: "✅ Correct", 1: "🤔 Conceptually Flawed", 2: "🔢 Computationally Flawed"}
def load_model():
"""Load your trained LoRA adapter with base model"""
global model, tokenizer
try:
from peft import AutoPeftModelForCausalLM # Changed from SequenceClassification
# Load the LoRA adapter model for text generation
model = AutoPeftModelForCausalLM.from_pretrained(
"./lora_adapter", # Path to your adapter files
torch_dtype=torch.float16,
device_map="auto"
)
# Load tokenizer from the same directory
tokenizer = AutoTokenizer.from_pretrained("./lora_adapter")
# Fix padding token issue
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Set pad_token to eos_token")
logger.info("LoRA model loaded successfully")
return "LoRA model loaded successfully!"
except Exception as e:
logger.error(f"Error loading LoRA model: {e}")
# Fallback to placeholder for testing
logger.warning("Using placeholder model loading - replace with your actual model!")
model_name = "microsoft/DialoGPT-medium" # Closer to Phi-4 architecture
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Fix padding token for fallback model too
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_name)
return f"Fallback model loaded. LoRA error: {e}"
def get_system_prompt():
"""Generates the specific system prompt for the fine-tuning task."""
return """You are a mathematics tutor.
You are given a math word problem, and a solution written by a student.
Analyze the solution carefully, line-by-line, and classify it into one of the following categories:
- Correct (All logic is correct, and all calculations are correct)
- Conceptual Error (There is an error in reasoning or logic somewhere in the solution)
- Computational Error (All logic and reasoning is correct, but the result of some calculation is incorrect)
Respond *only* with a valid JSON object that follows this exact schema:
```json
{
"verdict": "must be one of 'correct', 'conceptual_error', or 'computational_error'",
"erroneous_line": "the exact, verbatim text of the first incorrect line, or null if the verdict is 'correct'",
"explanation": "a brief, one-sentence explanation of the error, or null if the verdict is 'correct'"
}
```
Do NOT add any text or explanations before or after the JSON object.
"""
def classify_solution(question: str, solution: str):
"""
Classify the math solution using the exact training format
Returns: (classification_label, confidence_score, explanation)
"""
if not question.strip() or not solution.strip():
return "Please fill in both fields", "", ""
if not model or not tokenizer:
return "Model not loaded", "", ""
try:
# Create the exact prompt format used in training
system_prompt = get_system_prompt()
user_message = f"Problem: {question}\n\nSolution:\n{solution}"
# Format as chat messages (common for instruction-tuned models)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
]
# Apply chat template
text_input = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_token=True
)
# Tokenize input
inputs = tokenizer(
text_input,
return_tensors="pt",
truncation=True,
padding=True,
max_length=2048 # Increased for longer prompts
)
# Generate response (not just classify)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.1,
do_sample=True,
pad_token_id=tokenizer.pad_token_id
)
# Decode the generated response
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the JSON response (after the input)
response_start = generated_text.find(text_input) + len(text_input)
json_response = generated_text[response_start:].strip()
# Parse the JSON response
import json
try:
result = json.loads(json_response)
verdict = result.get("verdict", "unknown")
erroneous_line = result.get("erroneous_line", "")
explanation = result.get("explanation", "")
# Map verdict to display format
verdict_mapping = {
"correct": "✅ Correct",
"conceptual_error": "🤔 Conceptual Error",
"computational_error": "🔢 Computational Error"
}
display_verdict = verdict_mapping.get(verdict, f"❓ {verdict}")
return display_verdict, erroneous_line or "None", explanation or "Solution is correct"
except json.JSONDecodeError:
return f"Model response: {json_response}", "", "Could not parse JSON response"
except Exception as e:
logger.error(f"Error during classification: {e}")
return f"Classification error: {str(e)}", "", ""
# Load model on startup
load_model()
# Create Gradio interface
with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🧮 Math Solution Classifier")
gr.Markdown("Classify math solutions as correct, conceptually flawed, or computationally flawed.")
with gr.Row():
with gr.Column():
question_input = gr.Textbox(
label="Math Question",
placeholder="e.g., Solve for x: 2x + 5 = 13",
lines=3
)
solution_input = gr.Textbox(
label="Proposed Solution",
placeholder="e.g., 2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4",
lines=5
)
classify_btn = gr.Button("Classify Solution", variant="primary")
with gr.Column():
classification_output = gr.Textbox(label="Classification", interactive=False)
erroneous_line_output = gr.Textbox(label="Erroneous Line", interactive=False)
explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3)
# Examples
gr.Examples(
examples=[
[
"Solve for x: 2x + 5 = 13",
"2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4"
],
[
"Find the derivative of f(x) = x²",
"f'(x) = 2x + 1" # This should be computationally flawed
],
[
"What is 15% of 200?",
"15% = 15/100 = 0.15\n0.15 × 200 = 30"
]
],
inputs=[question_input, solution_input]
)
classify_btn.click(
fn=classify_solution,
inputs=[question_input, solution_input],
outputs=[classification_output, erroneous_line_output, explanation_output]
)
if __name__ == "__main__":
app.launch()