dtrovato997 commited on
Commit
825d7f4
·
1 Parent(s): 87728c7

improved logging and persistent storage on HF hub

Browse files
Dockerfile CHANGED
@@ -32,8 +32,12 @@ WORKDIR $HOME/app
32
  # Copy application code with proper ownership
33
  COPY --chown=user . $HOME/app
34
 
35
- # Create directories with proper permissions
36
- RUN mkdir -p $HOME/app/uploads $HOME/app/cache
 
 
 
 
37
 
38
  # Expose port 7860 (HF Spaces default)
39
  EXPOSE 7860
 
32
  # Copy application code with proper ownership
33
  COPY --chown=user . $HOME/app
34
 
35
+ # Create uploads directory in app folder (for temporary files)
36
+ RUN mkdir -p $HOME/app/uploads
37
+
38
+ # Create symbolic link from /data to cache (if /data exists)
39
+ # This will be created at runtime when persistent storage is mounted
40
+ RUN mkdir -p $HOME/app/cache
41
 
42
  # Expose port 7860 (HF Spaces default)
43
  EXPOSE 7860
main.py CHANGED
@@ -5,12 +5,16 @@ import numpy as np
5
  import librosa
6
  from typing import Dict, Any
7
  import logging
 
8
  from contextlib import asynccontextmanager
9
  from models.nationality_model import NationalityModel
10
  from models.age_and_gender_model import AgeGenderModel
11
 
12
- # Configure logging
13
- logging.basicConfig(level=logging.INFO)
 
 
 
14
  logger = logging.getLogger(__name__)
15
 
16
  UPLOAD_FOLDER = 'uploads'
@@ -31,25 +35,36 @@ async def load_models() -> bool:
31
  global age_gender_model, nationality_model
32
 
33
  try:
 
 
34
  # Load age & gender model
35
- logger.info("Loading age & gender model...")
 
36
  age_gender_model = AgeGenderModel()
37
  age_gender_success = age_gender_model.load()
 
38
 
39
  if not age_gender_success:
40
  logger.error("Failed to load age & gender model")
41
  return False
42
 
 
 
43
  # Load nationality model
44
- logger.info("Loading nationality model...")
 
45
  nationality_model = NationalityModel()
46
  nationality_success = nationality_model.load()
 
47
 
48
  if not nationality_success:
49
  logger.error("Failed to load nationality model")
50
  return False
51
 
52
- logger.info("All models loaded successfully!")
 
 
 
53
  return True
54
  except Exception as e:
55
  logger.error(f"Error loading models: {e}")
@@ -59,9 +74,14 @@ async def load_models() -> bool:
59
  async def lifespan(app: FastAPI):
60
  # Startup
61
  logger.info("Starting FastAPI application...")
 
62
  success = await load_models()
 
 
63
  if not success:
64
  logger.error("Failed to load models. Application will not work properly.")
 
 
65
 
66
  yield
67
 
@@ -77,49 +97,93 @@ app = FastAPI(
77
  )
78
 
79
  def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
 
 
 
 
 
80
  if len(audio_data.shape) > 1:
 
81
  audio_data = librosa.to_mono(audio_data)
82
-
 
 
 
83
  if sr != SAMPLING_RATE:
84
- logger.info(f"Resampling from {sr}Hz to {SAMPLING_RATE}Hz")
 
85
  audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLING_RATE)
86
-
 
 
 
 
 
87
  audio_data = audio_data.astype(np.float32)
88
-
 
 
 
 
 
89
  return audio_data, SAMPLING_RATE
90
 
91
  async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int]:
 
 
 
92
  if not file.filename:
93
  raise HTTPException(status_code=400, detail="No file selected")
94
 
95
  if not allowed_file(file.filename):
 
96
  raise HTTPException(status_code=400, detail="Invalid file type. Allowed: wav, mp3, flac, m4a")
97
 
 
 
 
 
98
  # Create a secure filename
99
- filename = f"temp_{file.filename}"
100
  filepath = os.path.join(UPLOAD_FOLDER, filename)
101
 
102
  try:
103
  # Save uploaded file temporarily
 
104
  with open(filepath, "wb") as buffer:
105
  content = await file.read()
106
  buffer.write(content)
 
 
 
 
107
 
108
  # Load and preprocess audio
 
 
109
  audio_data, sr = librosa.load(filepath, sr=None)
 
 
 
110
  processed_audio, processed_sr = preprocess_audio(audio_data, sr)
