yenniejun commited on
Commit
adf186a
Β·
1 Parent(s): fc75bb7

Adding plotly plot

Browse files
Files changed (1) hide show
  1. app.py +59 -39
app.py CHANGED
@@ -42,42 +42,62 @@ if remove_punct and input_str:
42
  translating = str.maketrans('', '', characters_to_remove)
43
  input_str = input_str.translate(translating)
44
 
45
- # Display the input text after processing
46
- st.write("Processed input:", input_str)
47
-
48
- # Predict and display the classification scores if input is provided
49
- if st.button("Classify"):
50
- if input_str:
51
- predictions = model_pipeline(input_str)
52
- data = pd.DataFrame(predictions)
53
- data=data.sort_values(by='score', ascending=True)
54
- data.label = data.label.astype(str)
55
-
56
-
57
- # Displaying predictions as a bar chart
58
- fig = go.Figure(
59
- go.Bar(
60
- x=data.score.values,
61
- y=[f'{i}th Century' for i in data.label.values],
62
- orientation='h',
63
- text=[f'{score:.3f}' for score in data['score'].values], # Format text with 2 decimal points
64
- textposition='outside', # Position the text outside the bars
65
- hoverinfo='text', # Use custom text for hover info
66
- hovertext=[f'{i}th Century<br>Score: {score:.3f}' for i, score in zip(data['label'], data['score'])], # Custom hover text
67
- marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]), # Cycle through colors
68
-
69
- ))
70
- fig.update_traces(width=0.4)
71
-
72
- fig.update_layout(
73
- height=300, # Custom height
74
- xaxis_title='Score',
75
- yaxis_title='',
76
- title='Model predictions and scores',
77
- margin=dict(l=100, r=200, t=50, b=50),
78
- uniformtext_minsize=8,
79
- uniformtext_mode='hide',
80
- )
81
- st.plotly_chart(fig=fig)
82
- else:
83
- st.write("Please enter some text to classify.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  translating = str.maketrans('', '', characters_to_remove)
43
  input_str = input_str.translate(translating)
44
 
45
+
46
+ # Create a two-column layout
47
+ col1, col2 = st.columns([2, 3]) # Adjust the width ratio as needed
48
+
49
+ with col1:
50
+ # Checkbox to remove punctuation
51
+ remove_punct = st.checkbox(label="Remove punctuation", value=True)
52
+
53
+ # Text area for user input
54
+ input_str = st.text_area("Input text", height=275)
55
+
56
+ # Remove punctuation if checkbox is selected
57
+ if remove_punct and input_str:
58
+ # Specify the characters to remove
59
+ characters_to_remove = "β—‹β–‘()〔〕:\"。·, ?ㆍ" + punctuation
60
+ translating = str.maketrans('', '', characters_to_remove)
61
+ input_str = input_str.translate(translating)
62
+
63
+ # Display the input text after processing
64
+ st.write("Processed input:", input_str)
65
+
66
+ # Button for prediction
67
+ classify_button = st.button("Classify")
68
+
69
+ # Predict and display the classification scores if input is provided and button is clicked
70
+ if classify_button and input_str:
71
+ predictions = model_pipeline(input_str)
72
+ data = pd.DataFrame(predictions)
73
+ data = data.sort_values(by='score', ascending=True)
74
+ data.label = data.label.astype(str)
75
+
76
+ # Displaying predictions as a bar chart
77
+ fig = go.Figure(
78
+ go.Bar(
79
+ x=data.score.values,
80
+ y=[f'{i}th Century' for i in data.label.values],
81
+ orientation='h',
82
+ text=[f'{score:.3f}' for score in data['score'].values],
83
+ textposition='outside',
84
+ hoverinfo='text',
85
+ hovertext=[f'{i}th Century<br>Score: {score:.3f}' for i, score in zip(data['label'], data['score'])],
86
+ marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]),
87
+ ))
88
+ fig.update_traces(width=0.4)
89
+ fig.update_layout(
90
+ height=300, # Custom height
91
+ xaxis_title='Score',
92
+ yaxis_title='',
93
+ title='Model predictions and scores',
94
+ margin=dict(l=100, r=200, t=50, b=50),
95
+ uniformtext_minsize=8,
96
+ uniformtext_mode='hide',
97
+ )
98
+
99
+ with col2:
100
+ st.plotly_chart(fig, use_container_width=True)
101
+ else:
102
+ with col2:
103
+ st.write("Please enter some text to classify and click 'Classify'.")