Krishna086 commited on
Commit
e95d8e4
·
verified ·
1 Parent(s): 0787c83

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +74 -63
translation.py CHANGED
@@ -2,13 +2,14 @@ import streamlit as st
2
  from transformers import MarianTokenizer, MarianMTModel
3
  import torch
4
 
 
5
  LANGUAGES = {
6
  "en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"),
7
  "de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"),
8
  "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
9
  }
10
 
11
- # Cache resource to load a specific translation model pair
12
  @st.cache_resource
13
  def _load_model_pair(source_lang, target_lang):
14
  try:
@@ -16,10 +17,11 @@ def _load_model_pair(source_lang, target_lang):
16
  tokenizer = MarianTokenizer.from_pretrained(model_name)
17
  model = MarianMTModel.from_pretrained(model_name)
18
  return tokenizer, model
19
- except Exception:
 
20
  return None, None
21
 
22
- # Cache resource to load all possible model combinations
23
  @st.cache_resource
24
  def _load_all_models():
25
  models = {}
@@ -29,40 +31,39 @@ def _load_all_models():
29
  models[(src, tgt)] = _load_model_pair(src, tgt)
30
  return models
31
 
 
32
  all_models = _load_all_models()
33
 
34
- # Define combined_translate outside load_model with explicit parameters
35
  def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
