huzey commited on
Commit
2ade645
1 Parent(s): a3d5c5a

add advanced tab for recursive

Browse files
Files changed (1) hide show
  1. app.py +136 -26
app.py CHANGED
@@ -325,9 +325,9 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
325
  mask = mask[mask_sort_idx[:3]]
326
  sort_values.append(mask.mean().item())
327
  # fps_heatmaps[idx.item()] = heatmap.cpu()
328
- fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:10]].cpu()
329
  top3_image_idx[idx.item()] = mask_sort_idx[:3]
330
- top10_image_idx[idx.item()] = mask_sort_idx[:10]
331
  # do the sorting
332
  _sort_idx = torch.tensor(sort_values).argsort(descending=True)
333
  fps_idx = fps_idx[_sort_idx]
@@ -427,6 +427,7 @@ def ncut_run(
427
  alignedcut_eig_norm_plot=False,
428
  **kwargs,
429
  ):
 
430
  progress = gr.Progress()
431
  progress(0.2, desc="Feature Extraction")
432
 
@@ -483,6 +484,7 @@ def ncut_run(
483
 
484
  if recursion:
485
  rgbs = []
 
486
  recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma]
487
  inp = features
488
  progress_start = 0.4
@@ -509,6 +511,7 @@ def ncut_run(
509
  progess_start=progress_start,
510
  )
511
  logging_str += _logging_str
 
512
 
513
 
514
  if "AlignedThreeModelAttnNodes" == model_name:
@@ -528,8 +531,42 @@ def ncut_run(
528
  inp = eigvecs.reshape(*features.shape[:-1], -1)
529
  if recursion_metric == "cosine":
530
  inp = F.normalize(inp, dim=-1)
531
- return rgbs[0], rgbs[1], rgbs[2], logging_str
532
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  if old_school_ncut: # individual images
534
  logging_str += "Running NCut for each image independently\n"
535
  rgb = []
@@ -643,7 +680,7 @@ def ncut_run(
643
  norm_images = to_pil_images(norm_images)
644
  logging_str += "Eigenvector Magnitude\n"
645
  logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
646
- gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=0)
647
 
648
  return to_pil_images(rgb), cluster_images, norm_images, logging_str
649
 
@@ -651,26 +688,26 @@ def ncut_run(
651
 
652
  def _ncut_run(*args, **kwargs):
653
  n_ret = kwargs.pop("n_ret", 1)
654
- # try:
655
- # if torch.cuda.is_available():
656
- # torch.cuda.empty_cache()
657
 
658
- # ret = ncut_run(*args, **kwargs)
659
 
660
- # if torch.cuda.is_available():
661
- # torch.cuda.empty_cache()
662
 
663
- # ret = list(ret)[:n_ret] + [ret[-1]]
664
- # return ret
665
- # except Exception as e:
666
- # gr.Error(str(e))
667
- # if torch.cuda.is_available():
668
- # torch.cuda.empty_cache()
669
- # return *(None for _ in range(n_ret)), "Error: " + str(e)
670
-
671
- ret = ncut_run(*args, **kwargs)
672
- ret = list(ret)[:n_ret] + [ret[-1]]
673
- return ret
674
 
675
  if USE_HUGGINGFACE_ZEROGPU:
676
  @spaces.GPU(duration=30)
@@ -1407,7 +1444,7 @@ with demo:
1407
  scroll_to_output=True,
1408
  )
1409
 
1410
- with gr.Tab('AlignedCut (+Norm Plot)', visible=False) as tab_alignedcut_norm:
1411
 
1412
  with gr.Row():
1413
  with gr.Column(scale=5, min_width=200):
@@ -1582,6 +1619,78 @@ with demo:
1582
  outputs=[l1_gallery, l2_gallery, l3_gallery, logging_text],
1583
  api_name="API_RecursiveCut"
1584
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1585
 
1586
 
1587
  with gr.Tab('Video'):
@@ -1741,7 +1850,7 @@ with demo:
1741
  outputs=[output_gallery, logging_text],
1742
  )
1743
 
1744
- with gr.Tab('Model Aligned (+Rrecursion)', visible=False) as tab_model_aligned_recursion:
1745
  gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
1746
  gr.Markdown('---')
1747
  gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
@@ -1965,8 +2074,9 @@ with demo:
1965
  return gr.update()
1966
 
1967
  hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
1968
- hidden_button.change(unlock_tabs_with_info, n_smiles, tab_alignedcut_norm)
1969
- hidden_button.change(unlock_tabs, n_smiles, tab_model_aligned_recursion)
 
1970
 
1971
  with gr.Row():
1972
  with gr.Column():
 
325
  mask = mask[mask_sort_idx[:3]]
