SpeechAnalysisDemo / models /nationality_model.py
dtrovato997's picture
Initial commit
d7e7912
raw
history blame
2.94 kB
import os
import torch
import numpy as np
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
# Constants
MODEL_ID = "facebook/mms-lid-256"
SAMPLING_RATE = 16000
class NationalityModel:
def __init__(self, cache_dir="./cache/nationality"):
self.processor = None
self.model = None
self.cache_dir = cache_dir
os.makedirs(cache_dir, exist_ok=True)
def load(self):
try:
print(f"Loading nationality prediction model from {MODEL_ID}...")
self.processor = AutoFeatureExtractor.from_pretrained(MODEL_ID, cache_dir=self.cache_dir)
self.model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_ID, cache_dir=self.cache_dir)
print("Nationality prediction model loaded successfully!")
return True
except Exception as e:
print(f"Error loading nationality prediction model: {e}")
return False
def predict(self, audio_data, sampling_rate):
if self.model is None or self.processor is None:
raise ValueError("Model not loaded. Call load() first.")
try:
# Ensure audio is properly formatted (float32, mono)
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=0)
audio_data = audio_data.astype(np.float32)
# Process audio with the feature extractor
inputs = self.processor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
# Get model predictions
with torch.no_grad():
outputs = self.model(**inputs).logits
# Get top 5 predictions
probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
top_k_values, top_k_indices = torch.topk(probabilities, k=5)
# Convert to language codes and probabilities
top_languages = []
for i, idx in enumerate(top_k_indices):
lang_id = idx.item()
lang_code = self.model.config.id2label[lang_id]
probability = top_k_values[i].item()
top_languages.append({
"language_code": lang_code,
"probability": probability
})
# Get the most likely language
predicted_lang_id = torch.argmax(outputs, dim=-1)[0].item()
predicted_lang = self.model.config.id2label[predicted_lang_id]
max_probability = probabilities[predicted_lang_id].item()
return {
"predicted_language": predicted_lang,
"confidence": max_probability,
"top_languages": top_languages
}
except Exception as e:
raise Exception(f"Nationality prediction error: {str(e)}")