|
import torch |
|
import gradio as gr |
|
import requests |
|
from bs4 import BeautifulSoup |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5TokenizerFast, AutoTokenizer, AutoModelForSeq2SeqLM |
|
import nltk |
|
|
|
|
|
nltk.download('punkt') |
|
|
|
|
|
|
|
|
|
def extract_naver_news_article(text_or_url): |
|
try: |
|
text_or_url = text_or_url.strip() |
|
if text_or_url.startswith("http"): |
|
headers = { |
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" |
|
} |
|
response = requests.get(text_or_url, headers=headers) |
|
if response.status_code != 200: |
|
return f"ํ์ด์ง ์์ฒญ ์คํจ (HTTP {response.status_code})" |
|
soup = BeautifulSoup(response.text, "html.parser") |
|
article_content = soup.select_one("#dic_area") |
|
if not article_content: |
|
article_content = soup.select_one(".go_trans._article_content") |
|
if not article_content: |
|
return "๊ธฐ์ฌ ๋ณธ๋ฌธ์ ์ฐพ์ ์ ์์ต๋๋ค." |
|
paragraphs = article_content.find_all("p") |
|
article_text = "\n".join([p.get_text(strip=True) for p in paragraphs]) if paragraphs else article_content.get_text(strip=True) |
|
return article_text |
|
else: |
|
return text_or_url |
|
except Exception as e: |
|
return f"์ค๋ฅ ๋ฐ์: {e}" |
|
|
|
|
|
def load_title_model(checkpoint_path): |
|
try: |
|
title_tokenizer = T5TokenizerFast.from_pretrained(checkpoint_path) |
|
title_model = T5ForConditionalGeneration.from_pretrained(checkpoint_path) |
|
return title_model, title_tokenizer |
|
except Exception as e: |
|
print(f"Error loading title model: {e}") |
|
return None, None |
|
|
|
|
|
def load_summarization_model(model_dir): |
|
try: |
|
summarizer_tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) |
|
return summarizer_model, summarizer_tokenizer |
|
except Exception as e: |
|
print(f"Error loading summarization model: {e}") |
|
return None, None |
|
|
|
|
|
title_path = 'onebeans/keT5-news-title-gen' |
|
abs_path = 'onebeans/keT5-news-summarizer' |
|
|
|
title_model, title_tokenizer = load_title_model(title_path) |
|
summarizer_model, summarizer_tokenizer = load_summarization_model(abs_path) |
|
|
|
|
|
def generate_title(article_text, max_length=20, num_beams=10, early_stopping=True): |
|
try: |
|
input_ids = title_tokenizer.encode(f"summarize: {article_text}", return_tensors="pt", truncation=True) |
|
outputs = title_model.generate( |
|
input_ids, |
|
max_length=max_length, |
|
num_beams=num_beams, |
|
early_stopping=early_stopping |
|
) |
|
title = title_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return title |
|
except Exception as e: |
|
return f"Error generating title: {e}" |
|
|
|
|
|
def generate_summary(article_text, max_length=40, num_beams=10, early_stopping=True): |
|
try: |
|
input_ids = summarizer_tokenizer.encode(article_text, return_tensors="pt", truncation=True) |
|
outputs = summarizer_model.generate( |
|
input_ids, |
|
max_length=max_length, |
|
num_beams=num_beams, |
|
early_stopping=early_stopping |
|
) |
|
summary = summarizer_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return summary |
|
except Exception as e: |
|
return f"Error generating summary: {e}" |
|
|
|
|
|
def process_news(input_text, title_max_length, title_num_beams, title_early_stopping, |
|
summary_max_length, summary_num_beams, summary_early_stopping): |
|
article_text = extract_naver_news_article(input_text) |
|
title = generate_title(article_text, max_length=title_max_length, num_beams=title_num_beams, early_stopping=title_early_stopping) |
|
summary = generate_summary(article_text, max_length=summary_max_length, num_beams=summary_num_beams, early_stopping=summary_early_stopping) |
|
return article_text, title, summary |
|
|
|
|
|
def launch_gradio_interface(): |
|
interface = gr.Interface( |
|
fn=process_news, |
|
inputs=[ |
|
gr.Textbox(label="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL", placeholder="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ์ ์ง์ ์
๋ ฅํ๊ฑฐ๋ URL์ ์
๋ ฅํ์ธ์.", lines=10), |
|
gr.Slider(0, 40, value=20, step=1, label="์ ๋ชฉ ์์ฑ - ์ต๋ ๊ธธ์ด"), |
|
gr.Slider(1, 10, value=10, step=1, label="์ ๋ชฉ ์์ฑ - ํ์ ๋น ์"), |
|
gr.Checkbox(value=True, label="์ ๋ชฉ ์์ฑ - ์กฐ๊ธฐ ์ข
๋ฃ"), |
|
gr.Slider(0, 256, value=128, step=1, label="์์ฝ ์์ฑ - ์ต๋ ๊ธธ์ด"), |
|
gr.Slider(1, 10, value=10, step=1, label="์์ฝ ์์ฑ - ํ์ ๋น ์"), |
|
gr.Checkbox(value=True, label="์์ฝ ์์ฑ - ์กฐ๊ธฐ ์ข
๋ฃ") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="์ถ์ถ๋๊ฑฐ๋ ์
๋ ฅ๋ ๊ธฐ์ฌ ๋ณธ๋ฌธ"), |
|
gr.Textbox(label="์์ฑ๋ ์ ๋ชฉ"), |
|
gr.Textbox(label="์์ฑ๋ ์์ฝ๋ฌธ") |
|
], |
|
title="๋ค์ด๋ฒ ๋ด์ค ์ ๋ชฉ ๋ฐ ์์ฝ ์์ฑ๊ธฐ", |
|
description="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL์ ์
๋ ฅํ๋ฉด, ๋ณธ๋ฌธ์ ์๋ ์ถ์ถํ ํ ์ ๋ชฉ๊ณผ ์์ฝ๋ฌธ์ ์์ฑํฉ๋๋ค." |
|
) |
|
interface.launch(share=True) |
|
|
|
if __name__ == '__main__': |
|
launch_gradio_interface() |
|
|