from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from transformers import pipeline from PIL import Image import io import logging from datetime import datetime import asyncio # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Age Detection API", version="1.0.0") # Add CORS middleware - CRITICAL FIX app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, specify your FlutterFlow domain allow_credentials=True, allow_methods=["GET", "POST"], allow_headers=["*"], ) # Global variable to store the model pipe = None def load_model(): """Load the model with error handling""" global pipe try: logger.info("Loading age classification model...") pipe = pipeline("image-classification", model="nateraw/vit-age-classifier") logger.info("Model loaded successfully") return True except Exception as e: logger.error(f"Failed to load model: {e}") return False # Load model on startup @app.on_event("startup") async def startup_event(): success = load_model() if not success: logger.error("Failed to initialize model on startup") @app.get("/") async def root(): return {"message": "Age Detection API is running", "status": "healthy"} @app.get("/health") async def health_check(): """Keep-alive endpoint to prevent sleeping""" global pipe model_status = "loaded" if pipe is not None else "not_loaded" return { "status": "alive", "timestamp": datetime.now().isoformat(), "model_status": model_status } @app.post("/predict") async def predict(file: UploadFile = File(...)): global pipe try: # Check if model is loaded if pipe is None: logger.warning("Model not loaded, attempting to load...") success = load_model() if not success: raise HTTPException(status_code=500, detail="Model failed to load") # Validate file type - more robust approach # Don't rely solely on content_type as it might be incorrect valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp'] filename_lower = (file.filename or '').lower() # Check both content type and file extension is_valid_content_type = file.content_type and file.content_type.startswith('image/') is_valid_extension = any(filename_lower.endswith(ext) for ext in valid_extensions) if not (is_valid_content_type or is_valid_extension): logger.warning(f"Invalid file type: content_type={file.content_type}, filename={file.filename}") raise HTTPException(status_code=400, detail="File must be an image") # Read and process image logger.info(f"Processing image: {file.filename}") image_data = await file.read() # Optimize image processing with better error handling try: image = Image.open(io.BytesIO(image_data)) # Verify it's actually an image by trying to get basic info image.verify() # This will raise an exception if not a valid image # Reopen the image since verify() closes it image = Image.open(io.BytesIO(image_data)).convert("RGB") # Resize large images to improve speed max_size = (1024, 1024) if image.size[0] > max_size[0] or image.size[1] > max_size[1]: image.thumbnail(max_size, Image.Resampling.LANCZOS) logger.info(f"Resized image to {image.size}") except Exception as e: logger.error(f"Image processing error: {e}") raise HTTPException(status_code=400, detail="Invalid or corrupted image file") # Run prediction with timeout try: logger.info("Running model prediction...") # Add timeout to prevent hanging results = await asyncio.wait_for( asyncio.to_thread(pipe, image), timeout=30.0 ) logger.info(f"Prediction completed: {len(results)} results") except asyncio.TimeoutError: logger.error("Model prediction timed out") raise HTTPException(status_code=504, detail="Prediction timed out") except Exception as e: logger.error(f"Model prediction error: {e}") raise HTTPException(status_code=500, detail="Prediction failed") return JSONResponse(content={ "results": results, "timestamp": datetime.now().isoformat(), "image_size": image.size }) except HTTPException: raise except Exception as e: logger.error(f"Unexpected error: {e}") raise HTTPException(status_code=500, detail="Internal server error") # Additional endpoint to warm up the model @app.post("/warmup") async def warmup(): """Endpoint to warm up the model""" global pipe if pipe is None: success = load_model() return {"status": "loaded" if success else "failed"} return {"status": "already_loaded"}