huzey commited on
Commit
5dac5bc
1 Parent(s): 81bd021

add directed ncut (test)

Browse files
Files changed (3) hide show
  1. app.py +382 -39
  2. directed_ncut.py +287 -0
  3. requirements.txt +1 -1
app.py CHANGED
@@ -183,6 +183,84 @@ def compute_ncut(
183
  return rgb, logging_str, eigvecs
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  def dont_use_too_much_green(image_rgb):
187
  # make sure the foval 40% of the image is red leading
188
  x1, x2 = int(image_rgb.shape[1] * 0.3), int(image_rgb.shape[1] * 0.7)
@@ -592,6 +670,8 @@ def ncut_run(
592
  **kwargs,
593
  ):
594
  advanced = kwargs.get("advanced", False)
 
 
595
  progress = gr.Progress()
596
  progress(0.2, desc="Feature Extraction")
597
 
@@ -640,6 +720,11 @@ def ncut_run(
640
  features = extract_features(
641
  images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
642
  )
 
 
 
 
 
643
  # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
644
  logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
645
  del model
@@ -768,25 +853,59 @@ def ncut_run(
768
 
769
 
770
  # ailgnedcut
771
-
772
- rgb, _logging_str, eigvecs = compute_ncut(
773
- features,
774
- num_eig=num_eig,
775
- num_sample_ncut=num_sample_ncut,
776
- affinity_focal_gamma=affinity_focal_gamma,
777
- knn_ncut=knn_ncut,
778
- knn_tsne=knn_tsne,
779
- num_sample_tsne=num_sample_tsne,
780
- embedding_method=embedding_method,
781
- embedding_metric=embedding_metric,
782
- perplexity=perplexity,
783
- n_neighbors=n_neighbors,
784
- min_dist=min_dist,
785
- sampling_method=sampling_method,
786
- indirect_connection=indirect_connection,
787
- make_orthogonal=make_orthogonal,
788
- metric=ncut_metric,
789
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
790
  logging_str += _logging_str
791
 
792
  if "AlignedThreeModelAttnNodes" == model_name:
@@ -858,26 +977,26 @@ def ncut_run(
858
 
859
  def _ncut_run(*args, **kwargs):
860
  n_ret = kwargs.pop("n_ret", 1)
861
- try:
862
- if torch.cuda.is_available():
863
- torch.cuda.empty_cache()
864
 
865
- ret = ncut_run(*args, **kwargs)
866
 
867
- if torch.cuda.is_available():
868
- torch.cuda.empty_cache()
869
 
870
- ret = list(ret)[:n_ret] + [ret[-1]]
871
- return ret
872
- except Exception as e:
873
- gr.Error(str(e))
874
- if torch.cuda.is_available():
875
- torch.cuda.empty_cache()
876
- return *(None for _ in range(n_ret)), "Error: " + str(e)
877
 
878
- # ret = ncut_run(*args, **kwargs)
879
- # ret = list(ret)[:n_ret] + [ret[-1]]
880
- # return ret
881
 
882
  if USE_HUGGINGFACE_ZEROGPU:
883
  @spaces.GPU(duration=30)
@@ -1085,12 +1204,16 @@ def run_fn(
1085
  recursion_l1_gamma=0.5,
1086
  recursion_l2_gamma=0.5,
1087
  recursion_l3_gamma=0.5,
 
 
 
1088
  n_ret=1,
1089
  plot_clusters=False,
1090
  alignedcut_eig_norm_plot=False,
1091
  advanced=False,
 
1092
  ):
1093
-
1094
  progress=gr.Progress()
1095
  progress(0, desc="Starting")
1096
 
@@ -1222,6 +1345,10 @@ def run_fn(
1222
  "plot_clusters": plot_clusters,
1223
  "alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
1224
  "advanced": advanced,
 
 
 
 
1225
  }
1226
  # print(kwargs)
1227
 
@@ -1379,7 +1506,7 @@ def fit_trans(rgb1, rgb2, num_layer=3, width=512, batch_size=256, lr=3e-4, fitti
1379
  # Train the model
1380
  trainer.fit(mlp, dataloader)
1381
 
1382
-
1383
  results = trainer.predict(mlp, data_loader)
1384
  A_transformed = torch.cat(results, dim=0)
1385
 
@@ -2734,10 +2861,226 @@ with demo:
2734
  buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
2735
  buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
2736
 
 
 
2737
 
2738
- # add MLP fitting buttons
2739
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2741
 
2742
  with gr.Tab('📄About'):
2743
  with gr.Column():
 
183
  return rgb, logging_str, eigvecs
184
 
185
 
186
+ def compute_ncut_directed(
187
+ features_1,
188
+ features_2,
189
+ num_eig=100,
190
+ num_sample_ncut=10000,
191
+ affinity_focal_gamma=0.3,
192
+ knn_ncut=10,
193
+ knn_tsne=10,
194
+ embedding_method="UMAP",
195
+ embedding_metric='euclidean',
196
+ num_sample_tsne=300,
197
+ perplexity=150,
198
+ n_neighbors=150,
199
+ min_dist=0.1,
200
+ sampling_method="QuickFPS",
201
+ metric="cosine",
202
+ indirect_connection=False,
203
+ make_orthogonal=False,
204
+ make_symmetric=False,
205
+ progess_start=0.4,
206
+ ):
207
+ print("Using directed_ncut")
208
+ print("features_1.shape", features_1.shape)
209
+ print("features_2.shape", features_2.shape)
210
+ from directed_ncut import nystrom_ncut
211
+ progress = gr.Progress()
212
+ logging_str = ""
213
+
214
+ num_nodes = np.prod(features_1.shape[:-2])
215
+ if num_nodes / 2 < num_eig:
216
+ # raise gr.Error("Number of eigenvectors should be less than half the number of nodes.")
217
+ gr.Warning("Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.")
218
+ num_eig = num_nodes // 2 - 1
219
+ logging_str += f"Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.\n"
220
+
221
+ start = time.time()
222
+ progress(progess_start+0.0, desc="NCut")
223
+ n_features = features_1.shape[-2]
224
+ _features_1 = rearrange(features_1, "b h w d c -> (b h w) (d c)")
225
+ _features_2 = rearrange(features_2, "b h w d c -> (b h w) (d c)")
226
+ eigvecs, eigvals, _ = nystrom_ncut(
227
+ _features_1,
228
+ features_B=_features_2,
229
+ num_eig=num_eig,
230
+ num_sample=num_sample_ncut,
231
+ device="cuda" if torch.cuda.is_available() else "cpu",
232
+ affinity_focal_gamma=affinity_focal_gamma,
233
+ knn=knn_ncut,
234
+ sample_method=sampling_method,
235
+ distance=metric,
236
+ normalize_features=False,
237
+ indirect_connection=indirect_connection,
238
+ make_orthogonal=make_orthogonal,
239
+ make_symmetric=make_symmetric,
240
+ n_features=n_features,
241
+ )
242
+ # print(f"NCUT time: {time.time() - start:.2f}s")
243
+ logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
244
+
245
+ start = time.time()
246
+ progress(progess_start+0.01, desc="spectral-tSNE")
247
+ _, rgb = eigenvector_to_rgb(
248
+ eigvecs,
249
+ method=embedding_method,
250
+ metric=embedding_metric,
251
+ num_sample=num_sample_tsne,
252
+ perplexity=perplexity,
253
+ n_neighbors=n_neighbors,
254
+ min_distance=min_dist,
255
+ knn=knn_tsne,
256
+ device="cuda" if torch.cuda.is_available() else "cpu",
257
+ )
258
+ logging_str += f"{embedding_method} time: {time.time() - start:.2f}s\n"
259
+
260
+ rgb = rgb.reshape(features_1.shape[:3] + (3,))
261
+ return rgb, logging_str, eigvecs
262
+
263
+
264
  def dont_use_too_much_green(image_rgb):
265
  # make sure the foval 40% of the image is red leading
266
  x1, x2 = int(image_rgb.shape[1] * 0.3), int(image_rgb.shape[1] * 0.7)
 
670
  **kwargs,
671
  ):
672
  advanced = kwargs.get("advanced", False)
673
+ directed = kwargs.get("directed", False)
674
+
675
  progress = gr.Progress()
676
  progress(0.2, desc="Feature Extraction")
677
 
 
720
  features = extract_features(
721
  images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
722
  )
723
+ if directed:
724
+ node_type2 = kwargs.get("node_type2", None)
725
+ features_B = extract_features(
726
+ images, model, node_type=node_type2, layer=layer-1, batch_size=BATCH_SIZE
727
+ )
728
  # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
729
  logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
730
  del model
 
853
 
854
 
855
  # ailgnedcut
856
+ if not directed:
857
+ rgb, _logging_str, eigvecs = compute_ncut(
858
+ features,
859
+ num_eig=num_eig,
860
+ num_sample_ncut=num_sample_ncut,
861
+ affinity_focal_gamma=affinity_focal_gamma,
862
+ knn_ncut=knn_ncut,
863
+ knn_tsne=knn_tsne,
864
+ num_sample_tsne=num_sample_tsne,
865
+ embedding_method=embedding_method,
866
+ embedding_metric=embedding_metric,
867
+ perplexity=perplexity,
868
+ n_neighbors=n_neighbors,
869
+ min_dist=min_dist,
870
+ sampling_method=sampling_method,
871
+ indirect_connection=indirect_connection,
872
+ make_orthogonal=make_orthogonal,
873
+ metric=ncut_metric,
874
+ )
875
+ if directed:
876
+ head_index_text = kwargs.get("head_index_text", None)
877
+ n_heads = features.shape[-2] # (batch, h, w, n_heads, d)
878
+ if head_index_text == 'all':
879
+ head_idx = torch.arange(n_heads)
880
+ else:
881
+ _idxs = head_index_text.split(",")
882
+ head_idx = torch.tensor([int(idx) for idx in _idxs])
883
+ features_A = features[:, :, :, head_idx, :]
884
+ features_B = features_B[:, :, :, head_idx, :]
885
+
886
+ rgb, _logging_str, eigvecs = compute_ncut_directed(
887
+ features_A,
888
+ features_B,
889
+ num_eig=num_eig,
890
+ num_sample_ncut=num_sample_ncut,
891
+ affinity_focal_gamma=affinity_focal_gamma,
892
+ knn_ncut=knn_ncut,
893
+ knn_tsne=knn_tsne,
894
+ num_sample_tsne=num_sample_tsne,
895
+ embedding_method=embedding_method,
896
+ embedding_metric=embedding_metric,
897
+ perplexity=perplexity,
898
+ n_neighbors=n_neighbors,
899
+ min_dist=min_dist,
900
+ sampling_method=sampling_method,
901
+ indirect_connection=False,
902
+ make_orthogonal=make_orthogonal,
903
+ metric=ncut_metric,
904
+ make_symmetric=kwargs.get("make_symmetric", None),
905
+ )
906
+
907
+
908
+
909
  logging_str += _logging_str
910
 
911
  if "AlignedThreeModelAttnNodes" == model_name:
 
977
 
978
  def _ncut_run(*args, **kwargs):
979
  n_ret = kwargs.pop("n_ret", 1)
980
+ # try:
981
+ # if torch.cuda.is_available():
982
+ # torch.cuda.empty_cache()
983
 
984
+ # ret = ncut_run(*args, **kwargs)
985
 
986
+ # if torch.cuda.is_available():
987
+ # torch.cuda.empty_cache()
988
 
989
+ # ret = list(ret)[:n_ret] + [ret[-1]]
990
+ # return ret
991
+ # except Exception as e:
992
+ # gr.Error(str(e))
993
+ # if torch.cuda.is_available():
994
+ # torch.cuda.empty_cache()
995
+ # return *(None for _ in range(n_ret)), "Error: " + str(e)
996
 
997
+ ret = ncut_run(*args, **kwargs)
998
+ ret = list(ret)[:n_ret] + [ret[-1]]
999
+ return ret
1000
 
1001
  if USE_HUGGINGFACE_ZEROGPU:
1002
  @spaces.GPU(duration=30)
 
1204
  recursion_l1_gamma=0.5,
1205
  recursion_l2_gamma=0.5,
1206
  recursion_l3_gamma=0.5,
1207
+ node_type2="k",
1208
+ head_index_text='all',
1209
+ make_symmetric=False,
1210
  n_ret=1,
1211
  plot_clusters=False,
1212
  alignedcut_eig_norm_plot=False,
1213
  advanced=False,
1214
+ directed=False,
1215
  ):
1216
+ print(node_type2, head_index_text, make_symmetric)
1217
  progress=gr.Progress()
1218
  progress(0, desc="Starting")
1219
 
 
1345
  "plot_clusters": plot_clusters,
1346
  "alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
1347
  "advanced": advanced,
1348
+ "directed": directed,
1349
+ "node_type2": node_type2,
1350
+ "head_index_text": head_index_text,
1351
+ "make_symmetric": make_symmetric,
1352
  }
1353
  # print(kwargs)
1354
 
 
1506
  # Train the model
1507
  trainer.fit(mlp, dataloader)
1508
 
1509
+ mlp.progress(0.99, desc="Applying MLP")
1510
  results = trainer.predict(mlp, data_loader)
1511
  A_transformed = torch.cat(results, dim=0)
1512
 
 
2861
  buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
2862
  buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
2863
 
2864
+
2865
+ with gr.Tab('Directed (experimental)', visible=True) as tab_directed_ncut:
2866
 
2867
+ target_images = gr.State([])
2868
+ input_images = gr.State([])
2869
+ def add_mlp_fitting_buttons(output_gallery, mlp_gallery, target_images=target_images, input_images=input_images):
2870
+ with gr.Row():
2871
+ # mark_as_target_button = gr.Button("mark target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary')
2872
+ # mark_as_input_button = gr.Button("mark input", elem_id=f"mark_as_input_button_{output_gallery.elem_id}", variant='secondary')
2873
+ mark_as_target_button = gr.Button("🎯 Mark Target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary')
2874
+ fit_to_target_button = gr.Button("🔴 [MLP] Fit", elem_id=f"fit_to_target_button_{output_gallery.elem_id}", variant='primary')
2875
+ def mark_fn(images, text="target"):
2876
+ if images is None:
2877
+ raise gr.Error("No images selected")
2878
+ if len(images) == 0:
2879
+ raise gr.Error("No images selected")
2880
+ num_images = len(images)
2881
+ gr.Info(f"Marked {num_images} images as {text}")
2882
+ images = [(Image.open(tup[0]), []) for tup in images]
2883
+ return images
2884
+ mark_as_target_button.click(partial(mark_fn, text="target"), inputs=[output_gallery], outputs=[target_images])
2885
+ # mark_as_input_button.click(partial(mark_fn, text="input"), inputs=[output_gallery], outputs=[input_images])
2886
+
2887
+ with gr.Accordion("➡️ MLP Parameters", open=False):
2888
+ num_layers_slider = gr.Slider(2, 10, step=1, label="Number of Layers", value=3, elem_id=f"num_layers_slider_{output_gallery.elem_id}")
2889
+ width_slider = gr.Slider(128, 4096, step=128, label="Width", value=512, elem_id=f"width_slider_{output_gallery.elem_id}")
2890
+ batch_size_slider = gr.Slider(32, 4096, step=32, label="Batch Size", value=128, elem_id=f"batch_size_slider_{output_gallery.elem_id}")
2891
+ lr_slider = gr.Slider(1e-6, 1, step=1e-6, label="Learning Rate", value=3e-4, elem_id=f"lr_slider_{output_gallery.elem_id}")
2892
+ fitting_steps_slider = gr.Slider(1000, 100000, step=1000, label="Fitting Steps", value=30000, elem_id=f"fitting_steps_slider_{output_gallery.elem_id}")
2893
+ fps_sample_slider = gr.Slider(128, 50000, step=128, label="FPS Sample", value=10240, elem_id=f"fps_sample_slider_{output_gallery.elem_id}")
2894
+ segmentation_loss_lambda_slider = gr.Slider(0, 100, step=0.01, label="Segmentation Preserving Loss Lambda", value=1, elem_id=f"segmentation_loss_lambda_slider_{output_gallery.elem_id}")
2895
+
2896
+ fit_to_target_button.click(
2897
+ run_mlp_fit,
2898
+ inputs=[output_gallery, target_images, num_layers_slider, width_slider, batch_size_slider, lr_slider, fitting_steps_slider, fps_sample_slider, segmentation_loss_lambda_slider],
2899
+ outputs=[mlp_gallery],
2900
+ )
2901
+
2902
+ def make_parameters_section_2model(model_ratio=True):
2903
+ gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
2904
+ from ncut_pytorch.backbone import list_models, get_demo_model_names
2905
+ model_names = list_models()
2906
+ model_names = sorted(model_names)
2907
+ # only CLIP DINO MAE is implemented for q k v
2908
+ ok_models = ["CLIP(ViT", "DiNO(", "MAE("]
2909
+ model_names = [m for m in model_names if any(ok in m for ok in ok_models)]
2910
+
2911
+ def get_filtered_model_names(name):
2912
+ return [m for m in model_names if name.lower() in m.lower()]
2913
+ def get_default_model_name(name):
2914
+ lst = get_filtered_model_names(name)
2915
+ if len(lst) > 1:
2916
+ return lst[1]
2917
+ return lst[0]
2918
+
2919
+
2920
+ model_radio = gr.Radio(["CLIP", "DiNO", "MAE"], label="Backbone", value="DiNO", elem_id="model_radio", show_label=True, visible=model_ratio)
2921
+ model_dropdown = gr.Dropdown(get_filtered_model_names("DiNO"), label="", value="DiNO(dino_vitb8_448)", elem_id="model_name", show_label=False)
2922
+ model_radio.change(fn=lambda x: gr.update(choices=get_filtered_model_names(x), value=get_default_model_name(x)), inputs=model_radio, outputs=[model_dropdown])
2923
+ layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
2924
+ positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'")
2925
+ positive_prompt.visible = False
2926
+ negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'")
2927
+ negative_prompt.visible = False
2928
+ node_type_dropdown = gr.Dropdown(['q', 'k', 'v'],
2929
+ label="Left-side Node Type", value="q", elem_id="node_type", info="In directed case, left-side SVD eigenvector is taken")
2930
+ node_type_dropdown2 = gr.Dropdown(['q', 'k', 'v'],
2931
+ label="Right-side Node Type", value="k", elem_id="node_type2")
2932
+ head_index_text = gr.Textbox(value='all', label="Head Index", elem_id="head_index", type="text", info="which attention heads to use, comma separated, e.g. 0,1,2")
2933
+ make_symmetric = gr.Checkbox(label="Make Symmetric", value=False, elem_id="make_symmetric", info="make the graph symmetric by A = (A + A.T) / 2")
2934
+
2935
+ num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for smaller clusters')
2936
+
2937
+ def change_layer_slider(model_name):
2938
+ # SD2, UNET
2939
+ if "stable" in model_name.lower() and "diffusion" in model_name.lower():
2940
+ from ncut_pytorch.backbone import SD_KEY_DICT
2941
+ default_layer = 'up_2_resnets_1_block' if 'diffusion-3' not in model_name else 'block_23'
2942
+ return (gr.Slider(1, 49, step=1, label="Diffusion: Timestep (Noise)", value=5, elem_id="layer", visible=True, info="Noise level, 50 is max noise"),
2943
+ gr.Dropdown(SD_KEY_DICT[model_name], label="Diffusion: Layer and Node", value=default_layer, elem_id="node_type", info="U-Net (v1, v2) or DiT (v3)"))
2944
+
2945
+ if model_name == "LISSL(xinlai/LISSL-7B-v1)":
2946
+ layer_names = ["dec_0_input", "dec_0_attn", "dec_0_block", "dec_1_input", "dec_1_attn", "dec_1_block"]
2947
+ default_layer = "dec_1_block"
2948
+ return (gr.Slider(1, 6, step=1, label="LISA decoder: Layer index", value=6, elem_id="layer", visible=False, info=""),
2949
+ gr.Dropdown(layer_names, label="LISA decoder: Layer and Node", value=default_layer, elem_id="node_type"))
2950
+
2951
+ layer_dict = LAYER_DICT
2952
+ if model_name in layer_dict:
2953
+ value = layer_dict[model_name]
2954
+ return gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True, info="")
2955
+ else:
2956
+ value = 12
2957
+ return gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True, info="")
2958
+ model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=layer_slider)
2959
+
2960
+ def change_prompt_text(model_name):
2961
+ if model_name in promptable_diffusion_models:
2962
+ return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=True),
2963
+ gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=True))
2964
+ return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False),
2965
+ gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False))
2966
+ model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt])
2967
+
2968
+ with gr.Accordion("Advanced Parameters: NCUT", open=False):
2969
+ gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
2970
+ affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
2971
+ num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
2972
+ # sampling_method_dropdown = gr.Dropdown(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method", info="Nyström approximation")
2973
+ sampling_method_dropdown = gr.Radio(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method")
2974
+ # ncut_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
2975
+ ncut_metric_dropdown = gr.Radio(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
2976
+ ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
2977
+ ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=False, elem_id="ncut_indirect_connection", info="TODO: Indirect connection is not implemented for directed NCUT", interactive=False)
2978
+ ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
2979
+ with gr.Accordion("Advanced Parameters: Visualization", open=False):
2980
+ # embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
2981
+ embedding_method_dropdown = gr.Radio(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
2982
+ # embedding_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="t-SNE/UMAP metric", value="euclidean", elem_id="embedding_metric")
2983
+ embedding_metric_dropdown = gr.Radio(["euclidean", "cosine"], label="t-SNE/UMAP: metric", value="euclidean", elem_id="embedding_metric")
2984
+ num_sample_tsne_slider = gr.Slider(100, 10000, step=100, label="t-SNE/UMAP: num_sample", value=300, elem_id="num_sample_tsne", info="Nyström approximation")
2985
+ knn_tsne_slider = gr.Slider(1, 100, step=1, label="t-SNE/UMAP: KNN", value=10, elem_id="knn_tsne", info="Nyström approximation")
2986
+ perplexity_slider = gr.Slider(10, 1000, step=10, label="t-SNE: perplexity", value=150, elem_id="perplexity")
2987
+ n_neighbors_slider = gr.Slider(10, 1000, step=10, label="UMAP: n_neighbors", value=150, elem_id="n_neighbors")
2988
+ min_dist_slider = gr.Slider(0.1, 1, step=0.1, label="UMAP: min_dist", value=0.1, elem_id="min_dist")
2989
+ return [model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider,
2990
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
2991
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
2992
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
2993
+ sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt]
2994
+
2995
+ def add_one_model(i_model=1):
2996
+ with gr.Column(scale=5, min_width=200) as col:
2997
+ gr.Markdown(f'### Output Images')
2998
+ output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2999
+ submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
3000
+ add_rotate_flip_buttons(output_gallery)
3001
+ add_download_button(output_gallery, f"ncut_embed")
3002
+ mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
3003
+ add_mlp_fitting_buttons(output_gallery, mlp_gallery)
3004
+ add_download_button(mlp_gallery, f"mlp_color_align")
3005
+ norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
3006
+ add_download_button(norm_gallery, f"eig_norm")
3007
+ cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
3008
+ add_download_button(cluster_gallery, f"clusters")
3009
+ [
3010
+ model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider,
3011
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
3012
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
3013
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
3014
+ sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
3015
+ ] = make_parameters_section_2model()
3016
+ # logging text box
3017
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
3018
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
3019
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
3020
+
3021
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
3022
+
3023
+ submit_button.click(
3024
+ partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True, advanced=True, directed=True),
3025
+ inputs=[
3026
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
3027
+ positive_prompt, negative_prompt,
3028
+ false_placeholder, no_prompt, no_prompt, no_prompt,
3029
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
3030
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
3031
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown,
3032
+ *[false_placeholder for _ in range(9)],
3033
+ node_type_dropdown2, head_index_text, make_symmetric
3034
+ ],
3035
+ outputs=[output_gallery, cluster_gallery, norm_gallery, logging_text]
3036
+ )
3037
+
3038
+ output_gallery.change(lambda x: gr.update(value=x), inputs=[output_gallery], outputs=[mlp_gallery])
3039
+
3040
+ return output_gallery
3041
+
3042
+ galleries = []
3043
 
3044
+ with gr.Row():
3045
+ with gr.Column(scale=5, min_width=200):
3046
+ input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True)
3047
+ submit_button.visible = False
3048
+
3049
+
3050
+ for i in range(3):
3051
+ g = add_one_model()
3052
+ galleries.append(g)
3053
+
3054
+ # Create rows and buttons in a loop
3055
+ rows = []
3056
+ buttons = []
3057
+
3058
+ for i in range(4):
3059
+ row = gr.Row(visible=False)
3060
+ rows.append(row)
3061
+
3062
+ with row:
3063
+ for j in range(4):
3064
+ with gr.Column(scale=5, min_width=200):
3065
+ g = add_one_model()
3066
+ galleries.append(g)
3067
+
3068
+ button = gr.Button("➕ Add Compare", elem_id=f"add_button_{i}", visible=False if i > 0 else True, scale=3)
3069
+ buttons.append(button)
3070
+
3071
+ if i > 0:
3072
+ # Reveal the current row and next button
3073
+ buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=row)
3074
+ buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=button)
3075
+
3076
+ # Hide the current button
3077
+ buttons[i - 1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[i - 1])
3078
+
3079
+ # Last button only reveals the last row and hides itself
3080
+ buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
3081
+ buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
3082
+
3083
+
3084
 
