Krishna086 commited on
Commit
501abbc
·
verified ·
1 Parent(s): 538b10c

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +13 -6
translation.py CHANGED
@@ -20,18 +20,24 @@ def load_model(source_lang, target_lang):
20
  try:
21
  if source_lang == target_lang:
22
  return _load_default_model()
23
- # Try direct model first
24
  model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
25
  try:
26
  tokenizer = MarianTokenizer.from_pretrained(model_name)
27
  model = MarianMTModel.from_pretrained(model_name)
28
  return tokenizer, model
29
  except Exception:
30
- # Pivot through English for non-English pairs
31
  if source_lang != "en" and target_lang != "en":
32
- en_to_target = load_model("en", target_lang)
33
- source_to_en = load_model(source_lang, "en")
34
- return source_to_en if source_lang == "en" else en_to_target
 
 
 
 
 
 
35
  return _load_default_model()
36
  except Exception:
37
  return _load_default_model()
@@ -45,5 +51,6 @@ def translate(text, source_lang, target_lang):
45
  with torch.no_grad():
46
  translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True)
47
  return tokenizer.decode(translated[0], skip_special_tokens=True)
48
- except Exception:
 
49
  return text
 
20
  try:
21
  if source_lang == target_lang:
22
  return _load_default_model()
23
+ # Try direct model
24
  model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
25
  try:
26
  tokenizer = MarianTokenizer.from_pretrained(model_name)
27
  model = MarianMTModel.from_pretrained(model_name)
28
  return tokenizer, model
29
  except Exception:
30
+ # Pivot through English
31
  if source_lang != "en" and target_lang != "en":
32
+ en_to_target_tokenizer, en_to_target_model = load_model("en", target_lang)
33
+ source_to_en_tokenizer, source_to_en_model = load_model(source_lang, "en")
34
+ def combined_translate(text):
35
+ en_text = source_to_en_tokenizer.decode(source_to_en_model.generate(**source_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
36
+ return en_to_target_tokenizer.decode(en_to_target_model.generate(**en_to_target_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
37
+ class CombinedModel:
38
+ def generate(self, **kwargs):
39
+ return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
40
+ return MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi"), CombinedModel()
41
  return _load_default_model()
42
  except Exception:
43
  return _load_default_model()
 
51
  with torch.no_grad():
52
  translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True)
53
  return tokenizer.decode(translated[0], skip_special_tokens=True)
54
+ except Exception as e:
55
+ st.error(f"Translation error: {e}")
56
  return text