Spaces:
Sleeping
Sleeping
import spaces | |
import jiwer | |
import numpy as np | |
import gradio as gr | |
import nltk | |
from nltk.tokenize import sent_tokenize | |
# Ensure NLTK data is downloaded | |
nltk.download('punkt') | |
def calculate_wer(reference, hypothesis): | |
""" | |
Calculate the Word Error Rate (WER) using jiwer. | |
""" | |
wer = jiwer.wer(reference, hypothesis) | |
return wer | |
def calculate_cer(reference, hypothesis): | |
""" | |
Calculate the Character Error Rate (CER) using jiwer. | |
""" | |
cer = jiwer.cer(reference, hypothesis) | |
return cer | |
def calculate_sentence_wer(reference, hypothesis): | |
""" | |
Calculate WER for each sentence and overall statistics. | |
""" | |
try: | |
reference_sentences = sent_tokenize(reference) | |
hypothesis_sentences = sent_tokenize(hypothesis) | |
# Get minimum number of sentences | |
min_sentences = min(len(reference_sentences), len(hypothesis_sentences)) | |
# Trim to the same number of sentences | |
reference_sentences = reference_sentences[:min_sentences] | |
hypothesis_sentences = hypothesis_sentences[:min_sentences] | |
sentence_wers = [] | |
for ref, hyp in zip(reference_sentences, hypothesis_sentences): | |
sentence_wer = jiwer.wer(ref, hyp) | |
sentence_wers.append(sentence_wer) | |
if not sentence_wers: | |
return { | |
"sentence_wers": [], | |
"average_wer": 0.0, | |
"std_dev": 0.0, | |
"warning": "No sentences to compare" | |
} | |
average_wer = np.mean(sentence_wers) | |
std_dev = np.std(sentence_wers) | |
# Check if there were extra sentences | |
if len(reference_sentences) != len(hypothesis_sentences): | |
warning = f"Reference has {len(reference_sentences)} sentences, " \ | |
f"hypothesis has {len(hypothesis_sentences)} sentences. " \ | |
f"Only compared the first {min_sentences} sentences." | |
else: | |
warning = None | |
return { | |
"sentence_wers": sentence_wers, | |
"average_wer": average_wer, | |
"std_dev": std_dev, | |
"warning": warning | |
} | |
except Exception as e: | |
return { | |
"sentence_wers": [], | |
"average_wer": 0.0, | |
"std_dev": 0.0, | |
"error": str(e) | |
} | |
def process_files(reference_file, hypothesis_file): | |
try: | |
with open(reference_file.name, 'r') as f: | |
reference_text = f.read() | |
with open(hypothesis_file.name, 'r') as f: | |
hypothesis_text = f.read() | |
if not reference_text or not hypothesis_text: | |
return { | |
"error": "Both reference and hypothesis files must contain text" | |
} | |
wer_value = calculate_wer(reference_text, hypothesis_text) | |
cer_value = calculate_cer(reference_text, hypothesis_text) | |
sentence_wer_stats = calculate_sentence_wer(reference_text, hypothesis_text) | |
return { | |
"WER": wer_value, | |
"CER": cer_value, | |
"Sentence WERs": sentence_wer_stats["sentence_wers"], | |
"Average WER": sentence_wer_stats["average_wer"], | |
"Standard Deviation": sentence_wer_stats["std_dev"], | |
"Warning": sentence_wer_stats.get("warning"), | |
"Error": sentence_wer_stats.get("error") | |
} | |
except Exception as e: | |
return { | |
"WER": 0.0, | |
"CER": 0.0, | |
"Sentence WERs": [], | |
"Average WER": 0.0, | |
"Standard Deviation": 0.0, | |
"Error": str(e) | |
} | |
def format_sentence_wer_stats(sentence_wers, average_wer, std_dev, warning, error): | |
md = "" | |
if error: | |
md += f"### Error\n{error}\n\n" | |
elif warning: | |
md += f"### Warning\n{warning}\n\n" | |
if not sentence_wers: | |
md += "No sentences to compare" | |
return md | |
md += "### Sentence-level WER Analysis\n\n" | |
md += f"* Average WER: {average_wer:.2f}\n" | |
md += f"* Standard Deviation: {std_dev:.2f}\n\n" | |
md += "### WER for Each Sentence\n\n" | |
for i, wer in enumerate(sentence_wers): | |
md += f"* Sentence {i+1}: {wer:.2f}\n" | |
return md | |
def main(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# ASR Metrics Calculator") | |
with gr.Row(): | |
reference_file = gr.File(label="Upload Reference File") | |
hypothesis_file = gr.File(label="Upload Hypothesis File") | |
with gr.Row(): | |
reference_preview = gr.Textbox(label="Reference Preview", lines=3) | |
hypothesis_preview = gr.Textbox(label="Hypothesis Preview", lines=3) | |
with gr.Row(): | |
compute_button = gr.Button("Compute Metrics") | |
results_output = gr.JSON(label="Results") | |
wer_stats_output = gr.Markdown(label="WER Statistics") | |
# Update previews when files are uploaded | |
def update_previews(ref_file, hyp_file): | |
ref_text = "" | |
hyp_text = "" | |
if ref_file: | |
with open(ref_file.name, 'r') as f: | |
ref_text = f.read()[:200] # Show first 200 characters | |
if hyp_file: | |
with open(hyp_file.name, 'r') as f: | |
hyp_text = f.read()[:200] # Show first 200 characters | |
return ref_text, hyp_text | |
reference_file.change( | |
fn=update_previews, | |
inputs=[reference_file, hypothesis_file], | |
outputs=[reference_preview, hypothesis_preview] | |
) | |
hypothesis_file.change( | |
fn=update_previews, | |
inputs=[reference_file, hypothesis_file], | |
outputs=[reference_preview, hypothesis_preview] | |
) | |
def process_and_display(ref_file, hyp_file): | |
result = process_files(ref_file, hyp_file) | |
metrics = { | |
"WER": result["WER"], | |
"CER": result["CER"] | |
} | |
error = result.get("Error") | |
warning = result.get("Warning") | |
sentence_wers = result.get("Sentence WERs", []) | |
average_wer = result.get("Average WER", 0.0) | |
std_dev = result.get("Standard Deviation", 0.0) | |
wer_stats_md = format_sentence_wer_stats( | |
sentence_wers, | |
average_wer, | |
std_dev, | |
warning, | |
error | |
) | |
return metrics, wer_stats_md | |
compute_button.click( | |
fn=process_and_display, | |
inputs=[reference_file, hypothesis_file], | |
outputs=[results_output, wer_stats_output] | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |