File size: 4,683 Bytes
76b06b6
b912ba6
122790b
b912ba6
a6fc19e
 
 
 
 
 
83c5c51
71d9b1a
 
 
 
 
 
 
 
 
 
83c5c51
71d9b1a
 
 
 
 
 
 
 
 
 
 
64afd54
3aee48b
 
 
 
 
 
 
 
 
 
 
 
83c5c51
 
3aee48b
 
 
 
 
 
83c5c51
64afd54
 
 
 
83c5c51
8d43482
 
71d9b1a
 
 
 
410ea95
 
 
2328aaa
 
 
 
 
 
 
288ca30
3aee48b
71d9b1a
83c5c51
e22e364
 
dd34156
122790b
 
 
b912ba6
8d43482
 
122790b
 
dd34156
dc03f7a
dd34156
 
 
2328aaa
aa7d7c2
 
501abbc
 
dd34156
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import streamlit as st
from transformers import MarianTokenizer, MarianMTModel
import torch

LANGUAGES = {
    "en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"),
    "de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"),
    "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
}

# Cache resource to load a specific translation model pair
@st.cache_resource
def _load_model_pair(source_lang, target_lang):
    try:
        model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
        tokenizer = MarianTokenizer.from_pretrained(model_name)
        model = MarianMTModel.from_pretrained(model_name)
        return tokenizer, model
    except Exception:
        return None, None

# Cache resource to load all possible model combinations
@st.cache_resource
def _load_all_models():
    models = {}
    for src in LANGUAGES.keys():
        for tgt in LANGUAGES.keys():
            if src != tgt:
                models[(src, tgt)] = _load_model_pair(src, tgt)
    return models

all_models = _load_all_models()

# Define combined_translate outside load_model with explicit parameters
def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
    with torch.no_grad():
        if source_lang != "en":
            src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), (default_tokenizer, default_model))
            en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
        else:
            en_text = text
        if target_lang != "en":
            en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), (default_tokenizer, default_model))
            return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True)
        return en_text

# Class to handle combined translation through English pivot
class CombinedModel:
    def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
        self.source_lang = source_lang
        self.target_lang = target_lang
        self.default_tokenizer = default_tokenizer
        self.default_model = default_model

    def generate(self, **kwargs):
        input_ids = kwargs.get('input_ids')
        if not input_ids:
            return torch.tensor([])
        return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True), self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for x in input_ids])

# Function to load appropriate translation model with optimized caching
@st.cache_resource
def load_model(source_lang, target_lang):
    if source_lang == target_lang:
        return _load_default_model()
    model_key = (source_lang, target_lang)
    tokenizer_model_pair = all_models.get(model_key)
    if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
        return tokenizer_model_pair
    # Prefer direct model if available, then pivot
    for src in [source_lang, "en"]:
        for tgt in [target_lang, "en"]:
            if src != tgt:
                pair = all_models.get((src, tgt))
                if pair and pair[0] and pair[1]:
                    return pair
    default_tokenizer, default_model = _load_default_model()
    return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)

# Cache resource to load default translation model
@st.cache_resource
def _load_default_model():
    model_name = "Helsinki-NLP/opus-mt-en-hi"
    tokenizer = MarianTokenizer.from_pretrained(model_name)
    model = MarianMTModel.from_pretrained(model_name)
    return tokenizer, model

# Cache translation results to improve speed
@st.cache_data
def translate(text, source_lang, target_lang):
    if not text:
        return ""
    try:
        tokenizer, model = load_model(source_lang, target_lang)
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
        with torch.no_grad():
            translated = model.generate(**inputs, max_length=1000 if target_lang == "hi" else 500, num_beams=4, early_stopping=True)  # Reduced to 4 beams for speed
        result = tokenizer.decode(translated[0], skip_special_tokens=True)
        return result if result.strip() else text
    except Exception as e:
        st.error(f"Translation error: {e}")
        return text