File size: 6,200 Bytes
76b06b6 b912ba6 122790b b912ba6 a6fc19e 83c5c51 71d9b1a 83c5c51 71d9b1a 64afd54 3aee48b b38d3f8 8f10e97 b38d3f8 8f10e97 b38d3f8 8f10e97 b38d3f8 3aee48b 8f10e97 83c5c51 3aee48b 83c5c51 64afd54 b38d3f8 64afd54 b38d3f8 83c5c51 8d43482 71d9b1a 410ea95 b38d3f8 8f10e97 288ca30 3aee48b 71d9b1a 83c5c51 e22e364 dd34156 122790b b912ba6 8d43482 122790b dd34156 dc03f7a dd34156 b38d3f8 17b4050 dd34156 8f10e97 b38d3f8 aa7d7c2 501abbc dd34156 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import streamlit as st
from transformers import MarianTokenizer, MarianMTModel
import torch
LANGUAGES = {
"en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"),
"de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"),
"ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
}
# Cache resource to load a specific translation model pair
@st.cache_resource
def _load_model_pair(source_lang, target_lang):
try:
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
return tokenizer, model
except Exception:
return None, None
# Cache resource to load all possible model combinations
@st.cache_resource
def _load_all_models():
models = {}
for src in LANGUAGES.keys():
for tgt in LANGUAGES.keys():
if src != tgt:
models[(src, tgt)] = _load_model_pair(src, tgt)
return models
all_models = _load_all_models()
# Define combined_translate outside load_model with explicit parameters
def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
with torch.no_grad():
if source_lang != target_lang: # Only translate if languages differ
if source_lang != "en":
src_to_inter_tokenizer, src_to_inter_model = None, None
# Try multiple intermediates, prefer English first
for inter in ["en", "fr", "es", "de", "ru"]: # Prioritize common languages
pair = all_models.get((source_lang, inter))
if pair and pair[0] and pair[1]:
src_to_inter_tokenizer, src_to_inter_model = pair
break
if src_to_inter_tokenizer and src_to_inter_model:
inter_text = src_to_inter_tokenizer.decode(src_to_inter_model.generate(**src_to_inter_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
else:
inter_text = text # Fallback to input if no path found
else:
inter_text = text
if target_lang != "en":
inter_to_tgt_tokenizer, inter_to_tgt_model = None, None
for inter in ["en", "fr", "es", "de", "ru"]:
pair = all_models.get((inter, target_lang))
if pair and pair[0] and pair[1]:
inter_to_tgt_tokenizer, inter_to_tgt_model = pair
break
if inter_to_tgt_tokenizer and inter_to_tgt_model:
translated = inter_to_tgt_tokenizer.decode(inter_to_tgt_model.generate(**inter_to_tgt_tokenizer(inter_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True)
return translated if translated.strip() else text
return inter_text
return text
# Class to handle combined translation through multiple intermediates
class CombinedModel:
def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
self.source_lang = source_lang
self.target_lang = target_lang
self.default_tokenizer = default_tokenizer
self.default_model = default_model
def generate(self, **kwargs):
input_ids = kwargs.get('input_ids')
if not input_ids or input_ids.size(0) == 0:
return torch.tensor([])
inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
translated = [combined_translate(text, self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for text in inputs]
return torch.tensor([self.default_tokenizer.encode(t, return_tensors="pt", padding=True, truncation=True, max_length=500)[0] for t in translated])
# Function to load appropriate translation model with optimized caching
@st.cache_resource
def load_model(source_lang, target_lang):
if source_lang == target_lang:
return _load_default_model()
model_key = (source_lang, target_lang)
tokenizer_model_pair = all_models.get(model_key)
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
return tokenizer_model_pair
# Try to find the best path through any intermediate language
for inter in LANGUAGES.keys():
if inter != source_lang and inter != target_lang:
pair1 = all_models.get((source_lang, inter))
pair2 = all_models.get((inter, target_lang))
if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]:
return pair1
# Fallback to default model with CombinedModel
default_tokenizer, default_model = _load_default_model()
return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
# Cache resource to load default translation model
@st.cache_resource
def _load_default_model():
model_name = "Helsinki-NLP/opus-mt-en-hi"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
return tokenizer, model
# Cache translation results to improve speed
@st.cache_data
def translate(text, source_lang, target_lang):
if not text:
return ""
try:
tokenizer, model = load_model(source_lang, target_lang)
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
if inputs['input_ids'].size(0) > 1:
inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
with torch.no_grad():
translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "zh", "ja"] else 500, num_beams=4, early_stopping=True)
result = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
return result if result.strip() else text
except Exception as e:
st.error(f"Translation error: {e}")
return text |