from ast import literal_eval as ast_literal_eval from time import time as get_time from math import log as math_log from re import compile as re_compile,sub as re_sub from json import load as json_load from argparse import Namespace from collections import namedtuple from urduhack import normalize as shahmukhi_normalize from indicnlp.normalize.indic_normalize import IndicNormalizerFactory from fairseq import checkpoint_utils, options, tasks, utils from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.token_generation_constraints import pack_constraints, unpack_constraints from fairseq_cli.generate import get_symbols_to_strip_from_output RTL_LANG_CODES = { 'ks', 'pnb', 'sd', 'skr', 'ur', 'dv' } LANG_CODE_TO_SCRIPT_CODE = { "as" : "Beng", "bn" : "Beng", "doi" : "Deva", "dv" : "Thaa", "gom" : "Deva", "gu" : "Gujr", "hi" : "Deva", "ks" : "Aran", "mai" : "Deva", "mr" : "Deva", "ne" : "Deva", "or" : "Orya", "pa" : "Guru", "pnb" : "Aran", "sa" : "Deva", "sd" : "Arab", "si" : "Sinh", "skr" : "Aran", "ur" : "Aran", "kn" : "Knda", "ml" : "Mlym", "ta" : "Taml", "te" : "Telu", "brx" : "Deva", "mni" : "Mtei", "sat" : "Olck", "en" : "Latn", } SCRIPT_CODE_TO_UNICODE_CHARS_RANGE_STR = { "Beng": "\u0980-\u09FF", "Deva": "\u0900-\u097F", "Gujr": "\u0A80-\u0AFF", "Guru": "\u0A00-\u0A7F", "Orya": "\u0B00-\u0B7F", "Knda": "\u0C80-\u0CFF", "Mlym": "\u0D00-\u0D7F", "Sinh": "\u0D80-\u0DFF", "Taml": "\u0B80-\u0BFF", "Telu": "\u0C00-\u0C7F", "Mtei": "\uABC0-\uABFF", "Arab": "\u0600-\u06FF\u0750-\u077F\u0870-\u089F\u08A0-\u08FF", "Aran": "\u0600-\u06FF\u0750-\u077F\u0870-\u089F\u08A0-\u08FF", "Latn": "\u0041-\u005A\u0061-\u007A", "Olck": "\u1C50-\u1C7F", "Thaa": "\u0780-\u07BF", } INDIC_TO_LATIN_PUNCT = { '।': '.', '॥': "..", '෴': '.', '꫰': ',', '꯫': '.', '᱾': '.', '᱿': '..', '۔': '.', '؟': '?', '،': ',', '؛': ';', '۝': "..", } INDIC_TO_LATIN_PUNCT_TRANSLATOR = str.maketrans(INDIC_TO_LATIN_PUNCT) NON_LATIN_FULLSTOP_LANGS = { 'as' : '।', 'bn' : '।', 'brx': '।', 'doi': '।', 'hi' : '।', 'mai': '।', 'mni': '꯫', 'ne' : '।', 'or' : '।', 'pa' : '।', 'sa' : '।', 'sat': '᱾', 'ks' : '۔', 'pnb': '۔', 'skr': '۔', 'ur' : '۔', } ENDS_WITH_LATIN_FULLSTOP_REGEX = re_compile("(^|.*[^.])\.$") def nativize_latin_fullstop(text, lang_code): if lang_code in NON_LATIN_FULLSTOP_LANGS and ENDS_WITH_LATIN_FULLSTOP_REGEX.match(text): return text[:-1] + NON_LATIN_FULLSTOP_LANGS[lang_code] return text LATIN_TO_PERSOARABIC_PUNCTUATIONS = { '?': '؟', ',': '،', ';': '؛', } LATIN_TO_PERSOARABIC_PUNC_TRANSLATOR = str.maketrans(LATIN_TO_PERSOARABIC_PUNCTUATIONS) SCRIPT_CODE_TO_NUMERALS = { "Beng": "০১২৩৪৫৬৭৮৯", "Deva": "०१२३४५६७८९", "Gujr": "૦૧૨૩૪૫૬૭૮૯", "Guru": "੦੧੨੩੪੫੬੭੮੯", "Orya": "୦୧୨୩୪୫୬୭୮୯", "Knda": "೦೧೨೩೪೫೬೭೮೯", "Mlym": "൦൧൨൩൪൫൬൭൮൯", "Sinh": "෦෧෨෩෪෫෬෭෮෯", "Taml": "௦௧௨௩௪௫௬௭௮௯", "Telu": "౦౧౨౩౪౫౬౭౮౯", "Mtei": "꯰꯱꯲꯳꯴꯵꯶꯷꯸꯹", "Arab": "۰۱۲۳۴۵۶۷۸۹", "Aran": "۰۱۲۳۴۵۶۷۸۹", "Latn": "0123456789", "Olck": "᱐᱑᱒᱓᱔᱕᱖᱗᱘᱙", "Thaa": "٠١٢٣٤٥٦٧٨٩", } LANG_CODE_TO_NUMERALS = { lang_code: SCRIPT_CODE_TO_NUMERALS[script_code] for lang_code, script_code in LANG_CODE_TO_SCRIPT_CODE.items() } INDIC_TO_STANDARD_NUMERALS_GLOBAL_MAP = {} for lang_code, lang_numerals in LANG_CODE_TO_NUMERALS.items(): map_dict = {lang_numeral: en_numeral for lang_numeral, en_numeral in zip(lang_numerals, LANG_CODE_TO_NUMERALS["en"])} INDIC_TO_STANDARD_NUMERALS_GLOBAL_MAP.update(map_dict) INDIC_TO_STANDARD_NUMERALS_TRANSLATOR = str.maketrans(INDIC_TO_STANDARD_NUMERALS_GLOBAL_MAP) NATIVE_TO_LATIN_NUMERALS_TRANSLATORS = { lang_code: str.maketrans({lang_numeral: en_numeral for lang_numeral, en_numeral in zip(lang_numerals, LANG_CODE_TO_NUMERALS["en"])}) for lang_code, lang_numerals in LANG_CODE_TO_NUMERALS.items() if lang_code != "en" } LATIN_TO_NATIVE_NUMERALS_TRANSLATORS = { lang_code: str.maketrans({en_numeral: lang_numeral for en_numeral, lang_numeral in zip(LANG_CODE_TO_NUMERALS["en"], lang_numerals)}) for lang_code, lang_numerals in LANG_CODE_TO_NUMERALS.items() if lang_code != "en" } WORDFINAL_INDIC_VIRAMA_REGEX = re_compile("(\u09cd|\u094d|\u0acd|\u0a4d|\u0b4d|\u0ccd|\u0d4d|\u0dca|\u0bcd|\u0c4d|\uaaf6)$") def hardfix_wordfinal_virama(word): return WORDFINAL_INDIC_VIRAMA_REGEX.sub("\\1\u200c", word) ODIA_CONFUSING_YUKTAKSHARA_REGEX = re_compile("(\u0b4d)(ବ|ଵ|ୱ|ଯ|ୟ)") def fix_odia_confusing_ambiguous_yuktakshara(word): return ODIA_CONFUSING_YUKTAKSHARA_REGEX.sub("\\1\u200c\\2", word) LATIN_WORDFINAL_CONSONANTS_CHECKER_REGEX = re_compile(".*([bcdfghjklmnpqrstvwxyz])$") DEVANAGARI_WORDFINAL_CONSONANTS_REGEX = re_compile("([\u0915-\u0939\u0958-\u095f\u0979-\u097c\u097e-\u097f])$") def explicit_devanagari_wordfinal_schwa_delete(roman_word, indic_word): if LATIN_WORDFINAL_CONSONANTS_CHECKER_REGEX.match(roman_word): indic_word = DEVANAGARI_WORDFINAL_CONSONANTS_REGEX.sub("\\1\u094d", indic_word) return indic_word def rreplace(text, find_pattern, replace_pattern, match_count=1): splits = text.rsplit(find_pattern, match_count) return replace_pattern.join(splits) LANG_WORD_REGEXES = { lang_name: re_compile(f"[{SCRIPT_CODE_TO_UNICODE_CHARS_RANGE_STR[script_name]}]+") for lang_name, script_name in LANG_CODE_TO_SCRIPT_CODE.items() } normalizer_factory = IndicNormalizerFactory() Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints") Translation = namedtuple("Translation", "src_str hypos pos_scores alignments") def make_batches(lines, cfg, task, max_positions, encode_fn): def encode_fn_target(x): return encode_fn(x) if cfg.generation.constraints: batch_constraints = [list() for _ in lines] for i, line in enumerate(lines): if "\t" in line: lines[i], *batch_constraints[i] = line.split("\t") for i, constraint_list in enumerate(batch_constraints): batch_constraints[i] = [ task.target_dictionary.encode_line( encode_fn_target(constraint), append_eos=False, add_if_not_exist=False, ) for constraint in constraint_list ] if cfg.generation.constraints: constraints_tensor = pack_constraints(batch_constraints) else: constraints_tensor = None tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn) itr = task.get_batch_iterator( dataset=task.build_dataset_for_inference( tokens, lengths, constraints=constraints_tensor ), max_tokens=cfg.dataset.max_tokens, max_sentences=cfg.dataset.batch_size, max_positions=max_positions, ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, ).next_epoch_itr(shuffle=False) for batch in itr: ids = batch["id"] src_tokens = batch["net_input"]["src_tokens"] src_lengths = batch["net_input"]["src_lengths"] constraints = batch.get("constraints", None) yield Batch( ids=ids, src_tokens=src_tokens, src_lengths=src_lengths, constraints=constraints, ) class Transliterator: def __init__( self, data_bin_dir, model_checkpoint_path, lang_pairs_csv, lang_list_file, beam,device, batch_size = 32, ): self.parser = options.get_interactive_generation_parser() self.parser.set_defaults( path = model_checkpoint_path, num_wokers = -1, batch_size = batch_size, buffer_size = batch_size + 1, task = "translation_multi_simple_epoch", beam = beam, ) self.args = options.parse_args_and_arch(self.parser, input_args = [data_bin_dir] ) self.args.skip_invalid_size_inputs_valid_test = False self.args.lang_pairs = lang_pairs_csv self.args.lang_dict = lang_list_file self.cfg = convert_namespace_to_omegaconf(self.args) if isinstance(self.cfg, Namespace): self.cfg = convert_namespace_to_omegaconf(self.cfg) self.total_translate_time = 0 utils.import_user_module(self.cfg.common) if self.cfg.interactive.buffer_size < 1: self.cfg.interactive.buffer_size = 1 if self.cfg.dataset.max_tokens is None and self.cfg.dataset.batch_size is None: self.cfg.dataset.batch_size = 1 assert ( not self.cfg.generation.sampling or self.cfg.generation.nbest == self.cfg.generation.beam ), "--sampling requires --nbest to be equal to --beam" assert ( not self.cfg.dataset.batch_size or self.cfg.dataset.batch_size <= self.cfg.interactive.buffer_size ), "--batch-size cannot be larger than --buffer-size" self.use_cuda = device.type == "cuda" self.task = tasks.setup_task(self.cfg.task) overrides = ast_literal_eval(self.cfg.common_eval.model_overrides) self.models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(self.cfg.common_eval.path), arg_overrides=overrides, task=self.task, suffix=self.cfg.checkpoint.checkpoint_suffix, strict=(self.cfg.checkpoint.checkpoint_shard_count == 1), num_shards=self.cfg.checkpoint.checkpoint_shard_count, ) self.src_dict = self.task.source_dictionary self.tgt_dict = self.task.target_dictionary for i in range(len(self.models)): if self.models[i] is None: continue if self.cfg.common.fp16: self.models[i].half() if self.use_cuda and not self.cfg.distributed_training.pipeline_model_parallel: self.models[i].cuda() self.models[i].prepare_for_inference_(self.cfg) self.generator = self.task.build_generator(self.models, self.cfg.generation) self.tokenizer = self.task.build_tokenizer(self.cfg.tokenizer) self.bpe = self.task.build_bpe(self.cfg.bpe) self.align_dict = utils.load_align_dict(self.cfg.generation.replace_unk) self.max_positions = utils.resolve_max_positions( self.task.max_positions(), *[model.max_positions() for model in self.models] ) def encode_fn(self, x): if self.tokenizer is not None: x = self.tokenizer.encode(x) if self.bpe is not None: x = self.bpe.encode(x) return x def decode_fn(self, x): if self.bpe is not None: x = self.bpe.decode(x) if self.tokenizer is not None: x = self.tokenizer.decode(x) return x def translate(self, inputs, nbest=1): start_id = 0 results = [] for batch in make_batches(inputs, self.cfg, self.task, self.max_positions, self.encode_fn): bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths constraints = batch.constraints if self.use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() if constraints is not None: constraints = constraints.cuda() sample = { "net_input": { "src_tokens": src_tokens, "src_lengths": src_lengths, }, } translate_start_time = get_time() translations = self.task.inference_step( self.generator, self.models, sample, constraints=constraints ) translate_time = get_time() - translate_start_time self.total_translate_time += translate_time list_constraints = [[] for _ in range(bsz)] if self.cfg.generation.constraints: list_constraints = [unpack_constraints(c) for c in constraints] for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad()) constraints = list_constraints[i] results.append( ( start_id + id, src_tokens_i, hypos, { "constraints": constraints, "time": translate_time / len(translations), }, ) ) result_str = "" for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): src_str = "" if self.src_dict is not None: src_str = self.src_dict.string(src_tokens, self.cfg.common_eval.post_process) result_str += "S-{}\t{}".format(id_, src_str) + '\n' result_str += "W-{}\t{:.3f}\tseconds".format(id_, info["time"]) + '\n' for constraint in info["constraints"]: result_str += "C-{}\t{}".format( id_, self.tgt_dict.string(constraint, self.cfg.common_eval.post_process), ) + '\n' for hypo in hypos[: min(len(hypos), nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=self.align_dict, tgt_dict=self.tgt_dict, remove_bpe=self.cfg.common_eval.post_process, extra_symbols_to_ignore=get_symbols_to_strip_from_output(self.generator), ) detok_hypo_str = self.decode_fn(hypo_str) score = hypo["score"] / math_log(2) result_str += "H-{}\t{}\t{}".format(id_, score, hypo_str) + '\n' result_str += "D-{}\t{}\t{}".format(id_, score, detok_hypo_str) + '\n' result_str += "P-{}\t{}".format( id_, " ".join( map( lambda x: "{:.4f}".format(x), hypo["positional_scores"].div_(math_log(2)).tolist(), ) ), ) + '\n' if self.cfg.generation.print_alignment: alignment_str = " ".join( ["{}-{}".format(src, tgt) for src, tgt in alignment] ) result_str += "A-{}\t{}".format(id_, alignment_str) + '\n' return result_str class BaseEngineTransformer(): def __init__(self, word_prob_dicts_files,corpus_bin_dir,lang_list,model_file,tgt_langs, beam_width, rescore,device): self.all_supported_langs = {'as', 'bn', 'brx', 'gom', 'gu', 'hi', 'kn', 'ks', 'mai', 'ml', 'mni', 'mr', 'ne', 'or', 'pa', 'sa', 'sd', 'si', 'ta', 'te', 'ur'} print("Initializing Multilingual model for transliteration") if 'en' in tgt_langs: lang_pairs_csv = ','.join([lang+"-en" for lang in self.all_supported_langs]) else: lang_pairs_csv = ','.join(["en-"+lang for lang in self.all_supported_langs]) self.transliterator = Transliterator( corpus_bin_dir, model_file, lang_pairs_csv = lang_pairs_csv, lang_list_file=lang_list, device=device, beam = beam_width, batch_size = 32, ) self.beam_width = beam_width self._rescore = rescore if self._rescore: self.word_prob_dicts={lang:json_load(open(word_prob_dicts_files[lang])) for lang in tgt_langs} def indic_normalize(self, words, lang_code): if lang_code not in ['gom', 'ks', 'ur', 'mai', 'brx', 'mni']: normalizer = normalizer_factory.get_normalizer(lang_code) words = [ normalizer.normalize(word) for word in words ] if lang_code in ['mai', 'brx' ]: normalizer = normalizer_factory.get_normalizer('hi') words = [ normalizer.normalize(word) for word in words ] if lang_code in [ 'ur' ]: words = [ shahmukhi_normalize(word) for word in words ] if lang_code == 'gom': normalizer = normalizer_factory.get_normalizer('kK') words = [ normalizer.normalize(word) for word in words ] return words def pre_process(self, words, src_lang, tgt_lang): if src_lang != 'en': self.indic_normalize(words, src_lang) words = [' '.join(list(word.lower())) for word in words] lang_code = tgt_lang if src_lang == 'en' else src_lang words = ['__'+ lang_code +'__ ' + word for word in words] return words def rescore(self, res_dict, result_dict, tgt_lang, alpha ): alpha = alpha word_prob_dict = self.word_prob_dicts[tgt_lang] candidate_word_prob_norm_dict = {} candidate_word_result_norm_dict = {} input_data = {} for i in res_dict.keys(): input_data[res_dict[i]['S']] = [] for j in range(len(res_dict[i]['H'])): input_data[res_dict[i]['S']].append( res_dict[i]['H'][j][0] ) for src_word in input_data.keys(): candidates = input_data[src_word] candidates = [' '.join(word.split(' ')) for word in candidates] total_score = 0 if src_word.lower() in result_dict.keys(): for candidate_word in candidates: if candidate_word in result_dict[src_word.lower()].keys(): total_score += result_dict[src_word.lower()][candidate_word] candidate_word_result_norm_dict[src_word.lower()] = {} for candidate_word in candidates: candidate_word_result_norm_dict[src_word.lower()][candidate_word] = (result_dict[src_word.lower()][candidate_word]/total_score) candidates = [''.join(word.split(' ')) for word in candidates ] total_prob = 0 for candidate_word in candidates: if candidate_word in word_prob_dict.keys(): total_prob += word_prob_dict[candidate_word] candidate_word_prob_norm_dict[src_word.lower()] = {} for candidate_word in candidates: if candidate_word in word_prob_dict.keys(): candidate_word_prob_norm_dict[src_word.lower()][candidate_word] = (word_prob_dict[candidate_word]/total_prob) output_data = {} for src_word in input_data.keys(): temp_candidates_tuple_list = [] candidates = input_data[src_word] candidates = [ ''.join(word.split(' ')) for word in candidates] for candidate_word in candidates: if candidate_word in word_prob_dict.keys(): temp_candidates_tuple_list.append((candidate_word, alpha*candidate_word_result_norm_dict[src_word.lower()][' '.join(list(candidate_word))] + (1-alpha)*candidate_word_prob_norm_dict[src_word.lower()][candidate_word] )) else: temp_candidates_tuple_list.append((candidate_word, 0 )) temp_candidates_tuple_list.sort(key = lambda x: x[1], reverse = True ) temp_candidates_list = [] for cadidate_tuple in temp_candidates_tuple_list: temp_candidates_list.append(' '.join(list(cadidate_tuple[0]))) output_data[src_word] = temp_candidates_list return output_data def post_process(self, translation_str, tgt_lang): lines = translation_str.split('\n') list_s = [line for line in lines if 'S-' in line] list_h = [line for line in lines if 'H-' in line] list_s.sort(key = lambda x: int(x.split('\t')[0].split('-')[1]) ) list_h.sort(key = lambda x: int(x.split('\t')[0].split('-')[1]) ) res_dict = {} for s in list_s: s_id = int(s.split('\t')[0].split('-')[1]) res_dict[s_id] = { 'S' : s.split('\t')[1] } res_dict[s_id]['H'] = [] for h in list_h: h_id = int(h.split('\t')[0].split('-')[1]) if s_id == h_id: res_dict[s_id]['H'].append( ( h.split('\t')[2], pow(2,float(h.split('\t')[1])) ) ) for r in res_dict.keys(): res_dict[r]['H'].sort(key = lambda x : float(x[1]) ,reverse =True) result_dict = {} for i in res_dict.keys(): result_dict[res_dict[i]['S']] = {} for j in range(len(res_dict[i]['H'])): result_dict[res_dict[i]['S']][res_dict[i]['H'][j][0]] = res_dict[i]['H'][j][1] transliterated_word_list = [] if self._rescore: output_dir = self.rescore(res_dict, result_dict, tgt_lang, alpha = 0.9) for src_word in output_dir.keys(): for j in range(len(output_dir[src_word])): transliterated_word_list.append( output_dir[src_word][j] ) else: for i in res_dict.keys(): for j in range(len(res_dict[i]['H'])): transliterated_word_list.append( res_dict[i]['H'][j][0] ) transliterated_word_list = [''.join(word.split(' ')) for word in transliterated_word_list] return transliterated_word_list def _transliterate_word(self, text, src_lang, tgt_lang, topk=4, nativize_punctuations=True, nativize_numerals=False): if not text: return text text = text.lower().strip() if src_lang != 'en': text = text.translate(INDIC_TO_LATIN_PUNCT_TRANSLATOR) text = text.translate(INDIC_TO_STANDARD_NUMERALS_TRANSLATOR) else: if nativize_punctuations: if tgt_lang in RTL_LANG_CODES: text = text.translate(LATIN_TO_PERSOARABIC_PUNC_TRANSLATOR) text = nativize_latin_fullstop(text, tgt_lang) if nativize_numerals: text = text.translate(LATIN_TO_NATIVE_NUMERALS_TRANSLATORS[tgt_lang]) matches = LANG_WORD_REGEXES[src_lang].findall(text) if not matches: return [text] src_word = matches[-1] transliteration_list = self.batch_transliterate_words([src_word], src_lang, tgt_lang, topk=topk)[0] if tgt_lang != 'en' or tgt_lang != 'sa': for i in range(len(transliteration_list)): transliteration_list[i] = hardfix_wordfinal_virama(transliteration_list[i]) if src_word == text: return transliteration_list return [ rreplace(text, src_word, tgt_word) for tgt_word in transliteration_list ] def batch_transliterate_words(self, words, src_lang, tgt_lang, topk=4): perprcossed_words = self.pre_process(words, src_lang, tgt_lang) translation_str = self.transliterator.translate(perprcossed_words, nbest=topk) transliteration_list = self.post_process(translation_str, tgt_lang) if tgt_lang == 'mr': for i in range(len(transliteration_list)): transliteration_list[i] = transliteration_list[i].replace("अॅ", 'ॲ') if tgt_lang == 'or': for i in range(len(transliteration_list)): transliteration_list[i] = fix_odia_confusing_ambiguous_yuktakshara(transliteration_list[i]) if tgt_lang == 'sa': for i in range(len(transliteration_list)): transliteration_list[i] = explicit_devanagari_wordfinal_schwa_delete(words[0], transliteration_list[i]) transliteration_list = list(dict.fromkeys(transliteration_list)) return [transliteration_list] def _transliterate_sentence(self, text, src_lang, tgt_lang, nativize_punctuations=True, nativize_numerals=False): if not text: return text text = text.lower().strip() if src_lang != 'en': text = text.translate(INDIC_TO_LATIN_PUNCT_TRANSLATOR) text = text.translate(INDIC_TO_STANDARD_NUMERALS_TRANSLATOR) else: if nativize_punctuations: if tgt_lang in RTL_LANG_CODES: text = text.translate(LATIN_TO_PERSOARABIC_PUNC_TRANSLATOR) text = nativize_latin_fullstop(text, tgt_lang) if nativize_numerals: text = text.translate(LATIN_TO_NATIVE_NUMERALS_TRANSLATORS[tgt_lang]) matches = LANG_WORD_REGEXES[src_lang].findall(text) if not matches: return text out_str = text for match in matches: result = self.batch_transliterate_words([match], src_lang, tgt_lang)[0][0] out_str = re_sub(match, result, out_str, 1) return out_str