asr_metrics / app.py
akki2825's picture
typo
2b32a02 verified
raw
history blame
6.6 kB
import spaces
import jiwer
import numpy as np
import re
import gradio as gr
def split_into_sentences(text):
"""
Simple sentence tokenizer using regular expressions.
Splits text into sentences based on punctuation.
"""
sentences = re.split(r'(?<=[.!?])\s*', text)
sentences = [s.strip() for s in sentences if s.strip()]
return sentences
@spaces.GPU()
def calculate_wer(reference, hypothesis):
"""
Calculate the Word Error Rate (WER) using jiwer.
"""
wer = jiwer.wer(reference, hypothesis)
return wer
@spaces.GPU()
def calculate_cer(reference, hypothesis):
"""
Calculate the Character Error Rate (CER) using jiwer.
"""
cer = jiwer.cer(reference, hypothesis)
return cer
@spaces.GPU()
def calculate_sentence_metrics(reference, hypothesis):
"""
Calculate WER and CER for each sentence and overall statistics.
"""
try:
reference_sentences = split_into_sentences(reference)
hypothesis_sentences = split_into_sentences(hypothesis)
if len(reference_sentences) != len(hypothesis_sentences):
raise ValueError("Reference and hypothesis must contain the same number of sentences")
sentence_wers = []
sentence_cers = []
for ref, hyp in zip(reference_sentences, hypothesis_sentences):
wer = jiwer.wer(ref, hyp)
cer = jiwer.cer(ref, hyp)
sentence_wers.append(wer)
sentence_cers.append(cer)
if not sentence_wers or not sentence_cers:
return {
"sentence_wers": [],
"sentence_cers": [],
"average_wer": 0.0,
"average_cer": 0.0,
"std_dev_wer": 0.0,
"std_dev_cer": 0.0
}
average_wer = np.mean(sentence_wers)
average_cer = np.mean(sentence_cers)
std_dev_wer = np.std(sentence_wers)
std_dev_cer = np.std(sentence_cers)
return {
"sentence_wers": sentence_wers,
"sentence_cers": sentence_cers,
"average_wer": average_wer,
"average_cer": average_cer,
"std_dev_wer": std_dev_wer,
"std_dev_cer": std_dev_cer
}
except Exception as e:
raise e
@spaces.GPU()
def process_files(reference_file, hypothesis_file):
try:
with open(reference_file.name, 'r') as f:
reference_text = f.read()
with open(hypothesis_file.name, 'r') as f:
hypothesis_text = f.read()
overall_wer = calculate_wer(reference_text, hypothesis_text)
overall_cer = calculate_cer(reference_text, hypothesis_text)
sentence_metrics = calculate_sentence_metrics(reference_text, hypothesis_text)
return {
"Overall WER": overall_wer,
"Overall CER": overall_cer,
"Sentence WERs": sentence_metrics["sentence_wers"],
"Sentence CERs": sentence_metrics["sentence_cers"],
"Average WER": sentence_metrics["average_wer"],
"Average CER": sentence_metrics["average_cer"],
"Standard Deviation WER": sentence_metrics["std_dev_wer"],
"Standard Deviation CER": sentence_metrics["std_dev_cer"]
}
except Exception as e:
return {"error": str(e)}
def format_sentence_metrics(sentence_wers, sentence_cers, average_wer, average_cer, std_dev_wer, std_dev_cer):
if not sentence_wers and not sentence_cers:
return "All sentences match perfectly!"
md = "### Sentence-level Metrics\n\n"
md += "#### Word Error Rate (WER)\n"
md += f"* Average WER: {average_wer:.2f}\n"
md += f"* Standard Deviation: {std_dev_wer:.2f}\n\n"
md += "#### Character Error Rate (CER)\n"
md += f"* Average CER: {average_cer:.2f}\n"
md += f"* Standard Deviation: {std_dev_cer:.2f}\n\n"
md += "### WER for Each Sentence\n\n"
for i, wer in enumerate(sentence_wers):
md += f"* Sentence {i+1}: {wer:.2f}\n"
md += "\n### CER for Each Sentence\n\n"
for i, cer in enumerate(sentence_cers):
md += f"* Sentence {i+1}: {cer:.2f}\n"
return md
def main():
with gr.Blocks() as demo:
gr.Markdown("# ASR Metrics")
with gr.Row():
reference_file = gr.File(label="Upload Reference File")
hypothesis_file = gr.File(label="Upload Model Output File")
with gr.Row():
reference_preview = gr.Textbox(label="Reference Preview", lines=3)
hypothesis_preview = gr.Textbox(label="Hypothesis Preview", lines=3)
with gr.Row():
compute_button = gr.Button("Compute Metrics")
results_output = gr.JSON(label="Results")
metrics_output = gr.Markdown(label="Sentence Metrics")
# Update previews when files are uploaded
def update_previews(ref_file, hyp_file):
ref_text = ""
hyp_text = ""
if ref_file:
with open(ref_file.name, 'r') as f:
ref_text = f.read()[:200] # Show first 200 characters
if hyp_file:
with open(hyp_file.name, 'r') as f:
hyp_text = f.read()[:200] # Show first 200 characters
return ref_text, hyp_text
reference_file.change(
fn=update_previews,
inputs=[reference_file, hypothesis_file],
outputs=[reference_preview, hypothesis_preview]
)
hypothesis_file.change(
fn=update_previews,
inputs=[reference_file, hypothesis_file],
outputs=[reference_preview, hypothesis_preview]
)
def process_and_display(ref_file, hyp_file):
result = process_files(ref_file, hyp_file)
if "error" in result:
return {}, {}, "Error: " + result["error"]
metrics = {
"Overall WER": result["Overall WER"],
"Overall CER": result["Overall CER"]
}
metrics_md = format_sentence_metrics(
result["Sentence WERs"],
result["Sentence CERs"],
result["Average WER"],
result["Average CER"],
result["Standard Deviation WER"],
result["Standard Deviation CER"]
)
return metrics, metrics_md
compute_button.click(
fn=process_and_display,
inputs=[reference_file, hypothesis_file],
outputs=[results_output, metrics_output]
)
demo.launch()
if __name__ == "__main__":
main()