huzey commited on
Commit
319a391
1 Parent(s): a551f9e

update cluster fg bg

Browse files
Files changed (1) hide show
  1. app.py +142 -189
app.py CHANGED
@@ -308,7 +308,79 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
308
  blended = (1 - opacity1) * image + opacity2 * heatmap
309
  return blended.astype(np.uint8)
310
 
311
- def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  progress = gr.Progress()
313
  progress(progess_start, desc="Finding Clusters by FPS")
314
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -318,10 +390,13 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
318
 
319
  # gr.Info("Finding Clusters by FPS, no magnitude filtering")
320
  top_p_idx = torch.arange(eigvecs.shape[0])
 
 
321
  # gr.Info("Finding Clusters by FPS, with magnitude filtering")
322
  # p = 0.8
323
  # top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
324
 
 
325
  ret_magnitude = magnitude.reshape(-1, h, w)
326
 
327
 
@@ -338,7 +413,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
338
  right = F.normalize(right, dim=-1)
339
  heatmap = left @ right.T
340
  heatmap = F.normalize(heatmap, dim=-1)
341
- num_samples = 50 if not advanced else 100
342
  if num_samples > fps_idx.shape[0]:
343
  num_samples = fps_idx.shape[0]
344
  r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
@@ -398,10 +473,10 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
398
 
399
  fig_images = []
400
  i_cluster = 0
401
- num_plots = 10 if not advanced else 20
402
  plot_step_float = (1.0 - progess_start) / num_plots
403
  for i_fig in range(num_plots):
404
- progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
405
  if not advanced:
406
  fig, axs = plt.subplots(3, 5, figsize=(15, 9))
407
  if advanced:
@@ -421,7 +496,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
421
  _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
422
  axs[i, j].imshow(_heatmap)
423
  if i == 0:
424
- axs[i, j].set_title(f"cluster {i_cluster+1}", fontsize=24)
425
  i_cluster += 1
426
  plt.tight_layout(h_pad=0.5, w_pad=0.3)
427
 
@@ -440,6 +515,39 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
440
 
441
  return fig_images, ret_magnitude
442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
  def ncut_run(
445
  model,
@@ -601,7 +709,7 @@ def ncut_run(
601
  if torch.cuda.is_available():
602
  images = images.cuda()
603
  _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
604
- cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=advanced)
605
  logging_str += f"Recursion #{i+1} plot time: {time.time() - start:.2f}s\n"
606
 
607
  norm_images = []
@@ -716,7 +824,10 @@ def ncut_run(
716
  images = images.cuda()
717
  _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
718
  advanced = kwargs.get("advanced", False)
719
- cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=advanced)
 
 
 
720
  logging_str += f"plot time: {time.time() - start:.2f}s\n"
721
 
722
  norm_images = None
@@ -736,33 +847,33 @@ def ncut_run(
736
  logging_str += "Eigenvector Magnitude\n"
737
  logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
738
  gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=10)
739
-
740
  return to_pil_images(rgb), cluster_images, norm_images, logging_str
741
 
742
 
743
 
744
  def _ncut_run(*args, **kwargs):
745
  n_ret = kwargs.pop("n_ret", 1)
746
- try:
747
- if torch.cuda.is_available():
748
- torch.cuda.empty_cache()
749
 
750
- ret = ncut_run(*args, **kwargs)
751
 
752
- if torch.cuda.is_available():
753
- torch.cuda.empty_cache()
754
 
755
- ret = list(ret)[:n_ret] + [ret[-1]]
756
- return ret
757
- except Exception as e:
758
- gr.Error(str(e))
759
- if torch.cuda.is_available():
760
- torch.cuda.empty_cache()
761
- return *(None for _ in range(n_ret)), "Error: " + str(e)
762
-
763
- # ret = ncut_run(*args, **kwargs)
764
- # ret = list(ret)[:n_ret] + [ret[-1]]
765
- # return ret
766
 
767
  if USE_HUGGINGFACE_ZEROGPU:
768
  @spaces.GPU(duration=30)
@@ -1186,7 +1297,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
1186
  images += [Image.open(new_image) for new_image in new_images]
1187
  if isinstance(new_images, str):
1188
  images.append(Image.open(new_images))
1189
- new_images = None
1190
  return images
1191
  upload_button.upload(convert_to_pil_and_append, inputs=[input_gallery, upload_button], outputs=[input_gallery])
1192
 
@@ -1402,6 +1513,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
1402
  if existing_images is None:
1403
  existing_images = []
1404
  existing_images += new_images
 
1405
  return existing_images
1406
 
1407
  load_images_button.click(load_and_append,
@@ -1416,165 +1528,6 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
1416
 
1417
 
1418
 
1419
- # def make_input_images_section(rows=1, cols=3, height="auto"):
1420
- # gr.Markdown('### Input Images')
1421
- # input_gallery = gr.Gallery(value=None, label="Select images", show_label=True, elem_id="images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False)
1422
- # submit_button = gr.Button("🔴 RUN", elem_id="submit_button", variant='primary')
1423
- # clear_images_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
1424
- # return input_gallery, submit_button, clear_images_button
1425
-
1426
-
1427
- # def make_dataset_images_section(advanced=False, is_random=False):
1428
-
1429
- # gr.Markdown('### Load Datasets')
1430
- # load_images_button = gr.Button("🔴 Load Images", elem_id="load-images-button", variant='primary')
1431
- # advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets", value="Advanced" if advanced else "Basic", elem_id="advanced-radio", show_label=True)
1432
- # with gr.Column() as basic_block:
1433
- # example_gallery = gr.Gallery(value=example_items, label="Example Images", show_label=True, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
1434
- # with gr.Column() as advanced_block:
1435
- # # dataset_names = DATASET_NAMES
1436
- # # dataset_classes = DATASET_CLASSES
1437
- # dataset_categories = list(DATASETS.keys())
1438
- # defualt_cat = dataset_categories[0]
1439
- # def get_choices(cat):
1440
- # return [tup[0] for tup in DATASETS[cat]]
1441
- # defualt_choices = get_choices(defualt_cat)
1442
- # with gr.Row():
1443
- # dataset_radio = gr.Radio(dataset_categories, label="Dataset Category", value=defualt_cat, elem_id="dataset-radio", show_label=True, min_width=600)
1444
- # # dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset", min_width=300)
1445
- # dataset_dropdown = gr.Dropdown(defualt_choices, label="Dataset name", value=defualt_choices[0], elem_id="dataset", min_width=400)
1446
- # dataset_radio.change(fn=lambda x: gr.update(choices=get_choices(x), value=get_choices(x)[0]), inputs=dataset_radio, outputs=dataset_dropdown)
1447
- # # num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
1448
- # num_images_slider = gr.Slider(1, 1000, step=1, label="Number of images", value=10, elem_id="num_images", min_width=200)
1449
- # if not is_random:
1450
- # filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox")
1451
- # filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=True)
1452
- # # is_random_checkbox = gr.Checkbox(label="Random shuffle", value=False, elem_id="random_seed_checkbox")
1453
- # # random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=False)
1454
- # is_random_checkbox = gr.Checkbox(label="Random shuffle", value=True, elem_id="random_seed_checkbox")
1455
- # random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=True)
1456
- # if is_random:
1457
- # filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox")
1458
- # filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=False)
1459
- # is_random_checkbox = gr.Checkbox(label="Random shuffle", value=True, elem_id="random_seed_checkbox")
1460
- # random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=42, elem_id="random_seed", visible=True)
1461
-
1462
-
1463
- # if advanced:
1464
- # advanced_block.visible = True
1465
- # basic_block.visible = False
1466
- # else:
1467
- # advanced_block.visible = False
1468
- # basic_block.visible = True
1469
-
1470
- # # change visibility
1471
- # advanced_radio.change(fn=lambda x: gr.update(visible=x=="Advanced"), inputs=advanced_radio, outputs=[advanced_block])
1472
- # advanced_radio.change(fn=lambda x: gr.update(visible=x=="Basic"), inputs=advanced_radio, outputs=[basic_block])
1473
-
1474
- # def find_num_classes(dataset_name):
1475
- # num_classes = None
1476
- # for cat, datasets in DATASETS.items():
1477
- # datasets = [tup[0] for tup in datasets]
1478
- # if dataset_name in datasets:
1479
- # num_classes = DATASETS[cat][datasets.index(dataset_name)][1]
1480
- # break
1481
- # return num_classes
1482
-
1483
- # def change_filter_options(dataset_name):
1484
- # num_classes = find_num_classes(dataset_name)
1485
- # if num_classes is None:
1486
- # return (gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox", visible=False),
1487
- # gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info="e.g. `0,1,2`. This dataset has no class label", visible=False))
1488
- # return (gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox", visible=True),
1489
- # gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=True))
1490
- # dataset_dropdown.change(fn=change_filter_options, inputs=dataset_dropdown, outputs=[filter_by_class_checkbox, filter_by_class_text])
1491
-
1492
- # def change_filter_by_class(is_filter, dataset_name):
1493
- # num_classes = find_num_classes(dataset_name)
1494
- # return gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=is_filter)
1495
- # filter_by_class_checkbox.change(fn=change_filter_by_class, inputs=[filter_by_class_checkbox, dataset_dropdown], outputs=filter_by_class_text)
1496
-
1497
- # def change_random_seed(is_random):
1498
- # return gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=is_random)
1499
- # is_random_checkbox.change(fn=change_random_seed, inputs=is_random_checkbox, outputs=random_seed_slider)
1500
-
1501
-
1502
- # def load_dataset_images(is_advanced, dataset_name, num_images=10,
1503
- # is_filter=True, filter_by_class_text="0,1,2",
1504
- # is_random=False, seed=1):
1505
- # progress = gr.Progress()
1506
- # progress(0, desc="Loading Images")
1507
- # if is_advanced == "Basic":
1508
- # gr.Info("Loaded images from Ego-Exo4D")
1509
- # return default_images
1510
- # try:
1511
- # progress(0.5, desc="Downloading Dataset")
1512
- # dataset = load_dataset(dataset_name, trust_remote_code=True)
1513
- # key = list(dataset.keys())[0]
1514
- # dataset = dataset[key]
1515
- # except Exception as e:
1516
- # gr.Error(f"Error loading dataset {dataset_name}: {e}")
1517
- # return None
1518
- # if num_images > len(dataset):
1519
- # num_images = len(dataset)
1520
-
1521
- # if is_filter:
1522
- # progress(0.8, desc="Filtering Images")
1523
- # classes = [int(i) for i in filter_by_class_text.split(",")]
1524
- # labels = np.array(dataset['label'])
1525
- # unique_labels = np.unique(labels)
1526
- # valid_classes = [i for i in classes if i in unique_labels]
1527
- # invalid_classes = [i for i in classes if i not in unique_labels]
1528
- # if len(invalid_classes) > 0:
1529
- # gr.Warning(f"Classes {invalid_classes} not found in the dataset.")
1530
- # if len(valid_classes) == 0:
1531
- # gr.Error(f"Classes {classes} not found in the dataset.")
1532
- # return None
1533
- # # shuffle each class
1534
- # chunk_size = num_images // len(valid_classes)
1535
- # image_idx = []
1536
- # for i in valid_classes:
1537
- # idx = np.where(labels == i)[0]
1538
- # if is_random:
1539
- # idx = np.random.RandomState(seed).choice(idx, chunk_size, replace=False)
1540
- # else:
1541
- # idx = idx[:chunk_size]
1542
- # image_idx.extend(idx.tolist())
1543
- # if not is_filter:
1544
- # if is_random:
1545
- # image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
1546
- # else:
1547
- # image_idx = list(range(num_images))
1548
- # key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
1549
- # images = [dataset[i][key] for i in image_idx]
1550
- # gr.Info(f"Loaded {len(images)} images from {dataset_name}")
1551
- # del dataset
1552
-
1553
- # if dataset_name in CENTER_CROP_DATASETS:
1554
- # def center_crop_image(img):
1555
- # # image: PIL image
1556
- # w, h = img.size
1557
- # min_hw = min(h, w)
1558
- # # center crop
1559
- # left = (w - min_hw) // 2
1560
- # top = (h - min_hw) // 2
1561
- # right = left + min_hw
1562
- # bottom = top + min_hw
1563
- # img = img.crop((left, top, right, bottom))
1564
- # return img
1565
- # images = [center_crop_image(image) for image in images]
1566
-
1567
- # return images
1568
-
1569
- # load_images_button.click(load_dataset_images,
1570
- # inputs=[advanced_radio, dataset_dropdown, num_images_slider,
1571
- # filter_by_class_checkbox, filter_by_class_text,
1572
- # is_random_checkbox, random_seed_slider],
1573
- # outputs=[input_gallery])
1574
-
1575
- # return dataset_dropdown, num_images_slider, random_seed_slider, load_images_button
1576
-
1577
-
1578
  # def random_rotate_rgb_gallery(images):
1579
  # if images is None or len(images) == 0:
1580
  # gr.Warning("No images selected.")
@@ -1969,19 +1922,19 @@ with demo:
1969
  l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
1970
  add_output_images_buttons(l1_gallery)
1971
  l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
1972
- l1_cluster_gallery = gr.Gallery(value=[], label="Recursion #1 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
1973
  with gr.Column(scale=5, min_width=200):
1974
  gr.Markdown('### Output (Recursion #2)')
1975
  l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
1976
  add_output_images_buttons(l2_gallery)
1977
  l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
1978
- l2_cluster_gallery = gr.Gallery(value=[], label="Recursion #2 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
1979
  with gr.Column(scale=5, min_width=200):
1980
  gr.Markdown('### Output (Recursion #3)')
1981
  l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
1982
  add_output_images_buttons(l3_gallery)
1983
  l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
1984
- l3_cluster_gallery = gr.Gallery(value=[], label="Recursion #3 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
1985
 
1986
  with gr.Row():
1987
  with gr.Column(scale=5, min_width=200):
@@ -2352,7 +2305,7 @@ with demo:
2352
  submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
2353
  add_output_images_buttons(output_gallery)
2354
  norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2355
- cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
2356
  [
2357
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
2358
  affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
 
308
  blended = (1 - opacity1) * image + opacity2 * heatmap
309
  return blended.astype(np.uint8)
310
 
311
+
312
+ def segment_fg_bg(images):
313
+
314
+ images = F.interpolate(images, (224, 224), mode="bilinear")
315
+
316
+ # model = load_alignedthreemodel()
317
+ model = load_model("CLIP(ViT-B-16/openai)")
318
+ from ncut_pytorch.backbone import resample_position_embeddings
319
+ pos_embed = model.model.visual.positional_embedding
320
+ pos_embed = resample_position_embeddings(pos_embed, 14, 14)
321
+ model.model.visual.positional_embedding = torch.nn.Parameter(pos_embed)
322
+
323
+ batch_size = 4
324
+ chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
325
+
326
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
327
+ model.to(device)
328
+ means = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
329
+ stds = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
330
+
331
+ fg_acts, bg_acts = [], []
332
+ for chunk_idx in chunk_idxs:
333
+ with torch.no_grad():
334
+ input_images = images[chunk_idx].to(device)
335
+ # transform the input images
336
+ input_images = (input_images - means) / stds
337
+ # output = model(input_images)[:, 5]
338
+ output = model(input_images)['attn'][6]
339
+ fg_act = output[:, 6, 6].mean(0)
340
+ bg_act = output[:, 0, 0].mean(0)
341
+ fg_acts.append(fg_act)
342
+ bg_acts.append(bg_act)
343
+ fg_act = torch.stack(fg_acts, dim=0).mean(0)
344
+ bg_act = torch.stack(bg_acts, dim=0).mean(0)
345
+ fg_act = F.normalize(fg_act, dim=-1)
346
+ bg_act = F.normalize(bg_act, dim=-1)
347
+
348
+ # ref_image = default_images[0]
349
+ # image = Image.open(ref_image).convert("RGB").resize((224, 224), Image.Resampling.BILINEAR)
350
+ # image = torch.tensor(np.array(image)).permute(2, 0, 1).float().to(device)
351
+ # image = (image / 255.0 - means) / stds
352
+ # output = model(image)['attn'][6][0]
353
+ # # print(output.shape)
354
+ # # bg on the center
355
+ # fg_act = output[5, 5]
356
+ # # bg on the bottom left
357
+ # bg_act = output[0, 0]
358
+ # fg_act = F.normalize(fg_act, dim=-1)
359
+ # bg_act = F.normalize(bg_act, dim=-1)
360
+
361
+ # print(images.mean(), images.std())
362
+
363
+ fg_act, bg_act = fg_act.to(device), bg_act.to(device)
364
+ chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
365
+ heatmap_fgs, heatmap_bgs = [], []
366
+ for chunk_idx in chunk_idxs:
367
+ with torch.no_grad():
368
+ input_images = images[chunk_idx].to(device)
369
+ # transform the input images
370
+ input_images = (input_images - means) / stds
371
+ # output = model(input_images)[:, 5]
372
+ output = model(input_images)['attn'][6]
373
+ output = F.normalize(output, dim=-1)
374
+ heatmap_fg = output @ fg_act[:, None]
375
+ heatmap_bg = output @ bg_act[:, None]
376
+ heatmap_fgs.append(heatmap_fg.cpu())
377
+ heatmap_bgs.append(heatmap_bg.cpu())
378
+ heatmap_fg = torch.cat(heatmap_fgs, dim=0)
379
+ heatmap_bg = torch.cat(heatmap_bgs, dim=0)
380
+ return heatmap_fg, heatmap_bg
381
+
382
+
383
+ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=False, clusters=50, eig_idx=None, title='cluster'):
384
  progress = gr.Progress()
385
  progress(progess_start, desc="Finding Clusters by FPS")
386
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
390
 
391
  # gr.Info("Finding Clusters by FPS, no magnitude filtering")
392
  top_p_idx = torch.arange(eigvecs.shape[0])
393
+ if eig_idx is not None:
394
+ top_p_idx = eig_idx
395
  # gr.Info("Finding Clusters by FPS, with magnitude filtering")
396
  # p = 0.8
397
  # top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
398
 
399
+
400
  ret_magnitude = magnitude.reshape(-1, h, w)
401
 
402
 
 
413
  right = F.normalize(right, dim=-1)
414
  heatmap = left @ right.T
415
  heatmap = F.normalize(heatmap, dim=-1)
416
+ num_samples = clusters + 20
417
  if num_samples > fps_idx.shape[0]:
418
  num_samples = fps_idx.shape[0]
419
  r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
 
473
 
474
  fig_images = []
475
  i_cluster = 0
476
+ num_plots = clusters // 5
477
  plot_step_float = (1.0 - progess_start) / num_plots
478
  for i_fig in range(num_plots):
479
+ progress(progess_start + i_fig * plot_step_float, desc=f"Plotting {title}")
480
  if not advanced:
481
  fig, axs = plt.subplots(3, 5, figsize=(15, 9))
482
  if advanced:
 
496
  _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
497
  axs[i, j].imshow(_heatmap)
498
  if i == 0:
499
+ axs[i, j].set_title(f"{title} {i_cluster+1}", fontsize=24)
500
  i_cluster += 1
501
  plt.tight_layout(h_pad=0.5, w_pad=0.3)
502
 
 
515
 
516
  return fig_images, ret_magnitude
517
 
518
+ def make_cluster_plot_advanced(eigvecs, images, h=64, w=64):
519
+ heatmap_fg, heatmap_bg = segment_fg_bg(images.clone())
520
+ heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w')
521
+ heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w')
522
+ heatmap_fg = F.interpolate(heatmap_fg, (h, w), mode="bilinear")
523
+ heatmap_bg = F.interpolate(heatmap_bg, (h, w), mode="bilinear")
524
+ heatmap_fg = heatmap_fg.flatten()
525
+ heatmap_bg = heatmap_bg.flatten()
526
+
527
+ fg_minus_bg = heatmap_fg - heatmap_bg
528
+ fg_mask = fg_minus_bg > fg_minus_bg.quantile(0.8)
529
+ bg_mask = fg_minus_bg < fg_minus_bg.quantile(0.2)
530
+
531
+ # fg_mask = heatmap_fg > heatmap_fg.quantile(0.8)
532
+ # bg_mask = heatmap_bg > heatmap_bg.quantile(0.8)
533
+ other_mask = ~(fg_mask | bg_mask)
534
+
535
+ fg_idx = torch.arange(heatmap_fg.shape[0])[fg_mask]
536
+ bg_idx = torch.arange(heatmap_bg.shape[0])[bg_mask]
537
+ other_idx = torch.arange(heatmap_fg.shape[0])[other_mask]
538
+
539
+ fg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=fg_idx, title="fg")
540
+ bg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=bg_idx, title="bg")
541
+ other_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=other_idx, title="other")
542
+
543
+ cluster_images = fg_images + bg_images + other_images
544
+
545
+ magitude = torch.norm(eigvecs, dim=-1)
546
+ magitude = magitude.reshape(-1, h, w)
547
+
548
+ # magitude = fg_minus_bg.reshape(-1, h, w) #TODO
549
+
550
+ return cluster_images, magitude
551
 
552
  def ncut_run(
553
  model,
 
709
  if torch.cuda.is_available():
710
  images = images.cuda()
711
  _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
712
+ cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w)
713
  logging_str += f"Recursion #{i+1} plot time: {time.time() - start:.2f}s\n"
714
 
715
  norm_images = []
 
824
  images = images.cuda()
825
  _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
826
  advanced = kwargs.get("advanced", False)
827
+ if advanced:
828
+ cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w)
829
+ else:
830
+ cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=False)
831
  logging_str += f"plot time: {time.time() - start:.2f}s\n"
