akki2825 commited on
Commit
45347ae
·
verified ·
1 Parent(s): 85366cd

fix readlines issue

Browse files
Files changed (1) hide show
  1. app.py +82 -179
app.py CHANGED
@@ -1,112 +1,71 @@
1
- import spaces
2
  import jiwer
3
  import numpy as np
4
- import re
5
  import gradio as gr
6
 
7
- def split_into_sentences(text):
8
- """
9
- Simple sentence tokenizer using regular expressions.
10
- Splits text into sentences based on punctuation.
11
- """
12
- sentences = re.split(r'(?<=[.!?])\s+', text)
13
- sentences = [s.strip() for s in sentences if s.strip()]
14
- return sentences
15
-
16
- @spaces.GPU()
17
  def calculate_wer(reference, hypothesis):
18
- """
19
- Calculate the Word Error Rate (WER) using jiwer.
20
- """
21
- wer = jiwer.wer(reference, hypothesis)
22
- return wer
23
 
24
- @spaces.GPU()
25
  def calculate_cer(reference, hypothesis):
26
- """
27
- Calculate the Character Error Rate (CER) using jiwer.
28
- """
29
- cer = jiwer.cer(reference, hypothesis)
30
- return cer
31
 
32
- @spaces.GPU()
33
  def calculate_sentence_metrics(reference, hypothesis):
34
- """
35
- Calculate WER and CER for each sentence and overall statistics.
36
- Handles cases where the number of sentences differ.
37
- """
38
- try:
39
- reference_sentences = reference
40
- hypothesis_sentences = hypothesis
41
-
42
- sentence_wers = []
43
- sentence_cers = []
44
- min_length = min(len(reference_sentences), len(hypothesis_sentences))
45
-
46
- for i in range(min_length):
47
- ref = reference_sentences[i]
48
- hyp = hypothesis_sentences[i]
49
-
50
- wer = jiwer.wer(ref, hyp)
51
- cer = jiwer.cer(ref, hyp)
52
- sentence_wers.append(wer)
53
- sentence_cers.append(cer)
54
-
55
- # Calculate overall statistics
56
- if sentence_wers:
57
- average_wer = np.mean(sentence_wers)
58
- std_dev_wer = np.std(sentence_wers)
59
- else:
60
- average_wer = 0.0
61
- std_dev_wer = 0.0
62
-
63
- if sentence_cers:
64
- average_cer = np.mean(sentence_cers)
65
- std_dev_cer = np.std(sentence_cers)
66
- else:
67
- average_cer = 0.0
68
- std_dev_cer = 0.0
69
-
70
- return {
71
- "sentence_wers": sentence_wers,
72
- "sentence_cers": sentence_cers,
73
- "average_wer": average_wer,
74
- "average_cer": average_cer,
75
- "std_dev_wer": std_dev_wer,
76
- "std_dev_cer": std_dev_cer
77
- }
78
- except Exception as e:
79
- raise e
80
 
81
- def identify_misaligned_sentences(reference_text, hypothesis_text):
82
- """
83
- Identify sentences that don't match between reference and hypothesis.
84
- Returns a dictionary with misaligned sentence pairs, their indices, and misalignment details.
85
- """
86
- reference_sentences = reference_text
87
- hypothesis_sentences = hypothesis_text
88
 
89
  misaligned = []
 
90
  for i, (ref, hyp) in enumerate(zip(reference_sentences, hypothesis_sentences)):
91
  if ref != hyp:
92
- # Split sentences into words
93
  ref_words = ref.split()
94
  hyp_words = hyp.split()
95
-
96
- # Find the first position where the words diverge
97
  min_length = min(len(ref_words), len(hyp_words))
 
98
  misalignment_start = 0
99
  for j in range(min_length):
100
  if ref_words[j] != hyp_words[j]:
101
  misalignment_start = j
102
  break
103
 
104
- # Prepare the context for display
105
- context_ref = ' '.join(ref_words[:misalignment_start] + ['**' + ref_words[misalignment_start] + '**'])
106
- context_hyp = ' '.join(hyp_words[:misalignment_start] + ['**' + hyp_words[misalignment_start] + '**'])
107
 
108
  misaligned.append({
109
- "index": i+1,
110
  "reference": ref,
111
  "hypothesis": hyp,
112
  "misalignment_start": misalignment_start,
@@ -114,11 +73,11 @@ def identify_misaligned_sentences(reference_text, hypothesis_text):
114
  "context_hyp": context_hyp
115
  })
116
 
117
- # Handle cases where the number of sentences differs
118
  if len(reference_sentences) > len(hypothesis_sentences):
119
  for i in range(len(hypothesis_sentences), len(reference_sentences)):
120
  misaligned.append({
121
- "index": i+1,
122
  "reference": reference_sentences[i],
123
  "hypothesis": "No corresponding sentence",
124
  "misalignment_start": 0,
@@ -128,7 +87,7 @@ def identify_misaligned_sentences(reference_text, hypothesis_text):
128
  elif len(hypothesis_sentences) > len(reference_sentences):
129
  for i in range(len(reference_sentences), len(hypothesis_sentences)):
130
  misaligned.append({
131
- "index": i+1,
132
  "reference": "No corresponding sentence",
133
  "hypothesis": hypothesis_sentences[i],
134
  "misalignment_start": 0,
@@ -138,45 +97,27 @@ def identify_misaligned_sentences(reference_text, hypothesis_text):
138
 
139
  return misaligned
140
 
141
-
142
- def format_sentence_metrics(sentence_wers, sentence_cers, average_wer, average_cer, std_dev_wer, std_dev_cer, misaligned_sentences):
143
  md = "### Sentence-level Metrics\n\n"
144
- md += "#### Word Error Rate (WER)\n"
145
- md += f"* Average WER: {average_wer:.2f}\n"
146
- md += f"* Standard Deviation: {std_dev_wer:.2f}\n\n"
147
- md += "#### Character Error Rate (CER)\n"
148
- md += f"* Average CER: {average_cer:.2f}\n"
149
- md += f"* Standard Deviation: {std_dev_cer:.2f}\n\n"
150
-
151
- md += "### WER for Each Sentence\n\n"
152
  for i, wer in enumerate(sentence_wers):
153
- md += f"* Sentence {i+1}: {wer:.2f}\n"
154
 
155
- md += "\n### CER for Each Sentence\n\n"
156
  for i, cer in enumerate(sentence_cers):
157
- md += f"* Sentence {i+1}: {cer:.2f}\n"
158
-
159
- # if misaligned_sentences:
160
- # md += "\n### Misaligned Sentences\n\n"
161
- # for misaligned in misaligned_sentences:
162
- # md += f"#### Sentence {misaligned['index']}\n"
163
- # md += f"* Reference: {misaligned['context_ref']}\n"
164
- # md += f"* Hypothesis: {misaligned['context_hyp']}\n"
165
- # md += f"* Misalignment starts at word: {misaligned['misalignment_start'] + 1}\n\n"
166
- # else:
167
- # md += "\n### Misaligned Sentences\n\n"
168
- # md += "* No misaligned sentences found."
169
 
170
  return md
171
 
172
- @spaces.GPU()
173
  def process_files(reference_file, hypothesis_file):
174
  try:
175
- with open(reference_file.name, 'r') as f:
176
- reference_text = f.readlines()
177
-
178
- with open(hypothesis_file.name, 'r') as f:
179
- hypothesis_text = f.readlines()
180
 
181
  overall_wer = calculate_wer(reference_text, hypothesis_text)
182
  overall_cer = calculate_cer(reference_text, hypothesis_text)
@@ -186,24 +127,17 @@ def process_files(reference_file, hypothesis_file):
186
  return {
187
  "Overall WER": overall_wer,
188
  "Overall CER": overall_cer,
189
- "Sentence WERs": sentence_metrics["sentence_wers"],
190
- "Sentence CERs": sentence_metrics["sentence_cers"],
191
- "Average WER": sentence_metrics["average_wer"],
192
- "Average CER": sentence_metrics["average_cer"],
193
- "Standard Deviation WER": sentence_metrics["std_dev_wer"],
194
- "Standard Deviation CER": sentence_metrics["std_dev_cer"],
195
  "Misaligned Sentences": misaligned
196
  }
197
  except Exception as e:
198
  return {"error": str(e)}
199
 
200
-
201
  def process_and_display(ref_file, hyp_file):
202
  result = process_files(ref_file, hyp_file)
203
 
204
  if "error" in result:
205
- error_msg = result["error"]
206
- return {"error": error_msg}, "", ""
207
 
208
  metrics = {
209
  "Overall WER": result["Overall WER"],
@@ -211,46 +145,21 @@ def process_and_display(ref_file, hyp_file):
211
  }
212
 
213
  metrics_md = format_sentence_metrics(
214
- result["Sentence WERs"],
215
- result["Sentence CERs"],
216
- result["Average WER"],
217
- result["Average CER"],
218
- result["Standard Deviation WER"],
219
- result["Standard Deviation CER"],
220
- result["Misaligned Sentences"]
221
- )
222
-
223
- return metrics, metrics_md
224
-
225
- def process_and_display(ref_file, hyp_file):
226
- result = process_files(ref_file, hyp_file)
227
-
228
- if "error" in result:
229
- error_msg = result["error"]
230
- return {"error": error_msg}, "", ""
231
-
232
- metrics = {
233
- "Overall WER": result["Overall WER"],
234
- "Overall CER": result["Overall CER"]
235
- }
236
-
237
- metrics_md = format_sentence_metrics(
238
- result["Sentence WERs"],
239
- result["Sentence CERs"],
240
- result["Average WER"],
241
- result["Average CER"],
242
- result["Standard Deviation WER"],
243
- result["Standard Deviation CER"],
244
- result["Misaligned Sentences"]
245
  )
246
 
247
  misaligned_md = "### Misaligned Sentences\n\n"
248
  if result["Misaligned Sentences"]:
249
- for misaligned in result["Misaligned Sentences"]:
250
- misaligned_md += f"#### Sentence {misaligned['index']}\n"
251
- misaligned_md += f"* Reference: {misaligned['context_ref']}\n"
252
- misaligned_md += f"* Hypothesis: {misaligned['context_hyp']}\n"
253
- misaligned_md += f"* Misalignment starts at position: {misaligned['misalignment_start']}\n\n"
254
  else:
255
  misaligned_md += "* No misaligned sentences found."
256
 
@@ -258,25 +167,19 @@ def process_and_display(ref_file, hyp_file):
258
 
259
  def main():
260
  with gr.Blocks() as demo:
261
- gr.Markdown("# ASR Metrics Analysis")
262
 
263
- with gr.Row() as row:
264
  with gr.Column():
265
- gr.Markdown("### Instructions")
266
- gr.Markdown("* Upload your reference and hypothesis files")
267
- gr.Markdown("* The interface will automatically analyze and display metrics")
268
- gr.Markdown("* Results include overall metrics, sentence-level metrics, and misaligned sentences")
269
-
270
- with gr.Row():
271
- reference_file = gr.File(label="Upload Reference File")
272
- hypothesis_file = gr.File(label="Upload Model Output File")
273
-
274
- with gr.Row():
275
- compute_button = gr.Button("Compute Metrics", variant="primary")
276
- results_output = gr.JSON(label="Results")
277
- metrics_output = gr.Markdown(label="Metrics")
278
- misaligned_output = gr.Markdown(label="Misaligned Sentences")
279
 
 
 
 
 
280
 
281
  compute_button.click(
282
  fn=process_and_display,
@@ -287,4 +190,4 @@ def main():
287
  demo.launch()
288
 
289
  if __name__ == "__main__":
290
- main()
 
 
1
  import jiwer
2
  import numpy as np
 
3
  import gradio as gr
4
 
 
 
 
 
 
 
 
 
 
 
5
  def calculate_wer(reference, hypothesis):
6
+ reference_str = " ".join(reference)
7
+ hypothesis_str = " ".join(hypothesis)
8
+ return jiwer.wer(reference_str, hypothesis_str)
 
 
9
 
 
10
  def calculate_cer(reference, hypothesis):
11
+ reference_str = " ".join(reference)
12
+ hypothesis_str = " ".join(hypothesis)
13
+ return jiwer.cer(reference_str, hypothesis_str)
 
 
14
 
 
15
  def calculate_sentence_metrics(reference, hypothesis):
16
+ reference_sentences = [line.strip() for line in reference]
17
+ hypothesis_sentences = [line.strip() for line in hypothesis]
18
+
19
+ sentence_wers = []
20
+ sentence_cers = []
21
+
22
+ min_length = min(len(reference_sentences), len(hypothesis_sentences))
23
+
24
+ for i in range(min_length):
25
+ ref = reference_sentences[i]
26
+ hyp = hypothesis_sentences[i]
27
+ wer = jiwer.wer(ref, hyp)
28
+ cer = jiwer.cer(ref, hyp)
29
+ sentence_wers.append(wer)
30
+ sentence_cers.append(cer)
31
+
32
+ average_wer = np.mean(sentence_wers) if sentence_wers else 0.0
33
+ std_dev_wer = np.std(sentence_wers) if sentence_wers else 0.0
34
+ average_cer = np.mean(sentence_cers) if sentence_cers else 0.0
35
+ std_dev_cer = np.std(sentence_cers) if sentence_cers else 0.0
36
+
37
+ return {
38
+ "sentence_wers": sentence_wers,
39
+ "sentence_cers": sentence_cers,
40
+ "average_wer": average_wer,
41
+ "average_cer": average_cer,
42
+ "std_dev_wer": std_dev_wer,
43
+ "std_dev_cer": std_dev_cer
44
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ def identify_misaligned_sentences(reference, hypothesis):
47
+ reference_sentences = [line.strip() for line in reference]
48
+ hypothesis_sentences = [line.strip() for line in hypothesis]
 
 
 
 
49
 
50
  misaligned = []
51
+
52
  for i, (ref, hyp) in enumerate(zip(reference_sentences, hypothesis_sentences)):
53
  if ref != hyp:
 
54
  ref_words = ref.split()
55
  hyp_words = hyp.split()
 
 
56
  min_length = min(len(ref_words), len(hyp_words))
57
+
58
  misalignment_start = 0
59
  for j in range(min_length):
60
  if ref_words[j] != hyp_words[j]:
61
  misalignment_start = j
62
  break
63
 
64
+ context_ref = ' '.join(ref_words[:misalignment_start] + ['**' + ref_words[misalignment_start] + '**']) if ref_words else ""
65
+ context_hyp = ' '.join(hyp_words[:misalignment_start] + ['**' + hyp_words[misalignment_start] + '**']) if hyp_words else ""
 
66
 
67
  misaligned.append({
68
+ "index": i + 1,
69
  "reference": ref,
70
  "hypothesis": hyp,
71
  "misalignment_start": misalignment_start,
 
73
  "context_hyp": context_hyp
74
  })
75
 
76
+ # Handle extra sentences
77
  if len(reference_sentences) > len(hypothesis_sentences):
78
  for i in range(len(hypothesis_sentences), len(reference_sentences)):
79
  misaligned.append({
80
+ "index": i + 1,
81
  "reference": reference_sentences[i],
82
  "hypothesis": "No corresponding sentence",
83
  "misalignment_start": 0,
 
87
  elif len(hypothesis_sentences) > len(reference_sentences):
88
  for i in range(len(reference_sentences), len(hypothesis_sentences)):
89
  misaligned.append({
90
+ "index": i + 1,
91
  "reference": "No corresponding sentence",
92
  "hypothesis": hypothesis_sentences[i],
93
  "misalignment_start": 0,
 
97
 
98
  return misaligned
99
 
100
+ def format_sentence_metrics(sentence_wers, sentence_cers, average_wer, average_cer, std_dev_wer, std_dev_cer):
 
101
  md = "### Sentence-level Metrics\n\n"
102
+ md += f"**Average WER**: {average_wer:.2f}\n\n"
103
+ md += f"**Standard Deviation WER**: {std_dev_wer:.2f}\n\n"
104
+ md += f"**Average CER**: {average_cer:.2f}\n\n"
105
+ md += f"**Standard Deviation CER**: {std_dev_cer:.2f}\n\n"
106
+
107
+ md += "---\n**WER by Sentence**\n"
 
 
108
  for i, wer in enumerate(sentence_wers):
109
+ md += f"- Sentence {i+1}: {wer:.2f}\n"
110
 
111
+ md += "\n**CER by Sentence**\n"
112
  for i, cer in enumerate(sentence_cers):
113
+ md += f"- Sentence {i+1}: {cer:.2f}\n"
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  return md
116
 
 
117
  def process_files(reference_file, hypothesis_file):
118
  try:
119
+ reference_text = reference_file.read().decode("utf-8").splitlines()
120
+ hypothesis_text = hypothesis_file.read().decode("utf-8").splitlines()
 
 
 
121
 
122
  overall_wer = calculate_wer(reference_text, hypothesis_text)
123
  overall_cer = calculate_cer(reference_text, hypothesis_text)
 
127
  return {
128
  "Overall WER": overall_wer,
129
  "Overall CER": overall_cer,
130
+ **sentence_metrics,
 
 
 
 
 
131
  "Misaligned Sentences": misaligned
132
  }
133
  except Exception as e:
134
  return {"error": str(e)}
135
 
 
136
  def process_and_display(ref_file, hyp_file):
137
  result = process_files(ref_file, hyp_file)
138
 
139
  if "error" in result:
140
+ return {"error": result["error"]}, "", ""
 
141
 
142
  metrics = {
143
  "Overall WER": result["Overall WER"],
 
145
  }
146
 
147
  metrics_md = format_sentence_metrics(
148
+ result["sentence_wers"],
149
+ result["sentence_cers"],
150
+ result["average_wer"],
151
+ result["average_cer"],
152
+ result["std_dev_wer"],
153
+ result["std_dev_cer"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  )
155
 
156
  misaligned_md = "### Misaligned Sentences\n\n"
157
  if result["Misaligned Sentences"]:
158
+ for mis in result["Misaligned Sentences"]:
159
+ misaligned_md += f"**Sentence {mis['index']}**\n"
160
+ misaligned_md += f"- Reference: {mis['context_ref']}\n"
161
+ misaligned_md += f"- Hypothesis: {mis['context_hyp']}\n"
162
+ misaligned_md += f"- Misalignment starts at position: {mis['misalignment_start']}\n\n"
163
  else:
164
  misaligned_md += "* No misaligned sentences found."
165
 
 
167
 
168
  def main():
169
  with gr.Blocks() as demo:
170
+ gr.Markdown("# 📊 ASR Metrics Analysis Tool")
171
 
172
+ with gr.Row():
173
  with gr.Column():
174
+ gr.Markdown("### Upload your reference and hypothesis files")
175
+ reference_file = gr.File(label="Reference File (.txt)")
176
+ hypothesis_file = gr.File(label="Hypothesis File (.txt)")
177
+ compute_button = gr.Button("Compute Metrics", variant="primary")
 
 
 
 
 
 
 
 
 
 
178
 
179
+ with gr.Column():
180
+ results_output = gr.JSON(label="Results Summary")
181
+ metrics_output = gr.Markdown(label="Sentence Metrics")
182
+ misaligned_output = gr.Markdown(label="Misaligned Sentences")
183
 
184
  compute_button.click(
185
  fn=process_and_display,
 
190
  demo.launch()
191
 
192
  if __name__ == "__main__":
193
+ main()