File size: 5,760 Bytes
4c7dba1
 
 
 
 
d49524c
2f64eb3
d49524c
 
 
4c7dba1
 
 
 
 
 
 
 
 
 
d49524c
 
 
4c7dba1
 
 
 
 
 
 
d49524c
 
 
4c7dba1
 
 
d49524c
4c7dba1
d49524c
4c7dba1
 
 
63ca988
4c7dba1
 
 
6a197e4
4c7dba1
 
 
 
63ca988
4c7dba1
63ca988
4c7dba1
63ca988
4c7dba1
 
 
 
 
 
 
 
 
 
 
 
 
2f64eb3
4c7dba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f64eb3
d49524c
4c7dba1
2f64eb3
4c7dba1
 
 
 
 
2f64eb3
 
4c7dba1
 
 
2f64eb3
4c7dba1
 
 
 
 
 
2f64eb3
4c7dba1
 
 
 
2f64eb3
4c7dba1
2f64eb3
d49524c
 
4c7dba1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# app.py  โ”€โ”€ Math-solution classifier for HF Spaces
# Requires: gradio, torch, transformers, peft, accelerate, spaces

import os
import logging

import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Optional PEFT import (only available if you include it in requirements.txt)
try:
    from peft import AutoPeftModelForSequenceClassification
    PEFT_AVAILABLE = True
except ImportError:
    PEFT_AVAILABLE = False

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Config & logging
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

ADAPTER_PATH = os.getenv("ADAPTER_PATH", "./lora_adapter")   # local dir or Hub ID
FALLBACK_MODEL = "distilbert-base-uncased"
LABELS = {0: "โœ… Correct",
          1: "๐Ÿค” Conceptual Error",
          2: "๐Ÿ”ข Computational Error"}

device = "cuda" if torch.cuda.is_available() else "cpu"
model = None
tokenizer = None

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Load model & tokenizer
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def load_model():
    """Load the LoRA adapter if present, otherwise a baseline classifier."""
    global model, tokenizer

    if PEFT_AVAILABLE and os.path.isdir(ADAPTER_PATH):
        logger.info(f"Loading LoRA adapter from {ADAPTER_PATH}")
        model = AutoPeftModelForSequenceClassification.from_pretrained(
            ADAPTER_PATH,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            device_map="auto" if device == "cuda" else None,
        )
        tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)
    else:
        logger.warning("LoRA adapter not found โ€“ falling back to baseline model")
        tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
        model = AutoModelForSequenceClassification.from_pretrained(
            FALLBACK_MODEL,
            num_labels=3,
            ignore_mismatched_sizes=True,
        )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token or tokenizer.sep_token

    model.to(device)
    model.eval()
    logger.info("Model & tokenizer ready")

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Inference helper
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def classify(question: str, solution: str):
    """Return (label, confidence, placeholder-explanation)."""
    if not question.strip() or not solution.strip():
        return "Please provide both question and solution.", "", ""

    text = f"Question: {question}\n\nSolution:\n{solution}"
    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512,
    ).to(device)

    with torch.no_grad():
        logits = model(**inputs).logits
        probs = torch.softmax(logits, dim=-1)[0]
        pred = int(torch.argmax(probs))
        confidence = f"{probs[pred].item():.3f}"

    return LABELS.get(pred, "Unknown"), confidence, "โ€”"

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Build Gradio UI
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
load_model()

with gr.Blocks(title="Math Solution Classifier") as demo:
    gr.Markdown("# ๐Ÿงฎ Math Solution Classifier")
    gr.Markdown(
        "Classify a studentโ€™s math solution as **correct**, **conceptually flawed**, "
        "or **computationally flawed**."
    )

    with gr.Row():
        with gr.Column():
            q_in = gr.Textbox(label="Math Question", lines=3)
            s_in = gr.Textbox(label="Proposed Solution", lines=6)
            btn = gr.Button("Classify", variant="primary")
        with gr.Column():
            verdict = gr.Textbox(label="Verdict", interactive=False)
            conf = gr.Textbox(label="Confidence", interactive=False)
            expl = gr.Textbox(label="Explanation", interactive=False)

    btn.click(classify, [q_in, s_in], [verdict, conf, expl])

    gr.Examples(
        [
            ["Solve for x: 2x + 5 = 13", "2x + 5 = 13\n2x = 8\nx = 4"],
            ["Find the derivative of f(x)=xยฒ", "f'(x)=2x+1"],
            ["What is 15 % of 200?", "0.15 ร— 200 = 30"],
        ],
        inputs=[q_in, s_in],
    )

if __name__ == "__main__":
    demo.launch()