akki2825 commited on
Commit
af4ff35
·
verified ·
1 Parent(s): a1fe148

add wer fn

Browse files
Files changed (1) hide show
  1. app.py +77 -67
app.py CHANGED
@@ -1,78 +1,55 @@
 
1
  import spaces
2
- import gradio as gr
3
  import numpy as np
 
4
 
5
  @spaces.GPU()
6
- def get_mismatched_sentences(reference, hypothesis):
7
  """
8
- Get mismatched sentences between reference and hypothesis.
9
  """
10
- reference = reference.split()
11
- hypothesis = hypothesis.split()
12
-
13
- mismatched = []
14
- for ref, hyp in zip(reference, hypothesis):
15
- if ref != hyp:
16
- mismatched.append((ref, hyp))
17
-
18
- return mismatched
19
-
20
- @spaces.GPU()
21
- def calculate_wer(reference, hypothesis):
22
- reference_words = reference.split()
23
- hypothesis_words = hypothesis.split()
24
-
25
- m = len(reference_words)
26
- n = len(hypothesis_words)
27
-
28
- # Initialize DP table
29
- dp = np.zeros((m+1, n+1), dtype=np.int32)
30
-
31
- # Base cases
32
- for i in range(m+1):
33
- dp[i][0] = i
34
- for j in range(n+1):
35
- dp[0][j] = j
36
-
37
- # Fill DP table
38
- for i in range(1, m+1):
39
- for j in range(1, n+1):
40
- cost = 0 if reference_words[i-1] == hypothesis_words[j-1] else 1
41
- dp[i][j] = min(dp[i-1][j] + 1, # Deletion
42
- dp[i][j-1] + 1, # Insertion
43
- dp[i-1][j-1] + cost) # Substitution or no cost
44
-
45
- wer = dp[m][n] / m
46
  return wer
47
 
48
  @spaces.GPU()
49
  def calculate_cer(reference, hypothesis):
50
- reference = reference.replace(" ", "")
51
- hypothesis = hypothesis.replace(" ", "")
52
-
53
- m = len(reference)
54
- n = len(hypothesis)
55
-
56
- # Initialize DP table
57
- dp = np.zeros((m+1, n+1), dtype=np.int32)
58
-
59
- # Base cases
60
- for i in range(m+1):
61
- dp[i][0] = i
62
- for j in range(n+1):
63
- dp[0][j] = j
64
-
65
- # Fill DP table
66
- for i in range(1, m+1):
67
- for j in range(1, n+1):
68
- cost = 0 if reference[i-1] == hypothesis[j-1] else 1
69
- dp[i][j] = min(dp[i-1][j] + 1, # Deletion
70
- dp[i][j-1] + 1, # Insertion
71
- dp[i-1][j-1] + cost) # Substitution or no cost
72
-
73
- cer = dp[m][n] / m
74
  return cer
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  @spaces.GPU()
78
  def process_files(reference_file, hypothesis_file):
@@ -85,16 +62,30 @@ def process_files(reference_file, hypothesis_file):
85
 
86
  wer_value = calculate_wer(reference_text, hypothesis_text)
87
  cer_value = calculate_cer(reference_text, hypothesis_text)
88
- mismatched_sentences = get_mismatched_sentences(reference_text, hypothesis_text)
89
 
90
  return {
91
  "WER": wer_value,
92
  "CER": cer_value,
93
- "Mismatched Sentences": mismatched_sentences
 
 
94
  }
95
  except Exception as e:
96
  return {"error": str(e)}
97
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def main():
99
  with gr.Blocks() as demo:
100
  gr.Markdown("# ASR Metrics Calculator")
@@ -110,6 +101,7 @@ def main():
110
  with gr.Row():
111
  compute_button = gr.Button("Compute Metrics")
112
  results_output = gr.JSON(label="Results")
 
113
 
114
  # Update previews when files are uploaded
115
  def update_previews(ref_file, hyp_file):
@@ -136,10 +128,28 @@ def main():
136
  outputs=[reference_preview, hypothesis_preview]
137
  )
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  compute_button.click(
140
- fn=process_files,
141
  inputs=[reference_file, hypothesis_file],
142
- outputs=results_output
143
  )
144
 
145
  demo.launch()
 
1
+ import jiwer
2
  import spaces
 
3
  import numpy as np
4
+ import gradio as gr
5
 
6
  @spaces.GPU()
