Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	k means clustering
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,6 +1,5 @@ | |
| 1 | 
             
                ### LIBRARIES ###
         | 
| 2 | 
             
            # # Data
         | 
| 3 | 
            -
            from matplotlib.pyplot import legend
         | 
| 4 | 
             
            import numpy as np
         | 
| 5 | 
             
            import pandas as pd
         | 
| 6 | 
             
            import torch
         | 
| @@ -10,11 +9,15 @@ from math import floor | |
| 10 | 
             
            from datasets import load_dataset
         | 
| 11 | 
             
            from collections import defaultdict
         | 
| 12 | 
             
            from transformers import AutoTokenizer
         | 
|  | |
| 13 |  | 
| 14 | 
             
            # Analysis
         | 
| 15 | 
             
            # from gensim.models.doc2vec import Doc2Vec
         | 
| 16 | 
             
            # from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
         | 
| 17 | 
            -
             | 
|  | |
|  | |
|  | |
| 18 | 
             
            # nltk.download('punkt') #make sure that punkt is downloaded
         | 
| 19 |  | 
| 20 | 
             
            # App & Visualization
         | 
| @@ -23,11 +26,11 @@ import altair as alt | |
| 23 | 
             
            import plotly.graph_objects as go
         | 
| 24 | 
             
            from streamlit_vega_lite import altair_component
         | 
| 25 |  | 
|  | |
|  | |
| 26 | 
             
            # utils
         | 
| 27 | 
             
            from random import sample
         | 
| 28 | 
            -
            from  | 
| 29 | 
            -
            import os
         | 
| 30 | 
            -
             | 
| 31 |  | 
| 32 |  | 
| 33 | 
             
            def down_samp(embedding):
         | 
| @@ -61,12 +64,14 @@ def down_samp(embedding): | |
| 61 | 
             
            def data_comparison(df):
         | 
| 62 | 
             
                # set up a dropdown select bindinf
         | 
| 63 | 
             
                # input_dropdown = alt.binding_select(options=['Negative Sentiment','Positive Sentiment'])
         | 
| 64 | 
            -
             | 
| 65 | 
            -
             | 
|  | |
|  | |
| 66 | 
             
                # color = alt.condition(selection, 
         | 
| 67 | 
            -
                # | 
| 68 | 
            -
                # | 
| 69 | 
            -
                # | 
| 70 | 
             
                opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25))
         | 
| 71 |  | 
| 72 | 
             
                # basic chart
         | 
| @@ -75,7 +80,7 @@ def data_comparison(df): | |
| 75 | 
             
                    y=alt.Y('y', axis=None),
         | 
| 76 | 
             
                    color=color,
         | 
| 77 | 
             
                    shape=alt.Shape('label', scale=alt.Scale(range=['circle', 'diamond'])),
         | 
| 78 | 
            -
                    tooltip=['slice','content','label','pred'],
         | 
| 79 | 
             
                    opacity=opacity
         | 
| 80 | 
             
                ).properties(
         | 
| 81 | 
             
                    width=1500,
         | 
| @@ -83,28 +88,21 @@ def data_comparison(df): | |
| 83 | 
             
                ).interactive()
         | 
