mcamargo00 commited on
Commit
e87cacb
·
verified ·
1 Parent(s): c31c5ad

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +573 -154
app.py CHANGED
@@ -1,155 +1,574 @@
1
- # app.py - Gradio version (much simpler for HF Spaces)
2
-
3
- import gradio as gr
4
- import torch
5
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
- import logging
7
- import spaces
8
-
9
- # Set up logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
-
13
- # Global variables for model and tokenizer
14
- model = None
15
- tokenizer = None
16
- label_mapping = {0: "✅ Correct", 1: "🤔 Conceptually Flawed", 2: "🔢 Computationally Flawed"}
17
-
18
- def load_model():
19
- """Load your trained model here"""
20
- global model, tokenizer
21
-
22
- try:
23
- # Replace these with your actual model path/name
24
- # Option 1: Load from local files
25
- # model = AutoModelForSequenceClassification.from_pretrained("./your_model_directory")
26
- # tokenizer = AutoTokenizer.from_pretrained("./your_model_directory")
27
-
28
- # Option 2: Load from Hugging Face Hub (if you upload your model there)
29
- # model = AutoModelForSequenceClassification.from_pretrained("your-username/your-model-name")
30
- # tokenizer = AutoTokenizer.from_pretrained("your-username/your-model-name")
31
-
32
- # For now, we'll use a placeholder - replace this with your actual model loading
33
- logger.warning("Using placeholder model loading - replace with your actual model!")
34
-
35
- # Placeholder model loading (replace this!)
36
- model_name = "distilbert-base-uncased" # Replace with your model
37
- tokenizer = AutoTokenizer.from_pretrained(model_name)
38
- model = AutoModelForSequenceClassification.from_pretrained(
39
- model_name,
40
- num_labels=3,
41
- ignore_mismatched_sizes=True
42
- )
43
-
44
- logger.info("Model loaded successfully")
45
- return "Model loaded successfully!"
46
-
47
- except Exception as e:
48
- logger.error(f"Error loading model: {e}")
49
- return f"Error loading model: {e}"
50
-
51
- @spaces.GPU
52
- def classify_solution(question: str, solution: str):
53
- """
54
- Classify the math solution
55
- Returns: (classification_label, confidence_score, explanation)
56
- """
57
- if not question.strip() or not solution.strip():
58
- return "Please fill in both fields", 0.0, ""
59
-
60
- if not model or not tokenizer:
61
- return "Model not loaded", 0.0, ""
62
-
63
- try:
64
- # Combine question and solution for input
65
- text_input = f"Question: {question}\nSolution: {solution}"
66
-
67
- # Tokenize input
68
- inputs = tokenizer(
69
- text_input,
70
- return_tensors="pt",
71
- truncation=True,
72
- padding=True,
73
- max_length=512
74
- )
75
-
76
- # Get model prediction
77
- with torch.no_grad():
78
- outputs = model(**inputs)
79
- predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
80
- predicted_class = torch.argmax(predictions, dim=-1).item()
81
- confidence = predictions[0][predicted_class].item()
82
-
83
- classification = label_mapping[predicted_class]
84
-
85
- # Create explanation based on classification
86
- explanations = {
87
- 0: "The mathematical approach and calculations are both sound.",
88
- 1: "The approach or understanding has fundamental issues.",
89
- 2: "The approach is correct, but there are calculation errors."
90
- }
91
-
92
- explanation = explanations[predicted_class]
93
-
94
- return classification, f"{confidence:.2%}", explanation
95
-
96
- except Exception as e:
97
- logger.error(f"Error during classification: {e}")
98
- return f"Classification error: {str(e)}", "0%", ""
99
-
100
- # Load model on startup
101
- load_model()
102
-
103
- # Create Gradio interface
104
- with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
105
- gr.Markdown("# 🧮 Math Solution Classifier")
106
- gr.Markdown("Classify math solutions as correct, conceptually flawed, or computationally flawed.")
107
-
108
- with gr.Row():
109
- with gr.Column():
110
- question_input = gr.Textbox(
111
- label="Math Question",
112
- placeholder="e.g., Solve for x: 2x + 5 = 13",
113
- lines=3
114
- )
115
-
116
- solution_input = gr.Textbox(
117
- label="Proposed Solution",
118
- placeholder="e.g., 2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4",
119
- lines=5
120
- )
121
-
122
- classify_btn = gr.Button("Classify Solution", variant="primary")
123
-
124
- with gr.Column():
125
- classification_output = gr.Textbox(label="Classification", interactive=False)
126
- confidence_output = gr.Textbox(label="Confidence", interactive=False)
127
- explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3)
128
-
129
- # Examples
130
- gr.Examples(
131
- examples=[
132
- [
133
- "Solve for x: 2x + 5 = 13",
134
- "2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4"
135
- ],
136
- [
137
- "John has three apples and Mary has seven, how many apples do they have together?",
138
- "They have 7 + 3 = 11 apples." # This should be computationally flawed
139
- ],
140
- [
141
- "What is 15% of 200?",
142
- "15% = 15/100 = 0.15\n0.15 × 200 = 30"
143
- ]
144
- ],
145
- inputs=[question_input, solution_input]
146
- )
147
-
148
- classify_btn.click(
149
- fn=classify_solution,
150
- inputs=[question_input, solution_input],
151
- outputs=[classification_output, confidence_output, explanation_output]
152
- )
153
-
154
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  app.launch()
 
1
+ # app.py - Gradio version (much simpler for HF Spaces)
2
+
3
+ import gradio as gr
4
+ import logging
5
+ import spaces
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
10
+
11
+ import unsloth
12
+ from unsloth import FastModel
13
+ from peft import PeftModel
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ import json
17
+ import re
18
+ import math
19
+ import time
20
+
21
+ # Set up logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Global variables for model and tokenizer
26
+ label_mapping = {0: "✅ Correct", 1: "🤔 Conceptually Flawed", 2: "🔢 Computationally Flawed"}
27
+
28
+
29
+
30
+
31
+ # ===================================================================
32
+ # 1. DEFINE CUSTOM CLASSIFIER (Required for Phi-4)
33
+ # ===================================================================
34
+ class GPTSequenceClassifier(nn.Module):
35
+ def __init__(self, base_model, num_labels):
36
+ super().__init__()
37
+ self.base = base_model
38
+ hidden_size = base_model.config.hidden_size
39
+ self.classifier = nn.Linear(hidden_size, num_labels, bias=True)
40
+ self.num_labels = num_labels
41
+
42
+ def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
43
+ outputs = self.base(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, **kwargs)
44
+ last_hidden_state = outputs.hidden_states[-1]
45
+ pooled_output = last_hidden_state[:, -1, :]
46
+ logits = self.classifier(pooled_output)
47
+ loss = None
48
+ if labels is not None:
49
+ loss = nn.functional.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
50
+ return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
51
+
52
+
53
+
54
+
55
+ # --- Helper Functions ---
56
+ def format_solution_into_json_str(solution_text: str) -> str:
57
+ lines = [line.strip() for line in solution_text.strip().split('\n') if line.strip()]
58
+ final_answer = ""
59
+ if lines and "FINAL ANSWER:" in lines[-1].upper():
60
+ final_answer = lines[-1][len("FINAL ANSWER:"):].strip()
61
+ lines = lines[:-1]
62
+ solution_dict = {f"L{i+1}": line for i, line in enumerate(lines)}
63
+ solution_dict["FA"] = final_answer
64
+ return json.dumps(solution_dict, indent=2)
65
+
66
+ def sanitize_equation_string(expression: str) -> str:
67
+ if not isinstance(expression, str):
68
+ return ""
69
+
70
+ s = expression.strip()
71
+
72
+ # Normalize common symbols
73
+ s = s.replace('×', '*').replace('·', '*').replace('x', '*')
74
+
75
+ # Convert percentages like '12%' -> '(12/100)'
76
+ s = re.sub(r'(?<!\d)(\d+(?:\.\d+)?)\s*%', r'(\1/100)', s)
77
+
78
+ # Simple paren balancing trims (only when a single stray exists at an edge)
79
+ if s.count('(') > s.count(')') and s.startswith('('):
80
+ s = s[1:]
81
+ elif s.count(')') > s.count('(') and s.endswith(')'):
82
+ s = s[:-1]
83
+
84
+ # Drop units right after a slash: /hr, /dogs
85
+ s = re.sub(r'/([a-zA-Z]+)', '', s)
86
+
87
+ # Keep only numeric math tokens
88
+ s = re.sub(r'[^\d.()+\-*/=%]', '', s)
89
+
90
+ # Collapse repeated '=' (e.g., '==24/2=12' -> '=24/2=12')
91
+ s = re.sub(r'=+', '=', s)
92
+
93
+ return s
94
+
95
+ import re, math
96
+
97
+ def _parse_equation(eq_str: str):
98
+ s = sanitize_equation_string(eq_str or "")
99
+ s = s.lstrip('=') # handle lines like '=24/2=12'
100
+ if '=' not in s:
101
+ return None
102
+ if s.count('=') > 1:
103
+ pos = s.rfind('=')
104
+ lhs, rhs = s[:pos], s[pos+1:]
105
+ else:
106
+ lhs, rhs = s.split('=', 1)
107
+ lhs, rhs = lhs.strip(), rhs.strip()
108
+ if not lhs or not rhs:
109
+ return None
110
+ return lhs, rhs
111
+
112
+ def _abs_tol_from_display(rhs_str: str):
113
+ """
114
+ If RHS is a single numeric literal like 0.33, use half-ULP at that precision.
115
+ e.g., '0.33' -> 0.5 * 10^-2 = 0.005
116
+ Otherwise return None and fall back to base tolerances.
117
+ """
118
+ s = rhs_str.strip()
119
+ # allow optional parens and sign
120
+ m = re.fullmatch(r'\(?\s*[-+]?\d+(?:\.(\d+))?\s*\)?', s)
121
+ if not m:
122
+ return None
123
+ frac = m.group(1) or ""
124
+ d = len(frac)
125
+ return 0.5 * (10 ** (-d)) if d > 0 else 0.5 # if integer shown, allow ±0.5
126
+
127
+ def evaluate_equations(eq_dict: dict, sol_dict: dict,
128
+ base_rel_tol: float = 1e-6,
129
+ base_abs_tol: float = 1e-9,
130
+ honor_display_precision: bool = True):
131
+ """
132
+ Evaluates extracted equations. Accepts rounded RHS values based on displayed precision.
133
+ """
134
+ for key, eq_str in (eq_dict or {}).items():
135
+ parsed = _parse_equation(eq_str)
136
+ if not parsed:
137
+ continue
138
+ lhs, rhs_str = parsed
139
+
140
+ try:
141
+ lhs_val = eval(lhs, {"__builtins__": None}, {})
142
+ rhs_val = eval(rhs_str, {"__builtins__": None}, {})
143
+ except Exception:
144
+ continue
145
+
146
+ # dynamic absolute tolerance from RHS formatting (e.g., 0.33 -> 0.005)
147
+ abs_tol = base_abs_tol
148
+ if honor_display_precision:
149
+ dyn = _abs_tol_from_display(rhs_str)
150
+ if dyn is not None:
151
+ abs_tol = max(abs_tol, dyn)
152
+
153
+ if not math.isclose(lhs_val, rhs_val, rel_tol=base_rel_tol, abs_tol=abs_tol):
154
+ correct_rhs_val = round(lhs_val, 6)
155
+ correct_rhs_str = f"{correct_rhs_val:.6f}".rstrip('0').rstrip('.')
156
+ return {
157
+ "error": True,
158
+ "line_key": key,
159
+ "line_text": sol_dict.get(key, "N/A"),
160
+ "original_flawed_calc": eq_str,
161
+ "sanitized_lhs": lhs,
162
+ "original_rhs": rhs_str,
163
+ "correct_rhs": correct_rhs_str,
164
+ }
165
+
166
+ return {"error": False}
167
+
168
+
169
+
170
+
171
+
172
+ def extract_json_from_response(response: str) -> dict:
173
+ """
174
+ Recover equations from the extractor's output.
175
+
176
+ Strategy:
177
+ 1) Try to parse a real JSON object (if present).
178
+ 2) Parse relaxed key-value lines like 'L1: ...' or 'FA=...'.
179
+ 3) Also fall back to linewise equations (e.g., '=24/2=12', '7*2=14') and
180
+ merge them as L1, L2, ... preserving order. Keep FA if present.
181
+ """
182
+ out = {}
183
+
184
+ if not response or not isinstance(response, str):
185
+ return out
186
+
187
+ text = response.strip()
188
+
189
+ # --- 1) strict JSON block, if any ---
190
+ m = re.search(r'\{.*\}', text, flags=re.S)
191
+ if m:
192
+ try:
193
+ obj = json.loads(m.group(0))
194
+ if isinstance(obj, dict) and any(k.startswith("L") for k in obj):
195
+ return obj
196
+ elif isinstance(obj, dict):
197
+ out.update(obj) # keep FA etc., then continue
198
+ except Exception:
199
+ pass
200
+
201
+ # --- 2) relaxed key/value lines: Lk : value or FA = value ---
202
+ relaxed = {}
203
+ for ln in text.splitlines():
204
+ ln = ln.strip().strip(',')
205
+ if not ln:
206
+ continue
207
+ m = re.match(r'(?i)^(L\d+|FA)\s*[:=]\s*(.+?)\s*$', ln)
208
+ if m:
209
+ k = m.group(1).strip()
210
+ v = m.group(2).strip().rstrip(',')
211
+ relaxed[k] = v
212
+ out.update(relaxed)
213
+
214
+ # Count how many L-keys we already have
215
+ existing_L = sorted(
216
+ int(k[1:]) for k in out.keys()
217
+ if k.startswith("L") and k[1:].isdigit()
218
+ )
219
+ next_L = (max(existing_L) + 1) if existing_L else 1
220
+
221
+ # --- 3) linewise fallback: harvest bare equations and merge ---
222
+ def _looks_like_equation(s: str) -> str | None:
223
+ s = sanitize_equation_string(s or "").lstrip('=')
224
+ if '=' in s and any(ch.isdigit() for ch in s):
225
+ return s
226
+ return None
227
+
228
+ # set of existing equation strings to avoid duplicates
229
+ seen_vals = set(v for v in out.values() if isinstance(v, str))
230
+
231
+ for ln in text.splitlines():
232
+ ln = ln.strip().strip(',')
233
+ if not ln or re.match(r'(?i)^(L\d+|FA)\s*[:=]', ln):
234
+ # skip lines we already captured as relaxed pairs
235
+ continue
236
+ eq = _looks_like_equation(ln)
237
+ if eq and eq not in seen_vals:
238
+ out[f"L{next_L}"] = eq
239
+ seen_vals.add(eq)
240
+ next_L += 1
241
+
242
+ return out
243
+
244
+ # --- Prompts ---
245
+ EXTRACTOR_SYSTEM_PROMPT = \
246
+ """[ROLE]
247
+ You are an expert at parsing mathematical solutions.
248
+ [TASK]
249
+ You are given a mathematical solution. Your task is to extract the calculation performed on each line and represent it as a simple equation.
250
+ **This is a literal transcription task. Follow these rules with extreme precision:**
251
+ - **RULE 1: Transcribe EXACTLY.** Do not correct mathematical errors. If a line implies `2+2=5`, your output for that line must be `2+2=5`.
252
+ - **RULE 2: Isolate the Equation.** Your output must contain ONLY the equation. Do not include any surrounding text, units (like `/hour`), or currency symbols (like `$`).
253
+ - **RULE 3: Use Standard Operators.** Always use `*` for multiplication. Never use `x`.
254
+ [RESPONSE FORMAT]
255
+ Your response must be ONLY a single, valid JSON object, adhering strictly to these rules:
256
+ For each line of the solution, create a key-value pair.
257
+ - The key should be the line identifier (e.g., "L1", "L2", "FA" for the final answer line).
258
+ - The value should be the extracted equation string (e.g., "10+5=15").
259
+ - If a line contains no calculation, the value must be an empty string.
260
+ """
261
+
262
+ CLASSIFIER_SYSTEM_PROMPT = \
263
+ """You are a mathematics tutor.
264
+ You will be given a math word problem and a solution written by a student.
265
+ Carefully analyze the problem and solution LINE-BY-LINE and determine whether there are any errors in the solution."""
266
+
267
+ # --- Example 1 ---
268
+ FEW_SHOT_EXAMPLE_1_SOLUTION = {
269
+ "L1": "2% of $90 is (2/100)*$90 = $1.8",
270
+ "L2": "2% of $60 is (2/100)*$60 = $1.2",
271
+ "L3": "The second transaction was reversed without the service charge so only a total of $90+$1.8+$1.2 = $39 was deducted from his account",
272
+ "L4": "He will have a balance of $400-$39 = $361",
273
+ "FA": "361"
274
+ }
275
+
276
+ FEW_SHOT_EXAMPLE_1_EQUATIONS = {
277
+ "L1": "(2/100)*90=1.8",
278
+ "L2": "(2/100)*60=1.2",
279
+ "L3": "90+1.8+1.2=39",
280
+ "L4": "400-39=361",
281
+ "FA": ""
282
+ }
283
+
284
+
285
+ # --- Example 2 ---
286
+ FEW_SHOT_EXAMPLE_2_SOLUTION = {
287
+ "L1": "She drinks 2 bottles a day and there are 24 bottles in a case so a case will last 24/2 = 12 days",
288
+ "L2": "She needs enough to last her 240 days and 1 case will last her 12 days so she needs 240/12 = 20 cases",
289
+ "L3": "Each case is on sale for $12.00 and she needs 20 cases so that's 12*20 = $240.00",
290
+ "FA": "240"
291
+ }
292
+
293
+ FEW_SHOT_EXAMPLE_2_EQUATIONS = {
294
+ "L1": "24/2=12",
295
+ "L2": "240/12=20",
296
+ "L3": "12*20=240.00",
297
+ "FA": ""
298
+ }
299
+
300
+ def create_extractor_messages(solution_json_str: str) -> list:
301
+ """
302
+ Returns a list of dictionaries representing the conversation history for the prompt.
303
+ """
304
+ # Start with the constant few-shot examples defined globally
305
+ messages = [
306
+ {"role": "user", "content": f"{EXTRACTOR_SYSTEM_PROMPT}\n\n### Solution:\n{json.dumps(FEW_SHOT_EXAMPLE_1_SOLUTION, indent=2)}"},
307
+ {"role": "assistant", "content": json.dumps(FEW_SHOT_EXAMPLE_1_EQUATIONS, indent=2)},
308
+ {"role": "user", "content": f"### Solution:\n{json.dumps(FEW_SHOT_EXAMPLE_2_SOLUTION, indent=2)}"},
309
+ {"role": "assistant", "content": json.dumps(FEW_SHOT_EXAMPLE_2_EQUATIONS, indent=2)},
310
+ ]
311
+
312
+ # Add the final user query to the end of the conversation
313
+ final_user_prompt = f"### Solution:\n{solution_json_str}"
314
+ messages.append({"role": "user", "content": final_user_prompt})
315
+
316
+ return messages
317
+
318
+ gemma_model = None
319
+ gemma_tokenizer = None
320
+ classifier_model = None
321
+ classifier_tokenizer = None
322
+
323
+ def load_model():
324
+ """Load your trained model here"""
325
+ global gemma_model, gemma_tokenizer, classifier_model, classifier_tokenizer
326
+
327
+ try:
328
+ device = "cuda" if torch.cuda.is_available() else "cpu"
329
+
330
+ # --- Model 1: Equation Extractor (Gemma-3 with Unsloth) ---
331
+ extractor_adapter_repo = "arvindsuresh-math/gemma-3-1b-equation-extractor-lora"
332
+ base_gemma_model = "unsloth/gemma-3-1b-it-unsloth-bnb-4bit"
333
+
334
+ gemma_model, gemma_tokenizer = FastModel.from_pretrained(
335
+ model_name=base_gemma_model,
336
+ max_seq_length=2048,
337
+ dtype=None,
338
+ load_in_4bit=True,
339
+ )
340
+
341
+ # --- Gemma tokenizer hygiene (fix the right-padding warning) ---
342
+ if gemma_tokenizer.pad_token is None:
343
+ gemma_tokenizer.pad_token = gemma_tokenizer.eos_token
344
+ gemma_tokenizer.padding_side = "left" # align last tokens across the batch
345
+
346
+ gemma_model = PeftModel.from_pretrained(gemma_model, extractor_adapter_repo)
347
+
348
+
349
+ # --- Model 2: Conceptual Error Classifier (Phi-4) ---
350
+ classifier_adapter_repo = "arvindsuresh-math/phi-4-error-binary-classifier"
351
+ base_phi_model = "microsoft/Phi-4-mini-instruct"
352
+
353
+ DTYPE = torch.bfloat16
354
+ quantization_config = BitsAndBytesConfig(
355
+ load_in_4bit=True,
356
+ bnb_4bit_quant_type="nf4",
357
+ bnb_4bit_compute_dtype=DTYPE
358
+ )
359
+ classifier_backbone_base = AutoModelForCausalLM.from_pretrained(
360
+ base_phi_model,
361
+ quantization_config=quantization_config,
362
+ device_map="auto",
363
+ trust_remote_code=True,
364
+ )
365
+
366
+ classifier_tokenizer = AutoTokenizer.from_pretrained(
367
+ base_phi_model,
368
+ trust_remote_code=True
369
+ )
370
+ classifier_tokenizer.padding_side = "left"
371
+ if classifier_tokenizer.pad_token is None:
372
+ classifier_tokenizer.pad_token = classifier_tokenizer.eos_token
373
+
374
+ classifier_backbone_peft = PeftModel.from_pretrained(
375
+ classifier_backbone_base,
376
+ classifier_adapter_repo
377
+ )
378
+ classifier_model = GPTSequenceClassifier(classifier_backbone_peft, num_labels=2)
379
+
380
+ # Download and load the custom classifier head's state dictionary
381
+ classifier_head_path = hf_hub_download(repo_id=classifier_adapter_repo, filename="classifier_head.pth")
382
+ classifier_model.classifier.load_state_dict(torch.load(classifier_head_path, map_location=device))
383
+
384
+ classifier_model.to(device)
385
+ classifier_model = classifier_model.to(torch.bfloat16)
386
+
387
+ classifier_model.eval() # Set model to evaluation mode
388
+
389
+ logger.info("Model loaded successfully")
390
+ return "Model loaded successfully!"
391
+
392
+ except Exception as e:
393
+ logger.error(f"Error loading model: {e}")
394
+ return f"Error loading model: {e}"
395
+
396
+ @spaces.GPU
397
+ def analyze_single(math_question: str, proposed_solution: str, debug: bool = False):
398
+ """
399
+ Single (question, solution) classifier.
400
+ Stage 1: computational check via Gemma extraction + evaluator.
401
+ Stage 2: conceptual/correct check via Phi-4 classifier.
402
+ Returns: {"classification": "...", "confidence": "...", "explanation": "..."}
403
+ """
404
+ # -----------------------------
405
+ # STAGE 1: COMPUTATIONAL CHECK
406
+ # -----------------------------
407
+ # 1) Format and extract equations
408
+ solution_json_str = format_solution_into_json_str(proposed_solution)
409
+ solution_dict = json.loads(solution_json_str)
410
+
411
+ messages = create_extractor_messages(solution_json_str)
412
+ prompt = gemma_tokenizer.apply_chat_template(
413
+ messages, tokenize=False, add_generation_prompt=True
414
+ )
415
+
416
+ inputs = gemma_tokenizer([prompt], return_tensors="pt").to(device)
417
+ outputs = gemma_model.generate(
418
+ **inputs,
419
+ max_new_tokens=300,
420
+ use_cache=True,
421
+ pad_token_id=gemma_tokenizer.pad_token_id,
422
+ do_sample=False,
423
+ temperature=0.0,
424
+ )
425
+
426
+ extracted_text = gemma_tokenizer.batch_decode(
427
+ outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
428
+ )[0]
429
+
430
+ if debug:
431
+ print("\n[Gemma raw output]\n", extracted_text)
432
+
433
+ extracted_eq_dict = extract_json_from_response(extracted_text)
434
+
435
+ # 2) Keep only lines that actually contain digits in the original text
436
+ final_eq_to_eval = {
437
+ k: v
438
+ for k, v in extracted_eq_dict.items()
439
+ if any(ch.isdigit() for ch in solution_dict.get(k, ""))
440
+ }
441
+ if debug:
442
+ print("\n[Equations to evaluate]\n", json.dumps(final_eq_to_eval, indent=2))
443
+
444
+ # 3) Evaluate
445
+ computational_error = evaluate_equations(final_eq_to_eval, solution_dict)
446
+ if computational_error.get("error"):
447
+ lhs = computational_error["sanitized_lhs"]
448
+ rhs = computational_error["original_rhs"]
449
+ correct_rhs = computational_error["correct_rhs"]
450
+ line_txt = computational_error.get("line_text", "")
451
+ explanation = (
452
+ "A computational error was found.\n"
453
+ f'On line: "{line_txt}"\n'
454
+ f"The student wrote '{lhs} = {rhs}', but the correct result of '{lhs}' is {correct_rhs}."
455
+ )
456
+ return {
457
+ "classification": "Computational Error",
458
+ "confidence": "100%",
459
+ "explanation": explanation,
460
+ }
461
+
462
+ # --------------------------
463
+ # STAGE 2: CONCEPTUAL CHECK
464
+ # --------------------------
465
+ input_text = (
466
+ f"{CLASSIFIER_SYSTEM_PROMPT}\n\n"
467
+ f"### Problem:\n{math_question}\n\n"
468
+ f"### Answer:\n{proposed_solution}"
469
+ )
470
+ cls_inputs = classifier_tokenizer(
471
+ input_text, return_tensors="pt", truncation=True, max_length=512
472
+ ).to(device)
473
+
474
+ with torch.no_grad():
475
+ logits = classifier_model(**cls_inputs)["logits"]
476
+ probs = torch.softmax(logits, dim=-1).squeeze()
477
+
478
+ is_correct_prob = float(probs[0])
479
+ is_flawed_prob = float(probs[1])
480
+
481
+ if debug:
482
+ print("\n[Phi-4 logits]", logits.to(torch.float32).cpu().numpy())
483
+ print("[Phi-4 probs] [Correct, Flawed]:", [is_correct_prob, is_flawed_prob])
484
+
485
+ if is_flawed_prob > 0.5:
486
+ return {
487
+ "classification": "Conceptual Error",
488
+ "confidence": f"{is_flawed_prob:.2%}",
489
+ "explanation": "Logic or setup appears to have a conceptual error.",
490
+ }
491
+ else:
492
+ return {
493
+ "classification": "Correct",
494
+ "confidence": f"{is_correct_prob:.2%}",
495
+ "explanation": "Solution appears correct.",
496
+ }
497
+
498
+
499
+
500
+ @spaces.GPU
501
+ def classify_solution(question: str, solution: str):
502
+ """
503
+ Classify the math solution
504
+ Returns: (classification_label, confidence_score, explanation)
505
+ """
506
+ if not question.strip() or not solution.strip():
507
+ return "Please fill in both fields", 0.0, ""
508
+
509
+ if not model or not tokenizer:
510
+ return "Model not loaded", 0.0, ""
511
+
512
+ try:
513
+ res = analyze_single(question, solution)
514
+
515
+ return list(res.values())
516
+
517
+
518
+
519
+ # Load model on startup
520
+ load_model()
521
+
522
+ # Create Gradio interface
523
+ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
524
+ gr.Markdown("# 🧮 Math Solution Classifier")
525
+ gr.Markdown("Classify math solutions as correct, conceptually flawed, or computationally flawed.")
526
+
527
+ with gr.Row():
528
+ with gr.Column():
529
+ question_input = gr.Textbox(
530
+ label="Math Question",
531
+ placeholder="e.g., Solve for x: 2x + 5 = 13",
532
+ lines=3
533
+ )
534
+
535
+ solution_input = gr.Textbox(
536
+ label="Proposed Solution",
537
+ placeholder="e.g., 2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4",
538
+ lines=5
539
+ )
540
+
541
+ classify_btn = gr.Button("Classify Solution", variant="primary")
542
+
543
+ with gr.Column():
544
+ classification_output = gr.Textbox(label="Classification", interactive=False)
545
+ confidence_output = gr.Textbox(label="Confidence", interactive=False)
546
+ explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3)
547
+
548
+ # Examples
549
+ gr.Examples(
550
+ examples=[
551
+ [
552
+ "Solve for x: 2x + 5 = 13",
553
+ "2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4"
554
+ ],
555
+ [
556
+ "John has three apples and Mary has seven, how many apples do they have together?",
557
+ "They have 7 + 3 = 11 apples." # This should be computationally flawed
558
+ ],
559
+ [
560
+ "What is 15% of 200?",
561
+ "15% = 15/100 = 0.15\n0.15 × 200 = 30"
562
+ ]
563
+ ],
564
+ inputs=[question_input, solution_input]
565
+ )
566
+
567
+ classify_btn.click(
568
+ fn=classify_solution,
569
+ inputs=[question_input, solution_input],
570
+ outputs=[classification_output, confidence_output, explanation_output]
571
+ )
572
+
573
+ if __name__ == "__main__":
574
  app.launch()