import spaces import jiwer import numpy as np import gradio as gr import nltk from nltk.tokenize import sent_tokenize # Ensure NLTK data is downloaded nltk.download('punkt') @spaces.GPU() def calculate_wer(reference, hypothesis): """ Calculate the Word Error Rate (WER) using jiwer. """ wer = jiwer.wer(reference, hypothesis) return wer @spaces.GPU() def calculate_cer(reference, hypothesis): """ Calculate the Character Error Rate (CER) using jiwer. """ cer = jiwer.cer(reference, hypothesis) return cer @spaces.GPU() def calculate_sentence_wer(reference, hypothesis): """ Calculate WER for each sentence and overall statistics. """ try: reference_sentences = sent_tokenize(reference) hypothesis_sentences = sent_tokenize(hypothesis) # Get minimum number of sentences min_sentences = min(len(reference_sentences), len(hypothesis_sentences)) # Trim to the same number of sentences reference_sentences = reference_sentences[:min_sentences] hypothesis_sentences = hypothesis_sentences[:min_sentences] sentence_wers = [] for ref, hyp in zip(reference_sentences, hypothesis_sentences): sentence_wer = jiwer.wer(ref, hyp) sentence_wers.append(sentence_wer) if not sentence_wers: return { "sentence_wers": [], "average_wer": 0.0, "std_dev": 0.0, "warning": "No sentences to compare" } average_wer = np.mean(sentence_wers) std_dev = np.std(sentence_wers) # Check if there were extra sentences if len(reference_sentences) != len(hypothesis_sentences): warning = f"Reference has {len(reference_sentences)} sentences, " \ f"hypothesis has {len(hypothesis_sentences)} sentences. " \ f"Only compared the first {min_sentences} sentences." else: warning = None return { "sentence_wers": sentence_wers, "average_wer": average_wer, "std_dev": std_dev, "warning": warning } except Exception as e: return { "sentence_wers": [], "average_wer": 0.0, "std_dev": 0.0, "error": str(e) } @spaces.GPU() def process_files(reference_file, hypothesis_file): try: with open(reference_file.name, 'r') as f: reference_text = f.read() with open(hypothesis_file.name, 'r') as f: hypothesis_text = f.read() if not reference_text or not hypothesis_text: return { "error": "Both reference and hypothesis files must contain text" } wer_value = calculate_wer(reference_text, hypothesis_text) cer_value = calculate_cer(reference_text, hypothesis_text) sentence_wer_stats = calculate_sentence_wer(reference_text, hypothesis_text) return { "WER": wer_value, "CER": cer_value, "Sentence WERs": sentence_wer_stats["sentence_wers"], "Average WER": sentence_wer_stats["average_wer"], "Standard Deviation": sentence_wer_stats["std_dev"], "Warning": sentence_wer_stats.get("warning"), "Error": sentence_wer_stats.get("error") } except Exception as e: return { "WER": 0.0, "CER": 0.0, "Sentence WERs": [], "Average WER": 0.0, "Standard Deviation": 0.0, "Error": str(e) } def format_sentence_wer_stats(sentence_wers, average_wer, std_dev, warning, error): md = "" if error: md += f"### Error\n{error}\n\n" elif warning: md += f"### Warning\n{warning}\n\n" if not sentence_wers: md += "No sentences to compare" return md md += "### Sentence-level WER Analysis\n\n" md += f"* Average WER: {average_wer:.2f}\n" md += f"* Standard Deviation: {std_dev:.2f}\n\n" md += "### WER for Each Sentence\n\n" for i, wer in enumerate(sentence_wers): md += f"* Sentence {i+1}: {wer:.2f}\n" return md def main(): with gr.Blocks() as demo: gr.Markdown("# ASR Metrics Calculator") with gr.Row(): reference_file = gr.File(label="Upload Reference File") hypothesis_file = gr.File(label="Upload Hypothesis File") with gr.Row(): reference_preview = gr.Textbox(label="Reference Preview", lines=3) hypothesis_preview = gr.Textbox(label="Hypothesis Preview", lines=3) with gr.Row(): compute_button = gr.Button("Compute Metrics") results_output = gr.JSON(label="Results") wer_stats_output = gr.Markdown(label="WER Statistics") # Update previews when files are uploaded def update_previews(ref_file, hyp_file): ref_text = "" hyp_text = "" if ref_file: with open(ref_file.name, 'r') as f: ref_text = f.read()[:200] # Show first 200 characters if hyp_file: with open(hyp_file.name, 'r') as f: hyp_text = f.read()[:200] # Show first 200 characters return ref_text, hyp_text reference_file.change( fn=update_previews, inputs=[reference_file, hypothesis_file], outputs=[reference_preview, hypothesis_preview] ) hypothesis_file.change( fn=update_previews, inputs=[reference_file, hypothesis_file], outputs=[reference_preview, hypothesis_preview] ) def process_and_display(ref_file, hyp_file): result = process_files(ref_file, hyp_file) metrics = { "WER": result["WER"], "CER": result["CER"] } error = result.get("Error") warning = result.get("Warning") sentence_wers = result.get("Sentence WERs", []) average_wer = result.get("Average WER", 0.0) std_dev = result.get("Standard Deviation", 0.0) wer_stats_md = format_sentence_wer_stats( sentence_wers, average_wer, std_dev, warning, error ) return metrics, wer_stats_md compute_button.click( fn=process_and_display, inputs=[reference_file, hypothesis_file], outputs=[results_output, wer_stats_output] ) demo.launch() if __name__ == "__main__": main()