dtrovato997 commited on
Commit
5277669
·
1 Parent(s): ead33a6

fix : clip audio to max 2 mins

Browse files
main.py CHANGED
@@ -20,6 +20,7 @@ logger = logging.getLogger(__name__)
20
  UPLOAD_FOLDER = 'uploads'
21
  ALLOWED_EXTENSIONS = {'wav', 'mp3', 'flac', 'm4a'}
22
  SAMPLING_RATE = 16000
 
23
 
24
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
25
 
@@ -31,6 +32,23 @@ def allowed_file(filename: str) -> bool:
31
  return '.' in filename and \
32
  filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  async def load_models() -> bool:
35
  global age_gender_model, nationality_model
36
 
@@ -96,10 +114,11 @@ app = FastAPI(
96
  lifespan=lifespan
97
  )
98
 
99
- def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
100
  preprocess_start = time.time()
101
  original_shape = audio_data.shape
102
- logger.info(f"Starting audio preprocessing - Original shape: {original_shape}, Sample rate: {sr}Hz")
 
103
 
104
  # Convert to mono if stereo
105
  if len(audio_data.shape) > 1:
@@ -115,20 +134,24 @@ def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
115
  audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLING_RATE)
116
  resample_end = time.time()
117
  logger.info(f"Resampling completed in {resample_end - resample_start:.3f} seconds")
 
118
  else:
119
  logger.info(f"No resampling needed - already at {SAMPLING_RATE}Hz")
120
 
 
 
 
121
  # Convert to float32
122
  audio_data = audio_data.astype(np.float32)
123
 
124
  preprocess_end = time.time()
125
- duration_seconds = len(audio_data) / SAMPLING_RATE
126
  logger.info(f"Audio preprocessing completed in {preprocess_end - preprocess_start:.3f} seconds")
127
- logger.info(f"Final audio: {audio_data.shape} samples, {duration_seconds:.2f} seconds duration")
128
 
129
- return audio_data, SAMPLING_RATE
130
 
131
- async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int]:
132
  process_start = time.time()
133
  logger.info(f"Processing uploaded file: {file.filename}")
134
 
@@ -165,12 +188,12 @@ async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int]:
165
  load_end = time.time()
166
  logger.info(f"Audio loaded in {load_end - load_start:.3f} seconds")
167
 
168
- processed_audio, processed_sr = preprocess_audio(audio_data, sr)
169
 
170
  process_end = time.time()
171
  logger.info(f"Total file processing completed in {process_end - process_start:.3f} seconds")
172
 
173
- return processed_audio, processed_sr
174
 
175
  except Exception as e:
176
  logger.error(f"Error processing audio file {file.filename}: {str(e)}")
@@ -186,6 +209,7 @@ async def root() -> Dict[str, Any]:
186
  logger.info("Root endpoint accessed")
187
  return {
188
  "message": "Audio Analysis API - Age, Gender & Nationality Prediction",
 
189
  "models_loaded": {
190
  "age_gender": age_gender_model is not None and hasattr(age_gender_model, 'model') and age_gender_model.model is not None,
191
  "nationality": nationality_model is not None and hasattr(nationality_model, 'model') and nationality_model.model is not None
@@ -206,7 +230,6 @@ async def health_check() -> Dict[str, str]:
206
 
207
  @app.post("/predict_age_and_gender")
208
  async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]:
209
- """Predict age and gender from uploaded audio file."""
210
  endpoint_start = time.time()
211
  logger.info(f"Age & Gender prediction requested for file: {file.filename}")
212
 
@@ -215,7 +238,7 @@ async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]
215
  raise HTTPException(status_code=500, detail="Age & gender model not loaded")
216
 
217
  try:
218
- processed_audio, processed_sr = await process_audio_file(file)
219
 
220
  # Make prediction
221
  prediction_start = time.time()
@@ -230,12 +253,21 @@ async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]
230
  endpoint_end = time.time()
231
  logger.info(f"Total age & gender endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
232
 
233
- return {
234
  "success": True,
235
  "predictions": predictions,
236
- "processing_time": round(endpoint_end - endpoint_start, 3)
 
 
 
 
237
  }
238
 
 
 
 
 
 
239
  except HTTPException:
240
  raise
241
  except Exception as e:
@@ -244,7 +276,6 @@ async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]
244
 
245
  @app.post("/predict_nationality")
246
  async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
247
- """Predict nationality/language from uploaded audio file."""
248
  endpoint_start = time.time()
249
  logger.info(f"Nationality prediction requested for file: {file.filename}")
250
 
@@ -253,7 +284,7 @@ async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
253
  raise HTTPException(status_code=500, detail="Nationality model not loaded")
254
 
255
  try:
256
- processed_audio, processed_sr = await process_audio_file(file)
257
 
258
  # Make prediction
259
  prediction_start = time.time()
@@ -268,12 +299,21 @@ async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
268
  endpoint_end = time.time()
269
  logger.info(f"Total nationality endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
270
 
271
- return {
272
  "success": True,
273
  "predictions": predictions,
274
- "processing_time": round(endpoint_end - endpoint_start, 3)
 
 
 
 
275
  }
276
 
 
 
 
 
 
277
  except HTTPException:
278
  raise
279
  except Exception as e:
@@ -282,7 +322,6 @@ async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
282
 
283
  @app.post("/predict_all")
284
  async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
285
- """Predict age, gender, and nationality from uploaded audio file."""
286
  endpoint_start = time.time()
287
  logger.info(f"Complete analysis requested for file: {file.filename}")
288
 
@@ -295,7 +334,7 @@ async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
295
  raise HTTPException(status_code=500, detail="Nationality model not loaded")
296
 
297
  try:
298
- processed_audio, processed_sr = await process_audio_file(file)
299
 
300
  # Get age & gender predictions
301
  age_prediction_start = time.time()
@@ -323,7 +362,7 @@ async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
323
  logger.info(f"Total prediction time: {total_prediction_time:.3f} seconds")
324
  logger.info(f"Total complete analysis endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
325
 
326
- return {
327
  "success": True,
328
  "predictions": {
329
  "demographics": age_gender_predictions,
@@ -333,9 +372,18 @@ async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
333
  "total": round(endpoint_end - endpoint_start, 3),
334
  "age_gender": round(age_prediction_end - age_prediction_start, 3),
335
  "nationality": round(nationality_prediction_end - nationality_prediction_start, 3)
 
 
 
 
336
  }
337
  }
338
 
 
 
 
 
 
339
  except HTTPException:
340
  raise
341
  except Exception as e:
 
20
  UPLOAD_FOLDER = 'uploads'
21
  ALLOWED_EXTENSIONS = {'wav', 'mp3', 'flac', 'm4a'}
22
  SAMPLING_RATE = 16000
23
+ MAX_DURATION_SECONDS = 120 # 2 minutes maximum
24
 
25
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
26
 
 
32
  return '.' in filename and \
33
  filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
34
 
35
+ def clip_audio_to_max_duration(audio_data: np.ndarray, sr: int, max_duration: int = MAX_DURATION_SECONDS) -> tuple[np.ndarray, bool]:
36
+ current_duration = len(audio_data) / sr
37
+
38
+ if current_duration <= max_duration:
39
+ logger.info(f"Audio duration ({current_duration:.2f}s) is within limit ({max_duration}s) - no clipping needed")
40
+ return audio_data, False
41
+
42
+ # Calculate how many samples we need for the max duration
43
+ max_samples = int(max_duration * sr)
44
+
45
+ # Clip to first max_duration seconds
46
+ clipped_audio = audio_data[:max_samples]
47
+
48
+ logger.info(f"Audio clipped from {current_duration:.2f}s to {max_duration}s ({len(audio_data)} samples → {len(clipped_audio)} samples)")
49
+
50
+ return clipped_audio, True
51
+
52
  async def load_models() -> bool:
53
  global age_gender_model, nationality_model
54
 
 
114
  lifespan=lifespan
115
  )
116
 
117
+ def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int, bool]:
118
  preprocess_start = time.time()