111
 
 
 
 
112
  return processed_audio, processed_sr
113
 
114
  except Exception as e:
 
115
  raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}")
116
  finally:
117
  # Clean up temporary file
118
  if os.path.exists(filepath):
119
  os.remove(filepath)
 
120
 
121
  @app.get("/")
122
  async def root() -> Dict[str, Any]:
 
123
  return {
124
  "message": "Audio Analysis API - Age, Gender & Nationality Prediction",
125
  "models_loaded": {
@@ -137,79 +201,151 @@ async def root() -> Dict[str, Any]:
137
 
138
  @app.get("/health")
139
  async def health_check() -> Dict[str, str]:
 
140
  return {"status": "healthy"}
141
 
142
  @app.post("/predict_age_and_gender")
143
  async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]:
144
  """Predict age and gender from uploaded audio file."""
 
 
 
145
  if age_gender_model is None or not hasattr(age_gender_model, 'model') or age_gender_model.model is None:
 
146
  raise HTTPException(status_code=500, detail="Age & gender model not loaded")
147
 
148
  try:
149
  processed_audio, processed_sr = await process_audio_file(file)
 
 
 
 
150
  predictions = age_gender_model.predict(processed_audio, processed_sr)
 
 
 
 
 
 
 
 
151
 
152
  return {
153
  "success": True,
154
- "predictions": predictions
 
155
  }
156
 
157
  except HTTPException:
158
  raise
159
  except Exception as e:
 
160
  raise HTTPException(status_code=500, detail=str(e))
161
 
162
  @app.post("/predict_nationality")
163
  async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
164
  """Predict nationality/language from uploaded audio file."""
 
 
 
165
  if nationality_model is None or not hasattr(nationality_model, 'model') or nationality_model.model is None:
 
166
  raise HTTPException(status_code=500, detail="Nationality model not loaded")
167
 
168
  try:
169
  processed_audio, processed_sr = await process_audio_file(file)
 
 
 
 
170
  predictions = nationality_model.predict(processed_audio, processed_sr)
 
 
 
 
 
 
 
 
171
 
172
  return {
173
  "success": True,
174
- "predictions": predictions
 
175
  }
176
 
177
  except HTTPException:
178
  raise
179
  except Exception as e:
 
180
  raise HTTPException(status_code=500, detail=str(e))
181
 
182
  @app.post("/predict_all")
183
  async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
 
 
 
 
184
  if age_gender_model is None or not hasattr(age_gender_model, 'model') or age_gender_model.model is None:
 
185
  raise HTTPException(status_code=500, detail="Age & gender model not loaded")
186
 
187
  if nationality_model is None or not hasattr(nationality_model, 'model') or nationality_model.model is None:
 
188
  raise HTTPException(status_code=500, detail="Nationality model not loaded")
189
 
190
  try:
191
  processed_audio, processed_sr = await process_audio_file(file)
192
 
193
- # Get both predictions
 
 
194
  age_gender_predictions = age_gender_model.predict(processed_audio, processed_sr)
 
 
 
 
 
 
195
  nationality_predictions = nationality_model.predict(processed_audio, processed_sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  return {
198
  "success": True,
199
  "predictions": {
200
  "demographics": age_gender_predictions,
201
  "nationality": nationality_predictions
 
 
 
 
 
202
  }
203
  }
204
 
205
  except HTTPException:
206
  raise
207
  except Exception as e:
 
208
  raise HTTPException(status_code=500, detail=str(e))
209
 
210
  if __name__ == "__main__":
211
  import uvicorn
212
  port = int(os.environ.get("PORT", 7860))
 
213
  uvicorn.run(
214
  "app:app",
215
  host="0.0.0.0",
 
5
  import librosa
6
  from typing import Dict, Any
7
  import logging
8
+ import time
9
  from contextlib import asynccontextmanager
10
  from models.nationality_model import NationalityModel
11
  from models.age_and_gender_model import AgeGenderModel
12
 
13
+ # Configure logging with more detailed format
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
17
+ )
18
  logger = logging.getLogger(__name__)
19
 
20
  UPLOAD_FOLDER = 'uploads'
 
35
  global age_gender_model, nationality_model
36
 
37
  try:
38
+ total_start_time = time.time()
39
+
40
  # Load age & gender model
41
+ logger.info("Starting age & gender model loading...")
42
+ age_start = time.time()
43
  age_gender_model = AgeGenderModel()
44
  age_gender_success = age_gender_model.load()
45
+ age_end = time.time()
46
 
47
  if not age_gender_success:
48
  logger.error("Failed to load age & gender model")
49
  return False
50
 
51
+ logger.info(f"Age & gender model loaded successfully in {age_end - age_start:.2f} seconds")
52
+
53
  # Load nationality model
54
+ logger.info("Starting nationality model loading...")
55
+ nationality_start = time.time()
56
  nationality_model = NationalityModel()
57
  nationality_success = nationality_model.load()
58
+ nationality_end = time.time()
59
 
60
  if not nationality_success:
61
  logger.error("Failed to load nationality model")
62
  return False
63
 
64
+ logger.info(f"Nationality model loaded successfully in {nationality_end - nationality_start:.2f} seconds")
65
+
66
+ total_end = time.time()
67
+ logger.info(f"All models loaded successfully! Total time: {total_end - total_start_time:.2f} seconds")
68
  return True
69
  except Exception as e:
70
  logger.error(f"Error loading models: {e}")
 
74
  async def lifespan(app: FastAPI):
75
  # Startup
76
  logger.info("Starting FastAPI application...")
77
+ startup_start = time.time()
78
  success = await load_models()
79
+ startup_end = time.time()
80
+
81
  if not success:
82
  logger.error("Failed to load models. Application will not work properly.")
83
+ else:
84
+ logger.info(f"FastAPI application started successfully in {startup_end - startup_start:.2f} seconds")
85
 
86
  yield
87
 
 
97
  )
98
 
99
  def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
100
+ preprocess_start = time.time()
101
+ original_shape = audio_data.shape
102
+ logger.info(f"Starting audio preprocessing - Original shape: {original_shape}, Sample rate: {sr}Hz")
103
+
104
+ # Convert to mono if stereo
105
  if len(audio_data.shape) > 1:
106
+ mono_start = time.time()
107
  audio_data = librosa.to_mono(audio_data)
108
+ mono_end = time.time()
109
+ logger.info(f"Converted stereo to mono in {mono_end - mono_start:.3f} seconds - New shape: {audio_data.shape}")
110
+
111
+ # Resample if needed
112
  if sr != SAMPLING_RATE:
113
+ resample_start = time.time()
114
+ logger.info(f"Resampling from {sr}Hz to {SAMPLING_RATE}Hz...")
115
  audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLING_RATE)
