openfree's picture
Update app.py
bc928c9 verified
raw
history blame
2.76 kB
import json
import gradio as gr
import spaces
import wbgtopic
import plotly.graph_objects as go
from topic_translator import translate_topics
SAMPLE_TEXT = """A growing literature attributes gender inequality in labor market outcomes in part to the reduction in female labor supply after childbirth, the child penalty..."""
clf = wbgtopic.WBGDocTopic()
def create_chart(topics):
# ๋ฐ์ดํ„ฐ ์ค€๋น„
labels = [t['label'] for t in topics]
scores = [t['score'] for t in topics]
confidence = [t['confidence'] for t in topics]
# ๋ง‰๋Œ€ ์ฐจํŠธ ์ƒ์„ฑ
fig = go.Figure()
# ์ฃผ์š” ๋ง‰๋Œ€ ์ถ”๊ฐ€
fig.add_trace(go.Bar(
x=labels,
y=scores,
name='๊ด€๋ จ๋„',
marker_color='rgb(55, 83, 109)'
))
# ์ฐจํŠธ ๋ ˆ์ด์•„์›ƒ ์„ค์ •
fig.update_layout(
title='๋ฌธ์„œ ์ฃผ์ œ ๋ถ„์„ ๊ฒฐ๊ณผ',
xaxis_title='์ฃผ์ œ',
yaxis_title='๊ด€๋ จ๋„ (%)',
yaxis_range=[0, 100],
height=500,
font=dict(size=14)
)
return fig
def process_results(results):
if not results or not results[0]:
return [], None
topics = results[0]
top_topics = sorted(topics, key=lambda x: x['score_mean'], reverse=True)[:5]
formatted_topics = []
for topic in top_topics:
formatted_topic = {
'label': translate_topics.get(topic['label'], topic['label']),
'score': round(topic['score_mean'] * 100, 1),
'confidence': round((1 - topic['score_std']) * 100, 1)
}
formatted_topics.append(formatted_topic)
# ์ฐจํŠธ ์ƒ์„ฑ
chart = create_chart(formatted_topics)
return formatted_topics, chart
@spaces.GPU(enable_queue=True, duration=50)
def fn(inputs):
raw_results = clf.suggest_topics(inputs)
return process_results(raw_results)
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
with gr.Blocks(title="๋ฌธ์„œ ์ฃผ์ œ ๋ถ„์„๊ธฐ") as demo:
gr.Markdown("## ๐Ÿ“š ๋ฌธ์„œ ์ฃผ์ œ ๋ถ„์„๊ธฐ")
gr.Markdown("๋ฌธ์„œ๋ฅผ ์ž…๋ ฅํ•˜๋ฉด ๊ด€๋ จ๋œ ์ฃผ์ œ๋“ค์„ ๋ถ„์„ํ•˜์—ฌ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.")
with gr.Row():
text = gr.Textbox(
value=SAMPLE_TEXT,
label="๋ถ„์„ํ•  ํ…์ŠคํŠธ",
placeholder="์—ฌ๊ธฐ์— ๋ถ„์„ํ•  ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”",
lines=5
)
with gr.Row():
submit_btn = gr.Button("๋ถ„์„ ์‹œ์ž‘", variant="primary")
with gr.Row():
# ์ฐจํŠธ๋ฅผ ๋ณด์—ฌ์ค„ Plot ์ปดํฌ๋„ŒํŠธ ์ถ”๊ฐ€
plot = gr.Plot(label="์ฃผ์ œ ๋ถ„์„ ์ฐจํŠธ")
with gr.Row():
output = gr.JSON(label="์ƒ์„ธ ๋ถ„์„ ๊ฒฐ๊ณผ")
# ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
submit_btn.click(
fn=fn,
inputs=[text],
outputs=[output, plot]
)
demo.launch(debug=True)