|
import gradio as gr |
|
import pandas as pd |
|
from datasets import load_dataset |
|
import jiwer |
|
import numpy as np |
|
from functools import lru_cache |
|
import traceback |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def load_data(): |
|
try: |
|
|
|
dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test") |
|
return dataset |
|
except Exception as e: |
|
print(f"Error loading dataset: {str(e)}") |
|
|
|
try: |
|
dataset = load_dataset("parquet", |
|
data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet") |
|
return dataset |
|
except Exception as e2: |
|
print(f"Error loading with explicit path: {str(e2)}") |
|
raise |
|
|
|
|
|
def calculate_wer(examples): |
|
if not examples: |
|
return 0.0 |
|
|
|
try: |
|
|
|
valid_pairs = [] |
|
for ex in examples: |
|
try: |
|
|
|
if len(valid_pairs) == 0: |
|
print(f"Sample example keys: {ex.keys()}") |
|
|
|
transcription = ex.get("transcription", "") |
|
input1 = ex.get("input1", "") |
|
|
|
|
|
if transcription and input1 and isinstance(transcription, str) and isinstance(input1, str): |
|
|
|
transcription = transcription.strip()[:1000] |
|
input1 = input1.strip()[:1000] |
|
valid_pairs.append((transcription, input1)) |
|
except Exception as ex_error: |
|
|
|
print(f"Error processing example: {str(ex_error)}") |
|
continue |
|
|
|
if not valid_pairs: |
|
print("No valid pairs found for WER calculation") |
|
return np.nan |
|
|
|
|
|
print(f"Sample pair for WER calculation: {valid_pairs[0]}") |
|
print(f"Total valid pairs: {len(valid_pairs)}") |
|
|
|
|
|
references, hypotheses = zip(*valid_pairs) if valid_pairs else ([], []) |
|
|
|
|
|
try: |
|
wer = jiwer.wer(references, hypotheses) |
|
print(f"Calculated WER: {wer}") |
|
return wer |
|
except Exception as wer_error: |
|
print(f"Error calculating WER: {str(wer_error)}") |
|
return np.nan |
|
|
|
except Exception as e: |
|
print(f"Error in calculate_wer: {str(e)}") |
|
print(traceback.format_exc()) |
|
return np.nan |
|
|
|
|
|
def get_wer_metrics(dataset): |
|
try: |
|
|
|
examples_by_source = {} |
|
|
|
|
|
for ex in dataset: |
|
try: |
|
source = ex.get("source", "unknown") |
|
if source not in examples_by_source: |
|
examples_by_source[source] = [] |
|
examples_by_source[source].append(ex) |
|
except Exception as e: |
|
print(f"Error processing example: {str(e)}") |
|
continue |
|
|
|
|
|
all_sources = sorted(examples_by_source.keys()) |
|
|
|
|
|
results = [] |
|
for source in all_sources: |
|
try: |
|
examples = examples_by_source.get(source, []) |
|
count = len(examples) |
|
|
|
if count > 0: |
|
print(f"Calculating WER for source {source} with {count} examples") |
|
wer = calculate_wer(examples) |
|
else: |
|
wer = np.nan |
|
|
|
results.append({ |
|
"Source": source, |
|
"Count": count, |
|
"WER": wer |
|
}) |
|
except Exception as e: |
|
print(f"Error processing source {source}: {str(e)}") |
|
results.append({ |
|
"Source": source, |
|
"Count": 0, |
|
"WER": np.nan |
|
}) |
|
|
|
|
|
try: |
|
total_count = len(dataset) |
|
print(f"Calculating overall WER for {total_count} examples") |
|
overall_wer = calculate_wer(dataset) |
|
|
|
results.append({ |
|
"Source": "OVERALL", |
|
"Count": total_count, |
|
"WER": overall_wer |
|
}) |
|
except Exception as e: |
|
print(f"Error calculating overall metrics: {str(e)}") |
|
results.append({ |
|
"Source": "OVERALL", |
|
"Count": len(dataset), |
|
"WER": np.nan |
|
}) |
|
|
|
return pd.DataFrame(results) |
|
|
|
except Exception as e: |
|
print(f"Error in get_wer_metrics: {str(e)}") |
|
print(traceback.format_exc()) |
|
return pd.DataFrame([{"Error": str(e)}]) |
|
|
|
|
|
def format_dataframe(df): |
|
try: |
|
|
|
df = df.copy() |
|
|
|
if "WER" in df.columns: |
|
mask = df["WER"].notna() |
|
df.loc[mask, "WER"] = df.loc[mask, "WER"].map(lambda x: f"{x:.4f}") |
|
df.loc[~mask, "WER"] = "N/A" |
|
|
|
return df |
|
|
|
except Exception as e: |
|
print(f"Error in format_dataframe: {str(e)}") |
|
print(traceback.format_exc()) |
|
return pd.DataFrame([{"Error": str(e)}]) |
|
|
|
|
|
def create_leaderboard(): |
|
try: |
|
dataset = load_data() |
|
metrics_df = get_wer_metrics(dataset) |
|
return format_dataframe(metrics_df) |
|
except Exception as e: |
|
error_msg = f"Error creating leaderboard: {str(e)}\n{traceback.format_exc()}" |
|
print(error_msg) |
|
return pd.DataFrame([{"Error": error_msg}]) |
|
|
|
|
|
with gr.Blocks(title="ASR Text Correction Test Leaderboard") as demo: |
|
gr.Markdown("# ASR Text Correction Baseline WER Leaderboard (Test Data)") |
|
gr.Markdown("Word Error Rate (WER) metrics for test data in GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction dataset") |
|
|
|
with gr.Row(): |
|
refresh_btn = gr.Button("Refresh Leaderboard") |
|
|
|
with gr.Row(): |
|
error_output = gr.Textbox(label="Debug Information", visible=True) |
|
|
|
with gr.Row(): |
|
try: |
|
initial_df = create_leaderboard() |
|
leaderboard = gr.DataFrame(initial_df) |
|
except Exception as e: |
|
error_msg = f"Error initializing leaderboard: {str(e)}\n{traceback.format_exc()}" |
|
print(error_msg) |
|
error_output.update(value=error_msg) |
|
leaderboard = gr.DataFrame(pd.DataFrame([{"Error": error_msg}])) |
|
|
|
def refresh_and_report(): |
|
try: |
|
df = create_leaderboard() |
|
debug_info = "Leaderboard refreshed successfully." |
|
return df, debug_info |
|
except Exception as e: |
|
error_msg = f"Error refreshing leaderboard: {str(e)}\n{traceback.format_exc()}" |
|
print(error_msg) |
|
return pd.DataFrame([{"Error": error_msg}]), error_msg |
|
|
|
refresh_btn.click(refresh_and_report, outputs=[leaderboard, error_output]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |