Spaces:
Sleeping
Sleeping
File size: 7,337 Bytes
43380b2 af4ff35 aff6746 af4ff35 89fac21 43380b2 af4ff35 45347ae aff6746 43380b2 aff6746 45347ae aff6746 43380b2 905ed31 45347ae 7115c2e 45347ae 01e4e5c 45347ae dd10290 01e4e5c dd10290 45347ae d41eee2 dd10290 d41eee2 dd10290 45347ae d41eee2 01e4e5c 45347ae 01e4e5c d41eee2 01e4e5c 8c82d13 45347ae 7115c2e dd10290 7115c2e 45347ae 7115c2e dd10290 7115c2e 45347ae 7115c2e 01e4e5c 45347ae 01e4e5c 45347ae 01e4e5c 45347ae 01e4e5c 45347ae 01e4e5c 45347ae 01e4e5c ea825ae 613c9f7 ea825ae 45347ae ea825ae dc5a821 aff6746 dc5a821 45347ae aff6746 dc5a821 45347ae 7115c2e 45347ae 7115c2e aff6746 45347ae 0115245 45347ae 0115245 45347ae 0115245 45347ae aff6746 af4ff35 aff6746 0115245 aff6746 b4734b3 aff6746 45347ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import spaces
import jiwer
import numpy as np
import gradio as gr
@spaces.GPU()
def calculate_wer(reference, hypothesis):
reference_str = " ".join(reference)
hypothesis_str = " ".join(hypothesis)
return jiwer.wer(reference_str, hypothesis_str)
@spaces.GPU()
def calculate_cer(reference, hypothesis):
reference_str = " ".join(reference)
hypothesis_str = " ".join(hypothesis)
return jiwer.cer(reference_str, hypothesis_str)
@spaces.GPU()
def calculate_sentence_metrics(reference, hypothesis):
reference_sentences = [line.strip() for line in reference]
hypothesis_sentences = [line.strip() for line in hypothesis]
sentence_wers = []
sentence_cers = []
min_length = min(len(reference_sentences), len(hypothesis_sentences))
for i in range(min_length):
ref = reference_sentences[i]
hyp = hypothesis_sentences[i]
wer = jiwer.wer(ref, hyp)
cer = jiwer.cer(ref, hyp)
sentence_wers.append(wer)
sentence_cers.append(cer)
average_wer = np.mean(sentence_wers) if sentence_wers else 0.0
std_dev_wer = np.std(sentence_wers) if sentence_wers else 0.0
average_cer = np.mean(sentence_cers) if sentence_cers else 0.0
std_dev_cer = np.std(sentence_cers) if sentence_cers else 0.0
return {
"sentence_wers": sentence_wers,
"sentence_cers": sentence_cers,
"average_wer": average_wer,
"average_cer": average_cer,
"std_dev_wer": std_dev_wer,
"std_dev_cer": std_dev_cer
}
def identify_misaligned_sentences(reference, hypothesis):
reference_sentences = [line.strip() for line in reference]
hypothesis_sentences = [line.strip() for line in hypothesis]
misaligned = []
for i, (ref, hyp) in enumerate(zip(reference_sentences, hypothesis_sentences)):
if ref != hyp:
ref_words = ref.split()
hyp_words = hyp.split()
min_length = min(len(ref_words), len(hyp_words))
misalignment_start = 0
for j in range(min_length):
if ref_words[j] != hyp_words[j]:
misalignment_start = j
break
context_ref = ' '.join(ref_words[:misalignment_start] + ['**' + ref_words[misalignment_start] + '**']) if ref_words else ""
context_hyp = ' '.join(hyp_words[:misalignment_start] + ['**' + hyp_words[misalignment_start] + '**']) if hyp_words else ""
misaligned.append({
"index": i + 1,
"reference": ref,
"hypothesis": hyp,
"misalignment_start": misalignment_start,
"context_ref": context_ref,
"context_hyp": context_hyp
})
# Handle extra sentences
if len(reference_sentences) > len(hypothesis_sentences):
for i in range(len(hypothesis_sentences), len(reference_sentences)):
misaligned.append({
"index": i + 1,
"reference": reference_sentences[i],
"hypothesis": "No corresponding sentence",
"misalignment_start": 0,
"context_ref": reference_sentences[i],
"context_hyp": "No corresponding sentence"
})
elif len(hypothesis_sentences) > len(reference_sentences):
for i in range(len(reference_sentences), len(hypothesis_sentences)):
misaligned.append({
"index": i + 1,
"reference": "No corresponding sentence",
"hypothesis": hypothesis_sentences[i],
"misalignment_start": 0,
"context_ref": "No corresponding sentence",
"context_hyp": hypothesis_sentences[i]
})
return misaligned
def format_sentence_metrics(sentence_wers, sentence_cers, average_wer, average_cer, std_dev_wer, std_dev_cer):
md = "### Sentence-level Metrics\n\n"
md += f"**Average WER**: {average_wer:.2f}\n\n"
md += f"**Standard Deviation WER**: {std_dev_wer:.2f}\n\n"
md += f"**Average CER**: {average_cer:.2f}\n\n"
md += f"**Standard Deviation CER**: {std_dev_cer:.2f}\n\n"
md += "---\n**WER by Sentence**\n"
for i, wer in enumerate(sentence_wers):
md += f"- Sentence {i+1}: {wer:.2f}\n"
md += "\n**CER by Sentence**\n"
for i, cer in enumerate(sentence_cers):
md += f"- Sentence {i+1}: {cer:.2f}\n"
return md
def process_files(reference_file, hypothesis_file):
try:
with open(reference_file.name, 'r', encoding='utf-8') as f:
reference_text = f.read().splitlines()
with open(hypothesis_file.name, 'r', encoding='utf-8') as f:
hypothesis_text = f.read().splitlines()
overall_wer = calculate_wer(reference_text, hypothesis_text)
overall_cer = calculate_cer(reference_text, hypothesis_text)
sentence_metrics = calculate_sentence_metrics(reference_text, hypothesis_text)
misaligned = identify_misaligned_sentences(reference_text, hypothesis_text)
return {
"Overall WER": overall_wer,
"Overall CER": overall_cer,
**sentence_metrics,
"Misaligned Sentences": misaligned
}
except Exception as e:
return {"error": str(e)}
def process_and_display(ref_file, hyp_file):
result = process_files(ref_file, hyp_file)
if "error" in result:
return {"error": result["error"]}, "", ""
metrics = {
"Overall WER": result["Overall WER"],
"Overall CER": result["Overall CER"]
}
metrics_md = format_sentence_metrics(
result["sentence_wers"],
result["sentence_cers"],
result["average_wer"],
result["average_cer"],
result["std_dev_wer"],
result["std_dev_cer"]
)
misaligned_md = "### Misaligned Sentences\n\n"
if result["Misaligned Sentences"]:
for mis in result["Misaligned Sentences"]:
misaligned_md += f"**Sentence {mis['index']}**\n"
misaligned_md += f"- Reference: {mis['context_ref']}\n"
misaligned_md += f"- Hypothesis: {mis['context_hyp']}\n"
misaligned_md += f"- Misalignment starts at position: {mis['misalignment_start']}\n\n"
else:
misaligned_md += "* No misaligned sentences found."
return metrics, metrics_md, misaligned_md
def main():
with gr.Blocks() as demo:
gr.Markdown("# π ASR Metrics Analysis Tool")
with gr.Row():
with gr.Column():
gr.Markdown("### Upload your reference and hypothesis files")
reference_file = gr.File(label="Reference File (.txt)")
hypothesis_file = gr.File(label="Hypothesis File (.txt)")
compute_button = gr.Button("Compute Metrics", variant="primary")
with gr.Column():
results_output = gr.JSON(label="Results Summary")
metrics_output = gr.Markdown(label="Sentence Metrics")
misaligned_output = gr.Markdown(label="Misaligned Sentences")
compute_button.click(
fn=process_and_display,
inputs=[reference_file, hypothesis_file],
outputs=[results_output, metrics_output, misaligned_output]
)
demo.launch(ssr_mode=False)
if __name__ == "__main__":
main()
|