| 84 |  | 
| 85 | 
             
                legend = alt.Chart(df).mark_point().encode(
         | 
| 86 | 
            -
                    y=alt.Y(' | 
| 87 | 
             
                    x=alt.X("label"),
         | 
| 88 | 
             
                    shape=alt.Shape('label', scale=alt.Scale(
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                    color=color
         | 
| 91 | 
             
                ).add_selection(
         | 
| 92 | 
             
                    selection
         | 
| 93 | 
             
                )
         | 
| 94 | 
            -
             | 
| 95 | 
            -
                layered =  | 
| 96 |  | 
| 97 | 
             
                layered = layered.configure_axis(
         | 
| 98 | 
             
                    grid=False
         | 
| 99 | 
             
                ).configure_view(
         | 
| 100 | 
             
                    strokeOpacity=0
         | 
| 101 | 
            -
                ).configure_legend(
         | 
| 102 | 
            -
                    strokeColor='gray',
         | 
| 103 | 
            -
                    fillColor='#EEEEEE',
         | 
| 104 | 
            -
                    padding=10,
         | 
| 105 | 
            -
                    cornerRadius=10,
         | 
| 106 | 
            -
                    orient='top-right'
         | 
| 107 | 
            -
             | 
| 108 | 
             
                )
         | 
| 109 |  | 
| 110 | 
             
                return layered
         | 
| @@ -166,7 +164,36 @@ def get_data(spotlight, emb): | |
| 166 | 
             
                return pd.concat([pd.DataFrame(np.transpose(np.vstack([dataset[:num_examples]['content'], 
         | 
| 167 | 
             
                                dataset[:num_examples]['label'], preds, losses])), columns=['content', 'label', 'pred', 'loss']), embeddings], axis=1)
         | 
| 168 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 169 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 170 | 
             
            def topic_distribution(weights, smoothing=0.01):
         | 
| 171 | 
             
                topic_frequencies = defaultdict(float)
         | 
| 172 | 
             
                topic_frequencies_spotlight = defaultdict(float)
         | 
| @@ -196,15 +223,10 @@ def topic_distribution(weights, smoothing=0.01): | |
| 196 |  | 
| 197 | 
             
            if __name__ == "__main__":
         | 
| 198 | 
             
                ### STREAMLIT APP CONGFIG ###
         | 
| 199 | 
            -
                os.system("pip --ignore-installed streamlit ")
         | 
| 200 | 
             
                st.set_page_config(layout="wide", page_title="Error Slice Analysis")
         | 
| 201 |  | 
| 202 | 
            -
                 | 
| 203 | 
            -
             | 
| 204 | 
            -
                lcol, rcol = st.columns([2, 3])
         | 
| 205 | 
             
                # ******* loading the mode and the data
         | 
| 206 | 
            -
                with st.sidebar:
         | 
| 207 | 
            -
                    st.title('Error Analysis')
         | 
| 208 | 
             
                dataset = st.sidebar.selectbox(
         | 
| 209 | 
             
                    "Dataset",
         | 
| 210 | 
             
                    ["amazon_polarity", "squad", "movielens", "waterbirds"],
         | 
| @@ -221,15 +243,19 @@ if __name__ == "__main__": | |
| 221 | 
             
                    index=0
         | 
| 222 | 
             
                )
         | 
| 223 |  | 
| 224 | 
            -
                loss_quantile = st.sidebar. | 
| 225 | 
            -
                    "Loss Quantile",
         | 
| 226 | 
            -
                    [0.98, 0.95, 0.9, 0.8, 0.75],
         | 
| 227 | 
            -
                    index = 1
         | 
| 228 | 
             
                )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 229 | 
             
                ### LOAD DATA AND SESSION VARIABLES ###
         | 
| 230 | 
            -
                 | 
| 231 | 
            -
                 | 
| 232 | 
            -
                 | 
|  | |
| 233 | 
             
                if "user_data" not in st.session_state:
         | 
| 234 | 
             
                    st.session_state["user_data"] = data_df
         | 
| 235 | 
             
                if "selected_slice" not in st.session_state:
         | 
| @@ -237,26 +263,30 @@ if __name__ == "__main__": | |
| 237 | 
             
                if "embedding" not in st.session_state:
         | 
| 238 | 
             
                    st.session_state["embedding"] = embedding_umap
         | 
| 239 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 240 | 
             
                with lcol:
         | 
| 241 | 
             
                    st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
         | 
| 242 | 
            -
                    dataframe =  | 
| 243 | 
             
                        by=['loss'], ascending=False)
         | 
| 244 | 
             
                    table_html = dataframe.to_html(
         | 
| 245 | 
            -
                        columns=['content', 'label', 'pred', 'loss'], max_rows= | 
| 246 | 
             
                    # table_html = table_html.replace("<th>", '<th align="left">')  # left-align the headers
         | 
| 247 | 
             
                    st.write(dataframe)
         | 
| 248 | 
            -
                    st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
         | 
| 249 | 
            -
                    commontokens = frequent_tokens(data_df, tokenizer, loss_quantile=loss_quantile)
         | 
| 250 | 
            -
                    st.write(commontokens)
         | 
| 251 | 
             
                # st_aggrid.AgGrid(dataframe)
         | 
| 252 | 
             
                # table_html = dataframe.to_html(columns=['content', 'label', 'pred', 'loss'], max_rows=100)
         | 
| 253 | 
             
                # table_html = table_html.replace("<th>", '<th align="left">')  # left-align the headers
         | 
| 254 | 
             
                # st.write(table_html)
         | 
| 255 |  | 
| 256 | 
            -
                with rcol: | 
| 257 | 
            -
                     | 
| 258 | 
            -
                     | 
| 259 | 
            -
                     | 
| 260 | 
            -
             | 
| 261 | 
            -
             | 
| 262 | 
            -
                    quant_panel(data_df)
         | 
|  | |
| 1 | 
             
                ### LIBRARIES ###
         | 
| 2 | 
             
            # # Data
         | 
|  | |
| 3 | 
             
            import numpy as np
         | 
| 4 | 
             
            import pandas as pd
         | 
| 5 | 
             
            import torch
         | 
|  | |
| 9 | 
             
            from datasets import load_dataset
         | 
| 10 | 
             
            from collections import defaultdict
         | 
| 11 | 
             
            from transformers import AutoTokenizer
         | 
| 12 | 
            +
            pd.options.display.float_format = '${:,.2f}'.format
         | 
| 13 |  | 
| 14 | 
             
            # Analysis
         | 
| 15 | 
             
            # from gensim.models.doc2vec import Doc2Vec
         | 
| 16 | 
             
            # from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
         | 
| 17 | 
            +
            import nltk
         | 
| 18 | 
            +
            from nltk.cluster import KMeansClusterer
         | 
| 19 | 
            +
            import scipy.spatial.distance as sdist
         | 
| 20 | 
            +
            from scipy.spatial import distance_matrix
         | 
| 21 | 
             
            # nltk.download('punkt') #make sure that punkt is downloaded
         | 
| 22 |  | 
| 23 | 
             
            # App & Visualization
         | 
|  | |
| 26 | 
             
            import plotly.graph_objects as go
         | 
| 27 | 
             
            from streamlit_vega_lite import altair_component
         | 
| 28 |  | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
             
            # utils
         | 
| 32 | 
             
            from random import sample
         | 
| 33 | 
            +
            # from PIL import Image
         | 
|  | |
|  | |
| 34 |  | 
| 35 |  | 
| 36 | 
             
            def down_samp(embedding):
         | 
|  | |
| 64 | 
             
            def data_comparison(df):
         | 
| 65 | 
             
                # set up a dropdown select bindinf
         | 
| 66 | 
             
                # input_dropdown = alt.binding_select(options=['Negative Sentiment','Positive Sentiment'])
         | 
| 67 | 
            +
                    #data_kmeans['distance_from_centroid'] = data_kmeans.apply(distance_from_centroid, axis=1)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                selection = alt.selection_multi(fields=['cluster','label'])
         | 
| 70 | 
            +
                color = alt.condition(alt.datum.slice == 'high-loss', alt.Color('cluster:N', scale = alt.Scale(domain=df.cluster.tolist())), alt.value("lightgray"))
         | 
| 71 | 
             
                # color = alt.condition(selection, 
         | 
| 72 | 
            +
                #                        alt.Color('cluster:Q', legend=None),
         | 
| 73 | 
            +
                #                         # scale = alt.Scale(domain = pop_domain,range=color_range)),
         | 
| 74 | 
            +
                #                         alt.value('lightgray'))
         | 
| 75 | 
             
                opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25))
         | 
