Spaces:
Runtime error
Runtime error
add cache
Browse files
app.py
CHANGED
|
@@ -198,7 +198,7 @@ def topic_distribution(weights, smoothing=0.01):
|
|
| 198 |
# return(topic_frequencies[category], topic_frequencies_spotlight[category], topic_ratios[category], category)
|
| 199 |
|
| 200 |
def populate_session(dataset,model):
|
| 201 |
-
data_df =
|
| 202 |
if model == 'albert-base-v2-yelp-polarity':
|
| 203 |
tokenizer = AutoTokenizer.from_pretrained('textattack/'+model)
|
| 204 |
else:
|
|
@@ -208,7 +208,9 @@ def populate_session(dataset,model):
|
|
| 208 |
if "selected_slice" not in st.session_state:
|
| 209 |
st.session_state["selected_slice"] = None
|
| 210 |
|
| 211 |
-
|
|
|
|
|
|
|
| 212 |
|
| 213 |
if __name__ == "__main__":
|
| 214 |
### STREAMLIT APP CONGFIG ###
|
|
@@ -235,7 +237,7 @@ if __name__ == "__main__":
|
|
| 235 |
### LOAD DATA AND SESSION VARIABLES ###
|
| 236 |
##uncomment the next next line to run dynamically and not from file
|
| 237 |
#populate_session(dataset, model)
|
| 238 |
-
data_df =
|
| 239 |
loss_quantile = st.sidebar.slider(
|
| 240 |
"Loss Quantile", min_value=0.5, max_value=1.0,step=0.01,value=0.95
|
| 241 |
)
|
|
@@ -250,7 +252,7 @@ if __name__ == "__main__":
|
|
| 250 |
st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
|
| 251 |
#uncomment the next two lines to run dynamically and not from file
|
| 252 |
#commontokens = frequent_tokens(data_df, tokenizer, loss_quantile=loss_quantile)
|
| 253 |
-
commontokens =
|
| 254 |
with st.expander("How to read the table:"):
|
| 255 |
st.markdown("* The table displays the most frequent tokens in error slices, relative to their frequencies in the val set.")
|
| 256 |
st.write(commontokens)
|
|
@@ -260,20 +262,22 @@ if __name__ == "__main__":
|
|
| 260 |
num_clusters = st.sidebar.slider("# clusters", min_value=1, max_value=20, step=1, value=3)
|
| 261 |
|
| 262 |
if run_kmeans == 'True':
|
| 263 |
-
|
|
|
|
| 264 |
with lcol:
|
| 265 |
st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
#uncomment the next next line to run dynamically and not from file
|
| 268 |
# dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
|
| 269 |
# by=['loss'], ascending=False)
|
| 270 |
# table_html = dataframe.to_html(
|
| 271 |
# columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
|
| 272 |
# table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
|
| 273 |
-
with st.expander("How to read the table:"):
|
| 274 |
-
st.markdown("* *Error slice* refers to the subset of evaluation dataset the model performs poorly on.")
|
| 275 |
-
st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
|
| 276 |
-
st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
|
| 277 |
st.write(dataframe,width=900, height=300)
|
| 278 |
|
| 279 |
quant_panel(merged)
|
|
|
|
| 198 |
# return(topic_frequencies[category], topic_frequencies_spotlight[category], topic_ratios[category], category)
|
| 199 |
|
| 200 |
def populate_session(dataset,model):
|
| 201 |
+
data_df = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'.parquet')
|
| 202 |
if model == 'albert-base-v2-yelp-polarity':
|
| 203 |
tokenizer = AutoTokenizer.from_pretrained('textattack/'+model)
|
| 204 |
else:
|
|
|
|
| 208 |
if "selected_slice" not in st.session_state:
|
| 209 |
st.session_state["selected_slice"] = None
|
| 210 |
|
| 211 |
+
@st.cache(ttl=600)
|
| 212 |
+
def read_file_to_df(file):
|
| 213 |
+
return pd.read_parquet(file)
|
| 214 |
|
| 215 |
if __name__ == "__main__":
|
| 216 |
### STREAMLIT APP CONGFIG ###
|
|
|
|
| 237 |
### LOAD DATA AND SESSION VARIABLES ###
|
| 238 |
##uncomment the next next line to run dynamically and not from file
|
| 239 |
#populate_session(dataset, model)
|
| 240 |
+
data_df = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'.parquet')
|
| 241 |
loss_quantile = st.sidebar.slider(
|
| 242 |
"Loss Quantile", min_value=0.5, max_value=1.0,step=0.01,value=0.95
|
| 243 |
)
|
|
|
|
| 252 |
st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
|
| 253 |
#uncomment the next two lines to run dynamically and not from file
|
| 254 |
#commontokens = frequent_tokens(data_df, tokenizer, loss_quantile=loss_quantile)
|
| 255 |
+
commontokens = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'_commontokens.parquet')
|
| 256 |
with st.expander("How to read the table:"):
|
| 257 |
st.markdown("* The table displays the most frequent tokens in error slices, relative to their frequencies in the val set.")
|
| 258 |
st.write(commontokens)
|
|
|
|
| 262 |
num_clusters = st.sidebar.slider("# clusters", min_value=1, max_value=20, step=1, value=3)
|
| 263 |
|
| 264 |
if run_kmeans == 'True':
|
| 265 |
+
with st.spinner(text='running kmeans...'):
|
| 266 |
+
merged = kmeans(data_df,num_clusters=num_clusters)
|
| 267 |
with lcol:
|
| 268 |
st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
|
| 269 |
+
with st.expander("How to read the table:"):
|
| 270 |
+
st.markdown("* *Error slice* refers to the subset of evaluation dataset the model performs poorly on.")
|
| 271 |
+
st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
|
| 272 |
+
st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
|
| 273 |
+
with st.spinner(text='loading error slice...'):
|
| 274 |
+
dataframe=read_file_to_df('./assets/data/'+dataset+ '_'+ model+'_error-slices.parquet')
|
| 275 |
#uncomment the next next line to run dynamically and not from file
|
| 276 |
# dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
|
| 277 |
# by=['loss'], ascending=False)
|
| 278 |
# table_html = dataframe.to_html(
|
| 279 |
# columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
|
| 280 |
# table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
st.write(dataframe,width=900, height=300)
|
| 282 |
|
| 283 |
quant_panel(merged)
|