Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import gradio as gr | |
import spaces | |
import wbgtopic | |
import plotly.graph_objects as go | |
from topic_translator import translate_topics | |
SAMPLE_TEXT = """A growing literature attributes gender inequality in labor market outcomes in part to the reduction in female labor supply after childbirth, the child penalty...""" | |
clf = wbgtopic.WBGDocTopic() | |
def create_chart(topics): | |
# ๋ฐ์ดํฐ ์ค๋น | |
labels = [t['label'] for t in topics] | |
scores = [t['score'] for t in topics] | |
confidence = [t['confidence'] for t in topics] | |
# ๋ง๋ ์ฐจํธ ์์ฑ | |
fig = go.Figure() | |
# ์ฃผ์ ๋ง๋ ์ถ๊ฐ | |
fig.add_trace(go.Bar( | |
x=labels, | |
y=scores, | |
name='๊ด๋ จ๋', | |
marker_color='rgb(55, 83, 109)' | |
)) | |
# ์ฐจํธ ๋ ์ด์์ ์ค์ | |
fig.update_layout( | |
title='๋ฌธ์ ์ฃผ์ ๋ถ์ ๊ฒฐ๊ณผ', | |
xaxis_title='์ฃผ์ ', | |
yaxis_title='๊ด๋ จ๋ (%)', | |
yaxis_range=[0, 100], | |
height=500, | |
font=dict(size=14) | |
) | |
return fig | |
def process_results(results): | |
if not results or not results[0]: | |
return [], None | |
topics = results[0] | |
top_topics = sorted(topics, key=lambda x: x['score_mean'], reverse=True)[:5] | |
formatted_topics = [] | |
for topic in top_topics: | |
formatted_topic = { | |
'label': translate_topics.get(topic['label'], topic['label']), | |
'score': round(topic['score_mean'] * 100, 1), | |
'confidence': round((1 - topic['score_std']) * 100, 1) | |
} | |
formatted_topics.append(formatted_topic) | |
# ์ฐจํธ ์์ฑ | |
chart = create_chart(formatted_topics) | |
return formatted_topics, chart | |
def fn(inputs): | |
raw_results = clf.suggest_topics(inputs) | |
return process_results(raw_results) | |
# Gradio ์ธํฐํ์ด์ค ์์ฑ | |
with gr.Blocks(title="๋ฌธ์ ์ฃผ์ ๋ถ์๊ธฐ") as demo: | |
gr.Markdown("## ๐ ๋ฌธ์ ์ฃผ์ ๋ถ์๊ธฐ") | |
gr.Markdown("๋ฌธ์๋ฅผ ์ ๋ ฅํ๋ฉด ๊ด๋ จ๋ ์ฃผ์ ๋ค์ ๋ถ์ํ์ฌ ๋ณด์ฌ์ค๋๋ค.") | |
with gr.Row(): | |
text = gr.Textbox( | |
value=SAMPLE_TEXT, | |
label="๋ถ์ํ ํ ์คํธ", | |
placeholder="์ฌ๊ธฐ์ ๋ถ์ํ ํ ์คํธ๋ฅผ ์ ๋ ฅํ์ธ์", | |
lines=5 | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("๋ถ์ ์์", variant="primary") | |
with gr.Row(): | |
# ์ฐจํธ๋ฅผ ๋ณด์ฌ์ค Plot ์ปดํฌ๋ํธ ์ถ๊ฐ | |
plot = gr.Plot(label="์ฃผ์ ๋ถ์ ์ฐจํธ") | |
with gr.Row(): | |
output = gr.JSON(label="์์ธ ๋ถ์ ๊ฒฐ๊ณผ") | |
# ์ด๋ฒคํธ ์ฐ๊ฒฐ | |
submit_btn.click( | |
fn=fn, | |
inputs=[text], | |
outputs=[output, plot] | |
) | |
demo.launch(debug=True) |