832
 
833
  norm_images = None
 
847
  logging_str += "Eigenvector Magnitude\n"
848
  logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
849
  gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=10)
850
+
851
  return to_pil_images(rgb), cluster_images, norm_images, logging_str
852
 
853
 
854
 
855
  def _ncut_run(*args, **kwargs):
856
  n_ret = kwargs.pop("n_ret", 1)
857
+ # try:
858
+ # if torch.cuda.is_available():
859
+ # torch.cuda.empty_cache()
860
 
861
+ # ret = ncut_run(*args, **kwargs)
862
 
863
+ # if torch.cuda.is_available():
864
+ # torch.cuda.empty_cache()
865
 
866
+ # ret = list(ret)[:n_ret] + [ret[-1]]
867
+ # return ret
868
+ # except Exception as e:
869
+ # gr.Error(str(e))
870
+ # if torch.cuda.is_available():
871
+ # torch.cuda.empty_cache()
872
+ # return *(None for _ in range(n_ret)), "Error: " + str(e)
873
+
874
+ ret = ncut_run(*args, **kwargs)
875
+ ret = list(ret)[:n_ret] + [ret[-1]]
876
+ return ret
877
 
878
  if USE_HUGGINGFACE_ZEROGPU:
879
  @spaces.GPU(duration=30)
 
