Krishna086 commited on
Commit
8f10e97
·
verified ·
1 Parent(s): b38d3f8

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +25 -9
translation.py CHANGED
@@ -36,17 +36,33 @@ def combined_translate(text, source_lang, target_lang, default_tokenizer, defaul
36
  with torch.no_grad():
37
  if source_lang != target_lang: # Only translate if languages differ
38
  if source_lang != "en":
39
- src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), (default_tokenizer, default_model))
40
- en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
41
  else:
42
- en_text = text
43
  if target_lang != "en":
44
- en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), (default_tokenizer, default_model))
45
- translated = en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True)
46
- return translated if translated.strip() else text
 
 
 
 
 
 
 
47
  return text
48
 
49
- # Class to handle combined translation through English pivot
50
  class CombinedModel:
51
  def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
52
  self.source_lang = source_lang
@@ -78,7 +94,7 @@ def load_model(source_lang, target_lang):
78
  pair2 = all_models.get((inter, target_lang))
79
  if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]:
80
  return pair1
81
- # Fallback to pivot through English
82
  default_tokenizer, default_model = _load_default_model()
83
  return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
84
 
@@ -101,7 +117,7 @@ def translate(text, source_lang, target_lang):
101
  if inputs['input_ids'].size(0) > 1:
102
  inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
103
  with torch.no_grad():
104
- translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "ja"] else 500, num_beams=4, early_stopping=True)
105
  result = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
106
  return result if result.strip() else text
107
  except Exception as e:
 
36
  with torch.no_grad():
37
  if source_lang != target_lang: # Only translate if languages differ
38
  if source_lang != "en":
39
+ src_to_inter_tokenizer, src_to_inter_model = None, None
40
+ # Try multiple intermediates, prefer English first
41
+ for inter in ["en", "fr", "es", "de", "ru"]: # Prioritize common languages
42
+ pair = all_models.get((source_lang, inter))
43
+ if pair and pair[0] and pair[1]:
44
+ src_to_inter_tokenizer, src_to_inter_model = pair
45
+ break
46
+ if src_to_inter_tokenizer and src_to_inter_model:
47
+ inter_text = src_to_inter_tokenizer.decode(src_to_inter_model.generate(**src_to_inter_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
48
+ else:
49
+ inter_text = text # Fallback to input if no path found
50
  else:
51
+ inter_text = text
52
  if target_lang != "en":
53
+ inter_to_tgt_tokenizer, inter_to_tgt_model = None, None
54
+ for inter in ["en", "fr", "es", "de", "ru"]:
55
+ pair = all_models.get((inter, target_lang))
56
+ if pair and pair[0] and pair[1]:
57
+ inter_to_tgt_tokenizer, inter_to_tgt_model = pair
58
+ break
59
+ if inter_to_tgt_tokenizer and inter_to_tgt_model:
60
+ translated = inter_to_tgt_tokenizer.decode(inter_to_tgt_model.generate(**inter_to_tgt_tokenizer(inter_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True)
61
+ return translated if translated.strip() else text
62
+ return inter_text
63
  return text
64
 
65
+ # Class to handle combined translation through multiple intermediates
66
  class CombinedModel:
67
  def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
68
  self.source_lang = source_lang
 
94
  pair2 = all_models.get((inter, target_lang))
95
  if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]:
96
  return pair1
97
+ # Fallback to default model with CombinedModel
98
  default_tokenizer, default_model = _load_default_model()
99
  return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
100
 
 
117
  if inputs['input_ids'].size(0) > 1:
118
  inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
119
  with torch.no_grad():
120
+ translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "zh", "ja"] else 500, num_beams=4, early_stopping=True)
121
  result = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
122
  return result if result.strip() else text
123
  except Exception as e: