Update translation.py
Browse files- translation.py +7 -6
translation.py
CHANGED
@@ -36,14 +36,15 @@ class CombinedModel:
|
|
36 |
def generate(self, **kwargs):
|
37 |
return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
|
38 |
|
39 |
-
# Function to load appropriate translation model
|
40 |
def load_model(source_lang, target_lang):
|
41 |
if source_lang == target_lang:
|
42 |
return _load_default_model()
|
43 |
model_key = (source_lang, target_lang)
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
47 |
def combined_translate(text):
|
48 |
en_tokenizer, en_model = all_models.get(("en", "en"), _load_default_model())
|
49 |
if source_lang != "en":
|
@@ -55,8 +56,8 @@ def load_model(source_lang, target_lang):
|
|
55 |
en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
|
56 |
return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
57 |
return en_text
|
58 |
-
|
59 |
-
return
|
60 |
|
61 |
# Cache resource to load default translation model
|
62 |
@st.cache_resource
|
|
|
36 |
def generate(self, **kwargs):
|
37 |
return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
|
38 |
|
39 |
+
# Function to load appropriate translation model with optimized pivot
|
40 |
def load_model(source_lang, target_lang):
|
41 |
if source_lang == target_lang:
|
42 |
return _load_default_model()
|
43 |
model_key = (source_lang, target_lang)
|
44 |
+
tokenizer_model_pair = all_models.get(model_key)
|
45 |
+
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
|
46 |
+
return tokenizer_model_pair
|
47 |
+
# Optimize pivot through English using preloaded models
|
48 |
def combined_translate(text):
|
49 |
en_tokenizer, en_model = all_models.get(("en", "en"), _load_default_model())
|
50 |
if source_lang != "en":
|
|
|
56 |
en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
|
57 |
return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
58 |
return en_text
|
59 |
+
default_tokenizer, _ = _load_default_model()
|
60 |
+
return default_tokenizer, CombinedModel()
|
61 |
|
62 |
# Cache resource to load default translation model
|
63 |
@st.cache_resource
|