huzey commited on
Commit
753a147
1 Parent(s): 8a63e65

added more tabs

Browse files
Files changed (1) hide show
  1. app.py +192 -76
app.py CHANGED
@@ -266,6 +266,10 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6):
266
  magnitude = torch.norm(eigvecs, dim=-1)
267
  p = 0.8
268
  top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
 
 
 
 
269
  num_samples = 300
270
  if num_samples > top_p_idx.shape[0]:
271
  num_samples = top_p_idx.shape[0]
@@ -368,7 +372,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6):
368
  # plt.imshow(img)
369
  # plt.axis("off")
370
  # plt.show()
371
- return fig_images
372
 
373
 
374
  def ncut_run(
@@ -405,6 +409,9 @@ def ncut_run(
405
  lisa_prompt1="",
406
  lisa_prompt2="",
407
  lisa_prompt3="",
 
 
 
408
  ):
409
  progress = gr.Progress()
410
  progress(0.2, desc="Feature Extraction")
@@ -538,63 +545,53 @@ def ncut_run(
538
  )
539
  logging_str += _logging_str
540
  rgb.append(_rgb[0])
 
541
 
542
 
543
- cluster_images = None
544
- if not old_school_ncut: # ailgnedcut, joint across all images
545
- rgb, _logging_str, eigvecs = compute_ncut(
546
- features,
547
- num_eig=num_eig,
548
- num_sample_ncut=num_sample_ncut,
549
- affinity_focal_gamma=affinity_focal_gamma,
550
- knn_ncut=knn_ncut,
551
- knn_tsne=knn_tsne,
552
- num_sample_tsne=num_sample_tsne,
553
- embedding_method=embedding_method,
554
- embedding_metric=embedding_metric,
555
- perplexity=perplexity,
556
- n_neighbors=n_neighbors,
557
- min_dist=min_dist,
558
- sampling_method=sampling_method,
559
- indirect_connection=indirect_connection,
560
- make_orthogonal=make_orthogonal,
561
- metric=ncut_metric,
562
- )
563
- logging_str += _logging_str
564
-
565
- if "AlignedThreeModelAttnNodes" == model_name:
566
- # dirty patch for the alignedcut paper
567
- start = time.time()
568
- progress(0.6, desc="Plotting")
569
- pil_images = []
570
- for i_image in range(rgb.shape[0]):
571
- _im = plot_one_image_36_grid(images[i_image], rgb[i_image])
572
- pil_images.append(_im)
573
- logging_str += f"plot time: {time.time() - start:.2f}s\n"
574
- return pil_images, logging_str
575
-
576
-
577
- if is_lisa == True:
578
- # dirty patch for the LISA model
579
- galleries = []
580
- for i_prompt in range(len(lisa_prompts)):
581
- _rgb = rgb[i_prompt]
582
- galleries.append(to_pil_images(_rgb))
583
- return *galleries, logging_str
584
-
585
- rgb = dont_use_too_much_green(rgb)
586
-
587
- if not video_output:
588
- start = time.time()
589
- progress_start = 0.6
590
- progress(progress_start, desc="Plotting Clusters")
591
- h, w = features.shape[1], features.shape[2]
592
- if torch.cuda.is_available():
593
- images = images.cuda()
594
- _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
595
- cluster_images = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start)
596
- logging_str += f"plot time: {time.time() - start:.2f}s\n"
597
-
598
 
599
  if video_output:
600
  progress(0.8, desc="Saving Video")
