CSD / examples /tsne_visualization.py
yuxi-liu-wired's picture
example usage
1b1d8c3
import pandas as pd
import numpy as np
from sklearn.manifold import TSNE
import json
import base64
def generate_tsne_embedding(input_file, output_file):
# Load the Parquet file
df = pd.read_parquet(input_file)
# Extract embeddings and convert to numpy array
embeddings = np.array(df['embedding'].tolist())
# Perform t-SNE
tsne = TSNE(n_components=2, random_state=42)
tsne_results = tsne.fit_transform(embeddings)
# Prepare output data
output_data = []
for i, (x, y) in enumerate(tsne_results):
image_base64 = base64.b64encode(df['image'][i]).decode('utf-8')
output_data.append({
'x': float(x),
'y': float(y),
'image': image_base64
})
# Save results to JSON file
with open(output_file, 'w') as f:
json.dump(output_data, f)
## ----------------------------
## Dash app
## ----------------------------
import os
import base64
import json
import numpy as np
from dash import dcc, html, Input, Output, no_update, Dash
import numpy as np
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist
import plotly.graph_objects as go
from PIL import Image
import random
import socket
def find_free_port():
while True:
port = random.randint(49152, 65535) # Use dynamic/private port range
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(('', port))
return port
except OSError:
pass
def create_dash_app(fig, images):
app = Dash(__name__)
app.layout = html.Div(
className="container",
children=[
dcc.Graph(id="graph", figure=fig, clear_on_unhover=True),
dcc.Tooltip(id="graph-tooltip", direction='bottom'),
],
)
@app.callback(
Output("graph-tooltip", "show"),
Output("graph-tooltip", "bbox"),
Output("graph-tooltip", "children"),
Input("graph", "hoverData"),
)
def display_hover(hoverData):
if hoverData is None:
return False, no_update, no_update
hover_data = hoverData["points"][0]
bbox = hover_data["bbox"]
num = hover_data["pointNumber"]
image_base64 = images[num]
children = [
html.Div([
html.Img(
src=f"data:image/jpeg;base64,{image_base64}",
style={"width": "200px",
"height": "200px",
'display': 'block', 'margin': '0 auto'},
),
])
]
return True, bbox, children
return app
def perform_kmeans(data, k=20):
# Extract x, y coordinates
coords = np.array([[point['x'], point['y']] for point in data])
# Perform k-means clustering
kmeans = KMeans(n_clusters=k, random_state=42)
kmeans.fit(coords)
return kmeans
def find_nearest_images(data, kmeans):
coords = np.array([[point['x'], point['y']] for point in data])
images = [point['image'] for point in data]
# Calculate distances to cluster centers
distances = cdist(coords, kmeans.cluster_centers_, metric='euclidean')
# Find the index of the nearest point for each cluster
nearest_indices = distances.argmin(axis=0)
# Get the images nearest to each cluster center
nearest_images = [images[i] for i in nearest_indices]
return nearest_images, kmeans.cluster_centers_
def create_dash_fig(data, kmeans_result, nearest_images, cluster_centers, title):
# Extract x, y coordinates
x = [point['x'] for point in data]
y = [point['y'] for point in data]
images = [point['image'] for point in data]
# Determine the range for both axes
max_range = max(max(x) - min(x), max(y) - min(y)) / 2
center_x = (max(x) + min(x)) / 2
center_y = (max(y) + min(y)) / 2
# Create the scatter plot
fig = go.Figure()
# Add data points
fig.add_trace(go.Scatter(
x=x,
y=y,
mode='markers',
marker=dict(
size=5,
color=kmeans_result.labels_,
colorscale='Viridis',
showscale=False
),
name='Data Points'
))
# Add cluster centers and images
fig.update_layout(
title=title,
width=1000, height=1000,
xaxis=dict(
range=[center_x - max_range, center_x + max_range],
scaleanchor="y",
scaleratio=1,
),
yaxis=dict(
range=[center_y - max_range, center_y + max_range],
),
showlegend=False,
)
fig.update_traces(
hoverinfo="none",
hovertemplate=None,
)
# Add images
for i, (cx, cy) in enumerate(cluster_centers):
fig.add_layout_image(
dict(
source=f"data:image/jpg;base64,{nearest_images[i]}",
x=cx,
y=cy,
xref="x",
yref="y",
sizex=10,
sizey=10,
sizing="contain",
opacity=1,
layer="below"
)
)
# Remove x and y axes ticks
fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False))
return fig, images
def make_dash_kmeans(data, title, k=40):
kmeans_result = perform_kmeans(data, k=k)
nearest_images, cluster_centers = find_nearest_images(data, kmeans_result)
fig, images = create_dash_fig(data, kmeans_result, nearest_images, cluster_centers, title)
app = create_dash_app(fig, images)
port = find_free_port()
print(f"Serving on http://127.0.0.1:{port}/")
print(f"To serve this over the Internet, run `ngrok http {port}`")
app.run_server(port=port)
return app
if __name__ == "__main__":
dataset_folder = os.path.dirname('./')
name = "style"
image_embedding_path = os.path.join(dataset_folder, f"processed_dataset.parquet")
tsne_path = os.path.join(dataset_folder, f"processed_dataset.json")
generate_tsne_embedding(image_embedding_path, tsne_path)
with open(tsne_path, "r") as f:
data = json.load(f)
make_dash_kmeans(data, name, k=40)