import pandas as pd import numpy as np from sklearn.manifold import TSNE import json import base64 def generate_tsne_embedding(input_file, output_file): # Load the Parquet file df = pd.read_parquet(input_file) # Extract embeddings and convert to numpy array embeddings = np.array(df['embedding'].tolist()) # Perform t-SNE tsne = TSNE(n_components=2, random_state=42) tsne_results = tsne.fit_transform(embeddings) # Prepare output data output_data = [] for i, (x, y) in enumerate(tsne_results): image_base64 = base64.b64encode(df['image'][i]).decode('utf-8') output_data.append({ 'x': float(x), 'y': float(y), 'image': image_base64 }) # Save results to JSON file with open(output_file, 'w') as f: json.dump(output_data, f) ## ---------------------------- ## Dash app ## ---------------------------- import os import base64 import json import numpy as np from dash import dcc, html, Input, Output, no_update, Dash import numpy as np from sklearn.cluster import KMeans from scipy.spatial.distance import cdist import plotly.graph_objects as go from PIL import Image import random import socket def find_free_port(): while True: port = random.randint(49152, 65535) # Use dynamic/private port range with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.bind(('', port)) return port except OSError: pass def create_dash_app(fig, images): app = Dash(__name__) app.layout = html.Div( className="container", children=[ dcc.Graph(id="graph", figure=fig, clear_on_unhover=True), dcc.Tooltip(id="graph-tooltip", direction='bottom'), ], ) @app.callback( Output("graph-tooltip", "show"), Output("graph-tooltip", "bbox"), Output("graph-tooltip", "children"), Input("graph", "hoverData"), ) def display_hover(hoverData): if hoverData is None: return False, no_update, no_update hover_data = hoverData["points"][0] bbox = hover_data["bbox"] num = hover_data["pointNumber"] image_base64 = images[num] children = [ html.Div([ html.Img( src=f"data:image/jpeg;base64,{image_base64}", style={"width": "200px", "height": "200px", 'display': 'block', 'margin': '0 auto'}, ), ]) ] return True, bbox, children return app def perform_kmeans(data, k=20): # Extract x, y coordinates coords = np.array([[point['x'], point['y']] for point in data]) # Perform k-means clustering kmeans = KMeans(n_clusters=k, random_state=42) kmeans.fit(coords) return kmeans def find_nearest_images(data, kmeans): coords = np.array([[point['x'], point['y']] for point in data]) images = [point['image'] for point in data] # Calculate distances to cluster centers distances = cdist(coords, kmeans.cluster_centers_, metric='euclidean') # Find the index of the nearest point for each cluster nearest_indices = distances.argmin(axis=0) # Get the images nearest to each cluster center nearest_images = [images[i] for i in nearest_indices] return nearest_images, kmeans.cluster_centers_ def create_dash_fig(data, kmeans_result, nearest_images, cluster_centers, title): # Extract x, y coordinates x = [point['x'] for point in data] y = [point['y'] for point in data] images = [point['image'] for point in data] # Determine the range for both axes max_range = max(max(x) - min(x), max(y) - min(y)) / 2 center_x = (max(x) + min(x)) / 2 center_y = (max(y) + min(y)) / 2 # Create the scatter plot fig = go.Figure() # Add data points fig.add_trace(go.Scatter( x=x, y=y, mode='markers', marker=dict( size=5, color=kmeans_result.labels_, colorscale='Viridis', showscale=False ), name='Data Points' )) # Add cluster centers and images fig.update_layout( title=title, width=1000, height=1000, xaxis=dict( range=[center_x - max_range, center_x + max_range], scaleanchor="y", scaleratio=1, ), yaxis=dict( range=[center_y - max_range, center_y + max_range], ), showlegend=False, ) fig.update_traces( hoverinfo="none", hovertemplate=None, ) # Add images for i, (cx, cy) in enumerate(cluster_centers): fig.add_layout_image( dict( source=f"data:image/jpg;base64,{nearest_images[i]}", x=cx, y=cy, xref="x", yref="y", sizex=10, sizey=10, sizing="contain", opacity=1, layer="below" ) ) # Remove x and y axes ticks fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False)) return fig, images def make_dash_kmeans(data, title, k=40): kmeans_result = perform_kmeans(data, k=k) nearest_images, cluster_centers = find_nearest_images(data, kmeans_result) fig, images = create_dash_fig(data, kmeans_result, nearest_images, cluster_centers, title) app = create_dash_app(fig, images) port = find_free_port() print(f"Serving on http://127.0.0.1:{port}/") print(f"To serve this over the Internet, run `ngrok http {port}`") app.run_server(port=port) return app if __name__ == "__main__": dataset_folder = os.path.dirname('./') name = "style" image_embedding_path = os.path.join(dataset_folder, f"processed_dataset.parquet") tsne_path = os.path.join(dataset_folder, f"processed_dataset.json") generate_tsne_embedding(image_embedding_path, tsne_path) with open(tsne_path, "r") as f: data = json.load(f) make_dash_kmeans(data, name, k=40)