Spaces:
Sleeping
Sleeping
import spaces | |
import jiwer | |
import numpy as np | |
import re | |
import gradio as gr | |
def split_into_sentences(text): | |
""" | |
Simple sentence tokenizer using regular expressions. | |
Splits text into sentences based on punctuation. | |
""" | |
sentences = re.split(r'(?<=[.!?])\s*', text) | |
sentences = [s.strip() for s in sentences if s.strip()] | |
return sentences | |
def calculate_wer(reference, hypothesis): | |
""" | |
Calculate the Word Error Rate (WER) using jiwer. | |
""" | |
wer = jiwer.wer(reference, hypothesis) | |
return wer | |
def calculate_cer(reference, hypothesis): | |
""" | |
Calculate the Character Error Rate (CER) using jiwer. | |
""" | |
cer = jiwer.cer(reference, hypothesis) | |
return cer | |
def calculate_sentence_metrics(reference, hypothesis): | |
""" | |
Calculate WER and CER for each sentence and overall statistics. | |
Handles cases where the number of sentences differ. | |
""" | |
try: | |
reference_sentences = split_into_sentences(reference) | |
hypothesis_sentences = split_into_sentences(hypothesis) | |
sentence_wers = [] | |
sentence_cers = [] | |
min_length = min(len(reference_sentences), len(hypothesis_sentences)) | |
for i in range(min_length): | |
ref = reference_sentences[i] | |
hyp = hypothesis_sentences[i] | |
wer = jiwer.wer(ref, hyp) | |
cer = jiwer.cer(ref, hyp) | |
sentence_wers.append(wer) | |
sentence_cers.append(cer) | |
# Calculate overall statistics | |
if sentence_wers: | |
average_wer = np.mean(sentence_wers) | |
std_dev_wer = np.std(sentence_wers) | |
else: | |
average_wer = 0.0 | |
std_dev_wer = 0.0 | |
if sentence_cers: | |
average_cer = np.mean(sentence_cers) | |
std_dev_cer = np.std(sentence_cers) | |
else: | |
average_cer = 0.0 | |
std_dev_cer = 0.0 | |
return { | |
"sentence_wers": sentence_wers, | |
"sentence_cers": sentence_cers, | |
"average_wer": average_wer, | |
"average_cer": average_cer, | |
"std_dev_wer": std_dev_wer, | |
"std_dev_cer": std_dev_cer | |
} | |
except Exception as e: | |
raise e | |
def identify_misaligned_sentences(reference_text, hypothesis_text): | |
""" | |
Identify sentences that don't match between reference and hypothesis. | |
Handles cases where the number of sentences differ. | |
Returns a dictionary with misaligned sentence pairs, their indices, and misalignment details. | |
""" | |
reference_sentences = split_into_sentences(reference_text) | |
hypothesis_sentences = split_into_sentences(hypothesis_text) | |
misaligned = [] | |
min_length = min(len(reference_sentences), len(hypothesis_sentences)) | |
# Compare sentences up to the minimum length | |
for i in range(min_length): | |
ref = reference_sentences[i] | |
hyp = hypothesis_sentences[i] | |
if ref != hyp: | |
# Find the first position where the sentences diverge | |
min_len = min(len(ref), len(hyp)) | |
misalignment_start = 0 | |
for j in range(min_len): | |
if ref[j] != hyp[j]: | |
misalignment_start = j | |
break | |
# Prepare the context for display | |
context_ref = ref[:misalignment_start] + f"**{ref[misalignment_start:]}**" | |
context_hyp = hyp[:misalignment_start] + f"**{hyp[misalignment_start:]}**" | |
misaligned.append({ | |
"index": i+1, | |
"reference": ref, | |
"hypothesis": hyp, | |
"misalignment_start": misalignment_start, | |
"context_ref": context_ref, | |
"context_hyp": context_hyp | |
}) | |
# Note any extra sentences as misaligned | |
if len(reference_sentences) > len(hypothesis_sentences): | |
for i in range(min_length, len(reference_sentences)): | |
misaligned.append({ | |
"index": i+1, | |
"reference": reference_sentences[i], | |
"hypothesis": "No corresponding sentence", | |
"misalignment_start": 0, | |
"context_ref": reference_sentences[i], | |
"context_hyp": "No corresponding sentence" | |
}) | |
elif len(hypothesis_sentences) > len(reference_sentences): | |
for i in range(min_length, len(hypothesis_sentences)): | |
misaligned.append({ | |
"index": i+1, | |
"reference": "No corresponding sentence", | |
"hypothesis": hypothesis_sentences[i], | |
"misalignment_start": 0, | |
"context_ref": "No corresponding sentence", | |
"context_hyp": hypothesis_sentences[i] | |
}) | |
return misaligned | |
def format_sentence_metrics(sentence_wers, sentence_cers, average_wer, average_cer, std_dev_wer, std_dev_cer, misaligned_sentences): | |
md = "### Sentence-level Metrics\n\n" | |
md += "#### Word Error Rate (WER)\n" | |
md += f"* Average WER: {average_wer:.2f}\n" | |
md += f"* Standard Deviation: {std_dev_wer:.2f}\n\n" | |
md += "#### Character Error Rate (CER)\n" | |
md += f"* Average CER: {average_cer:.2f}\n" | |
md += f"* Standard Deviation: {std_dev_cer:.2f}\n\n" | |
md += "### WER for Each Sentence\n\n" | |
for i, wer in enumerate(sentence_wers): | |
md += f"* Sentence {i+1}: {wer:.2f}\n" | |
md += "\n### CER for Each Sentence\n\n" | |
for i, cer in enumerate(sentence_cers): | |
md += f"* Sentence {i+1}: {cer:.2f}\n" | |
if misaligned_sentences: | |
md += "\n### Misaligned Sentences\n\n" | |
for misaligned in misaligned_sentences: | |
md += f"#### Sentence {misaligned['index']}\n" | |
md += f"* Reference: {misaligned['context_ref']}\n" | |
md += f"* Hypothesis: {misaligned['context_hyp']}\n" | |
md += f"* Misalignment starts at position: {misaligned['misalignment_start']}\n\n" | |
else: | |
md += "\n### Misaligned Sentences\n\n" | |
md += "* No misaligned sentences found." | |
return md | |
def process_files(reference_file, hypothesis_file): | |
try: | |
with open(reference_file.name, 'r') as f: | |
reference_text = f.read() | |
with open(hypothesis_file.name, 'r') as f: | |
hypothesis_text = f.read() | |
overall_wer = calculate_wer(reference_text, hypothesis_text) | |
overall_cer = calculate_cer(reference_text, hypothesis_text) | |
sentence_metrics = calculate_sentence_metrics(reference_text, hypothesis_text) | |
misaligned = identify_misaligned_sentences(reference_text, hypothesis_text) | |
return { | |
"Overall WER": overall_wer, | |
"Overall CER": overall_cer, | |
"Sentence WERs": sentence_metrics["sentence_wers"], | |
"Sentence CERs": sentence_metrics["sentence_cers"], | |
"Average WER": sentence_metrics["average_wer"], | |
"Average CER": sentence_metrics["average_cer"], | |
"Standard Deviation WER": sentence_metrics["std_dev_wer"], | |
"Standard Deviation CER": sentence_metrics["std_dev_cer"], | |
"Misaligned Sentences": misaligned | |
} | |
except Exception as e: | |
return {"error": str(e)} | |
def process_and_display(ref_file, hyp_file): | |
result = process_files(ref_file, hyp_file) | |
if "error" in result: | |
error_msg = result["error"] | |
return {"error": error_msg}, "", "" | |
metrics = { | |
"Overall WER": result["Overall WER"], | |
"Overall CER": result["Overall CER"] | |
} | |
metrics_md = format_sentence_metrics( | |
result["Sentence WERs"], | |
result["Sentence CERs"], | |
result["Average WER"], | |
result["Average CER"], | |
result["Standard Deviation WER"], | |
result["Standard Deviation CER"], | |
result["Misaligned Sentences"] | |
) | |
misaligned_md = "### Misaligned Sentences\n\n" | |
if result["Misaligned Sentences"]: | |
for misaligned in result["Misaligned Sentences"]: | |
misaligned_md += f"#### Sentence {misaligned['index']}\n" | |
misaligned_md += f"* Reference: {misaligned['context_ref']}\n" | |
misaligned_md += f"* Hypothesis: {misaligned['context_hyp']}\n" | |
misaligned_md += f"* Misalignment starts at position: {misaligned['misalignment_start']}\n\n" | |
else: | |
misaligned_md += "* No misaligned sentences found." | |
return metrics, metrics_md, misaligned_md | |
def main(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# ASR Metrics") | |
with gr.Row(): | |
reference_file = gr.File(label="Upload Reference File") | |
hypothesis_file = gr.File(label="Upload Model Output File") | |
with gr.Row(): | |
reference_preview = gr.Textbox(label="Reference Preview", lines=3) | |
hypothesis_preview = gr.Textbox(label="Hypothesis Preview", lines=3) | |
with gr.Row(): | |
compute_button = gr.Button("Compute Metrics") | |
results_output = gr.JSON(label="Results") | |
metrics_output = gr.Markdown(label="Sentence Metrics") | |
misaligned_output = gr.Markdown(label="Misaligned Sentences") | |
# Update previews when files are uploaded | |
def update_previews(ref_file, hyp_file): | |
ref_text = "" | |
hyp_text = "" | |
if ref_file: | |
with open(ref_file.name, 'r') as f: | |
ref_text = f.read()[:200] | |
if hyp_file: | |
with open(hyp_file.name, 'r') as f: | |
hyp_text = f.read()[:200] | |
return ref_text, hyp_text | |
reference_file.change( | |
fn=update_previews, | |
inputs=[reference_file, hypothesis_file], | |
outputs=[reference_preview, hypothesis_preview] | |
) | |
hypothesis_file.change( | |
fn=update_previews, | |
inputs=[reference_file, hypothesis_file], | |
outputs=[reference_preview, hypothesis_preview] | |
) | |
compute_button.click( | |
fn=process_and_display, | |
inputs=[reference_file, hypothesis_file], | |
outputs=[results_output, metrics_output, misaligned_output] | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() |