mcamargo00 commited on
Commit
cecea85
·
verified ·
1 Parent(s): 2aa4dcf

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -33
app.py CHANGED
@@ -19,11 +19,10 @@ def load_model():
19
  global model, tokenizer
20
 
21
  try:
22
- from peft import AutoPeftModelForSequenceClassification
23
 
24
- # Load the LoRA adapter model
25
- # The adapter files should be in a folder (e.g., "./lora_adapter")
26
- model = AutoPeftModelForSequenceClassification.from_pretrained(
27
  "./lora_adapter", # Path to your adapter files
28
  torch_dtype=torch.float16,
29
  device_map="auto"
@@ -52,28 +51,58 @@ def load_model():
52
  if tokenizer.pad_token is None:
53
  tokenizer.pad_token = tokenizer.eos_token
54
 
55
- model = AutoModelForSequenceClassification.from_pretrained(
56
- model_name,
57
- num_labels=3,
58
- ignore_mismatched_sizes=True
59
- )
60
 
61
  return f"Fallback model loaded. LoRA error: {e}"
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def classify_solution(question: str, solution: str):
64
  """
65
- Classify the math solution
66
  Returns: (classification_label, confidence_score, explanation)
67
  """
68
  if not question.strip() or not solution.strip():
69
- return "Please fill in both fields", 0.0, ""
70
 
71
  if not model or not tokenizer:
72
- return "Model not loaded", 0.0, ""
73
 
74
  try:
75
- # Combine question and solution for input
76
- text_input = f"Question: {question}\nSolution: {solution}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  # Tokenize input
79
  inputs = tokenizer(
@@ -81,32 +110,51 @@ def classify_solution(question: str, solution: str):
81
  return_tensors="pt",
82
  truncation=True,
83
  padding=True,
84
- max_length=512
85
  )
86
 
87
- # Get model prediction
88
  with torch.no_grad():
89
- outputs = model(**inputs)
90
- predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
91
- predicted_class = torch.argmax(predictions, dim=-1).item()
92
- confidence = predictions[0][predicted_class].item()
93
-
94
- classification = label_mapping[predicted_class]
 
95
 
96
- # Create explanation based on classification
97
- explanations = {
98
- 0: "The mathematical approach and calculations are both sound.",
99
- 1: "The approach or understanding has fundamental issues.",
100
- 2: "The approach is correct, but there are calculation errors."
101
- }
102
 
103
- explanation = explanations[predicted_class]
 
 
104
 
105
- return classification, f"{confidence:.2%}", explanation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  except Exception as e:
108
  logger.error(f"Error during classification: {e}")
109
- return f"Classification error: {str(e)}", "0%", ""
110
 
111
  # Load model on startup
112
  load_model()
@@ -134,7 +182,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
134
 
135
  with gr.Column():
136
  classification_output = gr.Textbox(label="Classification", interactive=False)
137
- confidence_output = gr.Textbox(label="Confidence", interactive=False)
138
  explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3)
139
 
140
  # Examples
@@ -159,7 +207,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
159
  classify_btn.click(
160
  fn=classify_solution,
161
  inputs=[question_input, solution_input],
162
- outputs=[classification_output, confidence_output, explanation_output]
163
  )
164
 
165
  if __name__ == "__main__":
 
19
  global model, tokenizer
20
 
21
  try:
22
+ from peft import AutoPeftModelForCausalLM # Changed from SequenceClassification
23
 
