# app.py - FastAPI backend for Math Solution Classifier from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from pydantic import BaseModel import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Math Solution Classifier API") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables for model and tokenizer model = None tokenizer = None label_mapping = {0: "correct", 1: "conceptually-flawed", 2: "computationally-flawed"} class ClassificationRequest(BaseModel): question: str solution: str class ClassificationResponse(BaseModel): classification: str confidence: float def load_model(): """Load your trained model here""" global model, tokenizer try: # Replace these with your actual model path/name # Option 1: Load from local files # model = AutoModelForSequenceClassification.from_pretrained("./your_model_directory") # tokenizer = AutoTokenizer.from_pretrained("./your_model_directory") # Option 2: Load from Hugging Face Hub (if you upload your model there) # model = AutoModelForSequenceClassification.from_pretrained("your-username/your-model-name") # tokenizer = AutoTokenizer.from_pretrained("your-username/your-model-name") # For now, we'll use a placeholder - replace this with your actual model loading logger.warning("Using placeholder model loading - replace with your actual model!") # Placeholder model loading (replace this!) model_name = "distilbert-base-uncased" # Replace with your model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=3, ignore_mismatched_sizes=True ) logger.info("Model loaded successfully") except Exception as e: logger.error(f"Error loading model: {e}") raise def classify_solution(question: str, solution: str) -> tuple: """ Classify the math solution Returns: (classification_label, confidence_score) """ try: # Combine question and solution for input text_input = f"Question: {question}\nSolution: {solution}" # Tokenize input inputs = tokenizer( text_input, return_tensors="pt", truncation=True, padding=True, max_length=512 ) # Get model prediction with torch.no_grad(): outputs = model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) predicted_class = torch.argmax(predictions, dim=-1).item() confidence = predictions[0][predicted_class].item() classification = label_mapping[predicted_class] return classification, confidence except Exception as e: logger.error(f"Error during classification: {e}") raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}") @app.on_event("startup") async def startup_event(): """Load model on startup""" logger.info("Loading model...") load_model() @app.post("/classify", response_model=ClassificationResponse) async def classify_math_solution(request: ClassificationRequest): """ Classify a math solution as correct, conceptually flawed, or computationally flawed """ if not model or not tokenizer: raise HTTPException(status_code=503, detail="Model not loaded") if not request.question.strip() or not request.solution.strip(): raise HTTPException(status_code=400, detail="Both question and solution are required") try: classification, confidence = classify_solution(request.question, request.solution) return ClassificationResponse( classification=classification, confidence=confidence ) except Exception as e: logger.error(f"Classification failed: {e}") raise HTTPException(status_code=500, detail="Classification failed") @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy", "model_loaded": model is not None} # Serve the frontend (for Hugging Face Spaces) app.mount("/static", StaticFiles(directory="static"), name="static") @app.get("/") async def serve_frontend(): return FileResponse("static/index.html") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860) # Port 7860 is standard for HF Spaces