openfree commited on
Commit
ea6037b
·
verified ·
1 Parent(s): 8824b67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -56
app.py CHANGED
@@ -10,24 +10,42 @@ import numpy as np
10
  import pandas as pd
11
  from collections import Counter
12
  from scipy import stats
 
13
  from wordcloud import WordCloud
14
  from topic_translator import translate_topics
15
  from nltk.tokenize import sent_tokenize, word_tokenize
16
  from nltk.sentiment import SentimentIntensityAnalyzer
 
 
 
 
17
 
18
  # NLTK 필요 데이터 다운로드
19
- nltk.download('punkt')
20
- nltk.download('vader_lexicon')
 
 
 
21
 
22
  SAMPLE_TEXT = """
23
  The three reportedly discussed the Stargate Project, a large-scale AI initiative led by OpenAI, SoftBank, and U.S. software giant Oracle. The project aims to invest $500 billion over the next four years in building new AI infrastructure in the U.S. The U.S. government has shown a strong commitment to the initiative, with President Donald Trump personally announcing it at the White House the day after his inauguration last month. If Samsung participates, the project will lead to a Korea-U.S.-Japan AI alliance.
24
  The AI sector requires massive investments and extensive resources, including advanced models, high-performance AI chips to power the models, and large-scale data centers to operate them. Nvidia and TSMC currently dominate the AI sector, but a partnership between Samsung, SoftBank, and OpenAI could pave the way for a competitive alternative.
25
- """
 
 
 
26
 
27
- clf = wbgtopic.WBGDocTopic()
 
 
 
 
 
 
 
28
 
 
29
  def analyze_text_sections(text):
30
- # 문단별 분석
31
  sentences = sent_tokenize(text)
32
  sections = [' '.join(sentences[i:i+3]) for i in range(0, len(sentences), 3)]
33
  section_topics = []
@@ -38,33 +56,37 @@ def analyze_text_sections(text):
38
 
39
  return section_topics
40
 
 
41
  def calculate_topic_correlations(topics):
42
- # 주제 간 상관관계 계산
43
  topic_scores = {}
44
  for topic in topics:
45
  topic_scores[topic['label']] = topic['score_mean']
46
 
 
 
 
47
  correlation_matrix = np.corrcoef(list(topic_scores.values()))
48
  return correlation_matrix, list(topic_scores.keys())
49
 
 
50
  def perform_sentiment_analysis(text):
51
- # 감성 분석
52
  sia = SentimentIntensityAnalyzer()
53
  sentences = sent_tokenize(text)
54
  sentiments = [sia.polarity_scores(sent) for sent in sentences]
55
  return pd.DataFrame(sentiments)
56
 
 
57
  def create_topic_clusters(topics):
58
- # 주제 군집화
59
- from sklearn.cluster import KMeans
 
60
  X = np.array([[t['score_mean'], t['score_std']] for t in topics])
61
- kmeans = KMeans(n_clusters=3, random_state=42)
62
  clusters = kmeans.fit_predict(X)
63
  return clusters
64
 
65
-
66
  def create_main_charts(topics):
67
- # 1. 기본 막대 차트
68
  bar_fig = go.Figure()
69
  bar_fig.add_trace(go.Bar(
70
  x=[t['label'] for t in topics],
@@ -72,9 +94,14 @@ def create_main_charts(topics):
72
  name='관련도',
73
  marker_color='rgb(55, 83, 109)'
74
  ))
75
- bar_fig.update_layout(title='주제 분석 결과', height=500)
 
 
 
 
 
 
76
 
77
- # 2. 레이더 차트
78
  radar_fig = go.Figure()
79
  radar_fig.add_trace(go.Scatterpolar(
80
  r=[t['score'] for t in topics],
@@ -82,10 +109,15 @@ def create_main_charts(topics):
82
  fill='toself',
83
  name='주제 분포'
84
  ))
85
- radar_fig.update_layout(title='주제 레이더 차트')
 
 
 
 
86
 
87
  return bar_fig, radar_fig
88
 
 
89
  def create_correlation_heatmap(corr_matrix, labels):
90
  fig = go.Figure(data=go.Heatmap(
91
  z=corr_matrix,
@@ -93,72 +125,136 @@ def create_correlation_heatmap(corr_matrix, labels):
93
  y=labels,
94
  colorscale='Viridis'
95
  ))
96
- fig.update_layout(title='주제 간 상관관계')
 
 
 
 
97
  return fig
98
 
 
99
  def create_topic_evolution(section_topics):
 
 
 
100
  fig = go.Figure()
101
  for topic in section_topics[0]:
102
- topic_scores = [topics[topic['label']]['score_mean']
103
- for topics in section_topics]
104
- fig.add_trace(go.Scatter(
105
- x=list(range(len(section_topics))),
106
- y=topic_scores,
107
- name=topic['label'],
108
- mode='lines+markers'
109
- ))
110
- fig.update_layout(title='주제 변화 추이')
 
 
 
 
 
 
 
 
 
 
 
111
  return fig
112
 
 
113
  def create_confidence_gauge(topics):
114
  fig = go.Figure()
115
- for topic in topics:
116
  fig.add_trace(go.Indicator(
117
  mode="gauge+number",
118
  value=topic['confidence'],
119
  title={'text': topic['label']},
120
- domain={'row': 0, 'column': 0}
121
  ))
122
- fig.update_layout(grid={'rows': 1, 'columns': len(topics)})
 
 
 
 
123
  return fig
124
 
125
-
126
- def process_all_analysis(text):
127
- # 기본 주제 분석
128
- raw_results = clf.suggest_topics(text)
129
- topics = process_results(raw_results)
130
 
131
- # 추가 분석
132
- section_topics = analyze_text_sections(text)
133
- corr_matrix, labels = calculate_topic_correlations(topics)
134
- sentiments = perform_sentiment_analysis(text)
135
- clusters = create_topic_clusters(topics)
136
 
137
- # 차트 생성
138
- bar_chart, radar_chart = create_main_charts(topics)
139
- heatmap = create_correlation_heatmap(corr_matrix, labels)
140
- evolution_chart = create_topic_evolution(section_topics)
141
- gauge_chart = create_confidence_gauge(topics)
 
 
 
142
 
143
- return {
144
- 'topics': topics,
145
- 'bar_chart': bar_chart,
146
- 'radar_chart': radar_chart,
147
- 'heatmap': heatmap,
148
- 'evolution': evolution_chart,
149
- 'gauge': gauge_chart,
150
- 'sentiments': sentiments.to_dict(),
151
- 'clusters': clusters.tolist()
152
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
 
154
  with gr.Blocks(title="고급 문서 주제 분석기") as demo:
155
  gr.Markdown("## 📊 고급 문서 주제 분석기")
 
156
 
157
  with gr.Row():
158
- text = gr.Textbox(value=SAMPLE_TEXT, label="분석할 텍스트", lines=5)
 
 
 
 
 
159
 
160
  with gr.Row():
161
- submit_btn = gr.Button("분석 시작")
162
 
163
  with gr.Tabs():
164
  with gr.TabItem("주요 분석"):
@@ -186,4 +282,11 @@ with gr.Blocks(title="고급 문서 주제 분석기") as demo:
186
  outputs=[output, plot1, plot2, plot3, plot4, plot5, plot6]
187
  )
188
 
189
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
10
  import pandas as pd
11
  from collections import Counter
12
  from scipy import stats
13
+ import torch
14
  from wordcloud import WordCloud
15
  from topic_translator import translate_topics
16
  from nltk.tokenize import sent_tokenize, word_tokenize
17
  from nltk.sentiment import SentimentIntensityAnalyzer
18
+ from sklearn.cluster import KMeans
19
+
20
+ # GPU 설정
21
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
 
23
  # NLTK 필요 데이터 다운로드
24
+ try:
25
+ nltk.download('punkt', quiet=True)
26
+ nltk.download('vader_lexicon', quiet=True)
27
+ except Exception as e:
28
+ print(f"NLTK 데이터 다운로드 중 오류 발생: {e}")
29
 
30
  SAMPLE_TEXT = """
31
  The three reportedly discussed the Stargate Project, a large-scale AI initiative led by OpenAI, SoftBank, and U.S. software giant Oracle. The project aims to invest $500 billion over the next four years in building new AI infrastructure in the U.S. The U.S. government has shown a strong commitment to the initiative, with President Donald Trump personally announcing it at the White House the day after his inauguration last month. If Samsung participates, the project will lead to a Korea-U.S.-Japan AI alliance.
32
  The AI sector requires massive investments and extensive resources, including advanced models, high-performance AI chips to power the models, and large-scale data centers to operate them. Nvidia and TSMC currently dominate the AI sector, but a partnership between Samsung, SoftBank, and OpenAI could pave the way for a competitive alternative.
33
+ """
34
+
35
+ # WBGDocTopic 초기화 시 device 지정
36
+ clf = wbgtopic.WBGDocTopic(device=device)
37
 
