openfree commited on
Commit
cbda16b
ยท
verified ยท
1 Parent(s): 5ad5437

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -97
app.py CHANGED
@@ -12,20 +12,20 @@ from nltk.sentiment import SentimentIntensityAnalyzer
12
  from sklearn.cluster import KMeans
13
  import torch
14
 
15
- # GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์‹œ ์„ค์ •
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
 
18
- # WBGDocTopic ์ดˆ๊ธฐํ™”
19
  clf = wbgtopic.WBGDocTopic(device=device)
20
 
21
- # NLTK ๋ฐ์ดํ„ฐ ๋‹ค์šด๋กœ๋“œ (ํ•„์š” ์‹œ)
22
  try:
23
  nltk.download('punkt', quiet=True)
24
  nltk.download('vader_lexicon', quiet=True)
25
  except Exception as e:
26
- print(f"NLTK ๋ฐ์ดํ„ฐ ๋‹ค์šด๋กœ๋“œ ์—๋Ÿฌ: {e}")
27
 
28
- # ์ƒ˜ํ”Œ ํ…์ŠคํŠธ
29
  SAMPLE_TEXT = """
30
  The three reportedly discussed the Stargate Project, a large-scale AI initiative led by OpenAI, SoftBank, and U.S. software giant Oracle. The project aims to invest $500 billion over the next four years in building new AI infrastructure in the U.S. The U.S. government has shown a strong commitment to the initiative, with President Donald Trump personally announcing it at the White House the day after his inauguration last month. If Samsung participates, the project will lead to a Korea-U.S.-Japan AI alliance.
31
  The AI sector requires massive investments and extensive resources, including advanced models, high-performance AI chips to power the models, and large-scale data centers to operate them. Nvidia and TSMC currently dominate the AI sector, but a partnership between Samsung, SoftBank, and OpenAI could pave the way for a competitive alternative.
@@ -33,8 +33,7 @@ The AI sector requires massive investments and extensive resources, including ad
33
 
34
  def safe_process(func):
35
  """
36
- ์˜ˆ์™ธ ๋ฐœ์ƒ ์‹œ ๋กœ๊ทธ๋ฅผ ๋‚จ๊ธฐ๊ณ  None์„ ๋ฐ˜ํ™˜ํ•˜๋Š” ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ.
37
- Gradio ์ธํ„ฐํŽ˜์ด์Šค๊ฐ€ ์˜ˆ์™ธ๋กœ ์ธํ•ด ์ค‘๋‹จ๋˜์ง€ ์•Š๋„๋ก ๋„์™€์ค๋‹ˆ๋‹ค.
38
  """
39
  def wrapper(*args, **kwargs):
40
  try:
@@ -47,10 +46,10 @@ def safe_process(func):
47
  @safe_process
48
  def parse_wbg_results(raw_output):
49
  """
50
- wbgtopic.WBGDocTopic์˜ suggest_topics() ๊ฒฐ๊ณผ๋ฅผ
51
- 'label', 'score_mean', 'score_std' ๊ตฌ์กฐ์˜ ๋ฆฌ์ŠคํŠธ๋กœ ํ†ต์ผํ•ด์„œ ๋ฐ˜ํ™˜ํ•œ๋‹ค.
52
 
