File size: 7,035 Bytes
7b3ae60
c962cbf
7b3ae60
2f5cf2f
7b3ae60
2f5cf2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c962cbf
 
 
2f5cf2f
 
 
 
 
 
 
 
c962cbf
2f5cf2f
 
 
 
 
 
 
 
7b3ae60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
from dataclasses import dataclass
import re

from hebrew import Hebrew
from jiwer import process_words
from transformers.models.whisper.english_normalizer import BasicTextNormalizer


class HebrewTextNormalizer(BasicTextNormalizer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        superfluous_chars_to_remove = "\u061c"  # Arabic letter mark
        superfluous_chars_to_remove += (
            "\u200b\u200c\u200d"  # Zero-width space, non-joiner, joiner
        )
        superfluous_chars_to_remove += "\u200e\u200f"  # LTR and RTL marks
        superfluous_chars_to_remove += (
            "\u202a\u202b\u202c\u202d\u202e"  # LTR/RTL embedding, pop, override
        )
        superfluous_chars_to_remove += "\u2066\u2067\u2068\u2069"  # Isolate controls
        superfluous_chars_to_remove += "\ufeff"  # Zero-width no-break space
        self.superfluous_hebrew_unicode_symbols_translator = str.maketrans(
            {ord(c): None for c in superfluous_chars_to_remove}
        )

        quotes_str = "\"'״׳"
        self.quotes_translator = str.maketrans({ord(c): None for c in quotes_str})
        self.pre_quote_hyphen_removal_pattern = re.compile(f"-[{quotes_str}]")

    def __remove_niqqud(self, text: str) -> str:
        return Hebrew(text).no_niqqud().string

    def __remove_superfluous_hebrew_unicode_symbols(self, text: str) -> str:
        return text.translate(self.superfluous_hebrew_unicode_symbols_translator)

    def __remove_quotes(self, text: str) -> str:
        text = self.pre_quote_hyphen_removal_pattern.sub("", text)
        return text.translate(self.quotes_translator)

    def __call__(self, text):
        text = self.__remove_niqqud(text)
        text = self.__remove_superfluous_hebrew_unicode_symbols(text)
        text = self.__remove_quotes(text)
        text = super().__call__(text)
        return text


context_expansion_size = 4


@dataclass
class SubsSample:
    ref_context_span: tuple[int, int]
    hyp_context_span: tuple[int, int]
    ref: list[str]
    hyp: list[str]


def merge_spans(span1, span2):
    return (min(span1[0], span2[0]), max(span1[1], span2[1]))


def merge_sub_samples(sub_samples: list[SubsSample]):
    merged_sample = None
    for sample in sub_samples:
        if not merged_sample:
            merged_sample = sample
            continue
        merged_sample = SubsSample(
            ref_context_span=merge_spans(
                merged_sample.ref_context_span, sample.ref_context_span
            ),
            hyp_context_span=merge_spans(
                merged_sample.hyp_context_span, sample.hyp_context_span
            ),
            ref=merged_sample.ref + sample.ref,
            hyp=merged_sample.hyp + sample.hyp,
        )
    return merged_sample


def get_aligned_chunk_words(wer_word_output, chunk):
    ref_words = None
    hyp_words = None
    ref_context_span = [
        max(0, chunk.ref_start_idx - context_expansion_size),
        min(
            chunk.ref_end_idx + context_expansion_size,
            len(wer_word_output.references[0]),
        ),
    ]
    hyp_context_span = [
        max(0, chunk.hyp_start_idx - context_expansion_size),
        min(
            chunk.hyp_end_idx + context_expansion_size,
            len(wer_word_output.hypotheses[0]),
        ),
    ]
    if chunk.type == "equal":
        ref_words = wer_word_output.references[0][
            chunk.ref_start_idx : chunk.ref_end_idx
        ]
        hyp_words = wer_word_output.hypotheses[0][
            chunk.hyp_start_idx : chunk.hyp_end_idx
        ]
    elif chunk.type == "delete":
        ref_words = wer_word_output.references[0][
            chunk.ref_start_idx : chunk.ref_end_idx
        ]
        hyp_words = [""] * len(ref_words)
    elif chunk.type == "insert":
        hyp_words = wer_word_output.hypotheses[0][
            chunk.hyp_start_idx : chunk.hyp_end_idx
        ]
        ref_words = [""] * len(hyp_words)
    elif chunk.type == "substitute":
        ref_words = wer_word_output.references[0][
            chunk.ref_start_idx : chunk.ref_end_idx
        ]
        hyp_words = wer_word_output.hypotheses[0][
            chunk.hyp_start_idx : chunk.hyp_end_idx
        ]
    return ref_words, hyp_words, ref_context_span, hyp_context_span


def extract_substitution_samples(wer_word_output) -> list[SubsSample]:
    subs_samples = []
    prev_chunk = None
    all_chunks = wer_word_output.alignments[0]
    for chunk, next_chunk in zip(all_chunks, all_chunks[1:] + [None]):
        sample_to_store = None

        if chunk.type in ["delete", "insert"]:
            if prev_chunk and prev_chunk.type in ["substitute"]:
                ref_words, hyp_words, ref_context_span, hyp_context_span = (
                    get_aligned_chunk_words(wer_word_output, prev_chunk)
                )
                prev_sample = SubsSample(
                    ref_context_span=ref_context_span,
                    hyp_context_span=hyp_context_span,
                    ref=ref_words,
                    hyp=hyp_words,
                )

                ref_words, hyp_words, ref_context_span, hyp_context_span = (
                    get_aligned_chunk_words(wer_word_output, chunk)
                )
                sample = SubsSample(
                    ref_context_span=ref_context_span,
                    hyp_context_span=hyp_context_span,
                    ref=ref_words,
                    hyp=hyp_words,
                )
                sample_to_store = merge_sub_samples([prev_sample, sample])

        if chunk.type == "substitute":
            if next_chunk and next_chunk.type in ["insert", "delete"]:
                pass  # allow the next chunk to capture this chunk
            else:
                prev_sample = None
                if prev_chunk and prev_chunk.type in ["insert", "delete"]:
                    ref_words, hyp_words, ref_context_span, hyp_context_span = (
                        get_aligned_chunk_words(wer_word_output, prev_chunk)
                    )
                    prev_sample = SubsSample(
                        ref_context_span=ref_context_span,
                        hyp_context_span=hyp_context_span,
                        ref=ref_words,
                        hyp=hyp_words,
                    )

                ref_words, hyp_words, ref_context_span, hyp_context_span = (
                    get_aligned_chunk_words(wer_word_output, chunk)
                )
                sample = SubsSample(
                    ref_context_span=ref_context_span,
                    hyp_context_span=hyp_context_span,
                    ref=ref_words,
                    hyp=hyp_words,
                )
                sample_to_store = (
                    merge_sub_samples([prev_sample, sample]) if prev_sample else sample
                )

        if sample_to_store:
            subs_samples.append(sample_to_store)
            prev_chunk = None  # consume once
        else:
            prev_chunk = chunk

    return subs_samples