File size: 3,994 Bytes
ad915da 3c6aeb7 ad915da 3c6aeb7 ad915da 3c6aeb7 ad915da 3c6aeb7 ad915da fbba242 3c6aeb7 fbba242 3c6aeb7 ad915da 3c6aeb7 ad915da 3c6aeb7 ad915da 3c6aeb7 ad915da 3c6aeb7 ad915da 3c6aeb7 ad915da 3c6aeb7 ad915da 3c6aeb7 ad915da 3c6aeb7 ad915da 3c6aeb7 ad915da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
import pandas as pd
from datasets import load_dataset
import jiwer
import numpy as np
from functools import lru_cache
# Cache the dataset loading to avoid reloading on refresh
@lru_cache(maxsize=1)
def load_data():
return load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction")
# Calculate WER for a group of examples
def calculate_wer(examples):
if not examples:
return 0.0
# Filter valid examples in a single pass
valid_pairs = [(ex.get("transcription", "").strip(), ex.get("input1", "").strip())
for ex in examples
if ex.get("transcription") and ex.get("input1")]
if not valid_pairs:
return np.nan
# Unzip the pairs in one operation
references, hypotheses = zip(*valid_pairs) if valid_pairs else ([], [])
# Calculate WER
return jiwer.wer(references, hypotheses)
# Get WER metrics by source and split
def get_wer_metrics(dataset):
# Pre-process the data to avoid repeated filtering
train_by_source = {}
test_by_source = {}
# Group examples by source in a single pass for each split
for ex in dataset["train"]:
source = ex["source"]
if source not in train_by_source:
train_by_source[source] = []
train_by_source[source].append(ex)
for ex in dataset["test"]:
source = ex["source"]
if source not in test_by_source:
test_by_source[source] = []
test_by_source[source].append(ex)
# Get all unique sources
all_sources = sorted(set(train_by_source.keys()) | set(test_by_source.keys()))
# Calculate metrics for each source
results = []
for source in all_sources:
train_examples = train_by_source.get(source, [])
test_examples = test_by_source.get(source, [])
train_count = len(train_examples)
test_count = len(test_examples)
train_wer = calculate_wer(train_examples) if train_count > 0 else np.nan
test_wer = calculate_wer(test_examples) if test_count > 0 else np.nan
results.append({
"Source": source,
"Train Count": train_count,
"Train WER": train_wer,
"Test Count": test_count,
"Test WER": test_wer
})
# Calculate overall metrics once
train_wer = calculate_wer(dataset["train"])
test_wer = calculate_wer(dataset["test"])
results.append({
"Source": "OVERALL",
"Train Count": len(dataset["train"]),
"Train WER": train_wer,
"Test Count": len(dataset["test"]),
"Test WER": test_wer
})
return pd.DataFrame(results)
# Format the dataframe for display
def format_dataframe(df):
# Use vectorized operations instead of apply
df = df.copy()
mask = df["Train WER"].notna()
df.loc[mask, "Train WER"] = df.loc[mask, "Train WER"].map(lambda x: f"{x:.4f}")
df.loc[~mask, "Train WER"] = "N/A"
mask = df["Test WER"].notna()
df.loc[mask, "Test WER"] = df.loc[mask, "Test WER"].map(lambda x: f"{x:.4f}")
df.loc[~mask, "Test WER"] = "N/A"
return df
# Main function to create the leaderboard
def create_leaderboard():
try:
dataset = load_data()
metrics_df = get_wer_metrics(dataset)
return format_dataframe(metrics_df)
except Exception as e:
return pd.DataFrame({"Error": [str(e)]})
# Create the Gradio interface
with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
gr.Markdown("# ASR Text Correction Baseline WER Leaderboard")
gr.Markdown("Word Error Rate (WER) metrics for GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction dataset")
with gr.Row():
refresh_btn = gr.Button("Refresh Leaderboard")
with gr.Row():
leaderboard = gr.DataFrame(create_leaderboard())
refresh_btn.click(create_leaderboard, outputs=leaderboard)
if __name__ == "__main__":
demo.launch() |