File size: 5,361 Bytes
60c0001
25139b3
60c0001
25139b3
 
 
60c0001
 
 
25139b3
60c0001
 
 
25139b3
60c0001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25139b3
 
 
60c0001
25139b3
60c0001
 
 
 
 
 
 
 
cc6bef8
 
 
 
 
 
 
 
 
 
 
60c0001
 
 
 
 
 
cc6bef8
60c0001
 
cc6bef8
 
 
 
 
 
60c0001
 
 
 
 
 
 
 
 
cc6bef8
60c0001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
from PIL import Image
import io
import logging
from datetime import datetime
import asyncio

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="Age Detection API", version="1.0.0")

# Add CORS middleware - CRITICAL FIX
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, specify your FlutterFlow domain
    allow_credentials=True,
    allow_methods=["GET", "POST"],
    allow_headers=["*"],
)

# Global variable to store the model
pipe = None

def load_model():
    """Load the model with error handling"""
    global pipe
    try:
        logger.info("Loading age classification model...")
        pipe = pipeline("image-classification", model="nateraw/vit-age-classifier")
        logger.info("Model loaded successfully")
        return True
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        return False

# Load model on startup
@app.on_event("startup")
async def startup_event():
    success = load_model()
    if not success:
        logger.error("Failed to initialize model on startup")

@app.get("/")
async def root():
    return {"message": "Age Detection API is running", "status": "healthy"}

@app.get("/health")
async def health_check():
    """Keep-alive endpoint to prevent sleeping"""
    global pipe
    model_status = "loaded" if pipe is not None else "not_loaded"
    return {
        "status": "alive", 
        "timestamp": datetime.now().isoformat(),
        "model_status": model_status
    }

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    global pipe
    
    try:
        # Check if model is loaded
        if pipe is None:
            logger.warning("Model not loaded, attempting to load...")
            success = load_model()
            if not success:
                raise HTTPException(status_code=500, detail="Model failed to load")
        
        # Validate file type - more robust approach
        # Don't rely solely on content_type as it might be incorrect
        valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp']
        filename_lower = (file.filename or '').lower()
        
        # Check both content type and file extension
        is_valid_content_type = file.content_type and file.content_type.startswith('image/')
        is_valid_extension = any(filename_lower.endswith(ext) for ext in valid_extensions)
        
        if not (is_valid_content_type or is_valid_extension):
            logger.warning(f"Invalid file type: content_type={file.content_type}, filename={file.filename}")
            raise HTTPException(status_code=400, detail="File must be an image")
        
        # Read and process image
        logger.info(f"Processing image: {file.filename}")
        image_data = await file.read()
        
        # Optimize image processing with better error handling
        try:
            image = Image.open(io.BytesIO(image_data))
            
            # Verify it's actually an image by trying to get basic info
            image.verify()  # This will raise an exception if not a valid image
            
            # Reopen the image since verify() closes it
            image = Image.open(io.BytesIO(image_data)).convert("RGB")
            
            # Resize large images to improve speed
            max_size = (1024, 1024)
            if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
                image.thumbnail(max_size, Image.Resampling.LANCZOS)
                logger.info(f"Resized image to {image.size}")
            
        except Exception as e:
            logger.error(f"Image processing error: {e}")
            raise HTTPException(status_code=400, detail="Invalid or corrupted image file")
        
        # Run prediction with timeout
        try:
            logger.info("Running model prediction...")
            # Add timeout to prevent hanging
            results = await asyncio.wait_for(
                asyncio.to_thread(pipe, image), 
                timeout=30.0
            )
            logger.info(f"Prediction completed: {len(results)} results")
            
        except asyncio.TimeoutError:
            logger.error("Model prediction timed out")
            raise HTTPException(status_code=504, detail="Prediction timed out")
        except Exception as e:
            logger.error(f"Model prediction error: {e}")
            raise HTTPException(status_code=500, detail="Prediction failed")
        
        return JSONResponse(content={
            "results": results,
            "timestamp": datetime.now().isoformat(),
            "image_size": image.size
        })
        
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        raise HTTPException(status_code=500, detail="Internal server error")

# Additional endpoint to warm up the model
@app.post("/warmup")
async def warmup():
    """Endpoint to warm up the model"""
    global pipe
    if pipe is None:
        success = load_model()
        return {"status": "loaded" if success else "failed"}
    return {"status": "already_loaded"}