File size: 2,761 Bytes
dfa0bd7
2b1e4b7
d27df0e
96070b5
bc928c9
 
2b1e4b7
74b6cd5
4efedce
96070b5
2b1e4b7
bc928c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74b6cd5
 
bc928c9
74b6cd5
 
 
 
 
 
 
 
 
 
 
 
 
bc928c9
 
 
 
96070b5
d27df0e
96070b5
74b6cd5
 
96070b5
74b6cd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc928c9
 
 
 
 
74b6cd5
 
4efedce
74b6cd5
 
bc928c9
4efedce
 
74b6cd5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import json
import gradio as gr
import spaces
import wbgtopic
import plotly.graph_objects as go
from topic_translator import translate_topics

SAMPLE_TEXT = """A growing literature attributes gender inequality in labor market outcomes in part to the reduction in female labor supply after childbirth, the child penalty..."""

clf = wbgtopic.WBGDocTopic()

def create_chart(topics):
    # ๋ฐ์ดํ„ฐ ์ค€๋น„
    labels = [t['label'] for t in topics]
    scores = [t['score'] for t in topics]
    confidence = [t['confidence'] for t in topics]
    
    # ๋ง‰๋Œ€ ์ฐจํŠธ ์ƒ์„ฑ
    fig = go.Figure()
    
    # ์ฃผ์š” ๋ง‰๋Œ€ ์ถ”๊ฐ€
    fig.add_trace(go.Bar(
        x=labels,
        y=scores,
        name='๊ด€๋ จ๋„',
        marker_color='rgb(55, 83, 109)'
    ))
    
    # ์ฐจํŠธ ๋ ˆ์ด์•„์›ƒ ์„ค์ •
    fig.update_layout(
        title='๋ฌธ์„œ ์ฃผ์ œ ๋ถ„์„ ๊ฒฐ๊ณผ',
        xaxis_title='์ฃผ์ œ',
        yaxis_title='๊ด€๋ จ๋„ (%)',
        yaxis_range=[0, 100],
        height=500,
        font=dict(size=14)
    )
    
    return fig

def process_results(results):
    if not results or not results[0]:
        return [], None
    
    topics = results[0]
    top_topics = sorted(topics, key=lambda x: x['score_mean'], reverse=True)[:5]
    
    formatted_topics = []
    for topic in top_topics:
        formatted_topic = {
            'label': translate_topics.get(topic['label'], topic['label']),
            'score': round(topic['score_mean'] * 100, 1),
            'confidence': round((1 - topic['score_std']) * 100, 1)
        }
        formatted_topics.append(formatted_topic)
    
    # ์ฐจํŠธ ์ƒ์„ฑ
    chart = create_chart(formatted_topics)
    
    return formatted_topics, chart

@spaces.GPU(enable_queue=True, duration=50)
def fn(inputs):
    raw_results = clf.suggest_topics(inputs)
    return process_results(raw_results)

# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
with gr.Blocks(title="๋ฌธ์„œ ์ฃผ์ œ ๋ถ„์„๊ธฐ") as demo:
    gr.Markdown("## ๐Ÿ“š ๋ฌธ์„œ ์ฃผ์ œ ๋ถ„์„๊ธฐ")
    gr.Markdown("๋ฌธ์„œ๋ฅผ ์ž…๋ ฅํ•˜๋ฉด ๊ด€๋ จ๋œ ์ฃผ์ œ๋“ค์„ ๋ถ„์„ํ•˜์—ฌ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.")
    
    with gr.Row():
        text = gr.Textbox(
            value=SAMPLE_TEXT,
            label="๋ถ„์„ํ•  ํ…์ŠคํŠธ",
            placeholder="์—ฌ๊ธฐ์— ๋ถ„์„ํ•  ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”",
            lines=5
        )
    
    with gr.Row():
        submit_btn = gr.Button("๋ถ„์„ ์‹œ์ž‘", variant="primary")
    
    with gr.Row():
        # ์ฐจํŠธ๋ฅผ ๋ณด์—ฌ์ค„ Plot ์ปดํฌ๋„ŒํŠธ ์ถ”๊ฐ€
        plot = gr.Plot(label="์ฃผ์ œ ๋ถ„์„ ์ฐจํŠธ")
        
    with gr.Row():
        output = gr.JSON(label="์ƒ์„ธ ๋ถ„์„ ๊ฒฐ๊ณผ")
    
    # ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
    submit_btn.click(
        fn=fn,
        inputs=[text],
        outputs=[output, plot]
    )

demo.launch(debug=True)