| 76 |  | 
| 77 | 
             
                # basic chart
         | 
|  | |
| 80 | 
             
                    y=alt.Y('y', axis=None),
         | 
| 81 | 
             
                    color=color,
         | 
| 82 | 
             
                    shape=alt.Shape('label', scale=alt.Scale(range=['circle', 'diamond'])),
         | 
| 83 | 
            +
                    tooltip=['cluster','slice','content','label','pred'],
         | 
| 84 | 
             
                    opacity=opacity
         | 
| 85 | 
             
                ).properties(
         | 
| 86 | 
             
                    width=1500,
         | 
|  | |
| 88 | 
             
                ).interactive()
         | 
| 89 |  | 
| 90 | 
             
                legend = alt.Chart(df).mark_point().encode(
         | 
| 91 | 
            +
                    y=alt.Y('cluster:O', axis=alt.Axis(orient='right'), title=""),
         | 
| 92 | 
             
                    x=alt.X("label"),
         | 
| 93 | 
             
                    shape=alt.Shape('label', scale=alt.Scale(
         | 
| 94 | 
            +
                    range=['circle', 'diamond']), legend=None),
         | 
| 95 | 
            +
                    color=color,
         | 
| 96 | 
             
                ).add_selection(
         | 
| 97 | 
             
                    selection
         | 
| 98 | 
             
                )
         | 
