Spaces:
Running
Running
import os | |
import sys | |
import gradio as gr | |
from transformers import BartForConditionalGeneration, BartTokenizer, MBartForConditionalGeneration, MBart50TokenizerFast | |
from newspaper import Article | |
import requests | |
from urllib.parse import urlparse | |
# ์ ์ญ ๋ณ์ | |
trans_model = None | |
trans_tokenizer = None | |
summ_model = None | |
summ_tokenizer = None | |
def load_models(): | |
global trans_model, trans_tokenizer, summ_model, summ_tokenizer | |
if trans_model is None: | |
try: | |
print("MBart ๋ค๊ตญ์ด ๋ฒ์ญ ๋ชจ๋ธ ๋ก๋ฉ ์๋ ์ค...") | |
# ์์ด โ ํ๊ตญ์ด ๋ฒ์ญ ๋ชจ๋ธ ๋ก๋ (MBart ๋ค๊ตญ์ด ๋ชจ๋ธ ์ฌ์ฉ) | |
try: | |
trans_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50") | |
trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50") | |
print("๋ฒ์ญ ๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต") | |
except Exception as e: | |
print(f"๋ฒ์ญ ๋ชจ๋ธ ๋ก๋ ์ค๋ฅ: {e}") | |
return f"๋ฒ์ญ ๋ชจ๋ธ ๋ก๋ ์ค๋ฅ: {str(e)}" | |
print("์์ฝ ๋ชจ๋ธ ๋ก๋ฉ ์๋ ์ค...") | |
# ํ๊ตญ์ด ์์ฝ ๋ชจ๋ธ ๋ก๋ | |
try: | |
summ_tokenizer = BartTokenizer.from_pretrained("digit82/bart-korean-summarization") | |
summ_model = BartForConditionalGeneration.from_pretrained("digit82/bart-korean-summarization") | |
print("์์ฝ ๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต") | |
except Exception as e: | |
print(f"์์ฝ ๋ชจ๋ธ ๋ก๋ ์ค๋ฅ: {e}") | |
return f"์์ฝ ๋ชจ๋ธ ๋ก๋ ์ค๋ฅ: {str(e)}" | |
return "๋ชจ๋ธ ๋ก๋ ์๋ฃ" | |
except Exception as e: | |
print(f"๋ชจ๋ธ ๋ก๋ ์ค ์์์น ๋ชปํ ์ค๋ฅ: {e}") | |
return f"๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}" | |
def is_valid_url(url): | |
"""URL ์ ํจ์ฑ ๊ฒ์ฌ""" | |
try: | |
result = urlparse(url) | |
return all([result.scheme, result.netloc]) | |
except: | |
return False | |
def validate_url_accessibility(url): | |
"""URL ์ ๊ทผ ๊ฐ๋ฅ ์ฌ๋ถ ํ์ธ""" | |
try: | |
response = requests.head(url, timeout=5) | |
return 200 <= response.status_code < 400 | |
except: | |
return False | |
def translate_text(text): | |
"""MBart ๋ชจ๋ธ์ ์ฌ์ฉํ ์์ดโํ๊ตญ์ด ๋ฒ์ญ""" | |
# ์์ด๋ก ์์ค ์ธ์ด ์ค์ | |
trans_tokenizer.src_lang = "en_XX" | |
# ์ธ์ฝ๋ฉ | |
encoded = trans_tokenizer(text, return_tensors="pt", max_length=1024, truncation=True) | |
# ๋ฒ์ญ ์์ฑ | |
generated_tokens = trans_model.generate( | |
**encoded, | |
forced_bos_token_id=trans_tokenizer.lang_code_to_id["ko_KR"], | |
max_length=1024, | |
num_beams=5, | |
early_stopping=True | |
) | |
# ๋์ฝ๋ฉ | |
translated_text = trans_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
return translated_text | |
def summarize_news(url): | |
"""๋ด์ค ๊ธฐ์ฌ ์์ฝ ํจ์""" | |
# ์ ๋ ฅ ๊ฒ์ฆ | |
if not url or not isinstance(url, str): | |
return "์ค๋ฅ: ์ ํจํ URL์ ์ ๋ ฅํด์ฃผ์ธ์." | |
if not is_valid_url(url): | |
return "์ค๋ฅ: ์ ๋ ฅํ ํ ์คํธ๊ฐ ์ ํจํ URL ํ์์ด ์๋๋๋ค." | |
if not validate_url_accessibility(url): | |
return "์ค๋ฅ: ์ ๋ ฅํ URL์ ์ ๊ทผํ ์ ์์ต๋๋ค. URL์ด ์ฌ๋ฐ๋ฅธ์ง ํ์ธํด์ฃผ์ธ์." | |
# ๋ชจ๋ธ ๋ก๋ (์์ง ๋ก๋๋์ง ์์ ๊ฒฝ์ฐ) | |
if trans_model is None: | |
load_result = load_models() | |
if isinstance(load_result, str) and "์ค๋ฅ" in load_result: | |
return load_result | |
try: | |
# ๋ด์ค ๊ธฐ์ฌ ํฌ๋กค๋ง | |
article = Article(url) | |
article.download() | |
article.parse() | |
text = article.text | |
if not text: | |
return "์ค๋ฅ: ๊ธฐ์ฌ์์ ํ ์คํธ๋ฅผ ์ถ์ถํ ์ ์์ต๋๋ค." | |
# ๊ธด ํ ์คํธ ์ฒ๋ฆฌ | |
if len(text) > 5000: | |
print(f"ํ ์คํธ๊ฐ ๋๋ฌด ๊น๋๋ค. ์๋ณธ ๊ธธ์ด: {len(text)}์, 5000์๋ก ์ ํํฉ๋๋ค.") | |
text = text[:5000] | |
print("ํ ์คํธ ๋ฒ์ญ ์ค...") | |
# ์์ด โ ํ๊ตญ์ด ๋ฒ์ญ (MBart ๋ชจ๋ธ ์ฌ์ฉ) | |
translated_text = translate_text(text) | |
print("๋ฒ์ญ๋ ํ ์คํธ ์์ฝ ์ค...") | |
# ํ๊ตญ์ด ํ ์คํธ ์์ฝ | |
inputs = summ_tokenizer([translated_text], return_tensors="pt", max_length=1024, truncation=True) | |
summary_ids = summ_model.generate( | |
inputs["input_ids"], | |
num_beams=4, | |
max_length=200, | |
min_length=50, | |
early_stopping=True | |
) | |
summary = summ_tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
# ๊ฒฐ๊ณผ ๋ฐํ (๊ธฐ์ฌ ์ ๋ชฉ ํฌํจ) | |
result = f"์ ๋ชฉ: {article.title}\n\n์์ฝ: {summary}" | |
return result | |
except Exception as e: | |
print(f"์ฒ๋ฆฌ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
return f"์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}" | |
# Gradio UI | |
with gr.Blocks(theme="soft") as demo: | |
gr.Markdown("# ์์ด ๋ด์ค ์์ฝ๊ธฐ (EN โ KO โ ์์ฝ)") | |
gr.Markdown("### ์ฒซ ์คํ ์ ๋ชจ๋ธ์ ๋ค์ด๋ก๋ํ๋ฏ๋ก ์๊ฐ์ด ์์๋ ์ ์์ต๋๋ค.") | |
with gr.Row(): | |
with gr.Column(): | |
input_url = gr.Textbox( | |
label="๋ด์ค ๊ธฐ์ฌ URL ์ ๋ ฅ", | |
placeholder="https://example.com/news-article" | |
) | |
submit_btn = gr.Button("์์ฝํ๊ธฐ", variant="primary") | |
model_load_btn = gr.Button("๋ชจ๋ธ ๋ฏธ๋ฆฌ ๋ก๋ํ๊ธฐ") | |
with gr.Row(): | |
output_text = gr.Textbox( | |
label="์์ฝ ๊ฒฐ๊ณผ", | |
lines=10 | |
) | |
gr.Markdown(""" | |
### ์ฌ์ฉ ๋ฐฉ๋ฒ | |
1. '๋ชจ๋ธ ๋ฏธ๋ฆฌ ๋ก๋ํ๊ธฐ' ๋ฒํผ์ ํด๋ฆญํ์ฌ ๋ชจ๋ธ์ ๋ฏธ๋ฆฌ ๋ก๋ํ ์ ์์ต๋๋ค. | |
2. ์์ด ๋ด์ค ๊ธฐ์ฌ URL์ ์ ๋ ฅํฉ๋๋ค. | |
3. '์์ฝํ๊ธฐ' ๋ฒํผ์ ํด๋ฆญํฉ๋๋ค. | |
4. ํ๊ตญ์ด๋ก ๋ฒ์ญ ๋ฐ ์์ฝ๋ ๊ฒฐ๊ณผ๊ฐ ํ์๋ฉ๋๋ค. | |
### ์ฐธ๊ณ ์ฌํญ | |
- ์ฒ๋ฆฌ ์๊ฐ์ ๊ธฐ์ฌ ๊ธธ์ด์ ๋ฐ๋ผ ๋ฌ๋ผ์ง ์ ์์ต๋๋ค. | |
- ๋ชจ๋ธ์ด ์ฒ์ ๋ก๋๋ ๋ ๋ค์ ์๊ฐ์ด ์์๋ ์ ์์ต๋๋ค. | |
- 5,000์๊ฐ ๋๋ ๊ธด ๊ธฐ์ฌ๋ ์ผ๋ถ๋ง ์ฒ๋ฆฌ๋ฉ๋๋ค. | |
""") | |
submit_btn.click( | |
fn=summarize_news, | |
inputs=input_url, | |
outputs=output_text | |
) | |
model_load_btn.click( | |
fn=load_models, | |
inputs=None, | |
outputs=output_text | |
) | |
if __name__ == "__main__": | |
demo.launch() |