asr_metrics / app.py
akki2825's picture
add spaces
a754546 verified
raw
history blame
6.72 kB
import spaces
import jiwer
import numpy as np
import gradio as gr
import nltk
from nltk.tokenize import sent_tokenize
# Ensure NLTK data is downloaded
nltk.download('punkt')
@spaces.GPU()
def calculate_wer(reference, hypothesis):
"""
Calculate the Word Error Rate (WER) using jiwer.
"""
wer = jiwer.wer(reference, hypothesis)
return wer
@spaces.GPU()
def calculate_cer(reference, hypothesis):
"""
Calculate the Character Error Rate (CER) using jiwer.
"""
cer = jiwer.cer(reference, hypothesis)
return cer
@spaces.GPU()
def calculate_sentence_wer(reference, hypothesis):
"""
Calculate WER for each sentence and overall statistics.
"""
try:
reference_sentences = sent_tokenize(reference)
hypothesis_sentences = sent_tokenize(hypothesis)
# Get minimum number of sentences
min_sentences = min(len(reference_sentences), len(hypothesis_sentences))
# Trim to the same number of sentences
reference_sentences = reference_sentences[:min_sentences]
hypothesis_sentences = hypothesis_sentences[:min_sentences]
sentence_wers = []
for ref, hyp in zip(reference_sentences, hypothesis_sentences):
sentence_wer = jiwer.wer(ref, hyp)
sentence_wers.append(sentence_wer)
if not sentence_wers:
return {
"sentence_wers": [],
"average_wer": 0.0,
"std_dev": 0.0,
"warning": "No sentences to compare"
}
average_wer = np.mean(sentence_wers)
std_dev = np.std(sentence_wers)
# Check if there were extra sentences
if len(reference_sentences) != len(hypothesis_sentences):
warning = f"Reference has {len(reference_sentences)} sentences, " \
f"hypothesis has {len(hypothesis_sentences)} sentences. " \
f"Only compared the first {min_sentences} sentences."
else:
warning = None
return {
"sentence_wers": sentence_wers,
"average_wer": average_wer,
"std_dev": std_dev,
"warning": warning
}
except Exception as e:
return {
"sentence_wers": [],
"average_wer": 0.0,
"std_dev": 0.0,
"error": str(e)
}
@spaces.GPU()
def process_files(reference_file, hypothesis_file):
try:
with open(reference_file.name, 'r') as f:
reference_text = f.read()
with open(hypothesis_file.name, 'r') as f:
hypothesis_text = f.read()
if not reference_text or not hypothesis_text:
return {
"error": "Both reference and hypothesis files must contain text"
}
wer_value = calculate_wer(reference_text, hypothesis_text)
cer_value = calculate_cer(reference_text, hypothesis_text)
sentence_wer_stats = calculate_sentence_wer(reference_text, hypothesis_text)
return {
"WER": wer_value,
"CER": cer_value,
"Sentence WERs": sentence_wer_stats["sentence_wers"],
"Average WER": sentence_wer_stats["average_wer"],
"Standard Deviation": sentence_wer_stats["std_dev"],
"Warning": sentence_wer_stats.get("warning"),
"Error": sentence_wer_stats.get("error")
}
except Exception as e:
return {
"WER": 0.0,
"CER": 0.0,
"Sentence WERs": [],
"Average WER": 0.0,
"Standard Deviation": 0.0,
"Error": str(e)
}
def format_sentence_wer_stats(sentence_wers, average_wer, std_dev, warning, error):
md = ""
if error:
md += f"### Error\n{error}\n\n"
elif warning:
md += f"### Warning\n{warning}\n\n"
if not sentence_wers:
md += "No sentences to compare"
return md
md += "### Sentence-level WER Analysis\n\n"
md += f"* Average WER: {average_wer:.2f}\n"
md += f"* Standard Deviation: {std_dev:.2f}\n\n"
md += "### WER for Each Sentence\n\n"
for i, wer in enumerate(sentence_wers):
md += f"* Sentence {i+1}: {wer:.2f}\n"
return md
def main():
with gr.Blocks() as demo:
gr.Markdown("# ASR Metrics Calculator")
with gr.Row():
reference_file = gr.File(label="Upload Reference File")
hypothesis_file = gr.File(label="Upload Hypothesis File")
with gr.Row():
reference_preview = gr.Textbox(label="Reference Preview", lines=3)
hypothesis_preview = gr.Textbox(label="Hypothesis Preview", lines=3)
with gr.Row():
compute_button = gr.Button("Compute Metrics")
results_output = gr.JSON(label="Results")
wer_stats_output = gr.Markdown(label="WER Statistics")
# Update previews when files are uploaded
def update_previews(ref_file, hyp_file):
ref_text = ""
hyp_text = ""
if ref_file:
with open(ref_file.name, 'r') as f:
ref_text = f.read()[:200] # Show first 200 characters
if hyp_file:
with open(hyp_file.name, 'r') as f:
hyp_text = f.read()[:200] # Show first 200 characters
return ref_text, hyp_text
reference_file.change(
fn=update_previews,
inputs=[reference_file, hypothesis_file],
outputs=[reference_preview, hypothesis_preview]
)
hypothesis_file.change(
fn=update_previews,
inputs=[reference_file, hypothesis_file],
outputs=[reference_preview, hypothesis_preview]
)
def process_and_display(ref_file, hyp_file):
result = process_files(ref_file, hyp_file)
metrics = {
"WER": result["WER"],
"CER": result["CER"]
}
error = result.get("Error")
warning = result.get("Warning")
sentence_wers = result.get("Sentence WERs", [])
average_wer = result.get("Average WER", 0.0)
std_dev = result.get("Standard Deviation", 0.0)
wer_stats_md = format_sentence_wer_stats(
sentence_wers,
average_wer,
std_dev,
warning,
error
)
return metrics, wer_stats_md
compute_button.click(
fn=process_and_display,
inputs=[reference_file, hypothesis_file],
outputs=[results_output, wer_stats_output]
)
demo.launch()
if __name__ == "__main__":
main()