huzey commited on
Commit
94eb803
1 Parent(s): 976cf1e
Files changed (2) hide show
  1. app.py +151 -17
  2. packages.txt +2 -0
app.py CHANGED
@@ -183,6 +183,29 @@ downscaled_outputs = default_outputs
183
  example_items = downscaled_images[:3] + downscaled_outputs[:3]
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  def ncut_run(
188
  model,
@@ -212,7 +235,11 @@ def ncut_run(
212
  video_output=False,
213
  ):
214
  logging_str = ""
215
- resolution = RES_DICT[model_name]
 
 
 
 
216
  logging_str += f"Resolution: {resolution}\n"
217
  if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
218
  # raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
@@ -227,9 +254,13 @@ def ncut_run(
227
  node_type = node_type.split(":")[0].strip()
228
 
229
  start = time.time()
230
- features = extract_features(
231
- images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
232
- )
 
 
 
 
233
  # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
234
  logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
235
 
@@ -301,8 +332,25 @@ def ncut_run(
301
  )
302
  logging_str += _logging_str
303
 
 
 
 
 
 
 
 
 
304
  rgb = dont_use_too_much_green(rgb)
305
 
 
 
 
 
 
 
 
 
 
306
 
307
  if video_output:
308
  video_path = get_random_path()
@@ -313,16 +361,19 @@ def ncut_run(
313
  return to_pil_images(rgb), logging_str
314
 
315
  def _ncut_run(*args, **kwargs):
316
- try:
317
- ret = ncut_run(*args, **kwargs)
318
- if torch.cuda.is_available():
319
- torch.cuda.empty_cache()
320
- return ret
321
- except Exception as e:
322
- gr.Error(str(e))
323
- if torch.cuda.is_available():
324
- torch.cuda.empty_cache()
325
- return [], "Error: " + str(e)
 
 
 
326
 
327
  if USE_HUGGINGFACE_ZEROGPU:
328
  @spaces.GPU(duration=20)
@@ -376,6 +427,28 @@ def transform_image(image, resolution=(1024, 1024)):
376
  image = (image - 0.5) / 0.5
377
  return image
378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  def run_fn(
380
  images,
381
  model_name="SAM(sam_vit_b)",
@@ -416,12 +489,21 @@ def run_fn(
416
  sampling_method = "farthest"
417
 
418
  # resize the images before acquiring GPU
419
- resolution = RES_DICT[model_name]
 
 
 
 
420
  images = [tup[0] for tup in images]
421
  images = [transform_image(image, resolution=resolution) for image in images]
422
  images = torch.stack(images)
423
 
424
- model = load_model(model_name)
 
 
 
 
 
425
  if "stable" in model_name.lower() and "diffusion" in model_name.lower():
426
  model.timestep = layer
427
  layer = 1
@@ -932,7 +1014,59 @@ with demo:
932
  # Last button only reveals the last row and hides itself
933
  buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
934
  buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
935
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
936
  with gr.Row():
937
  with gr.Column():
938
  gr.Markdown("##### POWERED BY [ncut-pytorch](https://ncut-pytorch.readthedocs.io/) ")
 
183
  example_items = downscaled_images[:3] + downscaled_outputs[:3]
184
 
185
 
186
+ def run_alignedthreemodelattnnodes(images, model, batch_size=1):
187
+
188
+ use_cuda = torch.cuda.is_available()
189
+ device = torch.device("cuda" if use_cuda else "cpu")
190
+
191
+ if use_cuda:
192
+ model = model.to(device)
193
+
194
+ chunked_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
195
+
196
+ outputs = []
197
+ for idxs in chunked_idxs:
198
+ inp = images[idxs]
199
+ if use_cuda:
200
+ inp = inp.to(device)
201
+ out = model(inp)
202
+ # normalize before save
203
+ out = F.normalize(out, dim=-1)
204
+ outputs.append(out.cpu().float())
205
+ outputs = torch.cat(outputs, dim=0)
206
+
207
+ return outputs
208
+
209
 
210
  def ncut_run(
211
  model,
 
235
  video_output=False,
236
  ):
237
  logging_str = ""
238
+ if "AlignedThreeModelAttnNodes" == model_name:
239
+ # dirty patch for the alignedcut paper
240
+ resolution = (672, 672)
241
+ else:
242
+ resolution = RES_DICT[model_name]
243
  logging_str += f"Resolution: {resolution}\n"
244
  if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
245
  # raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
 
254
  node_type = node_type.split(":")[0].strip()
255
 
256
  start = time.time()
257
+ if "AlignedThreeModelAttnNodes" == model_name:
258
+ # dirty patch for the alignedcut paper
259
+ features = run_alignedthreemodelattnnodes(images, model, batch_size=BATCH_SIZE)
260
+ else:
261
+ features = extract_features(
262
+ images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
263
+ )
264
  # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
265
  logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
266
 
 
332
  )
333
  logging_str += _logging_str
334
 
335
+ if "AlignedThreeModelAttnNodes" == model_name:
336
+ # dirty patch for the alignedcut paper
337
+ galleries = []
338
+ for i_node in range(rgb.shape[1]):
339
+ _rgb = rgb[:, i_node]
340
+ galleries.append(to_pil_images(_rgb))
341
+ return *galleries, logging_str
342
+
343
  rgb = dont_use_too_much_green(rgb)
344
 
345
+ if "AlignedThreeModelAttnNodes" == model_name:
346
+ # dirty patch for the alignedcut paper
347
+ print("AlignedThreeModelAttnNodes")
348
+ galleries = []
349
+ for i_node in range(rgb.shape[1]):
350
+ _rgb = rgb[:, i_node]
351
+ print(_rgb.shape)
352
+ galleries.append(to_pil_images(_rgb))
353
+ return *galleries, logging_str
354
 
355
  if video_output:
356
  video_path = get_random_path()
 
361
  return to_pil_images(rgb), logging_str
362
 
363
  def _ncut_run(*args, **kwargs):
364
+ # try:
365
+ # ret = ncut_run(*args, **kwargs)
366
+ # if torch.cuda.is_available():
367
+ # torch.cuda.empty_cache()
368
+ # return ret
369
+ # except Exception as e:
370
+ # gr.Error(str(e))
371
+ # if torch.cuda.is_available():
372
+ # torch.cuda.empty_cache()
373
+ # return [], "Error: " + str(e)
374
+
375
+ ret = ncut_run(*args, **kwargs)
376
+ return ret
377
 
378
  if USE_HUGGINGFACE_ZEROGPU:
379
  @spaces.GPU(duration=20)
 
427
  image = (image - 0.5) / 0.5
428
  return image
429
 
430
+ def load_alignedthreemodel():
431
+
432
+ os.system("git clone https://huggingface.co/huzey/alignedthreeattn >> /dev/null 2>&1")
433
+ # pull
434
+ os.system("git -C alignedthreeattn pull >> /dev/null 2>&1")
435
+ # add to path
436
+ import sys
437
+ sys.path.append("alignedthreeattn")
438
+
439
+
440
+ from alignedthreeattn.alignedthreeattn_model import ThreeAttnNodes
441
+
442
+ align_weights = torch.load("alignedthreeattn/align_weights.pth")
443
+ model = ThreeAttnNodes(align_weights)
444
+
445
+ # url = 'https://huggingface.co/huzey/aligned_model_test/resolve/main/3attn_nodes.pth'
446
+ # save_path = "alignedthreemodel.pth"
447
+ # if not os.path.exists(save_path):
448
+ # os.system(f"wget {url} -O {save_path} -q")
449
+ # model = torch.load(save_path)
450
+ return model
451
+
452
  def run_fn(
453
  images,
454
  model_name="SAM(sam_vit_b)",
 
489
  sampling_method = "farthest"
490
 
491
  # resize the images before acquiring GPU
492
+ if "AlignedThreeModelAttnNodes" == model_name:
493
+ # dirty patch for the alignedcut paper
494
+ resolution = (672, 672)
495
+ else:
496
+ resolution = RES_DICT[model_name]
497
  images = [tup[0] for tup in images]
498
  images = [transform_image(image, resolution=resolution) for image in images]
499
  images = torch.stack(images)
500
 
501
+ if "AlignedThreeModelAttnNodes" == model_name:
502
+ # dirty patch for the alignedcut paper
503
+ model = load_alignedthreemodel()
504
+ else:
505
+ model = load_model(model_name)
506
+
507
  if "stable" in model_name.lower() and "diffusion" in model_name.lower():
508
  model.timestep = layer
509
  layer = 1
 
1014
  # Last button only reveals the last row and hides itself
1015
  buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
1016
  buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
1017
+
1018
+ with gr.Tab('Compare (Aligned)'):
1019
+ gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
1020
+ gr.Markdown('---')
1021
+ gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
1022
+ gr.Markdown('NCUT is computed on the concatenated graph of all models, layers, and images. Color is **aligned** across all models and layers.')
1023
+ gr.Markdown('---')
1024
+ with gr.Row():
1025
+ with gr.Column(scale=5, min_width=200):
1026
+ input_gallery, submit_button, clear_images_button = make_input_images_section()
1027
+
1028
+ dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True)
1029
+ num_images_slider.value = 100
1030
+
1031
+ with gr.Column(scale=5, min_width=200):
1032
+ gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
1033
+ gr.Markdown('Layer type: attention output (attn), without sum of residual')
1034
+ [
1035
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1036
+ affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1037
+ embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1038
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
1039
+ sampling_method_dropdown
1040
+ ] = make_parameters_section()
1041
+ model_dropdown.value = "AlignedThreeModelAttnNodes"
1042
+ model_dropdown.visible = False
1043
+ layer_slider.visible = False
1044
+ node_type_dropdown.visible = False
1045
+ # logging text box
1046
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1047
+
1048
+ galleries = []
1049
+ for i_model, model_name in enumerate(["CLIP", "DINO", "MAE"]):
1050
+ with gr.Row():
1051
+ for i_layer in range(1, 13):
1052
+ with gr.Column(scale=5, min_width=200):
1053
+ gr.Markdown(f'### {model_name} Layer {i_layer}')
1054
+ output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
1055
+ galleries.append(output_gallery)
1056
+
1057
+
1058
+ clear_images_button.click(lambda x: [] * (len(galleries) + 1), outputs=[input_gallery] + galleries)
1059
+ submit_button.click(
1060
+ run_fn,
1061
+ inputs=[
1062
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1063
+ affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1064
+ embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1065
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
1066
+ ],
1067
+ outputs=galleries + [logging_text],
1068
+ )
1069
+
1070
  with gr.Row():
1071
  with gr.Column():
1072
  gr.Markdown("##### POWERED BY [ncut-pytorch](https://ncut-pytorch.readthedocs.io/) ")
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ git-all
2
+ git-lfs