huzey commited on
Commit
80937a3
1 Parent(s): 5e7bae6

add aligned+recursion

Browse files
Files changed (1) hide show
  1. app.py +116 -66
app.py CHANGED
@@ -419,9 +419,22 @@ def ncut_run(
419
  metric="cosine" if i == 0 else recursion_metric,
420
  )
421
  logging_str += _logging_str
422
- rgb = dont_use_too_much_green(rgb)
423
- rgbs.append(to_pil_images(rgb))
424
- inp = eigvecs.reshape(*features.shape[:3], -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  if recursion_metric == "cosine":
426
  inp = F.normalize(inp, dim=-1)
427
  return rgbs[0], rgbs[1], rgbs[2], logging_str
@@ -472,15 +485,12 @@ def ncut_run(
472
 
473
  if "AlignedThreeModelAttnNodes" == model_name:
474
  # dirty patch for the alignedcut paper
475
- # galleries = []
476
- # for i_node in range(rgb.shape[1]):
477
- # _rgb = rgb[:, i_node]
478
- # galleries.append(to_pil_images(_rgb, target_size=56))
479
- # return *galleries, logging_str
480
  pil_images = []
481
  for i_image in range(rgb.shape[0]):
482
  _im = plot_one_image_36_grid(images[i_image], rgb[i_image])
483
  pil_images.append(_im)
 
484
  return pil_images, logging_str
485
 
486
 
@@ -516,26 +526,26 @@ def ncut_run(
516
 
517
  def _ncut_run(*args, **kwargs):
518
  n_ret = kwargs.pop("n_ret", 1)
519
- try:
520
- if torch.cuda.is_available():
521
- torch.cuda.empty_cache()
522
 
523
- ret = ncut_run(*args, **kwargs)
524
 
525
- if torch.cuda.is_available():
526
- torch.cuda.empty_cache()
527
 
528
- ret = list(ret)[:n_ret] + [ret[-1]]
529
- return ret
530
- except Exception as e:
531
- gr.Error(str(e))
532
- if torch.cuda.is_available():
533
- torch.cuda.empty_cache()
534
- return *(None for _ in range(n_ret)), "Error: " + str(e)
535
-
536
- # ret = ncut_run(*args, **kwargs)
537
- # ret = list(ret)[:n_ret] + [ret[-1]]
538
- # return ret
539
 
540
  if USE_HUGGINGFACE_ZEROGPU:
541
  @spaces.GPU(duration=20)
@@ -1018,19 +1028,28 @@ def make_dataset_images_section(advanced=False, is_random=False):
1018
  return dataset_dropdown, num_images_slider, random_seed_slider, load_images_button
1019
 
1020
 
1021
- def random_rotate_rgb_gallery(images):
1022
- if images is None or len(images) == 0:
1023
- gr.Warning("No images selected.")
1024
- return []
1025
- # read webp images
1026
- images = [Image.open(image[0]).convert("RGB") for image in images]
1027
- images = [np.array(image).astype(np.float32) for image in images]
1028
- images = np.stack(images)
1029
- images = torch.tensor(images) / 255
1030
- position = np.random.choice([1, 2, 4, 5, 6])
1031
- images = rotate_rgb_cube(images, position)
1032
- images = to_pil_images(images, resize=False)
1033
- return images
 
 
 
 
 
 
 
 
 
1034
 
1035
  def sequence_rotate_rgb_gallery(images):
1036
  if images is None or len(images) == 0:
@@ -1041,8 +1060,10 @@ def sequence_rotate_rgb_gallery(images):
1041
  images = [np.array(image).astype(np.float32) for image in images]
1042
  images = np.stack(images)
1043
  images = torch.tensor(images) / 255
 
1044
  rotation_matrix = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).float()
1045
  images = images @ rotation_matrix
 
1046
  images = to_pil_images(images, resize=False)
1047
  return images
1048
 
@@ -1055,7 +1076,9 @@ def flip_rgb_gallery(images, axis=0):
1055
  images = [np.array(image).astype(np.float32) for image in images]
1056
  images = np.stack(images)
1057
  images = torch.tensor(images) / 255
 
1058
  images = 1 - images
 
1059
  images = to_pil_images(images, resize=False)
1060
  return images
1061
 
@@ -1074,7 +1097,7 @@ def make_output_images_section():
1074
  add_output_images_buttons(output_gallery)
1075
  return output_gallery
1076
 
1077
- def make_parameters_section(is_lisa=False):
1078
  gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
1079
  from ncut_pytorch.backbone import list_models, get_demo_model_names
1080
  model_names = list_models()
@@ -1095,7 +1118,7 @@ def make_parameters_section(is_lisa=False):
1095
  negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False)
1096
  node_type_dropdown = gr.Dropdown(layer_names, label="LISA (SAM) decoder: Layer and Node", value="dec_1_block", elem_id="node_type")
1097
  else:
1098
- model_radio = gr.Radio(["CLIP", "DiNO", "Diffusion", "ImageNet", "MAE", "SAM"], label="Backbone", value="DiNO", elem_id="model_radio", show_label=True)
1099
  model_dropdown = gr.Dropdown(get_filtered_model_names("DiNO"), label="", value="DiNO(dino_vitb8_448)", elem_id="model_name", show_label=False)
1100
  model_radio.change(fn=lambda x: gr.update(choices=get_filtered_model_names(x), value=get_default_model_name(x)), inputs=model_radio, outputs=[model_dropdown])
1101
  layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
@@ -1292,6 +1315,7 @@ with demo:
1292
  dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True)
1293
  num_images_slider.value = 100
1294
  clear_images_button.visible = False
 
1295
 
1296
  with gr.Column(scale=5, min_width=200):
1297
  with gr.Accordion("➡️ Recursion config", open=True):
@@ -1311,14 +1335,6 @@ with demo:
1311
  ] = make_parameters_section()
1312
  num_eig_slider.visible = False
1313
  affinity_focal_gamma_slider.visible = False
1314
- # logging text box
1315
- with gr.Row():
1316
- with gr.Column(scale=5, min_width=200):
1317
- gr.Markdown(' ')
1318
- with gr.Column(scale=5, min_width=200):
1319
- gr.Markdown(' ')
1320
- with gr.Column(scale=5, min_width=200):
1321
- logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1322
  true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder")
1323
  true_placeholder.visible = False
1324
  false_placeholder = gr.Checkbox(label="False placeholder", value=False, elem_id="false_placeholder")
@@ -1326,13 +1342,12 @@ with demo:
1326
  number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder")
1327
  number_placeholder.visible = False
1328
  clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
1329
- false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
1330
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1331
 
1332
  submit_button.click(
1333
  partial(run_fn, n_ret=3),
1334
  inputs=[
1335
- input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1336
  positive_prompt, negative_prompt,
1337
  false_placeholder, no_prompt, no_prompt, no_prompt,
1338
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
@@ -1457,7 +1472,6 @@ with demo:
1457
  with gr.Row():
1458
  with gr.Column(scale=5, min_width=200):
1459
  input_gallery, submit_button, clear_images_button = make_input_images_section()
1460
-
1461
  dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True, is_random=True)
1462
  num_images_slider.value = 100
1463
 
@@ -1476,7 +1490,7 @@ with demo:
1476
  embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1477
  perplexity_slider, n_neighbors_slider, min_dist_slider,
1478
  sampling_method_dropdown, positive_prompt, negative_prompt
1479
- ] = make_parameters_section()
1480
  model_dropdown.value = "AlignedThreeModelAttnNodes"
1481
  model_dropdown.visible = False
1482
  layer_slider.visible = False
@@ -1505,7 +1519,7 @@ with demo:
1505
  outputs=[output_gallery, logging_text],
1506
  )
1507
 
1508
- with gr.Tab('Model Aligned (+Recursive)'):
1509
  gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
1510
  gr.Markdown('---')
1511
  gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
@@ -1513,29 +1527,58 @@ with demo:
1513
  gr.Markdown('')
1514
  gr.Markdown("To see a good pattern, you will need to load 100~1000 images. 100 images need 10sec for RTX4090. Running out of HuggingFace GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
1515
  gr.Markdown('---')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1516
  with gr.Row():
1517
  with gr.Column(scale=5, min_width=200):
1518
  input_gallery, submit_button, clear_images_button = make_input_images_section()
1519
-
1520
  dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True, is_random=True)
1521
  num_images_slider.value = 100
1522
 
1523
 
1524
  with gr.Column(scale=5, min_width=200):
1525
- output_gallery = make_output_images_section()
1526
- gr.Markdown('### TIP1: use the `full-screen` button, and use `arrow keys` to navigate')
 
 
 
 
 
 
1527
  gr.Markdown('---')
1528
  gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
1529
  gr.Markdown('Layer type: attention output (attn), without sum of residual')
1530
- gr.Markdown('### TIP2: for large image set, please increase the `num_sample` for t-SNE and NCUT')
1531
- gr.Markdown('---')
1532
  [
1533
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1534
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1535
  embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1536
  perplexity_slider, n_neighbors_slider, min_dist_slider,
1537
  sampling_method_dropdown, positive_prompt, negative_prompt
1538
- ] = make_parameters_section()
 
 
1539
  model_dropdown.value = "AlignedThreeModelAttnNodes"
1540
  model_dropdown.visible = False
1541
  layer_slider.visible = False
@@ -1545,23 +1588,30 @@ with demo:
1545
  # logging text box
1546
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1547
 
1548
- clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
1549
 
1550
- false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
 
 
 
 
 
1551
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1552
 
1553
  submit_button.click(
1554
- run_fn,
1555
  inputs=[
1556
- input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1557
  positive_prompt, negative_prompt,
1558
  false_placeholder, no_prompt, no_prompt, no_prompt,
1559
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1560
  embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1561
- perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
 
 
 
1562
  ],
1563
- # outputs=galleries + [logging_text],
1564
- outputs=[output_gallery, logging_text],
1565
  )
1566
 
1567
 
 
419
  metric="cosine" if i == 0 else recursion_metric,
420
  )
421
  logging_str += _logging_str
422
+
423
+
424
+ if "AlignedThreeModelAttnNodes" == model_name:
425
+ # dirty patch for the alignedcut paper
426
+ start = time.time()
427
+ pil_images = []
428
+ for i_image in range(rgb.shape[0]):
429
+ _im = plot_one_image_36_grid(images[i_image], rgb[i_image])
430
+ pil_images.append(_im)
431
+ rgbs.append(pil_images)
432
+ logging_str += f"plot time: {time.time() - start:.2f}s\n"
433
+ else:
434
+ rgb = dont_use_too_much_green(rgb)
435
+ rgbs.append(to_pil_images(rgb))
436
+
437
+ inp = eigvecs.reshape(*features.shape[:-1], -1)
438
  if recursion_metric == "cosine":
439
  inp = F.normalize(inp, dim=-1)
440
  return rgbs[0], rgbs[1], rgbs[2], logging_str
 
485
 
486
  if "AlignedThreeModelAttnNodes" == model_name:
487
  # dirty patch for the alignedcut paper
488
+ start = time.time()
 
 
 
 
489
  pil_images = []
490
  for i_image in range(rgb.shape[0]):
491
  _im = plot_one_image_36_grid(images[i_image], rgb[i_image])
492
  pil_images.append(_im)
493
+ logging_str += f"plot time: {time.time() - start:.2f}s\n"
494
  return pil_images, logging_str
495
 
496
 
 
526
 
527
  def _ncut_run(*args, **kwargs):
528
  n_ret = kwargs.pop("n_ret", 1)
529
+ # try:
530
+ # if torch.cuda.is_available():
531
+ # torch.cuda.empty_cache()
532
 
533
+ # ret = ncut_run(*args, **kwargs)
534
 
535
+ # if torch.cuda.is_available():
536
+ # torch.cuda.empty_cache()
537
 
538
+ # ret = list(ret)[:n_ret] + [ret[-1]]
539
+ # return ret
540
+ # except Exception as e:
541
+ # gr.Error(str(e))
542
+ # if torch.cuda.is_available():
543
+ # torch.cuda.empty_cache()
544
+ # return *(None for _ in range(n_ret)), "Error: " + str(e)
545
+
546
+ ret = ncut_run(*args, **kwargs)
547
+ ret = list(ret)[:n_ret] + [ret[-1]]
548
+ return ret
549
 
550
  if USE_HUGGINGFACE_ZEROGPU:
551
  @spaces.GPU(duration=20)
 
1028
  return dataset_dropdown, num_images_slider, random_seed_slider, load_images_button
1029
 
1030
 
1031
+ # def random_rotate_rgb_gallery(images):
1032
+ # if images is None or len(images) == 0:
1033
+ # gr.Warning("No images selected.")
1034
+ # return []
1035
+ # # read webp images
1036
+ # images = [Image.open(image[0]).convert("RGB") for image in images]
1037
+ # images = [np.array(image).astype(np.float32) for image in images]
1038
+ # images = np.stack(images)
1039
+ # images = torch.tensor(images) / 255
1040
+ # position = np.random.choice([1, 2, 4, 5, 6])
1041
+ # images = rotate_rgb_cube(images, position)
1042
+ # images = to_pil_images(images, resize=False)
1043
+ # return images
1044
+
1045
+ def protect_original_image_in_plot(original_image, rotated_images):
1046
+ plot_h, plot_w = 332, 1542
1047
+ image_h, image_w = original_image.shape[1], original_image.shape[2]
1048
+ if not (plot_h == image_h and plot_w == image_w):
1049
+ return rotated_images
1050
+ protection_w = 190
1051
+ rotated_images[:, :, :protection_w] = original_image[:, :, :protection_w]
1052
+ return rotated_images
1053
 