116
+ resample_end = time.time()
117
+ logger.info(f"Resampling completed in {resample_end - resample_start:.3f} seconds")
118
+ else:
119
+ logger.info(f"No resampling needed - already at {SAMPLING_RATE}Hz")
120
+
121
+ # Convert to float32
122
  audio_data = audio_data.astype(np.float32)
123
+
124
+ preprocess_end = time.time()
125
+ duration_seconds = len(audio_data) / SAMPLING_RATE
126
+ logger.info(f"Audio preprocessing completed in {preprocess_end - preprocess_start:.3f} seconds")
127
+ logger.info(f"Final audio: {audio_data.shape} samples, {duration_seconds:.2f} seconds duration")
128
+
129
  return audio_data, SAMPLING_RATE
130
 
131
  async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int]:
132
+ process_start = time.time()
133
+ logger.info(f"Processing uploaded file: {file.filename}")
134
+
135
  if not file.filename:
136
  raise HTTPException(status_code=400, detail="No file selected")
137
 
138
  if not allowed_file(file.filename):
139
+ logger.warning(f"Invalid file type uploaded: {file.filename}")
140
  raise HTTPException(status_code=400, detail="Invalid file type. Allowed: wav, mp3, flac, m4a")
141
 
142
+ # Get file extension and log it
143
+ file_ext = file.filename.rsplit('.', 1)[1].lower()
144
+ logger.info(f"Processing {file_ext.upper()} file: {file.filename}")
145
+
146
  # Create a secure filename
147
+ filename = f"temp_{int(time.time())}_{file.filename}"
148
  filepath = os.path.join(UPLOAD_FOLDER, filename)
149
 
150
  try:
151
  # Save uploaded file temporarily
152
+ save_start = time.time()
153
  with open(filepath, "wb") as buffer:
154
  content = await file.read()
155
  buffer.write(content)
156
+ save_end = time.time()
157
+
158
+ file_size_mb = len(content) / (1024 * 1024)
159
+ logger.info(f"File saved ({file_size_mb:.2f} MB) in {save_end - save_start:.3f} seconds")
160
 
161
  # Load and preprocess audio