@@ -602,9 +599,37 @@ def ncut_run(
602
  video_cache.add_video(video_path)
603
  pil_images_to_video(to_pil_images(rgb), video_path, fps=5)
604
  return video_path, logging_str
605
-
606
-
607
- return to_pil_images(rgb), cluster_images, logging_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608
 
609
 
610
  def _ncut_run(*args, **kwargs):
@@ -828,6 +853,8 @@ def run_fn(
828
  recursion_l2_gamma=0.5,
829
  recursion_l3_gamma=0.5,
830
  n_ret=1,
 
 
831
  ):
832
 
833
  progress=gr.Progress()
@@ -958,6 +985,8 @@ def run_fn(
958
  "lisa_prompt3": lisa_prompt3,
959
  "is_lisa": is_lisa,
960
  "n_ret": n_ret,
 
 
961
  }
962
  # print(kwargs)
963
 
@@ -1303,10 +1332,19 @@ def make_parameters_section(is_lisa=False, model_ratio=True):
1303
  perplexity_slider, n_neighbors_slider, min_dist_slider,
1304
  sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt]
1305
 
 
 
 
 
 
 
 
 
1306
  demo = gr.Blocks(
1307
  theme=gr.themes.Base(spacing_size='md', text_size='lg', primary_hue='blue', neutral_hue='slate', secondary_hue='pink'),
1308
  # fill_width=False,
1309
  # title="ncut-pytorch",
 
1310
  )
1311
  with demo:
1312
  with gr.Tab('AlignedCut'):
@@ -1336,7 +1374,7 @@ with demo:
1336
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1337
 
1338
  submit_button.click(
1339
- partial(run_fn, n_ret=2),
1340
  inputs=[
1341
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1342
  positive_prompt, negative_prompt,
@@ -1350,7 +1388,46 @@ with demo:
1350
  scroll_to_output=True,
1351
  )
1352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1354
 
1355
  with gr.Tab('NCut'):
1356
  gr.Markdown('#### NCut (Legacy), not aligned, no Nyström approximation')
@@ -1645,7 +1722,7 @@ with demo:
1645
  outputs=[output_gallery, logging_text],
1646
  )
1647
 
1648
- with gr.Tab('Model Aligned (+Rrecursion)'):
1649
  gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
1650
  gr.Markdown('---')
1651
  gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
@@ -1816,23 +1893,62 @@ with demo:
1816
 
1817
 
1818
  with gr.Tab('📄About'):
1819
- gr.Markdown("**This demo is for the Python package `ncut-pytorch`, please visit the [Documentation](https://ncut-pytorch.readthedocs.io/)**")
1820
- gr.Markdown("**All the models and functions used for this demo are in the Python package `ncut-pytorch`**")
1821
- gr.Markdown("---")
1822
- gr.Markdown("---")
1823
- gr.Markdown("**Normalized Cuts**, aka. spectral clustering, is a graphical method to analyze data grouping in the affinity eigenvector space. It has been widely used for unsupervised segmentation in the 2000s.")
1824
- gr.Markdown("*Normalized Cuts and Image Segmentation, Jianbo Shi and Jitendra Malik, 2000*")
1825
- gr.Markdown("---")
1826
- gr.Markdown("**We have improved NCut, with some advanced features:**")
1827
- gr.Markdown("- **Nyström** Normalized Cut, is a new approximation algorithm developed for large-scale graph cuts, a large-graph of million nodes can be processed in under 10s (cpu) or 2s (gpu).")
1828
- gr.Markdown("- **spectral-tSNE** visualization, a new method to visualize the high-dimensional eigenvector space with 3D RGB cube. Color is aligned across images, color infers distance in representation.")
1829
- gr.Markdown("*paper in prep, Yang 2024*")
1830
- gr.Markdown("*AlignedCut: Visual Concepts Discovery on Brain-Guided Universal Feature Space, Huzheng Yang, James Gee\*, and Jianbo Shi\*, 2024*")
1831
- gr.Markdown("---")
1832
- gr.Markdown("---")
1833
- gr.Markdown('<p style="text-align: center;">We thank the HuggingFace team for hosting this demo.</p>')
 
1834
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1835
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1836
  with gr.Row():
1837
  with gr.Column():
1838
  gr.Markdown("##### This demo is for `ncut-pytorch`, [Documentation](https://ncut-pytorch.readthedocs.io/) ")
 
266
  magnitude = torch.norm(eigvecs, dim=-1)
267
  p = 0.8
268
  top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