36
- with torch.no_grad():
37
- if source_lang != target_lang: # Only translate if languages differ
38
- if source_lang != "en":
39
- src_to_inter_tokenizer, src_to_inter_model = None, None
40
- # Try multiple intermediates, prefer English first
41
- for inter in ["en", "fr", "es", "de", "ru"]: # Prioritize common languages
42
- pair = all_models.get((source_lang, inter))
43
- if pair and pair[0] and pair[1]:
44
- src_to_inter_tokenizer, src_to_inter_model = pair
45
- break
46
- if src_to_inter_tokenizer and src_to_inter_model:
47
- inter_text = src_to_inter_tokenizer.decode(src_to_inter_model.generate(**src_to_inter_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
48
- else:
49
- inter_text = text # Fallback to input if no path found
50
- else:
51
- inter_text = text
52
- if target_lang != "en":
53
- inter_to_tgt_tokenizer, inter_to_tgt_model = None, None
54
- for inter in ["en", "fr", "es", "de", "ru"]:
55
- pair = all_models.get((inter, target_lang))
56
- if pair and pair[0] and pair[1]:
57
- inter_to_tgt_tokenizer, inter_to_tgt_model = pair
58
- break
59
- if inter_to_tgt_tokenizer and inter_to_tgt_model:
60
- translated = inter_to_tgt_tokenizer.decode(inter_to_tgt_model.generate(**inter_to_tgt_tokenizer(inter_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True)
61
- return translated if translated.strip() else text
62
- return inter_text
63
  return text
64
 
65
- # Class to handle combined translation through multiple intermediates
66
  class CombinedModel:
67
  def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
68
  self.source_lang = source_lang
@@ -71,47 +72,57 @@ class CombinedModel:
71
  self.default_model = default_model
72
 
73
  def generate(self, **kwargs):
74
- input_ids = kwargs.get('input_ids')
75
- if not input_ids or input_ids.size(0) == 0:
 
 
 
 
 
 
 
76
  return torch.tensor([])
77
- inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
78
- translated = [combined_translate(text, self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for text in inputs]
79
- return torch.tensor([self.default_tokenizer.encode(t, return_tensors="pt", padding=True, truncation=True, max_length=500)[0] for t in translated])
80
 
81
- # Function to load appropriate translation model with optimized caching
82
  @st.cache_resource
83
  def load_model(source_lang, target_lang):
84
- if source_lang == target_lang:
85
- return _load_default_model()
86
- model_key = (source_lang, target_lang)
87
- tokenizer_model_pair = all_models.get(model_key)
88
- if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
89
- return tokenizer_model_pair
90
- # Try to find the best path through any intermediate language
91
- for inter in LANGUAGES.keys():
92
- if inter != source_lang and inter != target_lang:
93
- pair1 = all_models.get((source_lang, inter))
94
- pair2 = all_models.get((inter, target_lang))
95
- if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]:
96
- return pair1
97
- # Fallback to default model with CombinedModel
98
- default_tokenizer, default_model = _load_default_model()
99
- return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
 
 
100
 
101
- # Cache resource to load default translation model
102
  @st.cache_resource
103
  def _load_default_model():
104
- model_name = "Helsinki-NLP/opus-mt-en-hi"
105
- tokenizer = MarianTokenizer.from_pretrained(model_name)
106
- model = MarianMTModel.from_pretrained(model_name)
107
- return tokenizer, model
 
 
 
 
108
 
109
- # Cache translation results to improve speed
110
  @st.cache_data
111
  def translate(text, source_lang, target_lang):
112
- if not text:
113
- return ""
114
  try:
 
 
115
  tokenizer, model = load_model(source_lang, target_lang)
116
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
117
  if inputs['input_ids'].size(0) > 1:
@@ -121,5 +132,5 @@ def translate(text, source_lang, target_lang):
121
  result = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
122
  return result if result.strip() else text
123
  except Exception as e:
124
- st.error(f"Translation error: {e}")
125
  return text
 
2
  from transformers import MarianTokenizer, MarianMTModel
3
  import torch
4
 
5
+ # Define supported languages
6
  LANGUAGES = {
7
  "en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"),
8
  "de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"),
9
  "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
10
  }
11
 
12
+ # Load a specific translation model pair with caching
13
  @st.cache_resource
14
  def _load_model_pair(source_lang, target_lang):
15
  try:
 
17
  tokenizer = MarianTokenizer.from_pretrained(model_name)
18
  model = MarianMTModel.from_pretrained(model_name)
19
  return tokenizer, model
20
+ except Exception as e:
21
+ st.error(f"Failed to load model pair ({source_lang} to {target_lang}): {e}")
22
  return None, None
23
 
24
+ # Load all possible model combinations with caching
25
  @st.cache_resource
26
  def _load_all_models():
27
  models = {}
 
31
  models[(src, tgt)] = _load_model_pair(src, tgt)
32
  return models
33
 
34
+ # Preload all models
35
  all_models = _load_all_models()
36
 
37
+ # Perform combined translation through intermediate languages
38
  def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
39
+ try:
40
+ if source_lang == target_lang: # No translation needed if languages are same
41
+ return text
42
+ if source_lang != "en":
43
+ src_to_inter_tokenizer, src_to_inter_model = None, None
44
+ for inter in ["en", "fr", "es", "de", "ru"]: # Try multiple intermediates
45
+ pair = all_models.get((source_lang, inter))
46
+ if pair and pair[0] and pair[1]:
47
+ src_to_inter_tokenizer, src_to_inter_model = pair
48
+ break
49
+ inter_text = src_to_inter_tokenizer.decode(src_to_inter_model.generate(**src_to_inter_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True) if src_to_inter_tokenizer else text
50
+ else:
51
+ inter_text = text
52
+ if target_lang != "en":
53
+ inter_to_tgt_tokenizer, inter_to_tgt_model = None, None
54
+ for inter in ["en", "fr", "es", "de", "ru"]:
55
+ pair = all_models.get((inter, target_lang))
56
+ if pair and pair[0] and pair[1]:
57
+ inter_to_tgt_tokenizer, inter_to_tgt_model = pair
58
+ break
59
+ translated = inter_to_tgt_tokenizer.decode(inter_to_tgt_model.generate(**inter_to_tgt_tokenizer(inter_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True) if inter_to_tgt_tokenizer else inter_text
60
+ return translated if translated.strip() else text
61
+ return inter_text
62
+ except Exception as e:
63
+ st.error(f"Translation error in combined_translate: {e}")
 
 
64
  return text
65
 
66
+ # Class to handle combined translation
67
  class CombinedModel:
68
  def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
69
  self.source_lang = source_lang
 
72
  self.default_model = default_model
73
 
74
  def generate(self, **kwargs):
75
+ try:
76
+ input_ids = kwargs.get('input_ids')
77
+ if not input_ids or input_ids.size(0) == 0:
78
+ return torch.tensor([])
79
+ inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
80
+ translated = [combined_translate(text, self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for text in inputs]
81
+ return torch.tensor([self.default_tokenizer.encode(t, return_tensors="pt", padding=True, truncation=True, max_length=500)[0] for t in translated])
82
+ except Exception as e:
83
+ st.error(f"Generation error in CombinedModel: {e}")
84
  return torch.tensor([])
 
 
 
85
 
86
+ # Load appropriate translation model with caching
87
  @st.cache_resource
88
  def load_model(source_lang, target_lang):
89
+ try:
90
+ if source_lang == target_lang:
91
+ return _load_default_model()
92
+ model_key = (source_lang, target_lang)
93
+ tokenizer_model_pair = all_models.get(model_key)
94
+ if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
95
+ return tokenizer_model_pair
96
+ for inter in LANGUAGES.keys():
97
+ if inter != source_lang and inter != target_lang:
98
+ pair1 = all_models.get((source_lang, inter))
99
+ pair2 = all_models.get((inter, target_lang))
100
+ if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]:
101
+ return pair1
102
+ default_tokenizer, default_model = _load_default_model()
103
+ return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
104
+ except Exception as e:
105
+ st.error(f"Failed to load model: {e}")
106
+ raise
107
 
108
+ # Load default translation model with caching
109
  @st.cache_resource
110
  def _load_default_model():
111
+ try:
112
+ model_name = "Helsinki-NLP/opus-mt-en-hi"
113
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
114
+ model = MarianMTModel.from_pretrained(model_name)
115
+ return tokenizer, model
116
+ except Exception as e:
117
+ st.error(f"Failed to load default model: {e}")
118
+ raise
119
 
120
+ # Translate text with caching
121
  @st.cache_data
122
  def translate(text, source_lang, target_lang):
 
 
123
  try:
124
+ if not text:
125
+ return ""
126
  tokenizer, model = load_model(source_lang, target_lang)
127
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
128
  if inputs['input_ids'].size(0) > 1:
 
132
  result = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
133
  return result if result.strip() else text
134
  except Exception as e:
135
+ st.error(f"Translation failed: {e}")
136
  return text