huzey commited on
Commit
5e7bae6
1 Parent(s): 5f0daa6

fix align3model load

Browse files
Files changed (1) hide show
  1. app.py +67 -22
app.py CHANGED
@@ -678,13 +678,14 @@ def plot_one_image_36_grid(original_image, tsne_rgb_images):
678
  return img
679
 
680
  def load_alignedthreemodel():
681
-
682
- os.system("git clone https://huggingface.co/huzey/alignedthreeattn >> /dev/null 2>&1")
683
- # pull
684
- os.system("git -C alignedthreeattn pull >> /dev/null 2>&1")
685
- # add to path
686
  import sys
687
- sys.path.append("alignedthreeattn")
 
 
 
 
 
 
688
 
689
 
690
  from alignedthreeattn.alignedthreeattn_model import ThreeAttnNodes
@@ -692,11 +693,6 @@ def load_alignedthreemodel():
692
  align_weights = torch.load("alignedthreeattn/align_weights.pth")
693
  model = ThreeAttnNodes(align_weights)
694
 
695
- # url = 'https://huggingface.co/huzey/aligned_model_test/resolve/main/3attn_nodes.pth'
696
- # save_path = "alignedthreemodel.pth"
697
- # if not os.path.exists(save_path):
698
- # os.system(f"wget {url} -O {save_path} -q")
699
- # model = torch.load(save_path)
700
  return model
701
 
702
  promptable_diffusion_models = ["Diffusion(stabilityai/stable-diffusion-2)", "Diffusion(CompVis/stable-diffusion-v1-4)"]
@@ -1174,7 +1170,7 @@ with demo:
1174
  with gr.Column(scale=5, min_width=200):
1175
  input_gallery, submit_button, clear_images_button = make_input_images_section()
1176
  dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section()
1177
- logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1178
 
1179
  with gr.Column(scale=5, min_width=200):
1180
  output_gallery = make_output_images_section()
@@ -1490,17 +1486,65 @@ with demo:
1490
  # logging text box
1491
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1492
 
1493
- # galleries = []
1494
- # for i_model, model_name in enumerate(["CLIP", "DINO", "MAE"]):
1495
- # with gr.Row():
1496
- # for i_layer in range(1, 13):
1497
- # with gr.Column(scale=5, min_width=200):
1498
- # gr.Markdown(f'### {model_name} Layer {i_layer}')
1499
- # output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True)
1500
- # galleries.append(output_gallery)
1501
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1502
 
1503
- # clear_images_button.click(lambda x: [] * (len(galleries) + 1), outputs=[input_gallery] + galleries)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1504
  clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
1505
 
1506
  false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
@@ -1520,6 +1564,7 @@ with demo:
1520
  outputs=[output_gallery, logging_text],
1521
  )
1522
 
 
1523
  with gr.Tab('Compare Models'):
1524
  def add_one_model(i_model=1):
1525
  with gr.Column(scale=5, min_width=200) as col:
 
678
  return img
679
 
680
  def load_alignedthreemodel():
 
 
 
 
 
681
  import sys
682
+
683
+ if "alignedthreeattn" not in sys.path:
684
+ for _ in range(3):
685
+ os.system("git clone https://huggingface.co/huzey/alignedthreeattn >> /dev/null 2>&1")
686
+ os.system("git -C alignedthreeattn pull >> /dev/null 2>&1")
687
+ # add to path
688
+ sys.path.append("alignedthreeattn")
689
 
690
 
691
  from alignedthreeattn.alignedthreeattn_model import ThreeAttnNodes
 
693
  align_weights = torch.load("alignedthreeattn/align_weights.pth")
694
  model = ThreeAttnNodes(align_weights)
695
 
 
 
 
 
 
696
  return model
697
 
698
  promptable_diffusion_models = ["Diffusion(stabilityai/stable-diffusion-2)", "Diffusion(CompVis/stable-diffusion-v1-4)"]
 
1170
  with gr.Column(scale=5, min_width=200):
1171
  input_gallery, submit_button, clear_images_button = make_input_images_section()
1172
  dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section()
1173
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
1174
 
1175
  with gr.Column(scale=5, min_width=200):
1176
  output_gallery = make_output_images_section()
 
1486
  # logging text box
1487
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1488
 
1489
+ clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
1490
+
1491
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
1492
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1493
+
1494
+ submit_button.click(
1495
+ run_fn,
1496
+ inputs=[
1497
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1498
+ positive_prompt, negative_prompt,
1499
+ false_placeholder, no_prompt, no_prompt, no_prompt,
1500
+ affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1501
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1502
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
1503
+ ],
1504
+ # outputs=galleries + [logging_text],
1505
+ outputs=[output_gallery, logging_text],
1506
+ )
1507
+
1508
+ with gr.Tab('Model Aligned (+Recursive)'):
1509
+ gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
1510
+ gr.Markdown('---')
1511
+ gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
1512
+ gr.Markdown('NCUT is computed on the concatenated graph of all models, layers, and images. Color is **aligned** across all models and layers.')
1513
+ gr.Markdown('')
1514
+ gr.Markdown("To see a good pattern, you will need to load 100~1000 images. 100 images need 10sec for RTX4090. Running out of HuggingFace GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
1515
+ gr.Markdown('---')
1516
+ with gr.Row():
1517
+ with gr.Column(scale=5, min_width=200):
1518
+ input_gallery, submit_button, clear_images_button = make_input_images_section()
1519
 
1520
+ dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True, is_random=True)
1521
+ num_images_slider.value = 100
1522
+
1523
+
1524
+ with gr.Column(scale=5, min_width=200):
1525
+ output_gallery = make_output_images_section()
1526
+ gr.Markdown('### TIP1: use the `full-screen` button, and use `arrow keys` to navigate')
1527
+ gr.Markdown('---')
1528
+ gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
1529
+ gr.Markdown('Layer type: attention output (attn), without sum of residual')
1530
+ gr.Markdown('### TIP2: for large image set, please increase the `num_sample` for t-SNE and NCUT')
1531
+ gr.Markdown('---')
1532
+ [
1533
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1534
+ affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1535
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1536
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
1537
+ sampling_method_dropdown, positive_prompt, negative_prompt
1538
+ ] = make_parameters_section()
1539
+ model_dropdown.value = "AlignedThreeModelAttnNodes"
1540
+ model_dropdown.visible = False
1541
+ layer_slider.visible = False
1542
+ node_type_dropdown.visible = False
1543
+ num_sample_ncut_slider.value = 10000
1544
+ num_sample_tsne_slider.value = 1000
1545
+ # logging text box
1546
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1547
+
1548
  clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
1549
 
1550
  false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
 
1564
  outputs=[output_gallery, logging_text],
1565
  )
1566
 
1567
+
1568
  with gr.Tab('Compare Models'):
1569
  def add_one_model(i_model=1):
1570
  with gr.Column(scale=5, min_width=200) as col: