File size: 3,213 Bytes
d7e7912
 
 
 
 
 
 
 
 
 
825d7f4
 
 
 
 
 
 
 
 
 
 
d7e7912
 
825d7f4
d7e7912
 
 
 
825d7f4
 
 
 
 
 
 
 
 
 
d7e7912
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825d7f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import torch
import numpy as np
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor

# Constants
MODEL_ID = "facebook/mms-lid-256"
SAMPLING_RATE = 16000

class NationalityModel:
    def __init__(self, cache_dir=None):
        if cache_dir is None:
            if os.path.exists("/data"):
                # HF Spaces persistent storage
                self.cache_dir = "/data/nationality"
            else:
                # Local development or other platforms
                self.cache_dir = "./cache/nationality"
        else:
            self.cache_dir = cache_dir
            
        self.processor = None
        self.model = None
        os.makedirs(self.cache_dir, exist_ok=True)
        
    def load(self):
        try:
            print(f"Loading nationality prediction model from {MODEL_ID}...")
            print(f"Using cache directory: {self.cache_dir}")
            
            self.processor = AutoFeatureExtractor.from_pretrained(
                MODEL_ID, 
                cache_dir=self.cache_dir
            )
            self.model = Wav2Vec2ForSequenceClassification.from_pretrained(
                MODEL_ID, 
                cache_dir=self.cache_dir
            )
            print("Nationality prediction model loaded successfully!")
            return True
        except Exception as e:
            print(f"Error loading nationality prediction model: {e}")
            return False
    
    def predict(self, audio_data, sampling_rate):
        if self.model is None or self.processor is None:
            raise ValueError("Model not loaded. Call load() first.")
        
        try:
            if len(audio_data.shape) > 1:
                audio_data = audio_data.mean(axis=0)
            
            audio_data = audio_data.astype(np.float32)
            
            inputs = self.processor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
            
            with torch.no_grad():
                outputs = self.model(**inputs).logits
            
            # Get top 5 predictions
            probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
            top_k_values, top_k_indices = torch.topk(probabilities, k=5)
            
            top_languages = []
            for i, idx in enumerate(top_k_indices):
                lang_id = idx.item()
                lang_code = self.model.config.id2label[lang_id]
                probability = top_k_values[i].item()
                top_languages.append({
                    "language_code": lang_code,
                    "probability": probability
                })
            
            # Get the most likely language
            predicted_lang_id = torch.argmax(outputs, dim=-1)[0].item()
            predicted_lang = self.model.config.id2label[predicted_lang_id]
            max_probability = probabilities[predicted_lang_id].item()
            
            return {
                "predicted_language": predicted_lang,
                "confidence": max_probability,
                "top_languages": top_languages
            }
            
        except Exception as e:
            raise Exception(f"Nationality prediction error: {str(e)}")