File size: 8,628 Bytes
4c7dba1
5a7a017
 
 
4c7dba1
 
5a7a017
4c7dba1
5a7a017
d49524c
2f64eb3
d49524c
5a7a017
 
 
 
d49524c
5a7a017
4c7dba1
5a7a017
 
 
 
4c7dba1
5a7a017
4c7dba1
 
 
 
 
d49524c
 
 
5a7a017
4c7dba1
5a7a017
 
 
4c7dba1
 
5a7a017
 
d49524c
5a7a017
d49524c
4c7dba1
5a7a017
4c7dba1
d49524c
5a7a017
 
 
 
4c7dba1
 
5a7a017
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c7dba1
5a7a017
4c7dba1
5a7a017
4c7dba1
5a7a017
4c7dba1
63ca988
4c7dba1
5a7a017
63ca988
5a7a017
4c7dba1
5a7a017
4c7dba1
 
 
 
 
5a7a017
4c7dba1
 
5a7a017
4c7dba1
5a7a017
4c7dba1
 
 
 
 
 
 
 
 
 
 
5a7a017
 
 
 
 
 
 
 
 
 
0c3715f
 
 
 
 
5a7a017
 
 
0c3715f
5a7a017
 
 
0c3715f
5a7a017
 
 
 
 
 
 
0c3715f
 
 
 
 
5a7a017
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c7dba1
5a7a017
 
 
 
 
 
 
 
 
 
4c7dba1
 
 
 
2f64eb3
d49524c
4c7dba1
2f64eb3
4c7dba1
5a7a017
4c7dba1
 
 
2f64eb3
 
5a7a017
 
 
2f64eb3
4c7dba1
5a7a017
 
4c7dba1
5a7a017
4c7dba1
2f64eb3
4c7dba1
 
 
5a7a017
2f64eb3
4c7dba1
2f64eb3
d49524c
 
4c7dba1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# app.py  โ”€โ”€ Math-solution classifier for HF Spaces
# Compatible with both LoRA-classification and LoRA-causal-LM adapters
# Requirements (pin in requirements.txt):
#   gradio torch transformers peft accelerate spaces

import os
import json
import logging
from typing import Tuple

import gradio as gr
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
)

# PEFT imports (optional)
try:
    from peft.auto import (
        AutoPeftModelForSequenceClassification,
        AutoPeftModelForCausalLM,
    )
    PEFT_AVAILABLE = True
except ImportError:  # PEFT not installed
    PEFT_AVAILABLE = False

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Config & logging
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

ADAPTER_PATH   = os.getenv("ADAPTER_PATH", "./lora_adapter")  # local dir or Hub ID
FALLBACK_MODEL = "distilbert-base-uncased"
LABELS         = {0: "โœ… Correct",
                  1: "๐Ÿค” Conceptual Error",
                  2: "๐Ÿ”ข Computational Error"}

device = "cuda" if torch.cuda.is_available() else "cpu"

model     = None
tokenizer = None
model_ty  = None          # "classification" | "causal_lm" | "baseline"

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Model loader
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def load_model():
    """Try adapter as classifier โ†’ causal-LM โ†’ plain baseline."""
    global model, tokenizer, model_ty

    dtype = torch.float16 if device == "cuda" else torch.float32

    if PEFT_AVAILABLE and os.path.isdir(ADAPTER_PATH):
        logger.info(f"Found adapter at {ADAPTER_PATH}")

        # 1) Try sequence-classification adapter
        try:
            model = AutoPeftModelForSequenceClassification.from_pretrained(
                ADAPTER_PATH,
                torch_dtype=dtype,
                device_map="auto" if device == "cuda" else None,
            )
            model_ty = "classification"
            logger.info("Loaded adapter as sequence-classifier")
        except ValueError:
            # 2) Fall back to causal-LM adapter
            logger.info("Adapter is not a classifier โ€“ trying causal-LM")
            model = AutoPeftModelForCausalLM.from_pretrained(
                ADAPTER_PATH,
                torch_dtype=dtype,
                device_map="auto" if device == "cuda" else None,
            )
            model_ty = "causal_lm"
            logger.info("Loaded adapter as causal-LM")

        tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)

    else:
        logger.warning("No adapter found โ€“ using baseline DistilBERT classifier")
        tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
        model     = AutoModelForSequenceClassification.from_pretrained(
            FALLBACK_MODEL,
            num_labels=3,
            ignore_mismatched_sizes=True,
            torch_dtype=dtype,
        )
        model_ty = "baseline"

    # Make sure we have a pad token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token or tokenizer.sep_token

    model.to(device)
    model.eval()
    logger.info(f"Model ready on {device} as {model_ty}")

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Inference helpers
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def _classify_logits(question: str, solution: str) -> Tuple[str, str, str]:
    text = f"Question: {question}\n\nSolution:\n{solution}"
    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512,
    ).to(device)

    with torch.no_grad():
        logits = model(**inputs).logits
        probs  = torch.softmax(logits, dim=-1)[0]
        pred   = int(torch.argmax(probs))
        conf   = f"{probs[pred].item():.3f}"

    return LABELS.get(pred, "Unknown"), conf, "โ€”"

