huzey commited on
Commit
a3d5c5a
1 Parent(s): 734e8a4

remove mask size filter

Browse files
Files changed (1) hide show
  1. app.py +31 -27
app.py CHANGED
@@ -287,7 +287,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
287
  right = F.normalize(right, dim=-1)
288
  heatmap = left @ right.T
289
  heatmap = F.normalize(heatmap, dim=-1)
290
- num_samples = 80 if not advanced else 130
291
  if num_samples > fps_idx.shape[0]:
292
  num_samples = fps_idx.shape[0]
293
  r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
@@ -305,6 +305,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
305
  fps_heatmaps = {}
306
  sort_values = []
307
  top3_image_idx = {}
 
308
  for _, idx in enumerate(fps_idx):
309
  heatmap = F.cosine_similarity(eigvecs, eigvecs[idx][None], dim=-1)
310
 
@@ -314,7 +315,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
314
  # tensor = tensor[torch.randperm(tensor.shape[0])[:max_size]]
315
  # return tensor.quantile(p)
316
  # top_p = top_percentile(heatmap, p=0.5)
317
- top_p = 0.5
318
 
319
  heatmap = heatmap.reshape(-1, h, w)
320
  mask = (heatmap > top_p).float()
@@ -324,8 +325,9 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
324
  mask = mask[mask_sort_idx[:3]]
325
  sort_values.append(mask.mean().item())
326
  # fps_heatmaps[idx.item()] = heatmap.cpu()
327
- fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:3]].cpu()
328
  top3_image_idx[idx.item()] = mask_sort_idx[:3]
 
329
  # do the sorting
330
  _sort_idx = torch.tensor(sort_values).argsort(descending=True)
331
  fps_idx = fps_idx[_sort_idx]
@@ -342,13 +344,17 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
342
  # shuffle the fps_idx
343
  fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
344
 
 
345
  fig_images = []
346
  i_cluster = 0
347
  num_plots = 10 if not advanced else 20
348
  plot_step_float = (1.0 - progess_start) / num_plots
349
  for i_fig in range(num_plots):
350
  progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
351
- fig, axs = plt.subplots(3, 5, figsize=(15, 9))
 
 
 
352
  for ax in axs.flatten():
353
  ax.axis("off")
354
  for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
@@ -358,7 +364,8 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
358
  size = (images.shape[1], images.shape[2])
359
  heatmap = apply_reds_colormap(heatmap, size)
360
  # for i, image_idx in enumerate(sorted_image_idxs[:3]):
361
- for i, image_idx in enumerate(top3_image_idx[idx.item()]):
 
362
  # _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[image_idx])
363
  _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
364
  axs[i, j].imshow(_heatmap)
@@ -378,10 +385,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
378
 
379
  fig_images.append(img)
380
  plt.close()
381
-
382
- # plt.imshow(img)
383
- # plt.axis("off")
384
- # plt.show()
385
  return fig_images, ret_magnitude
386
 
387
 
