Krishna086 commited on
Commit
aa7d7c2
·
verified ·
1 Parent(s): 23bc295

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +6 -7
translation.py CHANGED
@@ -45,13 +45,11 @@ def load_model(source_lang, target_lang):
45
  tokenizer_model_pair = all_models.get(model_key)
46
  if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
47
  return tokenizer_model_pair
48
- # Use direct English pivot only if direct model unavailable
49
  if source_lang != "en" and target_lang != "en":
50
- en_pivot_pair = all_models.get(("en", target_lang)) or _load_model_pair("en", target_lang)
51
  if en_pivot_pair[0] and en_pivot_pair[1]:
52
- src_to_en_pair = all_models.get((source_lang, "en")) or _load_model_pair(source_lang, "en")
53
- if src_to_en_pair[0] and src_to_en_pair[1]:
54
- return en_pivot_pair # Prefer direct pivot chain
55
  default_tokenizer, _ = _load_default_model()
56
  def combined_translate(text):
57
  with torch.no_grad():
@@ -83,8 +81,9 @@ def translate(text, source_lang, target_lang):
83
  tokenizer, model = load_model(source_lang, target_lang)
84
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
85
  with torch.no_grad():
86
- translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True)
87
- return tokenizer.decode(translated[0], skip_special_tokens=True)
 
88
  except Exception as e:
89
  st.error(f"Translation error: {e}")
90
  return text
 
45
  tokenizer_model_pair = all_models.get(model_key)
46
  if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
47
  return tokenizer_model_pair
48
+ # Use direct English pivot only if necessary
49
  if source_lang != "en" and target_lang != "en":
50
+ en_pivot_pair = all_models.get((source_lang, "en")) or _load_model_pair(source_lang, "en")
51
  if en_pivot_pair[0] and en_pivot_pair[1]:
52
+ return en_pivot_pair
 
 
53
  default_tokenizer, _ = _load_default_model()
54
  def combined_translate(text):
55
  with torch.no_grad():
 
81
  tokenizer, model = load_model(source_lang, target_lang)
82
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
83
  with torch.no_grad():
84
+ translated = model.generate(**inputs, max_length=500, num_beams=4, early_stopping=True) # Increased beams for better accuracy
85
+ result = tokenizer.decode(translated[0], skip_special_tokens=True)
86
+ return result if result.strip() else text
87
  except Exception as e:
88
  st.error(f"Translation error: {e}")
89
  return text