mcamargo00 commited on
Commit
4c7dba1
ยท
verified ยท
1 Parent(s): 7ab7841

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -194
app.py CHANGED
@@ -1,224 +1,124 @@
1
- # app.py - Gradio version (much simpler for HF Spaces)
 
 
 
 
2
 
3
  import gradio as gr
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
- import logging
7
 
8
- # Set up logging
 
 
 
 
 
 
 
 
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
- # Global variables for model and tokenizer
 
 
 
 
 
 
13
  model = None
14
  tokenizer = None
15
- label_mapping = {0: "โœ… Correct", 1: "๐Ÿค” Conceptually Flawed", 2: "๐Ÿ”ข Computationally Flawed"}
16
 
 
 
 
17
  def load_model():
18
- """Load your trained LoRA adapter with classification head"""
19
  global model, tokenizer
20
-
21
- try:
22
- from peft import AutoPeftModelForSequenceClassification # Back to classification
23
-
24
- # Load the LoRA adapter model for classification
25
  model = AutoPeftModelForSequenceClassification.from_pretrained(
26
- "./lora_adapter", # Path to your adapter files
27
- torch_dtype=torch.float16,
28
- device_map="auto"
29
  )
30
-
31
- # Load tokenizer from the same directory
32
- tokenizer = AutoTokenizer.from_pretrained("./lora_adapter")
33
-
34
- # Fix padding token issue
35
- if tokenizer.pad_token is None:
36
- tokenizer.pad_token = tokenizer.eos_token
37
- logger.info("Set pad_token to eos_token")
38
-
39
- logger.info("LoRA classification model loaded successfully")
40
- return "LoRA classification model loaded successfully!"
41
-
42
- except Exception as e:
43
- logger.error(f"Error loading LoRA model: {e}")
44
- # Fallback to placeholder for testing
45
- logger.warning("Using placeholder model loading - replace with your actual model!")
46
-
47
- model_name = "distilbert-base-uncased" # Simple fallback
48
- tokenizer = AutoTokenizer.from_pretrained(model_name)
49
-
50
- # Fix padding token for fallback model too
51
- if tokenizer.pad_token is None:
52
- tokenizer.pad_token = tokenizer.eos_token
53
-
54
- from transformers import AutoModelForSequenceClassification
55
  model = AutoModelForSequenceClassification.from_pretrained(
56
- model_name,
57
  num_labels=3,
58
- ignore_mismatched_sizes=True
59
  )
60
-
61
- return f"Fallback model loaded. LoRA error: {e}"
62
-
63
- def get_system_prompt():
64
- """Generates the specific system prompt for the fine-tuning task."""
65
- return """You are a mathematics tutor.
66
- You are given a math word problem, and a solution written by a student.
67
- Analyze the solution carefully, line-by-line, and classify it into one of the following categories:
68
- - Correct (All logic is correct, and all calculations are correct)
69
- - Conceptual Error (There is an error in reasoning or logic somewhere in the solution)
70
- - Computational Error (All logic and reasoning is correct, but the result of some calculation is incorrect)
71
- Respond *only* with a valid JSON object that follows this exact schema:
72
- ```json
73
- {
74
- "verdict": "must be one of 'correct', 'conceptual_error', or 'computational_error'",
75
- "erroneous_line": "the exact, verbatim text of the first incorrect line, or null if the verdict is 'correct'",
76
- "explanation": "a brief, one-sentence explanation of the error, or null if the verdict is 'correct'"
77
- }
78
- ```
79
- Do NOT add any text or explanations before or after the JSON object.
80
- """
81
-
82
- # Add this import at the top
83
- import spaces
84
-
85
- # Add this decorator to the classify function
86
- @spaces.GPU
87
- def classify_solution(question: str, solution: str):
88
- """
89
- Classify the math solution using the exact training format
90
- Returns: (classification_label, confidence_score, explanation)
91
- """
92
  if not question.strip() or not solution.strip():
93
- return "Please fill in both fields", "", ""
94
-
95
- if not model or not tokenizer:
96
- return "Model not loaded", "", ""
97
-
98
- try:
99
- # Create the exact prompt format used in training
100
- system_prompt = get_system_prompt()
101
- user_message = f"Problem: {question}\n\nSolution:\n{solution}"
102
-
103
- # Format as chat messages (common for instruction-tuned models)
104
- messages = [
105
- {"role": "system", "content": system_prompt},
106
- {"role": "user", "content": user_message}
107
- ]
108
-
109
- # Apply chat template
110
- text_input = tokenizer.apply_chat_template(
111
- messages,
112
- tokenize=False,
113
- add_generation_token=True
114
- )
115
-
116
- # Tokenize input
117
- inputs = tokenizer(
118
- text_input,
119
- return_tensors="pt",
120
- truncation=True,
121
- padding=True,
122
- max_length=2048 # Increased for longer prompts
123
- )
124
-
125
- # Generate response with CPU optimization
126
- with torch.no_grad():
127
- outputs = model.generate(
128
- **inputs,
129
- max_new_tokens=150, # Reduced from 200
130
- temperature=0.1,
131
- do_sample=False, # Faster greedy decoding
132
- pad_token_id=tokenizer.pad_token_id,
133
- use_cache=True # Speed up generation
134
- )
135
-
136
- # Decode the generated response
137
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
138
-
139
- # Extract just the JSON response (after the input)
140
- response_start = generated_text.find(text_input) + len(text_input)
141
- json_response = generated_text[response_start:].strip()
142
-
143
- # Parse the JSON response
144
- import json
145
- try:
146
- result = json.loads(json_response)
147
- verdict = result.get("verdict", "unknown")
148
- erroneous_line = result.get("erroneous_line", "")
149
- explanation = result.get("explanation", "")
150
-
151
- # Map verdict to display format
152
- verdict_mapping = {
153
- "correct": "โœ… Correct",
154
- "conceptual_error": "๐Ÿค” Conceptual Error",
155
- "computational_error": "๐Ÿ”ข Computational Error"
156
- }
157
-
158
- display_verdict = verdict_mapping.get(verdict, f"โ“ {verdict}")
159
-
160
- return display_verdict, erroneous_line or "None", explanation or "Solution is correct"
161
-
162
- except json.JSONDecodeError:
163
- return f"Model response: {json_response}", "", "Could not parse JSON response"
164
-
165
- except Exception as e:
166
- logger.error(f"Error during classification: {e}")
167
- return f"Classification error: {str(e)}", "", ""
168
-
169
- # Load model on startup
170
  load_model()
171
 
172
- # Create Gradio interface
173
- with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
174
  gr.Markdown("# ๐Ÿงฎ Math Solution Classifier")
175
- gr.Markdown("Classify math solutions as correct, conceptually flawed, or computationally flawed.")
176
-
 
 
 
177
  with gr.Row():
178
  with gr.Column():
179
- question_input = gr.Textbox(
180
- label="Math Question",
181
- placeholder="e.g., Solve for x: 2x + 5 = 13",
182
- lines=3
183
- )
184
-
185
- solution_input = gr.Textbox(
186
- label="Proposed Solution",
187
- placeholder="e.g., 2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4",
188
- lines=5
189
- )
190
-
191
- classify_btn = gr.Button("Classify Solution", variant="primary")
192
-
193
  with gr.Column():
194
- classification_output = gr.Textbox(label="Classification", interactive=False)
195
- confidence_output = gr.Textbox(label="Confidence", interactive=False)
196
- explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3)
197
-
198
- # Examples
 
199
  gr.Examples(
200
- examples=[
201
- [
202
- "Solve for x: 2x + 5 = 13",
203
- "2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4"
204
- ],
205
- [
206
- "Find the derivative of f(x) = xยฒ",
207
- "f'(x) = 2x + 1" # This should be computationally flawed
208
- ],
209
- [
210
- "What is 15% of 200?",
211
- "15% = 15/100 = 0.15\n0.15 ร— 200 = 30"
212
- ]
213
  ],
214
- inputs=[question_input, solution_input]
215
- )
216
-
217
- classify_btn.click(
218
- fn=classify_solution,
219
- inputs=[question_input, solution_input],
220
- outputs=[classification_output, confidence_output, explanation_output]
221
  )
222
 
223
  if __name__ == "__main__":
224
- app.launch()
 
1
+ # app.py โ”€โ”€ Math-solution classifier for HF Spaces
2
+ # Requires: gradio, torch, transformers, peft, accelerate, spaces
3
+
4
+ import os
5
+ import logging
6
 
7
  import gradio as gr
8
  import torch
9
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
10
 