53
- ๋ฐ˜ํ™˜ ๊ตฌ์กฐ ์˜ˆ์‹œ:
54
  [
55
  {
56
  "label": "Agriculture",
@@ -60,30 +59,36 @@ def parse_wbg_results(raw_output):
60
  ...
61
  ]
62
  """
63
- # ๋””๋ฒ„๊ทธ: ์‹ค์ œ ๊ฒฐ๊ณผ ๊ตฌ์กฐ๋ฅผ ํ™•์ธํ•ด๋ณด์„ธ์š”
64
  print(">>> DEBUG: raw_output =", raw_output)
65
 
66
- # ๊ฒฐ๊ณผ๊ฐ€ ๋น„์—ˆ์œผ๋ฉด ๋นˆ ๋ฆฌ์ŠคํŠธ ๋ฐ˜ํ™˜
 
 
 
 
 
 
 
 
 
 
 
67
  if not raw_output:
68
  return []
69
 
 
70
  first_item = raw_output[0]
71
 
72
- # (1) ์ด๋ฏธ 'label' ํ‚ค๊ฐ€ ์žˆ๋Š” ๋”•์…”๋„ˆ๋ฆฌ ํ˜•ํƒœ๋ผ๋ฉด
73
- # ์˜ˆ: [{"label": "...", "score": ...}, ...] ํ˜น์€ {"label": "...", "score_mean": ...}
74
  if isinstance(first_item, dict) and ("label" in first_item):
75
  parsed_list = []
76
  for item in raw_output:
77
  label = item.get("label", "")
78
- # score_mean / score_std๊ฐ€ ์ด๋ฏธ ์žˆ์œผ๋ฉด ์‚ฌ์šฉ
79
- # ์—†์œผ๋ฉด score ๋“ฑ์—์„œ ์ถ”๋ก 
80
  score_mean = item.get("score_mean", None)
81
  score_std = item.get("score_std", None)
82
 
83
- # ์˜ˆ: score๋งŒ ์žˆ๋Š” ๊ฒฝ์šฐ
84
  if score_mean is None and "score" in item:
85
- # ์ ์ˆ˜๊ฐ€ 0~1 ๋ฒ”์œ„์ธ์ง€ 0~100 ๋ฒ”์œ„์ธ์ง€ ํ™•์ธ ํ•„์š”
86
- # ์šฐ์„  ๊ทธ๋Œ€๋กœ float ์ฒ˜๋ฆฌ
87
  score_mean = float(item["score"])
88
  if score_mean is None:
89
  score_mean = 0.0
@@ -98,17 +103,12 @@ def parse_wbg_results(raw_output):
98
  })
99
  return parsed_list
100
 
101
- # (2) ํ† ํ”ฝ ์ด๋ฆ„: ์ ์ˆ˜ ํ˜•ํƒœ์˜ ๋”•์…”๋„ˆ๋ฆฌ๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ
102
- # ์˜ˆ: [{"Agriculture": 0.22, "Climate Change": 0.55}, ...]
103
  if isinstance(first_item, dict):
104
- # raw_output๊ฐ€ ์—ฌ๋Ÿฌ dict๋ฅผ ๋‹ด๊ณ  ์žˆ์„ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ, ํ•˜๋‚˜๋กœ ํ•ฉ์น˜๊ฑฐ๋‚˜
105
- # ํ˜น์€ ์ฒซ ๋ฒˆ์งธ dict๋งŒ ํŒŒ์‹ฑํ• ์ง€ ๊ฒฐ์ •ํ•ด์•ผ ํ•จ.
106
- # ์ผ๋‹จ ์—ฌ๊ธฐ์„œ๋Š” ํ•ฉ์น˜๋Š” ๋ฐฉ์‹์œผ๋กœ ์‹œ์—ฐ:
107
  merged = {}
108
  for d in raw_output:
109
  for k, v in d.items():
110
- # ํ‚ค ์ค‘๋ณต ์‹œ ๋งˆ์ง€๋ง‰ ๊ฐ’์œผ๋กœ overwrite
111
- merged[k] = v
112
 
113
  parsed_list = []
114
  for label, val in merged.items():
@@ -119,17 +119,17 @@ def parse_wbg_results(raw_output):
119
  })
120
  return parsed_list
121
 
122
- # ์˜ˆ์ƒ์น˜ ๋ชปํ•œ ๊ตฌ์กฐ์ธ ๊ฒฝ์šฐ
123
  return []
124
 
125
  @safe_process
126
  def analyze_text_sections(text):
127
  """
128
- ํ…์ŠคํŠธ๋ฅผ ์—ฌ๋Ÿฌ ์„น์…˜(์˜ˆ: 3๋ฌธ์žฅ์”ฉ)์œผ๋กœ ๋‚˜๋ˆ„๊ณ ,
129
- ๊ฐ ์„น์…˜๋ณ„๋กœ suggest_topics() ๊ฒฐ๊ณผ๋ฅผ parse_wbg_results()๋กœ ํŒŒ์‹ฑํ•ด ๋ฆฌ์ŠคํŠธ๋กœ ๋ชจ์€๋‹ค.
130
  """
131
  sentences = sent_tokenize(text)
132
- # 3๋ฌธ์žฅ์”ฉ ๋ฌถ์–ด์„œ ํ•˜๋‚˜์˜ ์„น์…˜์„ ๊ตฌ์„ฑ
133
  sections = [' '.join(sentences[i:i+3]) for i in range(0, len(sentences), 3)]
134
 
135
  section_topics = []
@@ -143,11 +143,8 @@ def analyze_text_sections(text):
143
  @safe_process
144
  def calculate_topic_correlations(topic_dicts):
145
  """
146
- topic_dicts: [{'label': ..., 'score_mean': ..., 'score_std': ...}, ...]
147
-
148
- ์ฃผ์ œ๋ณ„ score_mean๋งŒ ๋ฝ‘์•„์„œ ์ƒ๊ด€๊ด€๊ณ„๋ฅผ ๊ตฌํ•œ๋‹ค.
149
- ์‹ค์ œ๋กœ๋Š” '์„œ๋กœ ๋‹ค๋ฅธ ๋ฌธ์„œ'๋“ค์— ๋Œ€ํ•œ ์ƒ๊ด€์„ ๊ตฌํ•˜๋Š” ๊ฒƒ์ด ํƒ€๋‹นํ•˜๋‚˜,
150
- ์—ฌ๊ธฐ์„œ๋Š” ์˜ˆ์‹œ๋กœ ๋‹จ์ผ ํ…์ŠคํŠธ์˜ ์„œ๋กœ ๋‹ค๋ฅธ ํ† ํ”ฝ๋“ค ๊ฐ„ ์ ์ˆ˜ ์ƒ๊ด€๋„๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค.
151
  """
152
  if len(topic_dicts) < 2:
153
  return np.array([[1.0]]), ["Insufficient topics"]
@@ -164,8 +161,8 @@ def calculate_topic_correlations(topic_dicts):
164
  @safe_process
165
  def perform_sentiment_analysis(text):
166
  """
167
- NLTK VADER๋ฅผ ์‚ฌ์šฉํ•ด ๋ฌธ์žฅ๋ณ„ ๊ฐ์„ฑ ์ ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค.
168
- ๋ฐ˜ํ™˜๊ฐ’์€ pandas DataFrame ํ˜•์‹.
169
  """
170
  sia = SentimentIntensityAnalyzer()
171
  sents = sent_tokenize(text)
@@ -175,8 +172,8 @@ def perform_sentiment_analysis(text):
175
  @safe_process
176
  def create_topic_clusters(topic_dicts):
177
  """
178
- score_mean, score_std 2์ฐจ์›์œผ๋กœ KMeans ํด๋Ÿฌ์Šคํ„ฐ๋ง.
179
- ํ† ํ”ฝ ์ˆ˜๊ฐ€ 3๊ฐœ ๋ฏธ๋งŒ์ด๋ฉด trivially 0๋ฒˆ ํด๋Ÿฌ์Šคํ„ฐ๋กœ ์ฒ˜๋ฆฌ.
180
  """
181
  if len(topic_dicts) < 3:
182
  return [0] * len(topic_dicts)
@@ -196,8 +193,8 @@ def create_topic_clusters(topic_dicts):
196
  @safe_process
197
  def create_main_charts(topic_dicts):
198
  """
199
- ๋ฐ” ์ฐจํŠธ์™€ ๋ ˆ์ด๋” ์ฐจํŠธ๋ฅผ ์ƒ์„ฑ.
200
- 'score_mean'์„ 0~1๋กœ ๋ณด๊ณ , 100๋ฐฐ ํ•˜์—ฌ ํผ์„ผํŠธ๋กœ ์‹œ๊ฐํ™”.
201
  """
202
  if not topic_dicts:
203
  return go.Figure(), go.Figure()
@@ -205,28 +202,28 @@ def create_main_charts(topic_dicts):
205
  labels = [t['label'] for t in topic_dicts]
206
  scores = [t['score_mean'] * 100 for t in topic_dicts]
207
 
208
- # ๋ฐ” ์ฐจํŠธ
209
  bar_fig = go.Figure(
210
  data=[go.Bar(x=labels, y=scores, marker_color='rgb(55, 83, 109)')]
211
  )
212
  bar_fig.update_layout(
213
- title='์ฃผ์ œ ๋ถ„์„ ๊ฒฐ๊ณผ',
214
- xaxis_title='์ฃผ์ œ',
215
- yaxis_title='๊ด€๋ จ๋„(%)',
216
  template='plotly_white',
217
  height=500,
218
  )
219
 
220
- # ๋ ˆ์ด๋” ์ฐจํŠธ
221
  radar_fig = go.Figure()
222
  radar_fig.add_trace(go.Scatterpolar(
223
  r=scores,
224
  theta=labels,
225
  fill='toself',
226
- name='์ฃผ์ œ ๋ถ„ํฌ'
227
  ))
228
  radar_fig.update_layout(
229
- title='์ฃผ์ œ ๋ ˆ์ด๋” ์ฐจํŠธ',
230
  template='plotly_white',
231
  height=500,
232
  polar=dict(radialaxis=dict(visible=True)),
@@ -237,15 +234,13 @@ def create_main_charts(topic_dicts):
237
  @safe_process
238
  def create_correlation_heatmap(corr_matrix, labels):
239
  """
240
- ์ƒ๊ด€๊ด€๊ณ„ ํ–‰๋ ฌ์„ ํžˆํŠธ๋งต์œผ๋กœ ์‹œ๊ฐํ™”.
241
- ๋งŒ์•ฝ ๋ฐ์ดํ„ฐ๊ฐ€ ๋ถ€์กฑํ•˜๋ฉด ์•ˆ๋‚ด ๋ฌธ๊ตฌ๋งŒ ํ‘œ์‹œ.
242
  """
243
  if corr_matrix.ndim == 0:
244
- # ์Šค์นผ๋ผ(0์ฐจ์›)์ด๋ฉด 2์ฐจ์› ๋ฐฐ์—ด๋กœ ๋ฐ”๊ฟ”์คŒ
245
  corr_matrix = np.array([[corr_matrix]])
246
 
247
  if corr_matrix.shape == (1, 1):
248
- # ๋ฐ์ดํ„ฐ๊ฐ€ ๋ถ€์กฑํ•œ ๊ฒฝ์šฐ
249
  fig = go.Figure()
250
  fig.add_annotation(text="Not enough topics for correlation", showarrow=False)
251
  return fig
@@ -257,7 +252,7 @@ def create_correlation_heatmap(corr_matrix, labels):
257
  colorscale='Viridis'
258
  ))
259
  fig.update_layout(
260
- title='์ฃผ์ œ ์ƒ๊ด€๊ด€๊ณ„ ํžˆํŠธ๋งต',
261
  height=500,
262
  template='plotly_white'
263
  )
@@ -266,8 +261,8 @@ def create_correlation_heatmap(corr_matrix, labels):
266
  @safe_process
267
  def create_topic_evolution(section_topics):
268
  """
269
- ์„น์…˜๋ณ„ ํ† ํ”ฝ ์ ์ˆ˜ ๋ณ€ํ™”๋ฅผ ๋ผ์ธ ์ฐจํŠธ๋กœ ๋‚˜ํƒ€๋‚ธ๋‹ค.
270
- section_topics: [[{'label':..., 'score_mean':...}, ...], [...], ...]
271
  """
272
  fig = go.Figure()
273
  if not section_topics or len(section_topics) == 0:
@@ -276,7 +271,7 @@ def create_topic_evolution(section_topics):
276
  if not section_topics[0]:
277
  return fig
278
 
279
- # ์ฒซ ์„น์…˜์˜ ํ† ํ”ฝ๋“ค์„ ๊ธฐ์ค€์œผ๋กœ, ๊ฐ ์„น์…˜๋งˆ๋‹ค ํ•ด๋‹น ํ† ํ”ฝ์ด ์กด์žฌํ•˜๋ฉด ์ ์ˆ˜๋ฅผ ์ถ”์ถœ
280
  for topic_dict in section_topics[0]:
281
  label = topic_dict['label']
282
  score_list = []
@@ -295,9 +290,9 @@ def create_topic_evolution(section_topics):
295
  ))
296
 
297
  fig.update_layout(
298
- title='์„น์…˜๋ณ„ ์ฃผ์ œ ๋ณ€ํ™” ์ถ”์ด',
299
- xaxis_title='์„น์…˜',
300
- yaxis_title='score_mean',
301
  height=500,
302
  template='plotly_white'
303
  )
@@ -306,8 +301,8 @@ def create_topic_evolution(section_topics):
306
  @safe_process
307
  def create_confidence_gauge(topic_dicts):
308
  """
309
- ๊ฐ ํ† ํ”ฝ์˜ ์‹ ๋ขฐ๋„๋ฅผ ๊ฒŒ์ด์ง€ ํ˜•ํƒœ๋กœ ํ‘œ์‹œ.
310
- ์—ฌ๊ธฐ์„œ๋Š” (1 - score_std) * 100 ๋‹จ์ˆœ ๊ณต์‹ ์‚ฌ์šฉ.
311
  """
312
  if not topic_dicts:
313
  return go.Figure()
@@ -334,47 +329,47 @@ def create_confidence_gauge(topic_dicts):
334
  @spaces.GPU()
335
  def process_all_analysis(text):
336
  """
337
- ์ „์ฒด ํ…์ŠคํŠธ์— ๋Œ€ํ•œ ํ† ํ”ฝ ๋ถ„์„, ์„น์…˜ ๋ถ„์„, ์ƒ๊ด€๊ด€๊ณ„, ๊ฐ์„ฑ๋ถ„์„, ํด๋Ÿฌ์Šคํ„ฐ๋ง ๋“ฑ์„ ์ˆ˜ํ–‰ํ•œ ๋’ค
338
- JSON ๊ฒฐ๊ณผ์™€ Plotly ์ฐจํŠธ๋“ค์„ ๋ฐ˜ํ™˜ํ•œ๋‹ค.
339
  """
340
  try:
341
- # 1) ์ „์ฒด ํ…์ŠคํŠธ ๋Œ€์ƒ ํ† ํ”ฝ ๋ถ„์„
342
  raw_results = clf.suggest_topics(text)
343
  all_topics = parse_wbg_results(raw_results)
344
 
345
- # 2) score_mean ๊ธฐ์ค€ ๋‚ด๋ฆผ์ฐจ์ˆœ ์ •๋ ฌ ํ›„ ์ƒ์œ„ 5๊ฐœ
346
  sorted_topics = sorted(all_topics, key=lambda x: x['score_mean'], reverse=True)
347
  top_topics = sorted_topics[:5]
348
 
349
- # 3) ์„น์…˜ ๋ณ„ ๋ถ„์„
350
  section_topics = analyze_text_sections(text)
351
 
352
- # 4) ์ถ”๊ฐ€ ๋ถ„์„(์ƒ๊ด€๊ด€๊ณ„, ๊ฐ์„ฑ๋ถ„์„, ํด๋Ÿฌ์Šคํ„ฐ)
353
  corr_matrix, corr_labels = calculate_topic_correlations(all_topics)
354
  sentiments_df = perform_sentiment_analysis(text)
355
  clusters = create_topic_clusters(all_topics)
356
 
357
- # 5) ์ฐจํŠธ ์ƒ์„ฑ
358
  bar_chart, radar_chart = create_main_charts(top_topics)
359
  heatmap = create_correlation_heatmap(corr_matrix, corr_labels)
360
  evolution_chart = create_topic_evolution(section_topics)
361
  gauge_chart = create_confidence_gauge(top_topics)
362
 
363
- # 6) JSON ํ˜•ํƒœ๋กœ ๋ฌถ์–ด์„œ ๋ฐ˜ํ™˜(๋ฌธ์ž์—ด ํ‚ค๋งŒ ์‚ฌ์šฉ)
364
  results = {
365
- "top_topics": top_topics, # ์ƒ์œ„ 5๊ฐœ ํ† ํ”ฝ
366
- "clusters": clusters, # ํด๋Ÿฌ์Šคํ„ฐ ๊ฒฐ๊ณผ
367
- "sentiments": sentiments_df.to_dict(orient="records") # ๊ฐ์„ฑ ๋ถ„์„
368
  }
369
 
370
  return (
371
- results, # JSON output
372
- bar_chart, # plot1
373
- radar_chart, # plot2
374
- heatmap, # plot3
375
- evolution_chart, # plot4
376
- gauge_chart, # plot5
377
- go.Figure() # plot6 (ํ•„์š” ์‹œ ๊ฐ์„ฑ๋ถ„์„ ๊ทธ๋ž˜ํ”„ ์‚ฌ์šฉ)
378
  )
379
 
380
  except Exception as e:
@@ -394,38 +389,38 @@ def process_all_analysis(text):
394
  # Gradio UI Definition #
395
  ######################################################
396
 
397
- with gr.Blocks(title="๊ณ ๊ธ‰ ๋ฌธ์„œ ์ฃผ์ œ ๋ถ„์„๊ธฐ") as demo:
398
- gr.Markdown("## ๊ณ ๊ธ‰ ๋ฌธ์„œ ์ฃผ์ œ ๋ถ„์„๊ธฐ")
399
  gr.Markdown(
400
- "ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•œ ๋’ค, **๋ถ„์„ ์‹œ์ž‘** ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ์ฃผ์„ธ์š”. "
401
- "์ฃผ์š” ํ† ํ”ฝ ๋ถ„์„, ์ƒ๊ด€๊ด€๊ณ„, ์‹ ๋ขฐ๋„ ๊ฒŒ์ด์ง€, ๊ฐ์„ฑ๋ถ„์„ ๊ฒฐ๊ณผ ๋“ฑ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."
402
  )
403
 
404
  with gr.Row():
405
  text_input = gr.Textbox(
406
  value=SAMPLE_TEXT,
407
- label="๋ถ„์„ํ•  ํ…์ŠคํŠธ ์ž…๋ ฅ",
408
  lines=8
409
  )
410
  with gr.Row():
411
- submit_btn = gr.Button("๋ถ„์„ ์‹œ์ž‘", variant="primary")
412
 
413
  with gr.Tabs():
414
- with gr.TabItem("์ฃผ์š” ๋ถ„์„"):
415
  with gr.Row():
416
- plot1 = gr.Plot(label="์ฃผ์ œ ๋ถ„ํฌ(Bar Chart)")
417
- plot2 = gr.Plot(label="๋ ˆ์ด๋” ์ฐจํŠธ")
418
- with gr.TabItem("์ƒ์„ธ ๋ถ„์„"):
419
  with gr.Row():
420
- plot3 = gr.Plot(label="์ƒ๊ด€๊ด€๊ณ„ ํžˆํŠธ๋งต")
421
- plot4 = gr.Plot(label="์„น์…˜๋ณ„ ํ† ํ”ฝ ๋ณ€ํ™”")
422
- with gr.TabItem("์‹ ๋ขฐ๋„ ๋ถ„์„"):
423
- plot5 = gr.Plot(label="์‹ ๋ขฐ๋„ ๊ฒŒ์ด์ง€")
424
- with gr.TabItem("๊ฐ์„ฑ ๋ถ„์„"):
425
- plot6 = gr.Plot(label="๊ฐ์„ฑ๋ถ„์„ ๊ฒฐ๊ณผ")
426
 
427
  with gr.Row():
428
- output_json = gr.JSON(label="์ƒ์„ธ ๋ถ„์„ ๊ฒฐ๊ณผ(JSON)")
429
 
430
  submit_btn.click(
431
  fn=process_all_analysis,
@@ -438,6 +433,6 @@ if __name__ == "__main__":
438
  demo.launch(
439
  server_name="0.0.0.0",
440
  server_port=7860,
441
- share=False, # ๊ณต๊ฐœ ๋งํฌ ํ•„์š” ์‹œ True
442
  debug=True
443
  )
 
12
  from sklearn.cluster import KMeans
13
  import torch
14
 
15
+ # Set up GPU if available
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
 
18
+ # Initialize WBGDocTopic
19
  clf = wbgtopic.WBGDocTopic(device=device)
20
 
21
+ # Download NLTK data if necessary
22
  try:
23
  nltk.download('punkt', quiet=True)
24
  nltk.download('vader_lexicon', quiet=True)
25
  except Exception as e:
26
+ print(f"NLTK data download error: {e}")
27
 
28
+ # Sample Text
29
  SAMPLE_TEXT = """
30
  The three reportedly discussed the Stargate Project, a large-scale AI initiative led by OpenAI, SoftBank, and U.S. software giant Oracle. The project aims to invest $500 billion over the next four years in building new AI infrastructure in the U.S. The U.S. government has shown a strong commitment to the initiative, with President Donald Trump personally announcing it at the White House the day after his inauguration last month. If Samsung participates, the project will lead to a Korea-U.S.-Japan AI alliance.
31
  The AI sector requires massive investments and extensive resources, including advanced models, high-performance AI chips to power the models, and large-scale data centers to operate them. Nvidia and TSMC currently dominate the AI sector, but a partnership between Samsung, SoftBank, and OpenAI could pave the way for a competitive alternative.
 
33
 
34
  def safe_process(func):
35
  """
36
+ Decorator to log exceptions and return None to prevent Gradio interface crashes.
 
37
  """
38
  def wrapper(*args, **kwargs):
39
  try:
 
46
  @safe_process
47
  def parse_wbg_results(raw_output):
48
  """
49
+ Standardize the output of wbgtopic.WBGDocTopic's suggest_topics() into a list of dictionaries with keys:
50
+ 'label', 'score_mean', and 'score_std'.
51
 
52
+ Example return structure:
53
  [
54
  {
55
  "label": "Agriculture",
 
59
  ...
60
  ]
61
  """
 
62
  print(">>> DEBUG: raw_output =", raw_output)
63
 
64
+ # If raw_output is a dict (instead of a list), convert it to the expected list format.
65
+ if isinstance(raw_output, dict):
66
+ parsed_list = []
67
+ for k, v in raw_output.items():
68
+ parsed_list.append({
69
+ "label": k,
70
+ "score_mean": float(v) if v is not None else 0.0,
71
+ "score_std": 0.0
72
+ })
73
+ return parsed_list
74
+
75
+ # If the result is empty, return an empty list.
76
  if not raw_output:
77
  return []
78
 
79
+ # Assume raw_output is a list; get the first item.
80
  first_item = raw_output[0]
81
 
82
+ # Case 1: Already in the form of a dictionary with a 'label' key.
 
83
  if isinstance(first_item, dict) and ("label" in first_item):
84
  parsed_list = []
85
  for item in raw_output:
86
  label = item.get("label", "")
 
 
87
  score_mean = item.get("score_mean", None)
88
  score_std = item.get("score_std", None)
89
 
90
+ # If only 'score' exists, use it as score_mean.
91
  if score_mean is None and "score" in item:
 
 
92
  score_mean = float(item["score"])
93
  if score_mean is None:
94
  score_mean = 0.0
 
103
  })
104
  return parsed_list
105
 
106
+ # Case 2: Dictionary with topic names as keys and scores as values.
 
107
  if isinstance(first_item, dict):
 
 
 
108
  merged = {}
109
  for d in raw_output:
110
  for k, v in d.items():
111
+ merged[k] = v # Overwrite duplicates with the last occurrence.
 
112
 
113
  parsed_list = []
114
  for label, val in merged.items():
 
119
  })
120
  return parsed_list
121
 
122
+ # If the structure is unexpected, return an empty list.
123
  return []
124
 
125
  @safe_process
126
  def analyze_text_sections(text):
127
  """
128
+ Split the text into multiple sections (e.g., every 3 sentences) and analyze topics for each section using suggest_topics().
129
+ Returns a list of topic lists for each section.
130
  """
131
  sentences = sent_tokenize(text)
132
+ # Group every 3 sentences into one section.
133
  sections = [' '.join(sentences[i:i+3]) for i in range(0, len(sentences), 3)]
134
 
135
  section_topics = []
 
143
  @safe_process
144
  def calculate_topic_correlations(topic_dicts):
145
  """
146
+ Calculate correlation among topic score_means from a list of topic dictionaries.
147
+ Note: Ideally, correlations should be calculated across different documents, but here we use topics from a single text.
 
 
 
148
  """
149
  if len(topic_dicts) < 2:
150
  return np.array([[1.0]]), ["Insufficient topics"]
 
161
  @safe_process
162
  def perform_sentiment_analysis(text):
163
  """
164
+ Perform sentiment analysis on each sentence using NLTK's VADER.
165
+ Returns a pandas DataFrame of sentiment scores.
166
  """
167
  sia = SentimentIntensityAnalyzer()
168
  sents = sent_tokenize(text)
 
172
  @safe_process
173
  def create_topic_clusters(topic_dicts):
174
  """
175
+ Perform KMeans clustering on topics based on score_mean and score_std.
176
+ If there are fewer than 3 topics, assign all to cluster 0.
177
  """
178
  if len(topic_dicts) < 3:
179
  return [0] * len(topic_dicts)
 
193
  @safe_process
194
  def create_main_charts(topic_dicts):
195
  """
196
+ Generate a bar chart and radar chart.
197
+ 'score_mean' is assumed to be in the range 0-1 and is converted to a percentage.
198
  """
199
  if not topic_dicts:
200
  return go.Figure(), go.Figure()
 
202
  labels = [t['label'] for t in topic_dicts]
203
  scores = [t['score_mean'] * 100 for t in topic_dicts]
204
 
205
+ # Bar chart
206
  bar_fig = go.Figure(
207
  data=[go.Bar(x=labels, y=scores, marker_color='rgb(55, 83, 109)')]
208
  )
209
  bar_fig.update_layout(
210
+ title='Topic Analysis Results',
211
+ xaxis_title='Topics',
212
+ yaxis_title='Relevance (%)',
213
  template='plotly_white',
214
  height=500,
215
  )
216
 
217
+ # Radar chart
218
  radar_fig = go.Figure()
219
  radar_fig.add_trace(go.Scatterpolar(
220
  r=scores,
221
  theta=labels,
222
  fill='toself',
223
+ name='Topic Distribution'
224
  ))
225
  radar_fig.update_layout(
226
+ title='Topic Radar Chart',
227
  template='plotly_white',
228
  height=500,
229
  polar=dict(radialaxis=dict(visible=True)),
 
234
  @safe_process
235
  def create_correlation_heatmap(corr_matrix, labels):
236
  """
237
+ Visualize the correlation matrix as a heatmap.
238
+ If data is insufficient, display a message.
239
  """
240
  if corr_matrix.ndim == 0:
 
241
  corr_matrix = np.array([[corr_matrix]])
242
 
243
  if corr_matrix.shape == (1, 1):
 
244
  fig = go.Figure()
245
  fig.add_annotation(text="Not enough topics for correlation", showarrow=False)
246
  return fig
 
252
  colorscale='Viridis'
253
  ))
254
  fig.update_layout(
255
+ title='Topic Correlation Heatmap',
256
  height=500,
257
  template='plotly_white'
258
  )
 
261
  @safe_process
262
  def create_topic_evolution(section_topics):
263
  """
264
+ Create a line chart showing topic score evolution across different sections.
265
+ section_topics: List of lists containing topic dictionaries for each section.
266
  """
267
  fig = go.Figure()
268
  if not section_topics or len(section_topics) == 0:
 
271
  if not section_topics[0]:
272
  return fig
273
 
274
+ # Use topics from the first section as reference.
275
  for topic_dict in section_topics[0]:
276
  label = topic_dict['label']
277
  score_list = []
 
290
  ))
291
 
292
  fig.update_layout(
293
+ title='Section-wise Topic Evolution',
294
+ xaxis_title='Section',
295
+ yaxis_title='Score Mean',
296
  height=500,
297
  template='plotly_white'
298
  )
 
301
  @safe_process
302
  def create_confidence_gauge(topic_dicts):
303
  """
304
+ Display each topic's confidence as a gauge.
305
+ Confidence is calculated using a simple formula: (1 - score_std) * 100.
306
  """
307
  if not topic_dicts:
308
  return go.Figure()
 
329
  @spaces.GPU()
330
  def process_all_analysis(text):
331
  """
332
+ Perform comprehensive analysis on the input text, including topic analysis, section analysis,
333
+ correlation, sentiment analysis, clustering, and generate corresponding JSON results and Plotly charts.
334
  """
335
  try:
336
+ # 1) Analyze topics for the entire text.
337
  raw_results = clf.suggest_topics(text)
338
  all_topics = parse_wbg_results(raw_results)
339
 
340
+ # 2) Sort topics by score_mean in descending order and take the top 5.
341
  sorted_topics = sorted(all_topics, key=lambda x: x['score_mean'], reverse=True)
342
  top_topics = sorted_topics[:5]
343
 
344
+ # 3) Analyze topics by sections.
345
  section_topics = analyze_text_sections(text)
346
 
347
+ # 4) Additional analyses (correlation, sentiment, clustering).
348
  corr_matrix, corr_labels = calculate_topic_correlations(all_topics)
349
  sentiments_df = perform_sentiment_analysis(text)
350
  clusters = create_topic_clusters(all_topics)
351
 
352
+ # 5) Generate charts.
353
  bar_chart, radar_chart = create_main_charts(top_topics)
354
  heatmap = create_correlation_heatmap(corr_matrix, corr_labels)
355
  evolution_chart = create_topic_evolution(section_topics)
356
  gauge_chart = create_confidence_gauge(top_topics)
357
 
358
+ # 6) Return results as JSON and charts.
359
  results = {
360
+ "top_topics": top_topics, # Top 5 topics.
361
+ "clusters": clusters, # Cluster results.
362
+ "sentiments": sentiments_df.to_dict(orient="records") # Sentiment analysis results.
363
  }
364
 
365
  return (
366
+ results, # JSON output.
367
+ bar_chart, # Plot 1: Topic Distribution (Bar Chart).
368
+ radar_chart, # Plot 2: Radar Chart.
369
+ heatmap, # Plot 3: Correlation Heatmap.
370
+ evolution_chart, # Plot 4: Section Topic Evolution.
371
+ gauge_chart, # Plot 5: Confidence Gauge.
372
+ go.Figure() # Plot 6: (Placeholder for Sentiment Analysis Chart).
373
  )
374
 
375
  except Exception as e:
 
389
  # Gradio UI Definition #
390
  ######################################################
391
 
392
+ with gr.Blocks(title="Advanced Document Topic Analyzer") as demo:
393
+ gr.Markdown("## Advanced Document Topic Analyzer")
394
  gr.Markdown(
395
+ "Enter the text below and click **Start Analysis**. "
396
+ "The tool will analyze key topics, correlations, confidence gauges, sentiment analysis, and more."
397
  )
398
 
399
  with gr.Row():
400
  text_input = gr.Textbox(
401
  value=SAMPLE_TEXT,
402
+ label="Enter Text for Analysis",
403
  lines=8
404
  )
405
  with gr.Row():
406
+ submit_btn = gr.Button("Start Analysis", variant="primary")
407
 
408
  with gr.Tabs():
409
+ with gr.TabItem("Main Analysis"):
410
  with gr.Row():
411
+ plot1 = gr.Plot(label="Topic Distribution (Bar Chart)")
412
+ plot2 = gr.Plot(label="Radar Chart")
413
+ with gr.TabItem("Detailed Analysis"):
414
  with gr.Row():
415
+ plot3 = gr.Plot(label="Correlation Heatmap")
416
+ plot4 = gr.Plot(label="Section Topic Evolution")
417
+ with gr.TabItem("Confidence Analysis"):
418
+ plot5 = gr.Plot(label="Confidence Gauge")
419
+ with gr.TabItem("Sentiment Analysis"):
420
+ plot6 = gr.Plot(label="Sentiment Analysis Result")
421
 
422
  with gr.Row():
423
+ output_json = gr.JSON(label="Detailed Analysis Result (JSON)")
424
 
425
  submit_btn.click(
426
  fn=process_all_analysis,
 
433
  demo.launch(
434
  server_name="0.0.0.0",
435
  server_port=7860,
436
+ share=False, # Set to True if you need a public shareable link.
437
  debug=True
438
  )