pjbmask's picture
Update app.py
994131d verified
import os
import sys
import gradio as gr
from transformers import BartForConditionalGeneration, BartTokenizer, MBartForConditionalGeneration, MBart50TokenizerFast
from newspaper import Article
import requests
from urllib.parse import urlparse
# ์ „์—ญ ๋ณ€์ˆ˜
trans_model = None
trans_tokenizer = None
summ_model = None
summ_tokenizer = None
def load_models():
global trans_model, trans_tokenizer, summ_model, summ_tokenizer
if trans_model is None:
try:
print("MBart ๋‹ค๊ตญ์–ด ๋ฒˆ์—ญ ๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ๋„ ์ค‘...")
# ์˜์–ด โ†’ ํ•œ๊ตญ์–ด ๋ฒˆ์—ญ ๋ชจ๋ธ ๋กœ๋“œ (MBart ๋‹ค๊ตญ์–ด ๋ชจ๋ธ ์‚ฌ์šฉ)
try:
trans_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")
print("๋ฒˆ์—ญ ๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต")
except Exception as e:
print(f"๋ฒˆ์—ญ ๋ชจ๋ธ ๋กœ๋“œ ์˜ค๋ฅ˜: {e}")
return f"๋ฒˆ์—ญ ๋ชจ๋ธ ๋กœ๋“œ ์˜ค๋ฅ˜: {str(e)}"
print("์š”์•ฝ ๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ๋„ ์ค‘...")
# ํ•œ๊ตญ์–ด ์š”์•ฝ ๋ชจ๋ธ ๋กœ๋“œ
try:
summ_tokenizer = BartTokenizer.from_pretrained("digit82/bart-korean-summarization")
summ_model = BartForConditionalGeneration.from_pretrained("digit82/bart-korean-summarization")
print("์š”์•ฝ ๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต")
except Exception as e:
print(f"์š”์•ฝ ๋ชจ๋ธ ๋กœ๋“œ ์˜ค๋ฅ˜: {e}")
return f"์š”์•ฝ ๋ชจ๋ธ ๋กœ๋“œ ์˜ค๋ฅ˜: {str(e)}"
return "๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ"
except Exception as e:
print(f"๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ˆ์ƒ์น˜ ๋ชปํ•œ ์˜ค๋ฅ˜: {e}")
return f"๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
def is_valid_url(url):
"""URL ์œ ํšจ์„ฑ ๊ฒ€์‚ฌ"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except:
return False
def validate_url_accessibility(url):
"""URL ์ ‘๊ทผ ๊ฐ€๋Šฅ ์—ฌ๋ถ€ ํ™•์ธ"""
try:
response = requests.head(url, timeout=5)
return 200 <= response.status_code < 400
except:
return False
def translate_text(text):
"""MBart ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ์˜์–ดโ†’ํ•œ๊ตญ์–ด ๋ฒˆ์—ญ"""
# ์˜์–ด๋กœ ์†Œ์Šค ์–ธ์–ด ์„ค์ •
trans_tokenizer.src_lang = "en_XX"
# ์ธ์ฝ”๋”ฉ
encoded = trans_tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
# ๋ฒˆ์—ญ ์ƒ์„ฑ
generated_tokens = trans_model.generate(
**encoded,
forced_bos_token_id=trans_tokenizer.lang_code_to_id["ko_KR"],
max_length=1024,
num_beams=5,
early_stopping=True
)
# ๋””์ฝ”๋”ฉ
translated_text = trans_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return translated_text
def summarize_news(url):
"""๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ํ•จ์ˆ˜"""
# ์ž…๋ ฅ ๊ฒ€์ฆ
if not url or not isinstance(url, str):
return "์˜ค๋ฅ˜: ์œ ํšจํ•œ URL์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”."
if not is_valid_url(url):
return "์˜ค๋ฅ˜: ์ž…๋ ฅํ•œ ํ…์ŠคํŠธ๊ฐ€ ์œ ํšจํ•œ URL ํ˜•์‹์ด ์•„๋‹™๋‹ˆ๋‹ค."
if not validate_url_accessibility(url):
return "์˜ค๋ฅ˜: ์ž…๋ ฅํ•œ URL์— ์ ‘๊ทผํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. URL์ด ์˜ฌ๋ฐ”๋ฅธ์ง€ ํ™•์ธํ•ด์ฃผ์„ธ์š”."
# ๋ชจ๋ธ ๋กœ๋“œ (์•„์ง ๋กœ๋“œ๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ)
if trans_model is None:
load_result = load_models()
if isinstance(load_result, str) and "์˜ค๋ฅ˜" in load_result:
return load_result
try:
# ๋‰ด์Šค ๊ธฐ์‚ฌ ํฌ๋กค๋ง
article = Article(url)
article.download()
article.parse()
text = article.text
if not text:
return "์˜ค๋ฅ˜: ๊ธฐ์‚ฌ์—์„œ ํ…์ŠคํŠธ๋ฅผ ์ถ”์ถœํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
# ๊ธด ํ…์ŠคํŠธ ์ฒ˜๋ฆฌ
if len(text) > 5000:
print(f"ํ…์ŠคํŠธ๊ฐ€ ๋„ˆ๋ฌด ๊น๋‹ˆ๋‹ค. ์›๋ณธ ๊ธธ์ด: {len(text)}์ž, 5000์ž๋กœ ์ œํ•œํ•ฉ๋‹ˆ๋‹ค.")
text = text[:5000]
print("ํ…์ŠคํŠธ ๋ฒˆ์—ญ ์ค‘...")
# ์˜์–ด โ†’ ํ•œ๊ตญ์–ด ๋ฒˆ์—ญ (MBart ๋ชจ๋ธ ์‚ฌ์šฉ)
translated_text = translate_text(text)
print("๋ฒˆ์—ญ๋œ ํ…์ŠคํŠธ ์š”์•ฝ ์ค‘...")
# ํ•œ๊ตญ์–ด ํ…์ŠคํŠธ ์š”์•ฝ
inputs = summ_tokenizer([translated_text], return_tensors="pt", max_length=1024, truncation=True)
summary_ids = summ_model.generate(
inputs["input_ids"],
num_beams=4,
max_length=200,
min_length=50,
early_stopping=True
)
summary = summ_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜ (๊ธฐ์‚ฌ ์ œ๋ชฉ ํฌํ•จ)
result = f"์ œ๋ชฉ: {article.title}\n\n์š”์•ฝ: {summary}"
return result
except Exception as e:
print(f"์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
return f"์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
# Gradio UI
with gr.Blocks(theme="soft") as demo:
gr.Markdown("# ์˜์–ด ๋‰ด์Šค ์š”์•ฝ๊ธฐ (EN โ†’ KO โ†’ ์š”์•ฝ)")
gr.Markdown("### ์ฒซ ์‹คํ–‰ ์‹œ ๋ชจ๋ธ์„ ๋‹ค์šด๋กœ๋“œํ•˜๋ฏ€๋กœ ์‹œ๊ฐ„์ด ์†Œ์š”๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
with gr.Row():
with gr.Column():
input_url = gr.Textbox(
label="๋‰ด์Šค ๊ธฐ์‚ฌ URL ์ž…๋ ฅ",
placeholder="https://example.com/news-article"
)
submit_btn = gr.Button("์š”์•ฝํ•˜๊ธฐ", variant="primary")
model_load_btn = gr.Button("๋ชจ๋ธ ๋ฏธ๋ฆฌ ๋กœ๋“œํ•˜๊ธฐ")
with gr.Row():
output_text = gr.Textbox(
label="์š”์•ฝ ๊ฒฐ๊ณผ",
lines=10
)
gr.Markdown("""
### ์‚ฌ์šฉ ๋ฐฉ๋ฒ•
1. '๋ชจ๋ธ ๋ฏธ๋ฆฌ ๋กœ๋“œํ•˜๊ธฐ' ๋ฒ„ํŠผ์„ ํด๋ฆญํ•˜์—ฌ ๋ชจ๋ธ์„ ๋ฏธ๋ฆฌ ๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
2. ์˜์–ด ๋‰ด์Šค ๊ธฐ์‚ฌ URL์„ ์ž…๋ ฅํ•ฉ๋‹ˆ๋‹ค.
3. '์š”์•ฝํ•˜๊ธฐ' ๋ฒ„ํŠผ์„ ํด๋ฆญํ•ฉ๋‹ˆ๋‹ค.
4. ํ•œ๊ตญ์–ด๋กœ ๋ฒˆ์—ญ ๋ฐ ์š”์•ฝ๋œ ๊ฒฐ๊ณผ๊ฐ€ ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.
### ์ฐธ๊ณ  ์‚ฌํ•ญ
- ์ฒ˜๋ฆฌ ์‹œ๊ฐ„์€ ๊ธฐ์‚ฌ ๊ธธ์ด์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
- ๋ชจ๋ธ์ด ์ฒ˜์Œ ๋กœ๋“œ๋  ๋•Œ ๋‹ค์†Œ ์‹œ๊ฐ„์ด ์†Œ์š”๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
- 5,000์ž๊ฐ€ ๋„˜๋Š” ๊ธด ๊ธฐ์‚ฌ๋Š” ์ผ๋ถ€๋งŒ ์ฒ˜๋ฆฌ๋ฉ๋‹ˆ๋‹ค.
""")
submit_btn.click(
fn=summarize_news,
inputs=input_url,
outputs=output_text
)
model_load_btn.click(
fn=load_models,
inputs=None,
outputs=output_text
)
if __name__ == "__main__":
demo.launch()