269
+
270
+ ret_magnitude = magnitude.reshape(-1, h, w)
271
+
272
+
273
  num_samples = 300
274
  if num_samples > top_p_idx.shape[0]:
275
  num_samples = top_p_idx.shape[0]
 
372
  # plt.imshow(img)
373
  # plt.axis("off")
374
  # plt.show()
375
+ return fig_images, ret_magnitude
376
 
377
 
378
  def ncut_run(
 
409
  lisa_prompt1="",
410
  lisa_prompt2="",
411
  lisa_prompt3="",
412
+ plot_clusters=False,
413
+ alignedcut_eig_norm_plot=False,
414
+ **kwargs,
415
  ):
416
  progress = gr.Progress()
417
  progress(0.2, desc="Feature Extraction")
 
545
  )
546
  logging_str += _logging_str
547
  rgb.append(_rgb[0])
548
+ return rgb, logging_str
549
 
550
 
551
+
552
+ # ailgnedcut
553
+
554
+ rgb, _logging_str, eigvecs = compute_ncut(
555
+ features,
556
+ num_eig=num_eig,
557
+ num_sample_ncut=num_sample_ncut,
558
+ affinity_focal_gamma=affinity_focal_gamma,
559
+ knn_ncut=knn_ncut,
560
+ knn_tsne=knn_tsne,
561
+ num_sample_tsne=num_sample_tsne,
562
+ embedding_method=embedding_method,
563
+ embedding_metric=embedding_metric,
564
+ perplexity=perplexity,
565
+ n_neighbors=n_neighbors,
566
+ min_dist=min_dist,
567
+ sampling_method=sampling_method,
568
+ indirect_connection=indirect_connection,
569
+ make_orthogonal=make_orthogonal,
570
+ metric=ncut_metric,
571
+ )
572
+ logging_str += _logging_str
573
+
574
+ if "AlignedThreeModelAttnNodes" == model_name:
575
+ # dirty patch for the alignedcut paper
576
+ start = time.time()
577
+ progress(0.6, desc="Plotting")
578
+ pil_images = []
579
+ for i_image in range(rgb.shape[0]):
580
+ _im = plot_one_image_36_grid(images[i_image], rgb[i_image])
581
+ pil_images.append(_im)
582
+ logging_str += f"plot time: {time.time() - start:.2f}s\n"
583
+ return pil_images, logging_str
584
+
585
+
586
+ if is_lisa == True:
587
+ # dirty patch for the LISA model
588
+ galleries = []
589
+ for i_prompt in range(len(lisa_prompts)):
590
+ _rgb = rgb[i_prompt]
591
+ galleries.append(to_pil_images(_rgb))
592
+ return *galleries, logging_str
593
+
594
+ rgb = dont_use_too_much_green(rgb)
 
 
 
 
 
 
 
 
 
 
 
595
 
596
  if video_output:
597
  progress(0.8, desc="Saving Video")
 
599
  video_cache.add_video(video_path)
600
  pil_images_to_video(to_pil_images(rgb), video_path, fps=5)
601
  return video_path, logging_str
602
+
603
+ cluster_images = None
604
+ if plot_clusters:
605
+ start = time.time()
606
+ progress_start = 0.6
607
+ progress(progress_start, desc="Plotting Clusters")
608
+ h, w = features.shape[1], features.shape[2]
609
+ if torch.cuda.is_available():
610
+ images = images.cuda()
611
+ _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
612
+ cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start)
613
+ logging_str += f"plot time: {time.time() - start:.2f}s\n"
614
+
615
+ norm_images = None
616
+ if alignedcut_eig_norm_plot:
617
+ norm_images = []
618
+ # eig_magnitude = torch.clamp(eig_magnitude, 0, 1)
619
+ vmin, vmax = eig_magnitude.min(), eig_magnitude.max()
620
+ eig_magnitude = (eig_magnitude - vmin) / (vmax - vmin)
621
+ eig_magnitude = eig_magnitude.cpu().numpy()
622
+ colormap = matplotlib.colormaps['Reds']
623
+ for i_image in range(eig_magnitude.shape[0]):
624
+ norm_image = colormap(eig_magnitude[i_image])
625
+ norm_image = (norm_image[..., :3] * 255).astype(np.uint8)
626
+ norm_images.append(Image.fromarray(norm_image))
627
+ logging_str += "Eigenvector Magnitude\n"
628
+ logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
629
+ gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=0)
630
+
631
+ return to_pil_images(rgb), cluster_images, norm_images, logging_str
632
+
633
 
