Yoad commited on
Commit
7b3ae60
·
1 Parent(s): abf28c9

Add subs visualizer

Browse files
src/app.py CHANGED
@@ -12,6 +12,7 @@ from huggingface_hub import HfFileSystem
12
  from st_fixed_container import st_fixed_container
13
  from visual_eval.evaluator import HebrewTextNormalizer
14
  from visual_eval.visualization import render_visualize_jiwer_result_html
 
15
 
16
  HF_API_TOKEN = None
17
  try:
@@ -377,6 +378,8 @@ def main():
377
 
378
  # Toggle for normalized vs raw text
379
  use_normalized = st.sidebar.toggle("Use normalized text", value=True)
 
 
380
 
381
  # Create sidebar for entry selection
382
  st.sidebar.header("Select Entry")
@@ -495,22 +498,26 @@ def main():
495
  html = render_visualize_jiwer_result_html(ref, hyp)
496
  display_rtl(html)
497
 
498
- # Display metadata
499
- st.header("Metadata")
500
- metadata_cols = [
501
- "metadata_uuid",
502
- "model",
503
- "dataset",
504
- "dataset_split",
505
- "engine",
506
- ]
507
- metadata = eval_results.iloc[selected_entry][metadata_cols]
508
-
509
- # Create a DataFrame for better display
510
- metadata_df = pd.DataFrame(
511
- {"Field": metadata_cols, "Value": [str(v) for v in metadata.values]}
512
- )
513
- st.table(metadata_df)
 
 
 
 
514
 
515
  # If we have audio URL, display it in the sticky container
516
  if "audio_url" in locals() and audio_url:
 
12
  from st_fixed_container import st_fixed_container
13
  from visual_eval.evaluator import HebrewTextNormalizer
14
  from visual_eval.visualization import render_visualize_jiwer_result_html
15
+ from substitutions_visualizer import visualize_substitutions
16
 
17
  HF_API_TOKEN = None
18
  try:
 
378
 
379
  # Toggle for normalized vs raw text
380
  use_normalized = st.sidebar.toggle("Use normalized text", value=True)
381
+ show_metadata = st.sidebar.toggle("Show entry metadata", value=False)
382
+ visualize_subs = st.sidebar.toggle("List Substitutions", value=False)
383
 
384
  # Create sidebar for entry selection
385
  st.sidebar.header("Select Entry")
 
498
  html = render_visualize_jiwer_result_html(ref, hyp)
499
  display_rtl(html)
500
 
501
+ if show_metadata:
502
+ # Display metadata
503
+ st.header("Metadata")
504
+ metadata_cols = [
505
+ "metadata_uuid",
506
+ "model",
507
+ "dataset",
508
+ "dataset_split",
509
+ "engine",
510
+ ]
511
+ metadata = eval_results.iloc[selected_entry][metadata_cols]
512
+
513
+ # Create a DataFrame for better display
514
+ metadata_df = pd.DataFrame(
515
+ {"Field": metadata_cols, "Value": [str(v) for v in metadata.values]}
516
+ )
517
+ st.table(metadata_df)
518
+
519
+ if visualize_subs:
520
+ visualize_substitutions(ref, hyp)
521
 
522
  # If we have audio URL, display it in the sticky container
523
  if "audio_url" in locals() and audio_url:
