Spaces:
Sleeping
Sleeping
import jiwer | |
import spaces | |
import numpy as np | |
import gradio as gr | |
def calculate_wer(reference, hypothesis): | |
""" | |
Calculate the Word Error Rate (WER) using jiwer. | |
""" | |
wer = jiwer.wer(reference, hypothesis) | |
return wer | |
def calculate_cer(reference, hypothesis): | |
""" | |
Calculate the Character Error Rate (CER) using jiwer. | |
""" | |
cer = jiwer.cer(reference, hypothesis) | |
return cer | |
def calculate_sentence_wer(reference, hypothesis): | |
""" | |
Calculate WER for each sentence and overall statistics. | |
""" | |
reference_sentences = jiwer.split_into_sentences(reference) | |
hypothesis_sentences = jiwer.split_into_sentences(hypothesis) | |
if len(reference_sentences) != len(hypothesis_sentences): | |
raise ValueError("Reference and hypothesis must contain the same number of sentences") | |
sentence_wers = [] | |
for ref, hyp in zip(reference_sentences, hypothesis_sentences): | |
sentence_wer = jiwer.wer(ref, hyp) | |
sentence_wers.append(sentence_wer) | |
if not sentence_wers: | |
return { | |
"sentence_wers": [], | |
"average_wer": 0.0, | |
"std_dev": 0.0 | |
} | |
average_wer = np.mean(sentence_wers) | |
std_dev = np.std(sentence_wers) | |
return { | |
"sentence_wers": sentence_wers, | |
"average_wer": average_wer, | |
"std_dev": std_dev | |
} | |
def process_files(reference_file, hypothesis_file): | |
try: | |
with open(reference_file.name, 'r') as f: | |
reference_text = f.read() | |
with open(hypothesis_file.name, 'r') as f: | |
hypothesis_text = f.read() | |
wer_value = calculate_wer(reference_text, hypothesis_text) | |
cer_value = calculate_cer(reference_text, hypothesis_text) | |
sentence_wer_stats = calculate_sentence_wer(reference_text, hypothesis_text) | |
return { | |
"WER": wer_value, | |
"CER": cer_value, | |
"Sentence WERs": sentence_wer_stats["sentence_wers"], | |
"Average WER": sentence_wer_stats["average_wer"], | |
"Standard Deviation": sentence_wer_stats["std_dev"] | |
} | |
except Exception as e: | |
return {"error": str(e)} | |
def format_sentence_wer_stats(sentence_wers, average_wer, std_dev): | |
if not sentence_wers: | |
return "All sentences match perfectly!" | |
md = "### Sentence-level WER Analysis\n\n" | |
md += f"* Average WER: {average_wer:.2f}\n" | |
md += f"* Standard Deviation: {std_dev:.2f}\n\n" | |
md += "### WER for Each Sentence\n\n" | |
for i, wer in enumerate(sentence_wers): | |
md += f"* Sentence {i+1}: {wer:.2f}\n" | |
return md | |
def main(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# ASR Metrics Calculator") | |
with gr.Row(): | |
reference_file = gr.File(label="Upload Reference File") | |
hypothesis_file = gr.File(label="Upload Hypothesis File") | |
with gr.Row(): | |
reference_preview = gr.Textbox(label="Reference Preview", lines=3) | |
hypothesis_preview = gr.Textbox(label="Hypothesis Preview", lines=3) | |
with gr.Row(): | |
compute_button = gr.Button("Compute Metrics") | |
results_output = gr.JSON(label="Results") | |
wer_stats_output = gr.Markdown(label="WER Statistics") | |
# Update previews when files are uploaded | |
def update_previews(ref_file, hyp_file): | |
ref_text = "" | |
hyp_text = "" | |
if ref_file: | |
with open(ref_file.name, 'r') as f: | |
ref_text = f.read()[:200] # Show first 200 characters | |
if hyp_file: | |
with open(hyp_file.name, 'r') as f: | |
hyp_text = f.read()[:200] # Show first 200 characters | |
return ref_text, hyp_text | |
reference_file.change( | |
fn=update_previews, | |
inputs=[reference_file, hypothesis_file], | |
outputs=[reference_preview, hypothesis_preview] | |
) | |
hypothesis_file.change( | |
fn=update_previews, | |
inputs=[reference_file, hypothesis_file], | |
outputs=[reference_preview, hypothesis_preview] | |
) | |
def process_and_display(ref_file, hyp_file): | |
result = process_files(ref_file, hyp_file) | |
if "error" in result: | |
return {}, {}, "Error: " + result["error"] | |
metrics = { | |
"WER": result["WER"], | |
"CER": result["CER"] | |
} | |
wer_stats_md = format_sentence_wer_stats( | |
result["Sentence WERs"], | |
result["Average WER"], | |
result["Standard Deviation"] | |
) | |
return metrics, wer_stats_md | |
compute_button.click( | |
fn=process_and_display, | |
inputs=[reference_file, hypothesis_file], | |
outputs=[results_output, wer_stats_output] | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |