mcamargo00 commited on
Commit
76ebeac
verified
1 Parent(s): a8f4e5d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +183 -352
  2. requirements.txt +5 -1
app.py CHANGED
@@ -17,7 +17,6 @@ from huggingface_hub import hf_hub_download
17
  import json
18
  import re
19
  import math
20
- import time
21
 
22
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
 
@@ -65,271 +64,99 @@ class GPTSequenceClassifier(nn.Module):
65
  loss = nn.functional.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
66
  return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
67
 
68
-
69
-
70
 
 
 
 
 
71
  # --- Helper Functions ---
72
- def format_solution_into_json_str(solution_text: str) -> str:
73
- lines = [line.strip() for line in solution_text.strip().split('\n') if line.strip()]
74
- final_answer = ""
75
- if lines and "FINAL ANSWER:" in lines[-1].upper():
76
- final_answer = lines[-1][len("FINAL ANSWER:"):].strip()
77
- lines = lines[:-1]
78
- solution_dict = {f"L{i+1}": line for i, line in enumerate(lines)}
79
- solution_dict["FA"] = final_answer
80
- return json.dumps(solution_dict, indent=2)
81
 
82
  def sanitize_equation_string(expression: str) -> str:
 
 
 
83
  if not isinstance(expression, str):
84
  return ""
85
 
86
- s = expression.strip()
87
-
88
- # Normalize common symbols
89
- s = s.replace('', '*').replace('', '*').replace('x', '*')
90
-
91
- # Convert percentages like '12%' -> '(12/100)'
92
- s = re.sub(r'(?<!\d)(\d+(?:\.\d+)?)\s*%', r'(\1/100)', s)
93
-
94
- # Simple paren balancing trims (only when a single stray exists at an edge)
95
- if s.count('(') > s.count(')') and s.startswith('('):
96
- s = s[1:]
97
- elif s.count(')') > s.count('(') and s.endswith(')'):
98
- s = s[:-1]
99
-
100
- # Drop units right after a slash: /hr, /dogs
101
- s = re.sub(r'/([a-zA-Z]+)', '', s)
102
 
103
- # Keep only numeric math tokens
104
- s = re.sub(r'[^\d.()+\-*/=%]', '', s)
 
 
 
105
 
106
- # Collapse repeated '=' (e.g., '==24/2=12' -> '=24/2=12')
107
- s = re.sub(r'=+', '=', s)
108
-
109
- return s
110
-
111
- import re, math
112
-
113
- def _parse_equation(eq_str: str):
114
- s = sanitize_equation_string(eq_str or "")
115
- s = s.lstrip('=') # handle lines like '=24/2=12'
116
- if '=' not in s:
117
- return None
118
- if s.count('=') > 1:
119
- pos = s.rfind('=')
120
- lhs, rhs = s[:pos], s[pos+1:]
121
- else:
122
- lhs, rhs = s.split('=', 1)
123
- lhs, rhs = lhs.strip(), rhs.strip()
124
- if not lhs or not rhs:
125
- return None
126
- return lhs, rhs
127
-
128
- def _abs_tol_from_display(rhs_str: str):
129
- """
130
- If RHS is a single numeric literal like 0.33, use half-ULP at that precision.
131
- e.g., '0.33' -> 0.5 * 10^-2 = 0.005
132
- Otherwise return None and fall back to base tolerances.
133
  """
134
- s = rhs_str.strip()
135
- # allow optional parens and sign
136
- m = re.fullmatch(r'\(?\s*[-+]?\d+(?:\.(\d+))?\s*\)?', s)
137
- if not m:
138
- return None
139
- frac = m.group(1) or ""
140
- d = len(frac)
141
- return 0.5 * (10 ** (-d)) if d > 0 else 0.5 # if integer shown, allow 卤0.5
142
-
143
- def evaluate_equations(eq_dict: dict, sol_dict: dict,
144
- base_rel_tol: float = 1e-6,
145
- base_abs_tol: float = 1e-9,
146
- honor_display_precision: bool = True):
147
  """
148
- Evaluates extracted equations. Accepts rounded RHS values based on displayed precision.
149
- """
150
- for key, eq_str in (eq_dict or {}).items():
151
- parsed = _parse_equation(eq_str)
152
- if not parsed:
153
  continue
