asr_metrics / app.py
akki2825's picture
misalignment at the word level
dd10290 verified
raw
history blame
10.8 kB
import spaces
import jiwer
import numpy as np
import re
import gradio as gr
def split_into_sentences(text):
"""
Simple sentence tokenizer using regular expressions.
Splits text into sentences based on punctuation.
"""
sentences = re.split(r'(?<=[.!?])\s*', text)
sentences = [s.strip() for s in sentences if s.strip()]
return sentences
@spaces.GPU()
def calculate_wer(reference, hypothesis):
"""
Calculate the Word Error Rate (WER) using jiwer.
"""
wer = jiwer.wer(reference, hypothesis)
return wer
@spaces.GPU()
def calculate_cer(reference, hypothesis):
"""
Calculate the Character Error Rate (CER) using jiwer.
"""
cer = jiwer.cer(reference, hypothesis)
return cer
@spaces.GPU()
def calculate_sentence_metrics(reference, hypothesis):
"""
Calculate WER and CER for each sentence and overall statistics.
Handles cases where the number of sentences differ.
"""
try:
reference_sentences = split_into_sentences(reference)
hypothesis_sentences = split_into_sentences(hypothesis)
sentence_wers = []
sentence_cers = []
min_length = min(len(reference_sentences), len(hypothesis_sentences))
for i in range(min_length):
ref = reference_sentences[i]
hyp = hypothesis_sentences[i]
wer = jiwer.wer(ref, hyp)
cer = jiwer.cer(ref, hyp)
sentence_wers.append(wer)
sentence_cers.append(cer)
# Calculate overall statistics
if sentence_wers:
average_wer = np.mean(sentence_wers)
std_dev_wer = np.std(sentence_wers)
else:
average_wer = 0.0
std_dev_wer = 0.0
if sentence_cers:
average_cer = np.mean(sentence_cers)
std_dev_cer = np.std(sentence_cers)
else:
average_cer = 0.0
std_dev_cer = 0.0
return {
"sentence_wers": sentence_wers,
"sentence_cers": sentence_cers,
"average_wer": average_wer,
"average_cer": average_cer,
"std_dev_wer": std_dev_wer,
"std_dev_cer": std_dev_cer
}
except Exception as e:
raise e
def identify_misaligned_sentences(reference_text, hypothesis_text):
"""
Identify sentences that don't match between reference and hypothesis.
Returns a dictionary with misaligned sentence pairs, their indices, and misalignment details.
"""
reference_sentences = split_into_sentences(reference_text)
hypothesis_sentences = split_into_sentences(hypothesis_text)
misaligned = []
for i, (ref, hyp) in enumerate(zip(reference_sentences, hypothesis_sentences)):
if ref != hyp:
# Split sentences into words
ref_words = ref.split()
hyp_words = hyp.split()
# Find the first position where the words diverge
min_length = min(len(ref_words), len(hyp_words))
misalignment_start = 0
for j in range(min_length):
if ref_words[j] != hyp_words[j]:
misalignment_start = j
break
# Prepare the context for display
context_ref = ' '.join(ref_words[:misalignment_start] + ['**' + ref_words[misalignment_start] + '**'])
context_hyp = ' '.join(hyp_words[:misalignment_start] + ['**' + hyp_words[misalignment_start] + '**'])
misaligned.append({
"index": i+1,
"reference": ref,
"hypothesis": hyp,
"misalignment_start": misalignment_start,
"context_ref": context_ref,
"context_hyp": context_hyp
})
# Handle cases where the number of sentences differs
if len(reference_sentences) > len(hypothesis_sentences):
for i in range(len(hypothesis_sentences), len(reference_sentences)):
misaligned.append({
"index": i+1,
"reference": reference_sentences[i],
"hypothesis": "No corresponding sentence",
"misalignment_start": 0,
"context_ref": reference_sentences[i],
"context_hyp": "No corresponding sentence"
})
elif len(hypothesis_sentences) > len(reference_sentences):
for i in range(len(reference_sentences), len(hypothesis_sentences)):
misaligned.append({
"index": i+1,
"reference": "No corresponding sentence",
"hypothesis": hypothesis_sentences[i],
"misalignment_start": 0,
"context_ref": "No corresponding sentence",
"context_hyp": hypothesis_sentences[i]
})
return misaligned
def format_sentence_metrics(sentence_wers, sentence_cers, average_wer, average_cer, std_dev_wer, std_dev_cer, misaligned_sentences):
md = "### Sentence-level Metrics\n\n"
md += "#### Word Error Rate (WER)\n"
md += f"* Average WER: {average_wer:.2f}\n"
md += f"* Standard Deviation: {std_dev_wer:.2f}\n\n"
md += "#### Character Error Rate (CER)\n"
md += f"* Average CER: {average_cer:.2f}\n"
md += f"* Standard Deviation: {std_dev_cer:.2f}\n\n"
md += "### WER for Each Sentence\n\n"
for i, wer in enumerate(sentence_wers):
md += f"* Sentence {i+1}: {wer:.2f}\n"
md += "\n### CER for Each Sentence\n\n"
for i, cer in enumerate(sentence_cers):
md += f"* Sentence {i+1}: {cer:.2f}\n"
if misaligned_sentences:
md += "\n### Misaligned Sentences\n\n"
for misaligned in misaligned_sentences:
md += f"#### Sentence {misaligned['index']}\n"
md += f"* Reference: {misaligned['context_ref']}\n"
md += f"* Hypothesis: {misaligned['context_hyp']}\n"
md += f"* Misalignment starts at word: {misaligned['misalignment_start'] + 1}\n\n"
else:
md += "\n### Misaligned Sentences\n\n"
md += "* No misaligned sentences found."
return md
@spaces.GPU()
def process_files(reference_file, hypothesis_file):
try:
with open(reference_file.name, 'r') as f:
reference_text = f.read()
with open(hypothesis_file.name, 'r') as f:
hypothesis_text = f.read()
overall_wer = calculate_wer(reference_text, hypothesis_text)
overall_cer = calculate_cer(reference_text, hypothesis_text)
sentence_metrics = calculate_sentence_metrics(reference_text, hypothesis_text)
misaligned = identify_misaligned_sentences(reference_text, hypothesis_text)
return {
"Overall WER": overall_wer,
"Overall CER": overall_cer,
"Sentence WERs": sentence_metrics["sentence_wers"],
"Sentence CERs": sentence_metrics["sentence_cers"],
"Average WER": sentence_metrics["average_wer"],
"Average CER": sentence_metrics["average_cer"],
"Standard Deviation WER": sentence_metrics["std_dev_wer"],
"Standard Deviation CER": sentence_metrics["std_dev_cer"],
"Misaligned Sentences": misaligned
}
except Exception as e:
return {"error": str(e)}
def process_and_display(ref_file, hyp_file):
result = process_files(ref_file, hyp_file)
if "error" in result:
error_msg = result["error"]
return {"error": error_msg}, "", ""
metrics = {
"Overall WER": result["Overall WER"],
"Overall CER": result["Overall CER"]
}
metrics_md = format_sentence_metrics(
result["Sentence WERs"],
result["Sentence CERs"],
result["Average WER"],
result["Average CER"],
result["Standard Deviation WER"],
result["Standard Deviation CER"],
result["Misaligned Sentences"]
)
return metrics, metrics_md
def process_and_display(ref_file, hyp_file):
result = process_files(ref_file, hyp_file)
if "error" in result:
error_msg = result["error"]
return {"error": error_msg}, "", ""
metrics = {
"Overall WER": result["Overall WER"],
"Overall CER": result["Overall CER"]
}
metrics_md = format_sentence_metrics(
result["Sentence WERs"],
result["Sentence CERs"],
result["Average WER"],
result["Average CER"],
result["Standard Deviation WER"],
result["Standard Deviation CER"],
result["Misaligned Sentences"]
)
misaligned_md = "### Misaligned Sentences\n\n"
if result["Misaligned Sentences"]:
for misaligned in result["Misaligned Sentences"]:
misaligned_md += f"#### Sentence {misaligned['index']}\n"
misaligned_md += f"* Reference: {misaligned['context_ref']}\n"
misaligned_md += f"* Hypothesis: {misaligned['context_hyp']}\n"
misaligned_md += f"* Misalignment starts at position: {misaligned['misalignment_start']}\n\n"
else:
misaligned_md += "* No misaligned sentences found."
return metrics, metrics_md, misaligned_md
def main():
with gr.Blocks() as demo:
gr.Markdown("# ASR Metrics")
with gr.Row():
reference_file = gr.File(label="Upload Reference File")
hypothesis_file = gr.File(label="Upload Model Output File")
with gr.Row():
reference_preview = gr.Textbox(label="Reference Preview", lines=3)
hypothesis_preview = gr.Textbox(label="Hypothesis Preview", lines=3)
with gr.Row():
compute_button = gr.Button("Compute Metrics")
results_output = gr.JSON(label="Results")
metrics_output = gr.Markdown(label="Metrics")
# Update previews when files are uploaded
def update_previews(ref_file, hyp_file):
ref_text = ""
hyp_text = ""
if ref_file:
with open(ref_file.name, 'r') as f:
ref_text = f.read()[:200] # Show first 200 characters
if hyp_file:
with open(hyp_file.name, 'r') as f:
hyp_text = f.read()[:200] # Show first 200 characters
return ref_text, hyp_text
reference_file.change(
fn=update_previews,
inputs=[reference_file, hypothesis_file],
outputs=[reference_preview, hypothesis_preview]
)
hypothesis_file.change(
fn=update_previews,
inputs=[reference_file, hypothesis_file],
outputs=[reference_preview, hypothesis_preview]
)
compute_button.click(
fn=process_and_display,
inputs=[reference_file, hypothesis_file],
outputs=[results_output, metrics_output]
)
demo.launch()
if __name__ == "__main__":
main()