File size: 7,801 Bytes
d7e7912
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93ec2a8
d7e7912
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import os
import numpy as np
import librosa
from typing import Dict, Any
import logging
from contextlib import asynccontextmanager
from models.nationality_model import NationalityModel
from models.age_and_gender_model import AgeGenderModel

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

UPLOAD_FOLDER = 'uploads'
ALLOWED_EXTENSIONS = {'wav', 'mp3', 'flac', 'm4a'}
SAMPLING_RATE = 16000

os.makedirs(UPLOAD_FOLDER, exist_ok=True)

# Global model variables
age_gender_model = None
nationality_model = None

def allowed_file(filename: str) -> bool:
    return '.' in filename and \
           filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

async def load_models() -> bool:
    global age_gender_model, nationality_model
    
    try:
        # Load age & gender model
        logger.info("Loading age & gender model...")
        age_gender_model = AgeGenderModel()
        age_gender_success = age_gender_model.load()
        
        if not age_gender_success:
            logger.error("Failed to load age & gender model")
            return False
        
        # Load nationality model
        logger.info("Loading nationality model...")
        nationality_model = NationalityModel()
        nationality_success = nationality_model.load()
        
        if not nationality_success:
            logger.error("Failed to load nationality model")
            return False
            
        logger.info("All models loaded successfully!")
        return True
    except Exception as e:
        logger.error(f"Error loading models: {e}")
        return False

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup
    logger.info("Starting FastAPI application...")
    success = await load_models()
    if not success:
        logger.error("Failed to load models. Application will not work properly.")
    
    yield
    
    # Shutdown
    logger.info("Shutting down FastAPI application...")

# Create FastAPI app with lifespan events
app = FastAPI(
    title="Audio Analysis API",
    description="audio analysis for age, gender, and nationality prediction",
    version="1.0.0",
    lifespan=lifespan
)

def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
    if len(audio_data.shape) > 1:
        audio_data = librosa.to_mono(audio_data)
        
    if sr != SAMPLING_RATE:
        logger.info(f"Resampling from {sr}Hz to {SAMPLING_RATE}Hz")
        audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLING_RATE)
        
    audio_data = audio_data.astype(np.float32)
        
    return audio_data, SAMPLING_RATE

async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int]:
    if not file.filename:
        raise HTTPException(status_code=400, detail="No file selected")
    
    if not allowed_file(file.filename):
        raise HTTPException(status_code=400, detail="Invalid file type. Allowed: wav, mp3, flac, m4a")
    
    # Create a secure filename
    filename = f"temp_{file.filename}"
    filepath = os.path.join(UPLOAD_FOLDER, filename)
    
    try:
        # Save uploaded file temporarily
        with open(filepath, "wb") as buffer:
            content = await file.read()
            buffer.write(content)
        
        # Load and preprocess audio
        audio_data, sr = librosa.load(filepath, sr=None)
        processed_audio, processed_sr = preprocess_audio(audio_data, sr)
        
        return processed_audio, processed_sr
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}")
    finally:
        # Clean up temporary file
        if os.path.exists(filepath):
            os.remove(filepath)

@app.get("/")
async def root() -> Dict[str, Any]:
    return {
        "message": "Audio Analysis API - Age, Gender & Nationality Prediction",
        "models_loaded": {
            "age_gender": age_gender_model is not None and hasattr(age_gender_model, 'model') and age_gender_model.model is not None,
            "nationality": nationality_model is not None and hasattr(nationality_model, 'model') and nationality_model.model is not None
        },
        "endpoints": {
            "/predict_age_and_gender": "POST - Upload audio file for age and gender prediction",
            "/predict_nationality": "POST - Upload audio file for nationality prediction",
            "/predict_all": "POST - Upload audio file for complete analysis (age, gender, nationality)",
        },
        "docs": "/docs - Interactive API documentation",
        "openapi": "/openapi.json - OpenAPI schema"
    }

@app.get("/health")
async def health_check() -> Dict[str, str]:
    return {"status": "healthy"}

@app.post("/predict_age_and_gender")
async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]:
    """Predict age and gender from uploaded audio file."""
    if age_gender_model is None or not hasattr(age_gender_model, 'model') or age_gender_model.model is None:
        raise HTTPException(status_code=500, detail="Age & gender model not loaded")
    
    try:
        processed_audio, processed_sr = await process_audio_file(file)
        predictions = age_gender_model.predict(processed_audio, processed_sr)
        
        return {
            "success": True,
            "predictions": predictions
        }
        
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/predict_nationality")
async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
    """Predict nationality/language from uploaded audio file."""
    if nationality_model is None or not hasattr(nationality_model, 'model') or nationality_model.model is None:
        raise HTTPException(status_code=500, detail="Nationality model not loaded")
    
    try:
        processed_audio, processed_sr = await process_audio_file(file)
        predictions = nationality_model.predict(processed_audio, processed_sr)
        
        return {
            "success": True,
            "predictions": predictions
        }
        
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/predict_all")
async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
    if age_gender_model is None or not hasattr(age_gender_model, 'model') or age_gender_model.model is None:
        raise HTTPException(status_code=500, detail="Age & gender model not loaded")
    
    if nationality_model is None or not hasattr(nationality_model, 'model') or nationality_model.model is None:
        raise HTTPException(status_code=500, detail="Nationality model not loaded")
    
    try:
        processed_audio, processed_sr = await process_audio_file(file)
        
        # Get both predictions
        age_gender_predictions = age_gender_model.predict(processed_audio, processed_sr)
        nationality_predictions = nationality_model.predict(processed_audio, processed_sr)
        
        return {
            "success": True,
            "predictions": {
                "demographics": age_gender_predictions,
                "nationality": nationality_predictions
            }
        }
        
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run(
        "app:app",
        host="0.0.0.0",
        port=port,
        reload=False,  # Set to True for development
        log_level="info"
    )