Update app.py
Browse files
app.py
CHANGED
@@ -1,127 +1,127 @@
|
|
1 |
-
import torch
|
2 |
-
import gradio as gr
|
3 |
-
import requests
|
4 |
-
from bs4 import BeautifulSoup
|
5 |
-
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5TokenizerFast, AutoTokenizer, AutoModelForSeq2SeqLM
|
6 |
-
import nltk
|
7 |
-
|
8 |
-
# nltk์ 'punkt' ๋ฐ์ดํฐ ๋ค์ด๋ก๋ (๋ฌธ์ฅ ํ ํฐํ๋ฅผ ์ํด)
|
9 |
-
nltk.download('punkt')
|
10 |
-
|
11 |
-
# ๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL์ ์
๋ ฅ๋ฐ์,
|
12 |
-
# ๋ง์ฝ ์
๋ ฅ๊ฐ์ด "http"๋ก ์์ํ๋ฉด ๋ค์ด๋ฒ ๋ด์ค URL๋ก ๊ฐ์ฃผํ์ฌ ๋ณธ๋ฌธ์ ์ถ์ถํ๊ณ ,
|
13 |
-
# ๊ทธ๋ ์ง ์์ผ๋ฉด ๊ทธ๋๋ก ํ
์คํธ๋ฅผ ๋ฐํํ๋ ํจ์
|
14 |
-
def extract_naver_news_article(text_or_url):
|
15 |
-
try:
|
16 |
-
text_or_url = text_or_url.strip()
|
17 |
-
if text_or_url.startswith("http"):
|
18 |
-
headers = {
|
19 |
-
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
20 |
-
}
|
21 |
-
response = requests.get(text_or_url, headers=headers)
|
22 |
-
if response.status_code != 200:
|
23 |
-
return f"ํ์ด์ง ์์ฒญ ์คํจ (HTTP {response.status_code})"
|
24 |
-
soup = BeautifulSoup(response.text, "html.parser")
|
25 |
-
article_content = soup.select_one("#dic_area")
|
26 |
-
if not article_content:
|
27 |
-
article_content = soup.select_one(".go_trans._article_content")
|
28 |
-
if not article_content:
|
29 |
-
return "๊ธฐ์ฌ ๋ณธ๋ฌธ์ ์ฐพ์ ์ ์์ต๋๋ค."
|
30 |
-
paragraphs = article_content.find_all("p")
|
31 |
-
article_text = "\n".join([p.get_text(strip=True) for p in paragraphs]) if paragraphs else article_content.get_text(strip=True)
|
32 |
-
return article_text
|
33 |
-
else:
|
34 |
-
return text_or_url
|
35 |
-
except Exception as e:
|
36 |
-
return f"์ค๋ฅ ๋ฐ์: {e}"
|
37 |
-
|
38 |
-
# ์ ๋ชฉ ์์ฑ ๋ชจ๋ธ ๋ก๋ ํจ์ (ํ์ธํ๋๋ T5 ๋ชจ๋ธ)
|
39 |
-
def load_title_model(checkpoint_path):
|
40 |
-
try:
|
41 |
-
title_tokenizer = T5TokenizerFast.from_pretrained(checkpoint_path)
|
42 |
-
title_model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
|
43 |
-
return title_model, title_tokenizer
|
44 |
-
except Exception as e:
|
45 |
-
print(f"Error loading title model: {e}")
|
46 |
-
return None, None
|
47 |
-
|
48 |
-
# ๋ด์ค ๊ธฐ์ฌ ์์ฝ ๋ชจ๋ธ ๋ก๋ ํจ์
|
49 |
-
def load_summarization_model(model_dir):
|
50 |
-
try:
|
51 |
-
summarizer_tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
52 |
-
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
|
53 |
-
return summarizer_model, summarizer_tokenizer
|
54 |
-
except Exception as e:
|
55 |
-
print(f"Error loading summarization model: {e}")
|
56 |
-
return None, None
|
57 |
-
|
58 |
-
# ๋ชจ๋ธ ๊ฒฝ๋ก ์ค์ (์ฌ์ฉ์ ํ๊ฒฝ์ ๋ง๊ฒ ๋ณ๊ฒฝ)
|
59 |
-
checkpoint_path =
|
60 |
-
model_dir =
|
61 |
-
|
62 |
-
title_model, title_tokenizer = load_title_model(checkpoint_path)
|
63 |
-
summarizer_model, summarizer_tokenizer = load_summarization_model(model_dir)
|
64 |
-
|
65 |
-
# ๋ด์ค ๊ธฐ์ฌ ์ ๋ชฉ ์์ฑ ํจ์
|
66 |
-
def generate_title(article_text, max_length=20, num_beams=10, early_stopping=True):
|
67 |
-
try:
|
68 |
-
input_ids = title_tokenizer.encode(f"summarize: {article_text}", return_tensors="pt", truncation=True)
|
69 |
-
outputs = title_model.generate(
|
70 |
-
input_ids,
|
71 |
-
max_length=max_length,
|
72 |
-
num_beams=num_beams,
|
73 |
-
early_stopping=early_stopping
|
74 |
-
)
|
75 |
-
title = title_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
76 |
-
return title
|
77 |
-
except Exception as e:
|
78 |
-
return f"Error generating title: {e}"
|
79 |
-
|
80 |
-
# ๋ด์ค ๊ธฐ์ฌ ์์ฝ ์์ฑ ํจ์
|
81 |
-
def generate_summary(article_text, max_length=40, num_beams=10, early_stopping=True):
|
82 |
-
try:
|
83 |
-
input_ids = summarizer_tokenizer.encode(article_text, return_tensors="pt", truncation=True)
|
84 |
-
outputs = summarizer_model.generate(
|
85 |
-
input_ids,
|
86 |
-
max_length=max_length,
|
87 |
-
num_beams=num_beams,
|
88 |
-
early_stopping=early_stopping
|
89 |
-
)
|
90 |
-
summary = summarizer_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
91 |
-
return summary
|
92 |
-
except Exception as e:
|
93 |
-
return f"Error generating summary: {e}"
|
94 |
-
|
95 |
-
# ๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL์ ์
๋ ฅ๋ฐ์ ๋ณธ๋ฌธ ์ถ์ถ ํ ์ ๋ชฉ๊ณผ ์์ฝ ์์ฑํ๋ ํจ์
|
96 |
-
def process_news(input_text, title_max_length, title_num_beams, title_early_stopping,
|
97 |
-
summary_max_length, summary_num_beams, summary_early_stopping):
|
98 |
-
article_text = extract_naver_news_article(input_text)
|
99 |
-
title = generate_title(article_text, max_length=title_max_length, num_beams=title_num_beams, early_stopping=title_early_stopping)
|
100 |
-
summary = generate_summary(article_text, max_length=summary_max_length, num_beams=summary_num_beams, early_stopping=summary_early_stopping)
|
101 |
-
return article_text, title, summary
|
102 |
-
|
103 |
-
# Gradio ์ธํฐํ์ด์ค ์คํ ํจ์
|
104 |
-
def launch_gradio_interface():
|
105 |
-
interface = gr.Interface(
|
106 |
-
fn=process_news,
|
107 |
-
inputs=[
|
108 |
-
gr.Textbox(label="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL", placeholder="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ์ ์ง์ ์
๋ ฅํ๊ฑฐ๋ URL์ ์
๋ ฅํ์ธ์.", lines=10),
|
109 |
-
gr.Slider(0, 40, value=20, step=1, label="์ ๋ชฉ ์์ฑ - ์ต๋ ๊ธธ์ด"),
|
110 |
-
gr.Slider(1, 10, value=10, step=1, label="์ ๋ชฉ ์์ฑ - ํ์ ๋น ์"),
|
111 |
-
gr.Checkbox(value=True, label="์ ๋ชฉ ์์ฑ - ์กฐ๊ธฐ ์ข
๋ฃ"),
|
112 |
-
gr.Slider(0, 256, value=128, step=1, label="์์ฝ ์์ฑ - ์ต๋ ๊ธธ์ด"),
|
113 |
-
gr.Slider(1, 10, value=10, step=1, label="์์ฝ ์์ฑ - ํ์ ๋น ์"),
|
114 |
-
gr.Checkbox(value=True, label="์์ฝ ์์ฑ - ์กฐ๊ธฐ ์ข
๋ฃ")
|
115 |
-
],
|
116 |
-
outputs=[
|
117 |
-
gr.Textbox(label="์ถ์ถ๋๊ฑฐ๋ ์
๋ ฅ๋ ๊ธฐ์ฌ ๋ณธ๋ฌธ"),
|
118 |
-
gr.Textbox(label="์์ฑ๋ ์ ๋ชฉ"),
|
119 |
-
gr.Textbox(label="์์ฑ๋ ์์ฝ๋ฌธ")
|
120 |
-
],
|
121 |
-
title="๋ค์ด๋ฒ ๋ด์ค ์ ๋ชฉ ๋ฐ ์์ฝ ์์ฑ๊ธฐ",
|
122 |
-
description="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL์ ์
๋ ฅํ๋ฉด, ๋ณธ๋ฌธ์ ์๋ ์ถ์ถํ ํ ์ ๋ชฉ๊ณผ ์์ฝ๋ฌธ์ ์์ฑํฉ๋๋ค."
|
123 |
-
)
|
124 |
-
interface.launch(share=True)
|
125 |
-
|
126 |
-
if __name__ == '__main__':
|
127 |
-
launch_gradio_interface()
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
import requests
|
4 |
+
from bs4 import BeautifulSoup
|
5 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5TokenizerFast, AutoTokenizer, AutoModelForSeq2SeqLM
|
6 |
+
import nltk
|
7 |
+
|
8 |
+
# nltk์ 'punkt' ๋ฐ์ดํฐ ๋ค์ด๋ก๋ (๋ฌธ์ฅ ํ ํฐํ๋ฅผ ์ํด)
|
9 |
+
nltk.download('punkt')
|
10 |
+
|
11 |
+
# ๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL์ ์
๋ ฅ๋ฐ์,
|
12 |
+
# ๋ง์ฝ ์
๋ ฅ๊ฐ์ด "http"๋ก ์์ํ๋ฉด ๋ค์ด๋ฒ ๋ด์ค URL๋ก ๊ฐ์ฃผํ์ฌ ๋ณธ๋ฌธ์ ์ถ์ถํ๊ณ ,
|
13 |
+
# ๊ทธ๋ ์ง ์์ผ๋ฉด ๊ทธ๋๋ก ํ
์คํธ๋ฅผ ๋ฐํํ๋ ํจ์
|
14 |
+
def extract_naver_news_article(text_or_url):
|
15 |
+
try:
|
16 |
+
text_or_url = text_or_url.strip()
|
17 |
+
if text_or_url.startswith("http"):
|
18 |
+
headers = {
|
19 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
20 |
+
}
|
21 |
+
response = requests.get(text_or_url, headers=headers)
|
22 |
+
if response.status_code != 200:
|
23 |
+
return f"ํ์ด์ง ์์ฒญ ์คํจ (HTTP {response.status_code})"
|
24 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
25 |
+
article_content = soup.select_one("#dic_area")
|
26 |
+
if not article_content:
|
27 |
+
article_content = soup.select_one(".go_trans._article_content")
|
28 |
+
if not article_content:
|
29 |
+
return "๊ธฐ์ฌ ๋ณธ๋ฌธ์ ์ฐพ์ ์ ์์ต๋๋ค."
|
30 |
+
paragraphs = article_content.find_all("p")
|
31 |
+
article_text = "\n".join([p.get_text(strip=True) for p in paragraphs]) if paragraphs else article_content.get_text(strip=True)
|
32 |
+
return article_text
|
33 |
+
else:
|
34 |
+
return text_or_url
|
35 |
+
except Exception as e:
|
36 |
+
return f"์ค๋ฅ ๋ฐ์: {e}"
|
37 |
+
|
38 |
+
# ์ ๋ชฉ ์์ฑ ๋ชจ๋ธ ๋ก๋ ํจ์ (ํ์ธํ๋๋ T5 ๋ชจ๋ธ)
|
39 |
+
def load_title_model(checkpoint_path):
|
40 |
+
try:
|
41 |
+
title_tokenizer = T5TokenizerFast.from_pretrained(checkpoint_path)
|
42 |
+
title_model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
|
43 |
+
return title_model, title_tokenizer
|
44 |
+
except Exception as e:
|
45 |
+
print(f"Error loading title model: {e}")
|
46 |
+
return None, None
|
47 |
+
|
48 |
+
# ๋ด์ค ๊ธฐ์ฌ ์์ฝ ๋ชจ๋ธ ๋ก๋ ํจ์
|
49 |
+
def load_summarization_model(model_dir):
|
50 |
+
try:
|
51 |
+
summarizer_tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
52 |
+
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
|
53 |
+
return summarizer_model, summarizer_tokenizer
|
54 |
+
except Exception as e:
|
55 |
+
print(f"Error loading summarization model: {e}")
|
56 |
+
return None, None
|
57 |
+
|
58 |
+
# ๋ชจ๋ธ ๊ฒฝ๋ก ์ค์ (์ฌ์ฉ์ ํ๊ฒฝ์ ๋ง๊ฒ ๋ณ๊ฒฝ)
|
59 |
+
checkpoint_path = 'onebeans/keT5-news-summarizer' # ํ์ธํ๋๋ T5 ์ ๋ชฉ ์์ฑ ๋ชจ๋ธ ๊ฒฝ๋ก
|
60 |
+
model_dir = 'onebeans/keT5-news-title-gen' # ๋ด์ค ๊ธฐ์ฌ ์์ฝ ๋ชจ๋ธ ๊ฒฝ๋ก
|
61 |
+
|
62 |
+
title_model, title_tokenizer = load_title_model(checkpoint_path)
|
63 |
+
summarizer_model, summarizer_tokenizer = load_summarization_model(model_dir)
|
64 |
+
|
65 |
+
# ๋ด์ค ๊ธฐ์ฌ ์ ๋ชฉ ์์ฑ ํจ์
|
66 |
+
def generate_title(article_text, max_length=20, num_beams=10, early_stopping=True):
|
67 |
+
try:
|
68 |
+
input_ids = title_tokenizer.encode(f"summarize: {article_text}", return_tensors="pt", truncation=True)
|
69 |
+
outputs = title_model.generate(
|
70 |
+
input_ids,
|
71 |
+
max_length=max_length,
|
72 |
+
num_beams=num_beams,
|
73 |
+
early_stopping=early_stopping
|
74 |
+
)
|
75 |
+
title = title_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
76 |
+
return title
|
77 |
+
except Exception as e:
|
78 |
+
return f"Error generating title: {e}"
|
79 |
+
|
80 |
+
# ๋ด์ค ๊ธฐ์ฌ ์์ฝ ์์ฑ ํจ์
|
81 |
+
def generate_summary(article_text, max_length=40, num_beams=10, early_stopping=True):
|
82 |
+
try:
|
83 |
+
input_ids = summarizer_tokenizer.encode(article_text, return_tensors="pt", truncation=True)
|
84 |
+
outputs = summarizer_model.generate(
|
85 |
+
input_ids,
|
86 |
+
max_length=max_length,
|
87 |
+
num_beams=num_beams,
|
88 |
+
early_stopping=early_stopping
|
89 |
+
)
|
90 |
+
summary = summarizer_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
91 |
+
return summary
|
92 |
+
except Exception as e:
|
93 |
+
return f"Error generating summary: {e}"
|
94 |
+
|
95 |
+
# ๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL์ ์
๋ ฅ๋ฐ์ ๋ณธ๋ฌธ ์ถ์ถ ํ ์ ๋ชฉ๊ณผ ์์ฝ ์์ฑํ๋ ํจ์
|
96 |
+
def process_news(input_text, title_max_length, title_num_beams, title_early_stopping,
|
97 |
+
summary_max_length, summary_num_beams, summary_early_stopping):
|
98 |
+
article_text = extract_naver_news_article(input_text)
|
99 |
+
title = generate_title(article_text, max_length=title_max_length, num_beams=title_num_beams, early_stopping=title_early_stopping)
|
100 |
+
summary = generate_summary(article_text, max_length=summary_max_length, num_beams=summary_num_beams, early_stopping=summary_early_stopping)
|
101 |
+
return article_text, title, summary
|
102 |
+
|
103 |
+
# Gradio ์ธํฐํ์ด์ค ์คํ ํจ์
|
104 |
+
def launch_gradio_interface():
|
105 |
+
interface = gr.Interface(
|
106 |
+
fn=process_news,
|
107 |
+
inputs=[
|
108 |
+
gr.Textbox(label="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL", placeholder="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ์ ์ง์ ์
๋ ฅํ๊ฑฐ๋ URL์ ์
๋ ฅํ์ธ์.", lines=10),
|
109 |
+
gr.Slider(0, 40, value=20, step=1, label="์ ๋ชฉ ์์ฑ - ์ต๋ ๊ธธ์ด"),
|
110 |
+
gr.Slider(1, 10, value=10, step=1, label="์ ๋ชฉ ์์ฑ - ํ์ ๋น ์"),
|
111 |
+
gr.Checkbox(value=True, label="์ ๋ชฉ ์์ฑ - ์กฐ๊ธฐ ์ข
๋ฃ"),
|
112 |
+
gr.Slider(0, 256, value=128, step=1, label="์์ฝ ์์ฑ - ์ต๋ ๊ธธ์ด"),
|
113 |
+
gr.Slider(1, 10, value=10, step=1, label="์์ฝ ์์ฑ - ํ์ ๋น ์"),
|
114 |
+
gr.Checkbox(value=True, label="์์ฝ ์์ฑ - ์กฐ๊ธฐ ์ข
๋ฃ")
|
115 |
+
],
|
116 |
+
outputs=[
|
117 |
+
gr.Textbox(label="์ถ์ถ๋๊ฑฐ๋ ์
๋ ฅ๋ ๊ธฐ์ฌ ๋ณธ๋ฌธ"),
|
118 |
+
gr.Textbox(label="์์ฑ๋ ์ ๋ชฉ"),
|
119 |
+
gr.Textbox(label="์์ฑ๋ ์์ฝ๋ฌธ")
|
120 |
+
],
|
121 |
+
title="๋ค์ด๋ฒ ๋ด์ค ์ ๋ชฉ ๋ฐ ์์ฝ ์์ฑ๊ธฐ",
|
122 |
+
description="๋ด์ค ๊ธฐ์ฌ ๋ณธ๋ฌธ ๋๋ URL์ ์
๋ ฅํ๋ฉด, ๋ณธ๋ฌธ์ ์๋ ์ถ์ถํ ํ ์ ๋ชฉ๊ณผ ์์ฝ๋ฌธ์ ์์ฑํฉ๋๋ค."
|
123 |
+
)
|
124 |
+
interface.launch(share=True)
|
125 |
+
|
126 |
+
if __name__ == '__main__':
|
127 |
+
launch_gradio_interface()
|