File size: 4,742 Bytes
63da647
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

import gradio as gr
import random
import json
import os
from difflib import SequenceMatcher
from jiwer import wer
import torchaudio
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, HubertForCTC, HubertProcessor
import whisper

# Load metadata
with open("common_voice_en_validated_249_hf_ready.json") as f:
    data = json.load(f)

# Available filter values
ages = sorted(set(entry["age"] for entry in data))
genders = sorted(set(entry["gender"] for entry in data))
accents = sorted(set(entry["accent"] for entry in data))

# Load models
device = "cuda" if torch.cuda.is_available() else "cpu"

# Whisper
whisper_model = whisper.load_model("medium").to(device)

# Wav2Vec2
wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(device)

# HuBERT
hubert_processor = HubertProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
hubert_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device)

def load_audio(file_path):
    waveform, sr = torchaudio.load(file_path)
    return torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0].numpy()

def transcribe_whisper(file_path):
    result = whisper_model.transcribe(file_path)
    return result["text"].strip().lower()

def transcribe_wav2vec(file_path):
    audio = load_audio(file_path)
    inputs = wav2vec_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
    with torch.no_grad():
        logits = wav2vec_model(**inputs.to(device)).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    return wav2vec_processor.batch_decode(predicted_ids)[0].strip().lower()

def transcribe_hubert(file_path):
    audio = load_audio(file_path)
    inputs = hubert_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
    with torch.no_grad():
        logits = hubert_model(**inputs.to(device)).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    return hubert_processor.batch_decode(predicted_ids)[0].strip().lower()

def highlight_differences(ref, hyp):
    sm = SequenceMatcher(None, ref.split(), hyp.split())
    result = []
    for opcode, i1, i2, j1, j2 in sm.get_opcodes():
        if opcode == 'equal':
            result.extend(hyp.split()[j1:j2])
        elif opcode in ('replace', 'insert', 'delete'):
            wrong = hyp.split()[j1:j2]
            result.extend([f"<span style='color:red'>{w}</span>" for w in wrong])
    return " ".join(result)

def run_demo(age, gender, accent):
    filtered = [
        entry for entry in data
        if entry["age"] == age and entry["gender"] == gender and entry["accent"] == accent
    ]
    if not filtered:
        return "No matching sample.", None, "", "", "", "", "", ""

    sample = random.choice(filtered)
    file_path = os.path.join("common_voice_en_validated_249", sample["path"])
    gold = sample["sentence"].strip().lower()

    whisper_text = transcribe_whisper(file_path)
    wav2vec_text = transcribe_wav2vec(file_path)
    hubert_text = transcribe_hubert(file_path)

    table = f"""
    <table border="1" style="width:100%">
        <tr><th>Model</th><th>Transcription</th><th>WER</th></tr>
        <tr><td><b>Gold</b></td><td>{gold}</td><td>0.00</td></tr>
        <tr><td>Whisper</td><td>{highlight_differences(gold, whisper_text)}</td><td>{wer(gold, whisper_text):.2f}</td></tr>
        <tr><td>Wav2Vec2</td><td>{highlight_differences(gold, wav2vec_text)}</td><td>{wer(gold, wav2vec_text):.2f}</td></tr>
        <tr><td>HuBERT</td><td>{highlight_differences(gold, hubert_text)}</td><td>{wer(gold, hubert_text):.2f}</td></tr>
    </table>
    """

    return sample["sentence"], file_path, gold, whisper_text, wav2vec_text, hubert_text, table, f"Audio path: {file_path}"

with gr.Blocks() as demo:
    gr.Markdown("# ASR Model Comparison on ESL Audio")
    gr.Markdown("Filter by age, gender, and accent. Then generate a random ESL learner's audio to compare how Whisper, Wav2Vec2, and HuBERT transcribe it.")

    with gr.Row():
        age = gr.Dropdown(choices=ages, label="Age")
        gender = gr.Dropdown(choices=genders, label="Gender")
        accent = gr.Dropdown(choices=accents, label="Accent")

    btn = gr.Button("Generate and Transcribe")
    audio = gr.Audio(label="Audio", type="filepath")
    wer_output = gr.HTML()

    btn.click(fn=run_demo, inputs=[age, gender, accent], outputs=[
        gr.Textbox(label="Gold (Correct)"),
        audio,
        gr.Textbox(label="Whisper Output"),
        gr.Textbox(label="Wav2Vec2 Output"),
        gr.Textbox(label="HuBERT Output"),
        wer_output,
        gr.Textbox(label="Path")
    ])

demo.launch()