KDM999 commited on
Commit
63da647
·
verified ·
1 Parent(s): d27ba16

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import random
4
+ import json
5
+ import os
6
+ from difflib import SequenceMatcher
7
+ from jiwer import wer
8
+ import torchaudio
9
+ import torch
10
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, HubertForCTC, HubertProcessor
11
+ import whisper
12
+
13
+ # Load metadata
14
+ with open("common_voice_en_validated_249_hf_ready.json") as f:
15
+ data = json.load(f)
16
+
17
+ # Available filter values
18
+ ages = sorted(set(entry["age"] for entry in data))
19
+ genders = sorted(set(entry["gender"] for entry in data))
20
+ accents = sorted(set(entry["accent"] for entry in data))
21
+
22
+ # Load models
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ # Whisper
26
+ whisper_model = whisper.load_model("medium").to(device)
27
+
28
+ # Wav2Vec2
29
+ wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
30
+ wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(device)
31
+
32
+ # HuBERT
33
+ hubert_processor = HubertProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
34
+ hubert_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device)
35
+
36
+ def load_audio(file_path):
37
+ waveform, sr = torchaudio.load(file_path)
38
+ return torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0].numpy()
39
+
40
+ def transcribe_whisper(file_path):
41
+ result = whisper_model.transcribe(file_path)
42
+ return result["text"].strip().lower()
43
+
44
+ def transcribe_wav2vec(file_path):
45
+ audio = load_audio(file_path)
46
+ inputs = wav2vec_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
47
+ with torch.no_grad():
48
+ logits = wav2vec_model(**inputs.to(device)).logits
49
+ predicted_ids = torch.argmax(logits, dim=-1)
50
+ return wav2vec_processor.batch_decode(predicted_ids)[0].strip().lower()
51
+
52
+ def transcribe_hubert(file_path):
53
+ audio = load_audio(file_path)
54
+ inputs = hubert_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
55
+ with torch.no_grad():
56
+ logits = hubert_model(**inputs.to(device)).logits
57
+ predicted_ids = torch.argmax(logits, dim=-1)
58
+ return hubert_processor.batch_decode(predicted_ids)[0].strip().lower()
59
+
60
+ def highlight_differences(ref, hyp):
61
+ sm = SequenceMatcher(None, ref.split(), hyp.split())
62
+ result = []
63
+ for opcode, i1, i2, j1, j2 in sm.get_opcodes():
64
+ if opcode == 'equal':
65
+ result.extend(hyp.split()[j1:j2])
66
+ elif opcode in ('replace', 'insert', 'delete'):
67
+ wrong = hyp.split()[j1:j2]
68
+ result.extend([f"<span style='color:red'>{w}</span>" for w in wrong])
69
+ return " ".join(result)
70
+
71
+ def run_demo(age, gender, accent):
72
+ filtered = [
73
+ entry for entry in data
74
+ if entry["age"] == age and entry["gender"] == gender and entry["accent"] == accent
75
+ ]
76
+ if not filtered:
77
+ return "No matching sample.", None, "", "", "", "", "", ""
78
+
79
+ sample = random.choice(filtered)
80
+ file_path = os.path.join("common_voice_en_validated_249", sample["path"])
81
+ gold = sample["sentence"].strip().lower()
82
+
83
+ whisper_text = transcribe_whisper(file_path)
84
+ wav2vec_text = transcribe_wav2vec(file_path)
85
+ hubert_text = transcribe_hubert(file_path)
86
+
87
+ table = f"""
88
+ <table border="1" style="width:100%">
89
+ <tr><th>Model</th><th>Transcription</th><th>WER</th></tr>
90
+ <tr><td><b>Gold</b></td><td>{gold}</td><td>0.00</td></tr>
91
+ <tr><td>Whisper</td><td>{highlight_differences(gold, whisper_text)}</td><td>{wer(gold, whisper_text):.2f}</td></tr>
92
+ <tr><td>Wav2Vec2</td><td>{highlight_differences(gold, wav2vec_text)}</td><td>{wer(gold, wav2vec_text):.2f}</td></tr>
93
+ <tr><td>HuBERT</td><td>{highlight_differences(gold, hubert_text)}</td><td>{wer(gold, hubert_text):.2f}</td></tr>
94
+ </table>
95
+ """
96
+
97
+ return sample["sentence"], file_path, gold, whisper_text, wav2vec_text, hubert_text, table, f"Audio path: {file_path}"
98
+
99
+ with gr.Blocks() as demo:
100
+ gr.Markdown("# ASR Model Comparison on ESL Audio")
101
+ gr.Markdown("Filter by age, gender, and accent. Then generate a random ESL learner's audio to compare how Whisper, Wav2Vec2, and HuBERT transcribe it.")
102
+
103
+ with gr.Row():
104
+ age = gr.Dropdown(choices=ages, label="Age")
105
+ gender = gr.Dropdown(choices=genders, label="Gender")
106
+ accent = gr.Dropdown(choices=accents, label="Accent")
107
+
108
+ btn = gr.Button("Generate and Transcribe")
109
+ audio = gr.Audio(label="Audio", type="filepath")
110
+ wer_output = gr.HTML()
111
+
112
+ btn.click(fn=run_demo, inputs=[age, gender, accent], outputs=[
113
+ gr.Textbox(label="Gold (Correct)"),
114
+ audio,
115
+ gr.Textbox(label="Whisper Output"),
116
+ gr.Textbox(label="Wav2Vec2 Output"),
117
+ gr.Textbox(label="HuBERT Output"),
118
+ wer_output,
119
+ gr.Textbox(label="Path")
120
+ ])
121
+
122
+ demo.launch()