HanmunRoBERTa / app.py
winniehcy's picture
Change to show all prediction scores instead of just one of the predicted class
a92f417 verified
raw
history blame
2.98 kB
import streamlit as st
from transformers import pipeline
from string import punctuation
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
# Initialize or retrieve the session state variable
if 'plot_visible' not in st.session_state:
st.session_state.plot_visible = False # Initially, the plot is not visible
def strip_input_str(x):
characters_to_remove = "β—‹β–‘()〔〕:\"。·, ?ㆍ" + punctuation
translating = str.maketrans('', '', characters_to_remove)
x = x.translate(translating)
return x.strip()
# Load the pipeline with the HanmunRoBERTa model
model_pipeline = pipeline(task="text-classification", model="bdsl/HanmunRoBERTa", return_all_scores=True)
# Streamlit app layout
title = "HanmunRoBERTa Century Classifier"
st.set_page_config(page_title=title, page_icon="πŸ“š")
st.title(title)
# Checkbox to remove punctuation
remove_punct = st.checkbox(label="Remove punctuation", value=True)
# Text area for user input
input_str = st.text_area(
"Input text",
height=150,
value="權ηŸ₯ ι«˜ιΊ— εœ‹δΊ‹θ‡£ζŸθ¨€γ€‚ δΌζƒŸε°ι‚¦, θ‡ͺ ζ­ζ„ηŽ‹ η„‘ε—£θ–¨ι€δΉ‹εΎŒ, θΎ›ζ—½ 子 禑 ε†’ε§“η«Šδ½θ€….",
max_chars=500
)
if remove_punct and input_str:
input_str = strip_input_str(input_str)
st.write("Processed input:", input_str)
# Button to classify the text and toggle the visibility of the plot
if st.button("Classify"):
st.session_state.plot_visible = not st.session_state.plot_visible # Toggle the plot visibility
if input_str:
with st.spinner("Classifying..."):
predictions = model_pipeline(input_str, top_k=None)
data = pd.DataFrame(predictions)
data = data.sort_values(by='score', ascending=True)
data.label = data.label.astype(str)
# Ensure the plot is only displayed when `plot_visible` is True
if st.session_state.plot_visible:
colors = px.colors.qualitative.Plotly
fig = go.Figure(
go.Bar(
x=data.score.values,
y=[f'{i}th Century' for i in data.label.values],
orientation='h',
text=[f'{score:.3f}' for score in data['score'].values],
textposition='outside',
hoverinfo='text',
hovertext=[f'{i}th Century<br>Score: {score:.3f}' for i, score in zip(data['label'], data['score'])],
marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]),
))
fig.update_layout(
height=300,
xaxis_title='Score',
yaxis_title='',
title='Model predictions and scores',
uniformtext_minsize=8,
uniformtext_mode='hide',
)
st.plotly_chart(figure_or_data=fig, use_container_width=True)
st.session_state.plot_visible = False # Reset to False after displaying