154
- lhs, rhs_str = parsed
155
-
156
  try:
157
- lhs_val = eval(lhs, {"__builtins__": None}, {})
158
- rhs_val = eval(rhs_str, {"__builtins__": None}, {})
159
- except Exception:
160
- continue
161
-
162
- # dynamic absolute tolerance from RHS formatting (e.g., 0.33 -> 0.005)
163
- abs_tol = base_abs_tol
164
- if honor_display_precision:
165
- dyn = _abs_tol_from_display(rhs_str)
166
- if dyn is not None:
167
- abs_tol = max(abs_tol, dyn)
168
-
169
- if not math.isclose(lhs_val, rhs_val, rel_tol=base_rel_tol, abs_tol=abs_tol):
170
- correct_rhs_val = round(lhs_val, 6)
171
- correct_rhs_str = f"{correct_rhs_val:.6f}".rstrip('0').rstrip('.')
172
- return {
173
- "error": True,
174
- "line_key": key,
175
- "line_text": sol_dict.get(key, "N/A"),
176
- "original_flawed_calc": eq_str,
177
- "sanitized_lhs": lhs,
178
- "original_rhs": rhs_str,
179
- "correct_rhs": correct_rhs_str,
180
- }
181
-
182
- return {"error": False}
183
-
184
 
 
 
185
 
 
186
 
 
 
187
 
188
- def extract_json_from_response(response: str) -> dict:
189
- """
190
- Recover equations from the extractor's output.
191
-
192
- Strategy:
193
- 1) Try to parse a real JSON object (if present).
194
- 2) Parse relaxed key-value lines like 'L1: ...' or 'FA=...'.
195
- 3) Also fall back to linewise equations (e.g., '=24/2=12', '7*2=14') and
196
- merge them as L1, L2, ... preserving order. Keep FA if present.
197
- """
198
- out = {}
199
-
200
- if not response or not isinstance(response, str):
201
- return out
202
-
203
- text = response.strip()
204
 
205
- # --- 1) strict JSON block, if any ---
206
- m = re.search(r'\{.*\}', text, flags=re.S)
207
- if m:
208
- try:
209
- obj = json.loads(m.group(0))
210
- if isinstance(obj, dict) and any(k.startswith("L") for k in obj):
211
- return obj
212
- elif isinstance(obj, dict):
213
- out.update(obj) # keep FA etc., then continue
 
 
 
 
 
214
  except Exception:
215
- pass
216
-
217
- # --- 2) relaxed key/value lines: Lk : value or FA = value ---
218
- relaxed = {}
219
- for ln in text.splitlines():
220
- ln = ln.strip().strip(',')
221
- if not ln:
222
  continue
223
- m = re.match(r'(?i)^(L\d+|FA)\s*[:=]\s*(.+?)\s*$', ln)
224
- if m:
225
- k = m.group(1).strip()
226
- v = m.group(2).strip().rstrip(',')
227
- relaxed[k] = v
228
- out.update(relaxed)
229
-
230
- # Count how many L-keys we already have
231
- existing_L = sorted(
232
- int(k[1:]) for k in out.keys()
233
- if k.startswith("L") and k[1:].isdigit()
234
- )
235
- next_L = (max(existing_L) + 1) if existing_L else 1
236
-
237
- # --- 3) linewise fallback: harvest bare equations and merge ---
238
- def _looks_like_equation(s: str) -> str | None:
239
- s = sanitize_equation_string(s or "").lstrip('=')
240
- if '=' in s and any(ch.isdigit() for ch in s):
241
- return s
242
- return None
243
-
244
- # set of existing equation strings to avoid duplicates
245
- seen_vals = set(v for v in out.values() if isinstance(v, str))
246
-
247
- for ln in text.splitlines():
248
- ln = ln.strip().strip(',')
249
- if not ln or re.match(r'(?i)^(L\d+|FA)\s*[:=]', ln):
250
- # skip lines we already captured as relaxed pairs
251
- continue
252
- eq = _looks_like_equation(ln)
253
- if eq and eq not in seen_vals:
254
- out[f"L{next_L}"] = eq
255
- seen_vals.add(eq)
256
- next_L += 1
257
 
258
- return out
259
 
