Update translation.py
Browse files- translation.py +16 -12
translation.py
CHANGED
@@ -45,21 +45,25 @@ def load_model(source_lang, target_lang):
|
|
45 |
tokenizer_model_pair = all_models.get(model_key)
|
46 |
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
|
47 |
return tokenizer_model_pair
|
48 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def combined_translate(text):
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
with torch.no_grad():
|
54 |
en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
with torch.no_grad():
|
60 |
return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
61 |
-
|
62 |
-
default_tokenizer, _ = _load_default_model()
|
63 |
return default_tokenizer, CombinedModel()
|
64 |
|
65 |
# Cache resource to load default translation model
|
|
|
45 |
tokenizer_model_pair = all_models.get(model_key)
|
46 |
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
|
47 |
return tokenizer_model_pair
|
48 |
+
# Use direct English pivot only if direct model unavailable
|
49 |
+
if source_lang != "en" and target_lang != "en":
|
50 |
+
en_pivot_pair = all_models.get(("en", target_lang)) or _load_model_pair("en", target_lang)
|
51 |
+
if en_pivot_pair[0] and en_pivot_pair[1]:
|
52 |
+
src_to_en_pair = all_models.get((source_lang, "en")) or _load_model_pair(source_lang, "en")
|
53 |
+
if src_to_en_pair[0] and src_to_en_pair[1]:
|
54 |
+
return en_pivot_pair # Prefer direct pivot chain
|
55 |
+
default_tokenizer, _ = _load_default_model()
|
56 |
def combined_translate(text):
|
57 |
+
with torch.no_grad():
|
58 |
+
if source_lang != "en":
|
59 |
+
src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), _load_default_model())
|
|
|
60 |
en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
61 |
+
else:
|
62 |
+
en_text = text
|
63 |
+
if target_lang != "en":
|
64 |
+
en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_default_model())
|
|
|
65 |
return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
66 |
+
return en_text
|
|
|
67 |
return default_tokenizer, CombinedModel()
|
68 |
|
69 |
# Cache resource to load default translation model
|