Spaces:
Running
Running
File size: 7,035 Bytes
7b3ae60 c962cbf 7b3ae60 2f5cf2f 7b3ae60 2f5cf2f c962cbf 2f5cf2f c962cbf 2f5cf2f 7b3ae60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
from dataclasses import dataclass
import re
from hebrew import Hebrew
from jiwer import process_words
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
class HebrewTextNormalizer(BasicTextNormalizer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
superfluous_chars_to_remove = "\u061c" # Arabic letter mark
superfluous_chars_to_remove += (
"\u200b\u200c\u200d" # Zero-width space, non-joiner, joiner
)
superfluous_chars_to_remove += "\u200e\u200f" # LTR and RTL marks
superfluous_chars_to_remove += (
"\u202a\u202b\u202c\u202d\u202e" # LTR/RTL embedding, pop, override
)
superfluous_chars_to_remove += "\u2066\u2067\u2068\u2069" # Isolate controls
superfluous_chars_to_remove += "\ufeff" # Zero-width no-break space
self.superfluous_hebrew_unicode_symbols_translator = str.maketrans(
{ord(c): None for c in superfluous_chars_to_remove}
)
quotes_str = "\"'״׳"
self.quotes_translator = str.maketrans({ord(c): None for c in quotes_str})
self.pre_quote_hyphen_removal_pattern = re.compile(f"-[{quotes_str}]")
def __remove_niqqud(self, text: str) -> str:
return Hebrew(text).no_niqqud().string
def __remove_superfluous_hebrew_unicode_symbols(self, text: str) -> str:
return text.translate(self.superfluous_hebrew_unicode_symbols_translator)
def __remove_quotes(self, text: str) -> str:
text = self.pre_quote_hyphen_removal_pattern.sub("", text)
return text.translate(self.quotes_translator)
def __call__(self, text):
text = self.__remove_niqqud(text)
text = self.__remove_superfluous_hebrew_unicode_symbols(text)
text = self.__remove_quotes(text)
text = super().__call__(text)
return text
context_expansion_size = 4
@dataclass
class SubsSample:
ref_context_span: tuple[int, int]
hyp_context_span: tuple[int, int]
ref: list[str]
hyp: list[str]
def merge_spans(span1, span2):
return (min(span1[0], span2[0]), max(span1[1], span2[1]))
def merge_sub_samples(sub_samples: list[SubsSample]):
merged_sample = None
for sample in sub_samples:
if not merged_sample:
merged_sample = sample
continue
merged_sample = SubsSample(
ref_context_span=merge_spans(
merged_sample.ref_context_span, sample.ref_context_span
),
hyp_context_span=merge_spans(
merged_sample.hyp_context_span, sample.hyp_context_span
),
ref=merged_sample.ref + sample.ref,
hyp=merged_sample.hyp + sample.hyp,
)
return merged_sample
def get_aligned_chunk_words(wer_word_output, chunk):
ref_words = None
hyp_words = None
ref_context_span = [
max(0, chunk.ref_start_idx - context_expansion_size),
min(
chunk.ref_end_idx + context_expansion_size,
len(wer_word_output.references[0]),
),
]
hyp_context_span = [
max(0, chunk.hyp_start_idx - context_expansion_size),
min(
chunk.hyp_end_idx + context_expansion_size,
len(wer_word_output.hypotheses[0]),
),
]
if chunk.type == "equal":
ref_words = wer_word_output.references[0][
chunk.ref_start_idx : chunk.ref_end_idx
]
hyp_words = wer_word_output.hypotheses[0][
chunk.hyp_start_idx : chunk.hyp_end_idx
]
elif chunk.type == "delete":
ref_words = wer_word_output.references[0][
chunk.ref_start_idx : chunk.ref_end_idx
]
hyp_words = [""] * len(ref_words)
elif chunk.type == "insert":
hyp_words = wer_word_output.hypotheses[0][
chunk.hyp_start_idx : chunk.hyp_end_idx
]
ref_words = [""] * len(hyp_words)
elif chunk.type == "substitute":
ref_words = wer_word_output.references[0][
chunk.ref_start_idx : chunk.ref_end_idx
]
hyp_words = wer_word_output.hypotheses[0][
chunk.hyp_start_idx : chunk.hyp_end_idx
]
return ref_words, hyp_words, ref_context_span, hyp_context_span
def extract_substitution_samples(wer_word_output) -> list[SubsSample]:
subs_samples = []
prev_chunk = None
all_chunks = wer_word_output.alignments[0]
for chunk, next_chunk in zip(all_chunks, all_chunks[1:] + [None]):
sample_to_store = None
if chunk.type in ["delete", "insert"]:
if prev_chunk and prev_chunk.type in ["substitute"]:
ref_words, hyp_words, ref_context_span, hyp_context_span = (
get_aligned_chunk_words(wer_word_output, prev_chunk)
)
prev_sample = SubsSample(
ref_context_span=ref_context_span,
hyp_context_span=hyp_context_span,
ref=ref_words,
hyp=hyp_words,
)
ref_words, hyp_words, ref_context_span, hyp_context_span = (
get_aligned_chunk_words(wer_word_output, chunk)
)
sample = SubsSample(
ref_context_span=ref_context_span,
hyp_context_span=hyp_context_span,
ref=ref_words,
hyp=hyp_words,
)
sample_to_store = merge_sub_samples([prev_sample, sample])
if chunk.type == "substitute":
if next_chunk and next_chunk.type in ["insert", "delete"]:
pass # allow the next chunk to capture this chunk
else:
prev_sample = None
if prev_chunk and prev_chunk.type in ["insert", "delete"]:
ref_words, hyp_words, ref_context_span, hyp_context_span = (
get_aligned_chunk_words(wer_word_output, prev_chunk)
)
prev_sample = SubsSample(
ref_context_span=ref_context_span,
hyp_context_span=hyp_context_span,
ref=ref_words,
hyp=hyp_words,
)
ref_words, hyp_words, ref_context_span, hyp_context_span = (
get_aligned_chunk_words(wer_word_output, chunk)
)
sample = SubsSample(
ref_context_span=ref_context_span,
hyp_context_span=hyp_context_span,
ref=ref_words,
hyp=hyp_words,
)
sample_to_store = (
merge_sub_samples([prev_sample, sample]) if prev_sample else sample
)
if sample_to_store:
subs_samples.append(sample_to_store)
prev_chunk = None # consume once
else:
prev_chunk = chunk
return subs_samples
|