Krishna086's picture
Update translation.py
8f10e97 verified
raw
history blame
6.2 kB
import streamlit as st
from transformers import MarianTokenizer, MarianMTModel
import torch
LANGUAGES = {
"en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"),
"de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"),
"ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
}
# Cache resource to load a specific translation model pair
@st.cache_resource
def _load_model_pair(source_lang, target_lang):
try:
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
return tokenizer, model
except Exception:
return None, None
# Cache resource to load all possible model combinations
@st.cache_resource
def _load_all_models():
models = {}
for src in LANGUAGES.keys():
for tgt in LANGUAGES.keys():
if src != tgt:
models[(src, tgt)] = _load_model_pair(src, tgt)
return models
all_models = _load_all_models()
# Define combined_translate outside load_model with explicit parameters
def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
with torch.no_grad():
if source_lang != target_lang: # Only translate if languages differ
if source_lang != "en":
src_to_inter_tokenizer, src_to_inter_model = None, None
# Try multiple intermediates, prefer English first
for inter in ["en", "fr", "es", "de", "ru"]: # Prioritize common languages
pair = all_models.get((source_lang, inter))
if pair and pair[0] and pair[1]:
src_to_inter_tokenizer, src_to_inter_model = pair
break
if src_to_inter_tokenizer and src_to_inter_model:
inter_text = src_to_inter_tokenizer.decode(src_to_inter_model.generate(**src_to_inter_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
else:
inter_text = text # Fallback to input if no path found
else:
inter_text = text
if target_lang != "en":
inter_to_tgt_tokenizer, inter_to_tgt_model = None, None
for inter in ["en", "fr", "es", "de", "ru"]:
pair = all_models.get((inter, target_lang))
if pair and pair[0] and pair[1]:
inter_to_tgt_tokenizer, inter_to_tgt_model = pair
break
if inter_to_tgt_tokenizer and inter_to_tgt_model:
translated = inter_to_tgt_tokenizer.decode(inter_to_tgt_model.generate(**inter_to_tgt_tokenizer(inter_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True)
return translated if translated.strip() else text
return inter_text
return text
# Class to handle combined translation through multiple intermediates
class CombinedModel:
def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
self.source_lang = source_lang
self.target_lang = target_lang
self.default_tokenizer = default_tokenizer
self.default_model = default_model
def generate(self, **kwargs):
input_ids = kwargs.get('input_ids')
if not input_ids or input_ids.size(0) == 0:
return torch.tensor([])
inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
translated = [combined_translate(text, self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for text in inputs]
return torch.tensor([self.default_tokenizer.encode(t, return_tensors="pt", padding=True, truncation=True, max_length=500)[0] for t in translated])
# Function to load appropriate translation model with optimized caching
@st.cache_resource
def load_model(source_lang, target_lang):
if source_lang == target_lang:
return _load_default_model()
model_key = (source_lang, target_lang)
tokenizer_model_pair = all_models.get(model_key)
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
return tokenizer_model_pair
# Try to find the best path through any intermediate language
for inter in LANGUAGES.keys():
if inter != source_lang and inter != target_lang:
pair1 = all_models.get((source_lang, inter))
pair2 = all_models.get((inter, target_lang))
if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]:
return pair1
# Fallback to default model with CombinedModel
default_tokenizer, default_model = _load_default_model()
return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
# Cache resource to load default translation model
@st.cache_resource
def _load_default_model():
model_name = "Helsinki-NLP/opus-mt-en-hi"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
return tokenizer, model
# Cache translation results to improve speed
@st.cache_data
def translate(text, source_lang, target_lang):
if not text:
return ""
try:
tokenizer, model = load_model(source_lang, target_lang)
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
if inputs['input_ids'].size(0) > 1:
inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
with torch.no_grad():
translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "zh", "ja"] else 500, num_beams=4, early_stopping=True)
result = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
return result if result.strip() else text
except Exception as e:
st.error(f"Translation error: {e}")
return text