mcamargo00 commited on
Commit
5a7a017
ยท
verified ยท
1 Parent(s): 4c7dba1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -40
app.py CHANGED
@@ -1,18 +1,28 @@
1
  # app.py โ”€โ”€ Math-solution classifier for HF Spaces
2
- # Requires: gradio, torch, transformers, peft, accelerate, spaces
 
 
3
 
4
  import os
 
5
  import logging
 
6
 
7
  import gradio as gr
8
  import torch
9
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
 
10
 
11
- # Optional PEFT import (only available if you include it in requirements.txt)
12
  try:
13
- from peft import AutoPeftModelForSequenceClassification
 
 
 
14
  PEFT_AVAILABLE = True
15
- except ImportError:
16
  PEFT_AVAILABLE = False
17
 
18
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@@ -21,55 +31,75 @@ except ImportError:
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
- ADAPTER_PATH = os.getenv("ADAPTER_PATH", "./lora_adapter") # local dir or Hub ID
25
  FALLBACK_MODEL = "distilbert-base-uncased"
26
- LABELS = {0: "โœ… Correct",
27
- 1: "๐Ÿค” Conceptual Error",
28
- 2: "๐Ÿ”ข Computational Error"}
29
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
- model = None
 
32
  tokenizer = None
 
33
 
34
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
35
- # Load model & tokenizer
36
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
37
  def load_model():
38
- """Load the LoRA adapter if present, otherwise a baseline classifier."""
39
- global model, tokenizer
 
 
40
 
41
  if PEFT_AVAILABLE and os.path.isdir(ADAPTER_PATH):
42
- logger.info(f"Loading LoRA adapter from {ADAPTER_PATH}")
43
- model = AutoPeftModelForSequenceClassification.from_pretrained(
44
- ADAPTER_PATH,
45
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
46
- device_map="auto" if device == "cuda" else None,
47
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)
 
49
  else:
50
- logger.warning("LoRA adapter not found โ€“ falling back to baseline model")
51
  tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
52
- model = AutoModelForSequenceClassification.from_pretrained(
53
  FALLBACK_MODEL,
54
  num_labels=3,
55
  ignore_mismatched_sizes=True,
 
56
  )
 
57
 
 
58
  if tokenizer.pad_token is None:
59
  tokenizer.pad_token = tokenizer.eos_token or tokenizer.sep_token
60
 
61
  model.to(device)
62
  model.eval()
63
- logger.info("Model & tokenizer ready")
64
 
65
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€๏ฟฝ๏ฟฝ๏ฟฝโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
66
- # Inference helper
67
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
68
- def classify(question: str, solution: str):
69
- """Return (label, confidence, placeholder-explanation)."""
70
- if not question.strip() or not solution.strip():
71
- return "Please provide both question and solution.", "", ""
72
-
73
  text = f"Question: {question}\n\nSolution:\n{solution}"
