huzey commited on
Commit
0269da3
1 Parent(s): 3bd9672

add mlp align

Browse files
Files changed (1) hide show
  1. app.py +262 -77
app.py CHANGED
@@ -7,9 +7,9 @@ from functools import partial
7
  from io import BytesIO
8
  import json
9
  import os
10
- from pprint import pprint
11
  import uuid
12
  import zipfile
 
13
 
14
  from einops import rearrange
15
  from matplotlib import pyplot as plt
@@ -42,7 +42,6 @@ from ncut_pytorch.backbone import MODEL_DICT, LAYER_DICT, RES_DICT
42
  from ncut_pytorch import NCUT
43
  from ncut_pytorch import eigenvector_to_rgb, rotate_rgb_cube
44
 
45
- RUN_COUNT = 0
46
 
47
  DATASETS = {
48
  'Common': [
@@ -314,8 +313,7 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
314
  blended = (1 - opacity1) * image + opacity2 * heatmap
315
  return blended.astype(np.uint8)
316
 
317
- # preload the model
318
- load_model("CLIP(ViT-B-16/openai)")
319
  def segment_fg_bg(images):
320
 
321
  images = F.interpolate(images, (224, 224), mode="bilinear")
@@ -388,6 +386,8 @@ def segment_fg_bg(images):
388
 
389
 
390
  def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=False, clusters=50, eig_idx=None, title='cluster'):
 
 
391
  progress = gr.Progress()
392
  progress(progess_start, desc="Finding Clusters by FPS")
393
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -462,7 +462,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
462
  # fps_heatmaps[idx.item()] = heatmap.cpu()
463
  fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:6]].cpu()
464
  top3_image_idx[idx.item()] = mask_sort_idx[:3]
465
- top10_image_idx[idx.item()] = mask_sort_idx[:5]
466
  # do the sorting
467
  _sort_idx = torch.tensor(sort_values).argsort(descending=True)
468
  fps_idx = fps_idx[_sort_idx]
@@ -480,51 +480,46 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
480
  fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
481
 
482
 
483
- fig_images = []
484
- i_cluster = 0
485
- num_plots = clusters // 5
486
- plot_step_float = (1.0 - progess_start) / num_plots
487
- for i_fig in range(num_plots):
488
- progress(progess_start + i_fig * plot_step_float, desc=f"Plotting {title}")
489
- if not advanced:
490
- fig, axs = plt.subplots(3, 5, figsize=(15, 9))
491
- if advanced:
492
- fig, axs = plt.subplots(5, 5, figsize=(15, 15))
493
  for ax in axs.flatten():
494
  ax.axis("off")
495
- for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
496
  heatmap = fps_heatmaps[idx.item()]
497
- # mask = (heatmap > 0.1).float()
498
- # sorted_image_idxs = torch.argsort(mask.mean((1, 2)), descending=True)
499
  size = (images.shape[1], images.shape[2])
500
  heatmap = apply_reds_colormap(heatmap, size)
501
- # for i, image_idx in enumerate(sorted_image_idxs[:3]):
502
  image_idxs = top3_image_idx[idx.item()] if not advanced else top10_image_idx[idx.item()]
503
  for i, image_idx in enumerate(image_idxs):
504
- # _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[image_idx])
505
  _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
506
  axs[i, j].imshow(_heatmap)
507
  if i == 0:
508
- axs[i, j].set_title(f"{title} {i_cluster+1}", fontsize=24)
509
- i_cluster += 1
510
  plt.tight_layout(h_pad=0.5, w_pad=0.3)
511
-
512
- filename = uuid.uuid4()
513
  tmp_path = f"/tmp/{filename}.png"
514
  plt.savefig(tmp_path, bbox_inches='tight', dpi=72)
515
-
516
- img = Image.open(tmp_path)
517
- img = img.convert("RGB")
518
- img = copy.deepcopy(img)
519
-
520
  os.remove(tmp_path)
521
-
522
- fig_images.append(img)
523
  plt.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
525
  return fig_images, ret_magnitude
526
 
527
- def make_cluster_plot_advanced(eigvecs, images, h=64, w=64, num_fg=100, num_bg=10, num_other=0, small_title=True):
528
  heatmap_fg, heatmap_bg = segment_fg_bg(images.clone())
529
  heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w')
530
  heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w')
@@ -545,15 +540,9 @@ def make_cluster_plot_advanced(eigvecs, images, h=64, w=64, num_fg=100, num_bg=1
545
  bg_idx = torch.arange(heatmap_bg.shape[0])[bg_mask]
546
  other_idx = torch.arange(heatmap_fg.shape[0])[other_mask]
547
 
548
- fg_images = []
549
- if num_fg > 0:
550
- fg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=num_fg, eig_idx=fg_idx, title="fg" if small_title else "cluster")
551
- bg_images = []
552
- if num_bg > 0:
553
- bg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=num_bg, eig_idx=bg_idx, title="bg" if small_title else "cluster")
554
- other_images = []
555
- if num_other > 0:
556
- other_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=num_other, eig_idx=other_idx, title="other" if small_title else "cluster")
557
 