| 99 | 
            +
                
         | 
| 100 | 
            +
                layered = scatter |legend 
         | 
| 101 |  | 
| 102 | 
             
                layered = layered.configure_axis(
         | 
| 103 | 
             
                    grid=False
         | 
| 104 | 
             
                ).configure_view(
         | 
| 105 | 
             
                    strokeOpacity=0
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 106 | 
             
                )
         | 
| 107 |  | 
| 108 | 
             
                return layered
         | 
|  | |
| 164 | 
             
                return pd.concat([pd.DataFrame(np.transpose(np.vstack([dataset[:num_examples]['content'], 
         | 
| 165 | 
             
                                dataset[:num_examples]['label'], preds, losses])), columns=['content', 'label', 'pred', 'loss']), embeddings], axis=1)
         | 
| 166 |  | 
| 167 | 
            +
            @st.cache(ttl=600)
         | 
| 168 | 
            +
            def clustering(data,num_clusters):
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                X = np.array(data['embedding'].tolist())
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                kclusterer = KMeansClusterer(
         | 
| 173 | 
            +
                    num_clusters, distance=nltk.cluster.util.cosine_distance,
         | 
| 174 | 
            +
                    repeats=25,avoid_empty_clusters=True)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                assigned_clusters = kclusterer.cluster(X, assign_clusters=True)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                data['cluster'] = pd.Series(assigned_clusters, index=data.index).astype('int')
         | 
| 179 | 
            +
                data['centroid'] = data['cluster'].apply(lambda x: kclusterer.means()[x])
         | 
| 180 | 
            +
                    
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                return data, assigned_clusters
         | 
| 183 | 
            +
             | 
| 184 | 
            +
            def kmeans(df, num_clusters=3):
         | 
| 185 | 
            +
                data_hl = df.loc[df['slice'] == 'high-loss']
         | 
| 186 | 
            +
                data_kmeans,clusters = clustering(data_hl,num_clusters)
         | 
| 187 | 
            +
                merged = pd.merge(df, data_kmeans, left_index=True, right_index=True, how='outer', suffixes=('', '_y'))
         | 
| 188 | 
            +
                merged.drop(merged.filter(regex='_y$').columns.tolist(),axis=1,inplace=True)
         | 
| 189 | 
            +
                merged['cluster'] = merged['cluster'].fillna(num_clusters).astype('int')
         | 
| 190 | 
            +
                return merged
         | 
| 191 |  | 
| 192 | 
            +
            @st.cache(ttl=600)
         | 
| 193 | 
            +
            def distance_from_centroid(row):
         | 
| 194 | 
            +
                return sdist.norm(row['embedding'] - row['centroid'].tolist())
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            @st.cache(ttl=600)
         | 
| 197 | 
             
            def topic_distribution(weights, smoothing=0.01):
         | 
| 198 | 
             
                topic_frequencies = defaultdict(float)
         | 
| 199 | 
             
                topic_frequencies_spotlight = defaultdict(float)
         | 
|  | |
| 223 |  | 
| 224 | 
             
            if __name__ == "__main__":
         | 
| 225 | 
             
                ### STREAMLIT APP CONGFIG ###
         | 