326
  sort_values.append(mask.mean().item())
327
  # fps_heatmaps[idx.item()] = heatmap.cpu()
328
+ fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:6]].cpu()
329
  top3_image_idx[idx.item()] = mask_sort_idx[:3]
330
+ top10_image_idx[idx.item()] = mask_sort_idx[:6]
331
  # do the sorting
332
  _sort_idx = torch.tensor(sort_values).argsort(descending=True)
333
  fps_idx = fps_idx[_sort_idx]
 
427
  alignedcut_eig_norm_plot=False,
428
  **kwargs,
429
  ):
430
+ advanced = kwargs.get("advanced", False)
431
  progress = gr.Progress()
432
  progress(0.2, desc="Feature Extraction")
433
 
 
484
 
485
  if recursion:
486
  rgbs = []
487
+ all_eigvecs = []
488
  recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma]
489
  inp = features
490
  progress_start = 0.4
 
511
  progess_start=progress_start,
512
  )
513
  logging_str += _logging_str
514
+ all_eigvecs.append(eigvecs.cpu().clone())
515
 
516
 
517
  if "AlignedThreeModelAttnNodes" == model_name:
 
531
  inp = eigvecs.reshape(*features.shape[:-1], -1)
532
  if recursion_metric == "cosine":
533
  inp = F.normalize(inp, dim=-1)
534
+
535
+ if not advanced:
536
+ return rgbs[0], rgbs[1], rgbs[2], logging_str
537
+ if advanced:
538
+ cluster_plots, norm_plots = [], []
539
+ for i in range(3):
540
+ eigvecs = all_eigvecs[i]
541
+ # add norm plot, cluster plot
542
+ start = time.time()
543
+ progress_start = 0.6
544
+ progress(progress_start, desc=f"Plotting Clusters Recursion #{i+1}")
545
+ h, w = features.shape[1], features.shape[2]
546
+ if torch.cuda.is_available():
547
+ images = images.cuda()
548
+ _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
549
+ cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=advanced)
550
+ logging_str += f"Recursion #{i+1} plot time: {time.time() - start:.2f}s\n"
551
+
552
+ norm_images = []
553
+ vmin, vmax = eig_magnitude.min(), eig_magnitude.max()
554
+ eig_magnitude = (eig_magnitude - vmin) / (vmax - vmin)
555
+ eig_magnitude = eig_magnitude.cpu().numpy()
556
+ colormap = matplotlib.colormaps['Reds']
557
+ for i_image in range(eig_magnitude.shape[0]):
558
+ norm_image = colormap(eig_magnitude[i_image])
559
+ norm_images.append(torch.tensor(norm_image[..., :3]))
560
+ norm_images = to_pil_images(norm_images)
561
+ logging_str += f"Recursion #{i+1} Eigenvector Magnitude: [{vmin:.2f}, {vmax:.2f}]\n"
562
+ gr.Info(f"Recursion #{i+1} Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=10)
563
+
564
+ cluster_plots.append(cluster_images)
565
+ norm_plots.append(norm_images)
566
+
567
+ return *rgbs, *norm_plots, *cluster_plots, logging_str
568
+
569
+
570
  if old_school_ncut: # individual images
571
  logging_str += "Running NCut for each image independently\n"
572
  rgb = []
 
680
  norm_images = to_pil_images(norm_images)
681
  logging_str += "Eigenvector Magnitude\n"
682
  logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
683
+ gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=10)
684
 
685
  return to_pil_images(rgb), cluster_images, norm_images, logging_str
686
 
 
688
 
689
  def _ncut_run(*args, **kwargs):
690
  n_ret = kwargs.pop("n_ret", 1)
691
+ try:
692
+ if torch.cuda.is_available():
693
+ torch.cuda.empty_cache()
694
 
695
+ ret = ncut_run(*args, **kwargs)
696
 
697
+ if torch.cuda.is_available():
698
+ torch.cuda.empty_cache()
699
 
700
+ ret = list(ret)[:n_ret] + [ret[-1]]
701
+ return ret
702
+ except Exception as e:
703
+ gr.Error(str(e))
704
+ if torch.cuda.is_available():
705
+ torch.cuda.empty_cache()
706
+ return *(None for _ in range(n_ret)), "Error: " + str(e)
707
+
708
+ # ret = ncut_run(*args, **kwargs)
709
+ # ret = list(ret)[:n_ret] + [ret[-1]]
710
+ # return ret
711
 
712
  if USE_HUGGINGFACE_ZEROGPU:
713
  @spaces.GPU(duration=30)
 
1444
  scroll_to_output=True,
1445
  )
1446
 
1447
+ with gr.Tab('AlignedCut (Advanced)', visible=False) as tab_alignedcut_advanced:
1448
 
