File size: 3,817 Bytes
6a64f1b
791f22a
3f2d36f
7e55933
6a64f1b
2d8699c
d0c4515
 
 
 
 
 
3f2d36f
d0c4515
 
03c4019
 
 
 
 
 
 
 
 
 
b8c7156
3f2d36f
 
 
e7e598c
3f2d36f
03c4019
5f5e910
3f2d36f
03c4019
e7e598c
791f22a
03c4019
e7e598c
03c4019
3f2d36f
d0c4515
791f22a
d0c4515
791f22a
e7e598c
791f22a
 
 
 
 
 
 
 
 
e7e598c
791f22a
 
 
 
 
 
 
 
 
 
 
 
e7e598c
 
03c4019
791f22a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7e598c
 
 
d0c4515
791f22a
 
 
d0c4515
3f2d36f
791f22a
5f5e910
03c4019
791f22a
 
03c4019
 
 
791f22a
294ebd0
3f2d36f
 
 
197cbd5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import torch
import torchaudio

# Model URLs
model_urls = [
    "kiranpantha/whisper-tiny-ne",
    "kiranpantha/whisper-base-ne",
    "kiranpantha/whisper-small-np",
    "kiranpantha/whisper-medium-nepali",
    "kiranpantha/whisper-large-v3-nepali",
    "kiranpantha/whisper-large-v3-turbo-nepali",
]

# Mapping model names correctly
processor_mappings = {
    "kiranpantha/whisper-tiny-ne": "openai/whisper-tiny",
    "kiranpantha/whisper-base-ne": "openai/whisper-base",
    "kiranpantha/whisper-small-np": "openai/whisper-small",
    "kiranpantha/whisper-medium-nepali": "openai/whisper-medium",
    "kiranpantha/whisper-large-v3-nepali": "openai/whisper-large-v3",
    "kiranpantha/whisper-large-v3-turbo-nepali": "openai/whisper-large-v3",
}

# Cache models and processors
model_cache = {}

def load_model(model_name):
    """Loads and caches the model and processor with proper device management."""
    if model_name not in model_cache:
        processor_name = processor_mappings.get(model_name, model_name)  # Handle mapping
        
        processor = AutoProcessor.from_pretrained(processor_name)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(device)
        model.eval()
        
        model_cache[model_name] = (processor, model, device)
    
    return model_cache[model_name]

import numpy as np

def transcribe_audio(model_name, audio_chunk):
    try:
        print("Received audio_chunk:", type(audio_chunk), audio_chunk)

        if audio_chunk is None:
            return "Error: No audio received"

        if isinstance(audio_chunk, str):
            # Upload case
            audio_tensor, sample_rate = torchaudio.load(audio_chunk)
            audio_array = audio_tensor.squeeze(0).numpy()
        
        elif isinstance(audio_chunk, tuple) and isinstance(audio_chunk[1], np.ndarray):
            # Microphone case
            sample_rate, audio_array = audio_chunk

        else:
            return "Error: Invalid audio input format"

        # Stereo to mono
        if audio_array.ndim == 2:
            audio_array = np.mean(audio_array, axis=0)

        # Resample
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            audio_array = resampler(torch.tensor(audio_array).unsqueeze(0)).squeeze(0).numpy()

        # Load model
        processor, model, device = load_model(model_name)

        # Prepare inputs
        inputs = processor(
            torch.tensor(audio_array), sampling_rate=16000, return_tensors="pt"
        )
        input_features = inputs.input_features.to(device)

        # Generate output
        generated_ids = model.generate(
            input_features,
            forced_decoder_ids=processor.get_decoder_prompt_ids(language="ne", task="transcribe"),
            max_length=448,
        )

        transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return transcription.strip()
    
    except Exception as e:
        return f"Error: {str(e)}"




# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# πŸŽ™οΈ Nepali Speech Recognition with Whisper Models")
    
    model_dropdown = gr.Dropdown(choices=model_urls, label="Select Model", value=model_urls[0])
    audio_input = gr.Audio(type="numpy", label="🎀 Record your voice here")
    output_text = gr.Textbox(label="πŸ“„ Transcription Output")
    transcribe_button = gr.Button("Transcribe")

    transcribe_button.click(
        fn=transcribe_audio,  # <-- fixed function name
        inputs=[model_dropdown, audio_input],
        outputs=output_text,
    )

demo.launch()