Update translation.py
Browse files- translation.py +6 -7
translation.py
CHANGED
@@ -45,13 +45,11 @@ def load_model(source_lang, target_lang):
|
|
45 |
tokenizer_model_pair = all_models.get(model_key)
|
46 |
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
|
47 |
return tokenizer_model_pair
|
48 |
-
# Use direct English pivot only if
|
49 |
if source_lang != "en" and target_lang != "en":
|
50 |
-
en_pivot_pair = all_models.get(("en"
|
51 |
if en_pivot_pair[0] and en_pivot_pair[1]:
|
52 |
-
|
53 |
-
if src_to_en_pair[0] and src_to_en_pair[1]:
|
54 |
-
return en_pivot_pair # Prefer direct pivot chain
|
55 |
default_tokenizer, _ = _load_default_model()
|
56 |
def combined_translate(text):
|
57 |
with torch.no_grad():
|
@@ -83,8 +81,9 @@ def translate(text, source_lang, target_lang):
|
|
83 |
tokenizer, model = load_model(source_lang, target_lang)
|
84 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
|
85 |
with torch.no_grad():
|
86 |
-
translated = model.generate(**inputs, max_length=500, num_beams=
|
87 |
-
|
|
|
88 |
except Exception as e:
|
89 |
st.error(f"Translation error: {e}")
|
90 |
return text
|
|
|
45 |
tokenizer_model_pair = all_models.get(model_key)
|
46 |
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
|
47 |
return tokenizer_model_pair
|
48 |
+
# Use direct English pivot only if necessary
|
49 |
if source_lang != "en" and target_lang != "en":
|
50 |
+
en_pivot_pair = all_models.get((source_lang, "en")) or _load_model_pair(source_lang, "en")
|
51 |
if en_pivot_pair[0] and en_pivot_pair[1]:
|
52 |
+
return en_pivot_pair
|
|
|
|
|
53 |
default_tokenizer, _ = _load_default_model()
|
54 |
def combined_translate(text):
|
55 |
with torch.no_grad():
|
|
|
81 |
tokenizer, model = load_model(source_lang, target_lang)
|
82 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
|
83 |
with torch.no_grad():
|
84 |
+
translated = model.generate(**inputs, max_length=500, num_beams=4, early_stopping=True) # Increased beams for better accuracy
|
85 |
+
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
86 |
+
return result if result.strip() else text
|
87 |
except Exception as e:
|
88 |
st.error(f"Translation error: {e}")
|
89 |
return text
|