24
+ # Load the LoRA adapter model for text generation
25
+ model = AutoPeftModelForCausalLM.from_pretrained(
 
26
  "./lora_adapter", # Path to your adapter files
27
  torch_dtype=torch.float16,
28
  device_map="auto"
 
51
  if tokenizer.pad_token is None:
52
  tokenizer.pad_token = tokenizer.eos_token
53
 
54
+ from transformers import AutoModelForCausalLM
55
+ model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
56
 
57
  return f"Fallback model loaded. LoRA error: {e}"
58
 
59
+ def get_system_prompt():
60
+ """Generates the specific system prompt for the fine-tuning task."""
61
+ return """You are a mathematics tutor.
62
+ You are given a math word problem, and a solution written by a student.
63
+ Analyze the solution carefully, line-by-line, and classify it into one of the following categories:
64
+ - Correct (All logic is correct, and all calculations are correct)
65
+ - Conceptual Error (There is an error in reasoning or logic somewhere in the solution)
66
+ - Computational Error (All logic and reasoning is correct, but the result of some calculation is incorrect)
67
+ Respond *only* with a valid JSON object that follows this exact schema:
68
+ ```json
69
+ {
70
+ "verdict": "must be one of 'correct', 'conceptual_error', or 'computational_error'",
71
+ "erroneous_line": "the exact, verbatim text of the first incorrect line, or null if the verdict is 'correct'",
72
+ "explanation": "a brief, one-sentence explanation of the error, or null if the verdict is 'correct'"
73
+ }
74
+ ```
75
+ Do NOT add any text or explanations before or after the JSON object.
76
+ """
77
+
78
  def classify_solution(question: str, solution: str):
79
  """
80
+ Classify the math solution using the exact training format
81
  Returns: (classification_label, confidence_score, explanation)
82
  """
83
  if not question.strip() or not solution.strip():
84
+ return "Please fill in both fields", "", ""
85
 
86
  if not model or not tokenizer:
87
+ return "Model not loaded", "", ""
88
 
89
  try:
90
+ # Create the exact prompt format used in training
91
+ system_prompt = get_system_prompt()
92
+ user_message = f"Problem: {question}\n\nSolution:\n{solution}"
93
+
94
+ # Format as chat messages (common for instruction-tuned models)
95
+ messages = [
96
+ {"role": "system", "content": system_prompt},
97
+ {"role": "user", "content": user_message}
98
+ ]
99
+
100
+ # Apply chat template
101
+ text_input = tokenizer.apply_chat_template(
102
+ messages,
103
+ tokenize=False,
104
+ add_generation_token=True
105
+ )
106
 
107
  # Tokenize input
108
  inputs = tokenizer(
 
110
  return_tensors="pt",
111
  truncation=True,
112
  padding=True,
113
+ max_length=2048 # Increased for longer prompts
114
  )
115
 
116
+ # Generate response (not just classify)
117
  with torch.no_grad():
118
+ outputs = model.generate(
119
+ **inputs,
120
+ max_new_tokens=200,
121
+ temperature=0.1,
122
+ do_sample=True,
123
+ pad_token_id=tokenizer.pad_token_id
124
+ )
125
 
126
+ # Decode the generated response
127
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
128
 
129
+ # Extract just the JSON response (after the input)
130
+ response_start = generated_text.find(text_input) + len(text_input)
131
+ json_response = generated_text[response_start:].strip()
132
 
133
+ # Parse the JSON response
134
+ import json
135
+ try:
136
+ result = json.loads(json_response)
137
+ verdict = result.get("verdict", "unknown")
138
+ erroneous_line = result.get("erroneous_line", "")
139
+ explanation = result.get("explanation", "")
140
+
141
+ # Map verdict to display format
142
+ verdict_mapping = {
143
+ "correct": "✅ Correct",
144
+ "conceptual_error": "🤔 Conceptual Error",
145
+ "computational_error": "🔢 Computational Error"
146
+ }
147
+
148
+ display_verdict = verdict_mapping.get(verdict, f"❓ {verdict}")
149
+
150
+ return display_verdict, erroneous_line or "None", explanation or "Solution is correct"
151
+
152
+ except json.JSONDecodeError:
153
+ return f"Model response: {json_response}", "", "Could not parse JSON response"
154
 
155
  except Exception as e:
156
  logger.error(f"Error during classification: {e}")
157
+ return f"Classification error: {str(e)}", "", ""
158
 
159
  # Load model on startup
160
  load_model()
 
182
 
183
  with gr.Column():
184
  classification_output = gr.Textbox(label="Classification", interactive=False)
185
+ erroneous_line_output = gr.Textbox(label="Erroneous Line", interactive=False)
186
  explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3)
187
 
188
  # Examples
 
207
  classify_btn.click(
208
  fn=classify_solution,
209
  inputs=[question_input, solution_input],
210
+ outputs=[classification_output, erroneous_line_output, explanation_output]
211
  )
212
 
213
  if __name__ == "__main__":