Update translation.py
Browse files- translation.py +13 -6
translation.py
CHANGED
@@ -20,18 +20,24 @@ def load_model(source_lang, target_lang):
|
|
20 |
try:
|
21 |
if source_lang == target_lang:
|
22 |
return _load_default_model()
|
23 |
-
# Try direct model
|
24 |
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
|
25 |
try:
|
26 |
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
27 |
model = MarianMTModel.from_pretrained(model_name)
|
28 |
return tokenizer, model
|
29 |
except Exception:
|
30 |
-
# Pivot through English
|
31 |
if source_lang != "en" and target_lang != "en":
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
return _load_default_model()
|
36 |
except Exception:
|
37 |
return _load_default_model()
|
@@ -45,5 +51,6 @@ def translate(text, source_lang, target_lang):
|
|
45 |
with torch.no_grad():
|
46 |
translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True)
|
47 |
return tokenizer.decode(translated[0], skip_special_tokens=True)
|
48 |
-
except Exception:
|
|
|
49 |
return text
|
|
|
20 |
try:
|
21 |
if source_lang == target_lang:
|
22 |
return _load_default_model()
|
23 |
+
# Try direct model
|
24 |
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
|
25 |
try:
|
26 |
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
27 |
model = MarianMTModel.from_pretrained(model_name)
|
28 |
return tokenizer, model
|
29 |
except Exception:
|
30 |
+
# Pivot through English
|
31 |
if source_lang != "en" and target_lang != "en":
|
32 |
+
en_to_target_tokenizer, en_to_target_model = load_model("en", target_lang)
|
33 |
+
source_to_en_tokenizer, source_to_en_model = load_model(source_lang, "en")
|
34 |
+
def combined_translate(text):
|
35 |
+
en_text = source_to_en_tokenizer.decode(source_to_en_model.generate(**source_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
36 |
+
return en_to_target_tokenizer.decode(en_to_target_model.generate(**en_to_target_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
37 |
+
class CombinedModel:
|
38 |
+
def generate(self, **kwargs):
|
39 |
+
return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
|
40 |
+
return MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi"), CombinedModel()
|
41 |
return _load_default_model()
|
42 |
except Exception:
|
43 |
return _load_default_model()
|
|
|
51 |
with torch.no_grad():
|
52 |
translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True)
|
53 |
return tokenizer.decode(translated[0], skip_special_tokens=True)
|
54 |
+
except Exception as e:
|
55 |
+
st.error(f"Translation error: {e}")
|
56 |
return text
|