File size: 4,337 Bytes
76b06b6
b912ba6
122790b
b912ba6
a6fc19e
 
 
 
 
 
83c5c51
71d9b1a
 
 
 
 
 
 
 
 
 
83c5c51
71d9b1a
 
 
 
 
 
 
 
 
 
 
3aee48b
 
 
 
 
 
 
 
 
 
 
 
 
83c5c51
 
3aee48b
 
 
 
 
 
83c5c51
3aee48b
83c5c51
8d43482
 
71d9b1a
 
 
 
410ea95
 
 
3aee48b
288ca30
3aee48b
71d9b1a
83c5c51
e22e364
 
dd34156
122790b
 
 
b912ba6
8d43482
 
122790b
 
dd34156
dc03f7a
dd34156
 
 
288ca30
aa7d7c2
 
501abbc
 
dd34156
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import streamlit as st
from transformers import MarianTokenizer, MarianMTModel
import torch

LANGUAGES = {
    "en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"),
    "de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"),
    "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
}

# Cache resource to load a specific translation model pair
@st.cache_resource
def _load_model_pair(source_lang, target_lang):
    try:
        model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
        tokenizer = MarianTokenizer.from_pretrained(model_name)
        model = MarianMTModel.from_pretrained(model_name)
        return tokenizer, model
    except Exception:
        return None, None

# Cache resource to load all possible model combinations
@st.cache_resource
def _load_all_models():
    models = {}
    for src in LANGUAGES.keys():
        for tgt in LANGUAGES.keys():
            if src != tgt:
                models[(src, tgt)] = _load_model_pair(src, tgt)
    return models

all_models = _load_all_models()

# Define combined_translate outside load_model
def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
    with torch.no_grad():
        if source_lang != "en":
            src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), (default_tokenizer, default_model))
            en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
        else:
            en_text = text
        if target_lang != "en":
            en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), (default_tokenizer, default_model))
            return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True)
        return en_text

# Class to handle combined translation through English pivot
class CombinedModel:
    def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
        self.source_lang = source_lang
        self.target_lang = target_lang
        self.default_tokenizer = default_tokenizer
        self.default_model = default_model

    def generate(self, **kwargs):
        return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True), self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for x in kwargs['input_ids']])

# Function to load appropriate translation model with optimized caching
@st.cache_resource
def load_model(source_lang, target_lang):
    if source_lang == target_lang:
        return _load_default_model()
    model_key = (source_lang, target_lang)
    tokenizer_model_pair = all_models.get(model_key)
    if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
        return tokenizer_model_pair
    # Use simplified pivot through English with CombinedModel
    default_tokenizer, default_model = _load_default_model()
    return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)

# Cache resource to load default translation model
@st.cache_resource
def _load_default_model():
    model_name = "Helsinki-NLP/opus-mt-en-hi"
    tokenizer = MarianTokenizer.from_pretrained(model_name)
    model = MarianMTModel.from_pretrained(model_name)
    return tokenizer, model

# Cache translation results to improve speed
@st.cache_data
def translate(text, source_lang, target_lang):
    if not text:
        return ""
    try:
        tokenizer, model = load_model(source_lang, target_lang)
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
        with torch.no_grad():
            translated = model.generate(**inputs, max_length=1000 if target_lang == "hi" else 500, num_beams=6 if target_lang == "hi" else 4, early_stopping=True)
        result = tokenizer.decode(translated[0], skip_special_tokens=True)
        return result if result.strip() else text
    except Exception as e:
        st.error(f"Translation error: {e}")
        return text