558
  cluster_images = fg_images + bg_images + other_images
559
 
@@ -842,7 +831,7 @@ def ncut_run(
842
  if advanced:
843
  cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w)
844
  else:
845
- cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w, num_fg=20, num_bg=0, num_other=0, small_title=False)
846
  logging_str += f"plot time: {time.time() - start:.2f}s\n"
847
 
848
  norm_images = None
@@ -903,10 +892,6 @@ if USE_HUGGINGFACE_ZEROGPU:
903
  def longer_run(*args, **kwargs):
904
  return _ncut_run(*args, **kwargs)
905
 
906
- @spaces.GPU(duration=90)
907
- def quite_long_run(*args, **kwargs):
908
- return _ncut_run(*args, **kwargs)
909
-
910
  @spaces.GPU(duration=120)
911
  def super_duper_long_run(*args, **kwargs):
912
  return _ncut_run(*args, **kwargs)
@@ -924,9 +909,6 @@ if not USE_HUGGINGFACE_ZEROGPU:
924
  def longer_run(*args, **kwargs):
925
  return _ncut_run(*args, **kwargs)
926
 
927
- def quite_long_run(*args, **kwargs):
928
- return _ncut_run(*args, **kwargs)
929
-
930
  def super_duper_long_run(*args, **kwargs):
931
  return _ncut_run(*args, **kwargs)
932
 
@@ -1241,10 +1223,7 @@ def run_fn(
1241
  "alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
1242
  "advanced": advanced,
1243
  }
1244
- global RUN_COUNT
1245
- RUN_COUNT += 1
1246
- print(f"Run Count: {RUN_COUNT}")
1247
- pprint(kwargs)
1248
 
1249
  try:
1250
  # try to aquiare GPU, can fail if the user is out of GPU quota
@@ -1256,15 +1235,13 @@ def run_fn(
1256
  return super_duper_long_run(model, images, **kwargs)
1257
 
1258
  num_images = len(images)
1259
- if num_images > 100:
1260
  return super_duper_long_run(model, images, **kwargs)
1261
  if 'diffusion' in model_name.lower():
1262
  return super_duper_long_run(model, images, **kwargs)
1263
  if recursion:
1264
  return longer_run(model, images, **kwargs)
1265
- if num_images > 50:
1266
- return quite_long_run(model, images, **kwargs)
1267
- if num_images > 30:
1268
  return longer_run(model, images, **kwargs)
1269
  if old_school_ncut:
1270
  return longer_run(model, images, **kwargs)
@@ -1284,7 +1261,7 @@ def run_fn(
1284
  except gr.Error as e:
1285
  # I assume this is a GPU quota error
1286
 
1287
- info1 = 'Running out of HuggingFace GPU Quota?</br> Please try <a style="white-space: nowrap;text-underline-offset: 2px;color: var(--body-text-color)" href="https://ncut-pytorch.readthedocs.io/en/latest/demo/">Demo hosted at UPenn</a>, passcode is: `158.130.50.41`</br>'
1288
  info2 = 'Or try use the Python package that powers this app: <a style="white-space: nowrap;text-underline-offset: 2px;color: var(--body-text-color)" href="https://ncut-pytorch.readthedocs.io/en/latest/">ncut-pytorch</a>'
1289
  info = info1 + info2
1290
 
@@ -1292,6 +1269,165 @@ def run_fn(
1292
  raise gr.Error(message, duration=0)
1293
 
1294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1295
  def make_input_video_section():
1296
  # gr.Markdown('### Input Video')
1297
  input_gallery = gr.Video(value=None, label="Select video", elem_id="video-input", height="auto", show_share_button=False, interactive=True)
@@ -1426,7 +1562,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
1426
  def make_example(name, images, dataset_name):
1427
  with gr.Row():
1428
  button = gr.Button("Load\n"+name, elem_id=f"example-{name}", elem_classes="small-button", variant='secondary', size="sm", scale=1, min_width=60)
1429
- gallery = gr.Gallery(value=images, label=name, show_label=True, columns=[3], rows=[1], interactive=False, height=80, scale=8, object_fit="cover", min_width=140)
1430
  button.click(fn=lambda: gr.update(value=load_dataset_images(True, dataset_name, 100, is_random=True, seed=42)), outputs=[input_gallery])
1431
  return gallery, button
1432
  example_items = [
@@ -1641,7 +1777,7 @@ def flip_rgb_gallery(images, axis=0):
1641
  images = to_pil_images(images, resize=False)
1642
  return images
1643
 
1644
- def add_output_images_buttons(output_gallery):
1645
  with gr.Row():
1646
  rotate_button = gr.Button("🔄 Rotate", elem_id="rotate_button", variant='secondary')
1647
  rotate_button.click(sequence_rotate_rgb_gallery, inputs=[output_gallery], outputs=[output_gallery])
@@ -1760,7 +1896,7 @@ def add_download_button(gallery, filename_prefix="output"):
1760
  def make_output_images_section():
1761
  gr.Markdown('### Output Images')
1762
  output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
1763
- add_output_images_buttons(output_gallery)
1764
  return output_gallery
1765
 
1766
  def make_parameters_section(is_lisa=False, model_ratio=True):
@@ -1880,7 +2016,7 @@ with demo:
1880
 
1881
  with gr.Column(scale=5, min_width=200):
1882
  output_gallery = make_output_images_section()
1883
- cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=True, interactive=True)
1884
  [
1885
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1886
  affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
@@ -2024,15 +2160,15 @@ with demo:
2024
  with gr.Column(scale=5, min_width=200):
2025
  gr.Markdown('### Output (Recursion #1)')
2026
  l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2027
- add_output_images_buttons(l1_gallery)
2028
  with gr.Column(scale=5, min_width=200):
2029
  gr.Markdown('### Output (Recursion #2)')
2030
  l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2031
- add_output_images_buttons(l2_gallery)
2032
  with gr.Column(scale=5, min_width=200):
2033
  gr.Markdown('### Output (Recursion #3)')
2034
  l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2035
- add_output_images_buttons(l3_gallery)
2036
  with gr.Row():
2037
 
2038
  with gr.Column(scale=5, min_width=200):
@@ -2089,7 +2225,7 @@ with demo:
2089
  with gr.Column(scale=5, min_width=200):
2090
  gr.Markdown('### Output (Recursion #1)')
2091
  l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2092
- add_output_images_buttons(l1_gallery)
2093
  add_download_button(l1_gallery, "ncut_embed_recur1")
2094
  l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2095
  add_download_button(l1_norm_gallery, "eig_norm_recur1")
@@ -2098,7 +2234,7 @@ with demo:
2098
  with gr.Column(scale=5, min_width=200):
2099
  gr.Markdown('### Output (Recursion #2)')
2100
  l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2101
- add_output_images_buttons(l2_gallery)
2102
  add_download_button(l2_gallery, "ncut_embed_recur2")
2103
  l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2104
  add_download_button(l2_norm_gallery, "eig_norm_recur2")
@@ -2107,7 +2243,7 @@ with demo:
2107
  with gr.Column(scale=5, min_width=200):
2108
  gr.Markdown('### Output (Recursion #3)')
2109
  l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2110
- add_output_images_buttons(l3_gallery)
2111
  add_download_button(l3_gallery, "ncut_embed_recur3")
2112
  l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2113
  add_download_button(l3_norm_gallery, "eig_norm_recur3")
@@ -2335,15 +2471,15 @@ with demo:
2335
  # add_output_images_buttons(l3_gallery)
2336
  gr.Markdown('### Output (Recursion #1)')
2337
  l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
2338
- add_output_images_buttons(l1_gallery)
2339
  add_download_button(l1_gallery, "modelaligned_recur1")
2340
  gr.Markdown('### Output (Recursion #2)')
2341
  l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
2342
- add_output_images_buttons(l2_gallery)
2343
  add_download_button(l2_gallery, "modelaligned_recur2")
2344
  gr.Markdown('### Output (Recursion #3)')
2345
  l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
2346
- add_output_images_buttons(l3_gallery)
2347
  add_download_button(l3_gallery, "modelaligned_recur3")
2348
 
2349
  with gr.Row():
@@ -2413,7 +2549,7 @@ with demo:
2413
  gr.Markdown(f'### Output Images')
2414
  output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=False, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2415
  submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
2416
- add_output_images_buttons(output_gallery)
2417
  [
2418
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
2419
  affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
@@ -2479,13 +2615,52 @@ with demo:
2479
  buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
2480
 
2481
  with gr.Tab('Compare Models (Advanced)', visible=False) as tab_compare_models_advanced:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2482
  def add_one_model(i_model=1):
2483
  with gr.Column(scale=5, min_width=200) as col:
2484
  gr.Markdown(f'### Output Images')
2485
  output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2486
  submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
2487
- add_output_images_buttons(output_gallery)
2488
  add_download_button(output_gallery, f"ncut_embed")
 
 
 
2489
  norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2490
  add_download_button(norm_gallery, f"eig_norm")
2491
  cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
@@ -2515,8 +2690,12 @@ with demo:
2515
  outputs=[output_gallery, cluster_gallery, norm_gallery, logging_text]
2516
  )
2517
 
2518
- return col
 
 
2519
 
 
 
2520
  with gr.Row():
2521
  with gr.Column(scale=5, min_width=200):
2522
  input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True)
@@ -2524,7 +2703,8 @@ with demo:
2524
 
2525
 
2526
  for i in range(3):
2527
- add_one_model()
 
2528
 
2529
  # Create rows and buttons in a loop
2530
  rows = []
@@ -2537,7 +2717,8 @@ with demo:
2537
  with row:
2538
  for j in range(4):
2539
  with gr.Column(scale=5, min_width=200):
2540
- add_one_model()
 
2541
 
2542
  button = gr.Button("➕ Add Compare", elem_id=f"add_button_{i}", visible=False if i > 0 else True, scale=3)
2543
  buttons.append(button)
@@ -2553,7 +2734,11 @@ with demo:
2553
  # Last button only reveals the last row and hides itself
2554
  buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
2555
  buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
2556
-
 
 
 
 
2557
 
2558
  with gr.Tab('📄About'):
2559
  with gr.Column():
 
7
  from io import BytesIO
8
  import json
9
  import os
 
10
  import uuid
11
  import zipfile
12
+ import multiprocessing as mp
13
 
14
  from einops import rearrange
15
  from matplotlib import pyplot as plt
 
42
  from ncut_pytorch import NCUT
43
  from ncut_pytorch import eigenvector_to_rgb, rotate_rgb_cube
44
 
 
45
 
46
  DATASETS = {
47
  'Common': [
 
313
  blended = (1 - opacity1) * image + opacity2 * heatmap
314
  return blended.astype(np.uint8)
315
 
316
+
 
317
  def segment_fg_bg(images):
318
 
319
  images = F.interpolate(images, (224, 224), mode="bilinear")
 
386
 
387
 
388
  def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=False, clusters=50, eig_idx=None, title='cluster'):
389
+ if clusters == 0:
390
+ return [], []
391
  progress = gr.Progress()
392
  progress(progess_start, desc="Finding Clusters by FPS")
393
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
462
  # fps_heatmaps[idx.item()] = heatmap.cpu()
463
  fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:6]].cpu()
464
  top3_image_idx[idx.item()] = mask_sort_idx[:3]
465
+ top10_image_idx[idx.item()] = mask_sort_idx[:6]
466
  # do the sorting
467
  _sort_idx = torch.tensor(sort_values).argsort(descending=True)
468
  fps_idx = fps_idx[_sort_idx]
 
480
  fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
481
 
482
 
483
+ def plot_cluster_images(fps_idx_chunk, chunk_idx):
484
+ fig, axs = plt.subplots(3, 5, figsize=(15, 9)) if not advanced else plt.subplots(6, 5, figsize=(15, 18))
 
 
 
 
 
 
 
 
485
  for ax in axs.flatten():
486
  ax.axis("off")
487
+ for j, idx in enumerate(fps_idx_chunk):
488
  heatmap = fps_heatmaps[idx.item()]
 
 
489
  size = (images.shape[1], images.shape[2])
490
  heatmap = apply_reds_colormap(heatmap, size)
 
491
  image_idxs = top3_image_idx[idx.item()] if not advanced else top10_image_idx[idx.item()]
492
  for i, image_idx in enumerate(image_idxs):
 
493
  _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
494
  axs[i, j].imshow(_heatmap)
495
  if i == 0:
496
+ axs[i, j].set_title(f"{title} {chunk_idx * 5 + j + 1}", fontsize=24)
 
497
  plt.tight_layout(h_pad=0.5, w_pad=0.3)
498
+ filename = f"{datetime.now():%Y%m%d%H%M%S%f}_{uuid.uuid4().hex}"
 
499
  tmp_path = f"/tmp/{filename}.png"
500
  plt.savefig(tmp_path, bbox_inches='tight', dpi=72)
501
+ img = Image.open(tmp_path).convert("RGB")
 
 
 
 
502
  os.remove(tmp_path)
 
 
503
  plt.close()
504
+ return img
505
+
506
+ fig_images = []
507
+ num_plots = clusters // 5
508
+ plot_step_float = (1.0 - progess_start) / num_plots
509
+ fps_idx_chunks = [fps_idx[i*5:(i+1)*5] for i in range(num_plots)]
510
+
511
+ # with mp.Pool(processes=mp.cpu_count()) as pool:
512
+ # results = [pool.apply_async(plot_cluster_images, args=(chunk, i)) for i, chunk in enumerate(fps_idx_chunks)]
513
+ # for i, result in enumerate(results):
514
+ # progress(progess_start + i * plot_step_float, desc=f"Plotted {title}")
515
+ # fig_images.append(result.get())
516
+ for i, chunk in enumerate(fps_idx_chunks):
517
+ progress(progess_start + i * plot_step_float, desc=f"Plotted {title}")
518
+ fig_images.append(plot_cluster_images(chunk, i))
519
 
520
  return fig_images, ret_magnitude
521
 
522
+ def make_cluster_plot_advanced(eigvecs, images, h=64, w=64):
523
  heatmap_fg, heatmap_bg = segment_fg_bg(images.clone())
524
  heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w')
525
  heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w')
 
540
  bg_idx = torch.arange(heatmap_bg.shape[0])[bg_mask]
541
  other_idx = torch.arange(heatmap_fg.shape[0])[other_mask]
542
 
543
+ fg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=fg_idx, title="fg")
544
+ bg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=20, eig_idx=bg_idx, title="bg")
545
+ other_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=0, eig_idx=other_idx, title="other")
 
 
 
 
 
 
546
 
