NewsBrief / app.py
onebeans's picture
Update app.py
4d7bcf1 verified
import torch
import gradio as gr
import requests
from bs4 import BeautifulSoup
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5TokenizerFast, AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
# nltk์˜ 'punkt' ๋ฐ์ดํ„ฐ ๋‹ค์šด๋กœ๋“œ (๋ฌธ์žฅ ํ† ํฐํ™”๋ฅผ ์œ„ํ•ด)
nltk.download('punkt')
# ๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅ๋ฐ›์•„,
# ๋งŒ์•ฝ ์ž…๋ ฅ๊ฐ’์ด "http"๋กœ ์‹œ์ž‘ํ•˜๋ฉด ๋„ค์ด๋ฒ„ ๋‰ด์Šค URL๋กœ ๊ฐ„์ฃผํ•˜์—ฌ ๋ณธ๋ฌธ์„ ์ถ”์ถœํ•˜๊ณ ,
# ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด ๊ทธ๋Œ€๋กœ ํ…์ŠคํŠธ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
def extract_naver_news_article(text_or_url):
try:
text_or_url = text_or_url.strip()
if text_or_url.startswith("http"):
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
response = requests.get(text_or_url, headers=headers)
if response.status_code != 200:
return f"ํŽ˜์ด์ง€ ์š”์ฒญ ์‹คํŒจ (HTTP {response.status_code})"
soup = BeautifulSoup(response.text, "html.parser")
article_content = soup.select_one("#dic_area")
if not article_content:
article_content = soup.select_one(".go_trans._article_content")
if not article_content:
return "๊ธฐ์‚ฌ ๋ณธ๋ฌธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
paragraphs = article_content.find_all("p")
article_text = "\n".join([p.get_text(strip=True) for p in paragraphs]) if paragraphs else article_content.get_text(strip=True)
return article_text
else:
return text_or_url
except Exception as e:
return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"
# ์ œ๋ชฉ ์ƒ์„ฑ ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜ (ํŒŒ์ธํŠœ๋‹๋œ T5 ๋ชจ๋ธ)
def load_title_model(checkpoint_path):
try:
title_tokenizer = T5TokenizerFast.from_pretrained(checkpoint_path)
title_model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
return title_model, title_tokenizer
except Exception as e:
print(f"Error loading title model: {e}")
return None, None
# ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜
def load_summarization_model(model_dir):
try:
summarizer_tokenizer = AutoTokenizer.from_pretrained(model_dir)
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
return summarizer_model, summarizer_tokenizer
except Exception as e:
print(f"Error loading summarization model: {e}")
return None, None
# ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ • (์‚ฌ์šฉ์ž ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ๋ณ€๊ฒฝ)
title_path = 'onebeans/keT5-news-title-gen' # ํŒŒ์ธํŠœ๋‹๋œ T5 ์ œ๋ชฉ ์ƒ์„ฑ ๋ชจ๋ธ ๊ฒฝ๋กœ
abs_path = 'onebeans/keT5-news-summarizer' # ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ๋ชจ๋ธ ๊ฒฝ๋กœ
title_model, title_tokenizer = load_title_model(title_path)
summarizer_model, summarizer_tokenizer = load_summarization_model(abs_path)
# ๋‰ด์Šค ๊ธฐ์‚ฌ ์ œ๋ชฉ ์ƒ์„ฑ ํ•จ์ˆ˜
def generate_title(article_text, max_length=20, num_beams=10, early_stopping=True):
try:
input_ids = title_tokenizer.encode(f"summarize: {article_text}", return_tensors="pt", truncation=True)
outputs = title_model.generate(
input_ids,
max_length=max_length,
num_beams=num_beams,
early_stopping=early_stopping
)
title = title_tokenizer.decode(outputs[0], skip_special_tokens=True)
return title
except Exception as e:
return f"Error generating title: {e}"
# ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ์ƒ์„ฑ ํ•จ์ˆ˜
def generate_summary(article_text, max_length=40, num_beams=10, early_stopping=True):
try:
input_ids = summarizer_tokenizer.encode(article_text, return_tensors="pt", truncation=True)
outputs = summarizer_model.generate(
input_ids,
max_length=max_length,
num_beams=num_beams,
early_stopping=early_stopping
)
summary = summarizer_tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary
except Exception as e:
return f"Error generating summary: {e}"
# ๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅ๋ฐ›์•„ ๋ณธ๋ฌธ ์ถ”์ถœ ํ›„ ์ œ๋ชฉ๊ณผ ์š”์•ฝ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜
def process_news(input_text, title_max_length, title_num_beams, title_early_stopping,
summary_max_length, summary_num_beams, summary_early_stopping):
article_text = extract_naver_news_article(input_text)
title = generate_title(article_text, max_length=title_max_length, num_beams=title_num_beams, early_stopping=title_early_stopping)
summary = generate_summary(article_text, max_length=summary_max_length, num_beams=summary_num_beams, early_stopping=summary_early_stopping)
return article_text, title, summary
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰ ํ•จ์ˆ˜
def launch_gradio_interface():
interface = gr.Interface(
fn=process_news,
inputs=[
gr.Textbox(label="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL", placeholder="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ์„ ์ง์ ‘ ์ž…๋ ฅํ•˜๊ฑฐ๋‚˜ URL์„ ์ž…๋ ฅํ•˜์„ธ์š”.", lines=10),
gr.Slider(0, 40, value=20, step=1, label="์ œ๋ชฉ ์ƒ์„ฑ - ์ตœ๋Œ€ ๊ธธ์ด"),
gr.Slider(1, 10, value=10, step=1, label="์ œ๋ชฉ ์ƒ์„ฑ - ํƒ์ƒ‰ ๋น” ์ˆ˜"),
gr.Checkbox(value=True, label="์ œ๋ชฉ ์ƒ์„ฑ - ์กฐ๊ธฐ ์ข…๋ฃŒ"),
gr.Slider(0, 256, value=128, step=1, label="์š”์•ฝ ์ƒ์„ฑ - ์ตœ๋Œ€ ๊ธธ์ด"),
gr.Slider(1, 10, value=10, step=1, label="์š”์•ฝ ์ƒ์„ฑ - ํƒ์ƒ‰ ๋น” ์ˆ˜"),
gr.Checkbox(value=True, label="์š”์•ฝ ์ƒ์„ฑ - ์กฐ๊ธฐ ์ข…๋ฃŒ")
],
outputs=[
gr.Textbox(label="์ถ”์ถœ๋˜๊ฑฐ๋‚˜ ์ž…๋ ฅ๋œ ๊ธฐ์‚ฌ ๋ณธ๋ฌธ"),
gr.Textbox(label="์ƒ์„ฑ๋œ ์ œ๋ชฉ"),
gr.Textbox(label="์ƒ์„ฑ๋œ ์š”์•ฝ๋ฌธ")
],
title="๋„ค์ด๋ฒ„ ๋‰ด์Šค ์ œ๋ชฉ ๋ฐ ์š”์•ฝ ์ƒ์„ฑ๊ธฐ",
description="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅํ•˜๋ฉด, ๋ณธ๋ฌธ์„ ์ž๋™ ์ถ”์ถœํ•œ ํ›„ ์ œ๋ชฉ๊ณผ ์š”์•ฝ๋ฌธ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค."
)
interface.launch(share=True)
if __name__ == '__main__':
launch_gradio_interface()