File size: 2,659 Bytes
b78ce3a
17f8795
07a50af
b78ce3a
 
 
 
 
 
 
 
 
 
 
a6158c1
099b786
b78ce3a
07a50af
d06382d
b78ce3a
87966ec
b04a244
87966ec
b78ce3a
87966ec
 
b78ce3a
87966ec
099b786
fa9bb6e
b78ce3a
fa9bb6e
b78ce3a
fa9bb6e
 
 
87966ec
 
 
b78ce3a
 
87966ec
b78ce3a
87966ec
 
 
 
 
23545c8
099b786
87966ec
23545c8
 
 
 
87966ec
b78ce3a
87966ec
b78ce3a
 
23545c8
b78ce3a
099b786
 
406c152
b78ce3a
 
 
86ceded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import spaces

# load the model
print("Loading model...")
model_id = "badrex/mms-300m-arabic-dialect-identifier"
classifier = pipeline("audio-classification", model=model_id, device='cuda')
print("Model loaded successfully")
print("Model moved to GPU successfully")


@spaces.GPU
def predict(audio_segment, sr=16000):
    return classifier({"sampling_rate": sr, "raw": audio_segment})

# define dialect mapping

dialect_mapping = {
    "MSA": "Modern Standard Arabic (MSA) - العربية الفصحى الحديثة",
    "Egyptian": "Egyptian Arabic -  اللهجة المصرية العامية",
    if audio is None:
        return {"Error": 1.0}
    

    sr, audio_array = audio
    

    if len(audio_array.shape) > 1:
        audio_array = audio_array.mean(axis=1)


    if audio_array.dtype != np.float32:

        if audio_array.dtype == np.int16:
            audio_array = audio_array.astype(np.float32) / 32768.0
        else:
    
    print(f"Processing audio: sample rate={sr}, shape={audio_array.shape}")
    

    predictions = predict(sr=sr, audio_segment=audio_array)
    

    results = {}
    for pred in predictions:
        dialect_name = dialect_mapping.get(pred['label'], pred['label'])
    
    return results

# prepare examples
examples = []
examples_dir = "examples"
if os.path.exists(examples_dir):
    for filename in os.listdir(examples_dir):
        if filename.endswith((".wav", ".mp3", ".ogg")):
            examples.append([os.path.join(examples_dir, filename)])

    print(f"Found {len(examples)} example files")
else:
    print("Examples directory not found")

# clean description without problematic HTML
description = """
By <a href="https://badrex.github.io/">Badr Alabsi</a> with ❤️🤍💚

This is a demo for the accurate and robust Transformer-based <a href="https://huggingface.co/badrex/mms-300m-arabic-dialect-identifier">model</a> for Spoken Arabic Dialect Identification (ADI). 
From just a short audio clip (5-10 seconds), the model can identify Modern Standard Arabic (MSA) as well as four major regional Arabic varieties: Egyptian Arabic, Peninsular Arabic (Gulf, Yemeni, and Iraqi), Levantine Arabic, and Maghrebi Arabic.

Simply **upload an audio file** 📤 or **record yourself speaking** 🎙️⏺️ to try out the model!
"""


demo = gr.Interface(
    fn=predict_dialect,
    inputs=gr.Audio(),
    outputs=gr.Label(num_top_classes=5, label="Predicted Dialect"),
    title="Tamyïz 🍉 Arabic Dialect Identification in Speech",
    description=description,
    examples=examples if examples else None,
    cache_examples=False,

    flagging_mode=None
)


demo.launch(share=True)