mcamargo00 commited on
Commit
d8899dd
ยท
verified ยท
1 Parent(s): 9995215

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -91
app.py CHANGED
@@ -1,7 +1,7 @@
1
- # app.py โ”€โ”€ Math-solution classifier for HF Spaces
2
- # Compatible with both LoRA-classification and LoRA-causal-LM adapters
3
- # Requirements (pin in requirements.txt):
4
- # gradio torch transformers peft accelerate spaces
5
 
6
  import os
7
  import json
@@ -9,107 +9,83 @@ import logging
9
  from typing import Tuple
10
 
11
  import gradio as gr
12
- import torch
13
- import spaces
14
-
15
- from transformers import (
16
- AutoTokenizer,
17
- AutoModelForSequenceClassification,
18
- )
19
-
20
- # PEFT imports (optional)
21
- try:
22
- from peft.auto import (
23
- AutoPeftModelForSequenceClassification,
24
- AutoPeftModelForCausalLM,
25
- )
26
- PEFT_AVAILABLE = True
27
- except ImportError: # PEFT not installed
28
- PEFT_AVAILABLE = False
29
 
30
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
31
- # Config & logging
32
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
33
- logging.basicConfig(level=logging.INFO)
34
- logger = logging.getLogger(__name__)
35
-
36
- ADAPTER_PATH = os.getenv("ADAPTER_PATH", "./lora_adapter") # local dir or Hub ID
37
  FALLBACK_MODEL = "distilbert-base-uncased"
38
  LABELS = {0: "โœ… Correct",
39
  1: "๐Ÿค” Conceptual Error",
40
  2: "๐Ÿ”ข Computational Error"}
41
 
42
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
43
 
 
44
  model = None
45
  tokenizer = None
46
  model_ty = None # "classification" | "causal_lm" | "baseline"
47
 
48
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
49
- # Model loader
50
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
51
- def load_model():
52
- """Try adapter as classifier โ†’ causal-LM โ†’ plain baseline."""
 
 
 
53
  global model, tokenizer, model_ty
 
 
 
 
 
 
 
 
 
54
 
55
- dtype = torch.float16 if device == "cuda" else torch.float32
56
-
57
- if PEFT_AVAILABLE and os.path.isdir(ADAPTER_PATH):
58
- logger.info(f"Found adapter at {ADAPTER_PATH}")
59
 
60
- # 1) Try sequence-classification adapter
61
- try:
62
  model = AutoPeftModelForSequenceClassification.from_pretrained(
63
- ADAPTER_PATH,
64
- torch_dtype=dtype,
65
- device_map="auto" if device == "cuda" else None,
66
  )
67
  model_ty = "classification"
68
- logger.info("Loaded adapter as sequence-classifier")
69
  except ValueError:
70
- # 2) Fall back to causal-LM adapter
71
- logger.info("Adapter is not a classifier โ€“ trying causal-LM")
72
  model = AutoPeftModelForCausalLM.from_pretrained(
73
- ADAPTER_PATH,
74
- torch_dtype=dtype,
75
- device_map="auto" if device == "cuda" else None,
76
  )
77
  model_ty = "causal_lm"
78
- logger.info("Loaded adapter as causal-LM")
79
 
80
  tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)
81
 
82
  else:
83
- logger.warning("No adapter found โ€“ using baseline DistilBERT classifier")
84
  tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
85
  model = AutoModelForSequenceClassification.from_pretrained(
86
- FALLBACK_MODEL,
87
- num_labels=3,
88
- ignore_mismatched_sizes=True,
89
- torch_dtype=dtype,
90
  )
91
  model_ty = "baseline"
92
 
93
- # Make sure we have a pad token
94
  if tokenizer.pad_token is None:
95
  tokenizer.pad_token = tokenizer.eos_token or tokenizer.sep_token
96
 
97
- model.to(device)
98
  model.eval()
99
- logger.info(f"Model ready on {device} as {model_ty}")
 
100
 
101
- # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
102
- # Inference helpers
103
- # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
104
  def _classify_logits(question: str, solution: str) -> Tuple[str, str, str]:
 
105
  text = f"Question: {question}\n\nSolution:\n{solution}"
106
  inputs = tokenizer(
107
- text,
108
- return_tensors="pt",
109
- padding=True,
110
- truncation=True,
111
- max_length=512,
112
- ).to(device)
113
 
114
  with torch.no_grad():
115
  logits = model(**inputs).logits
@@ -117,14 +93,14 @@ def _classify_logits(question: str, solution: str) -> Tuple[str, str, str]:
117
  pred = int(torch.argmax(probs))
118
  conf = f"{probs[pred].item():.3f}"
119
 
120
- return LABELS.get(pred, "Unknown"), conf, "โ€”"
 
121
 
122
  def _classify_generate(question: str, solution: str) -> Tuple[str, str, str]:
