openfree's picture
Update app.py
6efc05e verified
raw
history blame
13 kB
import json
import gradio as gr
import spaces
import wbgtopic
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
import pandas as pd
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.sentiment import SentimentIntensityAnalyzer
from sklearn.cluster import KMeans
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize WBGDocTopic
clf = wbgtopic.WBGDocTopic(device=device)
try:
nltk.download('punkt', quiet=True)
nltk.download('vader_lexicon', quiet=True)
except:
pass
SAMPLE_TEXT = """Your sample text here ..."""
def safe_process(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
print(f"Error in {func.__name__}: {str(e)}")
return None
return wrapper
################################################################
# 1) Convert Raw Results into a Consistent Format #
################################################################
@safe_process
def parse_wbg_results(raw_output):
"""
Example: raw_output might be something like:
[
{ 'Innovation and Entrepreneurship': 0.32,
'Digital Development': 0.27,
...}
]
or it might be [ [ {...}, {...} ] ]
Adjust logic so we end up with a list of dicts:
[
{'label': 'Innovation and Entrepreneurship', 'score_mean': 0.32, 'score_std': 0.0},
{'label': 'Digital Development', 'score_mean': 0.27, 'score_std': 0.0},
...
]
"""
if not raw_output:
return []
# If the library returns a list with a single dictionary:
# raw_output[0] might be a dict of {topic: score}
# or it might be a list of dicts with 'label'/'score_mean' keys
first_item = raw_output[0]
# If it's already a list of dicts with 'label', 'score_mean', etc.
if isinstance(first_item, dict) and 'label' in first_item:
# Possibly we already have the correct format
return raw_output
# If it's a dict of {topic_label: numeric_score}
if isinstance(first_item, dict):
# Then let's convert it
parsed_list = []
for label, val in first_item.items():
parsed_list.append({
'label': label,
'score_mean': float(val),
'score_std': 0.0 # If no std is given, default 0
})
return parsed_list
# If itโ€™s something else, handle it
return []
################################################################
# 2) Section-based Analysis #
################################################################
@safe_process
def analyze_text_sections(text):
"""
Splits text into sections, calls clf.suggest_topics on each,
and returns a list-of-lists:
section_topics = [
[ {'label':'...', 'score_mean':...}, {...} ],
[ {'label':'...', 'score_mean':...}, {...} ],
...
]
"""
sentences = sent_tokenize(text)
# e.g. group 3 sentences per section
sections = [' '.join(sentences[i:i+3]) for i in range(0, len(sentences), 3)]
section_topics = []
for section in sections:
raw_sec = clf.suggest_topics(section)
parsed_sec = parse_wbg_results(raw_sec)
section_topics.append(parsed_sec)
return section_topics
################################################################
# 3) Basic Summaries (Correlation, Sentiment, Clusters etc.) #
################################################################
@safe_process
def calculate_topic_correlations(topic_dicts):
"""
If we only want a single dimension correlation (like score_mean),
we can do a simple correlation across different topics.
But typically you'd want multiple texts or some multi-dimensional approach.
"""
if len(topic_dicts) < 2:
# Not enough to do correlation
return np.array([[1.0]]), ["Insufficient topics"]
labels = [d['label'] for d in topic_dicts]
scores = [d['score_mean'] for d in topic_dicts] # single dimension
if len(scores) < 2:
return np.array([[1.0]]), ["Insufficient topics"]
corr_matrix = np.corrcoef(scores)
return corr_matrix, labels
@safe_process
def perform_sentiment_analysis(text):
sia = SentimentIntensityAnalyzer()
sents = sent_tokenize(text)
results = [sia.polarity_scores(s) for s in sents]
return pd.DataFrame(results)
@safe_process
def create_topic_clusters(topic_dicts):
if len(topic_dicts) < 3:
return [0]*len(topic_dicts) # trivial cluster
# Must have 'score_mean' and 'score_std' or something else
X = []
for t in topic_dicts:
X.append([t['score_mean'], t.get('score_std', 0.0)])
X = np.array(X)
if X.shape[0] < 3:
return [0]*X.shape[0]
kmeans = KMeans(n_clusters=min(3, X.shape[0]), random_state=42)
clusters = kmeans.fit_predict(X)
return clusters.tolist() # safe to JSON-encode
################################################################
# 4) Charts (Bar, Radar, Correlation Heatmap, etc.) #
################################################################
@safe_process
def create_main_charts(topic_dicts):
"""
Expects a list of dicts with keys: 'label', 'score_mean', ...
We'll just use 'score_mean' (or a scaled version).
"""
if not topic_dicts:
return go.Figure(), go.Figure()
# Bar chart
labels = [t['label'] for t in topic_dicts]
scores = [t['score_mean']*100 for t in topic_dicts] # convert to %
bar_fig = go.Figure(
data=[go.Bar(x=labels, y=scores, marker_color='rgb(55, 83, 109)')]
)
bar_fig.update_layout(
title='์ฃผ์ œ ๋ถ„์„ ๊ฒฐ๊ณผ',
xaxis_title='์ฃผ์ œ',
yaxis_title='๊ด€๋ จ๋„ (%)',
template='plotly_white',
height=500,
)
# Radar chart
radar_fig = go.Figure()
radar_fig.add_trace(go.Scatterpolar(
r=scores,
theta=labels,
fill='toself',
name='์ฃผ์ œ ๋ถ„ํฌ'
))
radar_fig.update_layout(
title='์ฃผ์ œ ๋ ˆ์ด๋” ์ฐจํŠธ',
template='plotly_white',
height=500,
polar=dict(radialaxis=dict(visible=True)),
showlegend=False
)
return bar_fig, radar_fig
@safe_process
def create_correlation_heatmap(corr_matrix, labels):
if corr_matrix.ndim == 0:
# It's a scalar => shape ()
corr_matrix = np.array([[corr_matrix]])
if corr_matrix.shape == (1,1):
# Usually means not enough data
fig = go.Figure()
fig.add_annotation(text="Not enough topics for correlation", showarrow=False)
return fig
fig = go.Figure(data=go.Heatmap(
z=corr_matrix,
x=labels,
y=labels,
colorscale='Viridis'
))
fig.update_layout(
title='์ฃผ์ œ ๊ฐ„ ์ƒ๊ด€๊ด€๊ณ„',
height=500,
template='plotly_white'
)
return fig
@safe_process
def create_topic_evolution(section_topics):
"""
section_topics: list of [ {label:..., score_mean:...}, ...]
one element per section
"""
fig = go.Figure()
if not section_topics or len(section_topics) == 0:
return fig
# Take the first sectionโ€™s list as reference
if not section_topics[0]:
return fig
# For each topic in the first section, gather its evolution
for topic_dict in section_topics[0]:
label = topic_dict['label']
score_list = []
for sec_list in section_topics:
# find matching label
match = next((d for d in sec_list if d['label'] == label), None)
if match:
score_list.append(match['score_mean'])
else:
score_list.append(0.0)
fig.add_trace(go.Scatter(
x=list(range(len(section_topics))),
y=score_list,
name=label,
mode='lines+markers'
))
fig.update_layout(
title='์ฃผ์ œ ๋ณ€ํ™” ์ถ”์ด',
xaxis_title='์„น์…˜',
yaxis_title='score_mean',
height=500,
template='plotly_white'
)
return fig
@safe_process
def create_confidence_gauge(topic_dicts):
"""
If your data doesnโ€™t actually have a separate confidence measure,
you may skip or adapt. For example, you might define confidence
= (1 - score_std)*100
"""
if not topic_dicts:
return go.Figure()
fig = go.Figure()
num_topics = len(topic_dicts)
for i, t in enumerate(topic_dicts):
confidence_val = 100.0*(1.0 - t.get('score_std', 0.0)) # an example
fig.add_trace(go.Indicator(
mode="gauge+number",
value=confidence_val,
title={'text': t['label']},
domain={'row': 0, 'column': i}
))
fig.update_layout(
grid={'rows': 1, 'columns': num_topics},
height=400,
template='plotly_white'
)
return fig
################################################################
# 5) Putting Everything into `process_all_analysis` #
################################################################
@spaces.GPU()
def process_all_analysis(text):
try:
# 1) Suggest topics on the entire text
raw_results = clf.suggest_topics(text)
all_topics = parse_wbg_results(raw_results) # keep full list of dicts
# 2) Top 5 (if you want to highlight them)
# Sort by score_mean descending
sorted_topics = sorted(all_topics, key=lambda x: x['score_mean'], reverse=True)
top_topics = sorted_topics[:5]
# 3) Section-based
section_topics = analyze_text_sections(text) # list of lists
# 4) Extra analyses
corr_matrix, corr_labels = calculate_topic_correlations(all_topics)
sentiments_df = perform_sentiment_analysis(text)
clusters = create_topic_clusters(all_topics)
# 5) Build charts
bar_chart, radar_chart = create_main_charts(top_topics) # show top 5 on bar
heatmap = create_correlation_heatmap(corr_matrix, corr_labels)
evolution_chart = create_topic_evolution(section_topics)
gauge_chart = create_confidence_gauge(top_topics)
# 6) Prepare output for the JSON field
# Make sure everything is JSON-serializable with string keys
results = {
"top_topics": top_topics, # list of dict
"clusters": clusters, # list of ints
"sentiments": sentiments_df.to_dict(orient="records"),
}
return (
results, # JSON output
bar_chart, # plot1
radar_chart, # plot2
heatmap, # plot3
evolution_chart,# plot4
gauge_chart, # plot5
go.Figure() # plot6 (placeholder for sentiment plot, or skip)
)
except Exception as e:
print(f"Analysis error: {str(e)}")
empty_fig = go.Figure()
return (
{"error": str(e), "topics": []},
empty_fig,
empty_fig,
empty_fig,
empty_fig,
empty_fig,
empty_fig
)
################################################################
# 6) Gradio UI #
################################################################
with gr.Blocks(title="๊ณ ๊ธ‰ ๋ฌธ์„œ ์ฃผ์ œ ๋ถ„์„๊ธฐ") as demo:
gr.Markdown("## ๐Ÿ“Š ๊ณ ๊ธ‰ ๋ฌธ์„œ ์ฃผ์ œ ๋ถ„์„๊ธฐ")
with gr.Row():
text_input = gr.Textbox(
value=SAMPLE_TEXT,
label="๋ถ„์„ํ•  ํ…์ŠคํŠธ",
lines=8
)
with gr.Row():
submit_btn = gr.Button("๋ถ„์„ ์‹œ์ž‘", variant="primary")
with gr.Tabs():
with gr.TabItem("์ฃผ์š” ๋ถ„์„"):
with gr.Row():
plot1 = gr.Plot(label="์ฃผ์ œ ๋ถ„ํฌ")
plot2 = gr.Plot(label="๋ ˆ์ด๋” ์ฐจํŠธ")
with gr.TabItem("์ƒ์„ธ ๋ถ„์„"):
with gr.Row():
plot3 = gr.Plot(label="์ƒ๊ด€๊ด€๊ณ„ ํžˆํŠธ๋งต")
plot4 = gr.Plot(label="์ฃผ์ œ ๋ณ€ํ™” ์ถ”์ด")
with gr.TabItem("์‹ ๋ขฐ๋„ ๋ถ„์„"):
plot5 = gr.Plot(label="์‹ ๋ขฐ๋„ ๊ฒŒ์ด์ง€")
with gr.TabItem("๊ฐ์„ฑ ๋ถ„์„"):
plot6 = gr.Plot(label="๊ฐ์„ฑ ๋ถ„์„ ๊ฒฐ๊ณผ")
with gr.Row():
output_json = gr.JSON(label="์ƒ์„ธ ๋ถ„์„ ๊ฒฐ๊ณผ")
submit_btn.click(
fn=process_all_analysis,
inputs=[text_input],
outputs=[output_json, plot1, plot2, plot3, plot4, plot5, plot6]
)
if __name__ == "__main__":
demo.queue(max_size=1)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=True
)