547
  cluster_images = fg_images + bg_images + other_images
548
 
 
831
  if advanced:
832
  cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w)
833
  else:
834
+ cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=False)
835
  logging_str += f"plot time: {time.time() - start:.2f}s\n"
836
 
837
  norm_images = None
 
892
  def longer_run(*args, **kwargs):
893
  return _ncut_run(*args, **kwargs)
894
 
 
 
 
 
895
  @spaces.GPU(duration=120)
896
  def super_duper_long_run(*args, **kwargs):
897
  return _ncut_run(*args, **kwargs)
 
909
  def longer_run(*args, **kwargs):
910
  return _ncut_run(*args, **kwargs)
911
 
 
 
 
912
  def super_duper_long_run(*args, **kwargs):
913
  return _ncut_run(*args, **kwargs)
914
 
 
1223
  "alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
1224
  "advanced": advanced,
1225
  }
1226
+ # print(kwargs)
 
 
 
1227
 
1228
  try:
1229
  # try to aquiare GPU, can fail if the user is out of GPU quota
 
1235
  return super_duper_long_run(model, images, **kwargs)
1236
 
1237
  num_images = len(images)
1238
+ if num_images >= 100:
1239
  return super_duper_long_run(model, images, **kwargs)
1240
  if 'diffusion' in model_name.lower():