7
+ def calculate_wer(reference, hypothesis):
8
  """
9
+ Calculate the Word Error Rate (WER) using jiwer.
10
  """
11
+ wer = jiwer.wer(reference, hypothesis)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  return wer
13
 
14
  @spaces.GPU()
15
  def calculate_cer(reference, hypothesis):
16
+ """
17
+ Calculate the Character Error Rate (CER) using jiwer.
18
+ """
19
+ cer = jiwer.cer(reference, hypothesis)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  return cer
21
 
22
+ @spaces.GPU()
23
+ def calculate_sentence_wer(reference, hypothesis):
24
+ """
25
+ Calculate WER for each sentence and overall statistics.
26
+ """
27
+ reference_sentences = jiwer.split_into_sentences(reference)
28
+ hypothesis_sentences = jiwer.split_into_sentences(hypothesis)
29
+
30
+ if len(reference_sentences) != len(hypothesis_sentences):
31
+ raise ValueError("Reference and hypothesis must contain the same number of sentences")
32
+
33
+ sentence_wers = []
34
+ for ref, hyp in zip(reference_sentences, hypothesis_sentences):
35
+ sentence_wer = jiwer.wer(ref, hyp)
36
+ sentence_wers.append(sentence_wer)
37
+
38
+ if not sentence_wers:
39
+ return {
40
+ "sentence_wers": [],
41
+ "average_wer": 0.0,
42
+ "std_dev": 0.0
43
+ }
44
+
45
+ average_wer = np.mean(sentence_wers)
46
+ std_dev = np.std(sentence_wers)
47
+
48
+ return {
49
+ "sentence_wers": sentence_wers,
50
+ "average_wer": average_wer,
51
+ "std_dev": std_dev
52
+ }
53
 
54
  @spaces.GPU()
55
  def process_files(reference_file, hypothesis_file):
 
62
 
63
  wer_value = calculate_wer(reference_text, hypothesis_text)
64
  cer_value = calculate_cer(reference_text, hypothesis_text)
65
+ sentence_wer_stats = calculate_sentence_wer(reference_text, hypothesis_text)
66
 
67
  return {
68
  "WER": wer_value,
69
  "CER": cer_value,
70
+ "Sentence WERs": sentence_wer_stats["sentence_wers"],
71
+ "Average WER": sentence_wer_stats["average_wer"],
72
+ "Standard Deviation": sentence_wer_stats["std_dev"]
73
  }
74
  except Exception as e:
75
  return {"error": str(e)}
76
 
77
+ def format_sentence_wer_stats(sentence_wers, average_wer, std_dev):
78
+ if not sentence_wers:
79
+ return "All sentences match perfectly!"
80
+
81
+ md = "### Sentence-level WER Analysis\n\n"
82
+ md += f"* Average WER: {average_wer:.2f}\n"
83
+ md += f"* Standard Deviation: {std_dev:.2f}\n\n"
84
+ md += "### WER for Each Sentence\n\n"
85
+ for i, wer in enumerate(sentence_wers):
86
+ md += f"* Sentence {i+1}: {wer:.2f}\n"
87
+ return md
88
+
89
  def main():
90
  with gr.Blocks() as demo:
91
  gr.Markdown("# ASR Metrics Calculator")
 
101
  with gr.Row():
102
  compute_button = gr.Button("Compute Metrics")
103
  results_output = gr.JSON(label="Results")
104
+ wer_stats_output = gr.Markdown(label="WER Statistics")
105
 
106
  # Update previews when files are uploaded
107
  def update_previews(ref_file, hyp_file):
 
128
  outputs=[reference_preview, hypothesis_preview]
129
  )
130
 
131
+ def process_and_display(ref_file, hyp_file):
132
+ result = process_files(ref_file, hyp_file)
133
+ if "error" in result:
134
+ return {}, {}, "Error: " + result["error"]
135
+
136
+ metrics = {
137
+ "WER": result["WER"],
138
+ "CER": result["CER"]
139
+ }
140
+
141
+ wer_stats_md = format_sentence_wer_stats(
142
+ result["Sentence WERs"],
143
+ result["Average WER"],
144
+ result["Standard Deviation"]
145
+ )
146
+
147
+ return metrics, wer_stats_md
148
+
149
  compute_button.click(
150
+ fn=process_and_display,
151
  inputs=[reference_file, hypothesis_file],
152
+ outputs=[results_output, wer_stats_output]
153
  )
154
 
155
  demo.launch()