74
  inputs = tokenizer(
75
  text,
@@ -81,11 +111,64 @@ def classify(question: str, solution: str):
81
 
82
  with torch.no_grad():
83
  logits = model(**inputs).logits
84
- probs = torch.softmax(logits, dim=-1)[0]
85
- pred = int(torch.argmax(probs))
86
- confidence = f"{probs[pred].item():.3f}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- return LABELS.get(pred, "Unknown"), confidence, "โ€”"
 
 
 
 
 
 
 
 
 
89
 
90
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
91
  # Build Gradio UI
@@ -95,27 +178,27 @@ load_model()
95
  with gr.Blocks(title="Math Solution Classifier") as demo:
96
  gr.Markdown("# ๐Ÿงฎ Math Solution Classifier")
97
  gr.Markdown(
98
- "Classify a studentโ€™s math solution as **correct**, **conceptually flawed**, "
99
  "or **computationally flawed**."
100
  )
101
 
102
  with gr.Row():
103
  with gr.Column():
104
- q_in = gr.Textbox(label="Math Question", lines=3)
105
- s_in = gr.Textbox(label="Proposed Solution", lines=6)
106
- btn = gr.Button("Classify", variant="primary")
107
  with gr.Column():
108
  verdict = gr.Textbox(label="Verdict", interactive=False)
109
- conf = gr.Textbox(label="Confidence", interactive=False)
110
- expl = gr.Textbox(label="Explanation", interactive=False)
111
 
112
- btn.click(classify, [q_in, s_in], [verdict, conf, expl])
113
 
114
  gr.Examples(
115
  [
116
  ["Solve for x: 2x + 5 = 13", "2x + 5 = 13\n2x = 8\nx = 4"],
117
  ["Find the derivative of f(x)=xยฒ", "f'(x)=2x+1"],
118
- ["What is 15 % of 200?", "0.15 ร— 200 = 30"],
119
  ],
120
  inputs=[q_in, s_in],
121
  )
 
1
  # app.py โ”€โ”€ Math-solution classifier for HF Spaces
2
+ # Compatible with both LoRA-classification and LoRA-causal-LM adapters
3
+ # Requirements (pin in requirements.txt):
4
+ # gradio torch transformers peft accelerate spaces
5
 
6
  import os
7
+ import json
8
  import logging
9
+ from typing import Tuple
10
 
11
  import gradio as gr
12
  import torch
13
+ from transformers import (
14
+ AutoTokenizer,
15
+ AutoModelForSequenceClassification,
16
+ )
17
 
18
+ # PEFT imports (optional)
19
  try:
20
+ from peft.auto import (
21
+ AutoPeftModelForSequenceClassification,
22
+ AutoPeftModelForCausalLM,
23
+ )
24
  PEFT_AVAILABLE = True
25
+ except ImportError: # PEFT not installed
26
  PEFT_AVAILABLE = False
27
 
28
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
 
31
  logging.basicConfig(level=logging.INFO)
32
  logger = logging.getLogger(__name__)
33
 
34
+ ADAPTER_PATH = os.getenv("ADAPTER_PATH", "./lora_adapter") # local dir or Hub ID
35
  FALLBACK_MODEL = "distilbert-base-uncased"
36
+ LABELS = {0: "โœ… Correct",
37
+ 1: "๐Ÿค” Conceptual Error",
38
+ 2: "๐Ÿ”ข Computational Error"}
39
 
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
+
42
+ model = None
43
  tokenizer = None
44
+ model_ty = None # "classification" | "causal_lm" | "baseline"
45
 
46
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
47
+ # Model loader
48
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
49
  def load_model():
50
+ """Try adapter as classifier โ†’ causal-LM โ†’ plain baseline."""
51
+ global model, tokenizer, model_ty
52
+
53
+ dtype = torch.float16 if device == "cuda" else torch.float32
54
 
55
  if PEFT_AVAILABLE and os.path.isdir(ADAPTER_PATH):
56
+ logger.info(f"Found adapter at {ADAPTER_PATH}")
57
+
58
+ # 1) Try sequence-classification adapter
59
+ try:
60
+ model = AutoPeftModelForSequenceClassification.from_pretrained(
61
+ ADAPTER_PATH,
62
+ torch_dtype=dtype,
63
+ device_map="auto" if device == "cuda" else None,
64
+ )
65
+ model_ty = "classification"
66
+ logger.info("Loaded adapter as sequence-classifier")
67
+ except ValueError:
68
+ # 2) Fall back to causal-LM adapter
69
+ logger.info("Adapter is not a classifier โ€“ trying causal-LM")
70
+ model = AutoPeftModelForCausalLM.from_pretrained(
71
+ ADAPTER_PATH,
72
+ torch_dtype=dtype,
73
+ device_map="auto" if device == "cuda" else None,
74
+ )
75
+ model_ty = "causal_lm"
76
+ logger.info("Loaded adapter as causal-LM")
77
+
78
  tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)
79
+
80
  else:
81
+ logger.warning("No adapter found โ€“ using baseline DistilBERT classifier")
82
  tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
83
+ model = AutoModelForSequenceClassification.from_pretrained(
84
  FALLBACK_MODEL,
85
  num_labels=3,
86
  ignore_mismatched_sizes=True,
87
+ torch_dtype=dtype,
88
  )
89
+ model_ty = "baseline"
90
 
91
+ # Make sure we have a pad token
92
  if tokenizer.pad_token is None:
93
  tokenizer.pad_token = tokenizer.eos_token or tokenizer.sep_token
94
 
95
  model.to(device)
96
  model.eval()
97
+ logger.info(f"Model ready on {device} as {model_ty}")
98
 
