File size: 6,378 Bytes
76b06b6
b912ba6
122790b
b912ba6
e95d8e4
a6fc19e
 
 
 
 
 
e95d8e4
71d9b1a
 
 
 
 
 
 
e95d8e4
 
71d9b1a
 
e95d8e4
71d9b1a
 
 
 
 
 
 
 
 
e95d8e4
71d9b1a
 
e95d8e4
3aee48b
e95d8e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b38d3f8
3aee48b
e95d8e4
83c5c51
3aee48b
 
 
 
 
 
83c5c51
e95d8e4
 
 
 
 
 
 
 
 
64afd54
83c5c51
e95d8e4
8d43482
71d9b1a
e95d8e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71d9b1a
e95d8e4
e22e364
 
e95d8e4
 
 
 
 
 
 
 
b912ba6
e95d8e4
8d43482
122790b
dc03f7a
e95d8e4
 
dd34156
 
b38d3f8
17b4050
dd34156
8f10e97
b38d3f8
aa7d7c2
501abbc
e95d8e4
dd34156
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import streamlit as st
from transformers import MarianTokenizer, MarianMTModel
import torch

# Define supported languages
LANGUAGES = {
    "en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"),
    "de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"),
    "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
}

# Load a specific translation model pair with caching
@st.cache_resource
def _load_model_pair(source_lang, target_lang):
    try:
        model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
        tokenizer = MarianTokenizer.from_pretrained(model_name)
        model = MarianMTModel.from_pretrained(model_name)
        return tokenizer, model
    except Exception as e:
        st.error(f"Failed to load model pair ({source_lang} to {target_lang}): {e}")
        return None, None

# Load all possible model combinations with caching
@st.cache_resource
def _load_all_models():
    models = {}
    for src in LANGUAGES.keys():
        for tgt in LANGUAGES.keys():
            if src != tgt:
                models[(src, tgt)] = _load_model_pair(src, tgt)
    return models

# Preload all models
all_models = _load_all_models()

# Perform combined translation through intermediate languages
def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
    try:
        if source_lang == target_lang:  # No translation needed if languages are same
            return text
        if source_lang != "en":
            src_to_inter_tokenizer, src_to_inter_model = None, None
            for inter in ["en", "fr", "es", "de", "ru"]:  # Try multiple intermediates
                pair = all_models.get((source_lang, inter))
                if pair and pair[0] and pair[1]:
                    src_to_inter_tokenizer, src_to_inter_model = pair
                    break
            inter_text = src_to_inter_tokenizer.decode(src_to_inter_model.generate(**src_to_inter_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True) if src_to_inter_tokenizer else text
        else:
            inter_text = text
        if target_lang != "en":
            inter_to_tgt_tokenizer, inter_to_tgt_model = None, None
            for inter in ["en", "fr", "es", "de", "ru"]:
                pair = all_models.get((inter, target_lang))
                if pair and pair[0] and pair[1]:
                    inter_to_tgt_tokenizer, inter_to_tgt_model = pair
                    break
            translated = inter_to_tgt_tokenizer.decode(inter_to_tgt_model.generate(**inter_to_tgt_tokenizer(inter_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True) if inter_to_tgt_tokenizer else inter_text
            return translated if translated.strip() else text
        return inter_text
    except Exception as e:
        st.error(f"Translation error in combined_translate: {e}")
        return text

# Class to handle combined translation
class CombinedModel:
    def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
        self.source_lang = source_lang
        self.target_lang = target_lang
        self.default_tokenizer = default_tokenizer
        self.default_model = default_model

    def generate(self, **kwargs):
        try:
            input_ids = kwargs.get('input_ids')
            if not input_ids or input_ids.size(0) == 0:
                return torch.tensor([])
            inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
            translated = [combined_translate(text, self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for text in inputs]
            return torch.tensor([self.default_tokenizer.encode(t, return_tensors="pt", padding=True, truncation=True, max_length=500)[0] for t in translated])
        except Exception as e:
            st.error(f"Generation error in CombinedModel: {e}")
            return torch.tensor([])

# Load appropriate translation model with caching
@st.cache_resource
def load_model(source_lang, target_lang):
    try:
        if source_lang == target_lang:
            return _load_default_model()
        model_key = (source_lang, target_lang)
        tokenizer_model_pair = all_models.get(model_key)
        if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
            return tokenizer_model_pair
        for inter in LANGUAGES.keys():
            if inter != source_lang and inter != target_lang:
                pair1 = all_models.get((source_lang, inter))
                pair2 = all_models.get((inter, target_lang))
                if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]:
                    return pair1
        default_tokenizer, default_model = _load_default_model()
        return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
    except Exception as e:
        st.error(f"Failed to load model: {e}")
        raise

# Load default translation model with caching
@st.cache_resource
def _load_default_model():
    try:
        model_name = "Helsinki-NLP/opus-mt-en-hi"
        tokenizer = MarianTokenizer.from_pretrained(model_name)
        model = MarianMTModel.from_pretrained(model_name)
        return tokenizer, model
    except Exception as e:
        st.error(f"Failed to load default model: {e}")
        raise

# Translate text with caching
@st.cache_data
def translate(text, source_lang, target_lang):
    try:
        if not text:
            return ""
        tokenizer, model = load_model(source_lang, target_lang)
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
        if inputs['input_ids'].size(0) > 1:
            inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
        with torch.no_grad():
            translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "zh", "ja"] else 500, num_beams=4, early_stopping=True)
        result = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
        return result if result.strip() else text
    except Exception as e:
        st.error(f"Translation failed: {e}")
        return text