huzey commited on
Commit
b74468d
1 Parent(s): cb0930a

update aligned, fix z-score

Browse files
Files changed (1) hide show
  1. app.py +121 -25
app.py CHANGED
@@ -1,7 +1,10 @@
1
  # Author: Huzheng Yang
2
  # %%
3
  import copy
 
4
  import os
 
 
5
  USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
6
  DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
7
 
@@ -241,7 +244,7 @@ def ncut_run(
241
  logging_str = ""
242
  if "AlignedThreeModelAttnNodes" == model_name:
243
  # dirty patch for the alignedcut paper
244
- resolution = (672, 672)
245
  else:
246
  resolution = RES_DICT[model_name]
247
  logging_str += f"Resolution: {resolution}\n"
@@ -357,11 +360,18 @@ def ncut_run(
357
 
358
  if "AlignedThreeModelAttnNodes" == model_name:
359
  # dirty patch for the alignedcut paper
360
- galleries = []
361
- for i_node in range(rgb.shape[1]):
362
- _rgb = rgb[:, i_node]
363
- galleries.append(to_pil_images(_rgb, target_size=56))
364
- return *galleries, logging_str
 
 
 
 
 
 
 
365
 
366
  if is_lisa == True:
367
  # dirty patch for the LISA model
@@ -457,9 +467,78 @@ def transform_image(image, resolution=(1024, 1024)):
457
  image = torch.tensor(np.array(image).transpose(2, 0, 1)).float()
458
  image = image / 255
459
  # Normalize
460
- image = (image - 0.5) / 0.5
 
 
461
  return image
462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  def load_alignedthreemodel():
464
 
465
  os.system("git clone https://huggingface.co/huzey/alignedthreeattn >> /dev/null 2>&1")
@@ -687,10 +766,10 @@ def make_input_video_section():
687
  clear_images_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
688
  return input_gallery, submit_button, clear_images_button, max_frames_number
689
 
690
- def make_dataset_images_section(advanced=False):
691
 
692
  gr.Markdown('### Load Datasets')
693
- load_images_button = gr.Button("Load", elem_id="load-images-button", variant='secondary')
694
  advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets", value="Advanced" if advanced else "Basic", elem_id="advanced-radio")
695
  with gr.Column() as basic_block:
696
  example_gallery = gr.Gallery(value=example_items, label="Example Set A", show_label=False, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
@@ -700,10 +779,17 @@ def make_dataset_images_section(advanced=False):
700
  with gr.Row():
701
  dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset", min_width=300)
702
  num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
703
- filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox")
704
- filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=True)
705
- is_random_checkbox = gr.Checkbox(label="Random shuffle", value=False, elem_id="random_seed_checkbox")
706
- random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=False)
 
 
 
 
 
 
 
707
 
708
  if advanced:
709
  advanced_block.visible = True
@@ -1168,12 +1254,18 @@ with demo:
1168
  with gr.Column(scale=5, min_width=200):
1169
  input_gallery, submit_button, clear_images_button = make_input_images_section()
1170
 
1171
- dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True)
1172
  num_images_slider.value = 100
1173
 
 
1174
  with gr.Column(scale=5, min_width=200):
 
 
 
1175
  gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
1176
  gr.Markdown('Layer type: attention output (attn), without sum of residual')
 
 
1177
  [
1178
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1179
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
@@ -1185,20 +1277,23 @@ with demo:
1185
  model_dropdown.visible = False
1186
  layer_slider.visible = False
1187
  node_type_dropdown.visible = False
 
 
1188
  # logging text box
1189
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1190
 
1191
- galleries = []
1192
- for i_model, model_name in enumerate(["CLIP", "DINO", "MAE"]):
1193
- with gr.Row():
1194
- for i_layer in range(1, 13):
1195
- with gr.Column(scale=5, min_width=200):
1196
- gr.Markdown(f'### {model_name} Layer {i_layer}')
1197
- output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
1198
- galleries.append(output_gallery)
1199
 
1200
 
1201
- clear_images_button.click(lambda x: [] * (len(galleries) + 1), outputs=[input_gallery] + galleries)
 
1202
 
1203
  false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
1204
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
@@ -1213,7 +1308,8 @@ with demo:
1213
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1214
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
1215
  ],
1216
- outputs=galleries + [logging_text],
 
1217
  )
1218
 
1219
  with gr.Tab('Compare Models'):
@@ -1320,4 +1416,4 @@ if DOWNLOAD_ALL_MODELS_DATASETS:
1320
  demo.launch(share=True)
1321
 
1322
 
1323
- # %%
 
1
  # Author: Huzheng Yang
2
  # %%
3
  import copy
4
+ from io import BytesIO
5
  import os
6
+
7
+ from matplotlib import pyplot as plt
8
  USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
9
  DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
10
 
 
244
  logging_str = ""
245
  if "AlignedThreeModelAttnNodes" == model_name:
246
  # dirty patch for the alignedcut paper
247
+ resolution = (224, 224)
248
  else:
249
  resolution = RES_DICT[model_name]
250
  logging_str += f"Resolution: {resolution}\n"
 
360
 
361
  if "AlignedThreeModelAttnNodes" == model_name:
362
  # dirty patch for the alignedcut paper
363
+ # galleries = []
364
+ # for i_node in range(rgb.shape[1]):
365
+ # _rgb = rgb[:, i_node]
366
+ # galleries.append(to_pil_images(_rgb, target_size=56))
367
+ # return *galleries, logging_str
368
+ pil_images = []
369
+ for i_image in range(rgb.shape[0]):
370
+ _im = plot_one_image_36_grid(images[i_image], rgb[i_image])
371
+ pil_images.append(_im)
372
+ return pil_images, logging_str
373
+
374
+
375
 
376
  if is_lisa == True:
377
  # dirty patch for the LISA model
 
467
  image = torch.tensor(np.array(image).transpose(2, 0, 1)).float()
468
  image = image / 255
469
  # Normalize
470
+ mean = [0.485, 0.456, 0.406]
471
+ std = [0.229, 0.224, 0.225]
472
+ image = (image - torch.tensor(mean).view(3, 1, 1)) / torch.tensor(std).view(3, 1, 1)
473
  return image
474
 
475
+ def plot_one_image_36_grid(original_image, tsne_rgb_images):
476
+ mean = [0.485, 0.456, 0.406]
477
+ std = [0.229, 0.224, 0.225]
478
+ original_image = original_image * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1)
479
+ original_image = torch.clamp(original_image, 0, 1)
480
+
481
+ fig = plt.figure(figsize=(20, 4))
482
+ grid = plt.GridSpec(3, 14, hspace=0.1, wspace=0.1)
483
+
484
+ ax1 = fig.add_subplot(grid[0:2, 0:2])
485
+ img = original_image.cpu().float().numpy().transpose(1, 2, 0)
486
+
487
+ def convert_and_pad_image(np_array, pad_size=20):
488
+ """
489
+ Converts a NumPy array of shape (height, width, 3) to a PNG image
490
+ and pads the right and bottom sides with a transparent background.
491
+
492
+ Args:
493
+ np_array (numpy.ndarray): Input NumPy array of shape (height, width, 3)
494
+ pad_size (int, optional): Number of pixels to pad on the right and bottom sides. Default is 20.
495
+
496
+ Returns:
497
+ PIL.Image: Padded PNG image with transparent background
498
+ """
499
+ # Convert NumPy array to PIL Image
500
+ img = Image.fromarray(np_array)
501
+
502
+ # Get the original size
503
+ width, height = img.size
504
+
505
+ # Create a new image with padding and transparent background
506
+ new_width = width + pad_size
507
+ new_height = height + pad_size
508
+ padded_img = Image.new('RGBA', (new_width, new_height), color=(255, 255, 255, 0))
509
+
510
+ # Paste the original image onto the padded image
511
+ padded_img.paste(img, (0, 0))
512
+
513
+ return padded_img
514
+
515
+ img = convert_and_pad_image((img*255).astype(np.uint8))
516
+ ax1.imshow(img)
517
+ ax1.axis('off')
518
+
519
+ model_names = ['CLIP', 'DINO', 'MAE']
520
+
521
+ for i_model, model_name in enumerate(model_names):
522
+ for i_layer in range(12):
523
+ ax = fig.add_subplot(grid[i_model, i_layer+2])
524
+ ax.imshow(tsne_rgb_images[i_layer+12*i_model].cpu().float().numpy())
525
+ ax.axis('off')
526
+ if i_model == 0:
527
+ ax.set_title(f'Layer{i_layer}', fontsize=16)
528
+ if i_layer == 0:
529
+ ax.text(-0.1, 0.5, model_name, va="center", ha="center", fontsize=16, transform=ax.transAxes, rotation=90,)
530
+ plt.tight_layout()
531
+ buf = BytesIO()
532
+ plt.savefig(buf, bbox_inches='tight', pad_inches=0, dpi=100)
533
+
534
+ buf.seek(0) # Move to the start of the BytesIO buffer
535
+ img = Image.open(buf)
536
+ img = img.convert("RGB")
537
+ img = copy.deepcopy(img)
538
+ buf.close()
539
+ plt.close()
540
+ return img
541
+
542
  def load_alignedthreemodel():
