huzey commited on
Commit
fa3878c
1 Parent(s): f7365bb

add tree+image

Browse files
Files changed (1) hide show
  1. app.py +108 -11
app.py CHANGED
@@ -16,6 +16,7 @@ import multiprocessing as mp
16
  from einops import rearrange
17
  from matplotlib import pyplot as plt
18
  import matplotlib
 
19
  USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
20
  DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
21
 
@@ -2342,13 +2343,99 @@ with demo:
2342
 
2343
  pil_image = Image.fromarray(image)
2344
  return pil_image
2345
-
2346
- def run_fps_tsne_hierarchical(eigvecs, num_sample_fps, perplexity_tsne, tsne3d_rgb, seed=0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2347
  if len(eigvecs) == 0:
2348
  gr.Warning("Please run NCUT first.")
2349
  return
 
 
 
 
2350
  eigvecs = torch.tensor(eigvecs)
2351
- eigvecs = eigvecs.reshape(-1, eigvecs.shape[-1])
2352
  gr.Info("Running FPS, t-SNE, and Hierarchical Clustering...", 3)
2353
  from ncut_pytorch.ncut_pytorch import farthest_point_sampling
2354
  from sklearn.manifold import TSNE
@@ -2357,8 +2444,8 @@ with demo:
2357
  torch.manual_seed(seed)
2358
  np.random.seed(seed)
2359
 
2360
- fps_idx = farthest_point_sampling(eigvecs, num_sample_fps)
2361
- fps_eigvecs = eigvecs[fps_idx]
2362
  fps_eigvecs = fps_eigvecs.numpy()
2363
 
2364
  tsne3d_rgb = tsne3d_rgb.reshape(-1, 3)
@@ -2371,19 +2458,29 @@ with demo:
2371
  metric='cosine',
2372
  random_state=seed,
2373
  ).fit_transform(fps_eigvecs)
 
 
 
2374
 
2375
  edges = build_tree(tsne_embed)
 
 
2376
 
2377
  # Plot the t-SNE points
2378
  pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, 0)
2379
 
2380
- return tsne_embed, edges, fps_eigvecs, fps_tsne3d_rgb, fps_idx, pil_image
 
 
 
2381
 
2382
- run_hierarchical_button.click(
2383
- run_fps_tsne_hierarchical,
2384
- inputs=[eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider],
2385
- outputs=[tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot],
2386
- )
 
 
2387
  gr.Markdown('---')
2388
  gr.Markdown('<h3 style="text-align: center;">↓ interactively inspect the hierarchical structure</h3>')
2389
  gr.Markdown('---')
 
16
  from einops import rearrange
17
  from matplotlib import pyplot as plt
18
  import matplotlib
19
+ from matplotlib.offsetbox import AnnotationBbox, OffsetImage
20
  USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
21
  DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
22
 
 
2343
 
2344
  pil_image = Image.fromarray(image)
2345
  return pil_image
2346
+
2347
+ def get_top1_heatmap_for_each_dot(images, eigvecs, fps_eigvecs, max_display_dots, fps_tsne_rgb, tsne_embed):
2348
+ n_dots = fps_eigvecs.shape[0]
2349
+ if n_dots > max_display_dots:
2350
+ dots_idx = np.random.choice(n_dots, max_display_dots, replace=False)
2351
+ import fpsample
2352
+ dots_idx = fpsample.bucket_fps_kdline_sampling(tsne_embed, max_display_dots, 5).astype(np.int64)
2353
+ else:
2354
+ dots_idx = np.arange(n_dots)
2355
+ fps_eigvecs = fps_eigvecs[dots_idx]
2356
+ fps_tsne_rgb = fps_tsne_rgb[dots_idx]
2357
+
2358
+ heatmaps = eigvecs @ fps_eigvecs.T # [B, H, W, C] @ [N, C] -> [B, H, W, N]
2359
+ value = heatmaps.mean(1).mean(1) # [B, N]
2360
+ top1_image_idxs = value.argmax(axis=0) # [N]
2361
+
2362
+ def pad_image_with_border(image, border_color, border_width):
2363
+ new_image = np.ones((image.shape[0] + 2 * border_width, image.shape[1] + 2 * border_width, image.shape[2]), dtype=image.dtype)
2364
+ new_image[:, :] = border_color
2365
+ new_image[border_width:-border_width, border_width:-border_width] = image
2366
+ return new_image
2367
+
2368
+ top1_image_blended = []
2369
+ cm = matplotlib.colormaps['hot']
2370
+ for i_fps in range(len(top1_image_idxs)):
2371
+ image_idx = top1_image_idxs[i_fps]
2372
+ image = images[image_idx]
2373
+ heatmap = heatmaps[image_idx, :, :, i_fps]
2374
+ heatmap = cm(heatmap)
2375
+ heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8)
2376
+ image = image.convert("RGB").resize((256, 256))
2377
+ heatmap = Image.fromarray(heatmap).resize((256, 256)).convert("RGB")
2378
+ blended = 0.5 * np.array(image) + 0.5 * np.array(heatmap)
2379
+ blended = np.clip(blended, 0, 255).astype(np.uint8)
2380
+ border_color = fps_tsne_rgb[i_fps, :3] * 255
2381
+ border_width = 20
2382
+ padded_image = pad_image_with_border(blended, border_color, border_width)
2383
+ top1_image_blended.append(padded_image)
2384
+
2385
+ return top1_image_blended, dots_idx
2386
+
2387
+
2388
+
2389
+ def plot_tsne_with_image_heatmaps(images, eigvecs, fps_eigvecs, tsne_embed, fps_tsne_rgb, max_display_dots=100):
2390
+ top1_image_blended, dots_idx = get_top1_heatmap_for_each_dot(images, eigvecs, fps_eigvecs, max_display_dots, fps_tsne_rgb, tsne_embed)
2391
+
2392
+ # Plot the t-SNE points
2393
+ fig, ax = plt.subplots(1, 1, figsize=(15, 15))
2394
+ ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne_rgb)
2395
+ ax.set_xticks([])
2396
+ ax.set_yticks([])
2397
+ ax.axis('off')
2398
+ ax.set_xlim(tsne_embed[:, 0].min()*1.1, tsne_embed[:, 0].max()*1.1)
2399
+ ax.set_ylim(tsne_embed[:, 1].min()*1.1, tsne_embed[:, 1].max()*1.1)
2400
+
2401
+ # Add the top1_image_blended to the scatter plot
2402
+ for i, (x, y) in enumerate(tsne_embed[dots_idx]):
2403
+ img = top1_image_blended[i]
2404
+ img = np.array(img)
2405
+ imgbox = OffsetImage(img, zoom=0.15)
2406
+ ab = AnnotationBbox(imgbox, (x, y), frameon=False)
2407
+ ax.add_artist(ab)
2408
+ ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne_rgb)
2409
+
2410
+ # Remove the white space around the plot
2411
+ fig.tight_layout(pad=0)
2412
+
2413
+ # Save the plot to an in-memory buffer
2414
+ buf = io.BytesIO()
2415
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
2416
+ buf.seek(0)
2417
+
2418
+ # Load the image into a NumPy array
2419
+ image = np.array(Image.open(buf))
2420
+
2421
+ # Close the buffer and plot
2422
+ buf.close()
2423
+ plt.close(fig)
2424
+
2425
+ pil_image = Image.fromarray(image)
2426
+ return pil_image
2427
+
2428
+
2429
+ def run_fps_tsne_hierarchical(image_gallery, eigvecs, num_sample_fps, perplexity_tsne, tsne3d_rgb, seed=0, max_display_dots=300):
2430
  if len(eigvecs) == 0:
2431
  gr.Warning("Please run NCUT first.")
2432
  return
2433
+ images = [image[0] for image in image_gallery]
2434
+ if isinstance(images[0], str):
2435
+ images = [Image.open(image) for image in images]
2436
+
2437
  eigvecs = torch.tensor(eigvecs)
2438
+ _eigvecs = eigvecs.reshape(-1, eigvecs.shape[-1])
2439
  gr.Info("Running FPS, t-SNE, and Hierarchical Clustering...", 3)
2440
  from ncut_pytorch.ncut_pytorch import farthest_point_sampling
2441
  from sklearn.manifold import TSNE
 
2444
  torch.manual_seed(seed)
2445
  np.random.seed(seed)
2446
 
2447
+ fps_idx = farthest_point_sampling(_eigvecs, num_sample_fps)
2448
+ fps_eigvecs = _eigvecs[fps_idx]
2449
  fps_eigvecs = fps_eigvecs.numpy()
2450
 
2451
  tsne3d_rgb = tsne3d_rgb.reshape(-1, 3)
 
2458
  metric='cosine',
2459
  random_state=seed,
2460
  ).fit_transform(fps_eigvecs)
2461
+ # normalize = [-1, 1]
2462
+ tsne_embed[:, 0] = (tsne_embed[:, 0] - tsne_embed[:, 0].min()) / (tsne_embed[:, 0].max() - tsne_embed[:, 0].min()) * 2 - 1
2463
+ tsne_embed[:, 1] = (tsne_embed[:, 1] - tsne_embed[:, 1].min()) / (tsne_embed[:, 1].max() - tsne_embed[:, 1].min()) * 2 - 1
2464
 
2465
  edges = build_tree(tsne_embed)
2466
+ # edges = build_tree(fps_eigvecs, dist='cosine')
2467
+ # edges = build_tree(fps_tsne3d_rgb)
2468
 
2469
  # Plot the t-SNE points
2470
  pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, 0)
2471
 
2472
+ # Plot the t-SNE points with image heatmaps
2473
+ big_pil_image = plot_tsne_with_image_heatmaps(images, eigvecs, fps_eigvecs, tsne_embed, fps_tsne3d_rgb, max_display_dots)
2474
+
2475
+ return tsne_embed, edges, fps_eigvecs, fps_tsne3d_rgb, fps_idx, pil_image, big_pil_image
2476
 
2477
+ big_tsne_plot = gr.Image(label="spectral-tSNE tree [+ Cluster Heatmap]", elem_id="big_tsne_plot", interactive=False, format='png')
2478
+
2479
+ run_hierarchical_button.click(
2480
+ run_fps_tsne_hierarchical,
2481
+ inputs=[input_gallery, eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider],
2482
+ outputs=[tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot, big_tsne_plot],
2483
+ )
2484
  gr.Markdown('---')
2485
  gr.Markdown('<h3 style="text-align: center;">↓ interactively inspect the hierarchical structure</h3>')
2486
  gr.Markdown('---')