|
|
|
|
|
|
|
import os |
|
import logging |
|
|
|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
try: |
|
from peft import AutoPeftModelForSequenceClassification |
|
PEFT_AVAILABLE = True |
|
except ImportError: |
|
PEFT_AVAILABLE = False |
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
ADAPTER_PATH = os.getenv("ADAPTER_PATH", "./lora_adapter") |
|
FALLBACK_MODEL = "distilbert-base-uncased" |
|
LABELS = {0: "โ
Correct", |
|
1: "๐ค Conceptual Error", |
|
2: "๐ข Computational Error"} |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = None |
|
tokenizer = None |
|
|
|
|
|
|
|
|
|
def load_model(): |
|
"""Load the LoRA adapter if present, otherwise a baseline classifier.""" |
|
global model, tokenizer |
|
|
|
if PEFT_AVAILABLE and os.path.isdir(ADAPTER_PATH): |
|
logger.info(f"Loading LoRA adapter from {ADAPTER_PATH}") |
|
model = AutoPeftModelForSequenceClassification.from_pretrained( |
|
ADAPTER_PATH, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
device_map="auto" if device == "cuda" else None, |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH) |
|
else: |
|
logger.warning("LoRA adapter not found โ falling back to baseline model") |
|
tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL) |
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
FALLBACK_MODEL, |
|
num_labels=3, |
|
ignore_mismatched_sizes=True, |
|
) |
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token or tokenizer.sep_token |
|
|
|
model.to(device) |
|
model.eval() |
|
logger.info("Model & tokenizer ready") |
|
|
|
|
|
|
|
|
|
def classify(question: str, solution: str): |
|
"""Return (label, confidence, placeholder-explanation).""" |
|
if not question.strip() or not solution.strip(): |
|
return "Please provide both question and solution.", "", "" |
|
|
|
text = f"Question: {question}\n\nSolution:\n{solution}" |
|
inputs = tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=512, |
|
).to(device) |
|
|
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
probs = torch.softmax(logits, dim=-1)[0] |
|
pred = int(torch.argmax(probs)) |
|
confidence = f"{probs[pred].item():.3f}" |
|
|
|
return LABELS.get(pred, "Unknown"), confidence, "โ" |
|
|
|
|
|
|
|
|
|
load_model() |
|
|
|
with gr.Blocks(title="Math Solution Classifier") as demo: |
|
gr.Markdown("# ๐งฎ Math Solution Classifier") |
|
gr.Markdown( |
|
"Classify a studentโs math solution as **correct**, **conceptually flawed**, " |
|
"or **computationally flawed**." |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
q_in = gr.Textbox(label="Math Question", lines=3) |
|
s_in = gr.Textbox(label="Proposed Solution", lines=6) |
|
btn = gr.Button("Classify", variant="primary") |
|
with gr.Column(): |
|
verdict = gr.Textbox(label="Verdict", interactive=False) |
|
conf = gr.Textbox(label="Confidence", interactive=False) |
|
expl = gr.Textbox(label="Explanation", interactive=False) |
|
|
|
btn.click(classify, [q_in, s_in], [verdict, conf, expl]) |
|
|
|
gr.Examples( |
|
[ |
|
["Solve for x: 2x + 5 = 13", "2x + 5 = 13\n2x = 8\nx = 4"], |
|
["Find the derivative of f(x)=xยฒ", "f'(x)=2x+1"], |
|
["What is 15 % of 200?", "0.15 ร 200 = 30"], |
|
], |
|
inputs=[q_in, s_in], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|