akki2825 commited on
Commit
89fac21
·
verified ·
1 Parent(s): 0f22edd
Files changed (1) hide show
  1. app.py +85 -30
app.py CHANGED
@@ -1,7 +1,11 @@
1
  import jiwer
2
- import spaces
3
  import numpy as np
4
  import gradio as gr
 
 
 
 
 
5
 
6
  @spaces.GPU()
7
  def calculate_wer(reference, hypothesis):
@@ -24,33 +28,55 @@ def calculate_sentence_wer(reference, hypothesis):
24
  """
25
  Calculate WER for each sentence and overall statistics.
26
  """
27
- reference_sentences = reference.split()
28
- hypothesis_sentences = hypothesis.split()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- if len(reference_sentences) != len(hypothesis_sentences):
31
- raise ValueError("Reference and hypothesis must contain the same number of sentences")
32
 
33
- sentence_wers = []
34
- for ref, hyp in zip(reference_sentences, hypothesis_sentences):
35
- sentence_wer = jiwer.wer(ref, hyp)
36
- sentence_wers.append(sentence_wer)
 
 
 
37
 
38
- if not sentence_wers:
 
 
 
 
 
 
39
  return {
40
  "sentence_wers": [],
41
  "average_wer": 0.0,
42
- "std_dev": 0.0
 
43
  }
44
 
45
- average_wer = np.mean(sentence_wers)
46
- std_dev = np.std(sentence_wers)
47
-
48
- return {
49
- "sentence_wers": sentence_wers,
50
- "average_wer": average_wer,
51
- "std_dev": std_dev
52
- }
53
-
54
  @spaces.GPU()
55
  def process_files(reference_file, hypothesis_file):
56
  try:
@@ -60,6 +86,11 @@ def process_files(reference_file, hypothesis_file):
60
  with open(hypothesis_file.name, 'r') as f:
61
  hypothesis_text = f.read()
62
 
 
 
 
 
 
63
  wer_value = calculate_wer(reference_text, hypothesis_text)
64
  cer_value = calculate_cer(reference_text, hypothesis_text)
65
  sentence_wer_stats = calculate_sentence_wer(reference_text, hypothesis_text)
@@ -69,21 +100,39 @@ def process_files(reference_file, hypothesis_file):
69
  "CER": cer_value,
70
  "Sentence WERs": sentence_wer_stats["sentence_wers"],
71
  "Average WER": sentence_wer_stats["average_wer"],
72
- "Standard Deviation": sentence_wer_stats["std_dev"]
 
 
73
  }
74
  except Exception as e:
75
- return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- def format_sentence_wer_stats(sentence_wers, average_wer, std_dev):
78
  if not sentence_wers:
79
- return "All sentences match perfectly!"
 
80
 
81
- md = "### Sentence-level WER Analysis\n\n"
82
  md += f"* Average WER: {average_wer:.2f}\n"
83
  md += f"* Standard Deviation: {std_dev:.2f}\n\n"
84
  md += "### WER for Each Sentence\n\n"
85
  for i, wer in enumerate(sentence_wers):
86
  md += f"* Sentence {i+1}: {wer:.2f}\n"
 
87
  return md
88
 
89
  def main():
@@ -130,18 +179,24 @@ def main():
130
 
131
  def process_and_display(ref_file, hyp_file):
132
  result = process_files(ref_file, hyp_file)
133
- if "error" in result:
134
- return {}, {}, "Error: " + result["error"]
135
 
136
  metrics = {
137
  "WER": result["WER"],
138
  "CER": result["CER"]
139
  }
140
 
 
 
 
 
 
 
141
  wer_stats_md = format_sentence_wer_stats(
142
- result["Sentence WERs"],
143
- result["Average WER"],
144
- result["Standard Deviation"]
 
 
145
  )
146
 
147
  return metrics, wer_stats_md
 
1
  import jiwer
 
2
  import numpy as np
3
  import gradio as gr
4
+ import nltk
5
+ from nltk.tokenize import sent_tokenize
6
+
7
+ # Ensure NLTK data is downloaded
8
+ nltk.download('punkt')
9
 
10
  @spaces.GPU()
11
  def calculate_wer(reference, hypothesis):
 
28
  """
29
  Calculate WER for each sentence and overall statistics.
30
  """