1241
  return super_duper_long_run(model, images, **kwargs)
1242
  if recursion:
1243
  return longer_run(model, images, **kwargs)
1244
+ if num_images >= 50:
 
 
1245
  return longer_run(model, images, **kwargs)
1246
  if old_school_ncut:
1247
  return longer_run(model, images, **kwargs)
 
1261
  except gr.Error as e:
1262
  # I assume this is a GPU quota error
1263
 
1264
+ info1 = 'Running out of HuggingFace GPU Quota?</br> Please try <a style="white-space: nowrap;text-underline-offset: 2px;color: var(--body-text-color)" href="https://ncut-pytorch.readthedocs.io/en/latest/demo/">Demo hosted at UPenn</a></br>'
1265
  info2 = 'Or try use the Python package that powers this app: <a style="white-space: nowrap;text-underline-offset: 2px;color: var(--body-text-color)" href="https://ncut-pytorch.readthedocs.io/en/latest/">ncut-pytorch</a>'
1266
  info = info1 + info2
1267
 
 
1269
  raise gr.Error(message, duration=0)
1270
 
1271
 
1272
+ import torch
1273
+ from torch import nn
1274
+ from torch.utils.data import Dataset, DataLoader
1275
+ import pytorch_lightning as pl
1276
+
1277
+ # Custom Dataset
1278
+ class TwoTensorDataset(Dataset):
1279
+ def __init__(self, A, B):
1280
+ self.A = A
1281
+ self.B = B
1282
+
1283
+ def __len__(self):
1284
+ return len(self.A)
1285
+
1286
+ def __getitem__(self, idx):
1287
+ return self.A[idx], self.B[idx]
1288
+
1289
+ # MLP model
1290
+ class MLP(pl.LightningModule):
1291
+ def __init__(self, num_layer=3, width=512, lr=3e-4, fitting_steps=10000, seg_loss_lambda=1.0):
1292
+ super().__init__()
1293
+ layers = [nn.Linear(3, width), nn.GELU()]
1294
+ for _ in range(num_layer - 1):
1295
+ layers.append(nn.Linear(width, width))
1296
+ layers.append(nn.GELU())
1297
+ layers.append(nn.Linear(width, 3))
1298
+ self.layers = nn.Sequential(*layers)
1299
+ self.mse_loss = nn.MSELoss()
1300
+ self.lr = lr
1301
+ self.fitting_steps = fitting_steps
1302
+ self.seg_loss_lambda = seg_loss_lambda
1303
+ self.progress = gr.Progress()
1304
+
1305
+ def forward(self, x):
1306
+ return self.layers(x)
1307
+
1308
+ def training_step(self, batch, batch_idx):
1309
+ x, y = batch
1310
+ y_hat = self.forward(x)
1311
+ loss = self.mse_loss(y_hat, y)
1312
+ # loss = torch.nn.functional.mse_loss(torch.log(y_hat), torch.log(y))
1313
+ self.log("train_loss", loss)
1314
+
1315
+ # add segmentation constraint
1316
+ bsz = x.shape[0]
1317
+ sample_size = 1000
1318
+ if bsz > sample_size:
1319
+ idx = torch.randperm(bsz)[:sample_size]
1320
+ x = x[idx]
1321
+ y_hat = y_hat[idx]
1322
+
1323
+ old_dist = torch.pdist(x, p=2)
1324
+ new_dist = torch.pdist(y_hat, p=2)
1325
+ # seg_loss = torch.log((old_dist - new_dist)).pow(2).mean()
1326
+ seg_loss = self.mse_loss(old_dist, new_dist)
1327
+ self.log("seg_loss", seg_loss)
1328
+ loss += seg_loss * self.seg_loss_lambda
1329
+
1330
+ step = self.global_step
1331
+ if step % 100 == 0:
1332
+ self.progress(step / self.fitting_steps, desc="Fitting MLP")
1333
+
1334
+ return loss
1335
+
1336
+ def predict_step(self, batch, batch_idx, dataloader_idx=None):
1337
+ x = batch[0]
1338
+ y_hat = self.forward(x)
1339
+ return y_hat
1340
+
1341
+ def configure_optimizers(self):
1342
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
1343
+ return optimizer
1344
+
1345
+
1346
+ def fit_trans(rgb1, rgb2, num_layer=3, width=512, batch_size=256, lr=3e-4, fitting_steps=10000, fps_sample=4096, seg_loss_lambda=1.0):
1347
+ A = rgb1.clone()
1348
+ B = rgb2.clone()
1349
+
1350
+ # FPS sample on the data
1351
+ from ncut_pytorch.ncut_pytorch import farthest_point_sampling
1352
+ A_idx = farthest_point_sampling(A, fps_sample)
1353
+ B_idx = farthest_point_sampling(B, fps_sample)
1354
+ A_B_idx = np.concatenate([A_idx, B_idx])
1355
+ A = A[A_B_idx]
1356
+ B = B[A_B_idx]
1357
+
1358
+ from torch.utils.data import DataLoader, TensorDataset
1359
+ # Dataset and DataLoader
1360
+ dataset = TwoTensorDataset(A, B)
1361
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
1362
+
1363
+ # Initialize model and trainer
1364
+ mlp = MLP(num_layer=num_layer, width=width, lr=lr, fitting_steps=fitting_steps, seg_loss_lambda=seg_loss_lambda)
1365
+ trainer = pl.Trainer(
1366
+ auto_scale_batch_size='power',
1367
+ max_epochs=100000,
1368
+ gpus=1,
1369
+ max_steps=fitting_steps,
1370
+ enable_checkpointing=False,
1371
+ enable_progress_bar=False,
1372
+ gradient_clip_val=1.0
1373
+ )
1374
+
1375
+ # Create a DataLoader for tensor A
1376
+ batch_size = 256 # Define your batch size
1377
+ data_loader = DataLoader(TensorDataset(rgb1), batch_size=batch_size, shuffle=False)
1378
+
1379
+
1380
+ # Train the model
1381
+ trainer.fit(mlp, dataloader)
1382
+
1383
+
1384
+ results = trainer.predict(mlp, data_loader)
1385
+ A_transformed = torch.cat(results, dim=0)
1386
+
1387
+ return A_transformed
1388
+
1389
+ if USE_HUGGINGFACE_ZEROGPU:
1390
+ @spaces.GPU(duration=60)
1391
+ def _run_mlp_fit(*args, **kwargs):
1392
+ return fit_trans(*args, **kwargs)
1393
+ else:
1394
+ def _run_mlp_fit(*args, **kwargs):
1395
+ return fit_trans(*args, **kwargs)
1396
+
1397
+
1398
+ def run_mlp_fit(input_gallery, target_gallery, num_layer=3, width=512, batch_size=256, lr=3e-4, fitting_steps=10000, fps_sample=4096, seg_loss_lambda=1.0):
1399
+ # print("Fitting MLP")
1400
+ # print("Target Gallery Length:", len(target_gallery))
1401
+ # print("Input Gallery Length:", len(input_gallery))
1402
+ if target_gallery is None or len(target_gallery) == 0:
1403
+ raise gr.Error("No target images selected. Please use the Mark button to select the target images.")
1404
+ if input_gallery is None or len(input_gallery) == 0:
1405
+ raise gr.Error("No input images selected.")
1406
+ def gallery_to_rgb(gallery):
1407
+ images = [tup[0] for tup in gallery]
1408
+ rgb = []
1409
+ for image in images:
1410
+ if isinstance(image, str):
1411
+ image = Image.open(image)
1412
+ image = image.convert('RGB')
1413
+ image = torch.tensor(np.array(image)).float()
1414
+ image = image / 255
1415
+ rgb.append(image)
1416
+ rgb = torch.stack(rgb)
1417
+ shape = rgb.shape
1418
+ rgb = rgb.reshape(-1, 3)
1419
+ return rgb, shape
1420
+
1421
+ target_rgb, target_shape = gallery_to_rgb(target_gallery)
1422
+ input_rgb, input_shape = gallery_to_rgb(input_gallery)
1423
+
1424
+ input_transformed = _run_mlp_fit(input_rgb, target_rgb, num_layer=num_layer, width=width, batch_size=batch_size, lr=lr,
1425
+ fitting_steps=fitting_steps, fps_sample=fps_sample, seg_loss_lambda=seg_loss_lambda)
1426
+ input_transformed = input_transformed.reshape(*input_shape)
1427
+ pil_images = to_pil_images(input_transformed, resize=False)
1428
+ return pil_images
1429
+
1430
+
1431
  def make_input_video_section():
