Spaces:
Running
Running
from dataclasses import dataclass | |
import re | |
from hebrew import Hebrew | |
from jiwer import process_words | |
from transformers.models.whisper.english_normalizer import BasicTextNormalizer | |
class HebrewTextNormalizer(BasicTextNormalizer): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
superfluous_chars_to_remove = "\u061c" # Arabic letter mark | |
superfluous_chars_to_remove += ( | |
"\u200b\u200c\u200d" # Zero-width space, non-joiner, joiner | |
) | |
superfluous_chars_to_remove += "\u200e\u200f" # LTR and RTL marks | |
superfluous_chars_to_remove += ( | |
"\u202a\u202b\u202c\u202d\u202e" # LTR/RTL embedding, pop, override | |
) | |
superfluous_chars_to_remove += "\u2066\u2067\u2068\u2069" # Isolate controls | |
superfluous_chars_to_remove += "\ufeff" # Zero-width no-break space | |
self.superfluous_hebrew_unicode_symbols_translator = str.maketrans( | |
{ord(c): None for c in superfluous_chars_to_remove} | |
) | |
quotes_str = "\"'״׳" | |
self.quotes_translator = str.maketrans({ord(c): None for c in quotes_str}) | |
self.pre_quote_hyphen_removal_pattern = re.compile(f"-[{quotes_str}]") | |
def __remove_niqqud(self, text: str) -> str: | |
return Hebrew(text).no_niqqud().string | |
def __remove_superfluous_hebrew_unicode_symbols(self, text: str) -> str: | |
return text.translate(self.superfluous_hebrew_unicode_symbols_translator) | |
def __remove_quotes(self, text: str) -> str: | |
text = self.pre_quote_hyphen_removal_pattern.sub("", text) | |
return text.translate(self.quotes_translator) | |
def __call__(self, text): | |
text = self.__remove_niqqud(text) | |
text = self.__remove_superfluous_hebrew_unicode_symbols(text) | |
text = self.__remove_quotes(text) | |
text = super().__call__(text) | |
return text | |
context_expansion_size = 4 | |
class SubsSample: | |
ref_context_span: tuple[int, int] | |
hyp_context_span: tuple[int, int] | |
ref: list[str] | |
hyp: list[str] | |
def merge_spans(span1, span2): | |
return (min(span1[0], span2[0]), max(span1[1], span2[1])) | |
def merge_sub_samples(sub_samples: list[SubsSample]): | |
merged_sample = None | |
for sample in sub_samples: | |
if not merged_sample: | |
merged_sample = sample | |
continue | |
merged_sample = SubsSample( | |
ref_context_span=merge_spans( | |
merged_sample.ref_context_span, sample.ref_context_span | |
), | |
hyp_context_span=merge_spans( | |
merged_sample.hyp_context_span, sample.hyp_context_span | |
), | |
ref=merged_sample.ref + sample.ref, | |
hyp=merged_sample.hyp + sample.hyp, | |
) | |
return merged_sample | |
def get_aligned_chunk_words(wer_word_output, chunk): | |
ref_words = None | |
hyp_words = None | |
ref_context_span = [ | |
max(0, chunk.ref_start_idx - context_expansion_size), | |
min( | |
chunk.ref_end_idx + context_expansion_size, | |
len(wer_word_output.references[0]), | |
), | |
] | |
hyp_context_span = [ | |
max(0, chunk.hyp_start_idx - context_expansion_size), | |
min( | |
chunk.hyp_end_idx + context_expansion_size, | |
len(wer_word_output.hypotheses[0]), | |
), | |
] | |
if chunk.type == "equal": | |
ref_words = wer_word_output.references[0][ | |
chunk.ref_start_idx : chunk.ref_end_idx | |
] | |
hyp_words = wer_word_output.hypotheses[0][ | |
chunk.hyp_start_idx : chunk.hyp_end_idx | |
] | |
elif chunk.type == "delete": | |
ref_words = wer_word_output.references[0][ | |
chunk.ref_start_idx : chunk.ref_end_idx | |
] | |
hyp_words = [""] * len(ref_words) | |
elif chunk.type == "insert": | |
hyp_words = wer_word_output.hypotheses[0][ | |
chunk.hyp_start_idx : chunk.hyp_end_idx | |
] | |
ref_words = [""] * len(hyp_words) | |
elif chunk.type == "substitute": | |
ref_words = wer_word_output.references[0][ | |
chunk.ref_start_idx : chunk.ref_end_idx | |
] | |
hyp_words = wer_word_output.hypotheses[0][ | |
chunk.hyp_start_idx : chunk.hyp_end_idx | |
] | |
return ref_words, hyp_words, ref_context_span, hyp_context_span | |
def extract_substitution_samples(wer_word_output) -> list[SubsSample]: | |
subs_samples = [] | |
prev_chunk = None | |
all_chunks = wer_word_output.alignments[0] | |
for chunk, next_chunk in zip(all_chunks, all_chunks[1:] + [None]): | |
sample_to_store = None | |
if chunk.type in ["delete", "insert"]: | |
if prev_chunk and prev_chunk.type in ["substitute"]: | |
ref_words, hyp_words, ref_context_span, hyp_context_span = ( | |
get_aligned_chunk_words(wer_word_output, prev_chunk) | |
) | |
prev_sample = SubsSample( | |
ref_context_span=ref_context_span, | |
hyp_context_span=hyp_context_span, | |
ref=ref_words, | |
hyp=hyp_words, | |
) | |
ref_words, hyp_words, ref_context_span, hyp_context_span = ( | |
get_aligned_chunk_words(wer_word_output, chunk) | |
) | |
sample = SubsSample( | |
ref_context_span=ref_context_span, | |
hyp_context_span=hyp_context_span, | |
ref=ref_words, | |
hyp=hyp_words, | |
) | |
sample_to_store = merge_sub_samples([prev_sample, sample]) | |
if chunk.type == "substitute": | |
if next_chunk and next_chunk.type in ["insert", "delete"]: | |
pass # allow the next chunk to capture this chunk | |
else: | |
prev_sample = None | |
if prev_chunk and prev_chunk.type in ["insert", "delete"]: | |
ref_words, hyp_words, ref_context_span, hyp_context_span = ( | |
get_aligned_chunk_words(wer_word_output, prev_chunk) | |
) | |
prev_sample = SubsSample( | |
ref_context_span=ref_context_span, | |
hyp_context_span=hyp_context_span, | |
ref=ref_words, | |
hyp=hyp_words, | |
) | |
ref_words, hyp_words, ref_context_span, hyp_context_span = ( | |
get_aligned_chunk_words(wer_word_output, chunk) | |
) | |
sample = SubsSample( | |
ref_context_span=ref_context_span, | |
hyp_context_span=hyp_context_span, | |
ref=ref_words, | |
hyp=hyp_words, | |
) | |
sample_to_store = ( | |
merge_sub_samples([prev_sample, sample]) if prev_sample else sample | |
) | |
if sample_to_store: | |
subs_samples.append(sample_to_store) | |
prev_chunk = None # consume once | |
else: | |
prev_chunk = chunk | |
return subs_samples | |