yenniejun commited on
Commit
55eef5c
Β·
1 Parent(s): 4e5b267

Adding plotly plot remove plot if loading

Browse files
Files changed (1) hide show
  1. app.py +36 -51
app.py CHANGED
@@ -1,35 +1,20 @@
1
- """
2
- HuggingFace Spaces that:
3
- - loads in HanmunRoBERTa model https://huggingface.co/bdsl/HanmunRoBERTa
4
- - optionally strips text of punctuation and unwanted charactesr
5
- - predicts century for the input text
6
- - Visualizes prediction scores for each century
7
-
8
- # https://huggingface.co/blog/streamlit-spaces
9
- # https://huggingface.co/docs/hub/en/spaces-sdks-streamlit
10
-
11
- """
12
-
13
  import streamlit as st
14
  from transformers import pipeline
15
  from string import punctuation
16
  import pandas as pd
17
  import plotly.express as px
18
  import plotly.graph_objects as go
19
- colors = px.colors.qualitative.Plotly
20
 
21
- # from huggingface_hub import InferenceClient
22
- # client = InferenceClient(model="bdsl/HanmunRoBERTa")
 
23
 
24
  def strip_input_str(x):
25
- # Specify the characters to remove
26
  characters_to_remove = "β—‹β–‘()〔〕:\"。·, ?ㆍ" + punctuation
27
  translating = str.maketrans('', '', characters_to_remove)
28
  x = x.translate(translating)
29
-
30
  return x.strip()
31
 
32
-
33
  # Load the pipeline with the HanmunRoBERTa model
34
  model_pipeline = pipeline(task="text-classification", model="bdsl/HanmunRoBERTa")
35
 
@@ -47,42 +32,42 @@ input_str = st.text_area(
47
  height=150,
48
  value="權ηŸ₯ ι«˜ιΊ— εœ‹δΊ‹θ‡£ζŸθ¨€γ€‚ δΌζƒŸε°ι‚¦, θ‡ͺ ζ­ζ„ηŽ‹ η„‘ε—£θ–¨ι€δΉ‹εΎŒ, θΎ›ζ—½ 子 禑 ε†’ε§“η«Šδ½θ€….")
49
 
50
- # Remove punctuation if checkbox is selected
51
  if remove_punct and input_str:
52
  input_str = strip_input_str(input_str)
53
-
54
- # Display the input text after processing
55
  st.write("Processed input:", input_str)
56
 
57
- # Predict and display the classification scores if input is provided
58
  if st.button("Classify"):
 
59
  if input_str:
60
- predictions = model_pipeline(input_str, top_k=None)
61
- data = pd.DataFrame(predictions)
62
- data=data.sort_values(by='score', ascending=True)
63
- data.label = data.label.astype(str)
64
-
65
-
66
- # Displaying predictions as a bar chart
67
- fig = go.Figure(
68
- go.Bar(
69
- x=data.score.values,
70
- y=[f'{i}th Century' for i in data.label.values],
71
- orientation='h',
72
- text=[f'{score:.3f}' for score in data['score'].values], # Format text with 2 decimal points
73
- textposition='outside', # Position the text outside the bars
74
- hoverinfo='text', # Use custom text for hover info
75
- hovertext=[f'{i}th Century<br>Score: {score:.3f}' for i, score in zip(data['label'], data['score'])], # Custom hover text
76
- marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]), # Cycle through colors
77
-
78
- ))
79
-
80
- fig.update_layout(
81
- height=300, # Custom height
82
- xaxis_title='Score',
83
- yaxis_title='',
84
- title='Model predictions and scores',
85
- uniformtext_minsize=8,
86
- uniformtext_mode='hide',
87
- )
88
- st.plotly_chart(figure_or_data=fig, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  from string import punctuation
4
  import pandas as pd
5
  import plotly.express as px
6
  import plotly.graph_objects as go
 
7
 
8
+ # Initialize or retrieve the session state variable
9
+ if 'plot_visible' not in st.session_state:
10
+ st.session_state.plot_visible = False # Initially, the plot is not visible
11
 
12
  def strip_input_str(x):
 
13
  characters_to_remove = "β—‹β–‘()〔〕:\"。·, ?ㆍ" + punctuation
14
  translating = str.maketrans('', '', characters_to_remove)
15
  x = x.translate(translating)
 
16
  return x.strip()
17
 
 
18
  # Load the pipeline with the HanmunRoBERTa model
19
  model_pipeline = pipeline(task="text-classification", model="bdsl/HanmunRoBERTa")
20
 
 
32
  height=150,
33
  value="權ηŸ₯ ι«˜ιΊ— εœ‹δΊ‹θ‡£ζŸθ¨€γ€‚ δΌζƒŸε°ι‚¦, θ‡ͺ ζ­ζ„ηŽ‹ η„‘ε—£θ–¨ι€δΉ‹εΎŒ, θΎ›ζ—½ 子 禑 ε†’ε§“η«Šδ½θ€….")
34
 
 
35
  if remove_punct and input_str:
36
  input_str = strip_input_str(input_str)
 
 
37
  st.write("Processed input:", input_str)
38
 
39
+ # Button to classify the text and toggle the visibility of the plot
40
  if st.button("Classify"):
41
+ st.session_state.plot_visible = not st.session_state.plot_visible # Toggle the plot visibility
42
  if input_str:
43
+ with st.spinner("Classifying..."):
44
+ predictions = model_pipeline(input_str, top_k=None)
45
+ data = pd.DataFrame(predictions)
46
+ data = data.sort_values(by='score', ascending=True)
47
+ data.label = data.label.astype(str)
48
+
49
+ # Ensure the plot is only displayed when `plot_visible` is True
50
+ if st.session_state.plot_visible:
51
+ colors = px.colors.qualitative.Plotly
52
+ fig = go.Figure(
53
+ go.Bar(
54
+ x=data.score.values,
55
+ y=[f'{i}th Century' for i in data.label.values],
56
+ orientation='h',
57
+ text=[f'{score:.3f}' for score in data['score'].values],
58
+ textposition='outside',
59
+ hoverinfo='text',
60
+ hovertext=[f'{i}th Century<br>Score: {score:.3f}' for i, score in zip(data['label'], data['score'])],
61
+ marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]),
62
+ ))
63
+
64
+ fig.update_layout(
65
+ height=300,
66
+ xaxis_title='Score',
67
+ yaxis_title='',
68
+ title='Model predictions and scores',
69
+ uniformtext_minsize=8,
70
+ uniformtext_mode='hide',
71
+ )
72
+ st.plotly_chart(figure_or_data=fig, use_container_width=True)
73
+ st.session_state.plot_visible = False # Reset to False after displaying