1432
  # gr.Markdown('### Input Video')
1433
  input_gallery = gr.Video(value=None, label="Select video", elem_id="video-input", height="auto", show_share_button=False, interactive=True)
 
1562
  def make_example(name, images, dataset_name):
1563
  with gr.Row():
1564
  button = gr.Button("Load\n"+name, elem_id=f"example-{name}", elem_classes="small-button", variant='secondary', size="sm", scale=1, min_width=60)
1565
+ gallery = gr.Gallery(value=images, label=name, show_label=True, columns=[3], rows=[1], interactive=False, height=80, scale=8, object_fit="cover", min_width=140, allow_preview=False)
1566
  button.click(fn=lambda: gr.update(value=load_dataset_images(True, dataset_name, 100, is_random=True, seed=42)), outputs=[input_gallery])
1567
  return gallery, button
1568
  example_items = [
 
1777
  images = to_pil_images(images, resize=False)
1778
  return images
1779
 
1780
+ def add_rotate_flip_buttons(output_gallery):
1781
  with gr.Row():
1782
  rotate_button = gr.Button("🔄 Rotate", elem_id="rotate_button", variant='secondary')
1783
  rotate_button.click(sequence_rotate_rgb_gallery, inputs=[output_gallery], outputs=[output_gallery])
 
1896
  def make_output_images_section():
1897
  gr.Markdown('### Output Images')
1898
  output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
1899
+ add_rotate_flip_buttons(output_gallery)
1900
  return output_gallery
1901
 
1902
  def make_parameters_section(is_lisa=False, model_ratio=True):
 
2016
 
2017
  with gr.Column(scale=5, min_width=200):
2018
  output_gallery = make_output_images_section()
2019
+ cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2020
  [
2021
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
2022
  affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
 
2160
  with gr.Column(scale=5, min_width=200):
2161
  gr.Markdown('### Output (Recursion #1)')
2162
  l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2163
+ add_rotate_flip_buttons(l1_gallery)
2164
  with gr.Column(scale=5, min_width=200):
2165
  gr.Markdown('### Output (Recursion #2)')
2166
  l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2167
+ add_rotate_flip_buttons(l2_gallery)
2168
  with gr.Column(scale=5, min_width=200):
2169
  gr.Markdown('### Output (Recursion #3)')
2170
  l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2171
+ add_rotate_flip_buttons(l3_gallery)
2172
  with gr.Row():
2173
 
2174
  with gr.Column(scale=5, min_width=200):
 
2225
  with gr.Column(scale=5, min_width=200):
2226
  gr.Markdown('### Output (Recursion #1)')
2227
  l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2228
+ add_rotate_flip_buttons(l1_gallery)
2229
  add_download_button(l1_gallery, "ncut_embed_recur1")
2230
  l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2231
  add_download_button(l1_norm_gallery, "eig_norm_recur1")
 
2234
  with gr.Column(scale=5, min_width=200):
2235
  gr.Markdown('### Output (Recursion #2)')
2236
  l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2237
+ add_rotate_flip_buttons(l2_gallery)
2238
  add_download_button(l2_gallery, "ncut_embed_recur2")
2239
  l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2240
  add_download_button(l2_norm_gallery, "eig_norm_recur2")
 
2243
  with gr.Column(scale=5, min_width=200):
2244
  gr.Markdown('### Output (Recursion #3)')
2245
  l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2246
+ add_rotate_flip_buttons(l3_gallery)
2247
  add_download_button(l3_gallery, "ncut_embed_recur3")
2248
  l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2249
  add_download_button(l3_norm_gallery, "eig_norm_recur3")
 
2471
  # add_output_images_buttons(l3_gallery)
2472
  gr.Markdown('### Output (Recursion #1)')
2473
  l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
2474
+ add_rotate_flip_buttons(l1_gallery)
2475
  add_download_button(l1_gallery, "modelaligned_recur1")
2476
  gr.Markdown('### Output (Recursion #2)')
2477
  l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
2478
+ add_rotate_flip_buttons(l2_gallery)
2479
  add_download_button(l2_gallery, "modelaligned_recur2")
2480
  gr.Markdown('### Output (Recursion #3)')
2481
  l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
2482
+ add_rotate_flip_buttons(l3_gallery)
2483
  add_download_button(l3_gallery, "modelaligned_recur3")
2484
 
2485
  with gr.Row():
 
2549
  gr.Markdown(f'### Output Images')
2550
  output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=False, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2551
  submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
2552
+ add_rotate_flip_buttons(output_gallery)
2553
  [
2554
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
2555
  affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
 
2615
  buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
2616
 
2617
  with gr.Tab('Compare Models (Advanced)', visible=False) as tab_compare_models_advanced:
2618
+
2619
+ target_images = gr.State([])
2620
+ input_images = gr.State([])
2621
+ def add_mlp_fitting_buttons(output_gallery, mlp_gallery, target_images=target_images, input_images=input_images):
2622
+ with gr.Row():
2623
+ # mark_as_target_button = gr.Button("mark target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary')
2624
+ # mark_as_input_button = gr.Button("mark input", elem_id=f"mark_as_input_button_{output_gallery.elem_id}", variant='secondary')
2625
+ mark_as_target_button = gr.Button("🎯 Mark Target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary')
2626
+ fit_to_target_button = gr.Button("🔴 [MLP] Fit", elem_id=f"fit_to_target_button_{output_gallery.elem_id}", variant='primary')
2627
+ def mark_fn(images, text="target"):
2628
+ if images is None:
2629
+ raise gr.Error("No images selected")
2630
+ if len(images) == 0:
2631
+ raise gr.Error("No images selected")
2632
+ num_images = len(images)
2633
+ gr.Info(f"Marked {num_images} images as {text}")
2634
+ images = [(Image.open(tup[0]), []) for tup in images]
2635
+ return images
2636
+ mark_as_target_button.click(partial(mark_fn, text="target"), inputs=[output_gallery], outputs=[target_images])
2637
+ # mark_as_input_button.click(partial(mark_fn, text="input"), inputs=[output_gallery], outputs=[input_images])
2638
+
2639
+ with gr.Accordion("➡️ MLP Parameters", open=False):
2640
+ num_layers_slider = gr.Slider(2, 10, step=1, label="Number of Layers", value=3, elem_id=f"num_layers_slider_{output_gallery.elem_id}")
2641
+ width_slider = gr.Slider(128, 4096, step=128, label="Width", value=512, elem_id=f"width_slider_{output_gallery.elem_id}")
2642
+ batch_size_slider = gr.Slider(32, 4096, step=32, label="Batch Size", value=128, elem_id=f"batch_size_slider_{output_gallery.elem_id}")
2643
+ lr_slider = gr.Slider(1e-6, 1, step=1e-6, label="Learning Rate", value=3e-4, elem_id=f"lr_slider_{output_gallery.elem_id}")
2644
+ fitting_steps_slider = gr.Slider(1000, 100000, step=1000, label="Fitting Steps", value=30000, elem_id=f"fitting_steps_slider_{output_gallery.elem_id}")
2645
+ fps_sample_slider = gr.Slider(128, 50000, step=128, label="FPS Sample", value=10240, elem_id=f"fps_sample_slider_{output_gallery.elem_id}")
2646
+ segmentation_loss_lambda_slider = gr.Slider(0, 100, step=0.01, label="Segmentation Preserving Loss Lambda", value=1, elem_id=f"segmentation_loss_lambda_slider_{output_gallery.elem_id}")
2647
+
2648
+ fit_to_target_button.click(
2649
+ run_mlp_fit,
2650
+ inputs=[output_gallery, target_images, num_layers_slider, width_slider, batch_size_slider, lr_slider, fitting_steps_slider, fps_sample_slider, segmentation_loss_lambda_slider],
2651
+ outputs=[mlp_gallery],
2652
+ )
2653
+
2654
  def add_one_model(i_model=1):
2655
  with gr.Column(scale=5, min_width=200) as col:
2656
  gr.Markdown(f'### Output Images')
2657
  output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2658
  submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
2659
+ add_rotate_flip_buttons(output_gallery)
2660
  add_download_button(output_gallery, f"ncut_embed")
2661
+ mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2662
+ add_mlp_fitting_buttons(output_gallery, mlp_gallery)
2663
+ add_download_button(mlp_gallery, f"mlp_color_align")
2664
  norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2665
  add_download_button(norm_gallery, f"eig_norm")
2666
  cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
 
2690
  outputs=[output_gallery, cluster_gallery, norm_gallery, logging_text]
2691
  )
2692
 
2693
+ output_gallery.change(lambda x: gr.update(value=x), inputs=[output_gallery], outputs=[mlp_gallery])
2694
+
2695
+ return output_gallery
2696
 
2697
+ galleries = []
2698
+
2699
  with gr.Row():
2700
  with gr.Column(scale=5, min_width=200):
2701
  input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True)
 
2703
 
2704
 
2705
  for i in range(3):
2706
+ g = add_one_model()
2707
+ galleries.append(g)
2708
 
2709
  # Create rows and buttons in a loop
2710
  rows = []
 
2717
  with row:
2718
  for j in range(4):
2719
  with gr.Column(scale=5, min_width=200):
2720
+ g = add_one_model()
2721
+ galleries.append(g)
2722
 
2723
  button = gr.Button("➕ Add Compare", elem_id=f"add_button_{i}", visible=False if i > 0 else True, scale=3)
2724
  buttons.append(button)
 
2734
  # Last button only reveals the last row and hides itself
2735
  buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
2736
  buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
2737
+
2738
+
2739
+ # add MLP fitting buttons
2740
+
2741
+
2742
 
2743
  with gr.Tab('📄About'):
2744
  with gr.Column():