543
 
544
  os.system("git clone https://huggingface.co/huzey/alignedthreeattn >> /dev/null 2>&1")
 
766
  clear_images_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
767
  return input_gallery, submit_button, clear_images_button, max_frames_number
768
 
769
+ def make_dataset_images_section(advanced=False, is_random=False):
770
 
771
  gr.Markdown('### Load Datasets')
772
+ load_images_button = gr.Button("🟢 Load Images", elem_id="load-images-button", variant='primary')
773
  advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets", value="Advanced" if advanced else "Basic", elem_id="advanced-radio")
774
  with gr.Column() as basic_block:
775
  example_gallery = gr.Gallery(value=example_items, label="Example Set A", show_label=False, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
 
779
  with gr.Row():
780
  dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset", min_width=300)
781
  num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
782
+ if not is_random:
783
+ filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox")
784
+ filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=True)
785
+ is_random_checkbox = gr.Checkbox(label="Random shuffle", value=False, elem_id="random_seed_checkbox")
786
+ random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=False)
787
+ if is_random:
788
+ filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox")
789
+ filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=False)
790
+ is_random_checkbox = gr.Checkbox(label="Random shuffle", value=True, elem_id="random_seed_checkbox")
791
+ random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=42, elem_id="random_seed", visible=True)
792
+
793
 
794
  if advanced:
795
  advanced_block.visible = True
 
1254
  with gr.Column(scale=5, min_width=200):
1255
  input_gallery, submit_button, clear_images_button = make_input_images_section()
1256
 
1257
+ dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True, is_random=True)
1258
  num_images_slider.value = 100
1259
 
1260
+
1261
  with gr.Column(scale=5, min_width=200):
1262
+ output_gallery = make_output_images_section()
1263
+ gr.Markdown('### TIP1: use the `full-screen` button, and use `arrow keys` to navigate')
1264
+ gr.Markdown('---')
1265
  gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
1266
  gr.Markdown('Layer type: attention output (attn), without sum of residual')
1267
+ gr.Markdown('### TIP2: for large image set, please increase the `num_sample` for t-SNE and NCUT')
1268
+ gr.Markdown('---')
1269
  [
1270
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1271
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
 
1277
  model_dropdown.visible = False
1278
  layer_slider.visible = False
1279
  node_type_dropdown.visible = False
1280
+ num_sample_ncut_slider.value = 10000
1281
+ num_sample_tsne_slider.value = 1000
1282
  # logging text box
1283
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1284
 
1285
+ # galleries = []
1286
+ # for i_model, model_name in enumerate(["CLIP", "DINO", "MAE"]):
1287
+ # with gr.Row():
1288
+ # for i_layer in range(1, 13):
1289
+ # with gr.Column(scale=5, min_width=200):
1290
+ # gr.Markdown(f'### {model_name} Layer {i_layer}')
1291
+ # output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
1292
+ # galleries.append(output_gallery)
1293
 
1294
 
1295
+ # clear_images_button.click(lambda x: [] * (len(galleries) + 1), outputs=[input_gallery] + galleries)
1296
+ clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
1297
 
1298
  false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
1299
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
 
1308
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1309
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
1310
  ],
1311
+ # outputs=galleries + [logging_text],
1312
+ outputs=[output_gallery, logging_text],
1313
  )
1314
 
1315
  with gr.Tab('Compare Models'):
 
1416
  demo.launch(share=True)
1417
 
1418
 
1419
+ # %%