1297
  images += [Image.open(new_image) for new_image in new_images]
1298
  if isinstance(new_images, str):
1299
  images.append(Image.open(new_images))
1300
+ gr.Info(f"Total images: {len(images)}")
1301
  return images
1302
  upload_button.upload(convert_to_pil_and_append, inputs=[input_gallery, upload_button], outputs=[input_gallery])
1303
 
 
1513
  if existing_images is None:
1514
  existing_images = []
1515
  existing_images += new_images
1516
+ gr.Info(f"Total images: {len(existing_images)}")
1517
  return existing_images
1518
 
1519
  load_images_button.click(load_and_append,
 
1528
 
1529
 
1530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1531
  # def random_rotate_rgb_gallery(images):
1532
  # if images is None or len(images) == 0:
1533
  # gr.Warning("No images selected.")
 
1922
  l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
1923
  add_output_images_buttons(l1_gallery)
1924
  l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
1925
+ l1_cluster_gallery = gr.Gallery(value=[], label="Recursion #1 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=500, show_share_button=True, preview=True, interactive=False)
1926
  with gr.Column(scale=5, min_width=200):
1927
  gr.Markdown('### Output (Recursion #2)')
1928
  l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
1929
  add_output_images_buttons(l2_gallery)
1930
  l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
1931
+ l2_cluster_gallery = gr.Gallery(value=[], label="Recursion #2 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=500, show_share_button=True, preview=True, interactive=False)
1932
  with gr.Column(scale=5, min_width=200):
1933
  gr.Markdown('### Output (Recursion #3)')
1934
  l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
1935
  add_output_images_buttons(l3_gallery)
1936
  l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
1937
+ l3_cluster_gallery = gr.Gallery(value=[], label="Recursion #3 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=500, show_share_button=True, preview=True, interactive=False)
1938
 
1939
  with gr.Row():
1940
  with gr.Column(scale=5, min_width=200):
 
2305
  submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
2306
  add_output_images_buttons(output_gallery)
2307
  norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2308
+ cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height=500, show_share_button=True, preview=True, interactive=False)
2309
  [
2310
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
2311
  affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,