|
import os |
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
|
import pickle |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import plotly.express as px |
|
import streamlit as st |
|
import tensorflow as tf |
|
import tensorflow_hub as hub |
|
|
|
from sklearn.cluster import DBSCAN |
|
|
|
|
|
def read_stops(p: str): |
|
""" |
|
Read in the .csv file of metro stops |
|
|
|
:param p: The path to the .csv file of metro stops |
|
""" |
|
return pd.read_csv(p) |
|
|
|
|
|
def read_encodings(p: str) -> tf.Tensor: |
|
""" |
|
Unpickle the Universal Sentence Encoder v4 encodings |
|
and return them |
|
|
|
This function doesn't make any attempt to patch the security holes in `pickle`. |
|
|
|
:param p: Path to the encodings |
|
|
|
:returns: A Tensor of the encodings with shape (number of sentences, 512) |
|
""" |
|
with open(p, 'rb') as f: |
|
encodings = pickle.load(f) |
|
return encodings |
|
|
|
|
|
def cluster_encodings(encodings: tf.Tensor) -> np.ndarray: |
|
""" |
|
Cluster the sentence encodings using DBSCAN. |
|
|
|
:param encodings: A Tensor of sentence encodings with shape |
|
(number of sentences, 512) |
|
|
|
:returns: a NumPy array of the cluster labels |
|
""" |
|
|
|
clusterer = DBSCAN(eps=0.7, min_samples=100).fit(encodings) |
|
return clusterer.labels_ |
|
|
|
|
|
def cluster_lat_lon(df: pd.DataFrame) -> np.ndarray: |
|
""" |
|
Cluster the metro stops by their latitude and longitude using DBSCAN. |
|
|
|
:param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns |
|
|
|
:returns: a NumPy array of the cluster labels |
|
""" |
|
|
|
clusterer = DBSCAN(eps=0.025, min_samples=100).fit(df[['latitude', 'longitude']]) |
|
return clusterer.labels_ |
|
|
|
|
|
def plot_example(df: pd.DataFrame, labels: np.ndarray): |
|
""" |
|
Plot the geographic clustering |
|
|
|
:param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns |
|
:param labels: a NumPy array of the cluster labels |
|
""" |
|
px.set_mapbox_access_token(st.secrets['mapbox_token']) |
|
labels = labels.astype('str') |
|
|
|
fig = px.scatter_mapbox(df, lon='longitude', lat='latitude', |
|
hover_name='display_name', |
|
color=labels, |
|
zoom=8, |
|
color_discrete_sequence=px.colors.qualitative.Dark24) |
|
return fig |
|
|
|
|
|
def plot_venice_blvd(df: pd.DataFrame, labels: np.ndarray): |
|
""" |
|
Plot the metro stops and color them based on their names |
|
|
|
:param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns |
|
:param labels: a NumPy array of the cluster labels |
|
""" |
|
px.set_mapbox_access_token(st.secrets['mapbox_token']) |
|
venice_blvd = {'lat': 34.008350, |
|
'lon': -118.425362} |
|
labels = labels.astype('str') |
|
|
|
fig = px.scatter_mapbox(df, lat='latitude', lon='longitude', |
|
color=labels, |
|
hover_name='display_name', |
|
center=venice_blvd, |
|
zoom=12, |
|
color_discrete_sequence=px.colors.qualitative.Dark24) |
|
|
|
|
|
return fig |
|
|
|
|
|
def main(data_path: str, enc_path: str): |
|
df = read_stops(data_path) |
|
|
|
|
|
example_labels = cluster_lat_lon(df) |
|
example_fig = plot_example(df, example_labels) |
|
|
|
|
|
encodings = read_encodings(enc_path) |
|
encoding_labels = cluster_encodings(encodings) |
|
venice_fig = plot_venice_blvd(df, encoding_labels) |
|
|
|
|
|
st.write('# Cluster the stops by their position') |
|
st.write("""First, I clustered the |
|
stops by their geographic location. |
|
The DBSCAN algorithm finds three clusters. |
|
Points labeled `-1` aren't part of any cluster. |
|
Clicking on `-1` in the legend will turn off those points.""") |
|
st.plotly_chart(example_fig, use_container_width=True) |
|
|
|
st.write('# Cluster the stops by their name') |
|
st.write("""I encoded the names of all the stops using the Universal Sentence Encoder v4. |
|
I then clustered those encodings so that I could group the stops based on their names |
|
instead of their geographic position. |
|
As I expected, stops on the same road end up close enough to each other that DBSCAN can cluster them together. |
|
|
|
|
|
Sometimes, however, a stop has a name that means something to the encoder. |
|
When that happens, the encoding ends up too far away from the rest of the stops on that road. |
|
For example, the stops on Venice Blvd get clustered together, |
|
but the stop "Venice / Lincoln" ends up somewhere else. |
|
|
|
|
|
I assume it ends up somewhere else because the encoder recognizes "Lincoln" |
|
and that meaning overpowers the "Venice" meaning enough that the encoding |
|
is too far away from the rest of the "Venice" stops. |
|
A few other examples on Venice Blvd are "Saint Andrews," "Harvard," and "Beethoven." |
|
There are also a few that I don't ascribe much meaning to, such as "Girard" and "Robertson." |
|
|
|
|
|
There's a lot more to dig into here but I'll leave it there for now. |
|
My mind first jumps to adversarial prompts that use famous names to move the encoding |
|
around in the encoding space. |
|
""") |
|
st.plotly_chart(venice_fig, use_container_width=True) |
|
|
|
|
|
if __name__ == '__main__': |
|
import argparse |
|
|
|
p = argparse.ArgumentParser() |
|
p.add_argument('--data_path', |
|
nargs='?', |
|
default='data/stops.csv', |
|
help="Path to the dataset of LA Metro stops. Defaults to 'data/stops.csv'") |
|
p.add_argument('--enc_path', |
|
nargs='?', |
|
default='data/encodings.pkl', |
|
help="Path to the pickled encodings. Defaults to 'data/encodings.pkl'") |
|
args = p.parse_args() |
|
|
|
main(**vars(args)) |
|
|
|
|