Commit
·
5277669
1
Parent(s):
ead33a6
fix : clip audio to max 2 mins
Browse files- main.py +67 -19
- models/age_and_gender_model.py +3 -7
- models/nationality_model.py +0 -5
main.py
CHANGED
@@ -20,6 +20,7 @@ logger = logging.getLogger(__name__)
|
|
20 |
UPLOAD_FOLDER = 'uploads'
|
21 |
ALLOWED_EXTENSIONS = {'wav', 'mp3', 'flac', 'm4a'}
|
22 |
SAMPLING_RATE = 16000
|
|
|
23 |
|
24 |
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
25 |
|
@@ -31,6 +32,23 @@ def allowed_file(filename: str) -> bool:
|
|
31 |
return '.' in filename and \
|
32 |
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
async def load_models() -> bool:
|
35 |
global age_gender_model, nationality_model
|
36 |
|
@@ -96,10 +114,11 @@ app = FastAPI(
|
|
96 |
lifespan=lifespan
|
97 |
)
|
98 |
|
99 |
-
def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
|
100 |
preprocess_start = time.time()
|
101 |
original_shape = audio_data.shape
|
102 |
-
|
|
|
103 |
|
104 |
# Convert to mono if stereo
|
105 |
if len(audio_data.shape) > 1:
|
@@ -115,20 +134,24 @@ def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
|
|
115 |
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLING_RATE)
|
116 |
resample_end = time.time()
|
117 |
logger.info(f"Resampling completed in {resample_end - resample_start:.3f} seconds")
|
|
|
118 |
else:
|
119 |
logger.info(f"No resampling needed - already at {SAMPLING_RATE}Hz")
|
120 |
|
|
|
|
|
|
|
121 |
# Convert to float32
|
122 |
audio_data = audio_data.astype(np.float32)
|
123 |
|
124 |
preprocess_end = time.time()
|
125 |
-
|
126 |
logger.info(f"Audio preprocessing completed in {preprocess_end - preprocess_start:.3f} seconds")
|
127 |
-
logger.info(f"Final audio: {audio_data.shape} samples, {
|
128 |
|
129 |
-
return audio_data,
|
130 |
|
131 |
-
async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int]:
|
132 |
process_start = time.time()
|
133 |
logger.info(f"Processing uploaded file: {file.filename}")
|
134 |
|
@@ -165,12 +188,12 @@ async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int]:
|
|
165 |
load_end = time.time()
|
166 |
logger.info(f"Audio loaded in {load_end - load_start:.3f} seconds")
|
167 |
|
168 |
-
processed_audio, processed_sr = preprocess_audio(audio_data, sr)
|
169 |
|
170 |
process_end = time.time()
|
171 |
logger.info(f"Total file processing completed in {process_end - process_start:.3f} seconds")
|
172 |
|
173 |
-
return processed_audio, processed_sr
|
174 |
|
175 |
except Exception as e:
|
176 |
logger.error(f"Error processing audio file {file.filename}: {str(e)}")
|
@@ -186,6 +209,7 @@ async def root() -> Dict[str, Any]:
|
|
186 |
logger.info("Root endpoint accessed")
|
187 |
return {
|
188 |
"message": "Audio Analysis API - Age, Gender & Nationality Prediction",
|
|
|
189 |
"models_loaded": {
|
190 |
"age_gender": age_gender_model is not None and hasattr(age_gender_model, 'model') and age_gender_model.model is not None,
|
191 |
"nationality": nationality_model is not None and hasattr(nationality_model, 'model') and nationality_model.model is not None
|
@@ -206,7 +230,6 @@ async def health_check() -> Dict[str, str]:
|
|
206 |
|
207 |
@app.post("/predict_age_and_gender")
|
208 |
async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]:
|
209 |
-
"""Predict age and gender from uploaded audio file."""
|
210 |
endpoint_start = time.time()
|
211 |
logger.info(f"Age & Gender prediction requested for file: {file.filename}")
|
212 |
|
@@ -215,7 +238,7 @@ async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]
|
|
215 |
raise HTTPException(status_code=500, detail="Age & gender model not loaded")
|
216 |
|
217 |
try:
|
218 |
-
processed_audio, processed_sr = await process_audio_file(file)
|
219 |
|
220 |
# Make prediction
|
221 |
prediction_start = time.time()
|
@@ -230,12 +253,21 @@ async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]
|
|
230 |
endpoint_end = time.time()
|
231 |
logger.info(f"Total age & gender endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
|
232 |
|
233 |
-
|
234 |
"success": True,
|
235 |
"predictions": predictions,
|
236 |
-
"processing_time": round(endpoint_end - endpoint_start, 3)
|
|
|
|
|
|
|
|
|
237 |
}
|
238 |
|
|
|
|
|
|
|
|
|
|
|
239 |
except HTTPException:
|
240 |
raise
|
241 |
except Exception as e:
|
@@ -244,7 +276,6 @@ async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]
|
|
244 |
|
245 |
@app.post("/predict_nationality")
|
246 |
async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
|
247 |
-
"""Predict nationality/language from uploaded audio file."""
|
248 |
endpoint_start = time.time()
|
249 |
logger.info(f"Nationality prediction requested for file: {file.filename}")
|
250 |
|
@@ -253,7 +284,7 @@ async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
253 |
raise HTTPException(status_code=500, detail="Nationality model not loaded")
|
254 |
|
255 |
try:
|
256 |
-
processed_audio, processed_sr = await process_audio_file(file)
|
257 |
|
258 |
# Make prediction
|
259 |
prediction_start = time.time()
|
@@ -268,12 +299,21 @@ async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
268 |
endpoint_end = time.time()
|
269 |
logger.info(f"Total nationality endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
|
270 |
|
271 |
-
|
272 |
"success": True,
|
273 |
"predictions": predictions,
|
274 |
-
"processing_time": round(endpoint_end - endpoint_start, 3)
|
|
|
|
|
|
|
|
|
275 |
}
|
276 |
|
|
|
|
|
|
|
|
|
|
|
277 |
except HTTPException:
|
278 |
raise
|
279 |
except Exception as e:
|
@@ -282,7 +322,6 @@ async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
282 |
|
283 |
@app.post("/predict_all")
|
284 |
async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
|
285 |
-
"""Predict age, gender, and nationality from uploaded audio file."""
|
286 |
endpoint_start = time.time()
|
287 |
logger.info(f"Complete analysis requested for file: {file.filename}")
|
288 |
|
@@ -295,7 +334,7 @@ async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
295 |
raise HTTPException(status_code=500, detail="Nationality model not loaded")
|
296 |
|
297 |
try:
|
298 |
-
processed_audio, processed_sr = await process_audio_file(file)
|
299 |
|
300 |
# Get age & gender predictions
|
301 |
age_prediction_start = time.time()
|
@@ -323,7 +362,7 @@ async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
323 |
logger.info(f"Total prediction time: {total_prediction_time:.3f} seconds")
|
324 |
logger.info(f"Total complete analysis endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
|
325 |
|
326 |
-
|
327 |
"success": True,
|
328 |
"predictions": {
|
329 |
"demographics": age_gender_predictions,
|
@@ -333,9 +372,18 @@ async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
333 |
"total": round(endpoint_end - endpoint_start, 3),
|
334 |
"age_gender": round(age_prediction_end - age_prediction_start, 3),
|
335 |
"nationality": round(nationality_prediction_end - nationality_prediction_start, 3)
|
|
|
|
|
|
|
|
|
336 |
}
|
337 |
}
|
338 |
|
|
|
|
|
|
|
|
|
|
|
339 |
except HTTPException:
|
340 |
raise
|
341 |
except Exception as e:
|
|
|
20 |
UPLOAD_FOLDER = 'uploads'
|
21 |
ALLOWED_EXTENSIONS = {'wav', 'mp3', 'flac', 'm4a'}
|
22 |
SAMPLING_RATE = 16000
|
23 |
+
MAX_DURATION_SECONDS = 120 # 2 minutes maximum
|
24 |
|
25 |
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
26 |
|
|
|
32 |
return '.' in filename and \
|
33 |
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
34 |
|
35 |
+
def clip_audio_to_max_duration(audio_data: np.ndarray, sr: int, max_duration: int = MAX_DURATION_SECONDS) -> tuple[np.ndarray, bool]:
|
36 |
+
current_duration = len(audio_data) / sr
|
37 |
+
|
38 |
+
if current_duration <= max_duration:
|
39 |
+
logger.info(f"Audio duration ({current_duration:.2f}s) is within limit ({max_duration}s) - no clipping needed")
|
40 |
+
return audio_data, False
|
41 |
+
|
42 |
+
# Calculate how many samples we need for the max duration
|
43 |
+
max_samples = int(max_duration * sr)
|
44 |
+
|
45 |
+
# Clip to first max_duration seconds
|
46 |
+
clipped_audio = audio_data[:max_samples]
|
47 |
+
|
48 |
+
logger.info(f"Audio clipped from {current_duration:.2f}s to {max_duration}s ({len(audio_data)} samples → {len(clipped_audio)} samples)")
|
49 |
+
|
50 |
+
return clipped_audio, True
|
51 |
+
|
52 |
async def load_models() -> bool:
|
53 |
global age_gender_model, nationality_model
|
54 |
|
|
|
114 |
lifespan=lifespan
|
115 |
)
|
116 |
|
117 |
+
def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int, bool]:
|
118 |
preprocess_start = time.time()
|
119 |
original_shape = audio_data.shape
|
120 |
+
original_duration = len(audio_data) / sr
|
121 |
+
logger.info(f"Starting audio preprocessing Sample rate: {sr}Hz, Duration: {original_duration:.2f}s")
|
122 |
|
123 |
# Convert to mono if stereo
|
124 |
if len(audio_data.shape) > 1:
|
|
|
134 |
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLING_RATE)
|
135 |
resample_end = time.time()
|
136 |
logger.info(f"Resampling completed in {resample_end - resample_start:.3f} seconds")
|
137 |
+
sr = SAMPLING_RATE
|
138 |
else:
|
139 |
logger.info(f"No resampling needed - already at {SAMPLING_RATE}Hz")
|
140 |
|
141 |
+
# Clip audio to maximum duration if needed
|
142 |
+
audio_data, was_clipped = clip_audio_to_max_duration(audio_data, sr)
|
143 |
+
|
144 |
# Convert to float32
|
145 |
audio_data = audio_data.astype(np.float32)
|
146 |
|
147 |
preprocess_end = time.time()
|
148 |
+
final_duration_seconds = len(audio_data) / sr
|
149 |
logger.info(f"Audio preprocessing completed in {preprocess_end - preprocess_start:.3f} seconds")
|
150 |
+
logger.info(f"Final audio: {audio_data.shape} samples, {final_duration_seconds:.2f} seconds duration")
|
151 |
|
152 |
+
return audio_data, sr, was_clipped
|
153 |
|
154 |
+
async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int, bool]:
|
155 |
process_start = time.time()
|
156 |
logger.info(f"Processing uploaded file: {file.filename}")
|
157 |
|
|
|
188 |
load_end = time.time()
|
189 |
logger.info(f"Audio loaded in {load_end - load_start:.3f} seconds")
|
190 |
|
191 |
+
processed_audio, processed_sr, was_clipped = preprocess_audio(audio_data, sr)
|
192 |
|
193 |
process_end = time.time()
|
194 |
logger.info(f"Total file processing completed in {process_end - process_start:.3f} seconds")
|
195 |
|
196 |
+
return processed_audio, processed_sr, was_clipped
|
197 |
|
198 |
except Exception as e:
|
199 |
logger.error(f"Error processing audio file {file.filename}: {str(e)}")
|
|
|
209 |
logger.info("Root endpoint accessed")
|
210 |
return {
|
211 |
"message": "Audio Analysis API - Age, Gender & Nationality Prediction",
|
212 |
+
"max_audio_duration": f"{MAX_DURATION_SECONDS} seconds (files longer than this will be automatically clipped)",
|
213 |
"models_loaded": {
|
214 |
"age_gender": age_gender_model is not None and hasattr(age_gender_model, 'model') and age_gender_model.model is not None,
|
215 |
"nationality": nationality_model is not None and hasattr(nationality_model, 'model') and nationality_model.model is not None
|
|
|
230 |
|
231 |
@app.post("/predict_age_and_gender")
|
232 |
async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
|
233 |
endpoint_start = time.time()
|
234 |
logger.info(f"Age & Gender prediction requested for file: {file.filename}")
|
235 |
|
|
|
238 |
raise HTTPException(status_code=500, detail="Age & gender model not loaded")
|
239 |
|
240 |
try:
|
241 |
+
processed_audio, processed_sr, was_clipped = await process_audio_file(file)
|
242 |
|
243 |
# Make prediction
|
244 |
prediction_start = time.time()
|
|
|
253 |
endpoint_end = time.time()
|
254 |
logger.info(f"Total age & gender endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
|
255 |
|
256 |
+
response = {
|
257 |
"success": True,
|
258 |
"predictions": predictions,
|
259 |
+
"processing_time": round(endpoint_end - endpoint_start, 3),
|
260 |
+
"audio_info": {
|
261 |
+
"was_clipped": was_clipped,
|
262 |
+
"max_duration_seconds": MAX_DURATION_SECONDS
|
263 |
+
}
|
264 |
}
|
265 |
|
266 |
+
if was_clipped:
|
267 |
+
response["warning"] = f"Audio was longer than {MAX_DURATION_SECONDS} seconds and was automatically clipped to the first {MAX_DURATION_SECONDS} seconds for analysis."
|
268 |
+
|
269 |
+
return response
|
270 |
+
|
271 |
except HTTPException:
|
272 |
raise
|
273 |
except Exception as e:
|
|
|
276 |
|
277 |
@app.post("/predict_nationality")
|
278 |
async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
|
279 |
endpoint_start = time.time()
|
280 |
logger.info(f"Nationality prediction requested for file: {file.filename}")
|
281 |
|
|
|
284 |
raise HTTPException(status_code=500, detail="Nationality model not loaded")
|
285 |
|
286 |
try:
|
287 |
+
processed_audio, processed_sr, was_clipped = await process_audio_file(file)
|
288 |
|
289 |
# Make prediction
|
290 |
prediction_start = time.time()
|
|
|
299 |
endpoint_end = time.time()
|
300 |
logger.info(f"Total nationality endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
|
301 |
|
302 |
+
response = {
|
303 |
"success": True,
|
304 |
"predictions": predictions,
|
305 |
+
"processing_time": round(endpoint_end - endpoint_start, 3),
|
306 |
+
"audio_info": {
|
307 |
+
"was_clipped": was_clipped,
|
308 |
+
"max_duration_seconds": MAX_DURATION_SECONDS
|
309 |
+
}
|
310 |
}
|
311 |
|
312 |
+
if was_clipped:
|
313 |
+
response["warning"] = f"Audio was longer than {MAX_DURATION_SECONDS} seconds and was automatically clipped to the first {MAX_DURATION_SECONDS} seconds for analysis."
|
314 |
+
|
315 |
+
return response
|
316 |
+
|
317 |
except HTTPException:
|
318 |
raise
|
319 |
except Exception as e:
|
|
|
322 |
|
323 |
@app.post("/predict_all")
|
324 |
async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
|
325 |
endpoint_start = time.time()
|
326 |
logger.info(f"Complete analysis requested for file: {file.filename}")
|
327 |
|
|
|
334 |
raise HTTPException(status_code=500, detail="Nationality model not loaded")
|
335 |
|
336 |
try:
|
337 |
+
processed_audio, processed_sr, was_clipped = await process_audio_file(file)
|
338 |
|
339 |
# Get age & gender predictions
|
340 |
age_prediction_start = time.time()
|
|
|
362 |
logger.info(f"Total prediction time: {total_prediction_time:.3f} seconds")
|
363 |
logger.info(f"Total complete analysis endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
|
364 |
|
365 |
+
response = {
|
366 |
"success": True,
|
367 |
"predictions": {
|
368 |
"demographics": age_gender_predictions,
|
|
|
372 |
"total": round(endpoint_end - endpoint_start, 3),
|
373 |
"age_gender": round(age_prediction_end - age_prediction_start, 3),
|
374 |
"nationality": round(nationality_prediction_end - nationality_prediction_start, 3)
|
375 |
+
},
|
376 |
+
"audio_info": {
|
377 |
+
"was_clipped": was_clipped,
|
378 |
+
"max_duration_seconds": MAX_DURATION_SECONDS
|
379 |
}
|
380 |
}
|
381 |
|
382 |
+
if was_clipped:
|
383 |
+
response["warning"] = f"Audio was longer than {MAX_DURATION_SECONDS} seconds and was automatically clipped to the first {MAX_DURATION_SECONDS} seconds for analysis."
|
384 |
+
|
385 |
+
return response
|
386 |
+
|
387 |
except HTTPException:
|
388 |
raise
|
389 |
except Exception as e:
|
models/age_and_gender_model.py
CHANGED
@@ -7,7 +7,6 @@ import librosa
|
|
7 |
|
8 |
class AgeGenderModel:
|
9 |
def __init__(self, model_path=None):
|
10 |
-
# Use persistent storage if available, fallback to local cache
|
11 |
if model_path is None:
|
12 |
if os.path.exists("/data"):
|
13 |
# HF Spaces persistent storage
|
@@ -34,7 +33,7 @@ class AgeGenderModel:
|
|
34 |
print("Age & gender model files not found. Downloading...")
|
35 |
|
36 |
try:
|
37 |
-
# Use /data for cache if available, otherwise use local cache
|
38 |
if os.path.exists("/data"):
|
39 |
cache_root = '/data/cache'
|
40 |
else:
|
@@ -72,16 +71,13 @@ class AgeGenderModel:
|
|
72 |
|
73 |
def load(self):
|
74 |
try:
|
75 |
-
# Download model if needed
|
76 |
if not self.download_model():
|
77 |
print("Failed to download age & gender model")
|
78 |
return False
|
79 |
|
80 |
-
# Load the audonnx model
|
81 |
print(f"Loading age & gender model from {self.model_path}...")
|
82 |
self.model = audonnx.load(self.model_path)
|
83 |
|
84 |
-
# Create the audinterface Feature interface
|
85 |
outputs = ['logits_age', 'logits_gender']
|
86 |
self.interface = audinterface.Feature(
|
87 |
self.model.labels(outputs),
|
@@ -91,7 +87,7 @@ class AgeGenderModel:
|
|
91 |
'concat': True,
|
92 |
},
|
93 |
sampling_rate=self.sampling_rate,
|
94 |
-
resample=False,
|
95 |
verbose=False,
|
96 |
)
|
97 |
print("Age & gender model loaded successfully!")
|
@@ -105,7 +101,7 @@ class AgeGenderModel:
|
|
105 |
if self.model is None or self.interface is None:
|
106 |
raise ValueError("Model not loaded. Call load() first.")
|
107 |
|
108 |
-
try:
|
109 |
result = self.interface.process_signal(audio_data, sr)
|
110 |
|
111 |
# Extract and process results
|
|
|
7 |
|
8 |
class AgeGenderModel:
|
9 |
def __init__(self, model_path=None):
|
|
|
10 |
if model_path is None:
|
11 |
if os.path.exists("/data"):
|
12 |
# HF Spaces persistent storage
|
|
|
33 |
print("Age & gender model files not found. Downloading...")
|
34 |
|
35 |
try:
|
36 |
+
# Use /data for cache if available, otherwise use local cache, this i nline with HF Spaces persistent storage
|
37 |
if os.path.exists("/data"):
|
38 |
cache_root = '/data/cache'
|
39 |
else:
|
|
|
71 |
|
72 |
def load(self):
|
73 |
try:
|
|
|
74 |
if not self.download_model():
|
75 |
print("Failed to download age & gender model")
|
76 |
return False
|
77 |
|
|
|
78 |
print(f"Loading age & gender model from {self.model_path}...")
|
79 |
self.model = audonnx.load(self.model_path)
|
80 |
|
|
|
81 |
outputs = ['logits_age', 'logits_gender']
|
82 |
self.interface = audinterface.Feature(
|
83 |
self.model.labels(outputs),
|
|
|
87 |
'concat': True,
|
88 |
},
|
89 |
sampling_rate=self.sampling_rate,
|
90 |
+
resample=False,
|
91 |
verbose=False,
|
92 |
)
|
93 |
print("Age & gender model loaded successfully!")
|
|
|
101 |
if self.model is None or self.interface is None:
|
102 |
raise ValueError("Model not loaded. Call load() first.")
|
103 |
|
104 |
+
try:
|
105 |
result = self.interface.process_signal(audio_data, sr)
|
106 |
|
107 |
# Extract and process results
|
models/nationality_model.py
CHANGED
@@ -9,7 +9,6 @@ SAMPLING_RATE = 16000
|
|
9 |
|
10 |
class NationalityModel:
|
11 |
def __init__(self, cache_dir=None):
|
12 |
-
# Use persistent storage if available, fallback to local cache
|
13 |
if cache_dir is None:
|
14 |
if os.path.exists("/data"):
|
15 |
# HF Spaces persistent storage
|
@@ -48,16 +47,13 @@ class NationalityModel:
|
|
48 |
raise ValueError("Model not loaded. Call load() first.")
|
49 |
|
50 |
try:
|
51 |
-
# Ensure audio is properly formatted (float32, mono)
|
52 |
if len(audio_data.shape) > 1:
|
53 |
audio_data = audio_data.mean(axis=0)
|
54 |
|
55 |
audio_data = audio_data.astype(np.float32)
|
56 |
|
57 |
-
# Process audio with the feature extractor
|
58 |
inputs = self.processor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
|
59 |
|
60 |
-
# Get model predictions
|
61 |
with torch.no_grad():
|
62 |
outputs = self.model(**inputs).logits
|
63 |
|
@@ -65,7 +61,6 @@ class NationalityModel:
|
|
65 |
probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
|
66 |
top_k_values, top_k_indices = torch.topk(probabilities, k=5)
|
67 |
|
68 |
-
# Convert to language codes and probabilities
|
69 |
top_languages = []
|
70 |
for i, idx in enumerate(top_k_indices):
|
71 |
lang_id = idx.item()
|
|
|
9 |
|
10 |
class NationalityModel:
|
11 |
def __init__(self, cache_dir=None):
|
|
|
12 |
if cache_dir is None:
|
13 |
if os.path.exists("/data"):
|
14 |
# HF Spaces persistent storage
|
|
|
47 |
raise ValueError("Model not loaded. Call load() first.")
|
48 |
|
49 |
try:
|
|
|
50 |
if len(audio_data.shape) > 1:
|
51 |
audio_data = audio_data.mean(axis=0)
|
52 |
|
53 |
audio_data = audio_data.astype(np.float32)
|
54 |
|
|
|
55 |
inputs = self.processor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
|
56 |
|
|
|
57 |
with torch.no_grad():
|
58 |
outputs = self.model(**inputs).logits
|
59 |
|
|
|
61 |
probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
|
62 |
top_k_values, top_k_indices = torch.topk(probabilities, k=5)
|
63 |
|
|
|
64 |
top_languages = []
|
65 |
for i, idx in enumerate(top_k_indices):
|
66 |
lang_id = idx.item()
|