Krishna086 commited on
Commit
410ea95
·
verified ·
1 Parent(s): cff3443

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +7 -6
translation.py CHANGED
@@ -36,14 +36,15 @@ class CombinedModel:
36
  def generate(self, **kwargs):
37
  return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
38
 
39
- # Function to load appropriate translation model
40
  def load_model(source_lang, target_lang):
41
  if source_lang == target_lang:
42
  return _load_default_model()
43
  model_key = (source_lang, target_lang)
44
- if all_models.get(model_key) and all_models[model_key][0] and all_models[model_key][1]:
45
- return all_models[model_key]
46
- # Pivot through English
 
47
  def combined_translate(text):
48
  en_tokenizer, en_model = all_models.get(("en", "en"), _load_default_model())
49
  if source_lang != "en":
@@ -55,8 +56,8 @@ def load_model(source_lang, target_lang):
55
  en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
56
  return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
57
  return en_text
58
- tokenizer, _ = _load_default_model()
59
- return tokenizer, CombinedModel()
60
 
61
  # Cache resource to load default translation model
62
  @st.cache_resource
 
36
  def generate(self, **kwargs):
37
  return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
38
 
39
+ # Function to load appropriate translation model with optimized pivot
40
  def load_model(source_lang, target_lang):
41
  if source_lang == target_lang:
42
  return _load_default_model()
43
  model_key = (source_lang, target_lang)
44
+ tokenizer_model_pair = all_models.get(model_key)
45
+ if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
46
+ return tokenizer_model_pair
47
+ # Optimize pivot through English using preloaded models
48
  def combined_translate(text):
49
  en_tokenizer, en_model = all_models.get(("en", "en"), _load_default_model())
50
  if source_lang != "en":
 
56
  en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
57
  return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
58
  return en_text
59
+ default_tokenizer, _ = _load_default_model()
60
+ return default_tokenizer, CombinedModel()
61
 
62
  # Cache resource to load default translation model
63
  @st.cache_resource