119
  original_shape = audio_data.shape
120
+ original_duration = len(audio_data) / sr
121
+ logger.info(f"Starting audio preprocessing Sample rate: {sr}Hz, Duration: {original_duration:.2f}s")
122
 
123
  # Convert to mono if stereo
124
  if len(audio_data.shape) > 1:
 
134
  audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLING_RATE)
135
  resample_end = time.time()
136
  logger.info(f"Resampling completed in {resample_end - resample_start:.3f} seconds")
137
+ sr = SAMPLING_RATE
138
  else:
139
  logger.info(f"No resampling needed - already at {SAMPLING_RATE}Hz")
140
 
141
+ # Clip audio to maximum duration if needed
142
+ audio_data, was_clipped = clip_audio_to_max_duration(audio_data, sr)
143
+
144
  # Convert to float32
145
  audio_data = audio_data.astype(np.float32)
146
 
147
  preprocess_end = time.time()
148
+ final_duration_seconds = len(audio_data) / sr
149
  logger.info(f"Audio preprocessing completed in {preprocess_end - preprocess_start:.3f} seconds")
150
+ logger.info(f"Final audio: {audio_data.shape} samples, {final_duration_seconds:.2f} seconds duration")
151
 
152
+ return audio_data, sr, was_clipped
153
 
154
+ async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int, bool]:
155
  process_start = time.time()
156
  logger.info(f"Processing uploaded file: {file.filename}")
157
 
 
188
  load_end = time.time()
189
  logger.info(f"Audio loaded in {load_end - load_start:.3f} seconds")
190
 
191
+ processed_audio, processed_sr, was_clipped = preprocess_audio(audio_data, sr)
192
 
193
  process_end = time.time()
194
  logger.info(f"Total file processing completed in {process_end - process_start:.3f} seconds")
195
 
196
+ return processed_audio, processed_sr, was_clipped
197
 
198
  except Exception as e:
199
  logger.error(f"Error processing audio file {file.filename}: {str(e)}")
 
209
  logger.info("Root endpoint accessed")
