Update translation.py
Browse files- translation.py +5 -2
translation.py
CHANGED
@@ -31,7 +31,7 @@ def _load_all_models():
|
|
31 |
|
32 |
all_models = _load_all_models()
|
33 |
|
34 |
-
# Define combined_translate outside load_model
|
35 |
def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
|
36 |
with torch.no_grad():
|
37 |
if source_lang != "en":
|
@@ -53,7 +53,10 @@ class CombinedModel:
|
|
53 |
self.default_model = default_model
|
54 |
|
55 |
def generate(self, **kwargs):
|
56 |
-
|
|
|
|
|
|
|
57 |
|
58 |
# Function to load appropriate translation model with optimized caching
|
59 |
@st.cache_resource
|
|
|
31 |
|
32 |
all_models = _load_all_models()
|
33 |
|
34 |
+
# Define combined_translate outside load_model with explicit parameters
|
35 |
def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
|
36 |
with torch.no_grad():
|
37 |
if source_lang != "en":
|
|
|
53 |
self.default_model = default_model
|
54 |
|
55 |
def generate(self, **kwargs):
|
56 |
+
input_ids = kwargs.get('input_ids')
|
57 |
+
if not input_ids:
|
58 |
+
return torch.tensor([])
|
59 |
+
return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True), self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for x in input_ids])
|
60 |
|
61 |
# Function to load appropriate translation model with optimized caching
|
62 |
@st.cache_resource
|