dtrovato997 commited on
Commit
d7e7912
·
1 Parent(s): 7c86e3c

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ pip-wheel-metadata/
20
+ share/python-wheels/
21
+ *.egg-info/
22
+ .installed.cfg
23
+ *.egg
24
+ MANIFEST
25
+
26
+ # Virtual environments
27
+ .env
28
+ .venv
29
+ env/
30
+ venv/
31
+ ENV/
32
+ env.bak/
33
+ venv.bak/
34
+
35
+ # Model cache directories (these will be large!)
36
+ cache/
37
+ models/cache/
38
+ *.pt
39
+ *.pth
40
+ *.bin
41
+ *.onnx
42
+ *.h5
43
+ *.pkl
44
+ *.joblib
45
+
46
+ # Audio uploads (temporary files)
47
+ uploads/
48
+ *.wav
49
+ *.mp3
50
+ *.flac
51
+ *.m4a
52
+ *.ogg
53
+
54
+ # IDE/Editor files
55
+ .vscode/
56
+ .idea/
57
+ *.swp
58
+ *.swo
59
+ *~
60
+ .DS_Store
61
+ Thumbs.db
62
+
63
+ # Logs
64
+ *.log
65
+ logs/
66
+
67
+ # Environment variables
68
+ .env.local
69
+ .env.development.local
70
+ .env.test.local
71
+ .env.production.local
72
+
73
+ # Jupyter Notebook
74
+ .ipynb_checkpoints
75
+
76
+ # Flask specific
77
+ instance/
78
+ .webassets-cache
79
+
80
+ # Coverage reports
81
+ htmlcov/
82
+ .tox/
83
+ .coverage
84
+ .coverage.*
85
+ .cache
86
+ nosetests.xml
87
+ coverage.xml
88
+ *.cover
89
+ *.py,cover
90
+ .hypothesis/
91
+ .pytest_cache/
92
+
93
+ # Temporary files
94
+ *.tmp
95
+ *.temp
96
+ .tmp/
97
+ .temp/
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ # Install system dependencies for audio processing
4
+ RUN apt-get update && apt-get install -y \
5
+ ffmpeg \
6
+ libsndfile1 \
7
+ gcc \
8
+ g++ \
9
+ curl \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Set working directory
13
+ WORKDIR /app
14
+
15
+ # Copy requirements and install Python dependencies
16
+ COPY requirements.txt .
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Copy application code
20
+ COPY . .
21
+
22
+ # Create directories
23
+ RUN mkdir -p uploads cache
24
+
25
+ # Expose port 7860 (HF Spaces default)
26
+ EXPOSE 7860
27
+
28
+ # Start the Flask app
29
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ import os
4
+ import numpy as np
5
+ import librosa
6
+ from typing import Dict, Any
7
+ import logging
8
+ from contextlib import asynccontextmanager
9
+ from models.nationality_model import NationalityModel
10
+ from models.age_and_gender_model import AgeGenderModel
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ UPLOAD_FOLDER = 'uploads'
17
+ ALLOWED_EXTENSIONS = {'wav', 'mp3', 'flac', 'm4a'}
18
+ SAMPLING_RATE = 16000
19
+
20
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
21
+
22
+ # Global model variables
23
+ age_gender_model = None
24
+ nationality_model = None
25
+
26
+ def allowed_file(filename: str) -> bool:
27
+ return '.' in filename and \
28
+ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
29
+
30
+ async def load_models() -> bool:
31
+ global age_gender_model, nationality_model
32
+
33
+ try:
34
+ # Load age & gender model
35
+ logger.info("Loading age & gender model...")
36
+ age_gender_model = AgeGenderModel()
37
+ age_gender_success = age_gender_model.load()
38
+
39
+ if not age_gender_success:
40
+ logger.error("Failed to load age & gender model")
41
+ return False
42
+
43
+ # Load nationality model
44
+ logger.info("Loading nationality model...")
45
+ nationality_model = NationalityModel()
46
+ nationality_success = nationality_model.load()
47
+
48
+ if not nationality_success:
49
+ logger.error("Failed to load nationality model")
50
+ return False
51
+
52
+ logger.info("All models loaded successfully!")
53
+ return True
54
+ except Exception as e:
55
+ logger.error(f"Error loading models: {e}")
56
+ return False
57
+
58
+ @asynccontextmanager
59
+ async def lifespan(app: FastAPI):
60
+ # Startup
61
+ logger.info("Starting FastAPI application...")
62
+ # success = await load_models()
63
+ succes = true
64
+ if not success:
65
+ logger.error("Failed to load models. Application will not work properly.")
66
+
67
+ yield
68
+
69
+ # Shutdown
70
+ logger.info("Shutting down FastAPI application...")
71
+
72
+ # Create FastAPI app with lifespan events
73
+ app = FastAPI(
74
+ title="Audio Analysis API",
75
+ description="audio analysis for age, gender, and nationality prediction",
76
+ version="1.0.0",
77
+ lifespan=lifespan
78
+ )
79
+
80
+ def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
81
+ if len(audio_data.shape) > 1:
82
+ audio_data = librosa.to_mono(audio_data)
83
+
84
+ if sr != SAMPLING_RATE:
85
+ logger.info(f"Resampling from {sr}Hz to {SAMPLING_RATE}Hz")
86
+ audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLING_RATE)
87
+
88
+ audio_data = audio_data.astype(np.float32)
89
+
90
+ return audio_data, SAMPLING_RATE
91
+
92
+ async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int]:
93
+ if not file.filename:
94
+ raise HTTPException(status_code=400, detail="No file selected")
95
+
96
+ if not allowed_file(file.filename):
97
+ raise HTTPException(status_code=400, detail="Invalid file type. Allowed: wav, mp3, flac, m4a")
98
+
99
+ # Create a secure filename
100
+ filename = f"temp_{file.filename}"
101
+ filepath = os.path.join(UPLOAD_FOLDER, filename)
102
+
103
+ try:
104
+ # Save uploaded file temporarily
105
+ with open(filepath, "wb") as buffer:
106
+ content = await file.read()
107
+ buffer.write(content)
108
+
109
+ # Load and preprocess audio
110
+ audio_data, sr = librosa.load(filepath, sr=None)
111
+ processed_audio, processed_sr = preprocess_audio(audio_data, sr)
112
+
113
+ return processed_audio, processed_sr
114
+
115
+ except Exception as e:
116
+ raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}")
117
+ finally:
118
+ # Clean up temporary file
119
+ if os.path.exists(filepath):
120
+ os.remove(filepath)
121
+
122
+ @app.get("/")
123
+ async def root() -> Dict[str, Any]:
124
+ return {
125
+ "message": "Audio Analysis API - Age, Gender & Nationality Prediction",
126
+ "models_loaded": {
127
+ "age_gender": age_gender_model is not None and hasattr(age_gender_model, 'model') and age_gender_model.model is not None,
128
+ "nationality": nationality_model is not None and hasattr(nationality_model, 'model') and nationality_model.model is not None
129
+ },
130
+ "endpoints": {
131
+ "/predict_age_and_gender": "POST - Upload audio file for age and gender prediction",
132
+ "/predict_nationality": "POST - Upload audio file for nationality prediction",
133
+ "/predict_all": "POST - Upload audio file for complete analysis (age, gender, nationality)",
134
+ },
135
+ "docs": "/docs - Interactive API documentation",
136
+ "openapi": "/openapi.json - OpenAPI schema"
137
+ }
138
+
139
+ @app.get("/health")
140
+ async def health_check() -> Dict[str, str]:
141
+ return {"status": "healthy"}
142
+
143
+ @app.post("/predict_age_and_gender")
144
+ async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]:
145
+ """Predict age and gender from uploaded audio file."""
146
+ if age_gender_model is None or not hasattr(age_gender_model, 'model') or age_gender_model.model is None:
147
+ raise HTTPException(status_code=500, detail="Age & gender model not loaded")
148
+
149
+ try:
150
+ processed_audio, processed_sr = await process_audio_file(file)
151
+ predictions = age_gender_model.predict(processed_audio, processed_sr)
152
+
153
+ return {
154
+ "success": True,
155
+ "predictions": predictions
156
+ }
157
+
158
+ except HTTPException:
159
+ raise
160
+ except Exception as e:
161
+ raise HTTPException(status_code=500, detail=str(e))
162
+
163
+ @app.post("/predict_nationality")
164
+ async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
165
+ """Predict nationality/language from uploaded audio file."""
166
+ if nationality_model is None or not hasattr(nationality_model, 'model') or nationality_model.model is None:
167
+ raise HTTPException(status_code=500, detail="Nationality model not loaded")
168
+
169
+ try:
170
+ processed_audio, processed_sr = await process_audio_file(file)
171
+ predictions = nationality_model.predict(processed_audio, processed_sr)
172
+
173
+ return {
174
+ "success": True,
175
+ "predictions": predictions
176
+ }
177
+
178
+ except HTTPException:
179
+ raise
180
+ except Exception as e:
181
+ raise HTTPException(status_code=500, detail=str(e))
182
+
183
+ @app.post("/predict_all")
184
+ async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
185
+ if age_gender_model is None or not hasattr(age_gender_model, 'model') or age_gender_model.model is None:
186
+ raise HTTPException(status_code=500, detail="Age & gender model not loaded")
187
+
188
+ if nationality_model is None or not hasattr(nationality_model, 'model') or nationality_model.model is None:
189
+ raise HTTPException(status_code=500, detail="Nationality model not loaded")
190
+
191
+ try:
192
+ processed_audio, processed_sr = await process_audio_file(file)
193
+
194
+ # Get both predictions
195
+ age_gender_predictions = age_gender_model.predict(processed_audio, processed_sr)
196
+ nationality_predictions = nationality_model.predict(processed_audio, processed_sr)
197
+
198
+ return {
199
+ "success": True,
200
+ "predictions": {
201
+ "demographics": age_gender_predictions,
202
+ "nationality": nationality_predictions
203
+ }
204
+ }
205
+
206
+ except HTTPException:
207
+ raise
208
+ except Exception as e:
209
+ raise HTTPException(status_code=500, detail=str(e))
210
+
211
+ if __name__ == "__main__":
212
+ import uvicorn
213
+ port = int(os.environ.get("PORT", 7860))
214
+ uvicorn.run(
215
+ "app:app",
216
+ host="0.0.0.0",
217
+ port=port,
218
+ reload=False, # Set to True for development
219
+ log_level="info"
220
+ )
models/age_and_gender_model.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import audeer
4
+ import audonnx
5
+ import audinterface
6
+ import librosa
7
+
8
+ class AgeGenderModel:
9
+ def __init__(self, model_path="./cache/age_and_gender"):
10
+ self.model_path = model_path
11
+ self.model = None
12
+ self.interface = None
13
+ self.sampling_rate = 16000
14
+ os.makedirs(model_path, exist_ok=True)
15
+
16
+ def download_model(self):
17
+ model_onnx = os.path.join(self.model_path, 'model.onnx')
18
+ model_yaml = os.path.join(self.model_path, 'model.yaml')
19
+
20
+ if os.path.exists(model_onnx) and os.path.exists(model_yaml):
21
+ print("Age & gender model files already exist, skipping download.")
22
+ return True
23
+
24
+ print("Age & gender model files not found. Downloading...")
25
+
26
+ try:
27
+ cache_root = 'cache'
28
+ audeer.mkdir(cache_root)
29
+ audeer.mkdir(self.model_path)
30
+
31
+ def cache_path(file):
32
+ return os.path.join(cache_root, file)
33
+
34
+ url = 'https://zenodo.org/record/7761387/files/w2v2-L-robust-24-age-gender.728d5a4c-1.1.1.zip'
35
+ dst_path = cache_path('model.zip')
36
+
37
+ if not os.path.exists(dst_path):
38
+ print(f"Downloading model from {url}...")
39
+ audeer.download_url(url, dst_path, verbose=True)
40
+
41
+ print(f"Extracting model to {self.model_path}...")
42
+ audeer.extract_archive(dst_path, self.model_path, verbose=True)
43
+
44
+ if os.path.exists(model_onnx) and os.path.exists(model_yaml):
45
+ print("Age & gender model downloaded and extracted successfully!")
46
+
47
+ if os.path.exists(dst_path):
48
+ os.remove(dst_path)
49
+ return True
50
+ else:
51
+ print("Age & gender model extraction failed, files not found after extraction")
52
+ return False
53
+
54
+ except Exception as e:
55
+ print(f"Error downloading age & gender model: {e}")
56
+ return False
57
+
58
+ def load(self):
59
+ try:
60
+ # Download model if needed
61
+ if not self.download_model():
62
+ print("Failed to download age & gender model")
63
+ return False
64
+
65
+ # Load the audonnx model
66
+ print("Loading age & gender model...")
67
+ self.model = audonnx.load(self.model_path)
68
+
69
+ # Create the audinterface Feature interface
70
+ outputs = ['logits_age', 'logits_gender']
71
+ self.interface = audinterface.Feature(
72
+ self.model.labels(outputs),
73
+ process_func=self.model,
74
+ process_func_args={
75
+ 'outputs': outputs,
76
+ 'concat': True,
77
+ },
78
+ sampling_rate=self.sampling_rate,
79
+ resample=False, # We handle resampling manually
80
+ verbose=False,
81
+ )
82
+ print("Age & gender model loaded successfully!")
83
+ return True
84
+ except Exception as e:
85
+ print(f"Error loading age & gender model: {e}")
86
+ return False
87
+
88
+
89
+ def predict(self, audio_data, sr):
90
+ if self.model is None or self.interface is None:
91
+ raise ValueError("Model not loaded. Call load() first.")
92
+
93
+ try: # Process with the interface
94
+ result = self.interface.process_signal(audio_data, sr)
95
+
96
+ # Extract and process results
97
+ age_score = result['age'].values[0]
98
+ gender_logits = {
99
+ 'female': result['female'].values[0],
100
+ 'male': result['male'].values[0],
101
+ 'child': result['child'].values[0]
102
+ }
103
+
104
+ predicted_age = age_score * 100
105
+ gender_values = np.array(list(gender_logits.values()))
106
+ gender_probs = np.exp(gender_values) / np.sum(np.exp(gender_values))
107
+
108
+ gender_labels = ['female', 'male', 'child']
109
+ gender_probabilities = {
110
+ label: float(prob) for label, prob in zip(gender_labels, gender_probs)
111
+ }
112
+
113
+ # Find most likely gender
114
+ predicted_gender = gender_labels[np.argmax(gender_probs)]
115
+ max_probability = float(np.max(gender_probs))
116
+
117
+ return {
118
+ 'age': {
119
+ 'predicted_age': float(predicted_age)
120
+ },
121
+ 'gender': {
122
+ 'predicted_gender': predicted_gender,
123
+ 'probabilities': gender_probabilities,
124
+ 'confidence': max_probability
125
+ }
126
+ }
127
+ except Exception as e:
128
+ raise Exception(f"Age & gender prediction error: {str(e)}")
models/nationality_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
5
+
6
+ # Constants
7
+ MODEL_ID = "facebook/mms-lid-256"
8
+ SAMPLING_RATE = 16000
9
+
10
+ class NationalityModel:
11
+ def __init__(self, cache_dir="./cache/nationality"):
12
+ self.processor = None
13
+ self.model = None
14
+ self.cache_dir = cache_dir
15
+ os.makedirs(cache_dir, exist_ok=True)
16
+
17
+ def load(self):
18
+ try:
19
+ print(f"Loading nationality prediction model from {MODEL_ID}...")
20
+ self.processor = AutoFeatureExtractor.from_pretrained(MODEL_ID, cache_dir=self.cache_dir)
21
+ self.model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_ID, cache_dir=self.cache_dir)
22
+ print("Nationality prediction model loaded successfully!")
23
+ return True
24
+ except Exception as e:
25
+ print(f"Error loading nationality prediction model: {e}")
26
+ return False
27
+
28
+ def predict(self, audio_data, sampling_rate):
29
+ if self.model is None or self.processor is None:
30
+ raise ValueError("Model not loaded. Call load() first.")
31
+
32
+ try:
33
+ # Ensure audio is properly formatted (float32, mono)
34
+ if len(audio_data.shape) > 1:
35
+ audio_data = audio_data.mean(axis=0)
36
+
37
+ audio_data = audio_data.astype(np.float32)
38
+
39
+ # Process audio with the feature extractor
40
+ inputs = self.processor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
41
+
42
+ # Get model predictions
43
+ with torch.no_grad():
44
+ outputs = self.model(**inputs).logits
45
+
46
+ # Get top 5 predictions
47
+ probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
48
+ top_k_values, top_k_indices = torch.topk(probabilities, k=5)
49
+
50
+ # Convert to language codes and probabilities
51
+ top_languages = []
52
+ for i, idx in enumerate(top_k_indices):
53
+ lang_id = idx.item()
54
+ lang_code = self.model.config.id2label[lang_id]
55
+ probability = top_k_values[i].item()
56
+ top_languages.append({
57
+ "language_code": lang_code,
58
+ "probability": probability
59
+ })
60
+
61
+ # Get the most likely language
62
+ predicted_lang_id = torch.argmax(outputs, dim=-1)[0].item()
63
+ predicted_lang = self.model.config.id2label[predicted_lang_id]
64
+ max_probability = probabilities[predicted_lang_id].item()
65
+
66
+ return {
67
+ "predicted_language": predicted_lang,
68
+ "confidence": max_probability,
69
+ "top_languages": top_languages
70
+ }
71
+
72
+ except Exception as e:
73
+ raise Exception(f"Nationality prediction error: {str(e)}")
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi[all]
2
+ uvicorn[standard]
3
+ python-multipart
4
+ audonnx
5
+ audinterface
6
+ librosa
7
+ numpy
8
+ audeer
9
+ torch
10
+ transformers
11
+ torchaudio
12
+ datasets