|
import torch
|
|
import gradio as gr
|
|
import requests
|
|
from bs4 import BeautifulSoup
|
|
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5TokenizerFast, AutoTokenizer, AutoModelForSeq2SeqLM
|
|
import nltk
|
|
|
|
|
|
nltk.download('punkt')
|
|
|
|
|
|
|
|
|
|
def extract_naver_news_article(text_or_url):
|
|
try:
|
|
text_or_url = text_or_url.strip()
|
|
if text_or_url.startswith("http"):
|
|
headers = {
|
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
|
}
|
|
response = requests.get(text_or_url, headers=headers)
|
|
if response.status_code != 200:
|
|
return f"ํ์ด์ง ์์ฒญ ์คํจ (HTTP {response.status_code})"
|
|
soup = BeautifulSoup(response.text, "html.parser")
|
|
article_content = soup.select_one("#dic_area")
|
|
if not article_content:
|
|
article_content = soup.select_one(".go_trans._article_content")
|
|
if not article_content:
|
|
return "๊ธฐ์ฌ ๋ณธ๋ฌธ์ ์ฐพ์ ์ ์์ต๋๋ค."
|
|
paragraphs = article_content.find_all("p")
|
|
article_text = "\n".join([p.get_text(strip=True) for p in paragraphs]) if paragraphs else article_content.get_text(strip=True)
|
|
return article_text
|
|
else:
|
|
return text_or_url
|
|
except Exception as e:
|
|
return f"์ค๋ฅ ๋ฐ์: {e}"
|
|
|
|
|
|
def load_title_model(checkpoint_path):
|
|
try:
|
|
title_tokenizer = T5TokenizerFast.from_pretrained(checkpoint_path)
|
|
title_model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
|
|
return title_model, title_tokenizer
|
|
except Exception as e:
|
|
print(f"Error loading title model: {e}")
|
|
return None, None
|
|
|
|
|
|
def load_summarization_model(model_dir):
|
|
try:
|
|
summarizer_tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
|
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
|
|
return summarizer_model, summarizer_tokenizer
|
|
except Exception as e:
|
|
print(f"Error loading summarization model: {e}")
|
|
return None, None
|
|
|
|
|
|
checkpoint_path = r"C:\Users\onebe\_ABH\prac\NLP\final\title_gen"
|
|
model_dir = r"C:\Users\onebe\_ABH\prac\NLP\final\abs_gen"
|
|
|
|
title_model, title_tokenizer = load_title_model(checkpoint_path)
|
|
summarizer_model, summarizer_tokenizer = load_summarization_model(model_dir)
|
|
|
|
|
|
def generate_title(article_text, max_length=20, num_beams=10, early_stopping=True):
|
|
try:
|
|
input_ids = title_tokenizer.encode(f"summarize: {article_text}", return_tensors="pt", truncation=True)
|
|
outputs = title_model.generate(
|
|
input_ids,
|
|
max_length=max_length,
|
|
num_beams=num_beams,
|
|
early_stopping=early_stopping
|
|
)
|
|
title = title_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
return title
|
|
except Exception as e:
|
|
return f"Error generating title: {e}"
|
|
|
|
|
|
def generate_summary(article_text, max_length=40, num_beams=10, early_stopping=True):
|
|
try:
|
|
input_ids = summarizer_tokenizer.encode(article_text, return_tensors="pt", truncation=True)
|
|
outputs = summarizer_model.generate(
|
|
input_ids,
|
|
max_length=max_length,
|
|
num_beams=num_beams,
|
|
early_stopping=early_stopping
|
|
)
|
|
summary = summarizer_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
return summary
|
|
except Exception as e:
|
|
return f"Error generating summary: {e}"
|
|
|
|
|
|
def process_news(input_text, title_max_length, title_num_beams, title_early_stopping,
|
|
summary_max_length, summary_num_beams, summary_early_stopping):
|
|
article_text = extract_naver_news_article(input_text)
|
|
title = generate_title(article_text, max_length=title_max_length, num_beams=title_num_beams, early_stopping=title_early_stopping)
|
|
summary = generate_summary(article_text, max_length=summary_max_length, num_beams=summary_num_beams, early_stopping=summary_early_stopping)
|
|
return article_text, title, summary
|
|
|
|
|
|
def launch_gradio_interface():
|
|
interface = gr.Interface(
|
|
fn=process_news,
|
|
inputs=[
|
|
gr.Textbox(label="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL", placeholder="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ์ ์ง์ ์
๋ ฅํ๊ฑฐ๋ URL์ ์
๋ ฅํ์ธ์.", lines=10),
|
|
gr.Slider(0, 40, value=20, step=1, label="์ ๋ชฉ ์์ฑ - ์ต๋ ๊ธธ์ด"),
|
|
gr.Slider(1, 10, value=10, step=1, label="์ ๋ชฉ ์์ฑ - ํ์ ๋น ์"),
|
|
gr.Checkbox(value=True, label="์ ๋ชฉ ์์ฑ - ์กฐ๊ธฐ ์ข
๋ฃ"),
|
|
gr.Slider(0, 256, value=128, step=1, label="์์ฝ ์์ฑ - ์ต๋ ๊ธธ์ด"),
|
|
gr.Slider(1, 10, value=10, step=1, label="์์ฝ ์์ฑ - ํ์ ๋น ์"),
|
|
gr.Checkbox(value=True, label="์์ฝ ์์ฑ - ์กฐ๊ธฐ ์ข
๋ฃ")
|
|
],
|
|
outputs=[
|
|
gr.Textbox(label="์ถ์ถ๋๊ฑฐ๋ ์
๋ ฅ๋ ๊ธฐ์ฌ ๋ณธ๋ฌธ"),
|
|
gr.Textbox(label="์์ฑ๋ ์ ๋ชฉ"),
|
|
gr.Textbox(label="์์ฑ๋ ์์ฝ๋ฌธ")
|
|
],
|
|
title="๋ค์ด๋ฒ ๋ด์ค ์ ๋ชฉ ๋ฐ ์์ฝ ์์ฑ๊ธฐ",
|
|
description="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL์ ์
๋ ฅํ๋ฉด, ๋ณธ๋ฌธ์ ์๋ ์ถ์ถํ ํ ์ ๋ชฉ๊ณผ ์์ฝ๋ฌธ์ ์์ฑํฉ๋๋ค."
|
|
)
|
|
interface.launch(share=True)
|
|
|
|
if __name__ == '__main__':
|
|
launch_gradio_interface()
|
|
|