Spaces:
Running
on
Zero
Running
on
Zero
add directed ncut (test)
Browse files- app.py +382 -39
- directed_ncut.py +287 -0
- requirements.txt +1 -1
app.py
CHANGED
@@ -183,6 +183,84 @@ def compute_ncut(
|
|
183 |
return rgb, logging_str, eigvecs
|
184 |
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
def dont_use_too_much_green(image_rgb):
|
187 |
# make sure the foval 40% of the image is red leading
|
188 |
x1, x2 = int(image_rgb.shape[1] * 0.3), int(image_rgb.shape[1] * 0.7)
|
@@ -592,6 +670,8 @@ def ncut_run(
|
|
592 |
**kwargs,
|
593 |
):
|
594 |
advanced = kwargs.get("advanced", False)
|
|
|
|
|
595 |
progress = gr.Progress()
|
596 |
progress(0.2, desc="Feature Extraction")
|
597 |
|
@@ -640,6 +720,11 @@ def ncut_run(
|
|
640 |
features = extract_features(
|
641 |
images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
|
642 |
)
|
|
|
|
|
|
|
|
|
|
|
643 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
644 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
645 |
del model
|
@@ -768,25 +853,59 @@ def ncut_run(
|
|
768 |
|
769 |
|
770 |
# ailgnedcut
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
790 |
logging_str += _logging_str
|
791 |
|
792 |
if "AlignedThreeModelAttnNodes" == model_name:
|
@@ -858,26 +977,26 @@ def ncut_run(
|
|
858 |
|
859 |
def _ncut_run(*args, **kwargs):
|
860 |
n_ret = kwargs.pop("n_ret", 1)
|
861 |
-
try:
|
862 |
-
|
863 |
-
|
864 |
|
865 |
-
|
866 |
|
867 |
-
|
868 |
-
|
869 |
|
870 |
-
|
871 |
-
|
872 |
-
except Exception as e:
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
|
882 |
if USE_HUGGINGFACE_ZEROGPU:
|
883 |
@spaces.GPU(duration=30)
|
@@ -1085,12 +1204,16 @@ def run_fn(
|
|
1085 |
recursion_l1_gamma=0.5,
|
1086 |
recursion_l2_gamma=0.5,
|
1087 |
recursion_l3_gamma=0.5,
|
|
|
|
|
|
|
1088 |
n_ret=1,
|
1089 |
plot_clusters=False,
|
1090 |
alignedcut_eig_norm_plot=False,
|
1091 |
advanced=False,
|
|
|
1092 |
):
|
1093 |
-
|
1094 |
progress=gr.Progress()
|
1095 |
progress(0, desc="Starting")
|
1096 |
|
@@ -1222,6 +1345,10 @@ def run_fn(
|
|
1222 |
"plot_clusters": plot_clusters,
|
1223 |
"alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
|
1224 |
"advanced": advanced,
|
|
|
|
|
|
|
|
|
1225 |
}
|
1226 |
# print(kwargs)
|
1227 |
|
@@ -1379,7 +1506,7 @@ def fit_trans(rgb1, rgb2, num_layer=3, width=512, batch_size=256, lr=3e-4, fitti
|
|
1379 |
# Train the model
|
1380 |
trainer.fit(mlp, dataloader)
|
1381 |
|
1382 |
-
|
1383 |
results = trainer.predict(mlp, data_loader)
|
1384 |
A_transformed = torch.cat(results, dim=0)
|
1385 |
|
@@ -2734,10 +2861,226 @@ with demo:
|
|
2734 |
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
2735 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
2736 |
|
|
|
|
|
2737 |
|
2738 |
-
|
2739 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2740 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2741 |
|
2742 |
with gr.Tab('📄About'):
|
2743 |
with gr.Column():
|
|
|
183 |
return rgb, logging_str, eigvecs
|
184 |
|
185 |
|
186 |
+
def compute_ncut_directed(
|
187 |
+
features_1,
|
188 |
+
features_2,
|
189 |
+
num_eig=100,
|
190 |
+
num_sample_ncut=10000,
|
191 |
+
affinity_focal_gamma=0.3,
|
192 |
+
knn_ncut=10,
|
193 |
+
knn_tsne=10,
|
194 |
+
embedding_method="UMAP",
|
195 |
+
embedding_metric='euclidean',
|
196 |
+
num_sample_tsne=300,
|
197 |
+
perplexity=150,
|
198 |
+
n_neighbors=150,
|
199 |
+
min_dist=0.1,
|
200 |
+
sampling_method="QuickFPS",
|
201 |
+
metric="cosine",
|
202 |
+
indirect_connection=False,
|
203 |
+
make_orthogonal=False,
|
204 |
+
make_symmetric=False,
|
205 |
+
progess_start=0.4,
|
206 |
+
):
|
207 |
+
print("Using directed_ncut")
|
208 |
+
print("features_1.shape", features_1.shape)
|
209 |
+
print("features_2.shape", features_2.shape)
|
210 |
+
from directed_ncut import nystrom_ncut
|
211 |
+
progress = gr.Progress()
|
212 |
+
logging_str = ""
|
213 |
+
|
214 |
+
num_nodes = np.prod(features_1.shape[:-2])
|
215 |
+
if num_nodes / 2 < num_eig:
|
216 |
+
# raise gr.Error("Number of eigenvectors should be less than half the number of nodes.")
|
217 |
+
gr.Warning("Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.")
|
218 |
+
num_eig = num_nodes // 2 - 1
|
219 |
+
logging_str += f"Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.\n"
|
220 |
+
|
221 |
+
start = time.time()
|
222 |
+
progress(progess_start+0.0, desc="NCut")
|
223 |
+
n_features = features_1.shape[-2]
|
224 |
+
_features_1 = rearrange(features_1, "b h w d c -> (b h w) (d c)")
|
225 |
+
_features_2 = rearrange(features_2, "b h w d c -> (b h w) (d c)")
|
226 |
+
eigvecs, eigvals, _ = nystrom_ncut(
|
227 |
+
_features_1,
|
228 |
+
features_B=_features_2,
|
229 |
+
num_eig=num_eig,
|
230 |
+
num_sample=num_sample_ncut,
|
231 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
232 |
+
affinity_focal_gamma=affinity_focal_gamma,
|
233 |
+
knn=knn_ncut,
|
234 |
+
sample_method=sampling_method,
|
235 |
+
distance=metric,
|
236 |
+
normalize_features=False,
|
237 |
+
indirect_connection=indirect_connection,
|
238 |
+
make_orthogonal=make_orthogonal,
|
239 |
+
make_symmetric=make_symmetric,
|
240 |
+
n_features=n_features,
|
241 |
+
)
|
242 |
+
# print(f"NCUT time: {time.time() - start:.2f}s")
|
243 |
+
logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
|
244 |
+
|
245 |
+
start = time.time()
|
246 |
+
progress(progess_start+0.01, desc="spectral-tSNE")
|
247 |
+
_, rgb = eigenvector_to_rgb(
|
248 |
+
eigvecs,
|
249 |
+
method=embedding_method,
|
250 |
+
metric=embedding_metric,
|
251 |
+
num_sample=num_sample_tsne,
|
252 |
+
perplexity=perplexity,
|
253 |
+
n_neighbors=n_neighbors,
|
254 |
+
min_distance=min_dist,
|
255 |
+
knn=knn_tsne,
|
256 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
257 |
+
)
|
258 |
+
logging_str += f"{embedding_method} time: {time.time() - start:.2f}s\n"
|
259 |
+
|
260 |
+
rgb = rgb.reshape(features_1.shape[:3] + (3,))
|
261 |
+
return rgb, logging_str, eigvecs
|
262 |
+
|
263 |
+
|
264 |
def dont_use_too_much_green(image_rgb):
|
265 |
# make sure the foval 40% of the image is red leading
|
266 |
x1, x2 = int(image_rgb.shape[1] * 0.3), int(image_rgb.shape[1] * 0.7)
|
|
|
670 |
**kwargs,
|
671 |
):
|
672 |
advanced = kwargs.get("advanced", False)
|
673 |
+
directed = kwargs.get("directed", False)
|
674 |
+
|
675 |
progress = gr.Progress()
|
676 |
progress(0.2, desc="Feature Extraction")
|
677 |
|
|
|
720 |
features = extract_features(
|
721 |
images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
|
722 |
)
|
723 |
+
if directed:
|
724 |
+
node_type2 = kwargs.get("node_type2", None)
|
725 |
+
features_B = extract_features(
|
726 |
+
images, model, node_type=node_type2, layer=layer-1, batch_size=BATCH_SIZE
|
727 |
+
)
|
728 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
729 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
730 |
del model
|
|
|
853 |
|
854 |
|
855 |
# ailgnedcut
|
856 |
+
if not directed:
|
857 |
+
rgb, _logging_str, eigvecs = compute_ncut(
|
858 |
+
features,
|
859 |
+
num_eig=num_eig,
|
860 |
+
num_sample_ncut=num_sample_ncut,
|
861 |
+
affinity_focal_gamma=affinity_focal_gamma,
|
862 |
+
knn_ncut=knn_ncut,
|
863 |
+
knn_tsne=knn_tsne,
|
864 |
+
num_sample_tsne=num_sample_tsne,
|
865 |
+
embedding_method=embedding_method,
|
866 |
+
embedding_metric=embedding_metric,
|
867 |
+
perplexity=perplexity,
|
868 |
+
n_neighbors=n_neighbors,
|
869 |
+
min_dist=min_dist,
|
870 |
+
sampling_method=sampling_method,
|
871 |
+
indirect_connection=indirect_connection,
|
872 |
+
make_orthogonal=make_orthogonal,
|
873 |
+
metric=ncut_metric,
|
874 |
+
)
|
875 |
+
if directed:
|
876 |
+
head_index_text = kwargs.get("head_index_text", None)
|
877 |
+
n_heads = features.shape[-2] # (batch, h, w, n_heads, d)
|
878 |
+
if head_index_text == 'all':
|
879 |
+
head_idx = torch.arange(n_heads)
|
880 |
+
else:
|
881 |
+
_idxs = head_index_text.split(",")
|
882 |
+
head_idx = torch.tensor([int(idx) for idx in _idxs])
|
883 |
+
features_A = features[:, :, :, head_idx, :]
|
884 |
+
features_B = features_B[:, :, :, head_idx, :]
|
885 |
+
|
886 |
+
rgb, _logging_str, eigvecs = compute_ncut_directed(
|
887 |
+
features_A,
|
888 |
+
features_B,
|
889 |
+
num_eig=num_eig,
|
890 |
+
num_sample_ncut=num_sample_ncut,
|
891 |
+
affinity_focal_gamma=affinity_focal_gamma,
|
892 |
+
knn_ncut=knn_ncut,
|
893 |
+
knn_tsne=knn_tsne,
|
894 |
+
num_sample_tsne=num_sample_tsne,
|
895 |
+
embedding_method=embedding_method,
|
896 |
+
embedding_metric=embedding_metric,
|
897 |
+
perplexity=perplexity,
|
898 |
+
n_neighbors=n_neighbors,
|
899 |
+
min_dist=min_dist,
|
900 |
+
sampling_method=sampling_method,
|
901 |
+
indirect_connection=False,
|
902 |
+
make_orthogonal=make_orthogonal,
|
903 |
+
metric=ncut_metric,
|
904 |
+
make_symmetric=kwargs.get("make_symmetric", None),
|
905 |
+
)
|
906 |
+
|
907 |
+
|
908 |
+
|
909 |
logging_str += _logging_str
|
910 |
|
911 |
if "AlignedThreeModelAttnNodes" == model_name:
|
|
|
977 |
|
978 |
def _ncut_run(*args, **kwargs):
|
979 |
n_ret = kwargs.pop("n_ret", 1)
|
980 |
+
# try:
|
981 |
+
# if torch.cuda.is_available():
|
982 |
+
# torch.cuda.empty_cache()
|
983 |
|
984 |
+
# ret = ncut_run(*args, **kwargs)
|
985 |
|
986 |
+
# if torch.cuda.is_available():
|
987 |
+
# torch.cuda.empty_cache()
|
988 |
|
989 |
+
# ret = list(ret)[:n_ret] + [ret[-1]]
|
990 |
+
# return ret
|
991 |
+
# except Exception as e:
|
992 |
+
# gr.Error(str(e))
|
993 |
+
# if torch.cuda.is_available():
|
994 |
+
# torch.cuda.empty_cache()
|
995 |
+
# return *(None for _ in range(n_ret)), "Error: " + str(e)
|
996 |
|
997 |
+
ret = ncut_run(*args, **kwargs)
|
998 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
999 |
+
return ret
|
1000 |
|
1001 |
if USE_HUGGINGFACE_ZEROGPU:
|
1002 |
@spaces.GPU(duration=30)
|
|
|
1204 |
recursion_l1_gamma=0.5,
|
1205 |
recursion_l2_gamma=0.5,
|
1206 |
recursion_l3_gamma=0.5,
|
1207 |
+
node_type2="k",
|
1208 |
+
head_index_text='all',
|
1209 |
+
make_symmetric=False,
|
1210 |
n_ret=1,
|
1211 |
plot_clusters=False,
|
1212 |
alignedcut_eig_norm_plot=False,
|
1213 |
advanced=False,
|
1214 |
+
directed=False,
|
1215 |
):
|
1216 |
+
print(node_type2, head_index_text, make_symmetric)
|
1217 |
progress=gr.Progress()
|
1218 |
progress(0, desc="Starting")
|
1219 |
|
|
|
1345 |
"plot_clusters": plot_clusters,
|
1346 |
"alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
|
1347 |
"advanced": advanced,
|
1348 |
+
"directed": directed,
|
1349 |
+
"node_type2": node_type2,
|
1350 |
+
"head_index_text": head_index_text,
|
1351 |
+
"make_symmetric": make_symmetric,
|
1352 |
}
|
1353 |
# print(kwargs)
|
1354 |
|
|
|
1506 |
# Train the model
|
1507 |
trainer.fit(mlp, dataloader)
|
1508 |
|
1509 |
+
mlp.progress(0.99, desc="Applying MLP")
|
1510 |
results = trainer.predict(mlp, data_loader)
|
1511 |
A_transformed = torch.cat(results, dim=0)
|
1512 |
|
|
|
2861 |
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
2862 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
2863 |
|
2864 |
+
|
2865 |
+
with gr.Tab('Directed (experimental)', visible=True) as tab_directed_ncut:
|
2866 |
|
2867 |
+
target_images = gr.State([])
|
2868 |
+
input_images = gr.State([])
|
2869 |
+
def add_mlp_fitting_buttons(output_gallery, mlp_gallery, target_images=target_images, input_images=input_images):
|
2870 |
+
with gr.Row():
|
2871 |
+
# mark_as_target_button = gr.Button("mark target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary')
|
2872 |
+
# mark_as_input_button = gr.Button("mark input", elem_id=f"mark_as_input_button_{output_gallery.elem_id}", variant='secondary')
|
2873 |
+
mark_as_target_button = gr.Button("🎯 Mark Target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary')
|
2874 |
+
fit_to_target_button = gr.Button("🔴 [MLP] Fit", elem_id=f"fit_to_target_button_{output_gallery.elem_id}", variant='primary')
|
2875 |
+
def mark_fn(images, text="target"):
|
2876 |
+
if images is None:
|
2877 |
+
raise gr.Error("No images selected")
|
2878 |
+
if len(images) == 0:
|
2879 |
+
raise gr.Error("No images selected")
|
2880 |
+
num_images = len(images)
|
2881 |
+
gr.Info(f"Marked {num_images} images as {text}")
|
2882 |
+
images = [(Image.open(tup[0]), []) for tup in images]
|
2883 |
+
return images
|
2884 |
+
mark_as_target_button.click(partial(mark_fn, text="target"), inputs=[output_gallery], outputs=[target_images])
|
2885 |
+
# mark_as_input_button.click(partial(mark_fn, text="input"), inputs=[output_gallery], outputs=[input_images])
|
2886 |
+
|
2887 |
+
with gr.Accordion("➡️ MLP Parameters", open=False):
|
2888 |
+
num_layers_slider = gr.Slider(2, 10, step=1, label="Number of Layers", value=3, elem_id=f"num_layers_slider_{output_gallery.elem_id}")
|
2889 |
+
width_slider = gr.Slider(128, 4096, step=128, label="Width", value=512, elem_id=f"width_slider_{output_gallery.elem_id}")
|
2890 |
+
batch_size_slider = gr.Slider(32, 4096, step=32, label="Batch Size", value=128, elem_id=f"batch_size_slider_{output_gallery.elem_id}")
|
2891 |
+
lr_slider = gr.Slider(1e-6, 1, step=1e-6, label="Learning Rate", value=3e-4, elem_id=f"lr_slider_{output_gallery.elem_id}")
|
2892 |
+
fitting_steps_slider = gr.Slider(1000, 100000, step=1000, label="Fitting Steps", value=30000, elem_id=f"fitting_steps_slider_{output_gallery.elem_id}")
|
2893 |
+
fps_sample_slider = gr.Slider(128, 50000, step=128, label="FPS Sample", value=10240, elem_id=f"fps_sample_slider_{output_gallery.elem_id}")
|
2894 |
+
segmentation_loss_lambda_slider = gr.Slider(0, 100, step=0.01, label="Segmentation Preserving Loss Lambda", value=1, elem_id=f"segmentation_loss_lambda_slider_{output_gallery.elem_id}")
|
2895 |
+
|
2896 |
+
fit_to_target_button.click(
|
2897 |
+
run_mlp_fit,
|
2898 |
+
inputs=[output_gallery, target_images, num_layers_slider, width_slider, batch_size_slider, lr_slider, fitting_steps_slider, fps_sample_slider, segmentation_loss_lambda_slider],
|
2899 |
+
outputs=[mlp_gallery],
|
2900 |
+
)
|
2901 |
+
|
2902 |
+
def make_parameters_section_2model(model_ratio=True):
|
2903 |
+
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
2904 |
+
from ncut_pytorch.backbone import list_models, get_demo_model_names
|
2905 |
+
model_names = list_models()
|
2906 |
+
model_names = sorted(model_names)
|
2907 |
+
# only CLIP DINO MAE is implemented for q k v
|
2908 |
+
ok_models = ["CLIP(ViT", "DiNO(", "MAE("]
|
2909 |
+
model_names = [m for m in model_names if any(ok in m for ok in ok_models)]
|
2910 |
+
|
2911 |
+
def get_filtered_model_names(name):
|
2912 |
+
return [m for m in model_names if name.lower() in m.lower()]
|
2913 |
+
def get_default_model_name(name):
|
2914 |
+
lst = get_filtered_model_names(name)
|
2915 |
+
if len(lst) > 1:
|
2916 |
+
return lst[1]
|
2917 |
+
return lst[0]
|
2918 |
+
|
2919 |
+
|
2920 |
+
model_radio = gr.Radio(["CLIP", "DiNO", "MAE"], label="Backbone", value="DiNO", elem_id="model_radio", show_label=True, visible=model_ratio)
|
2921 |
+
model_dropdown = gr.Dropdown(get_filtered_model_names("DiNO"), label="", value="DiNO(dino_vitb8_448)", elem_id="model_name", show_label=False)
|
2922 |
+
model_radio.change(fn=lambda x: gr.update(choices=get_filtered_model_names(x), value=get_default_model_name(x)), inputs=model_radio, outputs=[model_dropdown])
|
2923 |
+
layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
|
2924 |
+
positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'")
|
2925 |
+
positive_prompt.visible = False
|
2926 |
+
negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'")
|
2927 |
+
negative_prompt.visible = False
|
2928 |
+
node_type_dropdown = gr.Dropdown(['q', 'k', 'v'],
|
2929 |
+
label="Left-side Node Type", value="q", elem_id="node_type", info="In directed case, left-side SVD eigenvector is taken")
|
2930 |
+
node_type_dropdown2 = gr.Dropdown(['q', 'k', 'v'],
|
2931 |
+
label="Right-side Node Type", value="k", elem_id="node_type2")
|
2932 |
+
head_index_text = gr.Textbox(value='all', label="Head Index", elem_id="head_index", type="text", info="which attention heads to use, comma separated, e.g. 0,1,2")
|
2933 |
+
make_symmetric = gr.Checkbox(label="Make Symmetric", value=False, elem_id="make_symmetric", info="make the graph symmetric by A = (A + A.T) / 2")
|
2934 |
+
|
2935 |
+
num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for smaller clusters')
|
2936 |
+
|
2937 |
+
def change_layer_slider(model_name):
|
2938 |
+
# SD2, UNET
|
2939 |
+
if "stable" in model_name.lower() and "diffusion" in model_name.lower():
|
2940 |
+
from ncut_pytorch.backbone import SD_KEY_DICT
|
2941 |
+
default_layer = 'up_2_resnets_1_block' if 'diffusion-3' not in model_name else 'block_23'
|
2942 |
+
return (gr.Slider(1, 49, step=1, label="Diffusion: Timestep (Noise)", value=5, elem_id="layer", visible=True, info="Noise level, 50 is max noise"),
|
2943 |
+
gr.Dropdown(SD_KEY_DICT[model_name], label="Diffusion: Layer and Node", value=default_layer, elem_id="node_type", info="U-Net (v1, v2) or DiT (v3)"))
|
2944 |
+
|
2945 |
+
if model_name == "LISSL(xinlai/LISSL-7B-v1)":
|
2946 |
+
layer_names = ["dec_0_input", "dec_0_attn", "dec_0_block", "dec_1_input", "dec_1_attn", "dec_1_block"]
|
2947 |
+
default_layer = "dec_1_block"
|
2948 |
+
return (gr.Slider(1, 6, step=1, label="LISA decoder: Layer index", value=6, elem_id="layer", visible=False, info=""),
|
2949 |
+
gr.Dropdown(layer_names, label="LISA decoder: Layer and Node", value=default_layer, elem_id="node_type"))
|
2950 |
+
|
2951 |
+
layer_dict = LAYER_DICT
|
2952 |
+
if model_name in layer_dict:
|
2953 |
+
value = layer_dict[model_name]
|
2954 |
+
return gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True, info="")
|
2955 |
+
else:
|
2956 |
+
value = 12
|
2957 |
+
return gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True, info="")
|
2958 |
+
model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=layer_slider)
|
2959 |
+
|
2960 |
+
def change_prompt_text(model_name):
|
2961 |
+
if model_name in promptable_diffusion_models:
|
2962 |
+
return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=True),
|
2963 |
+
gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=True))
|
2964 |
+
return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False),
|
2965 |
+
gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False))
|
2966 |
+
model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt])
|
2967 |
+
|
2968 |
+
with gr.Accordion("Advanced Parameters: NCUT", open=False):
|
2969 |
+
gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
|
2970 |
+
affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
|
2971 |
+
num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
|
2972 |
+
# sampling_method_dropdown = gr.Dropdown(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method", info="Nyström approximation")
|
2973 |
+
sampling_method_dropdown = gr.Radio(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method")
|
2974 |
+
# ncut_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
|
2975 |
+
ncut_metric_dropdown = gr.Radio(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
|
2976 |
+
ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
|
2977 |
+
ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=False, elem_id="ncut_indirect_connection", info="TODO: Indirect connection is not implemented for directed NCUT", interactive=False)
|
2978 |
+
ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
|
2979 |
+
with gr.Accordion("Advanced Parameters: Visualization", open=False):
|
2980 |
+
# embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
|
2981 |
+
embedding_method_dropdown = gr.Radio(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
|
2982 |
+
# embedding_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="t-SNE/UMAP metric", value="euclidean", elem_id="embedding_metric")
|
2983 |
+
embedding_metric_dropdown = gr.Radio(["euclidean", "cosine"], label="t-SNE/UMAP: metric", value="euclidean", elem_id="embedding_metric")
|
2984 |
+
num_sample_tsne_slider = gr.Slider(100, 10000, step=100, label="t-SNE/UMAP: num_sample", value=300, elem_id="num_sample_tsne", info="Nyström approximation")
|
2985 |
+
knn_tsne_slider = gr.Slider(1, 100, step=1, label="t-SNE/UMAP: KNN", value=10, elem_id="knn_tsne", info="Nyström approximation")
|
2986 |
+
perplexity_slider = gr.Slider(10, 1000, step=10, label="t-SNE: perplexity", value=150, elem_id="perplexity")
|
2987 |
+
n_neighbors_slider = gr.Slider(10, 1000, step=10, label="UMAP: n_neighbors", value=150, elem_id="n_neighbors")
|
2988 |
+
min_dist_slider = gr.Slider(0.1, 1, step=0.1, label="UMAP: min_dist", value=0.1, elem_id="min_dist")
|
2989 |
+
return [model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider,
|
2990 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
2991 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
2992 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
2993 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt]
|
2994 |
+
|
2995 |
+
def add_one_model(i_model=1):
|
2996 |
+
with gr.Column(scale=5, min_width=200) as col:
|
2997 |
+
gr.Markdown(f'### Output Images')
|
2998 |
+
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2999 |
+
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
3000 |
+
add_rotate_flip_buttons(output_gallery)
|
3001 |
+
add_download_button(output_gallery, f"ncut_embed")
|
3002 |
+
mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
3003 |
+
add_mlp_fitting_buttons(output_gallery, mlp_gallery)
|
3004 |
+
add_download_button(mlp_gallery, f"mlp_color_align")
|
3005 |
+
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
3006 |
+
add_download_button(norm_gallery, f"eig_norm")
|
3007 |
+
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
3008 |
+
add_download_button(cluster_gallery, f"clusters")
|
3009 |
+
[
|
3010 |
+
model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider,
|
3011 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
3012 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
3013 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
3014 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
3015 |
+
] = make_parameters_section_2model()
|
3016 |
+
# logging text box
|
3017 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
3018 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
3019 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
3020 |
+
|
3021 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
3022 |
+
|
3023 |
+
submit_button.click(
|
3024 |
+
partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True, advanced=True, directed=True),
|
3025 |
+
inputs=[
|
3026 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
3027 |
+
positive_prompt, negative_prompt,
|
3028 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
3029 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
3030 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
3031 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown,
|
3032 |
+
*[false_placeholder for _ in range(9)],
|
3033 |
+
node_type_dropdown2, head_index_text, make_symmetric
|
3034 |
+
],
|
3035 |
+
outputs=[output_gallery, cluster_gallery, norm_gallery, logging_text]
|
3036 |
+
)
|
3037 |
+
|
3038 |
+
output_gallery.change(lambda x: gr.update(value=x), inputs=[output_gallery], outputs=[mlp_gallery])
|
3039 |
+
|
3040 |
+
return output_gallery
|
3041 |
+
|
3042 |
+
galleries = []
|
3043 |
|
3044 |
+
with gr.Row():
|
3045 |
+
with gr.Column(scale=5, min_width=200):
|
3046 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True)
|
3047 |
+
submit_button.visible = False
|
3048 |
+
|
3049 |
+
|
3050 |
+
for i in range(3):
|
3051 |
+
g = add_one_model()
|
3052 |
+
galleries.append(g)
|
3053 |
+
|
3054 |
+
# Create rows and buttons in a loop
|
3055 |
+
rows = []
|
3056 |
+
buttons = []
|
3057 |
+
|
3058 |
+
for i in range(4):
|
3059 |
+
row = gr.Row(visible=False)
|
3060 |
+
rows.append(row)
|
3061 |
+
|
3062 |
+
with row:
|
3063 |
+
for j in range(4):
|
3064 |
+
with gr.Column(scale=5, min_width=200):
|
3065 |
+
g = add_one_model()
|
3066 |
+
galleries.append(g)
|
3067 |
+
|
3068 |
+
button = gr.Button("➕ Add Compare", elem_id=f"add_button_{i}", visible=False if i > 0 else True, scale=3)
|
3069 |
+
buttons.append(button)
|
3070 |
+
|
3071 |
+
if i > 0:
|
3072 |
+
# Reveal the current row and next button
|
3073 |
+
buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=row)
|
3074 |
+
buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=button)
|
3075 |
+
|
3076 |
+
# Hide the current button
|
3077 |
+
buttons[i - 1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[i - 1])
|
3078 |
+
|
3079 |
+
# Last button only reveals the last row and hides itself
|
3080 |
+
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
3081 |
+
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
3082 |
+
|
3083 |
+
|
3084 |
|
3085 |
with gr.Tab('📄About'):
|
3086 |
with gr.Column():
|
directed_ncut.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
def affinity_from_features(
|
6 |
+
features,
|
7 |
+
features_B=None,
|
8 |
+
affinity_focal_gamma=1.0,
|
9 |
+
distance="cosine",
|
10 |
+
normalize_features=False,
|
11 |
+
fill_diagonal=False,
|
12 |
+
n_features=1,
|
13 |
+
):
|
14 |
+
"""Compute affinity matrix from input features.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
features (torch.Tensor): input features, shape (n_samples, n_features)
|
18 |
+
feature_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
|
19 |
+
affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the edge weights
|
20 |
+
on weak connections, default 1.0
|
21 |
+
distance (str): distance metric, 'cosine' (default) or 'euclidean'.
|
22 |
+
apply_normalize (bool): normalize input features before computing affinity matrix,
|
23 |
+
default True
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
(torch.Tensor): affinity matrix, shape (n_samples, n_samples)
|
27 |
+
"""
|
28 |
+
# compute affinity matrix from input features
|
29 |
+
features = features.clone()
|
30 |
+
if features_B is not None:
|
31 |
+
features_B = features_B.clone()
|
32 |
+
|
33 |
+
# if feature_B is not provided, compute affinity matrix on features x features
|
34 |
+
# if feature_B is provided, compute affinity matrix on features x feature_B
|
35 |
+
if features_B is not None:
|
36 |
+
assert not fill_diagonal, "fill_diagonal should be False when feature_B is None"
|
37 |
+
features_B = features if features_B is None else features_B
|
38 |
+
|
39 |
+
if normalize_features:
|
40 |
+
features = F.normalize(features, dim=-1)
|
41 |
+
features_B = F.normalize(features_B, dim=-1)
|
42 |
+
|
43 |
+
if distance == "cosine":
|
44 |
+
# if not check_if_normalized(features):
|
45 |
+
|
46 |
+
# TODO: make sure features are normalized within each head
|
47 |
+
|
48 |
+
features = F.normalize(features, dim=-1)
|
49 |
+
# if not check_if_normalized(features_B):
|
50 |
+
features_B = F.normalize(features_B, dim=-1)
|
51 |
+
A = 1 - (features @ features_B.T) / n_features
|
52 |
+
elif distance == "euclidean":
|
53 |
+
A = torch.cdist(features, features_B, p=2) / n_features
|
54 |
+
else:
|
55 |
+
raise ValueError("distance should be 'cosine' or 'euclidean'")
|
56 |
+
|
57 |
+
if fill_diagonal:
|
58 |
+
A[torch.arange(A.shape[0]), torch.arange(A.shape[0])] = 0
|
59 |
+
|
60 |
+
# torch.exp make affinity matrix positive definite,
|
61 |
+
# lower affinity_focal_gamma reduce the weak edge weights
|
62 |
+
A = torch.exp(-((A / affinity_focal_gamma)))
|
63 |
+
return A
|
64 |
+
|
65 |
+
from ncut_pytorch.ncut_pytorch import run_subgraph_sampling, propagate_knn, gram_schmidt
|
66 |
+
import logging
|
67 |
+
|
68 |
+
import torch
|
69 |
+
|
70 |
+
def ncut(
|
71 |
+
A,
|
72 |
+
num_eig=20,
|
73 |
+
eig_solver="svd_lowrank",
|
74 |
+
make_symmetric=True,
|
75 |
+
):
|
76 |
+
"""PyTorch implementation of Normalized cut without Nystrom-like approximation.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
A (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
|
80 |
+
num_eig (int): number of eigenvectors to return
|
81 |
+
eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh']
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
(torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
|
85 |
+
(torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
|
86 |
+
"""
|
87 |
+
if make_symmetric:
|
88 |
+
# make sure A is symmetric
|
89 |
+
A = (A + A.T) / 2
|
90 |
+
|
91 |
+
# symmetrical normalization; A = D^(-1/2) A D^(-1/2)
|
92 |
+
D_r = A.sum(dim=0).detach().clone()
|
93 |
+
D_c = A.sum(dim=1).detach().clone()
|
94 |
+
A /= torch.sqrt(D_r)[:, None]
|
95 |
+
A /= torch.sqrt(D_c)[None, :]
|
96 |
+
|
97 |
+
# compute eigenvectors
|
98 |
+
if eig_solver == "svd_lowrank": # default
|
99 |
+
# only top q eigenvectors, fastest
|
100 |
+
eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig)
|
101 |
+
elif eig_solver == "lobpcg":
|
102 |
+
# only top k eigenvectors, fast
|
103 |
+
eigen_value, eigen_vector = torch.lobpcg(A, k=num_eig)
|
104 |
+
elif eig_solver == "svd":
|
105 |
+
# all eigenvectors, slow
|
106 |
+
eigen_vector, eigen_value, _ = torch.svd(A)
|
107 |
+
elif eig_solver == "eigh":
|
108 |
+
# all eigenvectors, slow
|
109 |
+
eigen_value, eigen_vector = torch.linalg.eigh(A)
|
110 |
+
else:
|
111 |
+
raise ValueError(
|
112 |
+
"eigen_solver should be 'lobpcg', 'svd_lowrank', 'svd' or 'eigh'"
|
113 |
+
)
|
114 |
+
|
115 |
+
# sort eigenvectors by eigenvalues, take top (descending order)
|
116 |
+
eigen_value = eigen_value.real
|
117 |
+
eigen_vector = eigen_vector.real
|
118 |
+
|
119 |
+
sort_order = torch.argsort(eigen_value, descending=True)[:num_eig]
|
120 |
+
eigen_value = eigen_value[sort_order]
|
121 |
+
eigen_vector = eigen_vector[:, sort_order]
|
122 |
+
|
123 |
+
if eigen_value.min() < 0:
|
124 |
+
logging.warning(
|
125 |
+
"negative eigenvalues detected, please make sure the affinity matrix is positive definite"
|
126 |
+
)
|
127 |
+
|
128 |
+
return eigen_vector, eigen_value
|
129 |
+
|
130 |
+
def nystrom_ncut(
|
131 |
+
features,
|
132 |
+
features_B=None,
|
133 |
+
num_eig=100,
|
134 |
+
num_sample=10000,
|
135 |
+
knn=10,
|
136 |
+
sample_method="farthest",
|
137 |
+
distance="cosine",
|
138 |
+
affinity_focal_gamma=1.0,
|
139 |
+
indirect_connection=False,
|
140 |
+
indirect_pca_dim=100,
|
141 |
+
device=None,
|
142 |
+
eig_solver="svd_lowrank",
|
143 |
+
normalize_features=False,
|
144 |
+
matmul_chunk_size=8096,
|
145 |
+
make_orthogonal=False,
|
146 |
+
verbose=False,
|
147 |
+
no_propagation=False,
|
148 |
+
make_symmetric=False,
|
149 |
+
n_features=1,
|
150 |
+
):
|
151 |
+
"""PyTorch implementation of Faster Nystrom Normalized cut.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
features (torch.Tensor): feature matrix, shape (n_samples, n_features)
|
155 |
+
features_2 (torch.Tensor): feature matrix 2, for asymmetric affinity matrix, shape (n_samples2, n_features)
|
156 |
+
num_eig (int): default 20, number of top eigenvectors to return
|
157 |
+
num_sample (int): default 30000, number of samples for Nystrom-like approximation
|
158 |
+
knn (int): default 3, number of KNN for propagating eigenvectors from subgraph to full graph,
|
159 |
+
smaller knn will result in more sharp eigenvectors,
|
160 |
+
sample_method (str): sample method, 'farthest' (default) or 'random'
|
161 |
+
'farthest' is recommended for better approximation
|
162 |
+
distance (str): distance metric, 'cosine' (default) or 'euclidean'
|
163 |
+
affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the weak edge weights,
|
164 |
+
resulting in more sharp eigenvectors, default 1.0
|
165 |
+
indirect_connection (bool): include indirect connection in the subgraph, default True
|
166 |
+
indirect_pca_dim (int): default 100, PCA dimension to reduce the node dimension, only applied to
|
167 |
+
the not sampled nodes, not applied to the sampled nodes
|
168 |
+
device (str): device to use for computation, if None, will not change device
|
169 |
+
a good practice is to pass features by CPU since it's usually large,
|
170 |
+
and move subgraph affinity to GPU to speed up eigenvector computation
|
171 |
+
eig_solver (str): eigen decompose solver, 'svd_lowrank' (default), 'lobpcg', 'svd', 'eigh'
|
172 |
+
'svd_lowrank' is recommended for large scale graph, it's the fastest
|
173 |
+
they correspond to torch.svd_lowrank, torch.lobpcg, torch.svd, torch.linalg.eigh
|
174 |
+
normalize_features (bool): normalize input features before computing affinity matrix,
|
175 |
+
default True
|
176 |
+
matmul_chunk_size (int): chunk size for matrix multiplication
|
177 |
+
large matrix multiplication is chunked to reduce memory usage,
|
178 |
+
smaller chunk size will reduce memory usage but slower computation, default 8096
|
179 |
+
make_orthogonal (bool): make eigenvectors orthogonal after propagation, default True
|
180 |
+
verbose (bool): show progress bar when propagating eigenvectors from subgraph to full graph
|
181 |
+
no_propagation (bool): if True, skip the eigenvector propagation step, only return the subgraph eigenvectors
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
(torch.Tensor): eigenvectors, shape (n_samples, num_eig)
|
185 |
+
(torch.Tensor): eigenvalues, sorted in descending order, shape (num_eig,)
|
186 |
+
(torch.Tensor): sampled_indices used by Nystrom-like approximation subgraph, shape (num_sample,)
|
187 |
+
"""
|
188 |
+
|
189 |
+
# check if features dimension greater than num_eig
|
190 |
+
if eig_solver in ["svd_lowrank", "lobpcg"]:
|
191 |
+
assert features.shape[0] > (
|
192 |
+
num_eig * 2
|
193 |
+
), "number of nodes should be greater than 2*num_eig"
|
194 |
+
if eig_solver in ["svd", "eigh"]:
|
195 |
+
assert (
|
196 |
+
features.shape[0] > num_eig
|
197 |
+
), "number of nodes should be greater than num_eig"
|
198 |
+
|
199 |
+
features = features.clone()
|
200 |
+
if normalize_features:
|
201 |
+
# features need to be normalized for affinity matrix computation (cosine distance)
|
202 |
+
features = torch.nn.functional.normalize(features, dim=-1)
|
203 |
+
|
204 |
+
sampled_indices = run_subgraph_sampling(
|
205 |
+
features,
|
206 |
+
num_sample=num_sample,
|
207 |
+
sample_method=sample_method,
|
208 |
+
)
|
209 |
+
|
210 |
+
sampled_indices_B = run_subgraph_sampling(
|
211 |
+
features_B,
|
212 |
+
num_sample=num_sample,
|
213 |
+
sample_method=sample_method,
|
214 |
+
)
|
215 |
+
|
216 |
+
sampled_features = features[sampled_indices]
|
217 |
+
sampled_features_B = features_B[sampled_indices_B]
|
218 |
+
# move subgraph gpu to speed up
|
219 |
+
original_device = sampled_features.device
|
220 |
+
device = original_device if device is None else device
|
221 |
+
sampled_features = sampled_features.to(device)
|
222 |
+
sampled_features_B = sampled_features_B.to(device)
|
223 |
+
|
224 |
+
# compute affinity matrix on subgraph
|
225 |
+
A = affinity_from_features(
|
226 |
+
sampled_features, features_B=sampled_features_B,
|
227 |
+
affinity_focal_gamma=affinity_focal_gamma, distance=distance,
|
228 |
+
n_features=n_features,
|
229 |
+
)
|
230 |
+
|
231 |
+
not_sampled = torch.tensor(
|
232 |
+
list(set(range(features.shape[0])) - set(sampled_indices))
|
233 |
+
)
|
234 |
+
|
235 |
+
if len(not_sampled) == 0:
|
236 |
+
# if sampled all nodes, no need for nyström approximation
|
237 |
+
eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver)
|
238 |
+
return eigen_vector, eigen_value, sampled_indices
|
239 |
+
|
240 |
+
# 1) PCA to reduce the node dimension for the not sampled nodes
|
241 |
+
# 2) compute indirect connection on the PC nodes
|
242 |
+
if len(not_sampled) > 0 and indirect_connection:
|
243 |
+
raise NotImplementedError("indirect_connection is not implemented yet")
|
244 |
+
indirect_pca_dim = min(indirect_pca_dim, min(*features.shape))
|
245 |
+
U, S, V = torch.pca_lowrank(features[not_sampled].T, q=indirect_pca_dim)
|
246 |
+
feature_B = (features[not_sampled].T @ V).T # project to PCA space
|
247 |
+
feature_B = feature_B.to(device)
|
248 |
+
B = affinity_from_features(
|
249 |
+
sampled_features,
|
250 |
+
feature_B,
|
251 |
+
affinity_focal_gamma=affinity_focal_gamma,
|
252 |
+
distance=distance,
|
253 |
+
fill_diagonal=False,
|
254 |
+
)
|
255 |
+
# P is 1-hop random walk matrix
|
256 |
+
B_row = B / B.sum(axis=1, keepdim=True)
|
257 |
+
B_col = B / B.sum(axis=0, keepdim=True)
|
258 |
+
P = B_row @ B_col.T
|
259 |
+
P = (P + P.T) / 2
|
260 |
+
# fill diagonal with 0
|
261 |
+
P[torch.arange(P.shape[0]), torch.arange(P.shape[0])] = 0
|
262 |
+
A = A + P
|
263 |
+
|
264 |
+
# compute normalized cut on the subgraph
|
265 |
+
eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver, make_symmetric=make_symmetric)
|
266 |
+
eigen_vector = eigen_vector.to(dtype=features.dtype, device=original_device)
|
267 |
+
eigen_value = eigen_value.to(dtype=features.dtype, device=original_device)
|
268 |
+
|
269 |
+
if no_propagation:
|
270 |
+
return eigen_vector, eigen_value, sampled_indices
|
271 |
+
|
272 |
+
# propagate eigenvectors from subgraph to full graph
|
273 |
+
eigen_vector = propagate_knn(
|
274 |
+
eigen_vector,
|
275 |
+
features,
|
276 |
+
sampled_features,
|
277 |
+
knn,
|
278 |
+
chunk_size=matmul_chunk_size,
|
279 |
+
device=device,
|
280 |
+
use_tqdm=verbose,
|
281 |
+
)
|
282 |
+
|
283 |
+
# post-hoc orthogonalization
|
284 |
+
if make_orthogonal:
|
285 |
+
eigen_vector = gram_schmidt(eigen_vector)
|
286 |
+
|
287 |
+
return eigen_vector, eigen_value, sampled_indices
|
requirements.txt
CHANGED
@@ -20,4 +20,4 @@ lisa @ git+https://github.com/huzeyann/LISA.git@7211e99
|
|
20 |
timm==0.9.2
|
21 |
open-clip-torch==2.20.0
|
22 |
pytorch_lightning==1.9.4
|
23 |
-
ncut-pytorch>=1.
|
|
|
20 |
timm==0.9.2
|
21 |
open-clip-torch==2.20.0
|
22 |
pytorch_lightning==1.9.4
|
23 |
+
ncut-pytorch>=1.4.1
|