Update translation.py
Browse files- translation.py +22 -14
translation.py
CHANGED
@@ -31,10 +31,29 @@ def _load_all_models():
|
|
31 |
|
32 |
all_models = _load_all_models()
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# Class to handle combined translation through English pivot
|
35 |
class CombinedModel:
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def generate(self, **kwargs):
|
37 |
-
return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
|
38 |
|
39 |
# Function to load appropriate translation model with optimized caching
|
40 |
@st.cache_resource
|
@@ -45,20 +64,9 @@ def load_model(source_lang, target_lang):
|
|
45 |
tokenizer_model_pair = all_models.get(model_key)
|
46 |
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
|
47 |
return tokenizer_model_pair
|
48 |
-
#
|
49 |
default_tokenizer, default_model = _load_default_model()
|
50 |
-
|
51 |
-
with torch.no_grad():
|
52 |
-
if source_lang != "en":
|
53 |
-
src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), (default_tokenizer, default_model))
|
54 |
-
en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
55 |
-
else:
|
56 |
-
en_text = text
|
57 |
-
if target_lang != "en":
|
58 |
-
en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), (default_tokenizer, default_model))
|
59 |
-
return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True)
|
60 |
-
return en_text
|
61 |
-
return default_tokenizer, CombinedModel()
|
62 |
|
63 |
# Cache resource to load default translation model
|
64 |
@st.cache_resource
|
|
|
31 |
|
32 |
all_models = _load_all_models()
|
33 |
|
34 |
+
# Define combined_translate outside load_model
|
35 |
+
def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
|
36 |
+
with torch.no_grad():
|
37 |
+
if source_lang != "en":
|
38 |
+
src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), (default_tokenizer, default_model))
|
39 |
+
en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
40 |
+
else:
|
41 |
+
en_text = text
|
42 |
+
if target_lang != "en":
|
43 |
+
en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), (default_tokenizer, default_model))
|
44 |
+
return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True)
|
45 |
+
return en_text
|
46 |
+
|
47 |
# Class to handle combined translation through English pivot
|
48 |
class CombinedModel:
|
49 |
+
def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
|
50 |
+
self.source_lang = source_lang
|
51 |
+
self.target_lang = target_lang
|
52 |
+
self.default_tokenizer = default_tokenizer
|
53 |
+
self.default_model = default_model
|
54 |
+
|
55 |
def generate(self, **kwargs):
|
56 |
+
return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True), self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for x in kwargs['input_ids']])
|
57 |
|
58 |
# Function to load appropriate translation model with optimized caching
|
59 |
@st.cache_resource
|
|
|
64 |
tokenizer_model_pair = all_models.get(model_key)
|
65 |
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
|
66 |
return tokenizer_model_pair
|
67 |
+
# Use simplified pivot through English with CombinedModel
|
68 |
default_tokenizer, default_model = _load_default_model()
|
69 |
+
return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
# Cache resource to load default translation model
|
72 |
@st.cache_resource
|