123
- # Prompt must match the format you used in tuning
124
  prompt = (
125
  "You are a mathematics tutor.\n"
126
- "You are given a math word problem and a student's solution. "
127
- "Decide whether the solution is correct.\n\n"
128
  "- Correct = all reasoning and calculations are correct.\n"
129
  "- Conceptual Error = reasoning is wrong.\n"
130
  "- Computational Error= reasoning okay but arithmetic off.\n\n"
@@ -135,8 +111,7 @@ def _classify_generate(question: str, solution: str) -> Tuple[str, str, str]:
135
  f"Question: {question}\n\nSolution:\n{solution}\n\nAnswer:"
136
  )
137
 
138
-
139
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
140
  with torch.no_grad():
141
  out_ids = model.generate(
142
  **inputs,
@@ -148,12 +123,9 @@ def _classify_generate(question: str, solution: str) -> Tuple[str, str, str]:
148
  skip_special_tokens=True,
149
  ).strip()
150
 
151
-
152
- # Try to parse last JSON line
153
  verdict = "Unparsed"
154
  try:
155
- line = generated.splitlines()[-1]
156
- data = json.loads(line)
157
  v = data.get("verdict", "").lower()
158
  if v.startswith("corr"):
159
  verdict = LABELS[0]
@@ -166,21 +138,30 @@ def _classify_generate(question: str, solution: str) -> Tuple[str, str, str]:
166
 
167
  return verdict, "", generated
168
 
169
- def classify(question: str, solution: str):
 
 
 
 
 
 
 
 
 
170
  if not question.strip() or not solution.strip():
171
- return "Please enter both fields.", "", ""
172
 
173
  if model_ty in ("classification", "baseline"):
174
  return _classify_logits(question, solution)
175
- elif model_ty == "causal_lm":
176
  return _classify_generate(question, solution)
177
- else:
178
- return "Model not loaded.", "", ""
179
 
180
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
181
- # Build Gradio UI
182
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
183
- load_model()
 
 
184
 
185
  with gr.Blocks(title="Math Solution Classifier") as demo:
186
  gr.Markdown("# ๐Ÿงฎ Math Solution Classifier")
@@ -191,15 +172,15 @@ with gr.Blocks(title="Math Solution Classifier") as demo:
191
 
192
  with gr.Row():
193
  with gr.Column():
194
- q_in = gr.Textbox(label="Math Question", lines=3)
195
- s_in = gr.Textbox(label="Proposed Solution", lines=6)
196
- btn = gr.Button("Classify", variant="primary")
197
  with gr.Column():
198
  verdict = gr.Textbox(label="Verdict", interactive=False)
199
  conf = gr.Textbox(label="Confidence", interactive=False)
200
  raw = gr.Textbox(label="Model Output", interactive=False)
201
 
202
- btn.click(classify, [q_in, s_in], [verdict, conf, raw])
203
 
204
  gr.Examples(
205
  [
@@ -210,11 +191,6 @@ with gr.Blocks(title="Math Solution Classifier") as demo:
210
  inputs=[q_in, s_in],
211
  )
212
 
213
-
214
- @spaces.GPU # or @spaces.CPU if you deploy on CPU
215
-
216
  def launch_app():
217
- return demo # the Gradio Blocks object you built
218
-
219
- if __name__ == "__main__":
220
- demo.launch()
 
1
+ # app.py โ”€โ”€ Math-solution classifier on HF Spaces (Zero-GPU-safe)
2
+ #
3
+ # Pin in requirements.txt:
4
+ # gradio==4.44.0 torch==2.1.0 transformers==4.35.0 peft==0.7.1 accelerate==0.25.0 spaces
5
 
6
  import os
7
  import json
 
9
  from typing import Tuple
10
 
11
  import gradio as gr
12
+ import spaces # <- Hugging Face Spaces SDK (Zero)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
15
+ # CONSTANTS (no CUDA use here)
16
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
17
+ ADAPTER_PATH = os.getenv("ADAPTER_PATH", "./lora_adapter") # dir or Hub repo
 
 
 
18
  FALLBACK_MODEL = "distilbert-base-uncased"
19
  LABELS = {0: "โœ… Correct",
20
  1: "๐Ÿค” Conceptual Error",
21
  2: "๐Ÿ”ข Computational Error"}
22
 
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
 
26
+ # Globals that will live **inside the GPU worker**
27
  model = None
28
  tokenizer = None
29
  model_ty = None # "classification" | "causal_lm" | "baseline"
30
 
31
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
32
+ # GPU-SIDE INITIALISATION & INFERENCE
33
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
34
+ def _load_model_gpu():
35
+ """
36
+ Runs **inside the GPU worker**.
37
+ Tries LoRA classification adapter โ†’ LoRA causal-LM adapter โ†’ plain baseline.
38
+ """
39
  global model, tokenizer, model_ty
40
+ import torch
41
+ from transformers import (
42
+ AutoTokenizer,
43
+ AutoModelForSequenceClassification,
44
+ )
45
+ from peft.auto import (
46
+ AutoPeftModelForSequenceClassification,
47
+ AutoPeftModelForCausalLM,
48
+ )
49
 
50
+ dtype = torch.float16
51
+ if os.path.isdir(ADAPTER_PATH):
52
+ logger.info(f"[GPU] Loading adapter from {ADAPTER_PATH}")
 
53
 
54
+ try: # 1) classification adapter
 
55
  model = AutoPeftModelForSequenceClassification.from_pretrained(
56
+ ADAPTER_PATH, torch_dtype=dtype, device_map="auto"
 
 
57
  )
58
  model_ty = "classification"
 
59
  except ValueError:
60
+ logger.info("[GPU] Not a classifier, trying causal-LM")
 
61
  model = AutoPeftModelForCausalLM.from_pretrained(
62
+ ADAPTER_PATH, torch_dtype=dtype, device_map="auto"
 
 
63
  )
64
  model_ty = "causal_lm"
 
65
 
66
  tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)
