huckiyang's picture
optz the data loading
4e73867
raw
history blame
7.78 kB
import gradio as gr
import pandas as pd
from datasets import load_dataset
import jiwer
import numpy as np
from functools import lru_cache
import traceback
# Cache the dataset loading to avoid reloading on refresh
@lru_cache(maxsize=1)
def load_data():
try:
# Load only the test dataset by specifying the split
dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test")
return dataset
except Exception as e:
print(f"Error loading dataset: {str(e)}")
# Try loading with explicit file path if the default loading fails
try:
dataset = load_dataset("parquet",
data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")
return dataset
except Exception as e2:
print(f"Error loading with explicit path: {str(e2)}")
raise
# Calculate WER for a group of examples
def calculate_wer(examples):
if not examples:
return 0.0
try:
# Filter valid examples in a single pass
valid_pairs = []
for ex in examples:
try:
# Print a sample example to debug
if len(valid_pairs) == 0:
print(f"Sample example keys: {ex.keys()}")
transcription = ex.get("transcription", "")
input1 = ex.get("input1", "")
# Only add valid pairs with non-empty strings
if transcription and input1 and isinstance(transcription, str) and isinstance(input1, str):
# Limit text length to avoid potential issues
transcription = transcription.strip()[:1000] # Limit to 1000 chars
input1 = input1.strip()[:1000]
valid_pairs.append((transcription, input1))
except Exception as ex_error:
# Skip problematic examples but continue processing
print(f"Error processing example: {str(ex_error)}")
continue
if not valid_pairs:
print("No valid pairs found for WER calculation")
return np.nan
# Print sample pairs for debugging
print(f"Sample pair for WER calculation: {valid_pairs[0]}")
print(f"Total valid pairs: {len(valid_pairs)}")
# Unzip the pairs in one operation
references, hypotheses = zip(*valid_pairs) if valid_pairs else ([], [])
# Calculate WER
try:
wer = jiwer.wer(references, hypotheses)
print(f"Calculated WER: {wer}")
return wer
except Exception as wer_error:
print(f"Error calculating WER: {str(wer_error)}")
return np.nan
except Exception as e:
print(f"Error in calculate_wer: {str(e)}")
print(traceback.format_exc())
return np.nan
# Get WER metrics by source
def get_wer_metrics(dataset):
try:
# Group examples by source
examples_by_source = {}
# Process all examples
for ex in dataset:
try:
source = ex.get("source", "unknown")
if source not in examples_by_source:
examples_by_source[source] = []
examples_by_source[source].append(ex)
except Exception as e:
print(f"Error processing example: {str(e)}")
continue
# Get all unique sources
all_sources = sorted(examples_by_source.keys())
# Calculate metrics for each source
results = []
for source in all_sources:
try:
examples = examples_by_source.get(source, [])
count = len(examples)
if count > 0:
print(f"Calculating WER for source {source} with {count} examples")
wer = calculate_wer(examples)
else:
wer = np.nan
results.append({
"Source": source,
"Count": count,
"WER": wer
})
except Exception as e:
print(f"Error processing source {source}: {str(e)}")
results.append({
"Source": source,
"Count": 0,
"WER": np.nan
})
# Calculate overall metrics once
try:
total_count = len(dataset)
print(f"Calculating overall WER for {total_count} examples")
overall_wer = calculate_wer(dataset)
results.append({
"Source": "OVERALL",
"Count": total_count,
"WER": overall_wer
})
except Exception as e:
print(f"Error calculating overall metrics: {str(e)}")
results.append({
"Source": "OVERALL",
"Count": len(dataset),
"WER": np.nan
})
return pd.DataFrame(results)
except Exception as e:
print(f"Error in get_wer_metrics: {str(e)}")
print(traceback.format_exc())
return pd.DataFrame([{"Error": str(e)}])
# Format the dataframe for display
def format_dataframe(df):
try:
# Use vectorized operations instead of apply
df = df.copy()
if "WER" in df.columns:
mask = df["WER"].notna()
df.loc[mask, "WER"] = df.loc[mask, "WER"].map(lambda x: f"{x:.4f}")
df.loc[~mask, "WER"] = "N/A"
return df
except Exception as e:
print(f"Error in format_dataframe: {str(e)}")
print(traceback.format_exc())
return pd.DataFrame([{"Error": str(e)}])
# Main function to create the leaderboard
def create_leaderboard():
try:
dataset = load_data()
metrics_df = get_wer_metrics(dataset)
return format_dataframe(metrics_df)
except Exception as e:
error_msg = f"Error creating leaderboard: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return pd.DataFrame([{"Error": error_msg}])
# Create the Gradio interface
with gr.Blocks(title="ASR Text Correction Test Leaderboard") as demo:
gr.Markdown("# ASR Text Correction Baseline WER Leaderboard (Test Data)")
gr.Markdown("Word Error Rate (WER) metrics for test data in GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction dataset")
with gr.Row():
refresh_btn = gr.Button("Refresh Leaderboard")
with gr.Row():
error_output = gr.Textbox(label="Debug Information", visible=True)
with gr.Row():
try:
initial_df = create_leaderboard()
leaderboard = gr.DataFrame(initial_df)
except Exception as e:
error_msg = f"Error initializing leaderboard: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
error_output.update(value=error_msg)
leaderboard = gr.DataFrame(pd.DataFrame([{"Error": error_msg}]))
def refresh_and_report():
try:
df = create_leaderboard()
debug_info = "Leaderboard refreshed successfully."
return df, debug_info
except Exception as e:
error_msg = f"Error refreshing leaderboard: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return pd.DataFrame([{"Error": error_msg}]), error_msg
refresh_btn.click(refresh_and_report, outputs=[leaderboard, error_output])
if __name__ == "__main__":
demo.launch()