31
+ try:
32
+ reference_sentences = sent_tokenize(reference)
33
+ hypothesis_sentences = sent_tokenize(hypothesis)
34
+
35
+ # Get minimum number of sentences
36
+ min_sentences = min(len(reference_sentences), len(hypothesis_sentences))
37
+
38
+ # Trim to the same number of sentences
39
+ reference_sentences = reference_sentences[:min_sentences]
40
+ hypothesis_sentences = hypothesis_sentences[:min_sentences]
41
+
42
+ sentence_wers = []
43
+ for ref, hyp in zip(reference_sentences, hypothesis_sentences):
44
+ sentence_wer = jiwer.wer(ref, hyp)
45
+ sentence_wers.append(sentence_wer)
46
+
47
+ if not sentence_wers:
48
+ return {
49
+ "sentence_wers": [],
50
+ "average_wer": 0.0,
51
+ "std_dev": 0.0,
52
+ "warning": "No sentences to compare"
53
+ }
54
 
55
+ average_wer = np.mean(sentence_wers)
56
+ std_dev = np.std(sentence_wers)
57
 
58
+ # Check if there were extra sentences
59
+ if len(reference_sentences) != len(hypothesis_sentences):
60
+ warning = f"Reference has {len(reference_sentences)} sentences, " \
61
+ f"hypothesis has {len(hypothesis_sentences)} sentences. " \
62
+ f"Only compared the first {min_sentences} sentences."
63
+ else:
64
+ warning = None
65
 
66
+ return {
67
+ "sentence_wers": sentence_wers,
68
+ "average_wer": average_wer,
69
+ "std_dev": std_dev,
70
+ "warning": warning
71
+ }
72
+ except Exception as e:
73
  return {
74
  "sentence_wers": [],
75
  "average_wer": 0.0,
76
+ "std_dev": 0.0,
77
+ "error": str(e)
78
  }
79
 
 
 
 
 
 
 
 
 
 
80
  @spaces.GPU()
81
  def process_files(reference_file, hypothesis_file):
82
  try:
 
86
  with open(hypothesis_file.name, 'r') as f:
87
  hypothesis_text = f.read()
88
 
89
+ if not reference_text or not hypothesis_text:
90
+ return {
91
+ "error": "Both reference and hypothesis files must contain text"
92
+ }
93
+
94
  wer_value = calculate_wer(reference_text, hypothesis_text)
95
  cer_value = calculate_cer(reference_text, hypothesis_text)
96
  sentence_wer_stats = calculate_sentence_wer(reference_text, hypothesis_text)
 
100
  "CER": cer_value,
101
  "Sentence WERs": sentence_wer_stats["sentence_wers"],
102
  "Average WER": sentence_wer_stats["average_wer"],
103
+ "Standard Deviation": sentence_wer_stats["std_dev"],
104
+ "Warning": sentence_wer_stats.get("warning"),
105
+ "Error": sentence_wer_stats.get("error")
106
  }
107
  except Exception as e:
108
+ return {
109
+ "WER": 0.0,
110
+ "CER": 0.0,
111
+ "Sentence WERs": [],
112
+ "Average WER": 0.0,
113
+ "Standard Deviation": 0.0,
114
+ "Error": str(e)
115
+ }
116
+
117
+ def format_sentence_wer_stats(sentence_wers, average_wer, std_dev, warning, error):
118
+ md = ""
119
+
120
+ if error:
121
+ md += f"### Error\n{error}\n\n"
122
+ elif warning:
123
+ md += f"### Warning\n{warning}\n\n"
124
 
 
125
  if not sentence_wers:
126
+ md += "No sentences to compare"
127
+ return md
128
 
129
+ md += "### Sentence-level WER Analysis\n\n"
130
  md += f"* Average WER: {average_wer:.2f}\n"
131
  md += f"* Standard Deviation: {std_dev:.2f}\n\n"
132
  md += "### WER for Each Sentence\n\n"
133
  for i, wer in enumerate(sentence_wers):
134
  md += f"* Sentence {i+1}: {wer:.2f}\n"
135
+
136
  return md
137
 
138
  def main():
 
179
 
180
  def process_and_display(ref_file, hyp_file):
181
  result = process_files(ref_file, hyp_file)
 
 
182
 
183
  metrics = {
184
  "WER": result["WER"],
185
  "CER": result["CER"]
186
  }
187
 
188
+ error = result.get("Error")
189
+ warning = result.get("Warning")
190
+ sentence_wers = result.get("Sentence WERs", [])
191
+ average_wer = result.get("Average WER", 0.0)
192
+ std_dev = result.get("Standard Deviation", 0.0)
193
+
194
  wer_stats_md = format_sentence_wer_stats(
195
+ sentence_wers,
196
+ average_wer,
197
+ std_dev,
198
+ warning,
199
+ error
200
  )
201
 
202
  return metrics, wer_stats_md