Krishna086 commited on
Commit
dd34156
·
verified ·
1 Parent(s): ff5aa1c

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +10 -18
translation.py CHANGED
@@ -1,4 +1,3 @@
1
- import streamlit as st
2
  from transformers import MarianTokenizer, MarianMTModel
3
  import torch
4
 
@@ -10,7 +9,7 @@ LANGUAGES = {
10
 
11
  @st.cache_resource
12
  def _load_default_model():
13
- model_name = "Helsinki-NLP/opus-mt-en-hi"
14
  tokenizer = MarianTokenizer.from_pretrained(model_name)
15
  model = MarianMTModel.from_pretrained(model_name)
16
  return tokenizer, model
@@ -24,24 +23,17 @@ def load_model(source_lang, target_lang):
24
  tokenizer = MarianTokenizer.from_pretrained(model_name)
25
  model = MarianMTModel.from_pretrained(model_name)
26
  return tokenizer, model
27
- except Exception as e:
28
- st.warning(f"No direct model for {source_lang} to {target_lang}. Using en-hi fallback.")
29
  return _load_default_model()
30
 
31
- @st.cache_data(ttl=3600)
32
- def translate_cached(text, source_lang, target_lang):
33
- tokenizer, model = load_model(source_lang, target_lang)
34
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
35
- with torch.no_grad():
36
- translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True)
37
- translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
38
- return translated_text if translated_text.strip() and len(translated_text.split()) >= 2 else text
39
-
40
  def translate(text, source_lang, target_lang):
41
  if not text:
42
- return "No text provided."
43
  try:
44
- return translate_cached(text, source_lang, target_lang)
45
- except Exception as e:
46
- st.error(f"Translation error: {str(e)}. Using input as fallback.")
47
- return text
 
 
 
 
 
1
  from transformers import MarianTokenizer, MarianMTModel
2
  import torch
3
 
 
9
 
10
  @st.cache_resource
11
  def _load_default_model():
12
+ model_name = "Helsinki-NLP/opus-mt-en-hi"
13
  tokenizer = MarianTokenizer.from_pretrained(model_name)
14
  model = MarianMTModel.from_pretrained(model_name)
15
  return tokenizer, model
 
23
  tokenizer = MarianTokenizer.from_pretrained(model_name)
24
  model = MarianMTModel.from_pretrained(model_name)
25
  return tokenizer, model
26
+ except Exception:
 
27
  return _load_default_model()
28
 
 
 
 
 
 
 
 
 
 
29
  def translate(text, source_lang, target_lang):
30
  if not text:
31
+ return ""
32
  try:
33
+ tokenizer, model = load_model(source_lang, target_lang)
34
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
35
+ with torch.no_grad():
36
+ translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True)
37
+ return tokenizer.decode(translated[0], skip_special_tokens=True)
38
+ except Exception:
39
+ return text