67
 
68
  else:
69
+ logger.warning("[GPU] No adapter found โ€“ using baseline DistilBERT")
70
  tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
71
  model = AutoModelForSequenceClassification.from_pretrained(
72
+ FALLBACK_MODEL, num_labels=3, ignore_mismatched_sizes=True
 
 
 
73
  )
74
  model_ty = "baseline"
75
 
 
76
  if tokenizer.pad_token is None:
77
  tokenizer.pad_token = tokenizer.eos_token or tokenizer.sep_token
78
 
 
79
  model.eval()
80
+ logger.info(f"[GPU] Model ready ({model_ty})")
81
+
82
 
 
 
 
83
  def _classify_logits(question: str, solution: str) -> Tuple[str, str, str]:
84
+ import torch
85
  text = f"Question: {question}\n\nSolution:\n{solution}"
86
  inputs = tokenizer(
87
+ text, return_tensors="pt", padding=True, truncation=True, max_length=512
88
+ ).to("cuda")
 
 
 
 
89
 
90
  with torch.no_grad():
91
  logits = model(**inputs).logits
 
93
  pred = int(torch.argmax(probs))
94
  conf = f"{probs[pred].item():.3f}"
95
 
96
+ return LABELS[pred], conf, "โ€”"
97
+
98
 
99
  def _classify_generate(question: str, solution: str) -> Tuple[str, str, str]:
100
+ import torch
101
  prompt = (
102
  "You are a mathematics tutor.\n"
103
+ "You are given a math word problem and a student's solution. Decide whether the solution is correct.\n\n"
 
104
  "- Correct = all reasoning and calculations are correct.\n"
105
  "- Conceptual Error = reasoning is wrong.\n"
106
  "- Computational Error= reasoning okay but arithmetic off.\n\n"
 
111
  f"Question: {question}\n\nSolution:\n{solution}\n\nAnswer:"
112
  )
113
 
114
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
 
115
  with torch.no_grad():
116
  out_ids = model.generate(
117
  **inputs,
 
123
  skip_special_tokens=True,
124
  ).strip()
125
 
 
 
126
  verdict = "Unparsed"
127
  try:
128
+ data = json.loads(generated.splitlines()[-1])
 
129
  v = data.get("verdict", "").lower()
130
  if v.startswith("corr"):
131
  verdict = LABELS[0]
 
138
 
139
  return verdict, "", generated
140
 
141
+
142
+ @spaces.GPU # <-- every CUDA op happens inside here
143
+ def gpu_classify(question: str, solution: str):
144
+ """
145
+ Proxy target for Gradio. Executed in the GPU worker so CUDA is allowed.
146
+ Returns (verdict, confidence, raw_output)
147
+ """
148
+ if model is None:
149
+ _load_model_gpu()
150
+
151
  if not question.strip() or not solution.strip():
152
+ return "Please fill both fields.", "", ""
153
 
154
  if model_ty in ("classification", "baseline"):
155
  return _classify_logits(question, solution)
156
+ else: # causal_lm
157
  return _classify_generate(question, solution)
 
 
158
 
159
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
160
+ # CPU-SIDE UI (no torch.cuda here)
161
  # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
162
+ def classify_proxy(q, s):
163
+ """Simple wrapper so Gradio can call the GPU function."""
164
+ return gpu_classify(q, s)
165
 
166
  with gr.Blocks(title="Math Solution Classifier") as demo:
167
  gr.Markdown("# ๐Ÿงฎ Math Solution Classifier")
 
172
 
173
  with gr.Row():
174
  with gr.Column():
175
+ q_in = gr.Textbox(label="Math Question", lines=3)
176
+ s_in = gr.Textbox(label="Proposed Solution", lines=6)
177
+ btn = gr.Button("Classify", variant="primary")
178
  with gr.Column():
179
  verdict = gr.Textbox(label="Verdict", interactive=False)
180
  conf = gr.Textbox(label="Confidence", interactive=False)
181
  raw = gr.Textbox(label="Model Output", interactive=False)
182
 
183
+ btn.click(classify_proxy, [q_in, s_in], [verdict, conf, raw])
184
 
185
  gr.Examples(
186
  [
 
191
  inputs=[q_in, s_in],
192
  )
193
 
194
+ @spaces.CPU # UI served from the CPU worker
 
 
195
  def launch_app():
196
+ return demo