KDM999's picture
Upload app.py
63da647 verified
raw
history blame
4.74 kB
import gradio as gr
import random
import json
import os
from difflib import SequenceMatcher
from jiwer import wer
import torchaudio
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, HubertForCTC, HubertProcessor
import whisper
# Load metadata
with open("common_voice_en_validated_249_hf_ready.json") as f:
data = json.load(f)
# Available filter values
ages = sorted(set(entry["age"] for entry in data))
genders = sorted(set(entry["gender"] for entry in data))
accents = sorted(set(entry["accent"] for entry in data))
# Load models
device = "cuda" if torch.cuda.is_available() else "cpu"
# Whisper
whisper_model = whisper.load_model("medium").to(device)
# Wav2Vec2
wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(device)
# HuBERT
hubert_processor = HubertProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
hubert_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device)
def load_audio(file_path):
waveform, sr = torchaudio.load(file_path)
return torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0].numpy()
def transcribe_whisper(file_path):
result = whisper_model.transcribe(file_path)
return result["text"].strip().lower()
def transcribe_wav2vec(file_path):
audio = load_audio(file_path)
inputs = wav2vec_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = wav2vec_model(**inputs.to(device)).logits
predicted_ids = torch.argmax(logits, dim=-1)
return wav2vec_processor.batch_decode(predicted_ids)[0].strip().lower()
def transcribe_hubert(file_path):
audio = load_audio(file_path)
inputs = hubert_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = hubert_model(**inputs.to(device)).logits
predicted_ids = torch.argmax(logits, dim=-1)
return hubert_processor.batch_decode(predicted_ids)[0].strip().lower()
def highlight_differences(ref, hyp):
sm = SequenceMatcher(None, ref.split(), hyp.split())
result = []
for opcode, i1, i2, j1, j2 in sm.get_opcodes():
if opcode == 'equal':
result.extend(hyp.split()[j1:j2])
elif opcode in ('replace', 'insert', 'delete'):
wrong = hyp.split()[j1:j2]
result.extend([f"<span style='color:red'>{w}</span>" for w in wrong])
return " ".join(result)
def run_demo(age, gender, accent):
filtered = [
entry for entry in data
if entry["age"] == age and entry["gender"] == gender and entry["accent"] == accent
]
if not filtered:
return "No matching sample.", None, "", "", "", "", "", ""
sample = random.choice(filtered)
file_path = os.path.join("common_voice_en_validated_249", sample["path"])
gold = sample["sentence"].strip().lower()
whisper_text = transcribe_whisper(file_path)
wav2vec_text = transcribe_wav2vec(file_path)
hubert_text = transcribe_hubert(file_path)
table = f"""
<table border="1" style="width:100%">
<tr><th>Model</th><th>Transcription</th><th>WER</th></tr>
<tr><td><b>Gold</b></td><td>{gold}</td><td>0.00</td></tr>
<tr><td>Whisper</td><td>{highlight_differences(gold, whisper_text)}</td><td>{wer(gold, whisper_text):.2f}</td></tr>
<tr><td>Wav2Vec2</td><td>{highlight_differences(gold, wav2vec_text)}</td><td>{wer(gold, wav2vec_text):.2f}</td></tr>
<tr><td>HuBERT</td><td>{highlight_differences(gold, hubert_text)}</td><td>{wer(gold, hubert_text):.2f}</td></tr>
</table>
"""
return sample["sentence"], file_path, gold, whisper_text, wav2vec_text, hubert_text, table, f"Audio path: {file_path}"
with gr.Blocks() as demo:
gr.Markdown("# ASR Model Comparison on ESL Audio")
gr.Markdown("Filter by age, gender, and accent. Then generate a random ESL learner's audio to compare how Whisper, Wav2Vec2, and HuBERT transcribe it.")
with gr.Row():
age = gr.Dropdown(choices=ages, label="Age")
gender = gr.Dropdown(choices=genders, label="Gender")
accent = gr.Dropdown(choices=accents, label="Accent")
btn = gr.Button("Generate and Transcribe")
audio = gr.Audio(label="Audio", type="filepath")
wer_output = gr.HTML()
btn.click(fn=run_demo, inputs=[age, gender, accent], outputs=[
gr.Textbox(label="Gold (Correct)"),
audio,
gr.Textbox(label="Whisper Output"),
gr.Textbox(label="Wav2Vec2 Output"),
gr.Textbox(label="HuBERT Output"),
wer_output,
gr.Textbox(label="Path")
])
demo.launch()