File size: 3,896 Bytes
76b06b6
b912ba6
122790b
b912ba6
a6fc19e
 
 
 
 
 
83c5c51
71d9b1a
 
 
 
 
 
 
 
 
 
83c5c51
71d9b1a
 
 
 
 
 
 
 
 
 
 
83c5c51
 
 
 
 
8d43482
 
71d9b1a
 
 
 
410ea95
 
 
8d43482
71d9b1a
 
 
 
8d43482
 
71d9b1a
 
 
 
8d43482
 
71d9b1a
410ea95
 
71d9b1a
83c5c51
e22e364
 
dd34156
122790b
 
 
b912ba6
8d43482
 
122790b
 
dd34156
dc03f7a
dd34156
 
 
 
 
501abbc
 
dd34156
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
from transformers import MarianTokenizer, MarianMTModel
import torch

LANGUAGES = {
    "en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"),
    "de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"),
    "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
}

# Cache resource to load a specific translation model pair
@st.cache_resource
def _load_model_pair(source_lang, target_lang):
    try:
        model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
        tokenizer = MarianTokenizer.from_pretrained(model_name)
        model = MarianMTModel.from_pretrained(model_name)
        return tokenizer, model
    except Exception:
        return None, None

# Cache resource to load all possible model combinations
@st.cache_resource
def _load_all_models():
    models = {}
    for src in LANGUAGES.keys():
        for tgt in LANGUAGES.keys():
            if src != tgt:
                models[(src, tgt)] = _load_model_pair(src, tgt)
    return models

all_models = _load_all_models()

# Class to handle combined translation through English pivot
class CombinedModel:
    def generate(self, **kwargs):
        return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])

# Function to load appropriate translation model with optimized caching
@st.cache_resource
def load_model(source_lang, target_lang):
    if source_lang == target_lang:
        return _load_default_model()
    model_key = (source_lang, target_lang)
    tokenizer_model_pair = all_models.get(model_key)
    if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
        return tokenizer_model_pair
    # Optimized pivot through English using preloaded models
    def combined_translate(text):
        en_tokenizer, en_model = all_models.get(("en", "en"), _load_default_model())
        if source_lang != "en":
            src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), _load_model_pair(source_lang, "en")) or _load_default_model()
            with torch.no_grad():
                en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
        else:
            en_text = text
        if target_lang != "en":
            en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
            with torch.no_grad():
                return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
        return en_text
    default_tokenizer, _ = _load_default_model()
    return default_tokenizer, CombinedModel()

# Cache resource to load default translation model
@st.cache_resource
def _load_default_model():
    model_name = "Helsinki-NLP/opus-mt-en-hi"
    tokenizer = MarianTokenizer.from_pretrained(model_name)
    model = MarianMTModel.from_pretrained(model_name)
    return tokenizer, model

# Cache translation results to improve speed
@st.cache_data
def translate(text, source_lang, target_lang):
    if not text:
        return ""
    try:
        tokenizer, model = load_model(source_lang, target_lang)
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
        with torch.no_grad():
            translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True)
        return tokenizer.decode(translated[0], skip_special_tokens=True)
    except Exception as e:
        st.error(f"Translation error: {e}")
        return text