@@ -647,26 +651,26 @@ def ncut_run(
647
 
648
  def _ncut_run(*args, **kwargs):
649
  n_ret = kwargs.pop("n_ret", 1)
650
- try:
651
- if torch.cuda.is_available():
652
- torch.cuda.empty_cache()
653
 
654
- ret = ncut_run(*args, **kwargs)
655
 
656
- if torch.cuda.is_available():
657
- torch.cuda.empty_cache()
658
 
659
- ret = list(ret)[:n_ret] + [ret[-1]]
660
- return ret
661
- except Exception as e:
662
- gr.Error(str(e))
663
- if torch.cuda.is_available():
664
- torch.cuda.empty_cache()
665
- return *(None for _ in range(n_ret)), "Error: " + str(e)
666
-
667
- # ret = ncut_run(*args, **kwargs)
668
- # ret = list(ret)[:n_ret] + [ret[-1]]
669
- # return ret
670
 
671
  if USE_HUGGINGFACE_ZEROGPU:
672
  @spaces.GPU(duration=30)
@@ -1415,7 +1419,7 @@ with demo:
1415
  with gr.Column(scale=5, min_width=200):
1416
  output_gallery = make_output_images_section()
1417
  norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
1418
- cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[5], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=True, interactive=False)
1419
  [
1420
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1421
  affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
 
287
  right = F.normalize(right, dim=-1)
288
  heatmap = left @ right.T
289
  heatmap = F.normalize(heatmap, dim=-1)
290
+ num_samples = 50 if not advanced else 100
291
  if num_samples > fps_idx.shape[0]:
292
  num_samples = fps_idx.shape[0]
293
  r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
 
305
  fps_heatmaps = {}
306
  sort_values = []
307
  top3_image_idx = {}
308
+ top10_image_idx = {}
309
  for _, idx in enumerate(fps_idx):
310
  heatmap = F.cosine_similarity(eigvecs, eigvecs[idx][None], dim=-1)
311
 
 
315
  # tensor = tensor[torch.randperm(tensor.shape[0])[:max_size]]
316
  # return tensor.quantile(p)
317
  # top_p = top_percentile(heatmap, p=0.5)
318
+ top_p = 0.8
319
 
320
  heatmap = heatmap.reshape(-1, h, w)
321
  mask = (heatmap > top_p).float()
 
325
  mask = mask[mask_sort_idx[:3]]
326
  sort_values.append(mask.mean().item())
327
  # fps_heatmaps[idx.item()] = heatmap.cpu()
328
+ fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:10]].cpu()
329
  top3_image_idx[idx.item()] = mask_sort_idx[:3]
330
+ top10_image_idx[idx.item()] = mask_sort_idx[:10]
331
  # do the sorting
332
  _sort_idx = torch.tensor(sort_values).argsort(descending=True)
333
  fps_idx = fps_idx[_sort_idx]
 
344
  # shuffle the fps_idx
345
  fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
346
 
347
+
348
  fig_images = []
349
  i_cluster = 0
350
  num_plots = 10 if not advanced else 20
351
  plot_step_float = (1.0 - progess_start) / num_plots
352
  for i_fig in range(num_plots):
353
  progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
354
+ if not advanced:
355
+ fig, axs = plt.subplots(3, 5, figsize=(15, 9))
356
+ if advanced:
357
+ fig, axs = plt.subplots(6, 5, figsize=(15, 18))
358
  for ax in axs.flatten():
359
  ax.axis("off")
360
  for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
 
364
  size = (images.shape[1], images.shape[2])
365
  heatmap = apply_reds_colormap(heatmap, size)
366
  # for i, image_idx in enumerate(sorted_image_idxs[:3]):
367
+ image_idxs = top3_image_idx[idx.item()] if not advanced else top10_image_idx[idx.item()]
368
+ for i, image_idx in enumerate(image_idxs):
369
  # _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[image_idx])
370
  _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
371
  axs[i, j].imshow(_heatmap)
 
385
 
386
  fig_images.append(img)
387
  plt.close()
388
+
 
 
 
389
  return fig_images, ret_magnitude
390
 
391
 
 
651
 
652
  def _ncut_run(*args, **kwargs):
653
  n_ret = kwargs.pop("n_ret", 1)
654
+ # try:
655
+ # if torch.cuda.is_available():
656
+ # torch.cuda.empty_cache()
657
 
658
+ # ret = ncut_run(*args, **kwargs)
659
 
660
+ # if torch.cuda.is_available():
661
+ # torch.cuda.empty_cache()
662
 
663
+ # ret = list(ret)[:n_ret] + [ret[-1]]
664
+ # return ret
665
+ # except Exception as e:
666
+ # gr.Error(str(e))
667
+ # if torch.cuda.is_available():
668
+ # torch.cuda.empty_cache()
669
+ # return *(None for _ in range(n_ret)), "Error: " + str(e)
670
+
671
+ ret = ncut_run(*args, **kwargs)
672
+ ret = list(ret)[:n_ret] + [ret[-1]]
673
+ return ret
674
 
675
  if USE_HUGGINGFACE_ZEROGPU:
676
  @spaces.GPU(duration=30)
 
1419
  with gr.Column(scale=5, min_width=200):
1420
  output_gallery = make_output_images_section()
1421
  norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
1422
+ cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
1423
  [
1424
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1425
  affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,