Upload app.py
Browse files
app.py
CHANGED
@@ -15,18 +15,17 @@ tokenizer = None
|
|
15 |
label_mapping = {0: "✅ Correct", 1: "🤔 Conceptually Flawed", 2: "🔢 Computationally Flawed"}
|
16 |
|
17 |
def load_model():
|
18 |
-
"""Load your trained LoRA adapter with
|
19 |
global model, tokenizer
|
20 |
|
21 |
try:
|
22 |
-
from peft import
|
23 |
|
24 |
-
# Load the LoRA adapter model for
|
25 |
-
model =
|
26 |
"./lora_adapter", # Path to your adapter files
|
27 |
-
torch_dtype=torch.
|
28 |
-
device_map="
|
29 |
-
low_cpu_mem_usage=True # Optimize for low memory
|
30 |
)
|
31 |
|
32 |
# Load tokenizer from the same directory
|
@@ -37,23 +36,27 @@ def load_model():
|
|
37 |
tokenizer.pad_token = tokenizer.eos_token
|
38 |
logger.info("Set pad_token to eos_token")
|
39 |
|
40 |
-
logger.info("LoRA model loaded successfully")
|
41 |
-
return "LoRA model loaded successfully!"
|
42 |
|
43 |
except Exception as e:
|
44 |
logger.error(f"Error loading LoRA model: {e}")
|
45 |
# Fallback to placeholder for testing
|
46 |
logger.warning("Using placeholder model loading - replace with your actual model!")
|
47 |
|
48 |
-
model_name = "
|
49 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
50 |
|
51 |
# Fix padding token for fallback model too
|
52 |
if tokenizer.pad_token is None:
|
53 |
tokenizer.pad_token = tokenizer.eos_token
|
54 |
|
55 |
-
from transformers import
|
56 |
-
model =
|
|
|
|
|
|
|
|
|
57 |
|
58 |
return f"Fallback model loaded. LoRA error: {e}"
|
59 |
|
@@ -189,7 +192,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
189 |
|
190 |
with gr.Column():
|
191 |
classification_output = gr.Textbox(label="Classification", interactive=False)
|
192 |
-
|
193 |
explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3)
|
194 |
|
195 |
# Examples
|
@@ -214,7 +217,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
214 |
classify_btn.click(
|
215 |
fn=classify_solution,
|
216 |
inputs=[question_input, solution_input],
|
217 |
-
outputs=[classification_output,
|
218 |
)
|
219 |
|
220 |
if __name__ == "__main__":
|
|
|
15 |
label_mapping = {0: "✅ Correct", 1: "🤔 Conceptually Flawed", 2: "🔢 Computationally Flawed"}
|
16 |
|
17 |
def load_model():
|
18 |
+
"""Load your trained LoRA adapter with classification head"""
|
19 |
global model, tokenizer
|
20 |
|
21 |
try:
|
22 |
+
from peft import AutoPeftModelForSequenceClassification # Back to classification
|
23 |
|
24 |
+
# Load the LoRA adapter model for classification
|
25 |
+
model = AutoPeftModelForSequenceClassification.from_pretrained(
|
26 |
"./lora_adapter", # Path to your adapter files
|
27 |
+
torch_dtype=torch.float16,
|
28 |
+
device_map="auto"
|
|
|
29 |
)
|
30 |
|
31 |
# Load tokenizer from the same directory
|
|
|
36 |
tokenizer.pad_token = tokenizer.eos_token
|
37 |
logger.info("Set pad_token to eos_token")
|
38 |
|
39 |
+
logger.info("LoRA classification model loaded successfully")
|
40 |
+
return "LoRA classification model loaded successfully!"
|
41 |
|
42 |
except Exception as e:
|
43 |
logger.error(f"Error loading LoRA model: {e}")
|
44 |
# Fallback to placeholder for testing
|
45 |
logger.warning("Using placeholder model loading - replace with your actual model!")
|
46 |
|
47 |
+
model_name = "distilbert-base-uncased" # Simple fallback
|
48 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
49 |
|
50 |
# Fix padding token for fallback model too
|
51 |
if tokenizer.pad_token is None:
|
52 |
tokenizer.pad_token = tokenizer.eos_token
|
53 |
|
54 |
+
from transformers import AutoModelForSequenceClassification
|
55 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
56 |
+
model_name,
|
57 |
+
num_labels=3,
|
58 |
+
ignore_mismatched_sizes=True
|
59 |
+
)
|
60 |
|
61 |
return f"Fallback model loaded. LoRA error: {e}"
|
62 |
|
|
|
192 |
|
193 |
with gr.Column():
|
194 |
classification_output = gr.Textbox(label="Classification", interactive=False)
|
195 |
+
confidence_output = gr.Textbox(label="Confidence", interactive=False)
|
196 |
explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3)
|
197 |
|
198 |
# Examples
|
|
|
217 |
classify_btn.click(
|
218 |
fn=classify_solution,
|
219 |
inputs=[question_input, solution_input],
|
220 |
+
outputs=[classification_output, confidence_output, explanation_output]
|
221 |
)
|
222 |
|
223 |
if __name__ == "__main__":
|