File size: 5,760 Bytes
4c7dba1 d49524c 2f64eb3 d49524c 4c7dba1 d49524c 4c7dba1 d49524c 4c7dba1 d49524c 4c7dba1 d49524c 4c7dba1 63ca988 4c7dba1 6a197e4 4c7dba1 63ca988 4c7dba1 63ca988 4c7dba1 63ca988 4c7dba1 2f64eb3 4c7dba1 2f64eb3 d49524c 4c7dba1 2f64eb3 4c7dba1 2f64eb3 4c7dba1 2f64eb3 4c7dba1 2f64eb3 4c7dba1 2f64eb3 4c7dba1 2f64eb3 d49524c 4c7dba1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# app.py โโ Math-solution classifier for HF Spaces
# Requires: gradio, torch, transformers, peft, accelerate, spaces
import os
import logging
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Optional PEFT import (only available if you include it in requirements.txt)
try:
from peft import AutoPeftModelForSequenceClassification
PEFT_AVAILABLE = True
except ImportError:
PEFT_AVAILABLE = False
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Config & logging
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
ADAPTER_PATH = os.getenv("ADAPTER_PATH", "./lora_adapter") # local dir or Hub ID
FALLBACK_MODEL = "distilbert-base-uncased"
LABELS = {0: "โ
Correct",
1: "๐ค Conceptual Error",
2: "๐ข Computational Error"}
device = "cuda" if torch.cuda.is_available() else "cpu"
model = None
tokenizer = None
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Load model & tokenizer
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def load_model():
"""Load the LoRA adapter if present, otherwise a baseline classifier."""
global model, tokenizer
if PEFT_AVAILABLE and os.path.isdir(ADAPTER_PATH):
logger.info(f"Loading LoRA adapter from {ADAPTER_PATH}")
model = AutoPeftModelForSequenceClassification.from_pretrained(
ADAPTER_PATH,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
)
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)
else:
logger.warning("LoRA adapter not found โ falling back to baseline model")
tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
model = AutoModelForSequenceClassification.from_pretrained(
FALLBACK_MODEL,
num_labels=3,
ignore_mismatched_sizes=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token or tokenizer.sep_token
model.to(device)
model.eval()
logger.info("Model & tokenizer ready")
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Inference helper
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def classify(question: str, solution: str):
"""Return (label, confidence, placeholder-explanation)."""
if not question.strip() or not solution.strip():
return "Please provide both question and solution.", "", ""
text = f"Question: {question}\n\nSolution:\n{solution}"
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
).to(device)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)[0]
pred = int(torch.argmax(probs))
confidence = f"{probs[pred].item():.3f}"
return LABELS.get(pred, "Unknown"), confidence, "โ"
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Build Gradio UI
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
load_model()
with gr.Blocks(title="Math Solution Classifier") as demo:
gr.Markdown("# ๐งฎ Math Solution Classifier")
gr.Markdown(
"Classify a studentโs math solution as **correct**, **conceptually flawed**, "
"or **computationally flawed**."
)
with gr.Row():
with gr.Column():
q_in = gr.Textbox(label="Math Question", lines=3)
s_in = gr.Textbox(label="Proposed Solution", lines=6)
btn = gr.Button("Classify", variant="primary")
with gr.Column():
verdict = gr.Textbox(label="Verdict", interactive=False)
conf = gr.Textbox(label="Confidence", interactive=False)
expl = gr.Textbox(label="Explanation", interactive=False)
btn.click(classify, [q_in, s_in], [verdict, conf, expl])
gr.Examples(
[
["Solve for x: 2x + 5 = 13", "2x + 5 = 13\n2x = 8\nx = 4"],
["Find the derivative of f(x)=xยฒ", "f'(x)=2x+1"],
["What is 15 % of 200?", "0.15 ร 200 = 30"],
],
inputs=[q_in, s_in],
)
if __name__ == "__main__":
demo.launch()
|