import json import gradio as gr import spaces import wbgtopic import plotly.graph_objects as go import plotly.express as px import numpy as np import pandas as pd import nltk from nltk.tokenize import sent_tokenize, word_tokenize from nltk.sentiment import SentimentIntensityAnalyzer from sklearn.cluster import KMeans import torch device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Initialize WBGDocTopic clf = wbgtopic.WBGDocTopic(device=device) try: nltk.download('punkt', quiet=True) nltk.download('vader_lexicon', quiet=True) except: pass SAMPLE_TEXT = """Your sample text here ...""" def safe_process(func): def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: print(f"Error in {func.__name__}: {str(e)}") return None return wrapper ################################################################ # 1) Convert Raw Results into a Consistent Format # ################################################################ @safe_process def parse_wbg_results(raw_output): """ Example: raw_output might be something like: [ { 'Innovation and Entrepreneurship': 0.32, 'Digital Development': 0.27, ...} ] or it might be [ [ {...}, {...} ] ] Adjust logic so we end up with a list of dicts: [ {'label': 'Innovation and Entrepreneurship', 'score_mean': 0.32, 'score_std': 0.0}, {'label': 'Digital Development', 'score_mean': 0.27, 'score_std': 0.0}, ... ] """ if not raw_output: return [] # If the library returns a list with a single dictionary: # raw_output[0] might be a dict of {topic: score} # or it might be a list of dicts with 'label'/'score_mean' keys first_item = raw_output[0] # If it's already a list of dicts with 'label', 'score_mean', etc. if isinstance(first_item, dict) and 'label' in first_item: # Possibly we already have the correct format return raw_output # If it's a dict of {topic_label: numeric_score} if isinstance(first_item, dict): # Then let's convert it parsed_list = [] for label, val in first_item.items(): parsed_list.append({ 'label': label, 'score_mean': float(val), 'score_std': 0.0 # If no std is given, default 0 }) return parsed_list # If it’s something else, handle it return [] ################################################################ # 2) Section-based Analysis # ################################################################ @safe_process def analyze_text_sections(text): """ Splits text into sections, calls clf.suggest_topics on each, and returns a list-of-lists: section_topics = [ [ {'label':'...', 'score_mean':...}, {...} ], [ {'label':'...', 'score_mean':...}, {...} ], ... ] """ sentences = sent_tokenize(text) # e.g. group 3 sentences per section sections = [' '.join(sentences[i:i+3]) for i in range(0, len(sentences), 3)] section_topics = [] for section in sections: raw_sec = clf.suggest_topics(section) parsed_sec = parse_wbg_results(raw_sec) section_topics.append(parsed_sec) return section_topics ################################################################ # 3) Basic Summaries (Correlation, Sentiment, Clusters etc.) # ################################################################ @safe_process def calculate_topic_correlations(topic_dicts): """ If we only want a single dimension correlation (like score_mean), we can do a simple correlation across different topics. But typically you'd want multiple texts or some multi-dimensional approach. """ if len(topic_dicts) < 2: # Not enough to do correlation return np.array([[1.0]]), ["Insufficient topics"] labels = [d['label'] for d in topic_dicts] scores = [d['score_mean'] for d in topic_dicts] # single dimension if len(scores) < 2: return np.array([[1.0]]), ["Insufficient topics"] corr_matrix = np.corrcoef(scores) return corr_matrix, labels @safe_process def perform_sentiment_analysis(text): sia = SentimentIntensityAnalyzer() sents = sent_tokenize(text) results = [sia.polarity_scores(s) for s in sents] return pd.DataFrame(results) @safe_process def create_topic_clusters(topic_dicts): if len(topic_dicts) < 3: return [0]*len(topic_dicts) # trivial cluster # Must have 'score_mean' and 'score_std' or something else X = [] for t in topic_dicts: X.append([t['score_mean'], t.get('score_std', 0.0)]) X = np.array(X) if X.shape[0] < 3: return [0]*X.shape[0] kmeans = KMeans(n_clusters=min(3, X.shape[0]), random_state=42) clusters = kmeans.fit_predict(X) return clusters.tolist() # safe to JSON-encode ################################################################ # 4) Charts (Bar, Radar, Correlation Heatmap, etc.) # ################################################################ @safe_process def create_main_charts(topic_dicts): """ Expects a list of dicts with keys: 'label', 'score_mean', ... We'll just use 'score_mean' (or a scaled version). """ if not topic_dicts: return go.Figure(), go.Figure() # Bar chart labels = [t['label'] for t in topic_dicts] scores = [t['score_mean']*100 for t in topic_dicts] # convert to % bar_fig = go.Figure( data=[go.Bar(x=labels, y=scores, marker_color='rgb(55, 83, 109)')] ) bar_fig.update_layout( title='주제 분석 결과', xaxis_title='주제', yaxis_title='관련도 (%)', template='plotly_white', height=500, ) # Radar chart radar_fig = go.Figure() radar_fig.add_trace(go.Scatterpolar( r=scores, theta=labels, fill='toself', name='주제 분포' )) radar_fig.update_layout( title='주제 레이더 차트', template='plotly_white', height=500, polar=dict(radialaxis=dict(visible=True)), showlegend=False ) return bar_fig, radar_fig @safe_process def create_correlation_heatmap(corr_matrix, labels): if corr_matrix.ndim == 0: # It's a scalar => shape () corr_matrix = np.array([[corr_matrix]]) if corr_matrix.shape == (1,1): # Usually means not enough data fig = go.Figure() fig.add_annotation(text="Not enough topics for correlation", showarrow=False) return fig fig = go.Figure(data=go.Heatmap( z=corr_matrix, x=labels, y=labels, colorscale='Viridis' )) fig.update_layout( title='주제 간 상관관계', height=500, template='plotly_white' ) return fig @safe_process def create_topic_evolution(section_topics): """ section_topics: list of [ {label:..., score_mean:...}, ...] one element per section """ fig = go.Figure() if not section_topics or len(section_topics) == 0: return fig # Take the first section’s list as reference if not section_topics[0]: return fig # For each topic in the first section, gather its evolution for topic_dict in section_topics[0]: label = topic_dict['label'] score_list = [] for sec_list in section_topics: # find matching label match = next((d for d in sec_list if d['label'] == label), None) if match: score_list.append(match['score_mean']) else: score_list.append(0.0) fig.add_trace(go.Scatter( x=list(range(len(section_topics))), y=score_list, name=label, mode='lines+markers' )) fig.update_layout( title='주제 변화 추이', xaxis_title='섹션', yaxis_title='score_mean', height=500, template='plotly_white' ) return fig @safe_process def create_confidence_gauge(topic_dicts): """ If your data doesn’t actually have a separate confidence measure, you may skip or adapt. For example, you might define confidence = (1 - score_std)*100 """ if not topic_dicts: return go.Figure() fig = go.Figure() num_topics = len(topic_dicts) for i, t in enumerate(topic_dicts): confidence_val = 100.0*(1.0 - t.get('score_std', 0.0)) # an example fig.add_trace(go.Indicator( mode="gauge+number", value=confidence_val, title={'text': t['label']}, domain={'row': 0, 'column': i} )) fig.update_layout( grid={'rows': 1, 'columns': num_topics}, height=400, template='plotly_white' ) return fig ################################################################ # 5) Putting Everything into `process_all_analysis` # ################################################################ @spaces.GPU() def process_all_analysis(text): try: # 1) Suggest topics on the entire text raw_results = clf.suggest_topics(text) all_topics = parse_wbg_results(raw_results) # keep full list of dicts # 2) Top 5 (if you want to highlight them) # Sort by score_mean descending sorted_topics = sorted(all_topics, key=lambda x: x['score_mean'], reverse=True) top_topics = sorted_topics[:5] # 3) Section-based section_topics = analyze_text_sections(text) # list of lists # 4) Extra analyses corr_matrix, corr_labels = calculate_topic_correlations(all_topics) sentiments_df = perform_sentiment_analysis(text) clusters = create_topic_clusters(all_topics) # 5) Build charts bar_chart, radar_chart = create_main_charts(top_topics) # show top 5 on bar heatmap = create_correlation_heatmap(corr_matrix, corr_labels) evolution_chart = create_topic_evolution(section_topics) gauge_chart = create_confidence_gauge(top_topics) # 6) Prepare output for the JSON field # Make sure everything is JSON-serializable with string keys results = { "top_topics": top_topics, # list of dict "clusters": clusters, # list of ints "sentiments": sentiments_df.to_dict(orient="records"), } return ( results, # JSON output bar_chart, # plot1 radar_chart, # plot2 heatmap, # plot3 evolution_chart,# plot4 gauge_chart, # plot5 go.Figure() # plot6 (placeholder for sentiment plot, or skip) ) except Exception as e: print(f"Analysis error: {str(e)}") empty_fig = go.Figure() return ( {"error": str(e), "topics": []}, empty_fig, empty_fig, empty_fig, empty_fig, empty_fig, empty_fig ) ################################################################ # 6) Gradio UI # ################################################################ with gr.Blocks(title="고급 문서 주제 분석기") as demo: gr.Markdown("## 📊 고급 문서 주제 분석기") with gr.Row(): text_input = gr.Textbox( value=SAMPLE_TEXT, label="분석할 텍스트", lines=8 ) with gr.Row(): submit_btn = gr.Button("분석 시작", variant="primary") with gr.Tabs(): with gr.TabItem("주요 분석"): with gr.Row(): plot1 = gr.Plot(label="주제 분포") plot2 = gr.Plot(label="레이더 차트") with gr.TabItem("상세 분석"): with gr.Row(): plot3 = gr.Plot(label="상관관계 히트맵") plot4 = gr.Plot(label="주제 변화 추이") with gr.TabItem("신뢰도 분석"): plot5 = gr.Plot(label="신뢰도 게이지") with gr.TabItem("감성 분석"): plot6 = gr.Plot(label="감성 분석 결과") with gr.Row(): output_json = gr.JSON(label="상세 분석 결과") submit_btn.click( fn=process_all_analysis, inputs=[text_input], outputs=[output_json, plot1, plot2, plot3, plot4, plot5, plot6] ) if __name__ == "__main__": demo.queue(max_size=1) demo.launch( server_name="0.0.0.0", server_port=7860, share=False, debug=True )