import spaces import jiwer import numpy as np import gradio as gr @spaces.GPU() def calculate_wer(reference, hypothesis): reference_str = " ".join(reference) hypothesis_str = " ".join(hypothesis) return jiwer.wer(reference_str, hypothesis_str) @spaces.GPU() def calculate_cer(reference, hypothesis): reference_str = " ".join(reference) hypothesis_str = " ".join(hypothesis) return jiwer.cer(reference_str, hypothesis_str) @spaces.GPU() def calculate_sentence_metrics(reference, hypothesis): reference_sentences = [line.strip() for line in reference] hypothesis_sentences = [line.strip() for line in hypothesis] sentence_wers = [] sentence_cers = [] min_length = min(len(reference_sentences), len(hypothesis_sentences)) for i in range(min_length): ref = reference_sentences[i] hyp = hypothesis_sentences[i] wer = jiwer.wer(ref, hyp) cer = jiwer.cer(ref, hyp) sentence_wers.append(wer) sentence_cers.append(cer) average_wer = np.mean(sentence_wers) if sentence_wers else 0.0 std_dev_wer = np.std(sentence_wers) if sentence_wers else 0.0 average_cer = np.mean(sentence_cers) if sentence_cers else 0.0 std_dev_cer = np.std(sentence_cers) if sentence_cers else 0.0 return { "sentence_wers": sentence_wers, "sentence_cers": sentence_cers, "average_wer": average_wer, "average_cer": average_cer, "std_dev_wer": std_dev_wer, "std_dev_cer": std_dev_cer } def identify_misaligned_sentences(reference, hypothesis): reference_sentences = [line.strip() for line in reference] hypothesis_sentences = [line.strip() for line in hypothesis] misaligned = [] for i, (ref, hyp) in enumerate(zip(reference_sentences, hypothesis_sentences)): if ref != hyp: ref_words = ref.split() hyp_words = hyp.split() min_length = min(len(ref_words), len(hyp_words)) misalignment_start = 0 for j in range(min_length): if ref_words[j] != hyp_words[j]: misalignment_start = j break context_ref = ' '.join(ref_words[:misalignment_start] + ['**' + ref_words[misalignment_start] + '**']) if ref_words else "" context_hyp = ' '.join(hyp_words[:misalignment_start] + ['**' + hyp_words[misalignment_start] + '**']) if hyp_words else "" misaligned.append({ "index": i + 1, "reference": ref, "hypothesis": hyp, "misalignment_start": misalignment_start, "context_ref": context_ref, "context_hyp": context_hyp }) # Handle extra sentences if len(reference_sentences) > len(hypothesis_sentences): for i in range(len(hypothesis_sentences), len(reference_sentences)): misaligned.append({ "index": i + 1, "reference": reference_sentences[i], "hypothesis": "No corresponding sentence", "misalignment_start": 0, "context_ref": reference_sentences[i], "context_hyp": "No corresponding sentence" }) elif len(hypothesis_sentences) > len(reference_sentences): for i in range(len(reference_sentences), len(hypothesis_sentences)): misaligned.append({ "index": i + 1, "reference": "No corresponding sentence", "hypothesis": hypothesis_sentences[i], "misalignment_start": 0, "context_ref": "No corresponding sentence", "context_hyp": hypothesis_sentences[i] }) return misaligned def format_sentence_metrics(sentence_wers, sentence_cers, average_wer, average_cer, std_dev_wer, std_dev_cer): md = "### Sentence-level Metrics\n\n" md += f"**Average WER**: {average_wer:.2f}\n\n" md += f"**Standard Deviation WER**: {std_dev_wer:.2f}\n\n" md += f"**Average CER**: {average_cer:.2f}\n\n" md += f"**Standard Deviation CER**: {std_dev_cer:.2f}\n\n" md += "---\n**WER by Sentence**\n" for i, wer in enumerate(sentence_wers): md += f"- Sentence {i+1}: {wer:.2f}\n" md += "\n**CER by Sentence**\n" for i, cer in enumerate(sentence_cers): md += f"- Sentence {i+1}: {cer:.2f}\n" return md def process_files(reference_file, hypothesis_file): try: with open(reference_file.name, 'r', encoding='utf-8') as f: reference_text = f.read().splitlines() with open(hypothesis_file.name, 'r', encoding='utf-8') as f: hypothesis_text = f.read().splitlines() overall_wer = calculate_wer(reference_text, hypothesis_text) overall_cer = calculate_cer(reference_text, hypothesis_text) sentence_metrics = calculate_sentence_metrics(reference_text, hypothesis_text) misaligned = identify_misaligned_sentences(reference_text, hypothesis_text) return { "Overall WER": overall_wer, "Overall CER": overall_cer, **sentence_metrics, "Misaligned Sentences": misaligned } except Exception as e: return {"error": str(e)} def process_and_display(ref_file, hyp_file): result = process_files(ref_file, hyp_file) if "error" in result: return {"error": result["error"]}, "", "" metrics = { "Overall WER": result["Overall WER"], "Overall CER": result["Overall CER"] } metrics_md = format_sentence_metrics( result["sentence_wers"], result["sentence_cers"], result["average_wer"], result["average_cer"], result["std_dev_wer"], result["std_dev_cer"] ) misaligned_md = "### Misaligned Sentences\n\n" if result["Misaligned Sentences"]: for mis in result["Misaligned Sentences"]: misaligned_md += f"**Sentence {mis['index']}**\n" misaligned_md += f"- Reference: {mis['context_ref']}\n" misaligned_md += f"- Hypothesis: {mis['context_hyp']}\n" misaligned_md += f"- Misalignment starts at position: {mis['misalignment_start']}\n\n" else: misaligned_md += "* No misaligned sentences found." return metrics, metrics_md, misaligned_md def main(): with gr.Blocks() as demo: gr.Markdown("# 📊 ASR Metrics Analysis Tool") with gr.Row(): with gr.Column(): gr.Markdown("### Upload your reference and hypothesis files") reference_file = gr.File(label="Reference File (.txt)") hypothesis_file = gr.File(label="Hypothesis File (.txt)") compute_button = gr.Button("Compute Metrics", variant="primary") with gr.Column(): results_output = gr.JSON(label="Results Summary") metrics_output = gr.Markdown(label="Sentence Metrics") misaligned_output = gr.Markdown(label="Misaligned Sentences") compute_button.click( fn=process_and_display, inputs=[reference_file, hypothesis_file], outputs=[results_output, metrics_output, misaligned_output] ) demo.launch(ssr_mode=False) if __name__ == "__main__": main()