Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import gradio as gr | |
import spaces | |
import wbgtopic | |
import plotly.graph_objects as go | |
import plotly.express as px | |
import numpy as np | |
import pandas as pd | |
import nltk | |
from nltk.tokenize import sent_tokenize, word_tokenize | |
from nltk.sentiment import SentimentIntensityAnalyzer | |
from sklearn.cluster import KMeans | |
import torch | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Initialize WBGDocTopic | |
clf = wbgtopic.WBGDocTopic(device=device) | |
try: | |
nltk.download('punkt', quiet=True) | |
nltk.download('vader_lexicon', quiet=True) | |
except: | |
pass | |
SAMPLE_TEXT = """Your sample text here ...""" | |
def safe_process(func): | |
def wrapper(*args, **kwargs): | |
try: | |
return func(*args, **kwargs) | |
except Exception as e: | |
print(f"Error in {func.__name__}: {str(e)}") | |
return None | |
return wrapper | |
################################################################ | |
# 1) Convert Raw Results into a Consistent Format # | |
################################################################ | |
def parse_wbg_results(raw_output): | |
""" | |
Example: raw_output might be something like: | |
[ | |
{ 'Innovation and Entrepreneurship': 0.32, | |
'Digital Development': 0.27, | |
...} | |
] | |
or it might be [ [ {...}, {...} ] ] | |
Adjust logic so we end up with a list of dicts: | |
[ | |
{'label': 'Innovation and Entrepreneurship', 'score_mean': 0.32, 'score_std': 0.0}, | |
{'label': 'Digital Development', 'score_mean': 0.27, 'score_std': 0.0}, | |
... | |
] | |
""" | |
if not raw_output: | |
return [] | |
# If the library returns a list with a single dictionary: | |
# raw_output[0] might be a dict of {topic: score} | |
# or it might be a list of dicts with 'label'/'score_mean' keys | |
first_item = raw_output[0] | |
# If it's already a list of dicts with 'label', 'score_mean', etc. | |
if isinstance(first_item, dict) and 'label' in first_item: | |
# Possibly we already have the correct format | |
return raw_output | |
# If it's a dict of {topic_label: numeric_score} | |
if isinstance(first_item, dict): | |
# Then let's convert it | |
parsed_list = [] | |
for label, val in first_item.items(): | |
parsed_list.append({ | |
'label': label, | |
'score_mean': float(val), | |
'score_std': 0.0 # If no std is given, default 0 | |
}) | |
return parsed_list | |
# If itโs something else, handle it | |
return [] | |
################################################################ | |
# 2) Section-based Analysis # | |
################################################################ | |
def analyze_text_sections(text): | |
""" | |
Splits text into sections, calls clf.suggest_topics on each, | |
and returns a list-of-lists: | |
section_topics = [ | |
[ {'label':'...', 'score_mean':...}, {...} ], | |
[ {'label':'...', 'score_mean':...}, {...} ], | |
... | |
] | |
""" | |
sentences = sent_tokenize(text) | |
# e.g. group 3 sentences per section | |
sections = [' '.join(sentences[i:i+3]) for i in range(0, len(sentences), 3)] | |
section_topics = [] | |
for section in sections: | |
raw_sec = clf.suggest_topics(section) | |
parsed_sec = parse_wbg_results(raw_sec) | |
section_topics.append(parsed_sec) | |
return section_topics | |
################################################################ | |
# 3) Basic Summaries (Correlation, Sentiment, Clusters etc.) # | |
################################################################ | |
def calculate_topic_correlations(topic_dicts): | |
""" | |
If we only want a single dimension correlation (like score_mean), | |
we can do a simple correlation across different topics. | |
But typically you'd want multiple texts or some multi-dimensional approach. | |
""" | |
if len(topic_dicts) < 2: | |
# Not enough to do correlation | |
return np.array([[1.0]]), ["Insufficient topics"] | |
labels = [d['label'] for d in topic_dicts] | |
scores = [d['score_mean'] for d in topic_dicts] # single dimension | |
if len(scores) < 2: | |
return np.array([[1.0]]), ["Insufficient topics"] | |
corr_matrix = np.corrcoef(scores) | |
return corr_matrix, labels | |
def perform_sentiment_analysis(text): | |
sia = SentimentIntensityAnalyzer() | |
sents = sent_tokenize(text) | |
results = [sia.polarity_scores(s) for s in sents] | |
return pd.DataFrame(results) | |
def create_topic_clusters(topic_dicts): | |
if len(topic_dicts) < 3: | |
return [0]*len(topic_dicts) # trivial cluster | |
# Must have 'score_mean' and 'score_std' or something else | |
X = [] | |
for t in topic_dicts: | |
X.append([t['score_mean'], t.get('score_std', 0.0)]) | |
X = np.array(X) | |
if X.shape[0] < 3: | |
return [0]*X.shape[0] | |
kmeans = KMeans(n_clusters=min(3, X.shape[0]), random_state=42) | |
clusters = kmeans.fit_predict(X) | |
return clusters.tolist() # safe to JSON-encode | |
################################################################ | |
# 4) Charts (Bar, Radar, Correlation Heatmap, etc.) # | |
################################################################ | |
def create_main_charts(topic_dicts): | |
""" | |
Expects a list of dicts with keys: 'label', 'score_mean', ... | |
We'll just use 'score_mean' (or a scaled version). | |
""" | |
if not topic_dicts: | |
return go.Figure(), go.Figure() | |
# Bar chart | |
labels = [t['label'] for t in topic_dicts] | |
scores = [t['score_mean']*100 for t in topic_dicts] # convert to % | |
bar_fig = go.Figure( | |
data=[go.Bar(x=labels, y=scores, marker_color='rgb(55, 83, 109)')] | |
) | |
bar_fig.update_layout( | |
title='์ฃผ์ ๋ถ์ ๊ฒฐ๊ณผ', | |
xaxis_title='์ฃผ์ ', | |
yaxis_title='๊ด๋ จ๋ (%)', | |
template='plotly_white', | |
height=500, | |
) | |
# Radar chart | |
radar_fig = go.Figure() | |
radar_fig.add_trace(go.Scatterpolar( | |
r=scores, | |
theta=labels, | |
fill='toself', | |
name='์ฃผ์ ๋ถํฌ' | |
)) | |
radar_fig.update_layout( | |
title='์ฃผ์ ๋ ์ด๋ ์ฐจํธ', | |
template='plotly_white', | |
height=500, | |
polar=dict(radialaxis=dict(visible=True)), | |
showlegend=False | |
) | |
return bar_fig, radar_fig | |
def create_correlation_heatmap(corr_matrix, labels): | |
if corr_matrix.ndim == 0: | |
# It's a scalar => shape () | |
corr_matrix = np.array([[corr_matrix]]) | |
if corr_matrix.shape == (1,1): | |
# Usually means not enough data | |
fig = go.Figure() | |
fig.add_annotation(text="Not enough topics for correlation", showarrow=False) | |
return fig | |
fig = go.Figure(data=go.Heatmap( | |
z=corr_matrix, | |
x=labels, | |
y=labels, | |
colorscale='Viridis' | |
)) | |
fig.update_layout( | |
title='์ฃผ์ ๊ฐ ์๊ด๊ด๊ณ', | |
height=500, | |
template='plotly_white' | |
) | |
return fig | |
def create_topic_evolution(section_topics): | |
""" | |
section_topics: list of [ {label:..., score_mean:...}, ...] | |
one element per section | |
""" | |
fig = go.Figure() | |
if not section_topics or len(section_topics) == 0: | |
return fig | |
# Take the first sectionโs list as reference | |
if not section_topics[0]: | |
return fig | |
# For each topic in the first section, gather its evolution | |
for topic_dict in section_topics[0]: | |
label = topic_dict['label'] | |
score_list = [] | |
for sec_list in section_topics: | |
# find matching label | |
match = next((d for d in sec_list if d['label'] == label), None) | |
if match: | |
score_list.append(match['score_mean']) | |
else: | |
score_list.append(0.0) | |
fig.add_trace(go.Scatter( | |
x=list(range(len(section_topics))), | |
y=score_list, | |
name=label, | |
mode='lines+markers' | |
)) | |
fig.update_layout( | |
title='์ฃผ์ ๋ณํ ์ถ์ด', | |
xaxis_title='์น์ ', | |
yaxis_title='score_mean', | |
height=500, | |
template='plotly_white' | |
) | |
return fig | |
def create_confidence_gauge(topic_dicts): | |
""" | |
If your data doesnโt actually have a separate confidence measure, | |
you may skip or adapt. For example, you might define confidence | |
= (1 - score_std)*100 | |
""" | |
if not topic_dicts: | |
return go.Figure() | |
fig = go.Figure() | |
num_topics = len(topic_dicts) | |
for i, t in enumerate(topic_dicts): | |
confidence_val = 100.0*(1.0 - t.get('score_std', 0.0)) # an example | |
fig.add_trace(go.Indicator( | |
mode="gauge+number", | |
value=confidence_val, | |
title={'text': t['label']}, | |
domain={'row': 0, 'column': i} | |
)) | |
fig.update_layout( | |
grid={'rows': 1, 'columns': num_topics}, | |
height=400, | |
template='plotly_white' | |
) | |
return fig | |
################################################################ | |
# 5) Putting Everything into `process_all_analysis` # | |
################################################################ | |
def process_all_analysis(text): | |
try: | |
# 1) Suggest topics on the entire text | |
raw_results = clf.suggest_topics(text) | |
all_topics = parse_wbg_results(raw_results) # keep full list of dicts | |
# 2) Top 5 (if you want to highlight them) | |
# Sort by score_mean descending | |
sorted_topics = sorted(all_topics, key=lambda x: x['score_mean'], reverse=True) | |
top_topics = sorted_topics[:5] | |
# 3) Section-based | |
section_topics = analyze_text_sections(text) # list of lists | |
# 4) Extra analyses | |
corr_matrix, corr_labels = calculate_topic_correlations(all_topics) | |
sentiments_df = perform_sentiment_analysis(text) | |
clusters = create_topic_clusters(all_topics) | |
# 5) Build charts | |
bar_chart, radar_chart = create_main_charts(top_topics) # show top 5 on bar | |
heatmap = create_correlation_heatmap(corr_matrix, corr_labels) | |
evolution_chart = create_topic_evolution(section_topics) | |
gauge_chart = create_confidence_gauge(top_topics) | |
# 6) Prepare output for the JSON field | |
# Make sure everything is JSON-serializable with string keys | |
results = { | |
"top_topics": top_topics, # list of dict | |
"clusters": clusters, # list of ints | |
"sentiments": sentiments_df.to_dict(orient="records"), | |
} | |
return ( | |
results, # JSON output | |
bar_chart, # plot1 | |
radar_chart, # plot2 | |
heatmap, # plot3 | |
evolution_chart,# plot4 | |
gauge_chart, # plot5 | |
go.Figure() # plot6 (placeholder for sentiment plot, or skip) | |
) | |
except Exception as e: | |
print(f"Analysis error: {str(e)}") | |
empty_fig = go.Figure() | |
return ( | |
{"error": str(e), "topics": []}, | |
empty_fig, | |
empty_fig, | |
empty_fig, | |
empty_fig, | |
empty_fig, | |
empty_fig | |
) | |
################################################################ | |
# 6) Gradio UI # | |
################################################################ | |
with gr.Blocks(title="๊ณ ๊ธ ๋ฌธ์ ์ฃผ์ ๋ถ์๊ธฐ") as demo: | |
gr.Markdown("## ๐ ๊ณ ๊ธ ๋ฌธ์ ์ฃผ์ ๋ถ์๊ธฐ") | |
with gr.Row(): | |
text_input = gr.Textbox( | |
value=SAMPLE_TEXT, | |
label="๋ถ์ํ ํ ์คํธ", | |
lines=8 | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("๋ถ์ ์์", variant="primary") | |
with gr.Tabs(): | |
with gr.TabItem("์ฃผ์ ๋ถ์"): | |
with gr.Row(): | |
plot1 = gr.Plot(label="์ฃผ์ ๋ถํฌ") | |
plot2 = gr.Plot(label="๋ ์ด๋ ์ฐจํธ") | |
with gr.TabItem("์์ธ ๋ถ์"): | |
with gr.Row(): | |
plot3 = gr.Plot(label="์๊ด๊ด๊ณ ํํธ๋งต") | |
plot4 = gr.Plot(label="์ฃผ์ ๋ณํ ์ถ์ด") | |
with gr.TabItem("์ ๋ขฐ๋ ๋ถ์"): | |
plot5 = gr.Plot(label="์ ๋ขฐ๋ ๊ฒ์ด์ง") | |
with gr.TabItem("๊ฐ์ฑ ๋ถ์"): | |
plot6 = gr.Plot(label="๊ฐ์ฑ ๋ถ์ ๊ฒฐ๊ณผ") | |
with gr.Row(): | |
output_json = gr.JSON(label="์์ธ ๋ถ์ ๊ฒฐ๊ณผ") | |
submit_btn.click( | |
fn=process_all_analysis, | |
inputs=[text_input], | |
outputs=[output_json, plot1, plot2, plot3, plot4, plot5, plot6] | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=1) | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
debug=True | |
) | |