11
+ # Optional PEFT import (only available if you include it in requirements.txt)
12
+ try:
13
+ from peft import AutoPeftModelForSequenceClassification
14
+ PEFT_AVAILABLE = True
15
+ except ImportError:
16
+ PEFT_AVAILABLE = False
17
+
18
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
19
+ # Config & logging
20
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
+ ADAPTER_PATH = os.getenv("ADAPTER_PATH", "./lora_adapter") # local dir or Hub ID
25
+ FALLBACK_MODEL = "distilbert-base-uncased"
26
+ LABELS = {0: "โœ… Correct",
27
+ 1: "๐Ÿค” Conceptual Error",
28
+ 2: "๐Ÿ”ข Computational Error"}
29
+
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
  model = None
32
  tokenizer = None
 
33
 
34
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
35
+ # Load model & tokenizer
36
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
37
  def load_model():
38
+ """Load the LoRA adapter if present, otherwise a baseline classifier."""
39
  global model, tokenizer
40
+
41
+ if PEFT_AVAILABLE and os.path.isdir(ADAPTER_PATH):
42
+ logger.info(f"Loading LoRA adapter from {ADAPTER_PATH}")
 
 
43
  model = AutoPeftModelForSequenceClassification.from_pretrained(
44
+ ADAPTER_PATH,
45
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
46
+ device_map="auto" if device == "cuda" else None,
47
  )
48
+ tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)
49
+ else:
50
+ logger.warning("LoRA adapter not found โ€“ falling back to baseline model")
51
+ tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  model = AutoModelForSequenceClassification.from_pretrained(
53
+ FALLBACK_MODEL,
54
  num_labels=3,
55
+ ignore_mismatched_sizes=True,
56
  )
57
+
58
+ if tokenizer.pad_token is None:
59
+ tokenizer.pad_token = tokenizer.eos_token or tokenizer.sep_token
60
+
61
+ model.to(device)
62
+ model.eval()
63
+ logger.info("Model & tokenizer ready")
64
+
65
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
66
+ # Inference helper
67
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
68
+ def classify(question: str, solution: str):
69
+ """Return (label, confidence, placeholder-explanation)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  if not question.strip() or not solution.strip():
71
+ return "Please provide both question and solution.", "", ""
72
+
73
+ text = f"Question: {question}\n\nSolution:\n{solution}"
74
+ inputs = tokenizer(
75
+ text,
76
+ return_tensors="pt",
77
+ padding=True,
78
+ truncation=True,
79
+ max_length=512,
80
+ ).to(device)
81
+
82
+ with torch.no_grad():
83
+ logits = model(**inputs).logits
84
+ probs = torch.softmax(logits, dim=-1)[0]
85
+ pred = int(torch.argmax(probs))
86
+ confidence = f"{probs[pred].item():.3f}"
87
+
88
+ return LABELS.get(pred, "Unknown"), confidence, "โ€”"
89
+
90
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
91
+ # Build Gradio UI
92
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  load_model()
94
 
95
+ with gr.Blocks(title="Math Solution Classifier") as demo:
 
96
  gr.Markdown("# ๐Ÿงฎ Math Solution Classifier")
97
+ gr.Markdown(
98
+ "Classify a studentโ€™s math solution as **correct**, **conceptually flawed**, "
99
+ "or **computationally flawed**."
100
+ )
101
+
102
  with gr.Row():
103
  with gr.Column():
104
+ q_in = gr.Textbox(label="Math Question", lines=3)
105
+ s_in = gr.Textbox(label="Proposed Solution", lines=6)
106
+ btn = gr.Button("Classify", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
107
  with gr.Column():
108
+ verdict = gr.Textbox(label="Verdict", interactive=False)
109
+ conf = gr.Textbox(label="Confidence", interactive=False)
110
+ expl = gr.Textbox(label="Explanation", interactive=False)
111
+
112
+ btn.click(classify, [q_in, s_in], [verdict, conf, expl])
113
+
114
  gr.Examples(
115
+ [
116
+ ["Solve for x: 2x + 5 = 13", "2x + 5 = 13\n2x = 8\nx = 4"],
117
+ ["Find the derivative of f(x)=xยฒ", "f'(x)=2x+1"],
118
+ ["What is 15 % of 200?", "0.15 ร— 200 = 30"],
 
 
 
 
 
 
 
 
 
119
  ],
120
+ inputs=[q_in, s_in],
 
 
 
 
 
 
121
  )
122
 
123
  if __name__ == "__main__":
124
+ demo.launch()