1054
  def sequence_rotate_rgb_gallery(images):
1055
  if images is None or len(images) == 0:
 
1060
  images = [np.array(image).astype(np.float32) for image in images]
1061
  images = np.stack(images)
1062
  images = torch.tensor(images) / 255
1063
+ original_images = images.clone()
1064
  rotation_matrix = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).float()
1065
  images = images @ rotation_matrix
1066
+ images = protect_original_image_in_plot(original_images, images)
1067
  images = to_pil_images(images, resize=False)
1068
  return images
1069
 
 
1076
  images = [np.array(image).astype(np.float32) for image in images]
1077
  images = np.stack(images)
1078
  images = torch.tensor(images) / 255
1079
+ original_images = images.clone()
1080
  images = 1 - images
1081
+ images = protect_original_image_in_plot(original_images, images)
1082
  images = to_pil_images(images, resize=False)
1083
  return images
1084
 
 
1097
  add_output_images_buttons(output_gallery)
1098
  return output_gallery
1099
 
1100
+ def make_parameters_section(is_lisa=False, model_ratio=True):
1101
  gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
1102
  from ncut_pytorch.backbone import list_models, get_demo_model_names
1103
  model_names = list_models()
 
1118
  negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False)
1119
  node_type_dropdown = gr.Dropdown(layer_names, label="LISA (SAM) decoder: Layer and Node", value="dec_1_block", elem_id="node_type")
1120
  else:
1121
+ model_radio = gr.Radio(["CLIP", "DiNO", "Diffusion", "ImageNet", "MAE", "SAM"], label="Backbone", value="DiNO", elem_id="model_radio", show_label=True, visible=model_ratio)
1122
  model_dropdown = gr.Dropdown(get_filtered_model_names("DiNO"), label="", value="DiNO(dino_vitb8_448)", elem_id="model_name", show_label=False)
1123
  model_radio.change(fn=lambda x: gr.update(choices=get_filtered_model_names(x), value=get_default_model_name(x)), inputs=model_radio, outputs=[model_dropdown])
1124
  layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
 
1315
  dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True)
1316
  num_images_slider.value = 100
1317
  clear_images_button.visible = False
1318
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1319
 
1320
  with gr.Column(scale=5, min_width=200):
1321
  with gr.Accordion("➡️ Recursion config", open=True):
 
1335
  ] = make_parameters_section()
1336
  num_eig_slider.visible = False
1337
  affinity_focal_gamma_slider.visible = False
 
 
 
 
 
 
 
 
1338
  true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder")
1339
  true_placeholder.visible = False
1340
  false_placeholder = gr.Checkbox(label="False placeholder", value=False, elem_id="false_placeholder")
 
1342
  number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder")
1343
  number_placeholder.visible = False
1344
  clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
 
1345
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1346
 
1347
  submit_button.click(
1348
  partial(run_fn, n_ret=3),
1349
  inputs=[
1350
+ input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
1351
  positive_prompt, negative_prompt,
1352
  false_placeholder, no_prompt, no_prompt, no_prompt,
1353
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
 
1472
  with gr.Row():
1473
  with gr.Column(scale=5, min_width=200):
1474
  input_gallery, submit_button, clear_images_button = make_input_images_section()
 
1475
  dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True, is_random=True)
1476
  num_images_slider.value = 100
1477
 
 
1490
  embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1491
  perplexity_slider, n_neighbors_slider, min_dist_slider,
1492
  sampling_method_dropdown, positive_prompt, negative_prompt
1493
+ ] = make_parameters_section(model_ratio=False)
1494
  model_dropdown.value = "AlignedThreeModelAttnNodes"
1495
  model_dropdown.visible = False
1496
  layer_slider.visible = False
 
1519
  outputs=[output_gallery, logging_text],
1520
  )
1521
 
1522
+ with gr.Tab('Model Aligned (+Rrecursion)'):
1523
  gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
1524
  gr.Markdown('---')
