# app.py ── Math-solution classifier for HF Spaces # Compatible with both LoRA-classification and LoRA-causal-LM adapters # Requirements (pin in requirements.txt): # gradio torch transformers peft accelerate spaces import os import json import logging from typing import Tuple import gradio as gr import torch from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, ) # PEFT imports (optional) try: from peft.auto import ( AutoPeftModelForSequenceClassification, AutoPeftModelForCausalLM, ) PEFT_AVAILABLE = True except ImportError: # PEFT not installed PEFT_AVAILABLE = False # ────────────────────────────────────────────────────────────────────────────── # Config & logging # ────────────────────────────────────────────────────────────────────────────── logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) ADAPTER_PATH = os.getenv("ADAPTER_PATH", "./lora_adapter") # local dir or Hub ID FALLBACK_MODEL = "distilbert-base-uncased" LABELS = {0: "✅ Correct", 1: "🤔 Conceptual Error", 2: "🔢 Computational Error"} device = "cuda" if torch.cuda.is_available() else "cpu" model = None tokenizer = None model_ty = None # "classification" | "causal_lm" | "baseline" # ────────────────────────────────────────────────────────────────────────────── # Model loader # ────────────────────────────────────────────────────────────────────────────── def load_model(): """Try adapter as classifier → causal-LM → plain baseline.""" global model, tokenizer, model_ty dtype = torch.float16 if device == "cuda" else torch.float32 if PEFT_AVAILABLE and os.path.isdir(ADAPTER_PATH): logger.info(f"Found adapter at {ADAPTER_PATH}") # 1) Try sequence-classification adapter try: model = AutoPeftModelForSequenceClassification.from_pretrained( ADAPTER_PATH, torch_dtype=dtype, device_map="auto" if device == "cuda" else None, ) model_ty = "classification" logger.info("Loaded adapter as sequence-classifier") except ValueError: # 2) Fall back to causal-LM adapter logger.info("Adapter is not a classifier – trying causal-LM") model = AutoPeftModelForCausalLM.from_pretrained( ADAPTER_PATH, torch_dtype=dtype, device_map="auto" if device == "cuda" else None, ) model_ty = "causal_lm" logger.info("Loaded adapter as causal-LM") tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH) else: logger.warning("No adapter found – using baseline DistilBERT classifier") tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL) model = AutoModelForSequenceClassification.from_pretrained( FALLBACK_MODEL, num_labels=3, ignore_mismatched_sizes=True, torch_dtype=dtype, ) model_ty = "baseline" # Make sure we have a pad token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token or tokenizer.sep_token model.to(device) model.eval() logger.info(f"Model ready on {device} as {model_ty}") # ────────────────────────────────────────────────────────────────────────────── # Inference helpers # ────────────────────────────────────────────────────────────────────────────── def _classify_logits(question: str, solution: str) -> Tuple[str, str, str]: text = f"Question: {question}\n\nSolution:\n{solution}" inputs = tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=512, ).to(device) with torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=-1)[0] pred = int(torch.argmax(probs)) conf = f"{probs[pred].item():.3f}" return LABELS.get(pred, "Unknown"), conf, "—" def _classify_generate(question: str, solution: str) -> Tuple[str, str, str]: # Prompt must match the format you used in tuning prompt = ( "You are a mathematics tutor.\n" "You are given a math word problem and a student's solution. " "Decide whether the solution is correct.\n\n" "- Correct = all reasoning and calculations are correct.\n" "- Conceptual Error = reasoning is wrong.\n" "- Computational Error= reasoning okay but arithmetic off.\n\n" "Reply with ONLY one of these JSON lines:\n" '{"verdict": "correct"}\n' '{"verdict": "conceptual"}\n' '{"verdict": "computational"}\n\n' f"Question: {question}\n\nSolution:\n{solution}\n\nAnswer:" ) inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): out_ids = model.generate( **inputs, max_new_tokens=32, pad_token_id=tokenizer.eos_token_id, ) generated = tokenizer.decode( out_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True, ).strip() # Try to parse last JSON line verdict = "Unparsed" try: line = generated.splitlines()[-1] data = json.loads(line) v = data.get("verdict", "").lower() if v.startswith("corr"): verdict = LABELS[0] elif v.startswith("conc"): verdict = LABELS[1] elif v.startswith("comp"): verdict = LABELS[2] except Exception: pass return verdict, "", generated def classify(question: str, solution: str): if not question.strip() or not solution.strip(): return "Please enter both fields.", "", "" if model_ty in ("classification", "baseline"): return _classify_logits(question, solution) elif model_ty == "causal_lm": return _classify_generate(question, solution) else: return "Model not loaded.", "", "" # ────────────────────────────────────────────────────────────────────────────── # Build Gradio UI # ────────────────────────────────────────────────────────────────────────────── load_model() with gr.Blocks(title="Math Solution Classifier") as demo: gr.Markdown("# 🧮 Math Solution Classifier") gr.Markdown( "Classify a student's math solution as **correct**, **conceptually flawed**, " "or **computationally flawed**." ) with gr.Row(): with gr.Column(): q_in = gr.Textbox(label="Math Question", lines=3) s_in = gr.Textbox(label="Proposed Solution", lines=6) btn = gr.Button("Classify", variant="primary") with gr.Column(): verdict = gr.Textbox(label="Verdict", interactive=False) conf = gr.Textbox(label="Confidence", interactive=False) raw = gr.Textbox(label="Model Output", interactive=False) btn.click(classify, [q_in, s_in], [verdict, conf, raw]) gr.Examples( [ ["Solve for x: 2x + 5 = 13", "2x + 5 = 13\n2x = 8\nx = 4"], ["Find the derivative of f(x)=x²", "f'(x)=2x+1"], ["What is 15% of 200?", "0.15 × 200 = 30"], ], inputs=[q_in, s_in], ) if __name__ == "__main__": demo.launch()