mcamargo00's picture
Upload app.py
0c3715f verified
raw
history blame
8.63 kB
# app.py โ”€โ”€ Math-solution classifier for HF Spaces
# Compatible with both LoRA-classification and LoRA-causal-LM adapters
# Requirements (pin in requirements.txt):
# gradio torch transformers peft accelerate spaces
import os
import json
import logging
from typing import Tuple
import gradio as gr
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
)
# PEFT imports (optional)
try:
from peft.auto import (
AutoPeftModelForSequenceClassification,
AutoPeftModelForCausalLM,
)
PEFT_AVAILABLE = True
except ImportError: # PEFT not installed
PEFT_AVAILABLE = False
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Config & logging
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
ADAPTER_PATH = os.getenv("ADAPTER_PATH", "./lora_adapter") # local dir or Hub ID
FALLBACK_MODEL = "distilbert-base-uncased"
LABELS = {0: "โœ… Correct",
1: "๐Ÿค” Conceptual Error",
2: "๐Ÿ”ข Computational Error"}
device = "cuda" if torch.cuda.is_available() else "cpu"
model = None
tokenizer = None
model_ty = None # "classification" | "causal_lm" | "baseline"
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Model loader
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def load_model():
"""Try adapter as classifier โ†’ causal-LM โ†’ plain baseline."""
global model, tokenizer, model_ty
dtype = torch.float16 if device == "cuda" else torch.float32
if PEFT_AVAILABLE and os.path.isdir(ADAPTER_PATH):
logger.info(f"Found adapter at {ADAPTER_PATH}")
# 1) Try sequence-classification adapter
try:
model = AutoPeftModelForSequenceClassification.from_pretrained(
ADAPTER_PATH,
torch_dtype=dtype,
device_map="auto" if device == "cuda" else None,
)
model_ty = "classification"
logger.info("Loaded adapter as sequence-classifier")
except ValueError:
# 2) Fall back to causal-LM adapter
logger.info("Adapter is not a classifier โ€“ trying causal-LM")
model = AutoPeftModelForCausalLM.from_pretrained(
ADAPTER_PATH,
torch_dtype=dtype,
device_map="auto" if device == "cuda" else None,
)
model_ty = "causal_lm"
logger.info("Loaded adapter as causal-LM")
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)
else:
logger.warning("No adapter found โ€“ using baseline DistilBERT classifier")
tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
model = AutoModelForSequenceClassification.from_pretrained(
FALLBACK_MODEL,
num_labels=3,
ignore_mismatched_sizes=True,
torch_dtype=dtype,
)
model_ty = "baseline"
# Make sure we have a pad token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token or tokenizer.sep_token
model.to(device)
model.eval()
logger.info(f"Model ready on {device} as {model_ty}")
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Inference helpers
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def _classify_logits(question: str, solution: str) -> Tuple[str, str, str]:
text = f"Question: {question}\n\nSolution:\n{solution}"
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
).to(device)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)[0]
pred = int(torch.argmax(probs))
conf = f"{probs[pred].item():.3f}"
return LABELS.get(pred, "Unknown"), conf, "โ€”"
def _classify_generate(question: str, solution: str) -> Tuple[str, str, str]:
# Prompt must match the format you used in tuning
prompt = (
"You are a mathematics tutor.\n"
"You are given a math word problem and a student's solution. "
"Decide whether the solution is correct.\n\n"
"- Correct = all reasoning and calculations are correct.\n"
"- Conceptual Error = reasoning is wrong.\n"
"- Computational Error= reasoning okay but arithmetic off.\n\n"
"Reply with ONLY one of these JSON lines:\n"
'{"verdict": "correct"}\n'
'{"verdict": "conceptual"}\n'
'{"verdict": "computational"}\n\n'
f"Question: {question}\n\nSolution:\n{solution}\n\nAnswer:"
)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
out_ids = model.generate(
**inputs,
max_new_tokens=32,
pad_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(
out_ids[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
).strip()
# Try to parse last JSON line
verdict = "Unparsed"
try:
line = generated.splitlines()[-1]
data = json.loads(line)
v = data.get("verdict", "").lower()
if v.startswith("corr"):
verdict = LABELS[0]
elif v.startswith("conc"):
verdict = LABELS[1]
elif v.startswith("comp"):
verdict = LABELS[2]
except Exception:
pass
return verdict, "", generated
def classify(question: str, solution: str):
if not question.strip() or not solution.strip():
return "Please enter both fields.", "", ""
if model_ty in ("classification", "baseline"):
return _classify_logits(question, solution)
elif model_ty == "causal_lm":
return _classify_generate(question, solution)
else:
return "Model not loaded.", "", ""
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Build Gradio UI
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
load_model()
with gr.Blocks(title="Math Solution Classifier") as demo:
gr.Markdown("# ๐Ÿงฎ Math Solution Classifier")
gr.Markdown(
"Classify a student's math solution as **correct**, **conceptually flawed**, "
"or **computationally flawed**."
)
with gr.Row():
with gr.Column():
q_in = gr.Textbox(label="Math Question", lines=3)
s_in = gr.Textbox(label="Proposed Solution", lines=6)
btn = gr.Button("Classify", variant="primary")
with gr.Column():
verdict = gr.Textbox(label="Verdict", interactive=False)
conf = gr.Textbox(label="Confidence", interactive=False)
raw = gr.Textbox(label="Model Output", interactive=False)
btn.click(classify, [q_in, s_in], [verdict, conf, raw])
gr.Examples(
[
["Solve for x: 2x + 5 = 13", "2x + 5 = 13\n2x = 8\nx = 4"],
["Find the derivative of f(x)=xยฒ", "f'(x)=2x+1"],
["What is 15% of 200?", "0.15 ร— 200 = 30"],
],
inputs=[q_in, s_in],
)
if __name__ == "__main__":
demo.launch()