huzey commited on
Commit
4e7b524
1 Parent(s): 8f2d7ad

update cluster plot

Browse files
Files changed (1) hide show
  1. app.py +160 -28
app.py CHANGED
@@ -1,10 +1,13 @@
1
  # Author: Huzheng Yang
2
  # %%
3
  import copy
 
4
  from io import BytesIO
5
  import os
6
 
 
7
  from matplotlib import pyplot as plt
 
8
  USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
9
  DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
10
 
@@ -219,17 +222,111 @@ def run_alignedthreemodelattnnodes(images, model, batch_size=16):
219
  return outputs
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  def ncut_run(
223
  model,
224
  images,
225
- model_name="SAM(sam_vit_b)",
226
- layer=-1,
227
  num_eig=100,
228
  node_type="block",
229
- affinity_focal_gamma=0.3,
230
  num_sample_ncut=10000,
231
  knn_ncut=10,
232
- embedding_method="UMAP",
233
  embedding_metric='euclidean',
234
  num_sample_tsne=1000,
235
  knn_tsne=10,
@@ -353,8 +450,10 @@ def ncut_run(
353
  logging_str += _logging_str
354
  rgb.append(_rgb[0])
355
 
356
- if not old_school_ncut: # joint across all images
357
- rgb, _logging_str, _ = compute_ncut(
 
 
358
  features,
359
  num_eig=num_eig,
360
  num_sample_ncut=num_sample_ncut,
@@ -384,7 +483,6 @@ def ncut_run(
384
  pil_images.append(_im)
385
  return pil_images, logging_str
386
 
387
-
388
 
389
  if is_lisa == True:
390
  # dirty patch for the LISA model
@@ -396,16 +494,26 @@ def ncut_run(
396
 
397
  rgb = dont_use_too_much_green(rgb)
398
 
 
 
 
 
 
 
 
399
 
400
  if video_output:
401
  video_path = get_random_path()
402
  video_cache.add_video(video_path)
403
  pil_images_to_video(to_pil_images(rgb), video_path)
404
  return video_path, logging_str
405
- else:
406
- return to_pil_images(rgb), logging_str
 
 
407
 
408
  def _ncut_run(*args, **kwargs):
 
409
  try:
410
  if torch.cuda.is_available():
411
  torch.cuda.empty_cache()
@@ -414,15 +522,17 @@ def _ncut_run(*args, **kwargs):
414
 
415
  if torch.cuda.is_available():
416
  torch.cuda.empty_cache()
417
-
 
418
  return ret
419
  except Exception as e:
420
  gr.Error(str(e))
421
  if torch.cuda.is_available():
422
  torch.cuda.empty_cache()
423
- return [], "Error: " + str(e)
424
 
425
  # ret = ncut_run(*args, **kwargs)
 
426
  # return ret
427
 
428
  if USE_HUGGINGFACE_ZEROGPU:
@@ -488,6 +598,16 @@ def transform_image(image, resolution=(1024, 1024), stablediffusion=False):
488
  image = image * 2 - 1
489
  return image
490
 
 
 
 
 
 
 
 
 
 
 
491
  def plot_one_image_36_grid(original_image, tsne_rgb_images):
492
  mean = [0.485, 0.456, 0.406]
493
  std = [0.229, 0.224, 0.225]
@@ -583,8 +703,8 @@ promptable_segmentation_models = ["LISA(xinlai/LISA-7B-v1)"]
583
 
584
  def run_fn(
585
  images,
586
- model_name="SAM(sam_vit_b)",
587
- layer=-1,
588
  num_eig=100,
589
  node_type="block",
590
  positive_prompt="",
@@ -593,15 +713,15 @@ def run_fn(
593
  lisa_prompt1="",
594
  lisa_prompt2="",
595
  lisa_prompt3="",
596
- affinity_focal_gamma=0.3,
597
  num_sample_ncut=10000,
598
  knn_ncut=10,
599
- embedding_method="UMAP",
600
  embedding_metric='euclidean',
601
- num_sample_tsne=1000,
602
  knn_tsne=10,
603
- perplexity=500,
604
- n_neighbors=500,
605
  min_dist=0.1,
606
  sampling_method="fps",
607
  old_school_ncut=False,
@@ -613,11 +733,12 @@ def run_fn(
613
  recursion_l1_gamma=0.5,
614
  recursion_l2_gamma=0.5,
615
  recursion_l3_gamma=0.5,
 
616
  ):
617
 
618
  if images is None:
619
  gr.Warning("No images selected.")
620
- return [], "No images selected."
621
 
622
  video_output = False
623
  if isinstance(images, str):
@@ -733,6 +854,7 @@ def run_fn(
733
  "lisa_prompt2": lisa_prompt2,
734
  "lisa_prompt3": lisa_prompt3,
735
  "is_lisa": is_lisa,
 
736
  }
737
  # print(kwargs)
738
 
@@ -1042,9 +1164,11 @@ with demo:
1042
  with gr.Column(scale=5, min_width=200):
1043
  input_gallery, submit_button, clear_images_button = make_input_images_section()
1044
  dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section()
 
1045
 
1046
  with gr.Column(scale=5, min_width=200):
1047
  output_gallery = make_output_images_section()
 
1048
  [
1049
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1050
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
@@ -1052,16 +1176,15 @@ with demo:
1052
  perplexity_slider, n_neighbors_slider, min_dist_slider,
1053
  sampling_method_dropdown, positive_prompt, negative_prompt
1054
  ] = make_parameters_section()
1055
- # logging text box
1056
- logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1057
 
1058
- clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
1059
 
1060
  false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
1061
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1062
 
1063
  submit_button.click(
1064
- run_fn,
1065
  inputs=[
1066
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1067
  positive_prompt, negative_prompt,
@@ -1070,7 +1193,7 @@ with demo:
1070
  embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1071
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
1072
  ],
1073
- outputs=[output_gallery, logging_text],
1074
  api_name="API_AlignedCut"
1075
  )
1076
 
@@ -1201,7 +1324,7 @@ with demo:
1201
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1202
 
1203
  submit_button.click(
1204
- run_fn,
1205
  inputs=[
1206
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1207
  positive_prompt, negative_prompt,
@@ -1305,7 +1428,7 @@ with demo:
1305
  galleries = [l1_gallery, l2_gallery, l3_gallery]
1306
  true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder", visible=False)
1307
  submit_button.click(
1308
- run_fn,
1309
  inputs=[
1310
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1311
  positive_prompt, negative_prompt,
@@ -1465,6 +1588,7 @@ with demo:
1465
  gr.Markdown("**This demo is for the Python package `ncut-pytorch`, please visit the [Documentation](https://ncut-pytorch.readthedocs.io/)**")
1466
  gr.Markdown("**All the models and functions used for this demo are in the Python package `ncut-pytorch`**")
1467
  gr.Markdown("---")
 
1468
  gr.Markdown("**Normalized Cuts**, aka. spectral clustering, is a graphical method to analyze data grouping in the affinity eigenvector space. It has been widely used for unsupervised segmentation in the 2000s.")
1469
  gr.Markdown("*Normalized Cuts and Image Segmentation, Jianbo Shi and Jitendra Malik, 2000*")
1470
  gr.Markdown("---")
@@ -1473,7 +1597,9 @@ with demo:
1473
  gr.Markdown("- **spectral-tSNE** visualization, a new method to visualize the high-dimensional eigenvector space with 3D RGB cube. Color is aligned across images, color infers distance in representation.")
1474
  gr.Markdown("*paper in prep, Yang 2024*")
1475
  gr.Markdown("*AlignedCut: Visual Concepts Discovery on Brain-Guided Universal Feature Space, Huzheng Yang, James Gee\*, and Jianbo Shi\*, 2024*")
1476
-
 
 
1477
 
1478
 
1479
  with gr.Row():
@@ -1497,4 +1623,10 @@ if DOWNLOAD_ALL_MODELS_DATASETS:
1497
  demo.launch(share=True)
1498
 
1499
 
1500
- # %%
 
 
 
 
 
 
 
1
  # Author: Huzheng Yang
2
  # %%
3
  import copy
4
+ from functools import partial
5
  from io import BytesIO
6
  import os
7
 
8
+ from einops import rearrange
9
  from matplotlib import pyplot as plt
10
+ import matplotlib
11
  USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
12
  DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
13
 
 
222
  return outputs
223
 
224
 
225
+ def _reds_colormap(image):
226
+ # normed_data = image / image.max() # Normalize to [0, 1]
227
+ normed_data = image
228
+ colormap = matplotlib.colormaps['inferno'] # Get the Reds colormap
229
+ colored_image = colormap(normed_data) # Apply colormap
230
+ return (colored_image[..., :3] * 255).astype(np.uint8) # Convert to RGB
231
+
232
+ # heatmap images
233
+ def apply_reds_colormap(images, size):
234
+ # for i_image in range(images.shape[0]):
235
+ # images[i_image] -= images[i_image].min()
236
+ # images[i_image] /= images[i_image].max()
237
+ # normed_data = [_reds_colormap(images[i]) for i in range(images.shape[0])]
238
+ # normed_data = np.stack(normed_data)
239
+ normed_data = _reds_colormap(images)
240
+ normed_data = torch.tensor(normed_data).float()
241
+ normed_data = rearrange(normed_data, "b h w c -> b c h w")
242
+ normed_data = torch.nn.functional.interpolate(normed_data, size=size, mode="nearest")
243
+ normed_data = rearrange(normed_data, "b c h w -> b h w c")
244
+ normed_data = normed_data.cpu().numpy().astype(np.uint8)
245
+ return normed_data
246
+
247
+ # Blend heatmap with the original image
248
+ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
249
+ blended = (1 - opacity1) * image + opacity2 * heatmap
250
+ return blended.astype(np.uint8)
251
+
252
+ def make_cluster_plot(eigvecs, images, h=64, w=64):
253
+ from ncut_pytorch.ncut_pytorch import farthest_point_sampling
254
+ magnitude = torch.norm(eigvecs, dim=-1)
255
+ p = 0.5
256
+ top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
257
+ num_samples = 50
258
+ fps_idx = farthest_point_sampling(eigvecs[top_p_idx], num_samples)
259
+ fps_idx = top_p_idx[fps_idx]
260
+
261
+ # downsample to 256x256
262
+ images = F.interpolate(images, (256, 256), mode="bilinear")
263
+ images = images.cpu().numpy()
264
+ images = images.transpose(0, 2, 3, 1)
265
+ images = images * 255
266
+ images = images.astype(np.uint8)
267
+
268
+
269
+ # sort the fps_idx by the mean of the heatmap
270
+ fps_heatmaps = {}
271
+ sort_values = []
272
+ for _, idx in enumerate(fps_idx):
273
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
274
+ eigvecs = eigvecs.to(device)
275
+ heatmap = F.cosine_similarity(eigvecs, eigvecs[idx][None], dim=-1)
276
+ heatmap = heatmap.reshape(-1, h, w)
277
+ mask = (heatmap > 0.5).float()
278
+ sort_values.append(mask.mean().item())
279
+ fps_heatmaps[idx.item()] = heatmap.cpu()
280
+
281
+ fig_images = []
282
+ i_cluster = 0
283
+ for i_fig in range(10):
284
+ fig, axs = plt.subplots(3, 5, figsize=(15, 9))
285
+ for ax in axs.flatten():
286
+ ax.axis("off")
287
+ for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
288
+ heatmap = fps_heatmaps[idx.item()]
289
+ mask = (heatmap > 0.1).float()
290
+ sorted_image_idxs = torch.argsort(mask.mean((1, 2)), descending=True)
291
+ size = (images.shape[1], images.shape[2])
292
+ heatmap = apply_reds_colormap(heatmap, size)
293
+ for i, image_idx in enumerate(sorted_image_idxs[:3]):
294
+ _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[image_idx])
295
+ axs[i, j].imshow(_heatmap)
296
+ if i == 0:
297
+ axs[i, j].set_title(f"cluster {i_cluster+1}", fontsize=24)
298
+ i_cluster += 1
299
+ plt.tight_layout(h_pad=0.5, w_pad=0.3)
300
+
301
+ buf = BytesIO()
302
+ plt.savefig(buf, bbox_inches='tight', dpi=72)
303
+
304
+ buf.seek(0) # Move to the start of the BytesIO buffer
305
+ img = Image.open(buf)
306
+ img = img.convert("RGB")
307
+ img = copy.deepcopy(img)
308
+ buf.close()
309
+
310
+ fig_images.append(img)
311
+ plt.close()
312
+
313
+ # plt.imshow(img)
314
+ # plt.axis("off")
315
+ # plt.show()
316
+ return fig_images
317
+
318
+
319
  def ncut_run(
320
  model,
321
  images,
322
+ model_name="DiNO(dino_vitb8_448)",
323
+ layer=10,
324
  num_eig=100,
325
  node_type="block",
326
+ affinity_focal_gamma=0.5,
327
  num_sample_ncut=10000,
328
  knn_ncut=10,
329
+ embedding_method="tsne_3d",
330
  embedding_metric='euclidean',
331
  num_sample_tsne=1000,
332
  knn_tsne=10,
 
450
  logging_str += _logging_str
451
  rgb.append(_rgb[0])
452
 
453
+
454
+ cluster_images = None
455
+ if not old_school_ncut: # ailgnedcut, joint across all images
456
+ rgb, _logging_str, eigvecs = compute_ncut(
457
  features,
458
  num_eig=num_eig,
459
  num_sample_ncut=num_sample_ncut,
 
483
  pil_images.append(_im)
484
  return pil_images, logging_str
485
 
 
486
 
487
  if is_lisa == True:
488
  # dirty patch for the LISA model
 
494
 
495
  rgb = dont_use_too_much_green(rgb)
496
 
497
+ if not video_output:
498
+ start = time.time()
499
+ h, w = features.shape[1], features.shape[2]
500
+ _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
501
+ cluster_images = make_cluster_plot(eigvecs, _images, h=h, w=w)
502
+ logging_str += f"Plot time: {time.time() - start:.2f}s\n"
503
+
504
 
505
  if video_output:
506
  video_path = get_random_path()
507
  video_cache.add_video(video_path)
508
  pil_images_to_video(to_pil_images(rgb), video_path)
509
  return video_path, logging_str
510
+
511
+
512
+ return to_pil_images(rgb), cluster_images, logging_str
513
+
514
 
515
  def _ncut_run(*args, **kwargs):
516
+ n_ret = kwargs.pop("n_ret", 1)
517
  try:
518
  if torch.cuda.is_available():
519
  torch.cuda.empty_cache()
 
522
 
523
  if torch.cuda.is_available():
524
  torch.cuda.empty_cache()
525
+
526
+ ret = list(ret)[:n_ret] + [ret[-1]]
527
  return ret
528
  except Exception as e:
529
  gr.Error(str(e))
530
  if torch.cuda.is_available():
531
  torch.cuda.empty_cache()
532
+ return *(None for _ in range(n_ret)), "Error: " + str(e)
533
 
534
  # ret = ncut_run(*args, **kwargs)
535
+ # ret = list(ret)[:n_ret] + [ret[-1]]
536
  # return ret
537
 
538
  if USE_HUGGINGFACE_ZEROGPU:
 
598
  image = image * 2 - 1
599
  return image
600
 
601
+ def reverse_transform_image(image, stablediffusion=False):
602
+ if stablediffusion:
603
+ image = (image + 1) / 2
604
+ else:
605
+ mean = [0.485, 0.456, 0.406]
606
+ std = [0.229, 0.224, 0.225]
607
+ image = image * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1)
608
+ image = torch.clamp(image, 0, 1)
609
+ return image
610
+
611
  def plot_one_image_36_grid(original_image, tsne_rgb_images):
612
  mean = [0.485, 0.456, 0.406]
613
  std = [0.229, 0.224, 0.225]
 
703
 
704
  def run_fn(
705
  images,
706
+ model_name="DiNO(dino_vitb8_448)",
707
+ layer=10,
708
  num_eig=100,
709
  node_type="block",
710
  positive_prompt="",
 
713
  lisa_prompt1="",
714
  lisa_prompt2="",
715
  lisa_prompt3="",
716
+ affinity_focal_gamma=0.5,
717
  num_sample_ncut=10000,
718
  knn_ncut=10,
719
+ embedding_method="tsne_3d",
720
  embedding_metric='euclidean',
721
+ num_sample_tsne=300,
722
  knn_tsne=10,
723
+ perplexity=150,
724
+ n_neighbors=150,
725
  min_dist=0.1,
726
  sampling_method="fps",
727
  old_school_ncut=False,
 
733
  recursion_l1_gamma=0.5,
734
  recursion_l2_gamma=0.5,
735
  recursion_l3_gamma=0.5,
736
+ n_ret=1,
737
  ):
738
 
739
  if images is None:
740
  gr.Warning("No images selected.")
741
+ return *(None for _ in range(n_ret)), "No images selected."
742
 
743
  video_output = False
744
  if isinstance(images, str):
 
854
  "lisa_prompt2": lisa_prompt2,
855
  "lisa_prompt3": lisa_prompt3,
856
  "is_lisa": is_lisa,
857
+ "n_ret": n_ret,
858
  }
859
  # print(kwargs)
860
 
 
1164
  with gr.Column(scale=5, min_width=200):
1165
  input_gallery, submit_button, clear_images_button = make_input_images_section()
1166
  dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section()
1167
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1168
 
1169
  with gr.Column(scale=5, min_width=200):
1170
  output_gallery = make_output_images_section()
1171
+ cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=False, elem_id="clusters", columns=[2], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=True)
1172
  [
1173
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1174
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
 
1176
  perplexity_slider, n_neighbors_slider, min_dist_slider,
1177
  sampling_method_dropdown, positive_prompt, negative_prompt
1178
  ] = make_parameters_section()
1179
+ num_eig_slider.value = 30
 
1180
 
1181
+ clear_images_button.click(lambda x: ([], [], []), outputs=[input_gallery, output_gallery, cluster_gallery])
1182
 
1183
  false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
1184
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1185
 
1186
  submit_button.click(
1187
+ partial(run_fn, n_ret=2),
1188
  inputs=[
1189
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1190
  positive_prompt, negative_prompt,
 
1193
  embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1194
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
1195
  ],
1196
+ outputs=[output_gallery, cluster_gallery, logging_text],
1197
  api_name="API_AlignedCut"
1198
  )
1199
 
 
1324
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1325
 
1326
  submit_button.click(
1327
+ partial(run_fn, n_ret=3),
1328
  inputs=[
1329
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1330
  positive_prompt, negative_prompt,
 
1428
  galleries = [l1_gallery, l2_gallery, l3_gallery]
1429
  true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder", visible=False)
1430
  submit_button.click(
1431
+ partial(run_fn, n_ret=len(galleries)),
1432
  inputs=[
1433
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1434
  positive_prompt, negative_prompt,
 
1588
  gr.Markdown("**This demo is for the Python package `ncut-pytorch`, please visit the [Documentation](https://ncut-pytorch.readthedocs.io/)**")
1589
  gr.Markdown("**All the models and functions used for this demo are in the Python package `ncut-pytorch`**")
1590
  gr.Markdown("---")
1591
+ gr.Markdown("---")
1592
  gr.Markdown("**Normalized Cuts**, aka. spectral clustering, is a graphical method to analyze data grouping in the affinity eigenvector space. It has been widely used for unsupervised segmentation in the 2000s.")
1593
  gr.Markdown("*Normalized Cuts and Image Segmentation, Jianbo Shi and Jitendra Malik, 2000*")
1594
  gr.Markdown("---")
 
1597
  gr.Markdown("- **spectral-tSNE** visualization, a new method to visualize the high-dimensional eigenvector space with 3D RGB cube. Color is aligned across images, color infers distance in representation.")
1598
  gr.Markdown("*paper in prep, Yang 2024*")
1599
  gr.Markdown("*AlignedCut: Visual Concepts Discovery on Brain-Guided Universal Feature Space, Huzheng Yang, James Gee\*, and Jianbo Shi\*, 2024*")
1600
+ gr.Markdown("---")
1601
+ gr.Markdown("---")
1602
+ gr.Markdown('<p style="text-align: center;">We thank the HuggingFace team for hosting this demo.</p>')
1603
 
1604
 
1605
  with gr.Row():
 
1623
  demo.launch(share=True)
1624
 
1625
 
1626
+ # # %%
1627
+ # # debug
1628
+ # # change working directory to "/"
1629
+ # os.chdir("/")
1630
+ # images = [(Image.open(image), None) for image in default_images]
1631
+ # ret = run_fn(images, num_eig=30)
1632
+ # # %%