|
|
|
|
|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
label_mapping = {0: "✅ Correct", 1: "🤔 Conceptually Flawed", 2: "🔢 Computationally Flawed"} |
|
|
|
def load_model(): |
|
"""Load your trained LoRA adapter with base model""" |
|
global model, tokenizer |
|
|
|
try: |
|
from peft import AutoPeftModelForSequenceClassification |
|
|
|
|
|
|
|
model = AutoPeftModelForSequenceClassification.from_pretrained( |
|
"./lora_adapter", |
|
torch_dtype=torch.float16, |
|
device_map="auto" |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("./lora_adapter") |
|
|
|
logger.info("LoRA model loaded successfully") |
|
return "LoRA model loaded successfully!" |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading LoRA model: {e}") |
|
|
|
logger.warning("Using placeholder model loading - replace with your actual model!") |
|
|
|
model_name = "microsoft/DialoGPT-medium" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
model_name, |
|
num_labels=3, |
|
ignore_mismatched_sizes=True |
|
) |
|
|
|
return f"Fallback model loaded. LoRA error: {e}" |
|
|
|
def classify_solution(question: str, solution: str): |
|
""" |
|
Classify the math solution |
|
Returns: (classification_label, confidence_score, explanation) |
|
""" |
|
if not question.strip() or not solution.strip(): |
|
return "Please fill in both fields", 0.0, "" |
|
|
|
if not model or not tokenizer: |
|
return "Model not loaded", 0.0, "" |
|
|
|
try: |
|
|
|
text_input = f"Question: {question}\nSolution: {solution}" |
|
|
|
|
|
inputs = tokenizer( |
|
text_input, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding=True, |
|
max_length=512 |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
predicted_class = torch.argmax(predictions, dim=-1).item() |
|
confidence = predictions[0][predicted_class].item() |
|
|
|
classification = label_mapping[predicted_class] |
|
|
|
|
|
explanations = { |
|
0: "The mathematical approach and calculations are both sound.", |
|
1: "The approach or understanding has fundamental issues.", |
|
2: "The approach is correct, but there are calculation errors." |
|
} |
|
|
|
explanation = explanations[predicted_class] |
|
|
|
return classification, f"{confidence:.2%}", explanation |
|
|
|
except Exception as e: |
|
logger.error(f"Error during classification: {e}") |
|
return f"Classification error: {str(e)}", "0%", "" |
|
|
|
|
|
load_model() |
|
|
|
|
|
with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app: |
|
gr.Markdown("# 🧮 Math Solution Classifier") |
|
gr.Markdown("Classify math solutions as correct, conceptually flawed, or computationally flawed.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
question_input = gr.Textbox( |
|
label="Math Question", |
|
placeholder="e.g., Solve for x: 2x + 5 = 13", |
|
lines=3 |
|
) |
|
|
|
solution_input = gr.Textbox( |
|
label="Proposed Solution", |
|
placeholder="e.g., 2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4", |
|
lines=5 |
|
) |
|
|
|
classify_btn = gr.Button("Classify Solution", variant="primary") |
|
|
|
with gr.Column(): |
|
classification_output = gr.Textbox(label="Classification", interactive=False) |
|
confidence_output = gr.Textbox(label="Confidence", interactive=False) |
|
explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
[ |
|
"Solve for x: 2x + 5 = 13", |
|
"2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4" |
|
], |
|
[ |
|
"Find the derivative of f(x) = x²", |
|
"f'(x) = 2x + 1" |
|
], |
|
[ |
|
"What is 15% of 200?", |
|
"15% = 15/100 = 0.15\n0.15 × 200 = 30" |
|
] |
|
], |
|
inputs=[question_input, solution_input] |
|
) |
|
|
|
classify_btn.click( |
|
fn=classify_solution, |
|
inputs=[question_input, solution_input], |
|
outputs=[classification_output, confidence_output, explanation_output] |
|
) |
|
|
|
if __name__ == "__main__": |
|
app.launch() |