634
 
635
  def _ncut_run(*args, **kwargs):
 
853
  recursion_l2_gamma=0.5,
854
  recursion_l3_gamma=0.5,
855
  n_ret=1,
856
+ plot_clusters=False,
857
+ alignedcut_eig_norm_plot=False,
858
  ):
859
 
860
  progress=gr.Progress()
 
985
  "lisa_prompt3": lisa_prompt3,
986
  "is_lisa": is_lisa,
987
  "n_ret": n_ret,
988
+ "plot_clusters": plot_clusters,
989
+ "alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
990
  }
991
  # print(kwargs)
992
 
 
1332
  perplexity_slider, n_neighbors_slider, min_dist_slider,
1333
  sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt]
1334
 
1335
+ custom_css = """
1336
+ #unlock_button {
1337
+ all: unset !important;
1338
+ }
1339
+ .form:has(#unlock_button) {
1340
+ all: unset !important;
1341
+ }
1342
+ """
1343
  demo = gr.Blocks(
1344
  theme=gr.themes.Base(spacing_size='md', text_size='lg', primary_hue='blue', neutral_hue='slate', secondary_hue='pink'),
1345
  # fill_width=False,
1346
  # title="ncut-pytorch",
1347
+ css=custom_css,
1348
  )
1349
  with demo:
1350
  with gr.Tab('AlignedCut'):
 
1374
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1375
 
1376
  submit_button.click(
1377
+ partial(run_fn, n_ret=2, plot_clusters=True),
1378
  inputs=[
1379
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1380
  positive_prompt, negative_prompt,
 
1388
  scroll_to_output=True,
1389
  )
1390
 
1391
+ with gr.Tab('AlignedCut (+Norm Plot)', visible=False) as tab_alignedcut_norm:
1392
+
1393
+ with gr.Row():
1394
+ with gr.Column(scale=5, min_width=200):
1395
+ input_gallery, submit_button, clear_images_button = make_input_images_section()
1396
+ dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section()
1397
+ num_images_slider.value = 30
1398
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
1399
+
1400
+ with gr.Column(scale=5, min_width=200):
1401
+ cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[5], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=True, interactive=False)
1402
+ output_gallery = make_output_images_section()
1403
+ norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
1404
+ [
1405
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1406
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
1407
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1408
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
1409
+ sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
1410
+ ] = make_parameters_section()
1411
+ num_eig_slider.value = 30
1412
+
1413
+ clear_images_button.click(lambda x: ([], [], [], []), outputs=[input_gallery, output_gallery, cluster_gallery, norm_gallery])
1414
+
1415
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
1416
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1417
 
1418
+ submit_button.click(
1419
+ partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True),
1420
+ inputs=[
1421
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1422
+ positive_prompt, negative_prompt,
1423
+ false_placeholder, no_prompt, no_prompt, no_prompt,
1424
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
1425
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1426
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
1427
+ ],
1428
+ outputs=[output_gallery, cluster_gallery, norm_gallery, logging_text],
1429
+ scroll_to_output=True,
1430
+ )
1431
 
1432
  with gr.Tab('NCut'):
1433
  gr.Markdown('#### NCut (Legacy), not aligned, no Nyström approximation')
 
1722
  outputs=[output_gallery, logging_text],
1723
  )
1724
 
1725
+ with gr.Tab('Model Aligned (+Rrecursion)', visible=False) as tab_model_aligned_recursion:
1726
  gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
