Update translation.py
Browse files- translation.py +10 -18
translation.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import streamlit as st
|
2 |
from transformers import MarianTokenizer, MarianMTModel
|
3 |
import torch
|
4 |
|
@@ -10,7 +9,7 @@ LANGUAGES = {
|
|
10 |
|
11 |
@st.cache_resource
|
12 |
def _load_default_model():
|
13 |
-
model_name = "Helsinki-NLP/opus-mt-en-hi"
|
14 |
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
15 |
model = MarianMTModel.from_pretrained(model_name)
|
16 |
return tokenizer, model
|
@@ -24,24 +23,17 @@ def load_model(source_lang, target_lang):
|
|
24 |
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
25 |
model = MarianMTModel.from_pretrained(model_name)
|
26 |
return tokenizer, model
|
27 |
-
except Exception
|
28 |
-
st.warning(f"No direct model for {source_lang} to {target_lang}. Using en-hi fallback.")
|
29 |
return _load_default_model()
|
30 |
|
31 |
-
@st.cache_data(ttl=3600)
|
32 |
-
def translate_cached(text, source_lang, target_lang):
|
33 |
-
tokenizer, model = load_model(source_lang, target_lang)
|
34 |
-
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
|
35 |
-
with torch.no_grad():
|
36 |
-
translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True)
|
37 |
-
translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
|
38 |
-
return translated_text if translated_text.strip() and len(translated_text.split()) >= 2 else text
|
39 |
-
|
40 |
def translate(text, source_lang, target_lang):
|
41 |
if not text:
|
42 |
-
return "
|
43 |
try:
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import MarianTokenizer, MarianMTModel
|
2 |
import torch
|
3 |
|
|
|
9 |
|
10 |
@st.cache_resource
|
11 |
def _load_default_model():
|
12 |
+
model_name = "Helsinki-NLP/opus-mt-en-hi"
|
13 |
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
14 |
model = MarianMTModel.from_pretrained(model_name)
|
15 |
return tokenizer, model
|
|
|
23 |
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
24 |
model = MarianMTModel.from_pretrained(model_name)
|
25 |
return tokenizer, model
|
26 |
+
except Exception:
|
|
|
27 |
return _load_default_model()
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def translate(text, source_lang, target_lang):
|
30 |
if not text:
|
31 |
+
return ""
|
32 |
try:
|
33 |
+
tokenizer, model = load_model(source_lang, target_lang)
|
34 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
|
35 |
+
with torch.no_grad():
|
36 |
+
translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True)
|
37 |
+
return tokenizer.decode(translated[0], skip_special_tokens=True)
|
38 |
+
except Exception:
|
39 |
+
return text
|