Spaces:
Running
Running
File size: 5,361 Bytes
60c0001 25139b3 60c0001 25139b3 60c0001 25139b3 60c0001 25139b3 60c0001 25139b3 60c0001 25139b3 60c0001 cc6bef8 60c0001 cc6bef8 60c0001 cc6bef8 60c0001 cc6bef8 60c0001 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
from PIL import Image
import io
import logging
from datetime import datetime
import asyncio
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Age Detection API", version="1.0.0")
# Add CORS middleware - CRITICAL FIX
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify your FlutterFlow domain
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# Global variable to store the model
pipe = None
def load_model():
"""Load the model with error handling"""
global pipe
try:
logger.info("Loading age classification model...")
pipe = pipeline("image-classification", model="nateraw/vit-age-classifier")
logger.info("Model loaded successfully")
return True
except Exception as e:
logger.error(f"Failed to load model: {e}")
return False
# Load model on startup
@app.on_event("startup")
async def startup_event():
success = load_model()
if not success:
logger.error("Failed to initialize model on startup")
@app.get("/")
async def root():
return {"message": "Age Detection API is running", "status": "healthy"}
@app.get("/health")
async def health_check():
"""Keep-alive endpoint to prevent sleeping"""
global pipe
model_status = "loaded" if pipe is not None else "not_loaded"
return {
"status": "alive",
"timestamp": datetime.now().isoformat(),
"model_status": model_status
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
global pipe
try:
# Check if model is loaded
if pipe is None:
logger.warning("Model not loaded, attempting to load...")
success = load_model()
if not success:
raise HTTPException(status_code=500, detail="Model failed to load")
# Validate file type - more robust approach
# Don't rely solely on content_type as it might be incorrect
valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp']
filename_lower = (file.filename or '').lower()
# Check both content type and file extension
is_valid_content_type = file.content_type and file.content_type.startswith('image/')
is_valid_extension = any(filename_lower.endswith(ext) for ext in valid_extensions)
if not (is_valid_content_type or is_valid_extension):
logger.warning(f"Invalid file type: content_type={file.content_type}, filename={file.filename}")
raise HTTPException(status_code=400, detail="File must be an image")
# Read and process image
logger.info(f"Processing image: {file.filename}")
image_data = await file.read()
# Optimize image processing with better error handling
try:
image = Image.open(io.BytesIO(image_data))
# Verify it's actually an image by trying to get basic info
image.verify() # This will raise an exception if not a valid image
# Reopen the image since verify() closes it
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# Resize large images to improve speed
max_size = (1024, 1024)
if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
image.thumbnail(max_size, Image.Resampling.LANCZOS)
logger.info(f"Resized image to {image.size}")
except Exception as e:
logger.error(f"Image processing error: {e}")
raise HTTPException(status_code=400, detail="Invalid or corrupted image file")
# Run prediction with timeout
try:
logger.info("Running model prediction...")
# Add timeout to prevent hanging
results = await asyncio.wait_for(
asyncio.to_thread(pipe, image),
timeout=30.0
)
logger.info(f"Prediction completed: {len(results)} results")
except asyncio.TimeoutError:
logger.error("Model prediction timed out")
raise HTTPException(status_code=504, detail="Prediction timed out")
except Exception as e:
logger.error(f"Model prediction error: {e}")
raise HTTPException(status_code=500, detail="Prediction failed")
return JSONResponse(content={
"results": results,
"timestamp": datetime.now().isoformat(),
"image_size": image.size
})
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
# Additional endpoint to warm up the model
@app.post("/warmup")
async def warmup():
"""Endpoint to warm up the model"""
global pipe
if pipe is None:
success = load_model()
return {"status": "loaded" if success else "failed"}
return {"status": "already_loaded"} |