|
|
|
|
|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
label_mapping = {0: "✅ Correct", 1: "🤔 Conceptually Flawed", 2: "🔢 Computationally Flawed"} |
|
|
|
def load_model(): |
|
"""Load your trained LoRA adapter with base model""" |
|
global model, tokenizer |
|
|
|
try: |
|
from peft import AutoPeftModelForCausalLM |
|
|
|
|
|
model = AutoPeftModelForCausalLM.from_pretrained( |
|
"./lora_adapter", |
|
torch_dtype=torch.float16, |
|
device_map="auto" |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("./lora_adapter") |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
logger.info("Set pad_token to eos_token") |
|
|
|
logger.info("LoRA model loaded successfully") |
|
return "LoRA model loaded successfully!" |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading LoRA model: {e}") |
|
|
|
logger.warning("Using placeholder model loading - replace with your actual model!") |
|
|
|
model_name = "microsoft/DialoGPT-medium" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
from transformers import AutoModelForCausalLM |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
return f"Fallback model loaded. LoRA error: {e}" |
|
|
|
def get_system_prompt(): |
|
"""Generates the specific system prompt for the fine-tuning task.""" |
|
return """You are a mathematics tutor. |
|
You are given a math word problem, and a solution written by a student. |
|
Analyze the solution carefully, line-by-line, and classify it into one of the following categories: |
|
- Correct (All logic is correct, and all calculations are correct) |
|
- Conceptual Error (There is an error in reasoning or logic somewhere in the solution) |
|
- Computational Error (All logic and reasoning is correct, but the result of some calculation is incorrect) |
|
Respond *only* with a valid JSON object that follows this exact schema: |
|
```json |
|
{ |
|
"verdict": "must be one of 'correct', 'conceptual_error', or 'computational_error'", |
|
"erroneous_line": "the exact, verbatim text of the first incorrect line, or null if the verdict is 'correct'", |
|
"explanation": "a brief, one-sentence explanation of the error, or null if the verdict is 'correct'" |
|
} |
|
``` |
|
Do NOT add any text or explanations before or after the JSON object. |
|
""" |
|
|
|
def classify_solution(question: str, solution: str): |
|
""" |
|
Classify the math solution using the exact training format |
|
Returns: (classification_label, confidence_score, explanation) |
|
""" |
|
if not question.strip() or not solution.strip(): |
|
return "Please fill in both fields", "", "" |
|
|
|
if not model or not tokenizer: |
|
return "Model not loaded", "", "" |
|
|
|
try: |
|
|
|
system_prompt = get_system_prompt() |
|
user_message = f"Problem: {question}\n\nSolution:\n{solution}" |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": user_message} |
|
] |
|
|
|
|
|
text_input = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_token=True |
|
) |
|
|
|
|
|
inputs = tokenizer( |
|
text_input, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding=True, |
|
max_length=2048 |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=200, |
|
temperature=0.1, |
|
do_sample=True, |
|
pad_token_id=tokenizer.pad_token_id |
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
response_start = generated_text.find(text_input) + len(text_input) |
|
json_response = generated_text[response_start:].strip() |
|
|
|
|
|
import json |
|
try: |
|
result = json.loads(json_response) |
|
verdict = result.get("verdict", "unknown") |
|
erroneous_line = result.get("erroneous_line", "") |
|
explanation = result.get("explanation", "") |
|
|
|
|
|
verdict_mapping = { |
|
"correct": "✅ Correct", |
|
"conceptual_error": "🤔 Conceptual Error", |
|
"computational_error": "🔢 Computational Error" |
|
} |
|
|
|
display_verdict = verdict_mapping.get(verdict, f"❓ {verdict}") |
|
|
|
return display_verdict, erroneous_line or "None", explanation or "Solution is correct" |
|
|
|
except json.JSONDecodeError: |
|
return f"Model response: {json_response}", "", "Could not parse JSON response" |
|
|
|
except Exception as e: |
|
logger.error(f"Error during classification: {e}") |
|
return f"Classification error: {str(e)}", "", "" |
|
|
|
|
|
load_model() |
|
|
|
|
|
with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app: |
|
gr.Markdown("# 🧮 Math Solution Classifier") |
|
gr.Markdown("Classify math solutions as correct, conceptually flawed, or computationally flawed.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
question_input = gr.Textbox( |
|
label="Math Question", |
|
placeholder="e.g., Solve for x: 2x + 5 = 13", |
|
lines=3 |
|
) |
|
|
|
solution_input = gr.Textbox( |
|
label="Proposed Solution", |
|
placeholder="e.g., 2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4", |
|
lines=5 |
|
) |
|
|
|
classify_btn = gr.Button("Classify Solution", variant="primary") |
|
|
|
with gr.Column(): |
|
classification_output = gr.Textbox(label="Classification", interactive=False) |
|
erroneous_line_output = gr.Textbox(label="Erroneous Line", interactive=False) |
|
explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
[ |
|
"Solve for x: 2x + 5 = 13", |
|
"2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4" |
|
], |
|
[ |
|
"Find the derivative of f(x) = x²", |
|
"f'(x) = 2x + 1" |
|
], |
|
[ |
|
"What is 15% of 200?", |
|
"15% = 15/100 = 0.15\n0.15 × 200 = 30" |
|
] |
|
], |
|
inputs=[question_input, solution_input] |
|
) |
|
|
|
classify_btn.click( |
|
fn=classify_solution, |
|
inputs=[question_input, solution_input], |
|
outputs=[classification_output, erroneous_line_output, explanation_output] |
|
) |
|
|
|
if __name__ == "__main__": |
|
app.launch() |