src/substitutions_visualizer.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+ from jiwer import process_words
4
+
5
+ from visual_eval.evaluator import extract_substitution_samples, HebrewTextNormalizer
6
+
7
+ subs_table_styles = """
8
+ <style>
9
+ .sub-table {
10
+ background: white;
11
+ width: 100%;
12
+ border-collapse: collapse;
13
+ color: black;
14
+ }
15
+ .sub-row {
16
+ cursor: pointer;
17
+ transition: all 0.2s;
18
+ }
19
+ .sub-row .txt {
20
+ text-align: center;
21
+ }
22
+ .sub-row:nth-child(even):hover {
23
+ background: #eee;
24
+ }
25
+ .sub-row:nth-child(odd):hover {
26
+ background: #eee;
27
+ }
28
+ .sub-row:nth-child(even):hover + .sub-row {
29
+ background: #eee;
30
+ }
31
+ .sub-row:nth-child(even):has(+ .sub-row:hover) {
32
+ background: #eee;
33
+ }
34
+ .sub-row.ref {
35
+ color: green;
36
+ }
37
+ .sub-row.ref .ctx {
38
+ text-align: end;
39
+ }
40
+ .sub-row.hyp {
41
+ color: red;
42
+ border-bottom: 1px solid black;
43
+ }
44
+ .sub-row.hyp .ctx {
45
+ text-align: start;
46
+ }
47
+ </style>
48
+ """
49
+
50
+
51
+ @st.cache_data
52
+ def visualize_substitutions(ref, hyp):
53
+ norm = HebrewTextNormalizer()
54
+ wer_word_output = process_words(norm(ref), norm(hyp))
55
+ subs_rows = []
56
+ for sample in extract_substitution_samples(wer_word_output):
57
+ subs_rows.append(
58
+ {
59
+ "ref": " ".join(sample.ref),
60
+ "hyp": " ".join(sample.hyp),
61
+ "hyp_ctx": " ".join(
62
+ wer_word_output.hypotheses[0][slice(*sample.hyp_context_span)]
63
+ ),
64
+ "ref_ctx": " ".join(
65
+ wer_word_output.references[0][slice(*sample.ref_context_span)]
66
+ ),
67
+ }
68
+ )
69
+
70
+ sub_rows_html = []
71
+ for row in subs_rows:
72
+ sub_rows_html.append(
73
+ f"""
74
+ <tr class="sub-row ref">
75
+ <td class="ctx">{row['ref_ctx']}</td>
76
+ <td class="txt">{row['ref']}</td>
77
+ <td></td>
78
+ </tr>
79
+ <tr class="sub-row hyp">
80
+ <td></td>
81
+ <td class="txt">{row['hyp']}</td>
82
+ <td class="ctx">{row['hyp_ctx']}</td>
83
+ </tr>
84
+ """
85
+ )
86
+
87
+ st.subheader("Substitutions List")
88
+
89
+ table_html = f"""
90
+ {subs_table_styles}
91
+ <table class="sub-table" dir="rtl" lang="he">
92
+ <tr>
93
+ <th style="text-align: end;">Ref Context</th>
94
+ <th style="text-align: center;">Ref/Hyp</th>
95
+ <th style="text-align: start;">Hyp Context</th>
96
+ </tr>
97
+ {"".join(sub_rows_html)}
98
+ </table>
99
+ """
100
+
101
+ st.html(table_html)
src/visual_eval/evaluator.py CHANGED
@@ -1,23 +1,9 @@
1
- """
2
- Evaluator module.
3
- Provides functions to evaluate a given model on a dataset sample using the Faster Whisper model,
4
- and generate HTML visualization blocks of the word alignment.
5
- """
6
-
7
- import concurrent.futures
8
- import gc
9
- import io
10
- import queue
11
- import threading
12
- from typing import Dict, Generator, List
13
-
14
- import soundfile as sf
15
  from hebrew import Hebrew
16
- from tqdm import tqdm
17
  from transformers.models.whisper.english_normalizer import BasicTextNormalizer
18
 
19
- from visual_eval.visualization import render_visualize_jiwer_result_html
20
-
21
 
22
  class HebrewTextNormalizer(BasicTextNormalizer):
23
  def __init__(self, *args, **kwargs):
@@ -54,3 +40,149 @@ class HebrewTextNormalizer(BasicTextNormalizer):
54
  text = self.__remove_quotes(text)
55
  text = super().__call__(text)
56
  return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
 
 
 
 
 
 
 
 
 
 
 
 
3
  from hebrew import Hebrew
4
+ from jiwer import process_words
5
  from transformers.models.whisper.english_normalizer import BasicTextNormalizer
6
 
 
 
7
 
8
  class HebrewTextNormalizer(BasicTextNormalizer):
9
  def __init__(self, *args, **kwargs):
 
40
  text = self.__remove_quotes(text)
41
  text = super().__call__(text)
42
  return text