162
+ load_start = time.time()
163
+ logger.info(f"Loading audio from {filepath}...")
164
  audio_data, sr = librosa.load(filepath, sr=None)
165
+ load_end = time.time()
166
+ logger.info(f"Audio loaded in {load_end - load_start:.3f} seconds")
167
+
168
  processed_audio, processed_sr = preprocess_audio(audio_data, sr)
169
 
170
+ process_end = time.time()
171
+ logger.info(f"Total file processing completed in {process_end - process_start:.3f} seconds")
172
+
173
  return processed_audio, processed_sr
174
 
175
  except Exception as e:
176
+ logger.error(f"Error processing audio file {file.filename}: {str(e)}")
177
  raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}")
178
  finally:
179
  # Clean up temporary file
180
  if os.path.exists(filepath):
181
  os.remove(filepath)
182
+ logger.info(f"Temporary file {filename} cleaned up")
183
 
184
  @app.get("/")
185
  async def root() -> Dict[str, Any]:
186
+ logger.info("Root endpoint accessed")
187
  return {
188
  "message": "Audio Analysis API - Age, Gender & Nationality Prediction",
189
  "models_loaded": {
 
201
 
202
  @app.get("/health")
203
  async def health_check() -> Dict[str, str]:
204
+ logger.info("Health check endpoint accessed")
205
  return {"status": "healthy"}
206
 
207
  @app.post("/predict_age_and_gender")
208
  async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]:
209
  """Predict age and gender from uploaded audio file."""
210
+ endpoint_start = time.time()
211
+ logger.info(f"Age & Gender prediction requested for file: {file.filename}")
212
+
213
  if age_gender_model is None or not hasattr(age_gender_model, 'model') or age_gender_model.model is None:
214
+ logger.error("Age & gender model not loaded - returning 500 error")
215
  raise HTTPException(status_code=500, detail="Age & gender model not loaded")
216
 
217
  try:
218
  processed_audio, processed_sr = await process_audio_file(file)
219
+
220
+ # Make prediction
221
+ prediction_start = time.time()
222
+ logger.info("Starting age & gender prediction...")
223
  predictions = age_gender_model.predict(processed_audio, processed_sr)
224
+ prediction_end = time.time()
225
+
226
+ logger.info(f"Age & gender prediction completed in {prediction_end - prediction_start:.3f} seconds")
227
+ logger.info(f"Predicted age: {predictions['age']['predicted_age']:.1f} years")
228
+ logger.info(f"Predicted gender: {predictions['gender']['predicted_gender']} (confidence: {predictions['gender']['confidence']:.3f})")
229
+
230
+ endpoint_end = time.time()
231
+ logger.info(f"Total age & gender endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
232
 
233
  return {
234
  "success": True,
235
+ "predictions": predictions,
236
+ "processing_time": round(endpoint_end - endpoint_start, 3)
237
  }
238
 
239
  except HTTPException:
240
  raise
241
  except Exception as e:
242
+ logger.error(f"Error in age & gender prediction: {str(e)}")
243
  raise HTTPException(status_code=500, detail=str(e))
244
 
245
  @app.post("/predict_nationality")
246
  async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
247
  """Predict nationality/language from uploaded audio file."""
248
+ endpoint_start = time.time()
249
+ logger.info(f"Nationality prediction requested for file: {file.filename}")
250
+
251
  if nationality_model is None or not hasattr(nationality_model, 'model') or nationality_model.model is None:
252
+ logger.error("Nationality model not loaded - returning 500 error")
253
  raise HTTPException(status_code=500, detail="Nationality model not loaded")
254
 
255
  try:
256
  processed_audio, processed_sr = await process_audio_file(file)
257
+
258
+ # Make prediction
259
+ prediction_start = time.time()
260
+ logger.info("Starting nationality prediction...")
261
  predictions = nationality_model.predict(processed_audio, processed_sr)
262
+ prediction_end = time.time()
263
+
264
+ logger.info(f"Nationality prediction completed in {prediction_end - prediction_start:.3f} seconds")
265
+ logger.info(f"Predicted language: {predictions['predicted_language']} (confidence: {predictions['confidence']:.3f})")
266
+ logger.info(f"Top 3 languages: {[lang['language_code'] for lang in predictions['top_languages'][:3]]}")
267
+
268
+ endpoint_end = time.time()
269
+ logger.info(f"Total nationality endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
270
 
271
  return {
272
  "success": True,
273
+ "predictions": predictions,
274
+ "processing_time": round(endpoint_end - endpoint_start, 3)
275
  }
276
 
277
  except HTTPException:
278
  raise
279
  except Exception as e:
280
+ logger.error(f"Error in nationality prediction: {str(e)}")
281
  raise HTTPException(status_code=500, detail=str(e))
282
 
283
  @app.post("/predict_all")
284
  async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
285
+ """Predict age, gender, and nationality from uploaded audio file."""
286
+ endpoint_start = time.time()
287
+ logger.info(f"Complete analysis requested for file: {file.filename}")
288
+
289
  if age_gender_model is None or not hasattr(age_gender_model, 'model') or age_gender_model.model is None:
290
+ logger.error("Age & gender model not loaded - returning 500 error")
291
  raise HTTPException(status_code=500, detail="Age & gender model not loaded")
292
 
293
  if nationality_model is None or not hasattr(nationality_model, 'model') or nationality_model.model is None:
294
+ logger.error("Nationality model not loaded - returning 500 error")
295
  raise HTTPException(status_code=500, detail="Nationality model not loaded")
296
 
297
  try:
298
  processed_audio, processed_sr = await process_audio_file(file)
299
 
300
+ # Get age & gender predictions
301
+ age_prediction_start = time.time()
302
+ logger.info("Starting age & gender prediction for complete analysis...")
303
  age_gender_predictions = age_gender_model.predict(processed_audio, processed_sr)
304
+ age_prediction_end = time.time()
305
+ logger.info(f"Age & gender prediction completed in {age_prediction_end - age_prediction_start:.3f} seconds")
306
+
307
+ # Get nationality predictions
308
+ nationality_prediction_start = time.time()
309
+ logger.info("Starting nationality prediction for complete analysis...")
310
  nationality_predictions = nationality_model.predict(processed_audio, processed_sr)
311
+ nationality_prediction_end = time.time()
312
+ logger.info(f"Nationality prediction completed in {nationality_prediction_end - nationality_prediction_start:.3f} seconds")
313
+
314
+ # Log combined results
315
+ logger.info(f"Complete analysis results:")
316
+ logger.info(f" - Age: {age_gender_predictions['age']['predicted_age']:.1f} years")
317
+ logger.info(f" - Gender: {age_gender_predictions['gender']['predicted_gender']} (confidence: {age_gender_predictions['gender']['confidence']:.3f})")
318
+ logger.info(f" - Language: {nationality_predictions['predicted_language']} (confidence: {nationality_predictions['confidence']:.3f})")
319
+
320
+ total_prediction_time = (age_prediction_end - age_prediction_start) + (nationality_prediction_end - nationality_prediction_start)
321
+ endpoint_end = time.time()
322
+
323
+ logger.info(f"Total prediction time: {total_prediction_time:.3f} seconds")
324
+ logger.info(f"Total complete analysis endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
325
 
326
  return {
327
  "success": True,
328
  "predictions": {
329
  "demographics": age_gender_predictions,
330
  "nationality": nationality_predictions
331
+ },
332
+ "processing_time": {
333
+ "total": round(endpoint_end - endpoint_start, 3),
334
+ "age_gender": round(age_prediction_end - age_prediction_start, 3),
335
+ "nationality": round(nationality_prediction_end - nationality_prediction_start, 3)
336
  }
337
  }
338
 
339
  except HTTPException:
340
  raise
341
  except Exception as e:
342
+ logger.error(f"Error in complete analysis: {str(e)}")
343
  raise HTTPException(status_code=500, detail=str(e))
344
 
345
  if __name__ == "__main__":
346
  import uvicorn
347
  port = int(os.environ.get("PORT", 7860))
348
+ logger.info(f"Starting server on port {port}")
349
  uvicorn.run(
350
  "app:app",
351
  host="0.0.0.0",
models/age_and_gender_model.py CHANGED
@@ -6,12 +6,22 @@ import audinterface
6
  import librosa
7
 
8
  class AgeGenderModel:
9
- def __init__(self, model_path="./cache/age_and_gender"):
10
- self.model_path = model_path
 
 
 
 
 
 
 
 
 
 
11
  self.model = None
12
  self.interface = None
13
  self.sampling_rate = 16000
14
- os.makedirs(model_path, exist_ok=True)
15
 
16
  def download_model(self):
17
  model_onnx = os.path.join(self.model_path, 'model.onnx')
@@ -24,7 +34,12 @@ class AgeGenderModel:
24
  print("Age & gender model files not found. Downloading...")
25
 
26
  try:
27
- cache_root = 'cache'
 
 
 
 
 
28
  audeer.mkdir(cache_root)
29
  audeer.mkdir(self.model_path)
30
 
@@ -63,7 +78,7 @@ class AgeGenderModel:
63
  return False
64
 
65
  # Load the audonnx model
66
- print("Loading age & gender model...")
67
  self.model = audonnx.load(self.model_path)
68
 
69
  # Create the audinterface Feature interface
 
6
  import librosa
7
 
8
  class AgeGenderModel:
9
+ def __init__(self, model_path=None):
10
+ # Use persistent storage if available, fallback to local cache
11
+ if model_path is None:
12
+ if os.path.exists("/data"):
13
+ # HF Spaces persistent storage
14
+ self.model_path = "/data/age_and_gender"
15
+ else:
16
+ # Local development or other platforms
17
+ self.model_path = "./cache/age_and_gender"
18
+ else:
19
+ self.model_path = model_path
20
+
21
  self.model = None
22
  self.interface = None
23
  self.sampling_rate = 16000
24
+ os.makedirs(self.model_path, exist_ok=True)
25
 
26
  def download_model(self):
27
  model_onnx = os.path.join(self.model_path, 'model.onnx')
 
34
  print("Age & gender model files not found. Downloading...")
35
 
36
  try:
37
+ # Use /data for cache if available, otherwise use local cache
38
+ if os.path.exists("/data"):
39
+ cache_root = '/data/cache'
40
+ else:
41
+ cache_root = 'cache'
42
+
43
  audeer.mkdir(cache_root)
44
  audeer.mkdir(self.model_path)
45
 
 
78
  return False
79
 
80
  # Load the audonnx model
81
+ print(f"Loading age & gender model from {self.model_path}...")
82
  self.model = audonnx.load(self.model_path)
83
 
84
  # Create the audinterface Feature interface
models/nationality_model.py CHANGED
@@ -8,17 +8,35 @@ MODEL_ID = "facebook/mms-lid-256"
8
  SAMPLING_RATE = 16000
9
 
10
  class NationalityModel:
11
- def __init__(self, cache_dir="./cache/nationality"):
 
 
 
 
 
 
 
 
 
 
 
12
  self.processor = None
13
  self.model = None
14
- self.cache_dir = cache_dir
15
- os.makedirs(cache_dir, exist_ok=True)
16
 
17
  def load(self):
18
  try:
19
  print(f"Loading nationality prediction model from {MODEL_ID}...")
20
- self.processor = AutoFeatureExtractor.from_pretrained(MODEL_ID, cache_dir=self.cache_dir)
21
- self.model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_ID, cache_dir=self.cache_dir)
 
 
 
 
 
 
 
 
22
  print("Nationality prediction model loaded successfully!")
23
  return True
24
  except Exception as e:
@@ -70,4 +88,4 @@ class NationalityModel:
70
  }
71
 
72
  except Exception as e:
73
- raise Exception(f"Nationality prediction error: {str(e)}")
 
8
  SAMPLING_RATE = 16000
9
 
10
  class NationalityModel:
11
+ def __init__(self, cache_dir=None):
12
+ # Use persistent storage if available, fallback to local cache
13
+ if cache_dir is None:
14
+ if os.path.exists("/data"):
15
+ # HF Spaces persistent storage
16
+ self.cache_dir = "/data/nationality"
17
+ else:
18
+ # Local development or other platforms
19
+ self.cache_dir = "./cache/nationality"
20
+ else:
21
+ self.cache_dir = cache_dir
22
+
23
  self.processor = None
24
  self.model = None
25
+ os.makedirs(self.cache_dir, exist_ok=True)
 
26
 
27
  def load(self):
28
  try:
29
  print(f"Loading nationality prediction model from {MODEL_ID}...")
30
+ print(f"Using cache directory: {self.cache_dir}")
31
+
32
+ self.processor = AutoFeatureExtractor.from_pretrained(
33
+ MODEL_ID,
34
+ cache_dir=self.cache_dir
35
+ )
36
+ self.model = Wav2Vec2ForSequenceClassification.from_pretrained(
37
+ MODEL_ID,
38
+ cache_dir=self.cache_dir
39
+ )
40
  print("Nationality prediction model loaded successfully!")
41
  return True
42
  except Exception as e:
 
88
  }
89
 
90
  except Exception as e:
91
+ raise Exception(f"Nationality prediction error: {str(e)}")