huckiyang's picture
refines
20a8edc
raw
history blame
12.7 kB
import gradio as gr
import pandas as pd
from datasets import load_dataset
import numpy as np
from functools import lru_cache
import re
from collections import Counter
import editdistance
# Cache the dataset loading to avoid reloading on refresh
@lru_cache(maxsize=1)
def load_data():
try:
dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test")
return dataset
except Exception:
# Fallback to explicit file path if default loading fails
return load_dataset("parquet",
data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")
# Preprocess text for better WER calculation
def preprocess_text(text):
if not text or not isinstance(text, str):
return ""
text = text.lower()
text = re.sub(r'[^\w\s]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
# N-gram scoring for hypothesis ranking
def score_hypothesis(hypothesis, n=4):
if not hypothesis:
return 0
words = hypothesis.split()
if len(words) < n:
return len(words)
ngrams = []
for i in range(len(words) - n + 1):
ngram = ' '.join(words[i:i+n])
ngrams.append(ngram)
unique_ngrams = len(set(ngrams))
total_ngrams = len(ngrams)
score = len(words) + unique_ngrams/max(1, total_ngrams) * 5
return score
# N-gram ranking approach
def get_best_hypothesis_lm(hypotheses):
if not hypotheses:
return ""
if isinstance(hypotheses, str):
return hypotheses
hypothesis_list = [preprocess_text(h) for h in hypotheses if isinstance(h, str)]
if not hypothesis_list:
return ""
scores = [(score_hypothesis(h), h) for h in hypothesis_list]
best_hypothesis = max(scores, key=lambda x: x[0])[1]
return best_hypothesis
# Subwords voting correction approach
def correct_hypotheses(hypotheses):
if not hypotheses:
return ""
if isinstance(hypotheses, str):
return hypotheses
hypothesis_list = [preprocess_text(h) for h in hypotheses if isinstance(h, str)]
if not hypothesis_list:
return ""
word_lists = [h.split() for h in hypothesis_list]
lengths = [len(words) for words in word_lists]
if not lengths:
return ""
most_common_length = Counter(lengths).most_common(1)[0][0]
filtered_word_lists = [words for words in word_lists if len(words) == most_common_length]
if not filtered_word_lists:
return max(hypothesis_list, key=len)
corrected_words = []
for i in range(most_common_length):
position_words = [words[i] for words in filtered_word_lists]
most_common_word = Counter(position_words).most_common(1)[0][0]
corrected_words.append(most_common_word)
return ' '.join(corrected_words)
# Calculate WER
def calculate_simple_wer(reference, hypothesis):
if not reference or not hypothesis:
return 1.0
ref_words = reference.split()
hyp_words = hypothesis.split()
distance = editdistance.eval(ref_words, hyp_words)
if len(ref_words) == 0:
return 1.0
return float(distance) / float(len(ref_words))
# Calculate WER for a group of examples with multiple methods
def calculate_wer_methods(examples, max_samples=200):
if not examples or len(examples) == 0:
return np.nan, np.nan, np.nan
# Limit sample size for efficiency
if hasattr(examples, 'select'):
items_to_process = examples.select(range(min(max_samples, len(examples))))
else:
items_to_process = examples[:max_samples]
wer_values_no_lm = []
wer_values_lm_ranking = []
wer_values_n_best_correction = []
for ex in items_to_process:
# Get reference transcription
transcription = ex.get("transcription")
if not transcription or not isinstance(transcription, str):
continue
reference = preprocess_text(transcription)
if not reference:
continue
# Get 1-best hypothesis for baseline
input1 = ex.get("input1")
if input1 is None and "hypothesis" in ex and ex["hypothesis"]:
if isinstance(ex["hypothesis"], list) and len(ex["hypothesis"]) > 0:
input1 = ex["hypothesis"][0]
elif isinstance(ex["hypothesis"], str):
input1 = ex["hypothesis"]
# Get n-best hypotheses for other methods
n_best_hypotheses = ex.get("hypothesis", [])
# Method 1: No LM (1-best ASR output)
if input1 and isinstance(input1, str):
no_lm_hyp = preprocess_text(input1)
if no_lm_hyp:
wer_no_lm = calculate_simple_wer(reference, no_lm_hyp)
wer_values_no_lm.append(wer_no_lm)
# Method 2: N-gram ranking
if n_best_hypotheses:
lm_best_hyp = get_best_hypothesis_lm(n_best_hypotheses)
if lm_best_hyp:
wer_lm = calculate_simple_wer(reference, lm_best_hyp)
wer_values_lm_ranking.append(wer_lm)
# Method 3: Subwords voting correction
if n_best_hypotheses:
corrected_hyp = correct_hypotheses(n_best_hypotheses)
if corrected_hyp:
wer_corrected = calculate_simple_wer(reference, corrected_hyp)
wer_values_n_best_correction.append(wer_corrected)
# Calculate average WER for each method
no_lm_wer = np.mean(wer_values_no_lm) if wer_values_no_lm else np.nan
lm_ranking_wer = np.mean(wer_values_lm_ranking) if wer_values_lm_ranking else np.nan
n_best_correction_wer = np.mean(wer_values_n_best_correction) if wer_values_n_best_correction else np.nan
return no_lm_wer, lm_ranking_wer, n_best_correction_wer
# Get WER metrics by source
def get_wer_metrics(dataset):
# Group examples by source
examples_by_source = {}
for ex in dataset:
source = ex.get("source", "unknown")
# Skip all_et05_real as requested
if source == "all_et05_real":
continue
if source not in examples_by_source:
examples_by_source[source] = []
examples_by_source[source].append(ex)
# Get all unique sources
all_sources = sorted(examples_by_source.keys())
# Calculate metrics for each source
source_results = {}
for source in all_sources:
examples = examples_by_source.get(source, [])
count = len(examples)
if count > 0:
no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(examples)
else:
no_lm_wer, lm_ranking_wer, n_best_wer = np.nan, np.nan, np.nan
source_results[source] = {
"Count": count,
"No LM Baseline": no_lm_wer,
"N-best LM Ranking": lm_ranking_wer,
"N-best Correction": n_best_wer
}
# Calculate overall metrics
filtered_dataset = [ex for ex in dataset if ex.get("source") != "all_et05_real"]
total_count = len(filtered_dataset)
sample_size = min(500, total_count)
sample_dataset = filtered_dataset[:sample_size]
no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(sample_dataset)
source_results["OVERALL"] = {
"Count": total_count,
"No LM Baseline": no_lm_wer,
"N-best LM Ranking": lm_ranking_wer,
"N-best Correction": n_best_wer
}
# Create flat DataFrame with labels in the first column
rows = []
# First add row for number of examples
example_row = {"Methods": "Number of Examples"}
for source in all_sources + ["OVERALL"]:
example_row[source] = source_results[source]["Count"]
rows.append(example_row)
# Then add rows for each WER method with simplified names
no_lm_row = {"Methods": "No LM"}
lm_ranking_row = {"Methods": "N-gram Ranking"}
n_best_row = {"Methods": "Subwords Voting"}
# Add the additional methods from the figure
llama_lora_row = {"Methods": "LLaMA-7B-LoRA"}
nb_oracle_row = {"Methods": "N-best Oracle (o_nb)"}
cp_oracle_row = {"Methods": "Compositional Oracle (o_cp)"}
# Populate the existing methods
for source in all_sources + ["OVERALL"]:
no_lm_row[source] = source_results[source]["No LM Baseline"]
lm_ranking_row[source] = source_results[source]["N-best LM Ranking"]
n_best_row[source] = source_results[source]["N-best Correction"]
# Add hardcoded values for the additional methods based on the figure
# Default to NaN for sources not in the figure
llama_lora_row[source] = np.nan
nb_oracle_row[source] = np.nan
cp_oracle_row[source] = np.nan
# Add hardcoded values from the figure for each source
# CHiME-4
if "test_chime4" in all_sources:
llama_lora_row["test_chime4"] = 6.6 / 100 # Convert from percentage
nb_oracle_row["test_chime4"] = 9.1 / 100
cp_oracle_row["test_chime4"] = 2.8 / 100
# Tedlium-3
if "test_td3" in all_sources:
llama_lora_row["test_td3"] = 4.6 / 100
nb_oracle_row["test_td3"] = 3.0 / 100
cp_oracle_row["test_td3"] = 0.7 / 100
# CommonVoice (CV-accent)
if "test_cv" in all_sources:
llama_lora_row["test_cv"] = 11.0 / 100
nb_oracle_row["test_cv"] = 11.4 / 100
cp_oracle_row["test_cv"] = 7.9 / 100
# SwitchBoard
if "test_swbd" in all_sources:
llama_lora_row["test_swbd"] = 14.1 / 100
nb_oracle_row["test_swbd"] = 12.6 / 100
cp_oracle_row["test_swbd"] = 4.2 / 100
# LRS2
if "test_lrs2" in all_sources:
llama_lora_row["test_lrs2"] = 8.8 / 100
nb_oracle_row["test_lrs2"] = 6.9 / 100
cp_oracle_row["test_lrs2"] = 2.6 / 100
# CORAAL
if "test_coraal" in all_sources:
llama_lora_row["test_coraal"] = 19.2 / 100
nb_oracle_row["test_coraal"] = 21.8 / 100
cp_oracle_row["test_coraal"] = 10.7 / 100
# LibriSpeech Clean
if "test_ls_clean" in all_sources:
llama_lora_row["test_ls_clean"] = 1.7 / 100
nb_oracle_row["test_ls_clean"] = 1.0 / 100
cp_oracle_row["test_ls_clean"] = 0.6 / 100
# LibriSpeech Other
if "test_ls_other" in all_sources:
llama_lora_row["test_ls_other"] = 3.8 / 100
nb_oracle_row["test_ls_other"] = 2.7 / 100
cp_oracle_row["test_ls_other"] = 1.6 / 100
# Add rows in the desired order
rows.append(no_lm_row)
rows.append(lm_ranking_row)
rows.append(n_best_row)
rows.append(llama_lora_row)
rows.append(nb_oracle_row)
rows.append(cp_oracle_row)
# Create DataFrame from rows
result_df = pd.DataFrame(rows)
return result_df
# Format the dataframe for display
def format_dataframe(df):
df = df.copy()
# Find the rows containing WER values
wer_row_indices = []
for i, method in enumerate(df["Methods"]):
if method not in ["Number of Examples"]:
wer_row_indices.append(i)
# Format WER values
for idx in wer_row_indices:
for col in df.columns:
if col != "Methods":
value = df.loc[idx, col]
if pd.notna(value):
df.loc[idx, col] = f"{value:.4f}"
else:
df.loc[idx, col] = "N/A"
return df
# Main function to create the leaderboard
def create_leaderboard():
dataset = load_data()
metrics_df = get_wer_metrics(dataset)
return format_dataframe(metrics_df)
# Create the Gradio interface
with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
gr.Markdown("# HyPoraside N-Best WER Leaderboard (Test Data)")
gr.Markdown("Word Error Rate (WER) metrics for different speech sources with multiple correction approaches")
with gr.Row():
refresh_btn = gr.Button("Refresh Leaderboard")
with gr.Row():
gr.Markdown("### Word Error Rate (WER)")
with gr.Row():
try:
initial_df = create_leaderboard()
leaderboard = gr.DataFrame(initial_df)
except Exception:
leaderboard = gr.DataFrame(pd.DataFrame([{"Error": "Error initializing leaderboard"}]))
def refresh_and_report():
return create_leaderboard()
refresh_btn.click(refresh_and_report, outputs=[leaderboard])
if __name__ == "__main__":
demo.launch()