import pandas as pd |
import numpy as np |
from sklearn.manifold import TSNE |
import json |
import base64 |
def generate_tsne_embedding(input_file, output_file): |
df = pd.read_parquet(input_file) |
embeddings = np.array(df['embedding'].tolist()) |
tsne = TSNE(n_components=2, random_state=42) |
tsne_results = tsne.fit_transform(embeddings) |
output_data = [] |
for i, (x, y) in enumerate(tsne_results): |
image_base64 = base64.b64encode(df['image'][i]).decode('utf-8') |
output_data.append({ |
'x': float(x), |
'y': float(y), |
'image': image_base64 |
}) |
with open(output_file, 'w') as f: |
json.dump(output_data, f) |
import os |
import base64 |
import json |
import numpy as np |
from dash import dcc, html, Input, Output, no_update, Dash |
import numpy as np |
from sklearn.cluster import KMeans |
from scipy.spatial.distance import cdist |
import plotly.graph_objects as go |
from PIL import Image |
import random |
import socket |
def find_free_port(): |
while True: |
port = random.randint(49152, 65535) |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
try: |
s.bind(('', port)) |
return port |
except OSError: |
pass |
def create_dash_app(fig, images): |
app = Dash(__name__) |
app.layout = html.Div( |
className="container", |
children=[ |
dcc.Graph(id="graph", figure=fig, clear_on_unhover=True), |
dcc.Tooltip(id="graph-tooltip", direction='bottom'), |
], |
) |
@app.callback( |
Output("graph-tooltip", "show"), |
Output("graph-tooltip", "bbox"), |
Output("graph-tooltip", "children"), |
Input("graph", "hoverData"), |
) |
def display_hover(hoverData): |
if hoverData is None: |
return False, no_update, no_update |
hover_data = hoverData["points"][0] |
bbox = hover_data["bbox"] |
num = hover_data["pointNumber"] |
image_base64 = images[num] |
children = [ |
html.Div([ |
html.Img( |
src=f"data:image/jpeg;base64,{image_base64}", |
style={"width": "200px", |
"height": "200px", |
'display': 'block', 'margin': '0 auto'}, |
), |
]) |
] |
return True, bbox, children |
return app |
def perform_kmeans(data, k=20): |
coords = np.array([[point['x'], point['y']] for point in data]) |
kmeans = KMeans(n_clusters=k, random_state=42) |
kmeans.fit(coords) |
return kmeans |
def find_nearest_images(data, kmeans): |
coords = np.array([[point['x'], point['y']] for point in data]) |
images = [point['image'] for point in data] |
distances = cdist(coords, kmeans.cluster_centers_, metric='euclidean') |
nearest_indices = distances.argmin(axis=0) |
nearest_images = [images[i] for i in nearest_indices] |
return nearest_images, kmeans.cluster_centers_ |
def create_dash_fig(data, kmeans_result, nearest_images, cluster_centers, title): |
x = [point['x'] for point in data] |
y = [point['y'] for point in data] |
images = [point['image'] for point in data] |
max_range = max(max(x) - min(x), max(y) - min(y)) / 2 |
center_x = (max(x) + min(x)) / 2 |
center_y = (max(y) + min(y)) / 2 |
fig = go.Figure() |
fig.add_trace(go.Scatter( |
x=x, |
y=y, |
mode='markers', |
marker=dict( |
size=5, |
color=kmeans_result.labels_, |
colorscale='Viridis', |
showscale=False |
), |
name='Data Points' |
)) |
fig.update_layout( |
title=title, |
width=1000, height=1000, |
xaxis=dict( |
range=[center_x - max_range, center_x + max_range], |
scaleanchor="y", |
scaleratio=1, |
), |
yaxis=dict( |
range=[center_y - max_range, center_y + max_range], |
), |
showlegend=False, |
) |
fig.update_traces( |
hoverinfo="none", |
hovertemplate=None, |
) |
for i, (cx, cy) in enumerate(cluster_centers): |
fig.add_layout_image( |
dict( |
source=f"data:image/jpg;base64,{nearest_images[i]}", |
x=cx, |
y=cy, |
xref="x", |
yref="y", |
sizex=10, |
sizey=10, |
sizing="contain", |
opacity=1, |
layer="below" |
) |
) |
fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False)) |
return fig, images |
def make_dash_kmeans(data, title, k=40): |
kmeans_result = perform_kmeans(data, k=k) |
nearest_images, cluster_centers = find_nearest_images(data, kmeans_result) |
fig, images = create_dash_fig(data, kmeans_result, nearest_images, cluster_centers, title) |
app = create_dash_app(fig, images) |
port = find_free_port() |
print(f"Serving on{port}/") |
print(f"To serve this over the Internet, run `ngrok http {port}`") |
app.run_server(port=port) |
return app |
if __name__ == "__main__": |
dataset_folder = os.path.dirname('./') |
name = "style" |
image_embedding_path = os.path.join(dataset_folder, f"processed_dataset.parquet") |
tsne_path = os.path.join(dataset_folder, f"processed_dataset.json") |
generate_tsne_embedding(image_embedding_path, tsne_path) |
with open(tsne_path, "r") as f: |
data = json.load(f) |
make_dash_kmeans(data, name, k=40) |