File size: 6,103 Bytes
38c216c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d7bcf1
 
38c216c
4d7bcf1
 
38c216c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import gradio as gr
import requests
from bs4 import BeautifulSoup
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5TokenizerFast, AutoTokenizer, AutoModelForSeq2SeqLM
import nltk

# nltk์˜ 'punkt' ๋ฐ์ดํ„ฐ ๋‹ค์šด๋กœ๋“œ (๋ฌธ์žฅ ํ† ํฐํ™”๋ฅผ ์œ„ํ•ด)
nltk.download('punkt')

# ๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅ๋ฐ›์•„,
# ๋งŒ์•ฝ ์ž…๋ ฅ๊ฐ’์ด "http"๋กœ ์‹œ์ž‘ํ•˜๋ฉด ๋„ค์ด๋ฒ„ ๋‰ด์Šค URL๋กœ ๊ฐ„์ฃผํ•˜์—ฌ ๋ณธ๋ฌธ์„ ์ถ”์ถœํ•˜๊ณ ,
# ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด ๊ทธ๋Œ€๋กœ ํ…์ŠคํŠธ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
def extract_naver_news_article(text_or_url):
    try:
        text_or_url = text_or_url.strip()
        if text_or_url.startswith("http"):
            headers = {
                "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
            }
            response = requests.get(text_or_url, headers=headers)
            if response.status_code != 200:
                return f"ํŽ˜์ด์ง€ ์š”์ฒญ ์‹คํŒจ (HTTP {response.status_code})"
            soup = BeautifulSoup(response.text, "html.parser")
            article_content = soup.select_one("#dic_area")
            if not article_content:
                article_content = soup.select_one(".go_trans._article_content")
            if not article_content:
                return "๊ธฐ์‚ฌ ๋ณธ๋ฌธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
            paragraphs = article_content.find_all("p")
            article_text = "\n".join([p.get_text(strip=True) for p in paragraphs]) if paragraphs else article_content.get_text(strip=True)
            return article_text
        else:
            return text_or_url
    except Exception as e:
        return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"

# ์ œ๋ชฉ ์ƒ์„ฑ ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜ (ํŒŒ์ธํŠœ๋‹๋œ T5 ๋ชจ๋ธ)
def load_title_model(checkpoint_path):
    try:
        title_tokenizer = T5TokenizerFast.from_pretrained(checkpoint_path)
        title_model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
        return title_model, title_tokenizer
    except Exception as e:
        print(f"Error loading title model: {e}")
        return None, None

# ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜ 
def load_summarization_model(model_dir):
    try:
        summarizer_tokenizer = AutoTokenizer.from_pretrained(model_dir)
        summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
        return summarizer_model, summarizer_tokenizer
    except Exception as e:
        print(f"Error loading summarization model: {e}")
        return None, None

# ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ • (์‚ฌ์šฉ์ž ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ๋ณ€๊ฒฝ)
title_path = 'onebeans/keT5-news-title-gen'  # ํŒŒ์ธํŠœ๋‹๋œ T5 ์ œ๋ชฉ ์ƒ์„ฑ ๋ชจ๋ธ ๊ฒฝ๋กœ
abs_path =  'onebeans/keT5-news-summarizer'                 # ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ๋ชจ๋ธ ๊ฒฝ๋กœ

title_model, title_tokenizer = load_title_model(title_path)
summarizer_model, summarizer_tokenizer = load_summarization_model(abs_path)

# ๋‰ด์Šค ๊ธฐ์‚ฌ ์ œ๋ชฉ ์ƒ์„ฑ ํ•จ์ˆ˜
def generate_title(article_text, max_length=20, num_beams=10, early_stopping=True):
    try:
        input_ids = title_tokenizer.encode(f"summarize: {article_text}", return_tensors="pt", truncation=True)
        outputs = title_model.generate(
            input_ids,
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=early_stopping
        )
        title = title_tokenizer.decode(outputs[0], skip_special_tokens=True)
        return title
    except Exception as e:
        return f"Error generating title: {e}"

# ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ์ƒ์„ฑ ํ•จ์ˆ˜
def generate_summary(article_text, max_length=40, num_beams=10, early_stopping=True):
    try:
        input_ids = summarizer_tokenizer.encode(article_text, return_tensors="pt", truncation=True)
        outputs = summarizer_model.generate(
            input_ids,
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=early_stopping
        )
        summary = summarizer_tokenizer.decode(outputs[0], skip_special_tokens=True)
        return summary
    except Exception as e:
        return f"Error generating summary: {e}"

# ๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅ๋ฐ›์•„ ๋ณธ๋ฌธ ์ถ”์ถœ ํ›„ ์ œ๋ชฉ๊ณผ ์š”์•ฝ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜
def process_news(input_text, title_max_length, title_num_beams, title_early_stopping,
                 summary_max_length, summary_num_beams, summary_early_stopping):
    article_text = extract_naver_news_article(input_text)
    title = generate_title(article_text, max_length=title_max_length, num_beams=title_num_beams, early_stopping=title_early_stopping)
    summary = generate_summary(article_text, max_length=summary_max_length, num_beams=summary_num_beams, early_stopping=summary_early_stopping)
    return article_text, title, summary

# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰ ํ•จ์ˆ˜
def launch_gradio_interface():
    interface = gr.Interface(
        fn=process_news,
        inputs=[
            gr.Textbox(label="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL", placeholder="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ์„ ์ง์ ‘ ์ž…๋ ฅํ•˜๊ฑฐ๋‚˜ URL์„ ์ž…๋ ฅํ•˜์„ธ์š”.", lines=10),
            gr.Slider(0, 40, value=20, step=1, label="์ œ๋ชฉ ์ƒ์„ฑ - ์ตœ๋Œ€ ๊ธธ์ด"),
            gr.Slider(1, 10, value=10, step=1, label="์ œ๋ชฉ ์ƒ์„ฑ - ํƒ์ƒ‰ ๋น” ์ˆ˜"),
            gr.Checkbox(value=True, label="์ œ๋ชฉ ์ƒ์„ฑ - ์กฐ๊ธฐ ์ข…๋ฃŒ"),
            gr.Slider(0, 256, value=128, step=1, label="์š”์•ฝ ์ƒ์„ฑ - ์ตœ๋Œ€ ๊ธธ์ด"),
            gr.Slider(1, 10, value=10, step=1, label="์š”์•ฝ ์ƒ์„ฑ - ํƒ์ƒ‰ ๋น” ์ˆ˜"),
            gr.Checkbox(value=True, label="์š”์•ฝ ์ƒ์„ฑ - ์กฐ๊ธฐ ์ข…๋ฃŒ")
        ],
        outputs=[
            gr.Textbox(label="์ถ”์ถœ๋˜๊ฑฐ๋‚˜ ์ž…๋ ฅ๋œ ๊ธฐ์‚ฌ ๋ณธ๋ฌธ"),
            gr.Textbox(label="์ƒ์„ฑ๋œ ์ œ๋ชฉ"),
            gr.Textbox(label="์ƒ์„ฑ๋œ ์š”์•ฝ๋ฌธ")
        ],
        title="๋„ค์ด๋ฒ„ ๋‰ด์Šค ์ œ๋ชฉ ๋ฐ ์š”์•ฝ ์ƒ์„ฑ๊ธฐ",
        description="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅํ•˜๋ฉด, ๋ณธ๋ฌธ์„ ์ž๋™ ์ถ”์ถœํ•œ ํ›„ ์ œ๋ชฉ๊ณผ ์š”์•ฝ๋ฌธ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค."
    )
    interface.launch(share=True)

if __name__ == '__main__':
    launch_gradio_interface()