mcamargo00's picture
Upload app.py
6a197e4 verified
raw
history blame
5.41 kB
# app.py - Gradio version (much simpler for HF Spaces)
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables for model and tokenizer
model = None
tokenizer = None
label_mapping = {0: "✅ Correct", 1: "🤔 Conceptually Flawed", 2: "🔢 Computationally Flawed"}
def load_model():
"""Load your trained LoRA adapter with base model"""
global model, tokenizer
try:
from peft import AutoPeftModelForSequenceClassification
# Load the LoRA adapter model
# The adapter files should be in a folder (e.g., "./lora_adapter")
model = AutoPeftModelForSequenceClassification.from_pretrained(
"./lora_adapter", # Path to your adapter files
torch_dtype=torch.float16,
device_map="auto"
)
# Load tokenizer from the same directory
tokenizer = AutoTokenizer.from_pretrained("./lora_adapter")
logger.info("LoRA model loaded successfully")
return "LoRA model loaded successfully!"
except Exception as e:
logger.error(f"Error loading LoRA model: {e}")
# Fallback to placeholder for testing
logger.warning("Using placeholder model loading - replace with your actual model!")
model_name = "microsoft/DialoGPT-medium" # Closer to Phi-4 architecture
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=3,
ignore_mismatched_sizes=True
)
return f"Fallback model loaded. LoRA error: {e}"
def classify_solution(question: str, solution: str):
"""
Classify the math solution
Returns: (classification_label, confidence_score, explanation)
"""
if not question.strip() or not solution.strip():
return "Please fill in both fields", 0.0, ""
if not model or not tokenizer:
return "Model not loaded", 0.0, ""
try:
# Combine question and solution for input
text_input = f"Question: {question}\nSolution: {solution}"
# Tokenize input
inputs = tokenizer(
text_input,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512
)
# Get model prediction
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(predictions, dim=-1).item()
confidence = predictions[0][predicted_class].item()
classification = label_mapping[predicted_class]
# Create explanation based on classification
explanations = {
0: "The mathematical approach and calculations are both sound.",
1: "The approach or understanding has fundamental issues.",
2: "The approach is correct, but there are calculation errors."
}
explanation = explanations[predicted_class]
return classification, f"{confidence:.2%}", explanation
except Exception as e:
logger.error(f"Error during classification: {e}")
return f"Classification error: {str(e)}", "0%", ""
# Load model on startup
load_model()
# Create Gradio interface
with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🧮 Math Solution Classifier")
gr.Markdown("Classify math solutions as correct, conceptually flawed, or computationally flawed.")
with gr.Row():
with gr.Column():
question_input = gr.Textbox(
label="Math Question",
placeholder="e.g., Solve for x: 2x + 5 = 13",
lines=3
)
solution_input = gr.Textbox(
label="Proposed Solution",
placeholder="e.g., 2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4",
lines=5
)
classify_btn = gr.Button("Classify Solution", variant="primary")
with gr.Column():
classification_output = gr.Textbox(label="Classification", interactive=False)
confidence_output = gr.Textbox(label="Confidence", interactive=False)
explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3)
# Examples
gr.Examples(
examples=[
[
"Solve for x: 2x + 5 = 13",
"2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4"
],
[
"Find the derivative of f(x) = x²",
"f'(x) = 2x + 1" # This should be computationally flawed
],
[
"What is 15% of 200?",
"15% = 15/100 = 0.15\n0.15 × 200 = 30"
]
],
inputs=[question_input, solution_input]
)
classify_btn.click(
fn=classify_solution,
inputs=[question_input, solution_input],
outputs=[classification_output, confidence_output, explanation_output]
)
if __name__ == "__main__":
app.launch()