Krishna086 commited on
Commit
64afd54
·
verified ·
1 Parent(s): 3aee48b

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +5 -2
translation.py CHANGED
@@ -31,7 +31,7 @@ def _load_all_models():
31
 
32
  all_models = _load_all_models()
33
 
34
- # Define combined_translate outside load_model
35
  def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
36
  with torch.no_grad():
37
  if source_lang != "en":
@@ -53,7 +53,10 @@ class CombinedModel:
53
  self.default_model = default_model
54
 
55
  def generate(self, **kwargs):
56
- return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True), self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for x in kwargs['input_ids']])
 
 
 
57
 
58
  # Function to load appropriate translation model with optimized caching
59
  @st.cache_resource
 
31
 
32
  all_models = _load_all_models()
33
 
34
+ # Define combined_translate outside load_model with explicit parameters
35
  def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
36
  with torch.no_grad():
37
  if source_lang != "en":
 
53
  self.default_model = default_model
54
 
55
  def generate(self, **kwargs):
56
+ input_ids = kwargs.get('input_ids')
57
+ if not input_ids:
58
+ return torch.tensor([])
59
+ return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True), self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for x in input_ids])
60
 
61
  # Function to load appropriate translation model with optimized caching
62
  @st.cache_resource