210
  return {
211
  "message": "Audio Analysis API - Age, Gender & Nationality Prediction",
212
+ "max_audio_duration": f"{MAX_DURATION_SECONDS} seconds (files longer than this will be automatically clipped)",
213
  "models_loaded": {
214
  "age_gender": age_gender_model is not None and hasattr(age_gender_model, 'model') and age_gender_model.model is not None,
215
  "nationality": nationality_model is not None and hasattr(nationality_model, 'model') and nationality_model.model is not None
 
230
 
231
  @app.post("/predict_age_and_gender")
232
  async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]:
 
233
  endpoint_start = time.time()
234
  logger.info(f"Age & Gender prediction requested for file: {file.filename}")
235
 
 
238
  raise HTTPException(status_code=500, detail="Age & gender model not loaded")
239
 
240
  try:
241
+ processed_audio, processed_sr, was_clipped = await process_audio_file(file)
242
 
243
  # Make prediction
244
  prediction_start = time.time()
 
253
  endpoint_end = time.time()
254
  logger.info(f"Total age & gender endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
255
 
256
+ response = {
257
  "success": True,
258
  "predictions": predictions,
259
+ "processing_time": round(endpoint_end - endpoint_start, 3),
260
+ "audio_info": {
261
+ "was_clipped": was_clipped,
262
+ "max_duration_seconds": MAX_DURATION_SECONDS
263
+ }
264
  }
265
 
266
+ if was_clipped:
267
+ response["warning"] = f"Audio was longer than {MAX_DURATION_SECONDS} seconds and was automatically clipped to the first {MAX_DURATION_SECONDS} seconds for analysis."
268
+
269
+ return response
270
+
271
  except HTTPException:
272
  raise
273
  except Exception as e:
 
276
 
277
  @app.post("/predict_nationality")
278
  async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
 
279
  endpoint_start = time.time()
280
  logger.info(f"Nationality prediction requested for file: {file.filename}")
281
 
 
284
  raise HTTPException(status_code=500, detail="Nationality model not loaded")
285
 
286
  try:
287
+ processed_audio, processed_sr, was_clipped = await process_audio_file(file)
288
 
289
  # Make prediction
290
  prediction_start = time.time()
 
299
  endpoint_end = time.time()
300
  logger.info(f"Total nationality endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
301
 
302
+ response = {
303
  "success": True,
304
  "predictions": predictions,
305
+ "processing_time": round(endpoint_end - endpoint_start, 3),
306
+ "audio_info": {
307
+ "was_clipped": was_clipped,
308
+ "max_duration_seconds": MAX_DURATION_SECONDS
309
+ }
310
  }
311
 
312
+ if was_clipped:
313
+ response["warning"] = f"Audio was longer than {MAX_DURATION_SECONDS} seconds and was automatically clipped to the first {MAX_DURATION_SECONDS} seconds for analysis."
314
+
315
+ return response
316
+
317
  except HTTPException:
318
  raise
319
  except Exception as e:
 
322
 
323
  @app.post("/predict_all")
324
  async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
 
325
  endpoint_start = time.time()
326
  logger.info(f"Complete analysis requested for file: {file.filename}")
327
 
 
334
  raise HTTPException(status_code=500, detail="Nationality model not loaded")
335
 
336
  try:
337
+ processed_audio, processed_sr, was_clipped = await process_audio_file(file)
338
 
339
  # Get age & gender predictions
340
  age_prediction_start = time.time()
 
362
  logger.info(f"Total prediction time: {total_prediction_time:.3f} seconds")
363
  logger.info(f"Total complete analysis endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
364
 
365
+ response = {
366
  "success": True,
367
  "predictions": {
368
  "demographics": age_gender_predictions,
 
372
  "total": round(endpoint_end - endpoint_start, 3),
373
  "age_gender": round(age_prediction_end - age_prediction_start, 3),
374
  "nationality": round(nationality_prediction_end - nationality_prediction_start, 3)
375
+ },
376
+ "audio_info": {
377
+ "was_clipped": was_clipped,
378
+ "max_duration_seconds": MAX_DURATION_SECONDS
379
  }
380
  }
381
 
382
+ if was_clipped:
383
+ response["warning"] = f"Audio was longer than {MAX_DURATION_SECONDS} seconds and was automatically clipped to the first {MAX_DURATION_SECONDS} seconds for analysis."
384
+
385
+ return response
386
+
387
  except HTTPException:
388
  raise
389
  except Exception as e:
models/age_and_gender_model.py CHANGED
@@ -7,7 +7,6 @@ import librosa
7
 
8
  class AgeGenderModel:
9
  def __init__(self, model_path=None):
10
- # Use persistent storage if available, fallback to local cache
11
  if model_path is None:
12
  if os.path.exists("/data"):
13
  # HF Spaces persistent storage
@@ -34,7 +33,7 @@ class AgeGenderModel:
34
  print("Age & gender model files not found. Downloading...")
35
 
36
  try:
37
- # Use /data for cache if available, otherwise use local cache
38
  if os.path.exists("/data"):
39
  cache_root = '/data/cache'
40
  else:
@@ -72,16 +71,13 @@ class AgeGenderModel:
72
 
73
  def load(self):
74
  try:
75
- # Download model if needed
76
  if not self.download_model():
77
  print("Failed to download age & gender model")
78
  return False
79
 
80
- # Load the audonnx model
81
  print(f"Loading age & gender model from {self.model_path}...")
82
  self.model = audonnx.load(self.model_path)
83
 
84
- # Create the audinterface Feature interface
85
  outputs = ['logits_age', 'logits_gender']
86
  self.interface = audinterface.Feature(
87
  self.model.labels(outputs),
@@ -91,7 +87,7 @@ class AgeGenderModel:
91
  'concat': True,
92
  },
93
  sampling_rate=self.sampling_rate,
94
- resample=False, # We handle resampling manually
95
  verbose=False,
96
  )
97
  print("Age & gender model loaded successfully!")
@@ -105,7 +101,7 @@ class AgeGenderModel:
105
  if self.model is None or self.interface is None:
106
  raise ValueError("Model not loaded. Call load() first.")
107
 
108
- try: # Process with the interface
109
  result = self.interface.process_signal(audio_data, sr)
110
 
111
  # Extract and process results
 
7
 
8
  class AgeGenderModel:
9
  def __init__(self, model_path=None):
 
10
  if model_path is None:
11
  if os.path.exists("/data"):
12
  # HF Spaces persistent storage
 
33
  print("Age & gender model files not found. Downloading...")
34
 
35
  try:
36
+ # Use /data for cache if available, otherwise use local cache, this i nline with HF Spaces persistent storage
37
  if os.path.exists("/data"):
38
  cache_root = '/data/cache'
39
  else:
 
71
 
72
  def load(self):
73
  try:
 
74
  if not self.download_model():
75
  print("Failed to download age & gender model")
76
  return False
77
 
 
78
  print(f"Loading age & gender model from {self.model_path}...")
79
  self.model = audonnx.load(self.model_path)
80
 
 
81
  outputs = ['logits_age', 'logits_gender']
82
  self.interface = audinterface.Feature(
83
  self.model.labels(outputs),
 
87
  'concat': True,
88
  },
89
  sampling_rate=self.sampling_rate,
90
+ resample=False,
91
  verbose=False,
92
  )
93
  print("Age & gender model loaded successfully!")
 
101
  if self.model is None or self.interface is None:
102
  raise ValueError("Model not loaded. Call load() first.")
103
 
104
+ try:
105
  result = self.interface.process_signal(audio_data, sr)
106
 
107
  # Extract and process results
models/nationality_model.py CHANGED
@@ -9,7 +9,6 @@ SAMPLING_RATE = 16000
9
 
10
  class NationalityModel:
11
  def __init__(self, cache_dir=None):
12
- # Use persistent storage if available, fallback to local cache
13
  if cache_dir is None:
14
  if os.path.exists("/data"):
15
  # HF Spaces persistent storage
@@ -48,16 +47,13 @@ class NationalityModel:
48
  raise ValueError("Model not loaded. Call load() first.")
49
 
50
  try:
51
- # Ensure audio is properly formatted (float32, mono)
52
  if len(audio_data.shape) > 1:
53
  audio_data = audio_data.mean(axis=0)
54
 
55
  audio_data = audio_data.astype(np.float32)
56
 
57
- # Process audio with the feature extractor
58
  inputs = self.processor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
59
 
60
- # Get model predictions
61
  with torch.no_grad():
62
  outputs = self.model(**inputs).logits
63
 
@@ -65,7 +61,6 @@ class NationalityModel:
65
  probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
66
  top_k_values, top_k_indices = torch.topk(probabilities, k=5)
67
 
68
- # Convert to language codes and probabilities
69
  top_languages = []
70
  for i, idx in enumerate(top_k_indices):
71
  lang_id = idx.item()
 
9
 
10
  class NationalityModel:
11
  def __init__(self, cache_dir=None):
 
12
  if cache_dir is None:
13
  if os.path.exists("/data"):
14
  # HF Spaces persistent storage
 
47
  raise ValueError("Model not loaded. Call load() first.")
48
 
49
  try:
 
50
  if len(audio_data.shape) > 1:
51
  audio_data = audio_data.mean(axis=0)
52
 
53
  audio_data = audio_data.astype(np.float32)
54
 
 
55
  inputs = self.processor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
56
 
 
57
  with torch.no_grad():
58
  outputs = self.model(**inputs).logits
59
 
 
61
  probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
62
  top_k_values, top_k_indices = torch.topk(probabilities, k=5)
63
 
 
64
  top_languages = []
65
  for i, idx in enumerate(top_k_indices):
66
  lang_id = idx.item()