NewsBrief / app.py
onebeans's picture
Update app.py
4d7bcf1 verified
raw
history blame
6.1 kB
import torch
import gradio as gr
import requests
from bs4 import BeautifulSoup
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5TokenizerFast, AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
# nltk์˜ 'punkt' ๋ฐ์ดํ„ฐ ๋‹ค์šด๋กœ๋“œ (๋ฌธ์žฅ ํ† ํฐํ™”๋ฅผ ์œ„ํ•ด)
nltk.download('punkt')
# ๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅ๋ฐ›์•„,
# ๋งŒ์•ฝ ์ž…๋ ฅ๊ฐ’์ด "http"๋กœ ์‹œ์ž‘ํ•˜๋ฉด ๋„ค์ด๋ฒ„ ๋‰ด์Šค URL๋กœ ๊ฐ„์ฃผํ•˜์—ฌ ๋ณธ๋ฌธ์„ ์ถ”์ถœํ•˜๊ณ ,
# ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด ๊ทธ๋Œ€๋กœ ํ…์ŠคํŠธ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
def extract_naver_news_article(text_or_url):
try:
text_or_url = text_or_url.strip()
if text_or_url.startswith("http"):
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
response = requests.get(text_or_url, headers=headers)
if response.status_code != 200:
return f"ํŽ˜์ด์ง€ ์š”์ฒญ ์‹คํŒจ (HTTP {response.status_code})"
soup = BeautifulSoup(response.text, "html.parser")
article_content = soup.select_one("#dic_area")
if not article_content:
article_content = soup.select_one(".go_trans._article_content")
if not article_content:
return "๊ธฐ์‚ฌ ๋ณธ๋ฌธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
paragraphs = article_content.find_all("p")
article_text = "\n".join([p.get_text(strip=True) for p in paragraphs]) if paragraphs else article_content.get_text(strip=True)
return article_text
else:
return text_or_url
except Exception as e:
return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"
# ์ œ๋ชฉ ์ƒ์„ฑ ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜ (ํŒŒ์ธํŠœ๋‹๋œ T5 ๋ชจ๋ธ)
def load_title_model(checkpoint_path):
try:
title_tokenizer = T5TokenizerFast.from_pretrained(checkpoint_path)
title_model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
return title_model, title_tokenizer
except Exception as e:
print(f"Error loading title model: {e}")
return None, None
# ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜
def load_summarization_model(model_dir):
try:
summarizer_tokenizer = AutoTokenizer.from_pretrained(model_dir)
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
return summarizer_model, summarizer_tokenizer
except Exception as e:
print(f"Error loading summarization model: {e}")
return None, None
# ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ • (์‚ฌ์šฉ์ž ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ๋ณ€๊ฒฝ)
title_path = 'onebeans/keT5-news-title-gen' # ํŒŒ์ธํŠœ๋‹๋œ T5 ์ œ๋ชฉ ์ƒ์„ฑ ๋ชจ๋ธ ๊ฒฝ๋กœ
abs_path = 'onebeans/keT5-news-summarizer' # ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ๋ชจ๋ธ ๊ฒฝ๋กœ
title_model, title_tokenizer = load_title_model(title_path)
summarizer_model, summarizer_tokenizer = load_summarization_model(abs_path)
# ๋‰ด์Šค ๊ธฐ์‚ฌ ์ œ๋ชฉ ์ƒ์„ฑ ํ•จ์ˆ˜
def generate_title(article_text, max_length=20, num_beams=10, early_stopping=True):
try:
input_ids = title_tokenizer.encode(f"summarize: {article_text}", return_tensors="pt", truncation=True)
outputs = title_model.generate(
input_ids,
max_length=max_length,
num_beams=num_beams,
early_stopping=early_stopping
)
title = title_tokenizer.decode(outputs[0], skip_special_tokens=True)
return title
except Exception as e:
return f"Error generating title: {e}"
# ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ์ƒ์„ฑ ํ•จ์ˆ˜
def generate_summary(article_text, max_length=40, num_beams=10, early_stopping=True):
try:
input_ids = summarizer_tokenizer.encode(article_text, return_tensors="pt", truncation=True)
outputs = summarizer_model.generate(
input_ids,
max_length=max_length,
num_beams=num_beams,
early_stopping=early_stopping
)
summary = summarizer_tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary
except Exception as e:
return f"Error generating summary: {e}"
# ๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅ๋ฐ›์•„ ๋ณธ๋ฌธ ์ถ”์ถœ ํ›„ ์ œ๋ชฉ๊ณผ ์š”์•ฝ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜
def process_news(input_text, title_max_length, title_num_beams, title_early_stopping,
summary_max_length, summary_num_beams, summary_early_stopping):
article_text = extract_naver_news_article(input_text)
title = generate_title(article_text, max_length=title_max_length, num_beams=title_num_beams, early_stopping=title_early_stopping)
summary = generate_summary(article_text, max_length=summary_max_length, num_beams=summary_num_beams, early_stopping=summary_early_stopping)
return article_text, title, summary
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰ ํ•จ์ˆ˜
def launch_gradio_interface():
interface = gr.Interface(
fn=process_news,
inputs=[
gr.Textbox(label="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL", placeholder="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ์„ ์ง์ ‘ ์ž…๋ ฅํ•˜๊ฑฐ๋‚˜ URL์„ ์ž…๋ ฅํ•˜์„ธ์š”.", lines=10),
gr.Slider(0, 40, value=20, step=1, label="์ œ๋ชฉ ์ƒ์„ฑ - ์ตœ๋Œ€ ๊ธธ์ด"),
gr.Slider(1, 10, value=10, step=1, label="์ œ๋ชฉ ์ƒ์„ฑ - ํƒ์ƒ‰ ๋น” ์ˆ˜"),
gr.Checkbox(value=True, label="์ œ๋ชฉ ์ƒ์„ฑ - ์กฐ๊ธฐ ์ข…๋ฃŒ"),
gr.Slider(0, 256, value=128, step=1, label="์š”์•ฝ ์ƒ์„ฑ - ์ตœ๋Œ€ ๊ธธ์ด"),
gr.Slider(1, 10, value=10, step=1, label="์š”์•ฝ ์ƒ์„ฑ - ํƒ์ƒ‰ ๋น” ์ˆ˜"),
gr.Checkbox(value=True, label="์š”์•ฝ ์ƒ์„ฑ - ์กฐ๊ธฐ ์ข…๋ฃŒ")
],
outputs=[
gr.Textbox(label="์ถ”์ถœ๋˜๊ฑฐ๋‚˜ ์ž…๋ ฅ๋œ ๊ธฐ์‚ฌ ๋ณธ๋ฌธ"),
gr.Textbox(label="์ƒ์„ฑ๋œ ์ œ๋ชฉ"),
gr.Textbox(label="์ƒ์„ฑ๋œ ์š”์•ฝ๋ฌธ")
],
title="๋„ค์ด๋ฒ„ ๋‰ด์Šค ์ œ๋ชฉ ๋ฐ ์š”์•ฝ ์ƒ์„ฑ๊ธฐ",
description="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅํ•˜๋ฉด, ๋ณธ๋ฌธ์„ ์ž๋™ ์ถ”์ถœํ•œ ํ›„ ์ œ๋ชฉ๊ณผ ์š”์•ฝ๋ฌธ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค."
)
interface.launch(share=True)
if __name__ == '__main__':
launch_gradio_interface()