260
  # --- Prompts ---
261
  EXTRACTOR_SYSTEM_PROMPT = \
262
  """[ROLE]
263
  You are an expert at parsing mathematical solutions.
 
264
  [TASK]
265
- You are given a mathematical solution. Your task is to extract the calculation performed on each line and represent it as a simple equation.
 
266
  **This is a literal transcription task. Follow these rules with extreme precision:**
267
  - **RULE 1: Transcribe EXACTLY.** Do not correct mathematical errors. If a line implies `2+2=5`, your output for that line must be `2+2=5`.
268
- - **RULE 2: Isolate the Equation.** Your output must contain ONLY the equation. Do not include any surrounding text, units (like `/hour`), or currency symbols (like `$`).
269
- - **RULE 3: Use Standard Operators.** Always use `*` for multiplication. Never use `x`.
270
  [RESPONSE FORMAT]
271
- Your response must be ONLY a single, valid JSON object, adhering strictly to these rules:
272
- For each line of the solution, create a key-value pair.
273
- - The key should be the line identifier (e.g., "L1", "L2", "FA" for the final answer line).
274
- - The value should be the extracted equation string (e.g., "10+5=15").
275
- - If a line contains no calculation, the value must be an empty string.
276
  """
277
-
278
  CLASSIFIER_SYSTEM_PROMPT = \
279
  """You are a mathematics tutor.
280
  You will be given a math word problem and a solution written by a student.
281
  Carefully analyze the problem and solution LINE-BY-LINE and determine whether there are any errors in the solution."""
282
 
283
- # --- Example 1 ---
284
- FEW_SHOT_EXAMPLE_1_SOLUTION = {
285
- "L1": "2% of $90 is (2/100)*$90 = $1.8",
286
- "L2": "2% of $60 is (2/100)*$60 = $1.2",
287
- "L3": "The second transaction was reversed without the service charge so only a total of $90+$1.8+$1.2 = $39 was deducted from his account",
288
- "L4": "He will have a balance of $400-$39 = $361",
289
- "FA": "361"
290
- }
291
-
292
- FEW_SHOT_EXAMPLE_1_EQUATIONS = {
293
- "L1": "(2/100)*90=1.8",
294
- "L2": "(2/100)*60=1.2",
295
- "L3": "90+1.8+1.2=39",
296
- "L4": "400-39=361",
297
- "FA": ""
298
- }
299
-
300
-
301
- # --- Example 2 ---
302
- FEW_SHOT_EXAMPLE_2_SOLUTION = {
303
- "L1": "She drinks 2 bottles a day and there are 24 bottles in a case so a case will last 24/2 = 12 days",
304
- "L2": "She needs enough to last her 240 days and 1 case will last her 12 days so she needs 240/12 = 20 cases",
305
- "L3": "Each case is on sale for $12.00 and she needs 20 cases so that's 12*20 = $240.00",
306
- "FA": "240"
307
- }
308
-
309
- FEW_SHOT_EXAMPLE_2_EQUATIONS = {
310
- "L1": "24/2=12",
311
- "L2": "240/12=20",
312
- "L3": "12*20=240.00",
313
- "FA": ""
314
- }
315
-
316
- def create_extractor_messages(solution_json_str: str) -> list:
317
- """
318
- Returns a list of dictionaries representing the conversation history for the prompt.
319
- """
320
- # Start with the constant few-shot examples defined globally
321
- messages = [
322
- {"role": "user", "content": f"{EXTRACTOR_SYSTEM_PROMPT}\n\n### Solution:\n{json.dumps(FEW_SHOT_EXAMPLE_1_SOLUTION, indent=2)}"},
323
- {"role": "assistant", "content": json.dumps(FEW_SHOT_EXAMPLE_1_EQUATIONS, indent=2)},
324
- {"role": "user", "content": f"### Solution:\n{json.dumps(FEW_SHOT_EXAMPLE_2_SOLUTION, indent=2)}"},
325
- {"role": "assistant", "content": json.dumps(FEW_SHOT_EXAMPLE_2_EQUATIONS, indent=2)},
326
- ]
327
-
328
- # Add the final user query to the end of the conversation
329
- final_user_prompt = f"### Solution:\n{solution_json_str}"
330
- messages.append({"role": "user", "content": final_user_prompt})
331
-
332
- return messages
333
 
