Amelia-James commited on
Commit
e74ef23
·
verified ·
1 Parent(s): 7bfe2aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -29
app.py CHANGED
@@ -59,65 +59,57 @@ LANGUAGES = {
59
 
60
  # Function to get the appropriate translation model and tokenizer
61
  def get_translation_model(source_lang, target_lang):
62
- # Use a generic model for translation if specific model is not available
63
  model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
64
- try:
65
- model = MarianMTModel.from_pretrained(model_name)
66
- tokenizer = MarianTokenizer.from_pretrained(model_name)
67
- except OSError:
68
- st.error(f"Model '{model_name}' not found. Please check the model name or use another language pair.")
69
- return None, None
70
  return model, tokenizer
71
 
72
  # Function to translate text
73
  def translate_text(text, source_lang, target_lang):
74
  model, tokenizer = get_translation_model(source_lang, target_lang)
75
- if model is None or tokenizer is None:
76
- return ""
77
  inputs = tokenizer([text], return_tensors="pt", truncation=True)
78
  translated_ids = model.generate(inputs['input_ids'], max_length=1024)
79
  translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
80
  return translated_text
81
 
82
  # Summarization function with multi-language support
83
- def summarize_text(text, input_language="English", output_language="English"):
84
- input_lang_code = LANGUAGES[input_language]
85
- output_lang_code = LANGUAGES[output_language]
86
 
87
  # If the input language is not English, translate to English
88
- if input_lang_code != "en_XX":
89
- text = translate_text(text, input_lang_code, "en_XX")
90
 
91
  # Summarize the text using mBART
92
  inputs = multilingual_summarization_tokenizer(text, return_tensors='pt', padding=True, truncation=True)
93
  summary_ids = multilingual_summarization_model.generate(
94
  inputs['input_ids'],
95
- num_beams=6, # Increase the number of beams for better quality
96
- max_length=1024, # Increase the maximum length
97
- min_length=256, # Set a minimum length for the summary
98
- length_penalty=2.0, # Adjust length penalty to control the length of the summary
99
  early_stopping=True
100
  )
101
  summary = multilingual_summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
102
 
103
- # Translate summary to the output language if needed
104
- if output_lang_code != "en_XX":
105
- summary = translate_text(summary, "en_XX", output_lang_code)
106
 
107
  return summary
108
 
109
  # Streamlit interface
110
  st.title("Multi-Language Text Summarization Tool")
111
- st.write("Enter the text you want to summarize, select the input language, and choose the output language for the summary.")
112
 
113
- text_input = st.text_area("Input Text")
114
- input_language = st.selectbox("Input Language", options=list(LANGUAGES.keys()), index=list(LANGUAGES.keys()).index("English"))
115
- output_language = st.selectbox("Output Language", options=list(LANGUAGES.keys()), index=list(LANGUAGES.keys()).index("English"))
116
 
117
  if st.button("Summarize"):
118
- if text_input:
119
- summary = summarize_text(text_input, input_language, output_language)
120
- st.write("Summary:")
121
  st.write(summary)
122
  else:
123
- st.warning("Please enter some text to summarize.")
 
59
 
60
  # Function to get the appropriate translation model and tokenizer
61
  def get_translation_model(source_lang, target_lang):
 
62
  model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
63
+ model = MarianMTModel.from_pretrained(model_name)
64
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
 
 
 
 
65
  return model, tokenizer
66
 
67
  # Function to translate text
68
  def translate_text(text, source_lang, target_lang):
69
  model, tokenizer = get_translation_model(source_lang, target_lang)
 
 
70
  inputs = tokenizer([text], return_tensors="pt", truncation=True)
71
  translated_ids = model.generate(inputs['input_ids'], max_length=1024)
72
  translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
73
  return translated_text
74
 
75
  # Summarization function with multi-language support
76
+ def summarize_text(text, source_language="English", target_language="English"):
77
+ source_lang_code = LANGUAGES[source_language]
78
+ target_lang_code = LANGUAGES[target_language]
79
 
80
  # If the input language is not English, translate to English
81
+ if source_lang_code != "en_XX":
82
+ text = translate_text(text, source_lang_code, "en_XX")
83
 
84
  # Summarize the text using mBART
85
  inputs = multilingual_summarization_tokenizer(text, return_tensors='pt', padding=True, truncation=True)
86
  summary_ids = multilingual_summarization_model.generate(
87
  inputs['input_ids'],
88
+ num_beams=6, # Increased beams for better quality
89
+ max_length=1500, # Increased maximum length for longer summaries
90
+ min_length=400, # Set a minimum length for the summary
91
+ length_penalty=1.5, # Adjust length penalty to control the length of the summary
92
  early_stopping=True
93
  )
94
  summary = multilingual_summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
95
 
96
+ # Translate summary to the target language if needed
97
+ if target_lang_code != "en_XX":
98
+ summary = translate_text(summary, "en_XX", target_lang_code)
99
 
100
  return summary
101
 
102
  # Streamlit interface
103
  st.title("Multi-Language Text Summarization Tool")
 
104
 
105
+ text = st.text_area("Input Text")
106
+ source_language = st.selectbox("Source Language", options=list(LANGUAGES.keys()), index=list(LANGUAGES.keys()).index("English"))
107
+ target_language = st.selectbox("Target Language", options=list(LANGUAGES.keys()), index=list(LANGUAGES.keys()).index("English"))
108
 
109
  if st.button("Summarize"):
110
+ if text:
111
+ summary = summarize_text(text, source_language, target_language)
112
+ st.subheader("Summary")
113
  st.write(summary)
114
  else:
115
+ st.warning("Please enter text to summarize.")