Yoad
Improve hebrew normalizer
c962cbf
from dataclasses import dataclass
import re
from hebrew import Hebrew
from jiwer import process_words
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
class HebrewTextNormalizer(BasicTextNormalizer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
superfluous_chars_to_remove = "\u061c" # Arabic letter mark
superfluous_chars_to_remove += (
"\u200b\u200c\u200d" # Zero-width space, non-joiner, joiner
)
superfluous_chars_to_remove += "\u200e\u200f" # LTR and RTL marks
superfluous_chars_to_remove += (
"\u202a\u202b\u202c\u202d\u202e" # LTR/RTL embedding, pop, override
)
superfluous_chars_to_remove += "\u2066\u2067\u2068\u2069" # Isolate controls
superfluous_chars_to_remove += "\ufeff" # Zero-width no-break space
self.superfluous_hebrew_unicode_symbols_translator = str.maketrans(
{ord(c): None for c in superfluous_chars_to_remove}
)
quotes_str = "\"'״׳"
self.quotes_translator = str.maketrans({ord(c): None for c in quotes_str})
self.pre_quote_hyphen_removal_pattern = re.compile(f"-[{quotes_str}]")
def __remove_niqqud(self, text: str) -> str:
return Hebrew(text).no_niqqud().string
def __remove_superfluous_hebrew_unicode_symbols(self, text: str) -> str:
return text.translate(self.superfluous_hebrew_unicode_symbols_translator)
def __remove_quotes(self, text: str) -> str:
text = self.pre_quote_hyphen_removal_pattern.sub("", text)
return text.translate(self.quotes_translator)
def __call__(self, text):
text = self.__remove_niqqud(text)
text = self.__remove_superfluous_hebrew_unicode_symbols(text)
text = self.__remove_quotes(text)
text = super().__call__(text)
return text
context_expansion_size = 4
@dataclass
class SubsSample:
ref_context_span: tuple[int, int]
hyp_context_span: tuple[int, int]
ref: list[str]
hyp: list[str]
def merge_spans(span1, span2):
return (min(span1[0], span2[0]), max(span1[1], span2[1]))
def merge_sub_samples(sub_samples: list[SubsSample]):
merged_sample = None
for sample in sub_samples:
if not merged_sample:
merged_sample = sample
continue
merged_sample = SubsSample(
ref_context_span=merge_spans(
merged_sample.ref_context_span, sample.ref_context_span
),
hyp_context_span=merge_spans(
merged_sample.hyp_context_span, sample.hyp_context_span
),
ref=merged_sample.ref + sample.ref,
hyp=merged_sample.hyp + sample.hyp,
)
return merged_sample
def get_aligned_chunk_words(wer_word_output, chunk):
ref_words = None
hyp_words = None
ref_context_span = [
max(0, chunk.ref_start_idx - context_expansion_size),
min(
chunk.ref_end_idx + context_expansion_size,
len(wer_word_output.references[0]),
),
]
hyp_context_span = [
max(0, chunk.hyp_start_idx - context_expansion_size),
min(
chunk.hyp_end_idx + context_expansion_size,
len(wer_word_output.hypotheses[0]),
),
]
if chunk.type == "equal":
ref_words = wer_word_output.references[0][
chunk.ref_start_idx : chunk.ref_end_idx
]
hyp_words = wer_word_output.hypotheses[0][
chunk.hyp_start_idx : chunk.hyp_end_idx
]
elif chunk.type == "delete":
ref_words = wer_word_output.references[0][
chunk.ref_start_idx : chunk.ref_end_idx
]
hyp_words = [""] * len(ref_words)
elif chunk.type == "insert":
hyp_words = wer_word_output.hypotheses[0][
chunk.hyp_start_idx : chunk.hyp_end_idx
]
ref_words = [""] * len(hyp_words)
elif chunk.type == "substitute":
ref_words = wer_word_output.references[0][
chunk.ref_start_idx : chunk.ref_end_idx
]
hyp_words = wer_word_output.hypotheses[0][
chunk.hyp_start_idx : chunk.hyp_end_idx
]
return ref_words, hyp_words, ref_context_span, hyp_context_span
def extract_substitution_samples(wer_word_output) -> list[SubsSample]:
subs_samples = []
prev_chunk = None
all_chunks = wer_word_output.alignments[0]
for chunk, next_chunk in zip(all_chunks, all_chunks[1:] + [None]):
sample_to_store = None
if chunk.type in ["delete", "insert"]:
if prev_chunk and prev_chunk.type in ["substitute"]:
ref_words, hyp_words, ref_context_span, hyp_context_span = (
get_aligned_chunk_words(wer_word_output, prev_chunk)
)
prev_sample = SubsSample(
ref_context_span=ref_context_span,
hyp_context_span=hyp_context_span,
ref=ref_words,
hyp=hyp_words,
)
ref_words, hyp_words, ref_context_span, hyp_context_span = (
get_aligned_chunk_words(wer_word_output, chunk)
)
sample = SubsSample(
ref_context_span=ref_context_span,
hyp_context_span=hyp_context_span,
ref=ref_words,
hyp=hyp_words,
)
sample_to_store = merge_sub_samples([prev_sample, sample])
if chunk.type == "substitute":
if next_chunk and next_chunk.type in ["insert", "delete"]:
pass # allow the next chunk to capture this chunk
else:
prev_sample = None
if prev_chunk and prev_chunk.type in ["insert", "delete"]:
ref_words, hyp_words, ref_context_span, hyp_context_span = (
get_aligned_chunk_words(wer_word_output, prev_chunk)
)
prev_sample = SubsSample(
ref_context_span=ref_context_span,
hyp_context_span=hyp_context_span,
ref=ref_words,
hyp=hyp_words,
)
ref_words, hyp_words, ref_context_span, hyp_context_span = (
get_aligned_chunk_words(wer_word_output, chunk)
)
sample = SubsSample(
ref_context_span=ref_context_span,
hyp_context_span=hyp_context_span,
ref=ref_words,
hyp=hyp_words,
)
sample_to_store = (
merge_sub_samples([prev_sample, sample]) if prev_sample else sample
)
if sample_to_store:
subs_samples.append(sample_to_store)
prev_chunk = None # consume once
else:
prev_chunk = chunk
return subs_samples