File size: 6,200 Bytes
76b06b6
b912ba6
122790b
b912ba6
a6fc19e
 
 
 
 
 
83c5c51
71d9b1a
 
 
 
 
 
 
 
 
 
83c5c51
71d9b1a
 
 
 
 
 
 
 
 
 
 
64afd54
3aee48b
 
b38d3f8
 
8f10e97
 
 
 
 
 
 
 
 
 
 
b38d3f8
8f10e97
b38d3f8
8f10e97
 
 
 
 
 
 
 
 
 
b38d3f8
3aee48b
8f10e97
83c5c51
3aee48b
 
 
 
 
 
83c5c51
64afd54
b38d3f8
64afd54
b38d3f8
 
 
83c5c51
8d43482
 
71d9b1a
 
 
 
410ea95
 
 
b38d3f8
 
 
 
 
 
 
8f10e97
288ca30
3aee48b
71d9b1a
83c5c51
e22e364
 
dd34156
122790b
 
 
b912ba6
8d43482
 
122790b
 
dd34156
dc03f7a
dd34156
 
b38d3f8
17b4050
dd34156
8f10e97
b38d3f8
aa7d7c2
501abbc
 
dd34156
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import streamlit as st
from transformers import MarianTokenizer, MarianMTModel
import torch

LANGUAGES = {
    "en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"),
    "de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"),
    "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
}

# Cache resource to load a specific translation model pair
@st.cache_resource
def _load_model_pair(source_lang, target_lang):
    try:
        model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
        tokenizer = MarianTokenizer.from_pretrained(model_name)
        model = MarianMTModel.from_pretrained(model_name)
        return tokenizer, model
    except Exception:
        return None, None

# Cache resource to load all possible model combinations
@st.cache_resource
def _load_all_models():
    models = {}
    for src in LANGUAGES.keys():
        for tgt in LANGUAGES.keys():
            if src != tgt:
                models[(src, tgt)] = _load_model_pair(src, tgt)
    return models

all_models = _load_all_models()

# Define combined_translate outside load_model with explicit parameters
def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
    with torch.no_grad():
        if source_lang != target_lang:  # Only translate if languages differ
            if source_lang != "en":
                src_to_inter_tokenizer, src_to_inter_model = None, None
                # Try multiple intermediates, prefer English first
                for inter in ["en", "fr", "es", "de", "ru"]:  # Prioritize common languages
                    pair = all_models.get((source_lang, inter))
                    if pair and pair[0] and pair[1]:
                        src_to_inter_tokenizer, src_to_inter_model = pair
                        break
                if src_to_inter_tokenizer and src_to_inter_model:
                    inter_text = src_to_inter_tokenizer.decode(src_to_inter_model.generate(**src_to_inter_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
                else:
                    inter_text = text  # Fallback to input if no path found
            else:
                inter_text = text
            if target_lang != "en":
                inter_to_tgt_tokenizer, inter_to_tgt_model = None, None
                for inter in ["en", "fr", "es", "de", "ru"]:
                    pair = all_models.get((inter, target_lang))
                    if pair and pair[0] and pair[1]:
                        inter_to_tgt_tokenizer, inter_to_tgt_model = pair
                        break
                if inter_to_tgt_tokenizer and inter_to_tgt_model:
                    translated = inter_to_tgt_tokenizer.decode(inter_to_tgt_model.generate(**inter_to_tgt_tokenizer(inter_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True)
                    return translated if translated.strip() else text
            return inter_text
        return text

# Class to handle combined translation through multiple intermediates
class CombinedModel:
    def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
        self.source_lang = source_lang
        self.target_lang = target_lang
        self.default_tokenizer = default_tokenizer
        self.default_model = default_model

    def generate(self, **kwargs):
        input_ids = kwargs.get('input_ids')
        if not input_ids or input_ids.size(0) == 0:
            return torch.tensor([])
        inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
        translated = [combined_translate(text, self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for text in inputs]
        return torch.tensor([self.default_tokenizer.encode(t, return_tensors="pt", padding=True, truncation=True, max_length=500)[0] for t in translated])

# Function to load appropriate translation model with optimized caching
@st.cache_resource
def load_model(source_lang, target_lang):
    if source_lang == target_lang:
        return _load_default_model()
    model_key = (source_lang, target_lang)
    tokenizer_model_pair = all_models.get(model_key)
    if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
        return tokenizer_model_pair
    # Try to find the best path through any intermediate language
    for inter in LANGUAGES.keys():
        if inter != source_lang and inter != target_lang:
            pair1 = all_models.get((source_lang, inter))
            pair2 = all_models.get((inter, target_lang))
            if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]:
                return pair1
    # Fallback to default model with CombinedModel
    default_tokenizer, default_model = _load_default_model()
    return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)

# Cache resource to load default translation model
@st.cache_resource
def _load_default_model():
    model_name = "Helsinki-NLP/opus-mt-en-hi"
    tokenizer = MarianTokenizer.from_pretrained(model_name)
    model = MarianMTModel.from_pretrained(model_name)
    return tokenizer, model

# Cache translation results to improve speed
@st.cache_data
def translate(text, source_lang, target_lang):
    if not text:
        return ""
    try:
        tokenizer, model = load_model(source_lang, target_lang)
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
        if inputs['input_ids'].size(0) > 1:
            inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
        with torch.no_grad():
            translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "zh", "ja"] else 500, num_beams=4, early_stopping=True)
        result = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
        return result if result.strip() else text
    except Exception as e:
        st.error(f"Translation error: {e}")
        return text