|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
from fastapi.responses import JSONResponse |
|
import os |
|
import numpy as np |
|
import librosa |
|
from typing import Dict, Any |
|
import logging |
|
import time |
|
from contextlib import asynccontextmanager |
|
from models.nationality_model import NationalityModel |
|
from models.age_and_gender_model import AgeGenderModel |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
UPLOAD_FOLDER = 'uploads' |
|
ALLOWED_EXTENSIONS = {'wav', 'mp3', 'flac', 'm4a'} |
|
SAMPLING_RATE = 16000 |
|
MAX_DURATION_SECONDS = 120 |
|
|
|
os.makedirs(UPLOAD_FOLDER, exist_ok=True) |
|
|
|
|
|
age_gender_model = None |
|
nationality_model = None |
|
|
|
def allowed_file(filename: str) -> bool: |
|
return '.' in filename and \ |
|
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS |
|
|
|
def clip_audio_to_max_duration(audio_data: np.ndarray, sr: int, max_duration: int = MAX_DURATION_SECONDS) -> tuple[np.ndarray, bool]: |
|
current_duration = len(audio_data) / sr |
|
|
|
if current_duration <= max_duration: |
|
logger.info(f"Audio duration ({current_duration:.2f}s) is within limit ({max_duration}s) - no clipping needed") |
|
return audio_data, False |
|
|
|
|
|
max_samples = int(max_duration * sr) |
|
|
|
|
|
clipped_audio = audio_data[:max_samples] |
|
|
|
logger.info(f"Audio clipped from {current_duration:.2f}s to {max_duration}s ({len(audio_data)} samples → {len(clipped_audio)} samples)") |
|
|
|
return clipped_audio, True |
|
|
|
async def load_models() -> bool: |
|
global age_gender_model, nationality_model |
|
|
|
try: |
|
total_start_time = time.time() |
|
|
|
|
|
logger.info("Starting age & gender model loading...") |
|
age_start = time.time() |
|
age_gender_model = AgeGenderModel() |
|
age_gender_success = age_gender_model.load() |
|
age_end = time.time() |
|
|
|
if not age_gender_success: |
|
logger.error("Failed to load age & gender model") |
|
return False |
|
|
|
logger.info(f"Age & gender model loaded successfully in {age_end - age_start:.2f} seconds") |
|
|
|
|
|
logger.info("Starting nationality model loading...") |
|
nationality_start = time.time() |
|
nationality_model = NationalityModel() |
|
nationality_success = nationality_model.load() |
|
nationality_end = time.time() |
|
|
|
if not nationality_success: |
|
logger.error("Failed to load nationality model") |
|
return False |
|
|
|
logger.info(f"Nationality model loaded successfully in {nationality_end - nationality_start:.2f} seconds") |
|
|
|
total_end = time.time() |
|
logger.info(f"All models loaded successfully! Total time: {total_end - total_start_time:.2f} seconds") |
|
return True |
|
except Exception as e: |
|
logger.error(f"Error loading models: {e}") |
|
return False |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
logger.info("Starting FastAPI application...") |
|
startup_start = time.time() |
|
success = await load_models() |
|
startup_end = time.time() |
|
|
|
if not success: |
|
logger.error("Failed to load models. Application will not work properly.") |
|
else: |
|
logger.info(f"FastAPI application started successfully in {startup_end - startup_start:.2f} seconds") |
|
|
|
yield |
|
|
|
|
|
logger.info("Shutting down FastAPI application...") |
|
|
|
|
|
app = FastAPI( |
|
title="Audio Analysis API", |
|
description="audio analysis for age, gender, and nationality prediction", |
|
version="1.0.0", |
|
lifespan=lifespan |
|
) |
|
|
|
def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int, bool]: |
|
preprocess_start = time.time() |
|
original_shape = audio_data.shape |
|
original_duration = len(audio_data) / sr |
|
logger.info(f"Starting audio preprocessing Sample rate: {sr}Hz, Duration: {original_duration:.2f}s") |
|
|
|
|
|
if len(audio_data.shape) > 1: |
|
mono_start = time.time() |
|
audio_data = librosa.to_mono(audio_data) |
|
mono_end = time.time() |
|
logger.info(f"Converted stereo to mono in {mono_end - mono_start:.3f} seconds - New shape: {audio_data.shape}") |
|
|
|
|
|
if sr != SAMPLING_RATE: |
|
resample_start = time.time() |
|
logger.info(f"Resampling from {sr}Hz to {SAMPLING_RATE}Hz...") |
|
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLING_RATE) |
|
resample_end = time.time() |
|
logger.info(f"Resampling completed in {resample_end - resample_start:.3f} seconds") |
|
sr = SAMPLING_RATE |
|
else: |
|
logger.info(f"No resampling needed - already at {SAMPLING_RATE}Hz") |
|
|
|
|
|
audio_data, was_clipped = clip_audio_to_max_duration(audio_data, sr) |
|
|
|
|
|
audio_data = audio_data.astype(np.float32) |
|
|
|
preprocess_end = time.time() |
|
final_duration_seconds = len(audio_data) / sr |
|
logger.info(f"Audio preprocessing completed in {preprocess_end - preprocess_start:.3f} seconds") |
|
logger.info(f"Final audio: {audio_data.shape} samples, {final_duration_seconds:.2f} seconds duration") |
|
|
|
return audio_data, sr, was_clipped |
|
|
|
async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int, bool]: |
|
process_start = time.time() |
|
logger.info(f"Processing uploaded file: {file.filename}") |
|
|
|
if not file.filename: |
|
raise HTTPException(status_code=400, detail="No file selected") |
|
|
|
if not allowed_file(file.filename): |
|
logger.warning(f"Invalid file type uploaded: {file.filename}") |
|
raise HTTPException(status_code=400, detail="Invalid file type. Allowed: wav, mp3, flac, m4a") |
|
|
|
|
|
file_ext = file.filename.rsplit('.', 1)[1].lower() |
|
logger.info(f"Processing {file_ext.upper()} file: {file.filename}") |
|
|
|
|
|
filename = f"temp_{int(time.time())}_{file.filename}" |
|
filepath = os.path.join(UPLOAD_FOLDER, filename) |
|
|
|
try: |
|
|
|
save_start = time.time() |
|
with open(filepath, "wb") as buffer: |
|
content = await file.read() |
|
buffer.write(content) |
|
save_end = time.time() |
|
|
|
file_size_mb = len(content) / (1024 * 1024) |
|
logger.info(f"File saved ({file_size_mb:.2f} MB) in {save_end - save_start:.3f} seconds") |
|
|
|
|
|
load_start = time.time() |
|
logger.info(f"Loading audio from {filepath}...") |
|
audio_data, sr = librosa.load(filepath, sr=None) |
|
load_end = time.time() |
|
logger.info(f"Audio loaded in {load_end - load_start:.3f} seconds") |
|
|
|
processed_audio, processed_sr, was_clipped = preprocess_audio(audio_data, sr) |
|
|
|
process_end = time.time() |
|
logger.info(f"Total file processing completed in {process_end - process_start:.3f} seconds") |
|
|
|
return processed_audio, processed_sr, was_clipped |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing audio file {file.filename}: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}") |
|
finally: |
|
|
|
if os.path.exists(filepath): |
|
os.remove(filepath) |
|
logger.info(f"Temporary file {filename} cleaned up") |
|
|
|
@app.get("/") |
|
async def root() -> Dict[str, Any]: |
|
logger.info("Root endpoint accessed") |
|
return { |
|
"message": "Audio Analysis API - Age, Gender & Nationality Prediction", |
|
"max_audio_duration": f"{MAX_DURATION_SECONDS} seconds (files longer than this will be automatically clipped)", |
|
"models_loaded": { |
|
"age_gender": age_gender_model is not None and hasattr(age_gender_model, 'model') and age_gender_model.model is not None, |
|
"nationality": nationality_model is not None and hasattr(nationality_model, 'model') and nationality_model.model is not None |
|
}, |
|
"endpoints": { |
|
"/predict_age_and_gender": "POST - Upload audio file for age and gender prediction", |
|
"/predict_nationality": "POST - Upload audio file for nationality prediction", |
|
"/predict_all": "POST - Upload audio file for complete analysis (age, gender, nationality)", |
|
}, |
|
"docs": "/docs - Interactive API documentation", |
|
"openapi": "/openapi.json - OpenAPI schema" |
|
} |
|
|
|
@app.get("/health") |
|
async def health_check() -> Dict[str, str]: |
|
logger.info("Health check endpoint accessed") |
|
return {"status": "healthy"} |
|
|
|
@app.post("/predict_age_and_gender") |
|
async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]: |
|
endpoint_start = time.time() |
|
logger.info(f"Age & Gender prediction requested for file: {file.filename}") |
|
|
|
if age_gender_model is None or not hasattr(age_gender_model, 'model') or age_gender_model.model is None: |
|
logger.error("Age & gender model not loaded - returning 500 error") |
|
raise HTTPException(status_code=500, detail="Age & gender model not loaded") |
|
|
|
try: |
|
processed_audio, processed_sr, was_clipped = await process_audio_file(file) |
|
|
|
|
|
prediction_start = time.time() |
|
logger.info("Starting age & gender prediction...") |
|
predictions = age_gender_model.predict(processed_audio, processed_sr) |
|
prediction_end = time.time() |
|
|
|
logger.info(f"Age & gender prediction completed in {prediction_end - prediction_start:.3f} seconds") |
|
logger.info(f"Predicted age: {predictions['age']['predicted_age']:.1f} years") |
|
logger.info(f"Predicted gender: {predictions['gender']['predicted_gender']} (confidence: {predictions['gender']['confidence']:.3f})") |
|
|
|
endpoint_end = time.time() |
|
logger.info(f"Total age & gender endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds") |
|
|
|
response = { |
|
"success": True, |
|
"predictions": predictions, |
|
"processing_time": round(endpoint_end - endpoint_start, 3), |
|
"audio_info": { |
|
"was_clipped": was_clipped, |
|
"max_duration_seconds": MAX_DURATION_SECONDS |
|
} |
|
} |
|
|
|
if was_clipped: |
|
response["warning"] = f"Audio was longer than {MAX_DURATION_SECONDS} seconds and was automatically clipped to the first {MAX_DURATION_SECONDS} seconds for analysis." |
|
|
|
return response |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
logger.error(f"Error in age & gender prediction: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/predict_nationality") |
|
async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]: |
|
endpoint_start = time.time() |
|
logger.info(f"Nationality prediction requested for file: {file.filename}") |
|
|
|
if nationality_model is None or not hasattr(nationality_model, 'model') or nationality_model.model is None: |
|
logger.error("Nationality model not loaded - returning 500 error") |
|
raise HTTPException(status_code=500, detail="Nationality model not loaded") |
|
|
|
try: |
|
processed_audio, processed_sr, was_clipped = await process_audio_file(file) |
|
|
|
|
|
prediction_start = time.time() |
|
logger.info("Starting nationality prediction...") |
|
predictions = nationality_model.predict(processed_audio, processed_sr) |
|
prediction_end = time.time() |
|
|
|
logger.info(f"Nationality prediction completed in {prediction_end - prediction_start:.3f} seconds") |
|
logger.info(f"Predicted language: {predictions['predicted_language']} (confidence: {predictions['confidence']:.3f})") |
|
logger.info(f"Top 3 languages: {[lang['language_code'] for lang in predictions['top_languages'][:3]]}") |
|
|
|
endpoint_end = time.time() |
|
logger.info(f"Total nationality endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds") |
|
|
|
response = { |
|
"success": True, |
|
"predictions": predictions, |
|
"processing_time": round(endpoint_end - endpoint_start, 3), |
|
"audio_info": { |
|
"was_clipped": was_clipped, |
|
"max_duration_seconds": MAX_DURATION_SECONDS |
|
} |
|
} |
|
|
|
if was_clipped: |
|
response["warning"] = f"Audio was longer than {MAX_DURATION_SECONDS} seconds and was automatically clipped to the first {MAX_DURATION_SECONDS} seconds for analysis." |
|
|
|
return response |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
logger.error(f"Error in nationality prediction: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/predict_all") |
|
async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]: |
|
endpoint_start = time.time() |
|
logger.info(f"Complete analysis requested for file: {file.filename}") |
|
|
|
if age_gender_model is None or not hasattr(age_gender_model, 'model') or age_gender_model.model is None: |
|
logger.error("Age & gender model not loaded - returning 500 error") |
|
raise HTTPException(status_code=500, detail="Age & gender model not loaded") |
|
|
|
if nationality_model is None or not hasattr(nationality_model, 'model') or nationality_model.model is None: |
|
logger.error("Nationality model not loaded - returning 500 error") |
|
raise HTTPException(status_code=500, detail="Nationality model not loaded") |
|
|
|
try: |
|
processed_audio, processed_sr, was_clipped = await process_audio_file(file) |
|
|
|
|
|
age_prediction_start = time.time() |
|
logger.info("Starting age & gender prediction for complete analysis...") |
|
age_gender_predictions = age_gender_model.predict(processed_audio, processed_sr) |
|
age_prediction_end = time.time() |
|
logger.info(f"Age & gender prediction completed in {age_prediction_end - age_prediction_start:.3f} seconds") |
|
|
|
|
|
nationality_prediction_start = time.time() |
|
logger.info("Starting nationality prediction for complete analysis...") |
|
nationality_predictions = nationality_model.predict(processed_audio, processed_sr) |
|
nationality_prediction_end = time.time() |
|
logger.info(f"Nationality prediction completed in {nationality_prediction_end - nationality_prediction_start:.3f} seconds") |
|
|
|
|
|
logger.info(f"Complete analysis results:") |
|
logger.info(f" - Age: {age_gender_predictions['age']['predicted_age']:.1f} years") |
|
logger.info(f" - Gender: {age_gender_predictions['gender']['predicted_gender']} (confidence: {age_gender_predictions['gender']['confidence']:.3f})") |
|
logger.info(f" - Language: {nationality_predictions['predicted_language']} (confidence: {nationality_predictions['confidence']:.3f})") |
|
|
|
total_prediction_time = (age_prediction_end - age_prediction_start) + (nationality_prediction_end - nationality_prediction_start) |
|
endpoint_end = time.time() |
|
|
|
logger.info(f"Total prediction time: {total_prediction_time:.3f} seconds") |
|
logger.info(f"Total complete analysis endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds") |
|
|
|
response = { |
|
"success": True, |
|
"predictions": { |
|
"demographics": age_gender_predictions, |
|
"nationality": nationality_predictions |
|
}, |
|
"processing_time": { |
|
"total": round(endpoint_end - endpoint_start, 3), |
|
"age_gender": round(age_prediction_end - age_prediction_start, 3), |
|
"nationality": round(nationality_prediction_end - nationality_prediction_start, 3) |
|
}, |
|
"audio_info": { |
|
"was_clipped": was_clipped, |
|
"max_duration_seconds": MAX_DURATION_SECONDS |
|
} |
|
} |
|
|
|
if was_clipped: |
|
response["warning"] = f"Audio was longer than {MAX_DURATION_SECONDS} seconds and was automatically clipped to the first {MAX_DURATION_SECONDS} seconds for analysis." |
|
|
|
return response |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
logger.error(f"Error in complete analysis: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
port = int(os.environ.get("PORT", 7860)) |
|
logger.info(f"Starting server on port {port}") |
|
uvicorn.run( |
|
"app:app", |
|
host="0.0.0.0", |
|
port=port, |
|
reload=False, |
|
log_level="info" |
|
) |