1449
  with gr.Row():
1450
  with gr.Column(scale=5, min_width=200):
 
1619
  outputs=[l1_gallery, l2_gallery, l3_gallery, logging_text],
1620
  api_name="API_RecursiveCut"
1621
  )
1622
+
1623
+ with gr.Tab('Recursive Cut (Advanced)', visible=False) as tab_recursivecut_advanced:
1624
+
1625
+ with gr.Row():
1626
+ with gr.Column(scale=5, min_width=200):
1627
+ gr.Markdown('### Output (Recursion #1)')
1628
+ l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
1629
+ add_output_images_buttons(l1_gallery)
1630
+ l1_norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
1631
+ l1_cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
1632
+ with gr.Column(scale=5, min_width=200):
1633
+ gr.Markdown('### Output (Recursion #2)')
1634
+ l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
1635
+ add_output_images_buttons(l2_gallery)
1636
+ l2_norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
1637
+ l2_cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
1638
+ with gr.Column(scale=5, min_width=200):
1639
+ gr.Markdown('### Output (Recursion #3)')
1640
+ l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
1641
+ add_output_images_buttons(l3_gallery)
1642
+ l3_norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
1643
+ l3_cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
1644
+ with gr.Row():
1645
+ with gr.Column(scale=5, min_width=200):
1646
+ input_gallery, submit_button, clear_images_button = make_input_images_section()
1647
+ dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True)
1648
+ num_images_slider.value = 100
1649
+ clear_images_button.visible = False
1650
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1651
+
1652
+ with gr.Column(scale=5, min_width=200):
1653
+ with gr.Accordion("➡️ Recursion config", open=True):
1654
+ l1_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #1: N eigenvectors", value=100, elem_id="l1_num_eig")
1655
+ l2_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #2: N eigenvectors", value=50, elem_id="l2_num_eig")
1656
+ l3_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #3: N eigenvectors", value=50, elem_id="l3_num_eig")
1657
+ metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="Recursion distance metric", value="cosine", elem_id="recursion_metric")
1658
+ l1_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #1: Affinity focal gamma", value=0.5, elem_id="recursion_l1_gamma")
1659
+ l2_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #2: Affinity focal gamma", value=0.5, elem_id="recursion_l2_gamma")
1660
+ l3_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #3: Affinity focal gamma", value=0.5, elem_id="recursion_l3_gamma")
1661
+ [
1662
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1663
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
1664
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1665
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
1666
+ sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
1667
+ ] = make_parameters_section()
1668
+ num_eig_slider.visible = False
1669
+ affinity_focal_gamma_slider.visible = False
1670
+ true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder")
1671
+ true_placeholder.visible = False
1672
+ false_placeholder = gr.Checkbox(label="False placeholder", value=False, elem_id="false_placeholder")
1673
+ false_placeholder.visible = False
1674
+ number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder")
1675
+ number_placeholder.visible = False
1676
+ clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
1677
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1678
+
1679
+ submit_button.click(
1680
+ partial(run_fn, n_ret=9, advanced=True),
1681
+ inputs=[
1682
+ input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
1683
+ positive_prompt, negative_prompt,
1684
+ false_placeholder, no_prompt, no_prompt, no_prompt,
1685
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
1686
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1687
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown,
1688
+ false_placeholder, number_placeholder, true_placeholder,
1689
+ l2_num_eig_slider, l3_num_eig_slider, metric_dropdown,
1690
+ l1_affinity_focal_gamma_slider, l2_affinity_focal_gamma_slider, l3_affinity_focal_gamma_slider
1691
+ ],
1692
+ outputs=[l1_gallery, l2_gallery, l3_gallery, l1_norm_gallery, l2_norm_gallery, l3_norm_gallery, l1_cluster_gallery, l2_cluster_gallery, l3_cluster_gallery, logging_text],
1693
+ )
1694
 
1695
 
1696
  with gr.Tab('Video'):
 
1850
  outputs=[output_gallery, logging_text],
1851
  )
1852
 
1853
+ with gr.Tab('Model Aligned (Advanced)', visible=False) as tab_model_aligned_advanced:
1854
  gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
1855
  gr.Markdown('---')
1856
  gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
 
2074
  return gr.update()
2075
 
2076
  hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
2077
+ hidden_button.change(unlock_tabs_with_info, n_smiles, tab_alignedcut_advanced)
2078
+ hidden_button.change(unlock_tabs, n_smiles, tab_model_aligned_advanced)
2079
+ hidden_button.change(unlock_tabs, n_smiles, tab_recursivecut_advanced)
2080
 
2081
  with gr.Row():
2082
  with gr.Column():