Spaces:
Runtime error
Runtime error
David Wisdom
commited on
Commit
·
d256b25
1
Parent(s):
c05b194
plot the example stops on a map as well
Browse files
app.py
CHANGED
@@ -15,7 +15,9 @@ from sklearn.cluster import DBSCAN
|
|
15 |
|
16 |
def read_stops(p: str):
|
17 |
"""
|
18 |
-
|
|
|
|
|
19 |
"""
|
20 |
return pd.read_csv(p)
|
21 |
|
@@ -38,7 +40,12 @@ def read_encodings(p: str) -> tf.Tensor:
|
|
38 |
|
39 |
def cluster_encodings(encodings: tf.Tensor) -> np.ndarray:
|
40 |
"""
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
42 |
"""
|
43 |
# I know the hyperparams I want from the EDA I did in the notebook
|
44 |
clusterer = DBSCAN(eps=0.7, min_samples=100).fit(encodings)
|
@@ -47,7 +54,11 @@ def cluster_encodings(encodings: tf.Tensor) -> np.ndarray:
|
|
47 |
|
48 |
def cluster_lat_lon(df: pd.DataFrame) -> np.ndarray:
|
49 |
"""
|
50 |
-
|
|
|
|
|
|
|
|
|
51 |
"""
|
52 |
# I know the hyperparams I want from the EDA I did in the notebook
|
53 |
clusterer = DBSCAN(eps=0.025, min_samples=100).fit(df[['latitude', 'longitude']])
|
@@ -56,26 +67,28 @@ def cluster_lat_lon(df: pd.DataFrame) -> np.ndarray:
|
|
56 |
|
57 |
def plot_example(df: pd.DataFrame, labels: np.ndarray):
|
58 |
"""
|
59 |
-
|
|
|
|
|
|
|
60 |
"""
|
61 |
-
|
62 |
labels = labels.astype('str')
|
63 |
|
64 |
-
fig = px.
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
template='presentation',
|
70 |
-
width=plot_size,
|
71 |
-
height=plot_size)
|
72 |
-
# fig.show()
|
73 |
return fig
|
74 |
|
75 |
|
76 |
def plot_venice_blvd(df: pd.DataFrame, labels: np.ndarray):
|
77 |
"""
|
78 |
-
|
|
|
|
|
|
|
79 |
"""
|
80 |
px.set_mapbox_access_token(st.secrets['mapbox_token'])
|
81 |
venice_blvd = {'lat': 34.008350,
|
@@ -107,9 +120,31 @@ def main(data_path: str, enc_path: str):
|
|
107 |
|
108 |
# Display the plots with Streamlit
|
109 |
st.write('# Example of what DBSCAN does')
|
|
|
|
|
|
|
|
|
|
|
110 |
st.plotly_chart(example_fig, use_container_width=True)
|
111 |
|
112 |
st.write('# Venice Blvd')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
st.plotly_chart(venice_fig, use_container_width=True)
|
114 |
|
115 |
|
|
|
15 |
|
16 |
def read_stops(p: str):
|
17 |
"""
|
18 |
+
Read in the .csv file of metro stops
|
19 |
+
|
20 |
+
:param p: The path to the .csv file of metro stops
|
21 |
"""
|
22 |
return pd.read_csv(p)
|
23 |
|
|
|
40 |
|
41 |
def cluster_encodings(encodings: tf.Tensor) -> np.ndarray:
|
42 |
"""
|
43 |
+
Cluster the sentence encodings using DBSCAN.
|
44 |
+
|
45 |
+
:param encodings: A Tensor of sentence encodings with shape
|
46 |
+
(number of sentences, 512)
|
47 |
+
|
48 |
+
:returns: a NumPy array of the cluster labels
|
49 |
"""
|
50 |
# I know the hyperparams I want from the EDA I did in the notebook
|
51 |
clusterer = DBSCAN(eps=0.7, min_samples=100).fit(encodings)
|
|
|
54 |
|
55 |
def cluster_lat_lon(df: pd.DataFrame) -> np.ndarray:
|
56 |
"""
|
57 |
+
Cluster the metro stops by their latitude and longitude using DBSCAN.
|
58 |
+
|
59 |
+
:param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns
|
60 |
+
|
61 |
+
:returns: a NumPy array of the cluster labels
|
62 |
"""
|
63 |
# I know the hyperparams I want from the EDA I did in the notebook
|
64 |
clusterer = DBSCAN(eps=0.025, min_samples=100).fit(df[['latitude', 'longitude']])
|
|
|
67 |
|
68 |
def plot_example(df: pd.DataFrame, labels: np.ndarray):
|
69 |
"""
|
70 |
+
Plot the geographic clustering
|
71 |
+
|
72 |
+
:param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns
|
73 |
+
:param labels: a NumPy array of the cluster labels
|
74 |
"""
|
75 |
+
px.set_mapbox_access_token(st.secrets['mapbox_token'])
|
76 |
labels = labels.astype('str')
|
77 |
|
78 |
+
fig = px.scatter_mapbox(df, x='longitude', y='latitude',
|
79 |
+
hover_name='display_name',
|
80 |
+
color=labels,
|
81 |
+
zoom=10,
|
82 |
+
color_discrete_sequence=px.colors.qualitative.Safe,
|
|
|
|
|
|
|
|
|
83 |
return fig
|
84 |
|
85 |
|
86 |
def plot_venice_blvd(df: pd.DataFrame, labels: np.ndarray):
|
87 |
"""
|
88 |
+
Plot the metro stops and color them based on their names
|
89 |
+
|
90 |
+
:param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns
|
91 |
+
:param labels: a NumPy array of the cluster labels
|
92 |
"""
|
93 |
px.set_mapbox_access_token(st.secrets['mapbox_token'])
|
94 |
venice_blvd = {'lat': 34.008350,
|
|
|
120 |
|
121 |
# Display the plots with Streamlit
|
122 |
st.write('# Example of what DBSCAN does')
|
123 |
+
st.write("""As an example of a typical DBSCAN result, I've clustered the
|
124 |
+
stops by their geographic location.
|
125 |
+
The algorithm finds three clusters.
|
126 |
+
Points labeled `-1` aren't part of any cluster.
|
127 |
+
Clicking on `-1` in the legend will turn off those points."""
|
128 |
st.plotly_chart(example_fig, use_container_width=True)
|
129 |
|
130 |
st.write('# Venice Blvd')
|
131 |
+
st.write("""I encoded the names of all the stops using the Universal Sentence Encoder v4.
|
132 |
+
I then clustered those encodings so that I could group the stops based on their names
|
133 |
+
instead of their geographic position.
|
134 |
+
As I expected, stops on the same road end up close enough to each other that DBSCAN can cluster them together.
|
135 |
+
Sometimes, however, a stop has a name that means something to the encoder.
|
136 |
+
When that happens, the encoding ends up too far away from the rest of the stops on that road.
|
137 |
+
For example, the stops on Venice Blvd get clustered together,
|
138 |
+
but the stop `Venice / Lincoln` ends up somewhere else.
|
139 |
+
I assume it ends up somewhere else because the encoder recognizes "Lincoln"
|
140 |
+
and that meaning overpowers the "Venice" meaning enough that the encoding
|
141 |
+
is too far away from the rest of the "Venice" stops.
|
142 |
+
A few other examples on Venice Blvd are "Saint Andrews," "Harvard," and "Beethoven."
|
143 |
+
There are a few that I don't ascribe much meaning to, such as "Girard" and "Jasmine."
|
144 |
+
My mind first jumps to adversarial prompts that use famous names to move the encoding
|
145 |
+
around in the encoding space.
|
146 |
+
There's a lot more to dig into here but I'll leave it there for now.
|
147 |
+
"""
|
148 |
st.plotly_chart(venice_fig, use_container_width=True)
|
149 |
|
150 |
|