43
+
44
+
45
+ context_expansion_size = 4
46
+
47
+
48
+ @dataclass
49
+ class SubsSample:
50
+ ref_context_span: tuple[int, int]
51
+ hyp_context_span: tuple[int, int]
52
+ ref: list[str]
53
+ hyp: list[str]
54
+
55
+
56
+ def merge_spans(span1, span2):
57
+ return (min(span1[0], span2[0]), max(span1[1], span2[1]))
58
+
59
+
60
+ def merge_sub_samples(sub_samples: list[SubsSample]):
61
+ merged_sample = None
62
+ for sample in sub_samples:
63
+ if not merged_sample:
64
+ merged_sample = sample
65
+ continue
66
+ merged_sample = SubsSample(
67
+ ref_context_span=merge_spans(
68
+ merged_sample.ref_context_span, sample.ref_context_span
69
+ ),
70
+ hyp_context_span=merge_spans(
71
+ merged_sample.hyp_context_span, sample.hyp_context_span
72
+ ),
73
+ ref=merged_sample.ref + sample.ref,
74
+ hyp=merged_sample.hyp + sample.hyp,
75
+ )
76
+ return merged_sample
77
+
78
+
79
+ def get_aligned_chunk_words(wer_word_output, chunk):
80
+ ref_words = None
81
+ hyp_words = None
82
+ ref_context_span = [
83
+ max(0, chunk.ref_start_idx - context_expansion_size),
84
+ min(
85
+ chunk.ref_end_idx + context_expansion_size,
86
+ len(wer_word_output.references[0]),
87
+ ),
88
+ ]
89
+ hyp_context_span = [
90
+ max(0, chunk.hyp_start_idx - context_expansion_size),
91
+ min(
92
+ chunk.hyp_end_idx + context_expansion_size,
93
+ len(wer_word_output.hypotheses[0]),
94
+ ),
95
+ ]
96
+ if chunk.type == "equal":
97
+ ref_words = wer_word_output.references[0][
98
+ chunk.ref_start_idx : chunk.ref_end_idx
99
+ ]
100
+ hyp_words = wer_word_output.hypotheses[0][
101
+ chunk.hyp_start_idx : chunk.hyp_end_idx
102
+ ]
103
+ elif chunk.type == "delete":
104
+ ref_words = wer_word_output.references[0][
105
+ chunk.ref_start_idx : chunk.ref_end_idx
106
+ ]
107
+ hyp_words = [""] * len(ref_words)
108
+ elif chunk.type == "insert":
109
+ hyp_words = wer_word_output.hypotheses[0][
110
+ chunk.hyp_start_idx : chunk.hyp_end_idx
111
+ ]
112
+ ref_words = [""] * len(hyp_words)
113
+ elif chunk.type == "substitute":
114
+ ref_words = wer_word_output.references[0][
115
+ chunk.ref_start_idx : chunk.ref_end_idx
116
+ ]
117
+ hyp_words = wer_word_output.hypotheses[0][
118
+ chunk.hyp_start_idx : chunk.hyp_end_idx
119
+ ]
120
+ return ref_words, hyp_words, ref_context_span, hyp_context_span
121
+
122
+
123
+ def extract_substitution_samples(wer_word_output) -> list[SubsSample]:
124
+ subs_samples = []
125
+ prev_chunk = None
126
+ all_chunks = wer_word_output.alignments[0]
127
+ for chunk, next_chunk in zip(all_chunks, all_chunks[1:] + [None]):
128
+ sample_to_store = None
129
+
130
+ if chunk.type in ["delete", "insert"]:
131
+ if prev_chunk and prev_chunk.type in ["substitute"]:
132
+ ref_words, hyp_words, ref_context_span, hyp_context_span = (
133
+ get_aligned_chunk_words(wer_word_output, prev_chunk)
134
+ )
135
+ prev_sample = SubsSample(
136
+ ref_context_span=ref_context_span,
137
+ hyp_context_span=hyp_context_span,
138
+ ref=ref_words,
139
+ hyp=hyp_words,
140
+ )
141
+
142
+ ref_words, hyp_words, ref_context_span, hyp_context_span = (
143
+ get_aligned_chunk_words(wer_word_output, chunk)
144
+ )
145
+ sample = SubsSample(
146
+ ref_context_span=ref_context_span,
147
+ hyp_context_span=hyp_context_span,
148
+ ref=ref_words,
149
+ hyp=hyp_words,
150
+ )
151
+ sample_to_store = merge_sub_samples([prev_sample, sample])
152
+
153
+ if chunk.type == "substitute":
154
+ if next_chunk and next_chunk.type in ["insert", "delete"]:
155
+ pass # allow the next chunk to capture this chunk
156
+ else:
157
+ prev_sample = None
158
+ if prev_chunk and prev_chunk.type in ["insert", "delete"]:
159
+ ref_words, hyp_words, ref_context_span, hyp_context_span = (
160
+ get_aligned_chunk_words(wer_word_output, prev_chunk)
161
+ )
162
+ prev_sample = SubsSample(
163
+ ref_context_span=ref_context_span,
164
+ hyp_context_span=hyp_context_span,
165
+ ref=ref_words,
166
+ hyp=hyp_words,
167
+ )
168
+
169
+ ref_words, hyp_words, ref_context_span, hyp_context_span = (
170
+ get_aligned_chunk_words(wer_word_output, chunk)
171
+ )
172
+ sample = SubsSample(
173
+ ref_context_span=ref_context_span,
174
+ hyp_context_span=hyp_context_span,
175
+ ref=ref_words,
176
+ hyp=hyp_words,
177
+ )
178
+ sample_to_store = (
179
+ merge_sub_samples([prev_sample, sample]) if prev_sample else sample
180
+ )
181
+
182
+ if sample_to_store:
183
+ subs_samples.append(sample_to_store)
184
+ prev_chunk = None # consume once
185
+ else:
186
+ prev_chunk = chunk
187
+
188
+ return subs_samples