|
import pandas as pd |
|
import numpy as np |
|
from sklearn.manifold import TSNE |
|
import json |
|
import base64 |
|
|
|
def generate_tsne_embedding(input_file, output_file): |
|
|
|
df = pd.read_parquet(input_file) |
|
|
|
|
|
embeddings = np.array(df['embedding'].tolist()) |
|
|
|
|
|
tsne = TSNE(n_components=2, random_state=42) |
|
tsne_results = tsne.fit_transform(embeddings) |
|
|
|
|
|
output_data = [] |
|
for i, (x, y) in enumerate(tsne_results): |
|
image_base64 = base64.b64encode(df['image'][i]).decode('utf-8') |
|
output_data.append({ |
|
'x': float(x), |
|
'y': float(y), |
|
'image': image_base64 |
|
}) |
|
|
|
|
|
with open(output_file, 'w') as f: |
|
json.dump(output_data, f) |
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import base64 |
|
import json |
|
import numpy as np |
|
from dash import dcc, html, Input, Output, no_update, Dash |
|
import numpy as np |
|
from sklearn.cluster import KMeans |
|
from scipy.spatial.distance import cdist |
|
import plotly.graph_objects as go |
|
from PIL import Image |
|
import random |
|
import socket |
|
|
|
def find_free_port(): |
|
while True: |
|
port = random.randint(49152, 65535) |
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
|
try: |
|
s.bind(('', port)) |
|
return port |
|
except OSError: |
|
pass |
|
|
|
def create_dash_app(fig, images): |
|
app = Dash(__name__) |
|
|
|
app.layout = html.Div( |
|
className="container", |
|
children=[ |
|
dcc.Graph(id="graph", figure=fig, clear_on_unhover=True), |
|
dcc.Tooltip(id="graph-tooltip", direction='bottom'), |
|
], |
|
) |
|
|
|
@app.callback( |
|
Output("graph-tooltip", "show"), |
|
Output("graph-tooltip", "bbox"), |
|
Output("graph-tooltip", "children"), |
|
Input("graph", "hoverData"), |
|
) |
|
def display_hover(hoverData): |
|
if hoverData is None: |
|
return False, no_update, no_update |
|
|
|
hover_data = hoverData["points"][0] |
|
bbox = hover_data["bbox"] |
|
num = hover_data["pointNumber"] |
|
|
|
image_base64 = images[num] |
|
children = [ |
|
html.Div([ |
|
html.Img( |
|
src=f"data:image/jpeg;base64,{image_base64}", |
|
style={"width": "200px", |
|
"height": "200px", |
|
'display': 'block', 'margin': '0 auto'}, |
|
), |
|
]) |
|
] |
|
|
|
return True, bbox, children |
|
|
|
return app |
|
|
|
def perform_kmeans(data, k=20): |
|
|
|
coords = np.array([[point['x'], point['y']] for point in data]) |
|
|
|
|
|
kmeans = KMeans(n_clusters=k, random_state=42) |
|
kmeans.fit(coords) |
|
|
|
return kmeans |
|
|
|
def find_nearest_images(data, kmeans): |
|
coords = np.array([[point['x'], point['y']] for point in data]) |
|
images = [point['image'] for point in data] |
|
|
|
|
|
distances = cdist(coords, kmeans.cluster_centers_, metric='euclidean') |
|
|
|
|
|
nearest_indices = distances.argmin(axis=0) |
|
|
|
|
|
nearest_images = [images[i] for i in nearest_indices] |
|
|
|
return nearest_images, kmeans.cluster_centers_ |
|
|
|
def create_dash_fig(data, kmeans_result, nearest_images, cluster_centers, title): |
|
|
|
x = [point['x'] for point in data] |
|
y = [point['y'] for point in data] |
|
images = [point['image'] for point in data] |
|
|
|
|
|
max_range = max(max(x) - min(x), max(y) - min(y)) / 2 |
|
center_x = (max(x) + min(x)) / 2 |
|
center_y = (max(y) + min(y)) / 2 |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
x=x, |
|
y=y, |
|
mode='markers', |
|
marker=dict( |
|
size=5, |
|
color=kmeans_result.labels_, |
|
colorscale='Viridis', |
|
showscale=False |
|
), |
|
name='Data Points' |
|
)) |
|
|
|
|
|
|
|
fig.update_layout( |
|
title=title, |
|
width=1000, height=1000, |
|
xaxis=dict( |
|
range=[center_x - max_range, center_x + max_range], |
|
scaleanchor="y", |
|
scaleratio=1, |
|
), |
|
yaxis=dict( |
|
range=[center_y - max_range, center_y + max_range], |
|
), |
|
showlegend=False, |
|
) |
|
|
|
fig.update_traces( |
|
hoverinfo="none", |
|
hovertemplate=None, |
|
) |
|
|
|
for i, (cx, cy) in enumerate(cluster_centers): |
|
fig.add_layout_image( |
|
dict( |
|
source=f"data:image/jpg;base64,{nearest_images[i]}", |
|
x=cx, |
|
y=cy, |
|
xref="x", |
|
yref="y", |
|
sizex=10, |
|
sizey=10, |
|
sizing="contain", |
|
opacity=1, |
|
layer="below" |
|
) |
|
) |
|
|
|
|
|
fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False)) |
|
|
|
return fig, images |
|
|
|
def make_dash_kmeans(data, title, k=40): |
|
kmeans_result = perform_kmeans(data, k=k) |
|
nearest_images, cluster_centers = find_nearest_images(data, kmeans_result) |
|
fig, images = create_dash_fig(data, kmeans_result, nearest_images, cluster_centers, title) |
|
app = create_dash_app(fig, images) |
|
port = find_free_port() |
|
print(f"Serving on http://127.0.0.1:{port}/") |
|
print(f"To serve this over the Internet, run `ngrok http {port}`") |
|
app.run_server(port=port) |
|
return app |
|
|
|
if __name__ == "__main__": |
|
|
|
dataset_folder = os.path.dirname('./') |
|
name = "style" |
|
image_embedding_path = os.path.join(dataset_folder, f"processed_dataset.parquet") |
|
tsne_path = os.path.join(dataset_folder, f"processed_dataset.json") |
|
|
|
generate_tsne_embedding(image_embedding_path, tsne_path) |
|
with open(tsne_path, "r") as f: |
|
data = json.load(f) |
|
|
|
make_dash_kmeans(data, name, k=40) |