3085
  with gr.Tab('📄About'):
3086
  with gr.Column():
directed_ncut.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ def affinity_from_features(
6
+ features,
7
+ features_B=None,
8
+ affinity_focal_gamma=1.0,
9
+ distance="cosine",
10
+ normalize_features=False,
11
+ fill_diagonal=False,
12
+ n_features=1,
13
+ ):
14
+ """Compute affinity matrix from input features.
15
+
16
+ Args:
17
+ features (torch.Tensor): input features, shape (n_samples, n_features)
18
+ feature_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
19
+ affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the edge weights
20
+ on weak connections, default 1.0
21
+ distance (str): distance metric, 'cosine' (default) or 'euclidean'.
22
+ apply_normalize (bool): normalize input features before computing affinity matrix,
23
+ default True
24
+
25
+ Returns:
26
+ (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
27
+ """
28
+ # compute affinity matrix from input features
29
+ features = features.clone()
30
+ if features_B is not None:
31
+ features_B = features_B.clone()
32
+
33
+ # if feature_B is not provided, compute affinity matrix on features x features
34
+ # if feature_B is provided, compute affinity matrix on features x feature_B
35
+ if features_B is not None:
36
+ assert not fill_diagonal, "fill_diagonal should be False when feature_B is None"
37
+ features_B = features if features_B is None else features_B
38
+
39
+ if normalize_features:
40
+ features = F.normalize(features, dim=-1)
41
+ features_B = F.normalize(features_B, dim=-1)
42
+
43
+ if distance == "cosine":
44
+ # if not check_if_normalized(features):
45
+
46
+ # TODO: make sure features are normalized within each head
47
+
48
+ features = F.normalize(features, dim=-1)
49
+ # if not check_if_normalized(features_B):
50
+ features_B = F.normalize(features_B, dim=-1)
51
+ A = 1 - (features @ features_B.T) / n_features
52
+ elif distance == "euclidean":
53
+ A = torch.cdist(features, features_B, p=2) / n_features
54
+ else:
55
+ raise ValueError("distance should be 'cosine' or 'euclidean'")
56
+
57
+ if fill_diagonal:
58
+ A[torch.arange(A.shape[0]), torch.arange(A.shape[0])] = 0
59
+
60
+ # torch.exp make affinity matrix positive definite,
61
+ # lower affinity_focal_gamma reduce the weak edge weights
62
+ A = torch.exp(-((A / affinity_focal_gamma)))
63
+ return A
64
+
65
+ from ncut_pytorch.ncut_pytorch import run_subgraph_sampling, propagate_knn, gram_schmidt
66
+ import logging
67
+
68
+ import torch
69
+
70
+ def ncut(
71
+ A,
72
+ num_eig=20,
73
+ eig_solver="svd_lowrank",
74
+ make_symmetric=True,
75
+ ):
76
+ """PyTorch implementation of Normalized cut without Nystrom-like approximation.
77
+
78
+ Args:
79
+ A (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
80
+ num_eig (int): number of eigenvectors to return
81
+ eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh']
82
+
83
+ Returns:
84
+ (torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
85
+ (torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
86
+ """
87
+ if make_symmetric:
88
+ # make sure A is symmetric
89
+ A = (A + A.T) / 2
90
+
91
+ # symmetrical normalization; A = D^(-1/2) A D^(-1/2)
92
+ D_r = A.sum(dim=0).detach().clone()
93
+ D_c = A.sum(dim=1).detach().clone()
94
+ A /= torch.sqrt(D_r)[:, None]
95
+ A /= torch.sqrt(D_c)[None, :]
96
+
97
+ # compute eigenvectors
98
+ if eig_solver == "svd_lowrank": # default
99
+ # only top q eigenvectors, fastest
100
+ eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig)
101
+ elif eig_solver == "lobpcg":
102
+ # only top k eigenvectors, fast
103
+ eigen_value, eigen_vector = torch.lobpcg(A, k=num_eig)
104
+ elif eig_solver == "svd":
105
+ # all eigenvectors, slow
106
+ eigen_vector, eigen_value, _ = torch.svd(A)
107
+ elif eig_solver == "eigh":
108
+ # all eigenvectors, slow
109
+ eigen_value, eigen_vector = torch.linalg.eigh(A)
110
+ else:
111
+ raise ValueError(
112
+ "eigen_solver should be 'lobpcg', 'svd_lowrank', 'svd' or 'eigh'"
113
+ )
114
+
115
+ # sort eigenvectors by eigenvalues, take top (descending order)
116
+ eigen_value = eigen_value.real
117
+ eigen_vector = eigen_vector.real
118
+
119
+ sort_order = torch.argsort(eigen_value, descending=True)[:num_eig]
120
+ eigen_value = eigen_value[sort_order]
121
+ eigen_vector = eigen_vector[:, sort_order]
122
+
123
+ if eigen_value.min() < 0:
124
+ logging.warning(
125
+ "negative eigenvalues detected, please make sure the affinity matrix is positive definite"
126
+ )
127
+
128
+ return eigen_vector, eigen_value
129
+
130
+ def nystrom_ncut(
131
+ features,
132
+ features_B=None,
133
+ num_eig=100,
134
+ num_sample=10000,
135
+ knn=10,
136
+ sample_method="farthest",
137
+ distance="cosine",
138
+ affinity_focal_gamma=1.0,
139
+ indirect_connection=False,
140
+ indirect_pca_dim=100,
141
+ device=None,
142
+ eig_solver="svd_lowrank",
143
+ normalize_features=False,
144
+ matmul_chunk_size=8096,
145
+ make_orthogonal=False,
146
+ verbose=False,
147
+ no_propagation=False,
148
+ make_symmetric=False,
149
+ n_features=1,
150
+ ):
151
+ """PyTorch implementation of Faster Nystrom Normalized cut.
152
+
153
+ Args:
154
+ features (torch.Tensor): feature matrix, shape (n_samples, n_features)
155
+ features_2 (torch.Tensor): feature matrix 2, for asymmetric affinity matrix, shape (n_samples2, n_features)
156
+ num_eig (int): default 20, number of top eigenvectors to return
157
+ num_sample (int): default 30000, number of samples for Nystrom-like approximation
158
+ knn (int): default 3, number of KNN for propagating eigenvectors from subgraph to full graph,
159
+ smaller knn will result in more sharp eigenvectors,
160
+ sample_method (str): sample method, 'farthest' (default) or 'random'
161
+ 'farthest' is recommended for better approximation
162
+ distance (str): distance metric, 'cosine' (default) or 'euclidean'
163
+ affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the weak edge weights,
164
+ resulting in more sharp eigenvectors, default 1.0
165
+ indirect_connection (bool): include indirect connection in the subgraph, default True
166
+ indirect_pca_dim (int): default 100, PCA dimension to reduce the node dimension, only applied to
167
+ the not sampled nodes, not applied to the sampled nodes
168
+ device (str): device to use for computation, if None, will not change device
169
+ a good practice is to pass features by CPU since it's usually large,
170
+ and move subgraph affinity to GPU to speed up eigenvector computation
171
+ eig_solver (str): eigen decompose solver, 'svd_lowrank' (default), 'lobpcg', 'svd', 'eigh'
172
+ 'svd_lowrank' is recommended for large scale graph, it's the fastest
173
+ they correspond to torch.svd_lowrank, torch.lobpcg, torch.svd, torch.linalg.eigh
174
+ normalize_features (bool): normalize input features before computing affinity matrix,
175
+ default True
176
+ matmul_chunk_size (int): chunk size for matrix multiplication
177
+ large matrix multiplication is chunked to reduce memory usage,
178
+ smaller chunk size will reduce memory usage but slower computation, default 8096
179
+ make_orthogonal (bool): make eigenvectors orthogonal after propagation, default True
180
+ verbose (bool): show progress bar when propagating eigenvectors from subgraph to full graph
181
+ no_propagation (bool): if True, skip the eigenvector propagation step, only return the subgraph eigenvectors
182
+
183
+ Returns:
184
+ (torch.Tensor): eigenvectors, shape (n_samples, num_eig)
185
+ (torch.Tensor): eigenvalues, sorted in descending order, shape (num_eig,)
186
+ (torch.Tensor): sampled_indices used by Nystrom-like approximation subgraph, shape (num_sample,)
187
+ """
188
+
189
+ # check if features dimension greater than num_eig
190
+ if eig_solver in ["svd_lowrank", "lobpcg"]:
191
+ assert features.shape[0] > (
192
+ num_eig * 2
193
+ ), "number of nodes should be greater than 2*num_eig"
194
+ if eig_solver in ["svd", "eigh"]:
195
+ assert (
196
+ features.shape[0] > num_eig
197
+ ), "number of nodes should be greater than num_eig"
198
+
199
+ features = features.clone()
200
+ if normalize_features:
201
+ # features need to be normalized for affinity matrix computation (cosine distance)
202
+ features = torch.nn.functional.normalize(features, dim=-1)
203
+
204
+ sampled_indices = run_subgraph_sampling(
205
+ features,
206
+ num_sample=num_sample,
207
+ sample_method=sample_method,
208
+ )
209
+
210
+ sampled_indices_B = run_subgraph_sampling(
211
+ features_B,
212
+ num_sample=num_sample,
213
+ sample_method=sample_method,
214
+ )
215
+
216
+ sampled_features = features[sampled_indices]
217
+ sampled_features_B = features_B[sampled_indices_B]
218
+ # move subgraph gpu to speed up
219
+ original_device = sampled_features.device
220
+ device = original_device if device is None else device
221
+ sampled_features = sampled_features.to(device)
222
+ sampled_features_B = sampled_features_B.to(device)
223
+
224
+ # compute affinity matrix on subgraph
225
+ A = affinity_from_features(
226
+ sampled_features, features_B=sampled_features_B,
227
+ affinity_focal_gamma=affinity_focal_gamma, distance=distance,
228
+ n_features=n_features,
229
+ )
230
+
231
+ not_sampled = torch.tensor(
232
+ list(set(range(features.shape[0])) - set(sampled_indices))
233
+ )
234
+
235
+ if len(not_sampled) == 0:
236
+ # if sampled all nodes, no need for nyström approximation
237
+ eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver)
238
+ return eigen_vector, eigen_value, sampled_indices
239
+
240
+ # 1) PCA to reduce the node dimension for the not sampled nodes
241
+ # 2) compute indirect connection on the PC nodes
242
+ if len(not_sampled) > 0 and indirect_connection:
243
+ raise NotImplementedError("indirect_connection is not implemented yet")
244
+ indirect_pca_dim = min(indirect_pca_dim, min(*features.shape))
245
+ U, S, V = torch.pca_lowrank(features[not_sampled].T, q=indirect_pca_dim)
246
+ feature_B = (features[not_sampled].T @ V).T # project to PCA space
247
+ feature_B = feature_B.to(device)
248
+ B = affinity_from_features(
249
+ sampled_features,
250
+ feature_B,
251
+ affinity_focal_gamma=affinity_focal_gamma,
252
+ distance=distance,
253
+ fill_diagonal=False,
254
+ )
255
+ # P is 1-hop random walk matrix
256
+ B_row = B / B.sum(axis=1, keepdim=True)
257
+ B_col = B / B.sum(axis=0, keepdim=True)
258
+ P = B_row @ B_col.T
259
+ P = (P + P.T) / 2
260
+ # fill diagonal with 0
261
+ P[torch.arange(P.shape[0]), torch.arange(P.shape[0])] = 0
262
+ A = A + P
263
+
264
+ # compute normalized cut on the subgraph
265
+ eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver, make_symmetric=make_symmetric)
266
+ eigen_vector = eigen_vector.to(dtype=features.dtype, device=original_device)
267
+ eigen_value = eigen_value.to(dtype=features.dtype, device=original_device)
268
+
269
+ if no_propagation:
270
+ return eigen_vector, eigen_value, sampled_indices
271
+
272
+ # propagate eigenvectors from subgraph to full graph
273
+ eigen_vector = propagate_knn(
274
+ eigen_vector,
275
+ features,
276
+ sampled_features,
277
+ knn,
278
+ chunk_size=matmul_chunk_size,
279
+ device=device,
280
+ use_tqdm=verbose,
281
+ )
282
+
283
+ # post-hoc orthogonalization
284
+ if make_orthogonal:
285
+ eigen_vector = gram_schmidt(eigen_vector)
286
+
287
+ return eigen_vector, eigen_value, sampled_indices
requirements.txt CHANGED
@@ -20,4 +20,4 @@ lisa @ git+https://github.com/huzeyann/LISA.git@7211e99
20
  timm==0.9.2
21
  open-clip-torch==2.20.0
22
  pytorch_lightning==1.9.4
23
- ncut-pytorch>=1.3.15
 
20
  timm==0.9.2
21
  open-clip-torch==2.20.0
22
  pytorch_lightning==1.9.4
23
+ ncut-pytorch>=1.4.1