File size: 3,896 Bytes
76b06b6 b912ba6 122790b b912ba6 a6fc19e 83c5c51 71d9b1a 83c5c51 71d9b1a 83c5c51 8d43482 71d9b1a 410ea95 8d43482 71d9b1a 8d43482 71d9b1a 8d43482 71d9b1a 410ea95 71d9b1a 83c5c51 e22e364 dd34156 122790b b912ba6 8d43482 122790b dd34156 dc03f7a dd34156 501abbc dd34156 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import streamlit as st
from transformers import MarianTokenizer, MarianMTModel
import torch
LANGUAGES = {
"en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"),
"de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"),
"ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
}
# Cache resource to load a specific translation model pair
@st.cache_resource
def _load_model_pair(source_lang, target_lang):
try:
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
return tokenizer, model
except Exception:
return None, None
# Cache resource to load all possible model combinations
@st.cache_resource
def _load_all_models():
models = {}
for src in LANGUAGES.keys():
for tgt in LANGUAGES.keys():
if src != tgt:
models[(src, tgt)] = _load_model_pair(src, tgt)
return models
all_models = _load_all_models()
# Class to handle combined translation through English pivot
class CombinedModel:
def generate(self, **kwargs):
return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
# Function to load appropriate translation model with optimized caching
@st.cache_resource
def load_model(source_lang, target_lang):
if source_lang == target_lang:
return _load_default_model()
model_key = (source_lang, target_lang)
tokenizer_model_pair = all_models.get(model_key)
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
return tokenizer_model_pair
# Optimized pivot through English using preloaded models
def combined_translate(text):
en_tokenizer, en_model = all_models.get(("en", "en"), _load_default_model())
if source_lang != "en":
src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), _load_model_pair(source_lang, "en")) or _load_default_model()
with torch.no_grad():
en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
else:
en_text = text
if target_lang != "en":
en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
with torch.no_grad():
return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
return en_text
default_tokenizer, _ = _load_default_model()
return default_tokenizer, CombinedModel()
# Cache resource to load default translation model
@st.cache_resource
def _load_default_model():
model_name = "Helsinki-NLP/opus-mt-en-hi"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
return tokenizer, model
# Cache translation results to improve speed
@st.cache_data
def translate(text, source_lang, target_lang):
if not text:
return ""
try:
tokenizer, model = load_model(source_lang, target_lang)
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
with torch.no_grad():
translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True)
return tokenizer.decode(translated[0], skip_special_tokens=True)
except Exception as e:
st.error(f"Translation error: {e}")
return text |