Krishna086 commited on
Commit
f2db407
·
verified ·
1 Parent(s): f6b831a

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +5 -5
translation.py CHANGED
@@ -27,13 +27,13 @@ def load_model(source_lang, target_lang):
27
  model = MarianMTModel.from_pretrained(model_name)
28
  return tokenizer, model
29
  except Exception:
30
- # Pivot through English
31
  if source_lang != "en" and target_lang != "en":
32
- en_to_target_tokenizer, en_to_target_model = load_model("en", target_lang)
33
- source_to_en_tokenizer, source_to_en_model = load_model(source_lang, "en")
34
  def combined_translate(text):
35
- en_text = source_to_en_tokenizer.decode(source_to_en_model.generate(**source_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
36
- return en_to_target_tokenizer.decode(en_to_target_model.generate(**en_to_target_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
 
 
37
  class CombinedModel:
38
  def generate(self, **kwargs):
39
  return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
 
27
  model = MarianMTModel.from_pretrained(model_name)
28
  return tokenizer, model
29
  except Exception:
30
+ # Pivot through English for non-English pairs
31
  if source_lang != "en" and target_lang != "en":
 
 
32
  def combined_translate(text):
33
+ en_tokenizer, en_model = load_model(source_lang, "en")
34
+ en_text = en_tokenizer.decode(en_model.generate(**en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
35
+ target_tokenizer, target_model = load_model("en", target_lang)
36
+ return target_tokenizer.decode(target_model.generate(**target_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
37
  class CombinedModel:
38
  def generate(self, **kwargs):
39
  return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])