from dataclasses import dataclass import re from hebrew import Hebrew from jiwer import process_words from transformers.models.whisper.english_normalizer import BasicTextNormalizer class HebrewTextNormalizer(BasicTextNormalizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) superfluous_chars_to_remove = "\u061c" # Arabic letter mark superfluous_chars_to_remove += ( "\u200b\u200c\u200d" # Zero-width space, non-joiner, joiner ) superfluous_chars_to_remove += "\u200e\u200f" # LTR and RTL marks superfluous_chars_to_remove += ( "\u202a\u202b\u202c\u202d\u202e" # LTR/RTL embedding, pop, override ) superfluous_chars_to_remove += "\u2066\u2067\u2068\u2069" # Isolate controls superfluous_chars_to_remove += "\ufeff" # Zero-width no-break space self.superfluous_hebrew_unicode_symbols_translator = str.maketrans( {ord(c): None for c in superfluous_chars_to_remove} ) quotes_str = "\"'״׳" self.quotes_translator = str.maketrans({ord(c): None for c in quotes_str}) self.pre_quote_hyphen_removal_pattern = re.compile(f"-[{quotes_str}]") def __remove_niqqud(self, text: str) -> str: return Hebrew(text).no_niqqud().string def __remove_superfluous_hebrew_unicode_symbols(self, text: str) -> str: return text.translate(self.superfluous_hebrew_unicode_symbols_translator) def __remove_quotes(self, text: str) -> str: text = self.pre_quote_hyphen_removal_pattern.sub("", text) return text.translate(self.quotes_translator) def __call__(self, text): text = self.__remove_niqqud(text) text = self.__remove_superfluous_hebrew_unicode_symbols(text) text = self.__remove_quotes(text) text = super().__call__(text) return text context_expansion_size = 4 @dataclass class SubsSample: ref_context_span: tuple[int, int] hyp_context_span: tuple[int, int] ref: list[str] hyp: list[str] def merge_spans(span1, span2): return (min(span1[0], span2[0]), max(span1[1], span2[1])) def merge_sub_samples(sub_samples: list[SubsSample]): merged_sample = None for sample in sub_samples: if not merged_sample: merged_sample = sample continue merged_sample = SubsSample( ref_context_span=merge_spans( merged_sample.ref_context_span, sample.ref_context_span ), hyp_context_span=merge_spans( merged_sample.hyp_context_span, sample.hyp_context_span ), ref=merged_sample.ref + sample.ref, hyp=merged_sample.hyp + sample.hyp, ) return merged_sample def get_aligned_chunk_words(wer_word_output, chunk): ref_words = None hyp_words = None ref_context_span = [ max(0, chunk.ref_start_idx - context_expansion_size), min( chunk.ref_end_idx + context_expansion_size, len(wer_word_output.references[0]), ), ] hyp_context_span = [ max(0, chunk.hyp_start_idx - context_expansion_size), min( chunk.hyp_end_idx + context_expansion_size, len(wer_word_output.hypotheses[0]), ), ] if chunk.type == "equal": ref_words = wer_word_output.references[0][ chunk.ref_start_idx : chunk.ref_end_idx ] hyp_words = wer_word_output.hypotheses[0][ chunk.hyp_start_idx : chunk.hyp_end_idx ] elif chunk.type == "delete": ref_words = wer_word_output.references[0][ chunk.ref_start_idx : chunk.ref_end_idx ] hyp_words = [""] * len(ref_words) elif chunk.type == "insert": hyp_words = wer_word_output.hypotheses[0][ chunk.hyp_start_idx : chunk.hyp_end_idx ] ref_words = [""] * len(hyp_words) elif chunk.type == "substitute": ref_words = wer_word_output.references[0][ chunk.ref_start_idx : chunk.ref_end_idx ] hyp_words = wer_word_output.hypotheses[0][ chunk.hyp_start_idx : chunk.hyp_end_idx ] return ref_words, hyp_words, ref_context_span, hyp_context_span def extract_substitution_samples(wer_word_output) -> list[SubsSample]: subs_samples = [] prev_chunk = None all_chunks = wer_word_output.alignments[0] for chunk, next_chunk in zip(all_chunks, all_chunks[1:] + [None]): sample_to_store = None if chunk.type in ["delete", "insert"]: if prev_chunk and prev_chunk.type in ["substitute"]: ref_words, hyp_words, ref_context_span, hyp_context_span = ( get_aligned_chunk_words(wer_word_output, prev_chunk) ) prev_sample = SubsSample( ref_context_span=ref_context_span, hyp_context_span=hyp_context_span, ref=ref_words, hyp=hyp_words, ) ref_words, hyp_words, ref_context_span, hyp_context_span = ( get_aligned_chunk_words(wer_word_output, chunk) ) sample = SubsSample( ref_context_span=ref_context_span, hyp_context_span=hyp_context_span, ref=ref_words, hyp=hyp_words, ) sample_to_store = merge_sub_samples([prev_sample, sample]) if chunk.type == "substitute": if next_chunk and next_chunk.type in ["insert", "delete"]: pass # allow the next chunk to capture this chunk else: prev_sample = None if prev_chunk and prev_chunk.type in ["insert", "delete"]: ref_words, hyp_words, ref_context_span, hyp_context_span = ( get_aligned_chunk_words(wer_word_output, prev_chunk) ) prev_sample = SubsSample( ref_context_span=ref_context_span, hyp_context_span=hyp_context_span, ref=ref_words, hyp=hyp_words, ) ref_words, hyp_words, ref_context_span, hyp_context_span = ( get_aligned_chunk_words(wer_word_output, chunk) ) sample = SubsSample( ref_context_span=ref_context_span, hyp_context_span=hyp_context_span, ref=ref_words, hyp=hyp_words, ) sample_to_store = ( merge_sub_samples([prev_sample, sample]) if prev_sample else sample ) if sample_to_store: subs_samples.append(sample_to_store) prev_chunk = None # consume once else: prev_chunk = chunk return subs_samples