Update app.py
Browse files
app.py
CHANGED
@@ -59,65 +59,57 @@ LANGUAGES = {
|
|
59 |
|
60 |
# Function to get the appropriate translation model and tokenizer
|
61 |
def get_translation_model(source_lang, target_lang):
|
62 |
-
# Use a generic model for translation if specific model is not available
|
63 |
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
|
64 |
-
|
65 |
-
|
66 |
-
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
67 |
-
except OSError:
|
68 |
-
st.error(f"Model '{model_name}' not found. Please check the model name or use another language pair.")
|
69 |
-
return None, None
|
70 |
return model, tokenizer
|
71 |
|
72 |
# Function to translate text
|
73 |
def translate_text(text, source_lang, target_lang):
|
74 |
model, tokenizer = get_translation_model(source_lang, target_lang)
|
75 |
-
if model is None or tokenizer is None:
|
76 |
-
return ""
|
77 |
inputs = tokenizer([text], return_tensors="pt", truncation=True)
|
78 |
translated_ids = model.generate(inputs['input_ids'], max_length=1024)
|
79 |
translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
80 |
return translated_text
|
81 |
|
82 |
# Summarization function with multi-language support
|
83 |
-
def summarize_text(text,
|
84 |
-
|
85 |
-
|
86 |
|
87 |
# If the input language is not English, translate to English
|
88 |
-
if
|
89 |
-
text = translate_text(text,
|
90 |
|
91 |
# Summarize the text using mBART
|
92 |
inputs = multilingual_summarization_tokenizer(text, return_tensors='pt', padding=True, truncation=True)
|
93 |
summary_ids = multilingual_summarization_model.generate(
|
94 |
inputs['input_ids'],
|
95 |
-
num_beams=6, #
|
96 |
-
max_length=
|
97 |
-
min_length=
|
98 |
-
length_penalty=
|
99 |
early_stopping=True
|
100 |
)
|
101 |
summary = multilingual_summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
102 |
|
103 |
-
# Translate summary to the
|
104 |
-
if
|
105 |
-
summary = translate_text(summary, "en_XX",
|
106 |
|
107 |
return summary
|
108 |
|
109 |
# Streamlit interface
|
110 |
st.title("Multi-Language Text Summarization Tool")
|
111 |
-
st.write("Enter the text you want to summarize, select the input language, and choose the output language for the summary.")
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
|
117 |
if st.button("Summarize"):
|
118 |
-
if
|
119 |
-
summary = summarize_text(
|
120 |
-
st.
|
121 |
st.write(summary)
|
122 |
else:
|
123 |
-
st.warning("Please enter
|
|
|
59 |
|
60 |
# Function to get the appropriate translation model and tokenizer
|
61 |
def get_translation_model(source_lang, target_lang):
|
|
|
62 |
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
|
63 |
+
model = MarianMTModel.from_pretrained(model_name)
|
64 |
+
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
65 |
return model, tokenizer
|
66 |
|
67 |
# Function to translate text
|
68 |
def translate_text(text, source_lang, target_lang):
|
69 |
model, tokenizer = get_translation_model(source_lang, target_lang)
|
|
|
|
|
70 |
inputs = tokenizer([text], return_tensors="pt", truncation=True)
|
71 |
translated_ids = model.generate(inputs['input_ids'], max_length=1024)
|
72 |
translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
73 |
return translated_text
|
74 |
|
75 |
# Summarization function with multi-language support
|
76 |
+
def summarize_text(text, source_language="English", target_language="English"):
|
77 |
+
source_lang_code = LANGUAGES[source_language]
|
78 |
+
target_lang_code = LANGUAGES[target_language]
|
79 |
|
80 |
# If the input language is not English, translate to English
|
81 |
+
if source_lang_code != "en_XX":
|
82 |
+
text = translate_text(text, source_lang_code, "en_XX")
|
83 |
|
84 |
# Summarize the text using mBART
|
85 |
inputs = multilingual_summarization_tokenizer(text, return_tensors='pt', padding=True, truncation=True)
|
86 |
summary_ids = multilingual_summarization_model.generate(
|
87 |
inputs['input_ids'],
|
88 |
+
num_beams=6, # Increased beams for better quality
|
89 |
+
max_length=1500, # Increased maximum length for longer summaries
|
90 |
+
min_length=400, # Set a minimum length for the summary
|
91 |
+
length_penalty=1.5, # Adjust length penalty to control the length of the summary
|
92 |
early_stopping=True
|
93 |
)
|
94 |
summary = multilingual_summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
95 |
|
96 |
+
# Translate summary to the target language if needed
|
97 |
+
if target_lang_code != "en_XX":
|
98 |
+
summary = translate_text(summary, "en_XX", target_lang_code)
|
99 |
|
100 |
return summary
|
101 |
|
102 |
# Streamlit interface
|
103 |
st.title("Multi-Language Text Summarization Tool")
|
|
|
104 |
|
105 |
+
text = st.text_area("Input Text")
|
106 |
+
source_language = st.selectbox("Source Language", options=list(LANGUAGES.keys()), index=list(LANGUAGES.keys()).index("English"))
|
107 |
+
target_language = st.selectbox("Target Language", options=list(LANGUAGES.keys()), index=list(LANGUAGES.keys()).index("English"))
|
108 |
|
109 |
if st.button("Summarize"):
|
110 |
+
if text:
|
111 |
+
summary = summarize_text(text, source_language, target_language)
|
112 |
+
st.subheader("Summary")
|
113 |
st.write(summary)
|
114 |
else:
|
115 |
+
st.warning("Please enter text to summarize.")
|