334
  gemma_model = None
335
  gemma_tokenizer = None
@@ -344,51 +171,38 @@ def load_model():
344
  device = DEVICE
345
 
346
  # --- Model 1: Equation Extractor (Gemma-3 with Unsloth) ---
347
- extractor_adapter_repo = "arvindsuresh-math/gemma-3-1b-equation-extractor-lora"
348
  base_gemma_model = "unsloth/gemma-3-1b-it-unsloth-bnb-4bit"
349
 
350
  gemma_model, gemma_tokenizer = FastModel.from_pretrained(
351
  model_name=base_gemma_model,
352
- max_seq_length=2048,
353
  dtype=None,
354
  load_in_4bit=True,
355
  )
356
-
357
- # --- Gemma tokenizer hygiene (fix the right-padding warning) ---
358
- if gemma_tokenizer.pad_token is None:
359
- gemma_tokenizer.pad_token = gemma_tokenizer.eos_token
360
- gemma_tokenizer.padding_side = "left" # align last tokens across the batch
361
-
362
  gemma_model = PeftModel.from_pretrained(gemma_model, extractor_adapter_repo)
363
 
364
-
365
  # --- Model 2: Conceptual Error Classifier (Phi-4) ---
366
  classifier_adapter_repo = "arvindsuresh-math/phi-4-error-binary-classifier"
367
  base_phi_model = "microsoft/Phi-4-mini-instruct"
368
 
369
- # T4 does fp16 (not bf16)
370
- DTYPE = torch.float32
371
  quantization_config = BitsAndBytesConfig(
372
  load_in_4bit=True,
373
  bnb_4bit_quant_type="nf4",
374
- bnb_4bit_compute_dtype=DTYPE,
375
- )
376
-
377
-
378
  classifier_backbone_base = AutoModelForCausalLM.from_pretrained(
379
  base_phi_model,
380
  quantization_config=quantization_config,
381
- device_map={"": 0},
382
- trust_remote_code=False, # keep this if you switched it earlier
383
- # safest with eager attention when mixing kernels:
384
- attn_implementation="eager",
385
- )
386
 
387
  classifier_tokenizer = AutoTokenizer.from_pretrained(
388
  base_phi_model,
389
- trust_remote_code=False # <-- match above
390
- )
391
-
392
  classifier_tokenizer.padding_side = "left"
393
  if classifier_tokenizer.pad_token is None:
394
  classifier_tokenizer.pad_token = classifier_tokenizer.eos_token
@@ -404,121 +218,138 @@ def load_model():
404
  classifier_model.classifier.load_state_dict(torch.load(classifier_head_path, map_location=device))
405
 
406
  classifier_model.to(device)
407
- classifier_model = classifier_model.to(device=DEVICE, dtype=torch.float32)
408
 
409
- classifier_model.eval() # Set model to evaluation mode
410
-
411
- logger.info("Model loaded successfully")
412
- return "Model loaded successfully!"
413
 
414
  except Exception as e:
415
  logger.error(f"Error loading model: {e}")
416
  return f"Error loading model: {e}"
417
  def models_ready():
418
  return all([gemma_model, gemma_tokenizer, classifier_model, classifier_tokenizer])
 
 
 
 
 
 
419
 
420
- def analyze_single(math_question: str, proposed_solution: str, debug: bool = False):
421
  """
422
- Single (question, solution) classifier.
423
- Stage 1: computational check via Gemma extraction + evaluator.
424
- Stage 2: conceptual/correct check via Phi-4 classifier.
425
- Returns: {"classification": "...", "confidence": "...", "explanation": "..."}
426
  """
427
- global DEVICE
428
- device = DEVICE
429
- # -----------------------------
430
- # STAGE 1: COMPUTATIONAL CHECK
431
- # -----------------------------
432
- # 1) Format and extract equations
433
- solution_json_str = format_solution_into_json_str(proposed_solution)
434
- solution_dict = json.loads(solution_json_str)
435
-
436
- messages = create_extractor_messages(solution_json_str)
437
- prompt = gemma_tokenizer.apply_chat_template(
438
- messages, tokenize=False, add_generation_prompt=True
439
- )
 
 
 
 
 
 
 
 
 
440
 
441
- inputs = gemma_tokenizer([prompt], return_tensors="pt").to(device)
442
- outputs = gemma_model.generate(
443
- **inputs,
444
- max_new_tokens=300,
445
- use_cache=True,
446
- pad_token_id=gemma_tokenizer.pad_token_id,
447
- do_sample=False,
448
- temperature=0.0,
449
- )
450
 
451
- extracted_text = gemma_tokenizer.batch_decode(
452
- outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
453
- )[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
- if debug:
456
- print("\n[Gemma raw output]\n", extracted_text)
 
 
 
457
 
458
- extracted_eq_dict = extract_json_from_response(extracted_text)
459
 
460
- # 2) Keep only lines that actually contain digits in the original text
461
- final_eq_to_eval = {
462
- k: v
463
- for k, v in extracted_eq_dict.items()
464
- if any(ch.isdigit() for ch in solution_dict.get(k, ""))
465
- }
466
- if debug:
467
- print("\n[Equations to evaluate]\n", json.dumps(final_eq_to_eval, indent=2))
468
-
469
- # 3) Evaluate
470
- computational_error = evaluate_equations(final_eq_to_eval, solution_dict)
471
- if computational_error.get("error"):
472
- lhs = computational_error["sanitized_lhs"]
473
- rhs = computational_error["original_rhs"]
474
- correct_rhs = computational_error["correct_rhs"]
475
- line_txt = computational_error.get("line_text", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  explanation = (
477
- "A computational error was found.\n"
478
- f'On line: "{line_txt}"\n'
479
- f"The student wrote '{lhs} = {rhs}', but the correct result of '{lhs}' is {correct_rhs}."
480
  )
481
- return {
482
- "classification": "Computational Error",
483
- "confidence": "100%",
484
- "explanation": explanation,
485
- }
486
-
487
- # --------------------------
488
- # STAGE 2: CONCEPTUAL CHECK
489
- # --------------------------
490
- input_text = (
491
- f"{CLASSIFIER_SYSTEM_PROMPT}\n\n"
492
- f"### Problem:\n{math_question}\n\n"
493
- f"### Answer:\n{proposed_solution}"
494
- )
495
- cls_inputs = classifier_tokenizer(
496
- input_text, return_tensors="pt", truncation=True, max_length=512
497
- ).to(device)
498
-
499
- with torch.no_grad():
500
- logits = classifier_model(**cls_inputs)["logits"]
501
- probs = torch.softmax(logits, dim=-1).squeeze()
502
-
503
- is_correct_prob = float(probs[0])
504
- is_flawed_prob = float(probs[1])
505
-
506
- if debug:
507
- print("\n[Phi-4 logits]", logits.to(torch.float32).cpu().numpy())
508
- print("[Phi-4 probs] [Correct, Flawed]:", [is_correct_prob, is_flawed_prob])
509
-
510
- if is_flawed_prob > 0.5:
511
- return {
512
- "classification": "Conceptual Error",
513
- "confidence": f"{is_flawed_prob:.2%}",
514
- "explanation": "Logic or setup appears to have a conceptual error.",
515
- }
516
  else:
517
- return {
518
- "classification": "Correct",
519
- "confidence": f"{is_correct_prob:.2%}",
520
- "explanation": "Solution appears correct.",
521
- }
 
 
 
 
 
 
 
 
 
 
522
 
523
 
524
  def classify_solution(question: str, solution: str):
@@ -533,9 +364,9 @@ def classify_solution(question: str, solution: str):
533
  return "Models not loaded", 0.0, ""
534
 
535
  try:
536
- res = analyze_single(question, solution)
537
 
538
- return res["classification"], res["confidence"], res["explanation"]
539
  except Exception:
540
  logger.exception("inference failed")
541
 
 
17
  import json
18
  import re
19
  import math
 
20
 
21
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
 
 
64
  loss = nn.functional.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
65
  return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
66
 
 
 
67
 
68
+ # ===================================================================
69
+ # 3. HELPERS
70
+ # ===================================================================
71
+
72
  # --- Helper Functions ---
73
+ def extract_equation_from_response(response: str) -> str | None:
74
+ """Extracts content from between <eq> and </eq> tags."""
75
+ match = re.search(r'<eq>(.*?)</eq>', response, re.DOTALL)
76
+ return match.group(1) if match else None
 
 
 
 
 
77
 
78
  def sanitize_equation_string(expression: str) -> str:
79
+ """
80
+ Enhanced version with your requested simplified parenthesis logic.
81
+ """
82
  if not isinstance(expression, str):
83
  return ""
84
 
85
+ # Your requested parenthesis logic:
86
+ if expression.count('(') > expression.count(')') and expression.startswith('('):
87
+ expression = expression[1:]
88
+ elif expression.count(')') > expression.count('(') and expression.endswith(')'):
89
+ expression = expression[:-1]
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ sanitized = expression.replace(' ', '')
92
+ sanitized = sanitized.replace('x', '*').replace('脳', '*')
93
+ sanitized = re.sub(r'/([a-zA-Z]+)', '', sanitized)
94
+ sanitized = re.sub(r'[^\d.()+\-*/=]', '', sanitized)
95
+ return sanitized
96
 
97
+ def evaluate_equations(eq_dict: dict, sol_dict: dict):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  """
99
+ Evaluates extracted equations and returns a more detailed dictionary for
100
+ building clearer explanations.
 
 
 
 
 
 
 
 
 
 
 
101
  """
102
+ for key, eq_str in eq_dict.items():
103
+ if not eq_str or "=" not in eq_str:
 
 
 
104
  continue
 
 
105
  try:
106
+ sanitized_eq = sanitize_equation_string(eq_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ if not sanitized_eq or "=" not in sanitized_eq:
109
+ continue
110
 
111
+ lhs, rhs_str = sanitized_eq.split('=', 1)
112
 
113
+ if not lhs or not rhs_str:
114
+ continue
115
 
116
+ lhs_val = eval(lhs, {"__builtins__": None}, {})
117
+ rhs_val = eval(rhs_str, {"__builtins__": None}, {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ if not math.isclose(lhs_val, rhs_val, rel_tol=1e-2):
120
+ correct_rhs_val = round(lhs_val, 4)
121
+ correct_rhs_str = f"{correct_rhs_val:.4f}".rstrip('0').rstrip('.')
122
+
123
+ # Return a more detailed dictionary for better explanations
124
+ return {
125
+ "error": True,
126
+ "line_key": key,
127
+ "line_text": sol_dict.get(key, "N/A"),
128
+ "original_flawed_calc": eq_str, # The raw model output
129
+ "sanitized_lhs": lhs, # The clean left side
130
+ "original_rhs": rhs_str, # The clean right side
131
+ "correct_rhs": correct_rhs_str, # The correct right side
132
+ }
133
  except Exception:
 
 
 
 
 
 
 
134
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ return {"error": False}
137
 
138
  # --- Prompts ---
139
  EXTRACTOR_SYSTEM_PROMPT = \
140
  """[ROLE]
141
  You are an expert at parsing mathematical solutions.
142
+
143
  [TASK]
144
+ You are given a single line from a mathematical solution. Your task is to extract the calculation from this line.
145
+
146
  **This is a literal transcription task. Follow these rules with extreme precision:**
147
  - **RULE 1: Transcribe EXACTLY.** Do not correct mathematical errors. If a line implies `2+2=5`, your output for that line must be `2+2=5`.
148
+ - **RULE 2: Isolate the Equation.** Your output must contain ONLY the equation, with no surrounding text, units, or currency symbols. Always use `*` for multiplication.
149
+
150
  [RESPONSE FORMAT]
151
+ Your response must ONLY contain the extracted equation, wrapped in <eq> and </eq> tags.
152
+ If the line contains no calculation, respond with empty tags: <eq></eq>.
 
 
 
153
  """
 
154
  CLASSIFIER_SYSTEM_PROMPT = \
155
  """You are a mathematics tutor.
156
  You will be given a math word problem and a solution written by a student.
157
  Carefully analyze the problem and solution LINE-BY-LINE and determine whether there are any errors in the solution."""
158
 
159
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  gemma_model = None
162
  gemma_tokenizer = None
 
171
  device = DEVICE
172
 
173
  # --- Model 1: Equation Extractor (Gemma-3 with Unsloth) ---
174
+ extractor_adapter_repo = "arvindsuresh-math/gemma-3-1b-equation-line-extractor-aug-10"
175
  base_gemma_model = "unsloth/gemma-3-1b-it-unsloth-bnb-4bit"
176
 
177
  gemma_model, gemma_tokenizer = FastModel.from_pretrained(
178
  model_name=base_gemma_model,
179
+ max_seq_length=350,
180
  dtype=None,
181
  load_in_4bit=True,
182
  )
 
 
 
 
 
 
183
  gemma_model = PeftModel.from_pretrained(gemma_model, extractor_adapter_repo)
184
 
 
185
  # --- Model 2: Conceptual Error Classifier (Phi-4) ---
186
  classifier_adapter_repo = "arvindsuresh-math/phi-4-error-binary-classifier"
187
  base_phi_model = "microsoft/Phi-4-mini-instruct"
188
 
189
+ DTYPE = torch.float16
 
190
  quantization_config = BitsAndBytesConfig(
191
  load_in_4bit=True,
192
  bnb_4bit_quant_type="nf4",
193
+ bnb_4bit_compute_dtype=DTYPE
194
+ )
 
 
195
  classifier_backbone_base = AutoModelForCausalLM.from_pretrained(
196
  base_phi_model,
197
  quantization_config=quantization_config,
198
+ device_map="auto",
199
+ trust_remote_code=True,
200
+ )
 
 
201
 
202
  classifier_tokenizer = AutoTokenizer.from_pretrained(
203
  base_phi_model,
204
+ trust_remote_code=True
205
+ )
 
206
  classifier_tokenizer.padding_side = "left"
207
  if classifier_tokenizer.pad_token is None:
208
  classifier_tokenizer.pad_token = classifier_tokenizer.eos_token
 
218
  classifier_model.classifier.load_state_dict(torch.load(classifier_head_path, map_location=device))
219
 
220
  classifier_model.to(device)
221
+ classifier_model = classifier_model.to(torch.float16)
222
 
223
+ classifier_model.eval() # Set model to evaluation mode
 
 
 
224
 
225
  except Exception as e:
226
  logger.error(f"Error loading model: {e}")
227
  return f"Error loading model: {e}"
228
  def models_ready():
229
  return all([gemma_model, gemma_tokenizer, classifier_model, classifier_tokenizer])
230
+
231
+
232
+
233
+ # ===================================================================
234
+ # 4. PIPELINE COMPONENTS
235
+ # ===================================================================
236
 
237
+ def run_conceptual_check(question: str, solution: str, model, tokenizer) -> dict:
238
  """
239
+ STAGE 1: Runs the Phi-4 classifier with memory optimizations.
 
 
 
240
  """
241
+ input_text = f"{CLASSIFIER_SYSTEM_PROMPT}\n\n### Problem:\n{question}\n\n### Answer:\n{solution}"
242
+ inputs = tokenizer(
243
+ input_text,
244
+ return_tensors="pt",
245
+ truncation=True,
246
+ max_length=512).to(device)
247
+
248
+ # Use inference_mode and disable cache for better performance and memory management
249
+ with torch.inference_mode():
250
+ outputs = model(**inputs, use_cache=False)
251
+
252
+ # Explicitly cast logits to float32 for stable downstream processing
253
+ logits = outputs["logits"].to(torch.float32)
254
+ probs = torch.softmax(logits, dim=-1).squeeze().tolist()
255
+
256
+ is_flawed_prob = probs[1]
257
+ prediction = "flawed" if is_flawed_prob > 0.5 else "correct"
258
+
259
+ return {
260
+ "prediction": prediction,
261
+ "probabilities": {"correct": probs[0], "flawed": probs[1]}
262
+ }
263
 
 
 
 
 
 
 
 
 
 
264
 
265
+ def run_computational_check(solution: str, model, tokenizer, batch_size: int = 32) -> dict:
266
+ """
267
+ STAGE 2: Splits a solution into lines and performs a batched computational check.
268
+ (Corrected to handle PEMDAS/parentheses)
269
+ """
270
+ lines = [line.strip() for line in solution.strip().split('\n') if line.strip() and "FINAL ANSWER:" not in line.upper()]
271
+ if not lines:
272
+ return {"error": False}
273
+
274
+ # Create a batch of prompts, one for each line
275
+ prompts = []
276
+ for line in lines:
277
+ messages = [{"role": "user", "content": f"{EXTRACTOR_SYSTEM_PROMPT}\n\n### Solution Line:\n{line}"}]
278
+ prompts.append(tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True))
279
+
280
+ # Run batched inference
281
+ tokenizer.padding_side = "left"
282
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
283
+ tokenizer.padding_side = "left"
284
+ outputs = model.generate(**inputs, max_new_tokens=64, use_cache=True, pad_token_id=tokenizer.pad_token_id)
285
+ tokenizer.padding_side = "left"
286
+ decoded_outputs = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
287
+
288
+ # Evaluate each line's extracted equation
289
+ for i, raw_output in enumerate(decoded_outputs):
290
+ equation = extract_equation_from_response(raw_output)
291
+ if not equation or "=" not in equation:
292
+ continue
293
 
294
+ try:
295
+ # Sanitize the full equation string, preserving parentheses
296
+ sanitized_eq = sanitize_equation_string(equation)
297
+ if "=" not in sanitized_eq:
298
+ continue
299
 
300
+ lhs, rhs_str = sanitized_eq.split('=', 1)
301
 
302
+ # Evaluate the sanitized LHS, which now correctly includes parentheses
303
+ lhs_val = eval(lhs, {"__builtins__": None}, {})
304
+
305
+ # Compare with the RHS
306
+ if not math.isclose(lhs_val, float(rhs_str), rel_tol=1e-2):
307
+ return {
308
+ "error": True,
309
+ "line_text": lines[i],
310
+ "correct_calc": f"{lhs} = {round(lhs_val, 4)}"
311
+ }
312
+ except Exception:
313
+ continue # Skip lines where evaluation fails
314
+
315
+ return {"error": False}
316
+
317
+
318
+ def analyze_solution(question: str, solution: str):
319
+ """
320
+ Main orchestrator that runs the full pipeline and generates the final explanation.
321
+ """
322
+ # STAGE 1: Conceptual Check (Fast)
323
+ conceptual_result = run_conceptual_check(question, solution, classifier_model, classifier_tokenizer)
324
+ confidence = conceptual_result['probabilities'][conceptual_result['prediction']]
325
+
326
+ # STAGE 2: Computational Check (Slower, Batched)
327
+ computational_result = run_computational_check(solution, gemma_model, gemma_tokenizer)
328
+
329
+ # FINAL VERDICT LOGIC
330
+ if computational_result["error"]:
331
+ classification = "computational_error"
332
  explanation = (
333
+ f"A calculation error was found.\n"
334
+ f"On the line: \"{computational_result['line_text']}\"\n"
335
+ f"The correct calculation should be: {computational_result['correct_calc']}"
336
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  else:
338
+ # If calculations are fine, the final verdict is the conceptual one.
339
+ if conceptual_result['prediction'] == 'correct':
340
+ classification = 'correct'
341
+ explanation = "All calculations are correct and the overall logic appears to be sound."
342
+ else: # This now correctly corresponds to 'flawed'
343
+ classification = 'conceptual_error' # Produce the user-facing label
344
+ explanation = "All calculations are correct, but there appears to be a conceptual error in the logic or setup of the solution."
345
+ final_verdict = {
346
+ "classification": classification,
347
+ "explanation": explanation
348
+ }
349
+
350
+ return final_verdict
351
+
352
+
353
 
354
 
355
  def classify_solution(question: str, solution: str):
 
364
  return "Models not loaded", 0.0, ""
365
 
366
  try:
367
+ res = analyze_solution(question, solution)
368
 
369
+ return res["classification"], res["explanation"]
370
  except Exception:
371
  logger.exception("inference failed")
372
 
requirements.txt CHANGED
@@ -1,7 +1,11 @@
1
  gradio
2
  torch
3
  transformers
 
4
  peft
 
 
5
  accelerate
6
  spaces
7
- unsloth
 
 
1
  gradio
2
  torch
3
  transformers
4
+ bitsandbytes
5
  peft
6
+ trl
7
+ triton
8
  accelerate
9
  spaces
10
+ unsloth
11
+ unsloth_zoo