Spaces:
Runtime error
Runtime error
from transformers import MBartForConditionalGeneration, MBartTokenizer, MarianMTModel, MarianTokenizer | |
import streamlit as st | |
# Load multilingual summarization model and tokenizer | |
multilingual_summarization_model = MBartForConditionalGeneration.from_pretrained('facebook/mbart-large-50') | |
multilingual_summarization_tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-50') | |
# Dictionary of languages and their corresponding Hugging Face model codes | |
LANGUAGES = { | |
"English": "en_XX", | |
"French": "fr_XX", | |
"Spanish": "es_XX", | |
"German": "de_DE", | |
"Chinese": "zh_CN", | |
"Russian": "ru_RU", | |
"Arabic": "ar_AR", | |
"Portuguese": "pt_PT", | |
"Hindi": "hi_IN", | |
"Italian": "it_IT", | |
"Japanese": "ja_XX", | |
"Korean": "ko_KR", | |
"Dutch": "nl_NL", | |
"Polish": "pl_PL", | |
"Turkish": "tr_TR", | |
"Swedish": "sv_SE", | |
"Greek": "el_EL", | |
"Finnish": "fi_FI", | |
"Hungarian": "hu_HU", | |
"Danish": "da_DK", | |
"Norwegian": "no_NO", | |
"Czech": "cs_CZ", | |
"Romanian": "ro_RO", | |
"Thai": "th_TH", | |
"Hebrew": "he_IL", | |
"Vietnamese": "vi_VN", | |
"Indonesian": "id_ID", | |
"Malay": "ms_MY", | |
"Bengali": "bn_BD", | |
"Ukrainian": "uk_UA", | |
"Urdu": "ur_PK", | |
"Swahili": "sw_KE", | |
"Serbian": "sr_SR", | |
"Croatian": "hr_HR", | |
"Slovak": "sk_SK", | |
"Lithuanian": "lt_LT", | |
"Latvian": "lv_LV", | |
"Estonian": "et_EE", | |
"Bulgarian": "bg_BG", | |
"Macedonian": "mk_MK", | |
"Albanian": "sq_AL", | |
"Georgian": "ka_GE", | |
"Armenian": "hy_AM", | |
"Kazakh": "kk_KZ", | |
"Uzbek": "uz_UZ", | |
"Tajik": "tg_TJ", | |
"Kyrgyz": "ky_KG", | |
"Turkmen": "tk_TM" | |
} | |
# Function to get the appropriate translation model and tokenizer | |
def get_translation_model(source_lang, target_lang): | |
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}" | |
model = MarianMTModel.from_pretrained(model_name) | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
return model, tokenizer | |
# Function to translate text | |
def translate_text(text, source_lang, target_lang): | |
model, tokenizer = get_translation_model(source_lang, target_lang) | |
inputs = tokenizer([text], return_tensors="pt", truncation=True) | |
translated_ids = model.generate(inputs['input_ids'], max_length=1024) | |
translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True) | |
return translated_text | |
# Summarization function with multi-language support | |
def summarize_text(text, source_language="English", target_language="English"): | |
source_lang_code = LANGUAGES[source_language] | |
target_lang_code = LANGUAGES[target_language] | |
# If the input language is not English, translate to English | |
if source_lang_code != "en_XX": | |
text = translate_text(text, source_lang_code, "en_XX") | |
# Summarize the text using mBART | |
inputs = multilingual_summarization_tokenizer(text, return_tensors='pt', padding=True, truncation=True) | |
summary_ids = multilingual_summarization_model.generate( | |
inputs['input_ids'], | |
num_beams=6, # Increased beams for better quality | |
max_length=1500, # Increased maximum length for longer summaries | |
min_length=400, # Set a minimum length for the summary | |
length_penalty=1.5, # Adjust length penalty to control the length of the summary | |
early_stopping=True | |
) | |
summary = multilingual_summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
# Translate summary to the target language if needed | |
if target_lang_code != "en_XX": | |
summary = translate_text(summary, "en_XX", target_lang_code) | |
return summary | |
# Streamlit interface | |
st.title("Multi-Language Text Summarization Tool") | |
text = st.text_area("Input Text") | |
source_language = st.selectbox("Source Language", options=list(LANGUAGES.keys()), index=list(LANGUAGES.keys()).index("English")) | |
target_language = st.selectbox("Target Language", options=list(LANGUAGES.keys()), index=list(LANGUAGES.keys()).index("English")) | |
if st.button("Summarize"): | |
if text: | |
summary = summarize_text(text, source_language, target_language) | |
st.subheader("Summary") | |
st.write(summary) | |
else: | |
st.warning("Please enter text to summarize.") | |