mcamargo00 commited on
Commit
63ca988
·
verified ·
1 Parent(s): 9f84309

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -15,18 +15,17 @@ tokenizer = None
15
  label_mapping = {0: "✅ Correct", 1: "🤔 Conceptually Flawed", 2: "🔢 Computationally Flawed"}
16
 
17
  def load_model():
18
- """Load your trained LoRA adapter with base model"""
19
  global model, tokenizer
20
 
21
  try:
22
- from peft import AutoPeftModelForCausalLM # Changed from SequenceClassification
23
 
24
- # Load the LoRA adapter model for text generation
25
- model = AutoPeftModelForCausalLM.from_pretrained(
26
  "./lora_adapter", # Path to your adapter files
27
- torch_dtype=torch.float32, # Use float32 for CPU
28
- device_map="cpu", # Force CPU
29
- low_cpu_mem_usage=True # Optimize for low memory
30
  )
31
 
32
  # Load tokenizer from the same directory
@@ -37,23 +36,27 @@ def load_model():
37
  tokenizer.pad_token = tokenizer.eos_token
38
  logger.info("Set pad_token to eos_token")
39
 
40
- logger.info("LoRA model loaded successfully")
41
- return "LoRA model loaded successfully!"
42
 
43
  except Exception as e:
44
  logger.error(f"Error loading LoRA model: {e}")
45
  # Fallback to placeholder for testing
46
  logger.warning("Using placeholder model loading - replace with your actual model!")
47
 
48
- model_name = "microsoft/DialoGPT-medium" # Closer to Phi-4 architecture
49
  tokenizer = AutoTokenizer.from_pretrained(model_name)
50
 
51
  # Fix padding token for fallback model too
52
  if tokenizer.pad_token is None:
53
  tokenizer.pad_token = tokenizer.eos_token
54
 
55
- from transformers import AutoModelForCausalLM
56
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
57
 
58
  return f"Fallback model loaded. LoRA error: {e}"
59
 
@@ -189,7 +192,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
189
 
190
  with gr.Column():
191
  classification_output = gr.Textbox(label="Classification", interactive=False)
192
- erroneous_line_output = gr.Textbox(label="Erroneous Line", interactive=False)
193
  explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3)
194
 
195
  # Examples
@@ -214,7 +217,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
214
  classify_btn.click(
215
  fn=classify_solution,
216
  inputs=[question_input, solution_input],
217
- outputs=[classification_output, erroneous_line_output, explanation_output]
218
  )
219
 
220
  if __name__ == "__main__":
 
15
  label_mapping = {0: "✅ Correct", 1: "🤔 Conceptually Flawed", 2: "🔢 Computationally Flawed"}
16
 
17
  def load_model():
18
+ """Load your trained LoRA adapter with classification head"""
19
  global model, tokenizer
20
 
21
  try:
22
+ from peft import AutoPeftModelForSequenceClassification # Back to classification
23
 
24
+ # Load the LoRA adapter model for classification
25
+ model = AutoPeftModelForSequenceClassification.from_pretrained(
26
  "./lora_adapter", # Path to your adapter files
27
+ torch_dtype=torch.float16,
28
+ device_map="auto"
 
29
  )
30
 
31
  # Load tokenizer from the same directory
 
36
  tokenizer.pad_token = tokenizer.eos_token
37
  logger.info("Set pad_token to eos_token")
38
 
39
+ logger.info("LoRA classification model loaded successfully")
40
+ return "LoRA classification model loaded successfully!"
41
 
42
  except Exception as e:
43
  logger.error(f"Error loading LoRA model: {e}")
44
  # Fallback to placeholder for testing
45
  logger.warning("Using placeholder model loading - replace with your actual model!")
46
 
47
+ model_name = "distilbert-base-uncased" # Simple fallback
48
  tokenizer = AutoTokenizer.from_pretrained(model_name)
49
 
50
  # Fix padding token for fallback model too
51
  if tokenizer.pad_token is None:
52
  tokenizer.pad_token = tokenizer.eos_token
53
 
54
+ from transformers import AutoModelForSequenceClassification
55
+ model = AutoModelForSequenceClassification.from_pretrained(
56
+ model_name,
57
+ num_labels=3,
58
+ ignore_mismatched_sizes=True
59
+ )
60
 
61
  return f"Fallback model loaded. LoRA error: {e}"
62
 
 
192
 
193
  with gr.Column():
194
  classification_output = gr.Textbox(label="Classification", interactive=False)
195
+ confidence_output = gr.Textbox(label="Confidence", interactive=False)
196
  explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3)
197
 
198
  # Examples
 
217
  classify_btn.click(
218
  fn=classify_solution,
219
  inputs=[question_input, solution_input],
220
+ outputs=[classification_output, confidence_output, explanation_output]
221
  )
222
 
223
  if __name__ == "__main__":