def _classify_generate(question: str, solution: str) -> Tuple[str, str, str]:
    # Prompt must match the format you used in tuning
    prompt = (
        "You are a mathematics tutor.\n"
        "You are given a math word problem and a student's solution. "
        "Decide whether the solution is correct.\n\n"
        "- Correct            = all reasoning and calculations are correct.\n"
        "- Conceptual Error   = reasoning is wrong.\n"
        "- Computational Error= reasoning okay but arithmetic off.\n\n"
        "Reply with ONLY one of these JSON lines:\n"
        '{"verdict": "correct"}\n'
        '{"verdict": "conceptual"}\n'
        '{"verdict": "computational"}\n\n'
        f"Question: {question}\n\nSolution:\n{solution}\n\nAnswer:"
    )

  
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        out_ids = model.generate(
            **inputs,
            max_new_tokens=32,
            pad_token_id=tokenizer.eos_token_id,
        )
    generated = tokenizer.decode(
        out_ids[0][inputs["input_ids"].shape[1]:],
        skip_special_tokens=True,
    ).strip()


    # Try to parse last JSON line
    verdict = "Unparsed"
    try:
        line = generated.splitlines()[-1]
        data = json.loads(line)
        v = data.get("verdict", "").lower()
        if v.startswith("corr"):
            verdict = LABELS[0]
        elif v.startswith("conc"):
            verdict = LABELS[1]
        elif v.startswith("comp"):
            verdict = LABELS[2]
    except Exception:
        pass

    return verdict, "", generated

def classify(question: str, solution: str):
    if not question.strip() or not solution.strip():
        return "Please enter both fields.", "", ""

    if model_ty in ("classification", "baseline"):
        return _classify_logits(question, solution)
    elif model_ty == "causal_lm":
        return _classify_generate(question, solution)
    else:
        return "Model not loaded.", "", ""

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Build Gradio UI
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
load_model()

with gr.Blocks(title="Math Solution Classifier") as demo:
    gr.Markdown("# ๐Ÿงฎ Math Solution Classifier")
    gr.Markdown(
        "Classify a student's math solution as **correct**, **conceptually flawed**, "
        "or **computationally flawed**."
    )

    with gr.Row():
        with gr.Column():
            q_in  = gr.Textbox(label="Math Question", lines=3)
            s_in  = gr.Textbox(label="Proposed Solution", lines=6)
            btn   = gr.Button("Classify", variant="primary")
        with gr.Column():
            verdict = gr.Textbox(label="Verdict", interactive=False)
            conf    = gr.Textbox(label="Confidence", interactive=False)
            raw     = gr.Textbox(label="Model Output", interactive=False)

    btn.click(classify, [q_in, s_in], [verdict, conf, raw])

    gr.Examples(
        [
            ["Solve for x: 2x + 5 = 13", "2x + 5 = 13\n2x = 8\nx = 4"],
            ["Find the derivative of f(x)=xยฒ", "f'(x)=2x+1"],
            ["What is 15% of 200?", "0.15 ร— 200 = 30"],
        ],
        inputs=[q_in, s_in],
    )

if __name__ == "__main__":
    demo.launch()