1727
  gr.Markdown('---')
1728
  gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
 
1893
 
1894
 
1895
  with gr.Tab('📄About'):
1896
+ with gr.Column():
1897
+ gr.Markdown("**This demo is for the Python package `ncut-pytorch`, please visit the [Documentation](https://ncut-pytorch.readthedocs.io/)**")
1898
+ gr.Markdown("**All the models and functions used for this demo are in the Python package `ncut-pytorch`**")
1899
+ gr.Markdown("---")
1900
+ gr.Markdown("---")
1901
+ gr.Markdown("**Normalized Cuts**, aka. spectral clustering, is a graphical method to analyze data grouping in the affinity eigenvector space. It has been widely used for unsupervised segmentation in the 2000s.")
1902
+ gr.Markdown("*Normalized Cuts and Image Segmentation, Jianbo Shi and Jitendra Malik, 2000*")
1903
+ gr.Markdown("---")
1904
+ gr.Markdown("**We have improved NCut, with some advanced features:**")
1905
+ gr.Markdown("- **Nyström** Normalized Cut, is a new approximation algorithm developed for large-scale graph cuts, a large-graph of million nodes can be processed in under 10s (cpu) or 2s (gpu).")
1906
+ gr.Markdown("- **spectral-tSNE** visualization, a new method to visualize the high-dimensional eigenvector space with 3D RGB cube. Color is aligned across images, color infers distance in representation.")
1907
+ gr.Markdown("*paper in prep, Yang 2024*")
1908
+ gr.Markdown("*AlignedCut: Visual Concepts Discovery on Brain-Guided Universal Feature Space, Huzheng Yang, James Gee\*, and Jianbo Shi\*, 2024*")
1909
+ gr.Markdown("---")
1910
+ gr.Markdown("---")
1911
+ gr.Markdown('<p style="text-align: center;">We thank the HuggingFace team for hosting this demo.</p>')
1912
 
1913
+ # unlock the hidden tab
1914
+ with gr.Row():
1915
+ with gr.Column(scale=5):
1916
+ gr.Markdown("")
1917
+ with gr.Column(scale=5):
1918
+ hidden_button = gr.Checkbox(label="🤗", value=False, elem_id="unlock_button", visible=True, interactive=True)
1919
+ with gr.Column(scale=5):
1920
+ gr.Markdown("")
1921
+
1922
+ n_smiles = gr.State(0)
1923
+ unlock_value = 6
1924
+
1925
+ def update_smile(n_smiles):
1926
+ n_smiles = n_smiles + 1
1927
+ n_smiles = unlock_value if n_smiles > unlock_value else n_smiles
1928
+ if n_smiles == unlock_value - 2:
1929
+ gr.Info("click one more time to unlock", 2)
1930
+ if n_smiles == unlock_value:
1931
+ label = "🔓 unlocked"
1932
+ return n_smiles, gr.update(label=label, value=True, interactive=False)
1933
+ label = ["😊"] * n_smiles
1934
+ label = "".join(label)
1935
+ return n_smiles, gr.update(label=label, value=False)
1936
 
1937
+ def unlock_tabs_with_info(n_smiles):
1938
+ if n_smiles == unlock_value:
1939
+ gr.Info("🔓 unlocked tabs", 2)
1940
+ return gr.update(visible=True)
1941
+ return gr.update()
1942
+
1943
+ def unlock_tabs(n_smiles):
1944
+ if n_smiles == unlock_value:
1945
+ return gr.update(visible=True)
1946
+ return gr.update()
1947
+
1948
+ hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
1949
+ hidden_button.change(unlock_tabs_with_info, n_smiles, tab_alignedcut_norm)
1950
+ hidden_button.change(unlock_tabs, n_smiles, tab_model_aligned_recursion)
1951
+
1952
  with gr.Row():
1953
  with gr.Column():
1954
  gr.Markdown("##### This demo is for `ncut-pytorch`, [Documentation](https://ncut-pytorch.readthedocs.io/) ")