|
import gradio as gr |
|
import pandas as pd |
|
from datasets import load_dataset |
|
import jiwer |
|
import numpy as np |
|
from functools import lru_cache |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def load_data(): |
|
return load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction") |
|
|
|
|
|
def calculate_wer(examples): |
|
if not examples: |
|
return 0.0 |
|
|
|
|
|
valid_pairs = [(ex.get("transcription", "").strip(), ex.get("input1", "").strip()) |
|
for ex in examples |
|
if ex.get("transcription") and ex.get("input1")] |
|
|
|
if not valid_pairs: |
|
return np.nan |
|
|
|
|
|
references, hypotheses = zip(*valid_pairs) if valid_pairs else ([], []) |
|
|
|
|
|
return jiwer.wer(references, hypotheses) |
|
|
|
|
|
def get_wer_metrics(dataset): |
|
|
|
train_by_source = {} |
|
test_by_source = {} |
|
|
|
|
|
for ex in dataset["train"]: |
|
source = ex["source"] |
|
if source not in train_by_source: |
|
train_by_source[source] = [] |
|
train_by_source[source].append(ex) |
|
|
|
for ex in dataset["test"]: |
|
source = ex["source"] |
|
if source not in test_by_source: |
|
test_by_source[source] = [] |
|
test_by_source[source].append(ex) |
|
|
|
|
|
all_sources = sorted(set(train_by_source.keys()) | set(test_by_source.keys())) |
|
|
|
|
|
results = [] |
|
for source in all_sources: |
|
train_examples = train_by_source.get(source, []) |
|
test_examples = test_by_source.get(source, []) |
|
|
|
train_count = len(train_examples) |
|
test_count = len(test_examples) |
|
|
|
train_wer = calculate_wer(train_examples) if train_count > 0 else np.nan |
|
test_wer = calculate_wer(test_examples) if test_count > 0 else np.nan |
|
|
|
results.append({ |
|
"Source": source, |
|
"Train Count": train_count, |
|
"Train WER": train_wer, |
|
"Test Count": test_count, |
|
"Test WER": test_wer |
|
}) |
|
|
|
|
|
train_wer = calculate_wer(dataset["train"]) |
|
test_wer = calculate_wer(dataset["test"]) |
|
|
|
results.append({ |
|
"Source": "OVERALL", |
|
"Train Count": len(dataset["train"]), |
|
"Train WER": train_wer, |
|
"Test Count": len(dataset["test"]), |
|
"Test WER": test_wer |
|
}) |
|
|
|
return pd.DataFrame(results) |
|
|
|
|
|
def format_dataframe(df): |
|
|
|
df = df.copy() |
|
mask = df["Train WER"].notna() |
|
df.loc[mask, "Train WER"] = df.loc[mask, "Train WER"].map(lambda x: f"{x:.4f}") |
|
df.loc[~mask, "Train WER"] = "N/A" |
|
|
|
mask = df["Test WER"].notna() |
|
df.loc[mask, "Test WER"] = df.loc[mask, "Test WER"].map(lambda x: f"{x:.4f}") |
|
df.loc[~mask, "Test WER"] = "N/A" |
|
|
|
return df |
|
|
|
|
|
def create_leaderboard(): |
|
try: |
|
dataset = load_data() |
|
metrics_df = get_wer_metrics(dataset) |
|
return format_dataframe(metrics_df) |
|
except Exception as e: |
|
return pd.DataFrame({"Error": [str(e)]}) |
|
|
|
|
|
with gr.Blocks(title="ASR Text Correction Leaderboard") as demo: |
|
gr.Markdown("# ASR Text Correction Baseline WER Leaderboard") |
|
gr.Markdown("Word Error Rate (WER) metrics for GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction dataset") |
|
|
|
with gr.Row(): |
|
refresh_btn = gr.Button("Refresh Leaderboard") |
|
|
|
with gr.Row(): |
|
leaderboard = gr.DataFrame(create_leaderboard()) |
|
|
|
refresh_btn.click(create_leaderboard, outputs=leaderboard) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |