File size: 8,628 Bytes
4c7dba1 5a7a017 4c7dba1 5a7a017 4c7dba1 5a7a017 d49524c 2f64eb3 d49524c 5a7a017 d49524c 5a7a017 4c7dba1 5a7a017 4c7dba1 5a7a017 4c7dba1 d49524c 5a7a017 4c7dba1 5a7a017 4c7dba1 5a7a017 d49524c 5a7a017 d49524c 4c7dba1 5a7a017 4c7dba1 d49524c 5a7a017 4c7dba1 5a7a017 4c7dba1 5a7a017 4c7dba1 5a7a017 4c7dba1 5a7a017 4c7dba1 63ca988 4c7dba1 5a7a017 63ca988 5a7a017 4c7dba1 5a7a017 4c7dba1 5a7a017 4c7dba1 5a7a017 4c7dba1 5a7a017 4c7dba1 5a7a017 0c3715f 5a7a017 0c3715f 5a7a017 0c3715f 5a7a017 0c3715f 5a7a017 4c7dba1 5a7a017 4c7dba1 2f64eb3 d49524c 4c7dba1 2f64eb3 4c7dba1 5a7a017 4c7dba1 2f64eb3 5a7a017 2f64eb3 4c7dba1 5a7a017 4c7dba1 5a7a017 4c7dba1 2f64eb3 4c7dba1 5a7a017 2f64eb3 4c7dba1 2f64eb3 d49524c 4c7dba1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
# app.py โโ Math-solution classifier for HF Spaces
# Compatible with both LoRA-classification and LoRA-causal-LM adapters
# Requirements (pin in requirements.txt):
# gradio torch transformers peft accelerate spaces
import os
import json
import logging
from typing import Tuple
import gradio as gr
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
)
# PEFT imports (optional)
try:
from peft.auto import (
AutoPeftModelForSequenceClassification,
AutoPeftModelForCausalLM,
)
PEFT_AVAILABLE = True
except ImportError: # PEFT not installed
PEFT_AVAILABLE = False
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Config & logging
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
ADAPTER_PATH = os.getenv("ADAPTER_PATH", "./lora_adapter") # local dir or Hub ID
FALLBACK_MODEL = "distilbert-base-uncased"
LABELS = {0: "โ
Correct",
1: "๐ค Conceptual Error",
2: "๐ข Computational Error"}
device = "cuda" if torch.cuda.is_available() else "cpu"
model = None
tokenizer = None
model_ty = None # "classification" | "causal_lm" | "baseline"
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Model loader
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def load_model():
"""Try adapter as classifier โ causal-LM โ plain baseline."""
global model, tokenizer, model_ty
dtype = torch.float16 if device == "cuda" else torch.float32
if PEFT_AVAILABLE and os.path.isdir(ADAPTER_PATH):
logger.info(f"Found adapter at {ADAPTER_PATH}")
# 1) Try sequence-classification adapter
try:
model = AutoPeftModelForSequenceClassification.from_pretrained(
ADAPTER_PATH,
torch_dtype=dtype,
device_map="auto" if device == "cuda" else None,
)
model_ty = "classification"
logger.info("Loaded adapter as sequence-classifier")
except ValueError:
# 2) Fall back to causal-LM adapter
logger.info("Adapter is not a classifier โ trying causal-LM")
model = AutoPeftModelForCausalLM.from_pretrained(
ADAPTER_PATH,
torch_dtype=dtype,
device_map="auto" if device == "cuda" else None,
)
model_ty = "causal_lm"
logger.info("Loaded adapter as causal-LM")
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)
else:
logger.warning("No adapter found โ using baseline DistilBERT classifier")
tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
model = AutoModelForSequenceClassification.from_pretrained(
FALLBACK_MODEL,
num_labels=3,
ignore_mismatched_sizes=True,
torch_dtype=dtype,
)
model_ty = "baseline"
# Make sure we have a pad token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token or tokenizer.sep_token
model.to(device)
model.eval()
logger.info(f"Model ready on {device} as {model_ty}")
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Inference helpers
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def _classify_logits(question: str, solution: str) -> Tuple[str, str, str]:
text = f"Question: {question}\n\nSolution:\n{solution}"
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
).to(device)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)[0]
pred = int(torch.argmax(probs))
conf = f"{probs[pred].item():.3f}"
return LABELS.get(pred, "Unknown"), conf, "โ"
def _classify_generate(question: str, solution: str) -> Tuple[str, str, str]:
# Prompt must match the format you used in tuning
prompt = (
"You are a mathematics tutor.\n"
"You are given a math word problem and a student's solution. "
"Decide whether the solution is correct.\n\n"
"- Correct = all reasoning and calculations are correct.\n"
"- Conceptual Error = reasoning is wrong.\n"
"- Computational Error= reasoning okay but arithmetic off.\n\n"
"Reply with ONLY one of these JSON lines:\n"
'{"verdict": "correct"}\n'
'{"verdict": "conceptual"}\n'
'{"verdict": "computational"}\n\n'
f"Question: {question}\n\nSolution:\n{solution}\n\nAnswer:"
)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
out_ids = model.generate(
**inputs,
max_new_tokens=32,
pad_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(
out_ids[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
).strip()
# Try to parse last JSON line
verdict = "Unparsed"
try:
line = generated.splitlines()[-1]
data = json.loads(line)
v = data.get("verdict", "").lower()
if v.startswith("corr"):
verdict = LABELS[0]
elif v.startswith("conc"):
verdict = LABELS[1]
elif v.startswith("comp"):
verdict = LABELS[2]
except Exception:
pass
return verdict, "", generated
def classify(question: str, solution: str):
if not question.strip() or not solution.strip():
return "Please enter both fields.", "", ""
if model_ty in ("classification", "baseline"):
return _classify_logits(question, solution)
elif model_ty == "causal_lm":
return _classify_generate(question, solution)
else:
return "Model not loaded.", "", ""
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Build Gradio UI
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
load_model()
with gr.Blocks(title="Math Solution Classifier") as demo:
gr.Markdown("# ๐งฎ Math Solution Classifier")
gr.Markdown(
"Classify a student's math solution as **correct**, **conceptually flawed**, "
"or **computationally flawed**."
)
with gr.Row():
with gr.Column():
q_in = gr.Textbox(label="Math Question", lines=3)
s_in = gr.Textbox(label="Proposed Solution", lines=6)
btn = gr.Button("Classify", variant="primary")
with gr.Column():
verdict = gr.Textbox(label="Verdict", interactive=False)
conf = gr.Textbox(label="Confidence", interactive=False)
raw = gr.Textbox(label="Model Output", interactive=False)
btn.click(classify, [q_in, s_in], [verdict, conf, raw])
gr.Examples(
[
["Solve for x: 2x + 5 = 13", "2x + 5 = 13\n2x = 8\nx = 4"],
["Find the derivative of f(x)=xยฒ", "f'(x)=2x+1"],
["What is 15% of 200?", "0.15 ร 200 = 30"],
],
inputs=[q_in, s_in],
)
if __name__ == "__main__":
demo.launch()
|