File size: 7,801 Bytes
d7e7912 93ec2a8 d7e7912 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import os
import numpy as np
import librosa
from typing import Dict, Any
import logging
from contextlib import asynccontextmanager
from models.nationality_model import NationalityModel
from models.age_and_gender_model import AgeGenderModel
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
UPLOAD_FOLDER = 'uploads'
ALLOWED_EXTENSIONS = {'wav', 'mp3', 'flac', 'm4a'}
SAMPLING_RATE = 16000
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
# Global model variables
age_gender_model = None
nationality_model = None
def allowed_file(filename: str) -> bool:
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
async def load_models() -> bool:
global age_gender_model, nationality_model
try:
# Load age & gender model
logger.info("Loading age & gender model...")
age_gender_model = AgeGenderModel()
age_gender_success = age_gender_model.load()
if not age_gender_success:
logger.error("Failed to load age & gender model")
return False
# Load nationality model
logger.info("Loading nationality model...")
nationality_model = NationalityModel()
nationality_success = nationality_model.load()
if not nationality_success:
logger.error("Failed to load nationality model")
return False
logger.info("All models loaded successfully!")
return True
except Exception as e:
logger.error(f"Error loading models: {e}")
return False
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
logger.info("Starting FastAPI application...")
success = await load_models()
if not success:
logger.error("Failed to load models. Application will not work properly.")
yield
# Shutdown
logger.info("Shutting down FastAPI application...")
# Create FastAPI app with lifespan events
app = FastAPI(
title="Audio Analysis API",
description="audio analysis for age, gender, and nationality prediction",
version="1.0.0",
lifespan=lifespan
)
def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
if len(audio_data.shape) > 1:
audio_data = librosa.to_mono(audio_data)
if sr != SAMPLING_RATE:
logger.info(f"Resampling from {sr}Hz to {SAMPLING_RATE}Hz")
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLING_RATE)
audio_data = audio_data.astype(np.float32)
return audio_data, SAMPLING_RATE
async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int]:
if not file.filename:
raise HTTPException(status_code=400, detail="No file selected")
if not allowed_file(file.filename):
raise HTTPException(status_code=400, detail="Invalid file type. Allowed: wav, mp3, flac, m4a")
# Create a secure filename
filename = f"temp_{file.filename}"
filepath = os.path.join(UPLOAD_FOLDER, filename)
try:
# Save uploaded file temporarily
with open(filepath, "wb") as buffer:
content = await file.read()
buffer.write(content)
# Load and preprocess audio
audio_data, sr = librosa.load(filepath, sr=None)
processed_audio, processed_sr = preprocess_audio(audio_data, sr)
return processed_audio, processed_sr
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}")
finally:
# Clean up temporary file
if os.path.exists(filepath):
os.remove(filepath)
@app.get("/")
async def root() -> Dict[str, Any]:
return {
"message": "Audio Analysis API - Age, Gender & Nationality Prediction",
"models_loaded": {
"age_gender": age_gender_model is not None and hasattr(age_gender_model, 'model') and age_gender_model.model is not None,
"nationality": nationality_model is not None and hasattr(nationality_model, 'model') and nationality_model.model is not None
},
"endpoints": {
"/predict_age_and_gender": "POST - Upload audio file for age and gender prediction",
"/predict_nationality": "POST - Upload audio file for nationality prediction",
"/predict_all": "POST - Upload audio file for complete analysis (age, gender, nationality)",
},
"docs": "/docs - Interactive API documentation",
"openapi": "/openapi.json - OpenAPI schema"
}
@app.get("/health")
async def health_check() -> Dict[str, str]:
return {"status": "healthy"}
@app.post("/predict_age_and_gender")
async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]:
"""Predict age and gender from uploaded audio file."""
if age_gender_model is None or not hasattr(age_gender_model, 'model') or age_gender_model.model is None:
raise HTTPException(status_code=500, detail="Age & gender model not loaded")
try:
processed_audio, processed_sr = await process_audio_file(file)
predictions = age_gender_model.predict(processed_audio, processed_sr)
return {
"success": True,
"predictions": predictions
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict_nationality")
async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
"""Predict nationality/language from uploaded audio file."""
if nationality_model is None or not hasattr(nationality_model, 'model') or nationality_model.model is None:
raise HTTPException(status_code=500, detail="Nationality model not loaded")
try:
processed_audio, processed_sr = await process_audio_file(file)
predictions = nationality_model.predict(processed_audio, processed_sr)
return {
"success": True,
"predictions": predictions
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict_all")
async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
if age_gender_model is None or not hasattr(age_gender_model, 'model') or age_gender_model.model is None:
raise HTTPException(status_code=500, detail="Age & gender model not loaded")
if nationality_model is None or not hasattr(nationality_model, 'model') or nationality_model.model is None:
raise HTTPException(status_code=500, detail="Nationality model not loaded")
try:
processed_audio, processed_sr = await process_audio_file(file)
# Get both predictions
age_gender_predictions = age_gender_model.predict(processed_audio, processed_sr)
nationality_predictions = nationality_model.predict(processed_audio, processed_sr)
return {
"success": True,
"predictions": {
"demographics": age_gender_predictions,
"nationality": nationality_predictions
}
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(
"app:app",
host="0.0.0.0",
port=port,
reload=False, # Set to True for development
log_level="info"
) |