Update translation.py
Browse files- translation.py +5 -5
translation.py
CHANGED
@@ -27,13 +27,13 @@ def load_model(source_lang, target_lang):
|
|
27 |
model = MarianMTModel.from_pretrained(model_name)
|
28 |
return tokenizer, model
|
29 |
except Exception:
|
30 |
-
# Pivot through English
|
31 |
if source_lang != "en" and target_lang != "en":
|
32 |
-
en_to_target_tokenizer, en_to_target_model = load_model("en", target_lang)
|
33 |
-
source_to_en_tokenizer, source_to_en_model = load_model(source_lang, "en")
|
34 |
def combined_translate(text):
|
35 |
-
|
36 |
-
|
|
|
|
|
37 |
class CombinedModel:
|
38 |
def generate(self, **kwargs):
|
39 |
return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
|
|
|
27 |
model = MarianMTModel.from_pretrained(model_name)
|
28 |
return tokenizer, model
|
29 |
except Exception:
|
30 |
+
# Pivot through English for non-English pairs
|
31 |
if source_lang != "en" and target_lang != "en":
|
|
|
|
|
32 |
def combined_translate(text):
|
33 |
+
en_tokenizer, en_model = load_model(source_lang, "en")
|
34 |
+
en_text = en_tokenizer.decode(en_model.generate(**en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
35 |
+
target_tokenizer, target_model = load_model("en", target_lang)
|
36 |
+
return target_tokenizer.decode(target_model.generate(**target_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
37 |
class CombinedModel:
|
38 |
def generate(self, **kwargs):
|
39 |
return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
|