onebeans commited on
Commit
38c216c
ยท
verified ยท
1 Parent(s): d1892db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -127
app.py CHANGED
@@ -1,127 +1,127 @@
1
- import torch
2
- import gradio as gr
3
- import requests
4
- from bs4 import BeautifulSoup
5
- from transformers import T5Tokenizer, T5ForConditionalGeneration, T5TokenizerFast, AutoTokenizer, AutoModelForSeq2SeqLM
6
- import nltk
7
-
8
- # nltk์˜ 'punkt' ๋ฐ์ดํ„ฐ ๋‹ค์šด๋กœ๋“œ (๋ฌธ์žฅ ํ† ํฐํ™”๋ฅผ ์œ„ํ•ด)
9
- nltk.download('punkt')
10
-
11
- # ๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅ๋ฐ›์•„,
12
- # ๋งŒ์•ฝ ์ž…๋ ฅ๊ฐ’์ด "http"๋กœ ์‹œ์ž‘ํ•˜๋ฉด ๋„ค์ด๋ฒ„ ๋‰ด์Šค URL๋กœ ๊ฐ„์ฃผํ•˜์—ฌ ๋ณธ๋ฌธ์„ ์ถ”์ถœํ•˜๊ณ ,
13
- # ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด ๊ทธ๋Œ€๋กœ ํ…์ŠคํŠธ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
14
- def extract_naver_news_article(text_or_url):
15
- try:
16
- text_or_url = text_or_url.strip()
17
- if text_or_url.startswith("http"):
18
- headers = {
19
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
20
- }
21
- response = requests.get(text_or_url, headers=headers)
22
- if response.status_code != 200:
23
- return f"ํŽ˜์ด์ง€ ์š”์ฒญ ์‹คํŒจ (HTTP {response.status_code})"
24
- soup = BeautifulSoup(response.text, "html.parser")
25
- article_content = soup.select_one("#dic_area")
26
- if not article_content:
27
- article_content = soup.select_one(".go_trans._article_content")
28
- if not article_content:
29
- return "๊ธฐ์‚ฌ ๋ณธ๋ฌธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
30
- paragraphs = article_content.find_all("p")
31
- article_text = "\n".join([p.get_text(strip=True) for p in paragraphs]) if paragraphs else article_content.get_text(strip=True)
32
- return article_text
33
- else:
34
- return text_or_url
35
- except Exception as e:
36
- return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"
37
-
38
- # ์ œ๋ชฉ ์ƒ์„ฑ ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜ (ํŒŒ์ธํŠœ๋‹๋œ T5 ๋ชจ๋ธ)
39
- def load_title_model(checkpoint_path):
40
- try:
41
- title_tokenizer = T5TokenizerFast.from_pretrained(checkpoint_path)
42
- title_model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
43
- return title_model, title_tokenizer
44
- except Exception as e:
45
- print(f"Error loading title model: {e}")
46
- return None, None
47
-
48
- # ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜
49
- def load_summarization_model(model_dir):
50
- try:
51
- summarizer_tokenizer = AutoTokenizer.from_pretrained(model_dir)
52
- summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
53
- return summarizer_model, summarizer_tokenizer
54
- except Exception as e:
55
- print(f"Error loading summarization model: {e}")
56
- return None, None
57
-
58
- # ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ • (์‚ฌ์šฉ์ž ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ๋ณ€๊ฒฝ)
59
- checkpoint_path = r"C:\Users\onebe\_ABH\prac\NLP\final\title_gen" # ํŒŒ์ธํŠœ๋‹๋œ T5 ์ œ๋ชฉ ์ƒ์„ฑ ๋ชจ๋ธ ๊ฒฝ๋กœ
60
- model_dir = r"C:\Users\onebe\_ABH\prac\NLP\final\abs_gen" # ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ๋ชจ๋ธ ๊ฒฝ๋กœ
61
-
62
- title_model, title_tokenizer = load_title_model(checkpoint_path)
63
- summarizer_model, summarizer_tokenizer = load_summarization_model(model_dir)
64
-
65
- # ๋‰ด์Šค ๊ธฐ์‚ฌ ์ œ๋ชฉ ์ƒ์„ฑ ํ•จ์ˆ˜
66
- def generate_title(article_text, max_length=20, num_beams=10, early_stopping=True):
67
- try:
68
- input_ids = title_tokenizer.encode(f"summarize: {article_text}", return_tensors="pt", truncation=True)
69
- outputs = title_model.generate(
70
- input_ids,
71
- max_length=max_length,
72
- num_beams=num_beams,
73
- early_stopping=early_stopping
74
- )
75
- title = title_tokenizer.decode(outputs[0], skip_special_tokens=True)
76
- return title
77
- except Exception as e:
78
- return f"Error generating title: {e}"
79
-
80
- # ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ์ƒ์„ฑ ํ•จ์ˆ˜
81
- def generate_summary(article_text, max_length=40, num_beams=10, early_stopping=True):
82
- try:
83
- input_ids = summarizer_tokenizer.encode(article_text, return_tensors="pt", truncation=True)
84
- outputs = summarizer_model.generate(
85
- input_ids,
86
- max_length=max_length,
87
- num_beams=num_beams,
88
- early_stopping=early_stopping
89
- )
90
- summary = summarizer_tokenizer.decode(outputs[0], skip_special_tokens=True)
91
- return summary
92
- except Exception as e:
93
- return f"Error generating summary: {e}"
94
-
95
- # ๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅ๋ฐ›์•„ ๋ณธ๋ฌธ ์ถ”์ถœ ํ›„ ์ œ๋ชฉ๊ณผ ์š”์•ฝ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜
96
- def process_news(input_text, title_max_length, title_num_beams, title_early_stopping,
97
- summary_max_length, summary_num_beams, summary_early_stopping):
98
- article_text = extract_naver_news_article(input_text)
99
- title = generate_title(article_text, max_length=title_max_length, num_beams=title_num_beams, early_stopping=title_early_stopping)
100
- summary = generate_summary(article_text, max_length=summary_max_length, num_beams=summary_num_beams, early_stopping=summary_early_stopping)
101
- return article_text, title, summary
102
-
103
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰ ํ•จ์ˆ˜
104
- def launch_gradio_interface():
105
- interface = gr.Interface(
106
- fn=process_news,
107
- inputs=[
108
- gr.Textbox(label="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL", placeholder="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ์„ ์ง์ ‘ ์ž…๋ ฅํ•˜๊ฑฐ๋‚˜ URL์„ ์ž…๋ ฅํ•˜์„ธ์š”.", lines=10),
109
- gr.Slider(0, 40, value=20, step=1, label="์ œ๋ชฉ ์ƒ์„ฑ - ์ตœ๋Œ€ ๊ธธ์ด"),
110
- gr.Slider(1, 10, value=10, step=1, label="์ œ๋ชฉ ์ƒ์„ฑ - ํƒ์ƒ‰ ๋น” ์ˆ˜"),
111
- gr.Checkbox(value=True, label="์ œ๋ชฉ ์ƒ์„ฑ - ์กฐ๊ธฐ ์ข…๋ฃŒ"),
112
- gr.Slider(0, 256, value=128, step=1, label="์š”์•ฝ ์ƒ์„ฑ - ์ตœ๋Œ€ ๊ธธ์ด"),
113
- gr.Slider(1, 10, value=10, step=1, label="์š”์•ฝ ์ƒ์„ฑ - ํƒ์ƒ‰ ๋น” ์ˆ˜"),
114
- gr.Checkbox(value=True, label="์š”์•ฝ ์ƒ์„ฑ - ์กฐ๊ธฐ ์ข…๋ฃŒ")
115
- ],
116
- outputs=[
117
- gr.Textbox(label="์ถ”์ถœ๋˜๊ฑฐ๋‚˜ ์ž…๋ ฅ๋œ ๊ธฐ์‚ฌ ๋ณธ๋ฌธ"),
118
- gr.Textbox(label="์ƒ์„ฑ๋œ ์ œ๋ชฉ"),
119
- gr.Textbox(label="์ƒ์„ฑ๋œ ์š”์•ฝ๋ฌธ")
120
- ],
121
- title="๋„ค์ด๋ฒ„ ๋‰ด์Šค ์ œ๋ชฉ ๋ฐ ์š”์•ฝ ์ƒ์„ฑ๊ธฐ",
122
- description="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅํ•˜๋ฉด, ๋ณธ๋ฌธ์„ ์ž๋™ ์ถ”์ถœํ•œ ํ›„ ์ œ๋ชฉ๊ณผ ์š”์•ฝ๋ฌธ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค."
123
- )
124
- interface.launch(share=True)
125
-
126
- if __name__ == '__main__':
127
- launch_gradio_interface()
 