1525
  gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
 
1527
  gr.Markdown('')
1528
  gr.Markdown("To see a good pattern, you will need to load 100~1000 images. 100 images need 10sec for RTX4090. Running out of HuggingFace GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
1529
  gr.Markdown('---')
1530
+
1531
+ # with gr.Row():
1532
+ # with gr.Column(scale=5, min_width=200):
1533
+ # gr.Markdown('### Output (Recursion #1)')
1534
+ # l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=False, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
1535
+ # add_output_images_buttons(l1_gallery)
1536
+ # with gr.Column(scale=5, min_width=200):
1537
+ # gr.Markdown('### Output (Recursion #2)')
1538
+ # l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=False, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
1539
+ # add_output_images_buttons(l2_gallery)
1540
+ # with gr.Column(scale=5, min_width=200):
1541
+ # gr.Markdown('### Output (Recursion #3)')
1542
+ # l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=False, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
1543
+ # add_output_images_buttons(l3_gallery)
1544
+ gr.Markdown('### Output (Recursion #1)')
1545
+ l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=False, elem_id="ncut_l1", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
1546
+ add_output_images_buttons(l1_gallery)
1547
+ gr.Markdown('### Output (Recursion #2)')
1548
+ l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=False, elem_id="ncut_l2", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
1549
+ add_output_images_buttons(l2_gallery)
1550
+ gr.Markdown('### Output (Recursion #3)')
1551
+ l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=False, elem_id="ncut_l3", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
1552
+ add_output_images_buttons(l3_gallery)
1553
+
1554
  with gr.Row():
1555
  with gr.Column(scale=5, min_width=200):
1556
  input_gallery, submit_button, clear_images_button = make_input_images_section()
 
1557
  dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True, is_random=True)
1558
  num_images_slider.value = 100
1559
 
1560
 
1561
  with gr.Column(scale=5, min_width=200):
1562
+ with gr.Accordion("➡️ Recursion config", open=True):
1563
+ l1_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #1: N eigenvectors", value=100, elem_id="l1_num_eig")
1564
+ l2_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #2: N eigenvectors", value=50, elem_id="l2_num_eig")
1565
+ l3_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #3: N eigenvectors", value=50, elem_id="l3_num_eig")
1566
+ metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="Recursion distance metric", value="cosine", elem_id="recursion_metric")
1567
+ l1_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #1: Affinity focal gamma", value=0.5, elem_id="recursion_l1_gamma")
1568
+ l2_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #2: Affinity focal gamma", value=0.5, elem_id="recursion_l2_gamma")
1569
+ l3_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #3: Affinity focal gamma", value=0.5, elem_id="recursion_l3_gamma")
1570
  gr.Markdown('---')
1571
  gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
1572
  gr.Markdown('Layer type: attention output (attn), without sum of residual')
 
 
1573
  [
1574
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1575
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1576
  embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1577
  perplexity_slider, n_neighbors_slider, min_dist_slider,
1578
  sampling_method_dropdown, positive_prompt, negative_prompt
1579
+ ] = make_parameters_section(model_ratio=False)
1580
+ num_eig_slider.visible = False
1581
+ affinity_focal_gamma_slider.visible = False
1582
  model_dropdown.value = "AlignedThreeModelAttnNodes"
1583
  model_dropdown.visible = False
1584
  layer_slider.visible = False
 
1588
  # logging text box
1589
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1590
 
1591
+ clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
1592
 
1593
+ true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder")
1594
+ true_placeholder.visible = False
1595
+ false_placeholder = gr.Checkbox(label="False placeholder", value=False, elem_id="false_placeholder")
1596
+ false_placeholder.visible = False
1597
+ number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder")
1598
+ number_placeholder.visible = False
1599
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1600
 
1601
  submit_button.click(
1602
+ partial(run_fn, n_ret=3),
1603
  inputs=[
1604
+ input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
1605
  positive_prompt, negative_prompt,
1606
  false_placeholder, no_prompt, no_prompt, no_prompt,
1607
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1608
  embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1609
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
1610
+ false_placeholder, number_placeholder, true_placeholder,
1611
+ l2_num_eig_slider, l3_num_eig_slider, metric_dropdown,
1612
+ l1_affinity_focal_gamma_slider, l2_affinity_focal_gamma_slider, l3_affinity_focal_gamma_slider
1613
  ],
1614
+ outputs=[l1_gallery, l2_gallery, l3_gallery, logging_text],
 
1615
  )
1616
 
1617