38
+ def safe_process(func):
39
+ def wrapper(*args, **kwargs):
40
+ try:
41
+ return func(*args, **kwargs)
42
+ except Exception as e:
43
+ print(f"Error in {func.__name__}: {str(e)}")
44
+ return None
45
+ return wrapper
46
 
47
+ @safe_process
48
  def analyze_text_sections(text):
 
49
  sentences = sent_tokenize(text)
50
  sections = [' '.join(sentences[i:i+3]) for i in range(0, len(sentences), 3)]
51
  section_topics = []
 
56
 
57
  return section_topics
58
 
59
+ @safe_process
60
  def calculate_topic_correlations(topics):
 
61
  topic_scores = {}
62
  for topic in topics:
63
  topic_scores[topic['label']] = topic['score_mean']
64
 
65
+ if len(topic_scores) < 2:
66
+ return np.array([[1]]), list(topic_scores.keys())
67
+
68
  correlation_matrix = np.corrcoef(list(topic_scores.values()))
69
  return correlation_matrix, list(topic_scores.keys())
70
 
71
+ @safe_process
72
  def perform_sentiment_analysis(text):
 
73
  sia = SentimentIntensityAnalyzer()
74
  sentences = sent_tokenize(text)
75
  sentiments = [sia.polarity_scores(sent) for sent in sentences]
76
  return pd.DataFrame(sentiments)
77
 
78
+ @safe_process
79
  def create_topic_clusters(topics):
80
+ if len(topics) < 3:
81
+ return np.zeros(len(topics))
82
+
83
  X = np.array([[t['score_mean'], t['score_std']] for t in topics])
84
+ kmeans = KMeans(n_clusters=min(3, len(topics)), random_state=42)
85
  clusters = kmeans.fit_predict(X)
86
  return clusters
87
 
88
+ @safe_process
89
  def create_main_charts(topics):
 
90
  bar_fig = go.Figure()
91
  bar_fig.add_trace(go.Bar(
92
  x=[t['label'] for t in topics],
 
94
  name='관련도',
95
  marker_color='rgb(55, 83, 109)'
96
  ))
97
+ bar_fig.update_layout(
98
+ title='주제 분석 결과',
99
+ height=500,
100
+ xaxis_title='주제',
101
+ yaxis_title='관련도 (%)',
102
+ template='plotly_white'
103
+ )
104
 
 
105
  radar_fig = go.Figure()
106
  radar_fig.add_trace(go.Scatterpolar(
107
  r=[t['score'] for t in topics],
 
109
  fill='toself',
110
  name='주제 분포'
111
  ))
112
+ radar_fig.update_layout(
113
+ title='주제 레이더 차트',
114
+ height=500,
115
+ template='plotly_white'
116
+ )
117
 
118
  return bar_fig, radar_fig
119
 
120
+ @safe_process
121
  def create_correlation_heatmap(corr_matrix, labels):
122
  fig = go.Figure(data=go.Heatmap(
123
  z=corr_matrix,
 
125
  y=labels,
126
  colorscale='Viridis'
127
  ))
128
+ fig.update_layout(
129
+ title='주제 간 상관관계',
130
+ height=500,
131
+ template='plotly_white'
132
+ )
133
  return fig
134
 
135
+ @safe_process
136
  def create_topic_evolution(section_topics):
137
+ if not section_topics or len(section_topics) == 0:
138
+ return go.Figure()
139
+
140
  fig = go.Figure()
141
  for topic in section_topics[0]:
142
+ try:
143
+ topic_scores = [topics[topic['label']]['score_mean']
144
+ for topics in section_topics]
145
+ fig.add_trace(go.Scatter(
146
+ x=list(range(len(section_topics))),
147
+ y=topic_scores,
148
+ name=topic['label'],
149
+ mode='lines+markers'
150
+ ))
151
+ except Exception as e:
152
+ print(f"Error processing topic {topic['label']}: {e}")
153
+ continue
154
+
155
+ fig.update_layout(
156
+ title='주제 변화 추이',
157
+ xaxis_title='섹션',
158
+ yaxis_title='관련도',
159
+ height=500,
160
+ template='plotly_white'
161
+ )
162
  return fig
163
 
164
+ @safe_process
165
  def create_confidence_gauge(topics):
166
  fig = go.Figure()
167
+ for i, topic in enumerate(topics):
168
  fig.add_trace(go.Indicator(
169
  mode="gauge+number",
170
  value=topic['confidence'],
171
  title={'text': topic['label']},
172
+ domain={'row': 0, 'column': i, 'x': [i/len(topics), (i+1)/len(topics)]}
173
  ))
174
+ fig.update_layout(
175
+ grid={'rows': 1, 'columns': len(topics)},
176
+ height=400,
177
+ template='plotly_white'
178
+ )
179
  return fig
180
 
181
+ @safe_process
182
+ def process_results(results):
183
+ if not results or not results[0]:
184
+ return []
 
185
 
186
+ topics = results[0]
187
+ top_topics = sorted(topics, key=lambda x: x['score_mean'], reverse=True)[:5]
 
 
 
188
 
189
+ formatted_topics = []
190
+ for topic in top_topics:
191
+ formatted_topic = {
192
+ 'label': translate_topics.get(topic['label'], topic['label']),
193
+ 'score': round(topic['score_mean'] * 100, 1),
194
+ 'confidence': round((1 - topic['score_std']) * 100, 1)
195
+ }
196
+ formatted_topics.append(formatted_topic)
197
 
198
+ return formatted_topics
199
+
200
+ @spaces.GPU(enable_queue=True, duration=50)
201
+ def process_all_analysis(text):
202
+ try:
203
+ # 기본 주제 분석
204
+ raw_results = clf.suggest_topics(text)
205
+ topics = process_results(raw_results)
206
+
207
+ # 추가 분석
208
+ section_topics = analyze_text_sections(text)
209
+ corr_matrix, labels = calculate_topic_correlations(topics)
210
+ sentiments = perform_sentiment_analysis(text)
211
+ clusters = create_topic_clusters(topics)
212
+
213
+ # 차트 생성
214
+ bar_chart, radar_chart = create_main_charts(topics)
215
+ heatmap = create_correlation_heatmap(corr_matrix, labels)
216
+ evolution_chart = create_topic_evolution(section_topics)
217
+ gauge_chart = create_confidence_gauge(topics)
218
+
219
+ return {
220
+ 'topics': topics,
221
+ 'bar_chart': bar_chart,
222
+ 'radar_chart': radar_chart,
223
+ 'heatmap': heatmap,
224
+ 'evolution': evolution_chart,
225
+ 'gauge': gauge_chart,
226
+ 'sentiments': sentiments.to_dict() if sentiments is not None else {},
227
+ 'clusters': clusters.tolist() if clusters is not None else []
228
+ }
229
+ except Exception as e:
230
+ print(f"Analysis error: {str(e)}")
231
+ return {
232
+ 'error': str(e),
233
+ 'topics': [],
234
+ 'bar_chart': go.Figure(),
235
+ 'radar_chart': go.Figure(),
236
+ 'heatmap': go.Figure(),
237
+ 'evolution': go.Figure(),
238
+ 'gauge': go.Figure(),
239
+ 'sentiments': {},
240
+ 'clusters': []
241
+ }
242
 
243
+ # Gradio 인터페이스
244
  with gr.Blocks(title="고급 문서 주제 분석기") as demo:
245
  gr.Markdown("## 📊 고급 문서 주제 분석기")
246
+ gr.Markdown("문서를 입력하면 다양한 분석 결과를 시각화���여 보여줍니다.")
247
 
248
  with gr.Row():
249
+ text = gr.Textbox(
250
+ value=SAMPLE_TEXT,
251
+ label="분석할 텍스트",
252
+ placeholder="여기에 분석할 텍스트를 입력하세요",
253
+ lines=8
254
+ )
255
 
256
  with gr.Row():
257
+ submit_btn = gr.Button("분석 시작", variant="primary")
258
 
259
  with gr.Tabs():
260
  with gr.TabItem("주요 분석"):
 
282
  outputs=[output, plot1, plot2, plot3, plot4, plot5, plot6]
283
  )
284
 
285
+ if __name__ == "__main__":
286
+ demo.queue(concurrency_count=1)
287
+ demo.launch(
288
+ server_name="0.0.0.0",
289
+ server_port=7860,
290
+ share=False,
291
+ debug=True
292
+ )