Spaces:
Running
on
Zero
Running
on
Zero
add paper
Browse files- app.py +151 -17
- packages.txt +2 -0
app.py
CHANGED
@@ -183,6 +183,29 @@ downscaled_outputs = default_outputs
|
|
183 |
example_items = downscaled_images[:3] + downscaled_outputs[:3]
|
184 |
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
def ncut_run(
|
188 |
model,
|
@@ -212,7 +235,11 @@ def ncut_run(
|
|
212 |
video_output=False,
|
213 |
):
|
214 |
logging_str = ""
|
215 |
-
|
|
|
|
|
|
|
|
|
216 |
logging_str += f"Resolution: {resolution}\n"
|
217 |
if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
|
218 |
# raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
|
@@ -227,9 +254,13 @@ def ncut_run(
|
|
227 |
node_type = node_type.split(":")[0].strip()
|
228 |
|
229 |
start = time.time()
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
|
|
233 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
234 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
235 |
|
@@ -301,8 +332,25 @@ def ncut_run(
|
|
301 |
)
|
302 |
logging_str += _logging_str
|
303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
rgb = dont_use_too_much_green(rgb)
|
305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
|
307 |
if video_output:
|
308 |
video_path = get_random_path()
|
@@ -313,16 +361,19 @@ def ncut_run(
|
|
313 |
return to_pil_images(rgb), logging_str
|
314 |
|
315 |
def _ncut_run(*args, **kwargs):
|
316 |
-
try:
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
except Exception as e:
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
|
|
|
|
|
|
326 |
|
327 |
if USE_HUGGINGFACE_ZEROGPU:
|
328 |
@spaces.GPU(duration=20)
|
@@ -376,6 +427,28 @@ def transform_image(image, resolution=(1024, 1024)):
|
|
376 |
image = (image - 0.5) / 0.5
|
377 |
return image
|
378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
def run_fn(
|
380 |
images,
|
381 |
model_name="SAM(sam_vit_b)",
|
@@ -416,12 +489,21 @@ def run_fn(
|
|
416 |
sampling_method = "farthest"
|
417 |
|
418 |
# resize the images before acquiring GPU
|
419 |
-
|
|
|
|
|
|
|
|
|
420 |
images = [tup[0] for tup in images]
|
421 |
images = [transform_image(image, resolution=resolution) for image in images]
|
422 |
images = torch.stack(images)
|
423 |
|
424 |
-
|
|
|
|
|
|
|
|
|
|
|
425 |
if "stable" in model_name.lower() and "diffusion" in model_name.lower():
|
426 |
model.timestep = layer
|
427 |
layer = 1
|
@@ -932,7 +1014,59 @@ with demo:
|
|
932 |
# Last button only reveals the last row and hides itself
|
933 |
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
934 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
935 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
936 |
with gr.Row():
|
937 |
with gr.Column():
|
938 |
gr.Markdown("##### POWERED BY [ncut-pytorch](https://ncut-pytorch.readthedocs.io/) ")
|
|
|
183 |
example_items = downscaled_images[:3] + downscaled_outputs[:3]
|
184 |
|
185 |
|
186 |
+
def run_alignedthreemodelattnnodes(images, model, batch_size=1):
|
187 |
+
|
188 |
+
use_cuda = torch.cuda.is_available()
|
189 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
190 |
+
|
191 |
+
if use_cuda:
|
192 |
+
model = model.to(device)
|
193 |
+
|
194 |
+
chunked_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
195 |
+
|
196 |
+
outputs = []
|
197 |
+
for idxs in chunked_idxs:
|
198 |
+
inp = images[idxs]
|
199 |
+
if use_cuda:
|
200 |
+
inp = inp.to(device)
|
201 |
+
out = model(inp)
|
202 |
+
# normalize before save
|
203 |
+
out = F.normalize(out, dim=-1)
|
204 |
+
outputs.append(out.cpu().float())
|
205 |
+
outputs = torch.cat(outputs, dim=0)
|
206 |
+
|
207 |
+
return outputs
|
208 |
+
|
209 |
|
210 |
def ncut_run(
|
211 |
model,
|
|
|
235 |
video_output=False,
|
236 |
):
|
237 |
logging_str = ""
|
238 |
+
if "AlignedThreeModelAttnNodes" == model_name:
|
239 |
+
# dirty patch for the alignedcut paper
|
240 |
+
resolution = (672, 672)
|
241 |
+
else:
|
242 |
+
resolution = RES_DICT[model_name]
|
243 |
logging_str += f"Resolution: {resolution}\n"
|
244 |
if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
|
245 |
# raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
|
|
|
254 |
node_type = node_type.split(":")[0].strip()
|
255 |
|
256 |
start = time.time()
|
257 |
+
if "AlignedThreeModelAttnNodes" == model_name:
|
258 |
+
# dirty patch for the alignedcut paper
|
259 |
+
features = run_alignedthreemodelattnnodes(images, model, batch_size=BATCH_SIZE)
|
260 |
+
else:
|
261 |
+
features = extract_features(
|
262 |
+
images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
|
263 |
+
)
|
264 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
265 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
266 |
|
|
|
332 |
)
|
333 |
logging_str += _logging_str
|
334 |
|
335 |
+
if "AlignedThreeModelAttnNodes" == model_name:
|
336 |
+
# dirty patch for the alignedcut paper
|
337 |
+
galleries = []
|
338 |
+
for i_node in range(rgb.shape[1]):
|
339 |
+
_rgb = rgb[:, i_node]
|
340 |
+
galleries.append(to_pil_images(_rgb))
|
341 |
+
return *galleries, logging_str
|
342 |
+
|
343 |
rgb = dont_use_too_much_green(rgb)
|
344 |
|
345 |
+
if "AlignedThreeModelAttnNodes" == model_name:
|
346 |
+
# dirty patch for the alignedcut paper
|
347 |
+
print("AlignedThreeModelAttnNodes")
|
348 |
+
galleries = []
|
349 |
+
for i_node in range(rgb.shape[1]):
|
350 |
+
_rgb = rgb[:, i_node]
|
351 |
+
print(_rgb.shape)
|
352 |
+
galleries.append(to_pil_images(_rgb))
|
353 |
+
return *galleries, logging_str
|
354 |
|
355 |
if video_output:
|
356 |
video_path = get_random_path()
|
|
|
361 |
return to_pil_images(rgb), logging_str
|
362 |
|
363 |
def _ncut_run(*args, **kwargs):
|
364 |
+
# try:
|
365 |
+
# ret = ncut_run(*args, **kwargs)
|
366 |
+
# if torch.cuda.is_available():
|
367 |
+
# torch.cuda.empty_cache()
|
368 |
+
# return ret
|
369 |
+
# except Exception as e:
|
370 |
+
# gr.Error(str(e))
|
371 |
+
# if torch.cuda.is_available():
|
372 |
+
# torch.cuda.empty_cache()
|
373 |
+
# return [], "Error: " + str(e)
|
374 |
+
|
375 |
+
ret = ncut_run(*args, **kwargs)
|
376 |
+
return ret
|
377 |
|
378 |
if USE_HUGGINGFACE_ZEROGPU:
|
379 |
@spaces.GPU(duration=20)
|
|
|
427 |
image = (image - 0.5) / 0.5
|
428 |
return image
|
429 |
|
430 |
+
def load_alignedthreemodel():
|
431 |
+
|
432 |
+
os.system("git clone https://huggingface.co/huzey/alignedthreeattn >> /dev/null 2>&1")
|
433 |
+
# pull
|
434 |
+
os.system("git -C alignedthreeattn pull >> /dev/null 2>&1")
|
435 |
+
# add to path
|
436 |
+
import sys
|
437 |
+
sys.path.append("alignedthreeattn")
|
438 |
+
|
439 |
+
|
440 |
+
from alignedthreeattn.alignedthreeattn_model import ThreeAttnNodes
|
441 |
+
|
442 |
+
align_weights = torch.load("alignedthreeattn/align_weights.pth")
|
443 |
+
model = ThreeAttnNodes(align_weights)
|
444 |
+
|
445 |
+
# url = 'https://huggingface.co/huzey/aligned_model_test/resolve/main/3attn_nodes.pth'
|
446 |
+
# save_path = "alignedthreemodel.pth"
|
447 |
+
# if not os.path.exists(save_path):
|
448 |
+
# os.system(f"wget {url} -O {save_path} -q")
|
449 |
+
# model = torch.load(save_path)
|
450 |
+
return model
|
451 |
+
|
452 |
def run_fn(
|
453 |
images,
|
454 |
model_name="SAM(sam_vit_b)",
|
|
|
489 |
sampling_method = "farthest"
|
490 |
|
491 |
# resize the images before acquiring GPU
|
492 |
+
if "AlignedThreeModelAttnNodes" == model_name:
|
493 |
+
# dirty patch for the alignedcut paper
|
494 |
+
resolution = (672, 672)
|
495 |
+
else:
|
496 |
+
resolution = RES_DICT[model_name]
|
497 |
images = [tup[0] for tup in images]
|
498 |
images = [transform_image(image, resolution=resolution) for image in images]
|
499 |
images = torch.stack(images)
|
500 |
|
501 |
+
if "AlignedThreeModelAttnNodes" == model_name:
|
502 |
+
# dirty patch for the alignedcut paper
|
503 |
+
model = load_alignedthreemodel()
|
504 |
+
else:
|
505 |
+
model = load_model(model_name)
|
506 |
+
|
507 |
if "stable" in model_name.lower() and "diffusion" in model_name.lower():
|
508 |
model.timestep = layer
|
509 |
layer = 1
|
|
|
1014 |
# Last button only reveals the last row and hides itself
|
1015 |
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
1016 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
1017 |
+
|
1018 |
+
with gr.Tab('Compare (Aligned)'):
|
1019 |
+
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
1020 |
+
gr.Markdown('---')
|
1021 |
+
gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
|
1022 |
+
gr.Markdown('NCUT is computed on the concatenated graph of all models, layers, and images. Color is **aligned** across all models and layers.')
|
1023 |
+
gr.Markdown('---')
|
1024 |
+
with gr.Row():
|
1025 |
+
with gr.Column(scale=5, min_width=200):
|
1026 |
+
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
1027 |
+
|
1028 |
+
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True)
|
1029 |
+
num_images_slider.value = 100
|
1030 |
+
|
1031 |
+
with gr.Column(scale=5, min_width=200):
|
1032 |
+
gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
|
1033 |
+
gr.Markdown('Layer type: attention output (attn), without sum of residual')
|
1034 |
+
[
|
1035 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
1036 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
1037 |
+
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1038 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
1039 |
+
sampling_method_dropdown
|
1040 |
+
] = make_parameters_section()
|
1041 |
+
model_dropdown.value = "AlignedThreeModelAttnNodes"
|
1042 |
+
model_dropdown.visible = False
|
1043 |
+
layer_slider.visible = False
|
1044 |
+
node_type_dropdown.visible = False
|
1045 |
+
# logging text box
|
1046 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
1047 |
+
|
1048 |
+
galleries = []
|
1049 |
+
for i_model, model_name in enumerate(["CLIP", "DINO", "MAE"]):
|
1050 |
+
with gr.Row():
|
1051 |
+
for i_layer in range(1, 13):
|
1052 |
+
with gr.Column(scale=5, min_width=200):
|
1053 |
+
gr.Markdown(f'### {model_name} Layer {i_layer}')
|
1054 |
+
output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
|
1055 |
+
galleries.append(output_gallery)
|
1056 |
+
|
1057 |
+
|
1058 |
+
clear_images_button.click(lambda x: [] * (len(galleries) + 1), outputs=[input_gallery] + galleries)
|
1059 |
+
submit_button.click(
|
1060 |
+
run_fn,
|
1061 |
+
inputs=[
|
1062 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
1063 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
1064 |
+
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1065 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
1066 |
+
],
|
1067 |
+
outputs=galleries + [logging_text],
|
1068 |
+
)
|
1069 |
+
|
1070 |
with gr.Row():
|
1071 |
with gr.Column():
|
1072 |
gr.Markdown("##### POWERED BY [ncut-pytorch](https://ncut-pytorch.readthedocs.io/) ")
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
git-all
|
2 |
+
git-lfs
|