Update translation.py
Browse files- translation.py +4 -4
translation.py
CHANGED
@@ -13,7 +13,7 @@ def _load_default_model():
|
|
13 |
# Cache other models dynamically
|
14 |
@st.cache_resource
|
15 |
def load_model(src_lang, tgt_lang):
|
16 |
-
"""Load MarianMT model for a
|
17 |
model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
|
18 |
try:
|
19 |
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
@@ -37,16 +37,16 @@ def translate(text, source_lang, target_lang):
|
|
37 |
if src_code == "en" and tgt_code == "fr":
|
38 |
tokenizer, model = DEFAULT_TOKENIZER, DEFAULT_MODEL
|
39 |
else:
|
40 |
-
tokenizer, model = load_model(src_code,
|
41 |
|
42 |
# Perform translation
|
43 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=400)
|
44 |
translated = model.generate(**inputs)
|
45 |
return tokenizer.decode(translated[0], skip_special_tokens=True)
|
46 |
|
47 |
-
#
|
48 |
LANGUAGES = {
|
49 |
-
|
50 |
"French": "fr",
|
51 |
"Spanish": "es",
|
52 |
"German": "de",
|
|
|
13 |
# Cache other models dynamically
|
14 |
@st.cache_resource
|
15 |
def load_model(src_lang, tgt_lang):
|
16 |
+
"""Load the MarianMT model and tokenizer for a language pair."""
|
17 |
model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
|
18 |
try:
|
19 |
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
|
|
37 |
if src_code == "en" and tgt_code == "fr":
|
38 |
tokenizer, model = DEFAULT_TOKENIZER, DEFAULT_MODEL
|
39 |
else:
|
40 |
+
tokenizer, model = load_model(src_code, tgt_lang)
|
41 |
|
42 |
# Perform translation
|
43 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=400)
|
44 |
translated = model.generate(**inputs)
|
45 |
return tokenizer.decode(translated[0], skip_special_tokens=True)
|
46 |
|
47 |
+
# Dictionary of supported languages with MarianMT codes
|
48 |
LANGUAGES = {
|
49 |
+
"English": "en",
|
50 |
"French": "fr",
|
51 |
"Spanish": "es",
|
52 |
"German": "de",
|