Krishna086 commited on
Commit
23bc295
·
verified ·
1 Parent(s): 63de0ad

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +16 -12
translation.py CHANGED
@@ -45,21 +45,25 @@ def load_model(source_lang, target_lang):
45
  tokenizer_model_pair = all_models.get(model_key)
46
  if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
47
  return tokenizer_model_pair
48
- # Optimized pivot through English using preloaded models
 
 
 
 
 
 
 
49
  def combined_translate(text):
50
- en_tokenizer, en_model = all_models.get(("en", "en"), _load_default_model())
51
- if source_lang != "en":
52
- src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), _load_model_pair(source_lang, "en")) or _load_default_model()
53
- with torch.no_grad():
54
  en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
55
- else:
56
- en_text = text
57
- if target_lang != "en":
58
- en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
59
- with torch.no_grad():
60
  return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
61
- return en_text
62
- default_tokenizer, _ = _load_default_model()
63
  return default_tokenizer, CombinedModel()
64
 
65
  # Cache resource to load default translation model
 
45
  tokenizer_model_pair = all_models.get(model_key)
46
  if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
47
  return tokenizer_model_pair
48
+ # Use direct English pivot only if direct model unavailable
49
+ if source_lang != "en" and target_lang != "en":
50
+ en_pivot_pair = all_models.get(("en", target_lang)) or _load_model_pair("en", target_lang)
51
+ if en_pivot_pair[0] and en_pivot_pair[1]:
52
+ src_to_en_pair = all_models.get((source_lang, "en")) or _load_model_pair(source_lang, "en")
53
+ if src_to_en_pair[0] and src_to_en_pair[1]:
54
+ return en_pivot_pair # Prefer direct pivot chain
55
+ default_tokenizer, _ = _load_default_model()
56
  def combined_translate(text):
57
+ with torch.no_grad():
58
+ if source_lang != "en":
59
+ src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), _load_default_model())
 
60
  en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
61
+ else:
62
+ en_text = text
63
+ if target_lang != "en":
64
+ en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_default_model())
 
65
  return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
66
+ return en_text
 
67
  return default_tokenizer, CombinedModel()
68
 
69
  # Cache resource to load default translation model