|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
from fastapi.responses import JSONResponse |
|
import os |
|
import numpy as np |
|
import librosa |
|
from typing import Dict, Any |
|
import logging |
|
from contextlib import asynccontextmanager |
|
from models.nationality_model import NationalityModel |
|
from models.age_and_gender_model import AgeGenderModel |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
UPLOAD_FOLDER = 'uploads' |
|
ALLOWED_EXTENSIONS = {'wav', 'mp3', 'flac', 'm4a'} |
|
SAMPLING_RATE = 16000 |
|
|
|
os.makedirs(UPLOAD_FOLDER, exist_ok=True) |
|
|
|
|
|
age_gender_model = None |
|
nationality_model = None |
|
|
|
def allowed_file(filename: str) -> bool: |
|
return '.' in filename and \ |
|
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS |
|
|
|
async def load_models() -> bool: |
|
global age_gender_model, nationality_model |
|
|
|
try: |
|
|
|
logger.info("Loading age & gender model...") |
|
age_gender_model = AgeGenderModel() |
|
age_gender_success = age_gender_model.load() |
|
|
|
if not age_gender_success: |
|
logger.error("Failed to load age & gender model") |
|
return False |
|
|
|
|
|
logger.info("Loading nationality model...") |
|
nationality_model = NationalityModel() |
|
nationality_success = nationality_model.load() |
|
|
|
if not nationality_success: |
|
logger.error("Failed to load nationality model") |
|
return False |
|
|
|
logger.info("All models loaded successfully!") |
|
return True |
|
except Exception as e: |
|
logger.error(f"Error loading models: {e}") |
|
return False |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
logger.info("Starting FastAPI application...") |
|
success = await load_models() |
|
if not success: |
|
logger.error("Failed to load models. Application will not work properly.") |
|
|
|
yield |
|
|
|
|
|
logger.info("Shutting down FastAPI application...") |
|
|
|
|
|
app = FastAPI( |
|
title="Audio Analysis API", |
|
description="audio analysis for age, gender, and nationality prediction", |
|
version="1.0.0", |
|
lifespan=lifespan |
|
) |
|
|
|
def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int]: |
|
if len(audio_data.shape) > 1: |
|
audio_data = librosa.to_mono(audio_data) |
|
|
|
if sr != SAMPLING_RATE: |
|
logger.info(f"Resampling from {sr}Hz to {SAMPLING_RATE}Hz") |
|
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLING_RATE) |
|
|
|
audio_data = audio_data.astype(np.float32) |
|
|
|
return audio_data, SAMPLING_RATE |
|
|
|
async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int]: |
|
if not file.filename: |
|
raise HTTPException(status_code=400, detail="No file selected") |
|
|
|
if not allowed_file(file.filename): |
|
raise HTTPException(status_code=400, detail="Invalid file type. Allowed: wav, mp3, flac, m4a") |
|
|
|
|
|
filename = f"temp_{file.filename}" |
|
filepath = os.path.join(UPLOAD_FOLDER, filename) |
|
|
|
try: |
|
|
|
with open(filepath, "wb") as buffer: |
|
content = await file.read() |
|
buffer.write(content) |
|
|
|
|
|
audio_data, sr = librosa.load(filepath, sr=None) |
|
processed_audio, processed_sr = preprocess_audio(audio_data, sr) |
|
|
|
return processed_audio, processed_sr |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}") |
|
finally: |
|
|
|
if os.path.exists(filepath): |
|
os.remove(filepath) |
|
|
|
@app.get("/") |
|
async def root() -> Dict[str, Any]: |
|
return { |
|
"message": "Audio Analysis API - Age, Gender & Nationality Prediction", |
|
"models_loaded": { |
|
"age_gender": age_gender_model is not None and hasattr(age_gender_model, 'model') and age_gender_model.model is not None, |
|
"nationality": nationality_model is not None and hasattr(nationality_model, 'model') and nationality_model.model is not None |
|
}, |
|
"endpoints": { |
|
"/predict_age_and_gender": "POST - Upload audio file for age and gender prediction", |
|
"/predict_nationality": "POST - Upload audio file for nationality prediction", |
|
"/predict_all": "POST - Upload audio file for complete analysis (age, gender, nationality)", |
|
}, |
|
"docs": "/docs - Interactive API documentation", |
|
"openapi": "/openapi.json - OpenAPI schema" |
|
} |
|
|
|
@app.get("/health") |
|
async def health_check() -> Dict[str, str]: |
|
return {"status": "healthy"} |
|
|
|
@app.post("/predict_age_and_gender") |
|
async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]: |
|
"""Predict age and gender from uploaded audio file.""" |
|
if age_gender_model is None or not hasattr(age_gender_model, 'model') or age_gender_model.model is None: |
|
raise HTTPException(status_code=500, detail="Age & gender model not loaded") |
|
|
|
try: |
|
processed_audio, processed_sr = await process_audio_file(file) |
|
predictions = age_gender_model.predict(processed_audio, processed_sr) |
|
|
|
return { |
|
"success": True, |
|
"predictions": predictions |
|
} |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/predict_nationality") |
|
async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]: |
|
"""Predict nationality/language from uploaded audio file.""" |
|
if nationality_model is None or not hasattr(nationality_model, 'model') or nationality_model.model is None: |
|
raise HTTPException(status_code=500, detail="Nationality model not loaded") |
|
|
|
try: |
|
processed_audio, processed_sr = await process_audio_file(file) |
|
predictions = nationality_model.predict(processed_audio, processed_sr) |
|
|
|
return { |
|
"success": True, |
|
"predictions": predictions |
|
} |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/predict_all") |
|
async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]: |
|
if age_gender_model is None or not hasattr(age_gender_model, 'model') or age_gender_model.model is None: |
|
raise HTTPException(status_code=500, detail="Age & gender model not loaded") |
|
|
|
if nationality_model is None or not hasattr(nationality_model, 'model') or nationality_model.model is None: |
|
raise HTTPException(status_code=500, detail="Nationality model not loaded") |
|
|
|
try: |
|
processed_audio, processed_sr = await process_audio_file(file) |
|
|
|
|
|
age_gender_predictions = age_gender_model.predict(processed_audio, processed_sr) |
|
nationality_predictions = nationality_model.predict(processed_audio, processed_sr) |
|
|
|
return { |
|
"success": True, |
|
"predictions": { |
|
"demographics": age_gender_predictions, |
|
"nationality": nationality_predictions |
|
} |
|
} |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
port = int(os.environ.get("PORT", 7860)) |
|
uvicorn.run( |
|
"app:app", |
|
host="0.0.0.0", |
|
port=port, |
|
reload=False, |
|
log_level="info" |
|
) |