|  | |
| 226 | 
             
                st.set_page_config(layout="wide", page_title="Error Slice Analysis")
         | 
| 227 |  | 
| 228 | 
            +
                lcol, rcol = st.columns([2, 2])
         | 
|  | |
|  | |
| 229 | 
             
                # ******* loading the mode and the data
         | 
|  | |
|  | |
| 230 | 
             
                dataset = st.sidebar.selectbox(
         | 
| 231 | 
             
                    "Dataset",
         | 
| 232 | 
             
                    ["amazon_polarity", "squad", "movielens", "waterbirds"],
         | 
|  | |
| 243 | 
             
                    index=0
         | 
| 244 | 
             
                )
         | 
| 245 |  | 
| 246 | 
            +
                loss_quantile = st.sidebar.slider(
         | 
| 247 | 
            +
                    "Loss Quantile", min_value=0.0, max_value=1.0,step=0.1,value=0.95
         | 
|  | |
|  | |
| 248 | 
             
                )
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                run_kmeans = st.sidebar.radio("Cluster error slice?", ('True', 'False'), index=0)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                num_clusters = st.sidebar.slider("# clusters", min_value=1, max_value=20, step=1, value=3)
         | 
| 253 | 
            +
             | 
| 254 | 
             
                ### LOAD DATA AND SESSION VARIABLES ###
         | 
| 255 | 
            +
                data = pd.read_parquet('./assets/data/amazon_polarity.test.parquet')
         | 
| 256 | 
            +
                embedding_umap = data[['x','y']]
         | 
| 257 | 
            +
                emb_df = pd.read_parquet('./assets/data/amazon_test_emb.parquet')
         | 
| 258 | 
            +
                data_df = pd.DataFrame([data['content'], data['label'], data['pred'], data['loss'], emb_df['embedding'], data['x'], data['y']]).transpose()
         | 
| 259 | 
             
                if "user_data" not in st.session_state:
         | 
| 260 | 
             
                    st.session_state["user_data"] = data_df
         | 
| 261 | 
             
                if "selected_slice" not in st.session_state:
         | 
|  | |
| 263 | 
             
                if "embedding" not in st.session_state:
         | 
| 264 | 
             
                    st.session_state["embedding"] = embedding_umap
         | 
| 265 |  | 
| 266 | 
            +
                data_df['loss'] = data_df['loss'].astype(float)
         | 
| 267 | 
            +
                losses = data_df['loss']
         | 
| 268 | 
            +
                high_loss = losses.quantile(loss_quantile)
         | 
| 269 | 
            +
                data_df['slice'] = 'high-loss'
         | 
| 270 | 
            +
                data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss')
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                if run_kmeans == 'True':
         | 
| 273 | 
            +
                    merged = kmeans(data_df,num_clusters=num_clusters)
         | 
| 274 | 
             
                with lcol:
         | 
| 275 | 
             
                    st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
         | 
| 276 | 
            +
                    dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
         | 
| 277 | 
             
                        by=['loss'], ascending=False)
         | 
| 278 | 
             
                    table_html = dataframe.to_html(
         | 
| 279 | 
            +
                        columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
         | 
| 280 | 
             
                    # table_html = table_html.replace("<th>", '<th align="left">')  # left-align the headers
         | 
| 281 | 
             
                    st.write(dataframe)
         | 
|  | |
|  | |
|  | |
| 282 | 
             
                # st_aggrid.AgGrid(dataframe)
         | 
| 283 | 
             
                # table_html = dataframe.to_html(columns=['content', 'label', 'pred', 'loss'], max_rows=100)
         | 
| 284 | 
             
                # table_html = table_html.replace("<th>", '<th align="left">')  # left-align the headers
         | 
| 285 | 
             
                # st.write(table_html)
         | 
| 286 |  | 
| 287 | 
            +
                with rcol:
         | 
| 288 | 
            +
                    st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
         | 
| 289 | 
            +
                    commontokens = frequent_tokens(merged, tokenizer, loss_quantile=loss_quantile)
         | 
| 290 | 
            +
                    st.write(commontokens)
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                quant_panel(merged)
         | 
|  | 
