Spaces:
Sleeping
Sleeping
import spaces | |
import jiwer | |
import numpy as np | |
import re | |
import gradio as gr | |
def split_into_sentences(text): | |
""" | |
Simple sentence tokenizer using regular expressions. | |
Splits text into sentences based on punctuation. | |
""" | |
sentences = re.split(r'(?<=[.!?])\s*', text) | |
sentences = [s.strip() for s in sentences if s.strip()] | |
return sentences | |
def calculate_wer(reference, hypothesis): | |
""" | |
Calculate the Word Error Rate (WER) using jiwer. | |
""" | |
wer = jiwer.wer(reference, hypothesis) | |
return wer | |
def calculate_cer(reference, hypothesis): | |
""" | |
Calculate the Character Error Rate (CER) using jiwer. | |
""" | |
cer = jiwer.cer(reference, hypothesis) | |
return cer | |
def calculate_sentence_metrics(reference, hypothesis): | |
""" | |
Calculate WER and CER for each sentence and overall statistics. | |
""" | |
try: | |
reference_sentences = split_into_sentences(reference) | |
hypothesis_sentences = split_into_sentences(hypothesis) | |
if len(reference_sentences) != len(hypothesis_sentences): | |
raise ValueError("Reference and hypothesis must contain the same number of sentences") | |
sentence_wers = [] | |
sentence_cers = [] | |
for ref, hyp in zip(reference_sentences, hypothesis_sentences): | |
wer = jiwer.wer(ref, hyp) | |
cer = jiwer.cer(ref, hyp) | |
sentence_wers.append(wer) | |
sentence_cers.append(cer) | |
if not sentence_wers or not sentence_cers: | |
return { | |
"sentence_wers": [], | |
"sentence_cers": [], | |
"average_wer": 0.0, | |
"average_cer": 0.0, | |
"std_dev_wer": 0.0, | |
"std_dev_cer": 0.0 | |
} | |
average_wer = np.mean(sentence_wers) | |
average_cer = np.mean(sentence_cers) | |
std_dev_wer = np.std(sentence_wers) | |
std_dev_cer = np.std(sentence_cers) | |
return { | |
"sentence_wers": sentence_wers, | |
"sentence_cers": sentence_cers, | |
"average_wer": average_wer, | |
"average_cer": average_cer, | |
"std_dev_wer": std_dev_wer, | |
"std_dev_cer": std_dev_cer | |
} | |
except Exception as e: | |
raise e | |
def process_files(reference_file, hypothesis_file): | |
try: | |
with open(reference_file.name, 'r') as f: | |
reference_text = f.read() | |
with open(hypothesis_file.name, 'r') as f: | |
hypothesis_text = f.read() | |
overall_wer = calculate_wer(reference_text, hypothesis_text) | |
overall_cer = calculate_cer(reference_text, hypothesis_text) | |
sentence_metrics = calculate_sentence_metrics(reference_text, hypothesis_text) | |
return { | |
"Overall WER": overall_wer, | |
"Overall CER": overall_cer, | |
"Sentence WERs": sentence_metrics["sentence_wers"], | |
"Sentence CERs": sentence_metrics["sentence_cers"], | |
"Average WER": sentence_metrics["average_wer"], | |
"Average CER": sentence_metrics["average_cer"], | |
"Standard Deviation WER": sentence_metrics["std_dev_wer"], | |
"Standard Deviation CER": sentence_metrics["std_dev_cer"] | |
} | |
except Exception as e: | |
return {"error": str(e)} | |
def format_sentence_metrics(sentence_wers, sentence_cers, average_wer, average_cer, std_dev_wer, std_dev_cer): | |
if not sentence_wers and not sentence_cers: | |
return "All sentences match perfectly!" | |
md = "### Sentence-level Metrics\n\n" | |
md += "#### Word Error Rate (WER)\n" | |
md += f"* Average WER: {average_wer:.2f}\n" | |
md += f"* Standard Deviation: {std_dev_wer:.2f}\n\n" | |
md += "#### Character Error Rate (CER)\n" | |
md += f"* Average CER: {average_cer:.2f}\n" | |
md += f"* Standard Deviation: {std_dev_cer:.2f}\n\n" | |
md += "### WER for Each Sentence\n\n" | |
for i, wer in enumerate(sentence_wers): | |
md += f"* Sentence {i+1}: {wer:.2f}\n" | |
md += "\n### CER for Each Sentence\n\n" | |
for i, cer in enumerate(sentence_cers): | |
md += f"* Sentence {i+1}: {cer:.2f}\n" | |
return md | |
def main(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# ASR Metrics") | |
with gr.Row(): | |
reference_file = gr.File(label="Upload Reference File") | |
hypothesis_file = gr.File(label="Upload Model Output File") | |
with gr.Row(): | |
reference_preview = gr.Textbox(label="Reference Preview", lines=3) | |
hypothesis_preview = gr.Textbox(label="Hypothesis Preview", lines=3) | |
with gr.Row(): | |
compute_button = gr.Button("Compute Metrics") | |
results_output = gr.JSON(label="Results") | |
metrics_output = gr.Markdown(label="Sentence Metrics") | |
# Update previews when files are uploaded | |
def update_previews(ref_file, hyp_file): | |
ref_text = "" | |
hyp_text = "" | |
if ref_file: | |
with open(ref_file.name, 'r') as f: | |
ref_text = f.read()[:200] # Show first 200 characters | |
if hyp_file: | |
with open(hyp_file.name, 'r') as f: | |
hyp_text = f.read()[:200] # Show first 200 characters | |
return ref_text, hyp_text | |
reference_file.change( | |
fn=update_previews, | |
inputs=[reference_file, hypothesis_file], | |
outputs=[reference_preview, hypothesis_preview] | |
) | |
hypothesis_file.change( | |
fn=update_previews, | |
inputs=[reference_file, hypothesis_file], | |
outputs=[reference_preview, hypothesis_preview] | |
) | |
def process_and_display(ref_file, hyp_file): | |
result = process_files(ref_file, hyp_file) | |
if "error" in result: | |
return {}, {}, "Error: " + result["error"] | |
metrics = { | |
"Overall WER": result["Overall WER"], | |
"Overall CER": result["Overall CER"] | |
} | |
metrics_md = format_sentence_metrics( | |
result["Sentence WERs"], | |
result["Sentence CERs"], | |
result["Average WER"], | |
result["Average CER"], | |
result["Standard Deviation WER"], | |
result["Standard Deviation CER"] | |
) | |
return metrics, metrics_md | |
compute_button.click( | |
fn=process_and_display, | |
inputs=[reference_file, hypothesis_file], | |
outputs=[results_output, metrics_output] | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |