Update translation.py
Browse files- translation.py +25 -9
translation.py
CHANGED
@@ -36,17 +36,33 @@ def combined_translate(text, source_lang, target_lang, default_tokenizer, defaul
|
|
36 |
with torch.no_grad():
|
37 |
if source_lang != target_lang: # Only translate if languages differ
|
38 |
if source_lang != "en":
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
else:
|
42 |
-
|
43 |
if target_lang != "en":
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
return text
|
48 |
|
49 |
-
# Class to handle combined translation through
|
50 |
class CombinedModel:
|
51 |
def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
|
52 |
self.source_lang = source_lang
|
@@ -78,7 +94,7 @@ def load_model(source_lang, target_lang):
|
|
78 |
pair2 = all_models.get((inter, target_lang))
|
79 |
if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]:
|
80 |
return pair1
|
81 |
-
# Fallback to
|
82 |
default_tokenizer, default_model = _load_default_model()
|
83 |
return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
|
84 |
|
@@ -101,7 +117,7 @@ def translate(text, source_lang, target_lang):
|
|
101 |
if inputs['input_ids'].size(0) > 1:
|
102 |
inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
|
103 |
with torch.no_grad():
|
104 |
-
translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "ja"] else 500, num_beams=4, early_stopping=True)
|
105 |
result = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
106 |
return result if result.strip() else text
|
107 |
except Exception as e:
|
|
|
36 |
with torch.no_grad():
|
37 |
if source_lang != target_lang: # Only translate if languages differ
|
38 |
if source_lang != "en":
|
39 |
+
src_to_inter_tokenizer, src_to_inter_model = None, None
|
40 |
+
# Try multiple intermediates, prefer English first
|
41 |
+
for inter in ["en", "fr", "es", "de", "ru"]: # Prioritize common languages
|
42 |
+
pair = all_models.get((source_lang, inter))
|
43 |
+
if pair and pair[0] and pair[1]:
|
44 |
+
src_to_inter_tokenizer, src_to_inter_model = pair
|
45 |
+
break
|
46 |
+
if src_to_inter_tokenizer and src_to_inter_model:
|
47 |
+
inter_text = src_to_inter_tokenizer.decode(src_to_inter_model.generate(**src_to_inter_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
48 |
+
else:
|
49 |
+
inter_text = text # Fallback to input if no path found
|
50 |
else:
|
51 |
+
inter_text = text
|
52 |
if target_lang != "en":
|
53 |
+
inter_to_tgt_tokenizer, inter_to_tgt_model = None, None
|
54 |
+
for inter in ["en", "fr", "es", "de", "ru"]:
|
55 |
+
pair = all_models.get((inter, target_lang))
|
56 |
+
if pair and pair[0] and pair[1]:
|
57 |
+
inter_to_tgt_tokenizer, inter_to_tgt_model = pair
|
58 |
+
break
|
59 |
+
if inter_to_tgt_tokenizer and inter_to_tgt_model:
|
60 |
+
translated = inter_to_tgt_tokenizer.decode(inter_to_tgt_model.generate(**inter_to_tgt_tokenizer(inter_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True)
|
61 |
+
return translated if translated.strip() else text
|
62 |
+
return inter_text
|
63 |
return text
|
64 |
|
65 |
+
# Class to handle combined translation through multiple intermediates
|
66 |
class CombinedModel:
|
67 |
def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
|
68 |
self.source_lang = source_lang
|
|
|
94 |
pair2 = all_models.get((inter, target_lang))
|
95 |
if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]:
|
96 |
return pair1
|
97 |
+
# Fallback to default model with CombinedModel
|
98 |
default_tokenizer, default_model = _load_default_model()
|
99 |
return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
|
100 |
|
|
|
117 |
if inputs['input_ids'].size(0) > 1:
|
118 |
inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
|
119 |
with torch.no_grad():
|
120 |
+
translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "zh", "ja"] else 500, num_beams=4, early_stopping=True)
|
121 |
result = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
122 |
return result if result.strip() else text
|
123 |
except Exception as e:
|