99
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€๏ฟฝ๏ฟฝ๏ฟฝโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
100
+ # Inference helpers
101
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
102
+ def _classify_logits(question: str, solution: str) -> Tuple[str, str, str]:
 
 
 
 
103
  text = f"Question: {question}\n\nSolution:\n{solution}"
104
  inputs = tokenizer(
105
  text,
 
111
 
112
  with torch.no_grad():
113
  logits = model(**inputs).logits
114
+ probs = torch.softmax(logits, dim=-1)[0]
115
+ pred = int(torch.argmax(probs))
116
+ conf = f"{probs[pred].item():.3f}"
117
+
118
+ return LABELS.get(pred, "Unknown"), conf, "โ€”"
119
+
120
+ def _classify_generate(question: str, solution: str) -> Tuple[str, str, str]:
121
+ # Prompt must match the format you used in tuning
122
+ prompt = (
123
+ "You are a mathematics tutor.\n"
124
+ "You are given a math word problem and a student's solution. Decide whether the solution is correct.\n\n"
125
+ "- Correct = all reasoning and calculations are correct.\n"
126
+ "- Conceptual Error = reasoning is wrong.\n"
127
+ "- Computational Error = reasoning okay but arithmetic off.\n\n"
128
+ "Reply with ONLY one of these JSON lines:\n"
129
+ '{"verdict": "correct"}\n'
130
+ '{"verdict": "conceptual"}\n'
131
+ '{"verdict": "computational"}\n\n"
132
+ f"Question: {question}\n\nSolution:\n{solution}\n\nAnswer:"
133
+ )
134
+
135
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
136
+ with torch.no_grad():
137
+ out_ids = model.generate(
138
+ **inputs,
139
+ max_new_tokens=32,
140
+ pad_token_id=tokenizer.eos_token_id,
141
+ )
142
+ generated = tokenizer.decode(out_ids[0][inputs["input_ids"].shape[1]:],
143
+ skip_special_tokens=True).strip()
144
+
145
+ # Try to parse last JSON line
146
+ verdict = "Unparsed"
147
+ try:
148
+ line = generated.splitlines()[-1]
149
+ data = json.loads(line)
150
+ v = data.get("verdict", "").lower()
151
+ if v.startswith("corr"):
152
+ verdict = LABELS[0]
153
+ elif v.startswith("conc"):
154
+ verdict = LABELS[1]
155
+ elif v.startswith("comp"):
156
+ verdict = LABELS[2]
157
+ except Exception:
158
+ pass
159
+
160
+ return verdict, "", generated
161
 
162
+ def classify(question: str, solution: str):
163
+ if not question.strip() or not solution.strip():
164
+ return "Please enter both fields.", "", ""
165
+
166
+ if model_ty in ("classification", "baseline"):
167
+ return _classify_logits(question, solution)
168
+ elif model_ty == "causal_lm":
169
+ return _classify_generate(question, solution)
170
+ else:
171
+ return "Model not loaded.", "", ""
172
 
173
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
174
  # Build Gradio UI
 
178
  with gr.Blocks(title="Math Solution Classifier") as demo:
179
  gr.Markdown("# ๐Ÿงฎ Math Solution Classifier")
180
  gr.Markdown(
181
+ "Classify a student's math solution as **correct**, **conceptually flawed**, "
182
  "or **computationally flawed**."
183
  )
184
 
185
  with gr.Row():
186
  with gr.Column():
187
+ q_in = gr.Textbox(label="Math Question", lines=3)
188
+ s_in = gr.Textbox(label="Proposed Solution", lines=6)
189
+ btn = gr.Button("Classify", variant="primary")
190
  with gr.Column():
191
  verdict = gr.Textbox(label="Verdict", interactive=False)
192
+ conf = gr.Textbox(label="Confidence", interactive=False)
193
+ raw = gr.Textbox(label="Model Output", interactive=False)
194
 
195
+ btn.click(classify, [q_in, s_in], [verdict, conf, raw])
196
 
197
  gr.Examples(
198
  [
199
  ["Solve for x: 2x + 5 = 13", "2x + 5 = 13\n2x = 8\nx = 4"],
200
  ["Find the derivative of f(x)=xยฒ", "f'(x)=2x+1"],
201
+ ["What is 15% of 200?", "0.15 ร— 200 = 30"],
202
  ],
203
  inputs=[q_in, s_in],
204
  )