1
+ import torch
2
+ import gradio as gr
3
+ import requests
4
+ from bs4 import BeautifulSoup
5
+ from transformers import T5Tokenizer, T5ForConditionalGeneration, T5TokenizerFast, AutoTokenizer, AutoModelForSeq2SeqLM
6
+ import nltk
7
+
8
+ # nltk์˜ 'punkt' ๋ฐ์ดํ„ฐ ๋‹ค์šด๋กœ๋“œ (๋ฌธ์žฅ ํ† ํฐํ™”๋ฅผ ์œ„ํ•ด)
9
+ nltk.download('punkt')
10
+
11
+ # ๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅ๋ฐ›์•„,
12
+ # ๋งŒ์•ฝ ์ž…๋ ฅ๊ฐ’์ด "http"๋กœ ์‹œ์ž‘ํ•˜๋ฉด ๋„ค์ด๋ฒ„ ๋‰ด์Šค URL๋กœ ๊ฐ„์ฃผํ•˜์—ฌ ๋ณธ๋ฌธ์„ ์ถ”์ถœํ•˜๊ณ ,
13
+ # ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด ๊ทธ๋Œ€๋กœ ํ…์ŠคํŠธ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
14
+ def extract_naver_news_article(text_or_url):
15
+ try:
16
+ text_or_url = text_or_url.strip()
17
+ if text_or_url.startswith("http"):
18
+ headers = {
19
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
20
+ }
21
+ response = requests.get(text_or_url, headers=headers)
22
+ if response.status_code != 200:
23
+ return f"ํŽ˜์ด์ง€ ์š”์ฒญ ์‹คํŒจ (HTTP {response.status_code})"
24
+ soup = BeautifulSoup(response.text, "html.parser")
25
+ article_content = soup.select_one("#dic_area")
26
+ if not article_content:
27
+ article_content = soup.select_one(".go_trans._article_content")
28
+ if not article_content:
29
+ return "๊ธฐ์‚ฌ ๋ณธ๋ฌธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
30
+ paragraphs = article_content.find_all("p")
31
+ article_text = "\n".join([p.get_text(strip=True) for p in paragraphs]) if paragraphs else article_content.get_text(strip=True)
32
+ return article_text
33
+ else:
34
+ return text_or_url
35
+ except Exception as e:
36
+ return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"
37
+
38
+ # ์ œ๋ชฉ ์ƒ์„ฑ ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜ (ํŒŒ์ธํŠœ๋‹๋œ T5 ๋ชจ๋ธ)
39
+ def load_title_model(checkpoint_path):
40
+ try:
41
+ title_tokenizer = T5TokenizerFast.from_pretrained(checkpoint_path)
42
+ title_model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
43
+ return title_model, title_tokenizer
44
+ except Exception as e:
45
+ print(f"Error loading title model: {e}")
46
+ return None, None
47
+
48
+ # ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜
49
+ def load_summarization_model(model_dir):
50
+ try:
51
+ summarizer_tokenizer = AutoTokenizer.from_pretrained(model_dir)
52
+ summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
53
+ return summarizer_model, summarizer_tokenizer
54
+ except Exception as e:
55
+ print(f"Error loading summarization model: {e}")
56
+ return None, None
57
+
58
+ # ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ • (์‚ฌ์šฉ์ž ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ๋ณ€๊ฒฝ)
59
+ checkpoint_path = 'onebeans/keT5-news-summarizer' # ํŒŒ์ธํŠœ๋‹๋œ T5 ์ œ๋ชฉ ์ƒ์„ฑ ๋ชจ๋ธ ๊ฒฝ๋กœ
60
+ model_dir = 'onebeans/keT5-news-title-gen' # ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ๋ชจ๋ธ ๊ฒฝ๋กœ
61
+
62
+ title_model, title_tokenizer = load_title_model(checkpoint_path)
63
+ summarizer_model, summarizer_tokenizer = load_summarization_model(model_dir)
64
+
65
+ # ๋‰ด์Šค ๊ธฐ์‚ฌ ์ œ๋ชฉ ์ƒ์„ฑ ํ•จ์ˆ˜
66
+ def generate_title(article_text, max_length=20, num_beams=10, early_stopping=True):
67
+ try:
68
+ input_ids = title_tokenizer.encode(f"summarize: {article_text}", return_tensors="pt", truncation=True)
69
+ outputs = title_model.generate(
70
+ input_ids,
71
+ max_length=max_length,
72
+ num_beams=num_beams,
73
+ early_stopping=early_stopping
74
+ )
75
+ title = title_tokenizer.decode(outputs[0], skip_special_tokens=True)
76
+ return title
77
+ except Exception as e:
78
+ return f"Error generating title: {e}"
79
+
80
+ # ๋‰ด์Šค ๊ธฐ์‚ฌ ์š”์•ฝ ์ƒ์„ฑ ํ•จ์ˆ˜
81
+ def generate_summary(article_text, max_length=40, num_beams=10, early_stopping=True):
82
+ try:
83
+ input_ids = summarizer_tokenizer.encode(article_text, return_tensors="pt", truncation=True)
84
+ outputs = summarizer_model.generate(
85
+ input_ids,
86
+ max_length=max_length,
87
+ num_beams=num_beams,
88
+ early_stopping=early_stopping
89
+ )
90
+ summary = summarizer_tokenizer.decode(outputs[0], skip_special_tokens=True)
91
+ return summary
92
+ except Exception as e:
93
+ return f"Error generating summary: {e}"
94
+
95
+ # ๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅ๋ฐ›์•„ ๋ณธ๋ฌธ ์ถ”์ถœ ํ›„ ์ œ๋ชฉ๊ณผ ์š”์•ฝ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜
96
+ def process_news(input_text, title_max_length, title_num_beams, title_early_stopping,
97
+ summary_max_length, summary_num_beams, summary_early_stopping):
98
+ article_text = extract_naver_news_article(input_text)
99
+ title = generate_title(article_text, max_length=title_max_length, num_beams=title_num_beams, early_stopping=title_early_stopping)
100
+ summary = generate_summary(article_text, max_length=summary_max_length, num_beams=summary_num_beams, early_stopping=summary_early_stopping)
101
+ return article_text, title, summary
102
+
103
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰ ํ•จ์ˆ˜
104
+ def launch_gradio_interface():
105
+ interface = gr.Interface(
106
+ fn=process_news,
107
+ inputs=[
108
+ gr.Textbox(label="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL", placeholder="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ์„ ์ง์ ‘ ์ž…๋ ฅํ•˜๊ฑฐ๋‚˜ URL์„ ์ž…๋ ฅํ•˜์„ธ์š”.", lines=10),
109
+ gr.Slider(0, 40, value=20, step=1, label="์ œ๋ชฉ ์ƒ์„ฑ - ์ตœ๋Œ€ ๊ธธ์ด"),
110
+ gr.Slider(1, 10, value=10, step=1, label="์ œ๋ชฉ ์ƒ์„ฑ - ํƒ์ƒ‰ ๋น” ์ˆ˜"),
111
+ gr.Checkbox(value=True, label="์ œ๋ชฉ ์ƒ์„ฑ - ์กฐ๊ธฐ ์ข…๋ฃŒ"),
112
+ gr.Slider(0, 256, value=128, step=1, label="์š”์•ฝ ์ƒ์„ฑ - ์ตœ๋Œ€ ๊ธธ์ด"),
113
+ gr.Slider(1, 10, value=10, step=1, label="์š”์•ฝ ์ƒ์„ฑ - ํƒ์ƒ‰ ๋น” ์ˆ˜"),
114
+ gr.Checkbox(value=True, label="์š”์•ฝ ์ƒ์„ฑ - ์กฐ๊ธฐ ์ข…๋ฃŒ")
115
+ ],
116
+ outputs=[
117
+ gr.Textbox(label="์ถ”์ถœ๋˜๊ฑฐ๋‚˜ ์ž…๋ ฅ๋œ ๊ธฐ์‚ฌ ๋ณธ๋ฌธ"),
118
+ gr.Textbox(label="์ƒ์„ฑ๋œ ์ œ๋ชฉ"),
119
+ gr.Textbox(label="์ƒ์„ฑ๋œ ์š”์•ฝ๋ฌธ")
120
+ ],
121
+ title="๋„ค์ด๋ฒ„ ๋‰ด์Šค ์ œ๋ชฉ ๋ฐ ์š”์•ฝ ์ƒ์„ฑ๊ธฐ",
122
+ description="๋‰ด์Šค ๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋˜๋Š” URL์„ ์ž…๋ ฅํ•˜๋ฉด, ๋ณธ๋ฌธ์„ ์ž๋™ ์ถ”์ถœํ•œ ํ›„ ์ œ๋ชฉ๊ณผ ์š”์•ฝ๋ฌธ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค."
123
+ )
124
+ interface.launch(share=True)
125
+
126
+ if __name__ == '__main__':
127
+ launch_gradio_interface()