|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import FileResponse |
|
from pydantic import BaseModel |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI(title="Math Solution Classifier API") |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
label_mapping = {0: "correct", 1: "conceptually-flawed", 2: "computationally-flawed"} |
|
|
|
class ClassificationRequest(BaseModel): |
|
question: str |
|
solution: str |
|
|
|
class ClassificationResponse(BaseModel): |
|
classification: str |
|
confidence: float |
|
|
|
def load_model(): |
|
"""Load your trained model here""" |
|
global model, tokenizer |
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.warning("Using placeholder model loading - replace with your actual model!") |
|
|
|
|
|
model_name = "distilbert-base-uncased" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
model_name, |
|
num_labels=3, |
|
ignore_mismatched_sizes=True |
|
) |
|
|
|
logger.info("Model loaded successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading model: {e}") |
|
raise |
|
|
|
def classify_solution(question: str, solution: str) -> tuple: |
|
""" |
|
Classify the math solution |
|
Returns: (classification_label, confidence_score) |
|
""" |
|
try: |
|
|
|
text_input = f"Question: {question}\nSolution: {solution}" |
|
|
|
|
|
inputs = tokenizer( |
|
text_input, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding=True, |
|
max_length=512 |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
predicted_class = torch.argmax(predictions, dim=-1).item() |
|
confidence = predictions[0][predicted_class].item() |
|
|
|
classification = label_mapping[predicted_class] |
|
|
|
return classification, confidence |
|
|
|
except Exception as e: |
|
logger.error(f"Error during classification: {e}") |
|
raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}") |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
"""Load model on startup""" |
|
logger.info("Loading model...") |
|
load_model() |
|
|
|
@app.post("/classify", response_model=ClassificationResponse) |
|
async def classify_math_solution(request: ClassificationRequest): |
|
""" |
|
Classify a math solution as correct, conceptually flawed, or computationally flawed |
|
""" |
|
if not model or not tokenizer: |
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
if not request.question.strip() or not request.solution.strip(): |
|
raise HTTPException(status_code=400, detail="Both question and solution are required") |
|
|
|
try: |
|
classification, confidence = classify_solution(request.question, request.solution) |
|
|
|
return ClassificationResponse( |
|
classification=classification, |
|
confidence=confidence |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Classification failed: {e}") |
|
raise HTTPException(status_code=500, detail="Classification failed") |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""Health check endpoint""" |
|
return {"status": "healthy", "model_loaded": model is not None} |
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
@app.get("/") |
|
async def serve_frontend(): |
|
return FileResponse("static/index.html") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |