File size: 6,103 Bytes
38c216c 4d7bcf1 38c216c 4d7bcf1 38c216c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import torch
import gradio as gr
import requests
from bs4 import BeautifulSoup
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5TokenizerFast, AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
# nltk์ 'punkt' ๋ฐ์ดํฐ ๋ค์ด๋ก๋ (๋ฌธ์ฅ ํ ํฐํ๋ฅผ ์ํด)
nltk.download('punkt')
# ๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL์ ์
๋ ฅ๋ฐ์,
# ๋ง์ฝ ์
๋ ฅ๊ฐ์ด "http"๋ก ์์ํ๋ฉด ๋ค์ด๋ฒ ๋ด์ค URL๋ก ๊ฐ์ฃผํ์ฌ ๋ณธ๋ฌธ์ ์ถ์ถํ๊ณ ,
# ๊ทธ๋ ์ง ์์ผ๋ฉด ๊ทธ๋๋ก ํ
์คํธ๋ฅผ ๋ฐํํ๋ ํจ์
def extract_naver_news_article(text_or_url):
try:
text_or_url = text_or_url.strip()
if text_or_url.startswith("http"):
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
response = requests.get(text_or_url, headers=headers)
if response.status_code != 200:
return f"ํ์ด์ง ์์ฒญ ์คํจ (HTTP {response.status_code})"
soup = BeautifulSoup(response.text, "html.parser")
article_content = soup.select_one("#dic_area")
if not article_content:
article_content = soup.select_one(".go_trans._article_content")
if not article_content:
return "๊ธฐ์ฌ ๋ณธ๋ฌธ์ ์ฐพ์ ์ ์์ต๋๋ค."
paragraphs = article_content.find_all("p")
article_text = "\n".join([p.get_text(strip=True) for p in paragraphs]) if paragraphs else article_content.get_text(strip=True)
return article_text
else:
return text_or_url
except Exception as e:
return f"์ค๋ฅ ๋ฐ์: {e}"
# ์ ๋ชฉ ์์ฑ ๋ชจ๋ธ ๋ก๋ ํจ์ (ํ์ธํ๋๋ T5 ๋ชจ๋ธ)
def load_title_model(checkpoint_path):
try:
title_tokenizer = T5TokenizerFast.from_pretrained(checkpoint_path)
title_model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
return title_model, title_tokenizer
except Exception as e:
print(f"Error loading title model: {e}")
return None, None
# ๋ด์ค ๊ธฐ์ฌ ์์ฝ ๋ชจ๋ธ ๋ก๋ ํจ์
def load_summarization_model(model_dir):
try:
summarizer_tokenizer = AutoTokenizer.from_pretrained(model_dir)
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
return summarizer_model, summarizer_tokenizer
except Exception as e:
print(f"Error loading summarization model: {e}")
return None, None
# ๋ชจ๋ธ ๊ฒฝ๋ก ์ค์ (์ฌ์ฉ์ ํ๊ฒฝ์ ๋ง๊ฒ ๋ณ๊ฒฝ)
title_path = 'onebeans/keT5-news-title-gen' # ํ์ธํ๋๋ T5 ์ ๋ชฉ ์์ฑ ๋ชจ๋ธ ๊ฒฝ๋ก
abs_path = 'onebeans/keT5-news-summarizer' # ๋ด์ค ๊ธฐ์ฌ ์์ฝ ๋ชจ๋ธ ๊ฒฝ๋ก
title_model, title_tokenizer = load_title_model(title_path)
summarizer_model, summarizer_tokenizer = load_summarization_model(abs_path)
# ๋ด์ค ๊ธฐ์ฌ ์ ๋ชฉ ์์ฑ ํจ์
def generate_title(article_text, max_length=20, num_beams=10, early_stopping=True):
try:
input_ids = title_tokenizer.encode(f"summarize: {article_text}", return_tensors="pt", truncation=True)
outputs = title_model.generate(
input_ids,
max_length=max_length,
num_beams=num_beams,
early_stopping=early_stopping
)
title = title_tokenizer.decode(outputs[0], skip_special_tokens=True)
return title
except Exception as e:
return f"Error generating title: {e}"
# ๋ด์ค ๊ธฐ์ฌ ์์ฝ ์์ฑ ํจ์
def generate_summary(article_text, max_length=40, num_beams=10, early_stopping=True):
try:
input_ids = summarizer_tokenizer.encode(article_text, return_tensors="pt", truncation=True)
outputs = summarizer_model.generate(
input_ids,
max_length=max_length,
num_beams=num_beams,
early_stopping=early_stopping
)
summary = summarizer_tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary
except Exception as e:
return f"Error generating summary: {e}"
# ๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL์ ์
๋ ฅ๋ฐ์ ๋ณธ๋ฌธ ์ถ์ถ ํ ์ ๋ชฉ๊ณผ ์์ฝ ์์ฑํ๋ ํจ์
def process_news(input_text, title_max_length, title_num_beams, title_early_stopping,
summary_max_length, summary_num_beams, summary_early_stopping):
article_text = extract_naver_news_article(input_text)
title = generate_title(article_text, max_length=title_max_length, num_beams=title_num_beams, early_stopping=title_early_stopping)
summary = generate_summary(article_text, max_length=summary_max_length, num_beams=summary_num_beams, early_stopping=summary_early_stopping)
return article_text, title, summary
# Gradio ์ธํฐํ์ด์ค ์คํ ํจ์
def launch_gradio_interface():
interface = gr.Interface(
fn=process_news,
inputs=[
gr.Textbox(label="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL", placeholder="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ์ ์ง์ ์
๋ ฅํ๊ฑฐ๋ URL์ ์
๋ ฅํ์ธ์.", lines=10),
gr.Slider(0, 40, value=20, step=1, label="์ ๋ชฉ ์์ฑ - ์ต๋ ๊ธธ์ด"),
gr.Slider(1, 10, value=10, step=1, label="์ ๋ชฉ ์์ฑ - ํ์ ๋น ์"),
gr.Checkbox(value=True, label="์ ๋ชฉ ์์ฑ - ์กฐ๊ธฐ ์ข
๋ฃ"),
gr.Slider(0, 256, value=128, step=1, label="์์ฝ ์์ฑ - ์ต๋ ๊ธธ์ด"),
gr.Slider(1, 10, value=10, step=1, label="์์ฝ ์์ฑ - ํ์ ๋น ์"),
gr.Checkbox(value=True, label="์์ฝ ์์ฑ - ์กฐ๊ธฐ ์ข
๋ฃ")
],
outputs=[
gr.Textbox(label="์ถ์ถ๋๊ฑฐ๋ ์
๋ ฅ๋ ๊ธฐ์ฌ ๋ณธ๋ฌธ"),
gr.Textbox(label="์์ฑ๋ ์ ๋ชฉ"),
gr.Textbox(label="์์ฑ๋ ์์ฝ๋ฌธ")
],
title="๋ค์ด๋ฒ ๋ด์ค ์ ๋ชฉ ๋ฐ ์์ฝ ์์ฑ๊ธฐ",
description="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL์ ์
๋ ฅํ๋ฉด, ๋ณธ๋ฌธ์ ์๋ ์ถ์ถํ ํ ์ ๋ชฉ๊ณผ ์์ฝ๋ฌธ์ ์์ฑํฉ๋๋ค."
)
interface.launch(share=True)
if __name__ == '__main__':
launch_gradio_interface()
|