Spaces:
Running
on
Zero
Running
on
Zero
update cluster fg bg
Browse files
app.py
CHANGED
@@ -308,7 +308,79 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
|
|
308 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
309 |
return blended.astype(np.uint8)
|
310 |
|
311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
progress = gr.Progress()
|
313 |
progress(progess_start, desc="Finding Clusters by FPS")
|
314 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
@@ -318,10 +390,13 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
318 |
|
319 |
# gr.Info("Finding Clusters by FPS, no magnitude filtering")
|
320 |
top_p_idx = torch.arange(eigvecs.shape[0])
|
|
|
|
|
321 |
# gr.Info("Finding Clusters by FPS, with magnitude filtering")
|
322 |
# p = 0.8
|
323 |
# top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
|
324 |
|
|
|
325 |
ret_magnitude = magnitude.reshape(-1, h, w)
|
326 |
|
327 |
|
@@ -338,7 +413,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
338 |
right = F.normalize(right, dim=-1)
|
339 |
heatmap = left @ right.T
|
340 |
heatmap = F.normalize(heatmap, dim=-1)
|
341 |
-
num_samples =
|
342 |
if num_samples > fps_idx.shape[0]:
|
343 |
num_samples = fps_idx.shape[0]
|
344 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
@@ -398,10 +473,10 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
398 |
|
399 |
fig_images = []
|
400 |
i_cluster = 0
|
401 |
-
num_plots =
|
402 |
plot_step_float = (1.0 - progess_start) / num_plots
|
403 |
for i_fig in range(num_plots):
|
404 |
-
progress(progess_start + i_fig * plot_step_float, desc="Plotting
|
405 |
if not advanced:
|
406 |
fig, axs = plt.subplots(3, 5, figsize=(15, 9))
|
407 |
if advanced:
|
@@ -421,7 +496,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
421 |
_heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
|
422 |
axs[i, j].imshow(_heatmap)
|
423 |
if i == 0:
|
424 |
-
axs[i, j].set_title(f"
|
425 |
i_cluster += 1
|
426 |
plt.tight_layout(h_pad=0.5, w_pad=0.3)
|
427 |
|
@@ -440,6 +515,39 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
440 |
|
441 |
return fig_images, ret_magnitude
|
442 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
|
444 |
def ncut_run(
|
445 |
model,
|
@@ -601,7 +709,7 @@ def ncut_run(
|
|
601 |
if torch.cuda.is_available():
|
602 |
images = images.cuda()
|
603 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
604 |
-
cluster_images, eig_magnitude =
|
605 |
logging_str += f"Recursion #{i+1} plot time: {time.time() - start:.2f}s\n"
|
606 |
|
607 |
norm_images = []
|
@@ -716,7 +824,10 @@ def ncut_run(
|
|
716 |
images = images.cuda()
|
717 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
718 |
advanced = kwargs.get("advanced", False)
|
719 |
-
|
|
|
|
|
|
|
720 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
721 |
|
722 |
norm_images = None
|
@@ -736,33 +847,33 @@ def ncut_run(
|
|
736 |
logging_str += "Eigenvector Magnitude\n"
|
737 |
logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
|
738 |
gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=10)
|
739 |
-
|
740 |
return to_pil_images(rgb), cluster_images, norm_images, logging_str
|
741 |
|
742 |
|
743 |
|
744 |
def _ncut_run(*args, **kwargs):
|
745 |
n_ret = kwargs.pop("n_ret", 1)
|
746 |
-
try:
|
747 |
-
|
748 |
-
|
749 |
|
750 |
-
|
751 |
|
752 |
-
|
753 |
-
|
754 |
|
755 |
-
|
756 |
-
|
757 |
-
except Exception as e:
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
|
767 |
if USE_HUGGINGFACE_ZEROGPU:
|
768 |
@spaces.GPU(duration=30)
|
@@ -1186,7 +1297,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
|
|
1186 |
images += [Image.open(new_image) for new_image in new_images]
|
1187 |
if isinstance(new_images, str):
|
1188 |
images.append(Image.open(new_images))
|
1189 |
-
|
1190 |
return images
|
1191 |
upload_button.upload(convert_to_pil_and_append, inputs=[input_gallery, upload_button], outputs=[input_gallery])
|
1192 |
|
@@ -1402,6 +1513,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
|
|
1402 |
if existing_images is None:
|
1403 |
existing_images = []
|
1404 |
existing_images += new_images
|
|
|
1405 |
return existing_images
|
1406 |
|
1407 |
load_images_button.click(load_and_append,
|
@@ -1416,165 +1528,6 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
|
|
1416 |
|
1417 |
|
1418 |
|
1419 |
-
# def make_input_images_section(rows=1, cols=3, height="auto"):
|
1420 |
-
# gr.Markdown('### Input Images')
|
1421 |
-
# input_gallery = gr.Gallery(value=None, label="Select images", show_label=True, elem_id="images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False)
|
1422 |
-
# submit_button = gr.Button("🔴 RUN", elem_id="submit_button", variant='primary')
|
1423 |
-
# clear_images_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
|
1424 |
-
# return input_gallery, submit_button, clear_images_button
|
1425 |
-
|
1426 |
-
|
1427 |
-
# def make_dataset_images_section(advanced=False, is_random=False):
|
1428 |
-
|
1429 |
-
# gr.Markdown('### Load Datasets')
|
1430 |
-
# load_images_button = gr.Button("🔴 Load Images", elem_id="load-images-button", variant='primary')
|
1431 |
-
# advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets", value="Advanced" if advanced else "Basic", elem_id="advanced-radio", show_label=True)
|
1432 |
-
# with gr.Column() as basic_block:
|
1433 |
-
# example_gallery = gr.Gallery(value=example_items, label="Example Images", show_label=True, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
|
1434 |
-
# with gr.Column() as advanced_block:
|
1435 |
-
# # dataset_names = DATASET_NAMES
|
1436 |
-
# # dataset_classes = DATASET_CLASSES
|
1437 |
-
# dataset_categories = list(DATASETS.keys())
|
1438 |
-
# defualt_cat = dataset_categories[0]
|
1439 |
-
# def get_choices(cat):
|
1440 |
-
# return [tup[0] for tup in DATASETS[cat]]
|
1441 |
-
# defualt_choices = get_choices(defualt_cat)
|
1442 |
-
# with gr.Row():
|
1443 |
-
# dataset_radio = gr.Radio(dataset_categories, label="Dataset Category", value=defualt_cat, elem_id="dataset-radio", show_label=True, min_width=600)
|
1444 |
-
# # dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset", min_width=300)
|
1445 |
-
# dataset_dropdown = gr.Dropdown(defualt_choices, label="Dataset name", value=defualt_choices[0], elem_id="dataset", min_width=400)
|
1446 |
-
# dataset_radio.change(fn=lambda x: gr.update(choices=get_choices(x), value=get_choices(x)[0]), inputs=dataset_radio, outputs=dataset_dropdown)
|
1447 |
-
# # num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
|
1448 |
-
# num_images_slider = gr.Slider(1, 1000, step=1, label="Number of images", value=10, elem_id="num_images", min_width=200)
|
1449 |
-
# if not is_random:
|
1450 |
-
# filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox")
|
1451 |
-
# filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=True)
|
1452 |
-
# # is_random_checkbox = gr.Checkbox(label="Random shuffle", value=False, elem_id="random_seed_checkbox")
|
1453 |
-
# # random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=False)
|
1454 |
-
# is_random_checkbox = gr.Checkbox(label="Random shuffle", value=True, elem_id="random_seed_checkbox")
|
1455 |
-
# random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=True)
|
1456 |
-
# if is_random:
|
1457 |
-
# filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox")
|
1458 |
-
# filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=False)
|
1459 |
-
# is_random_checkbox = gr.Checkbox(label="Random shuffle", value=True, elem_id="random_seed_checkbox")
|
1460 |
-
# random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=42, elem_id="random_seed", visible=True)
|
1461 |
-
|
1462 |
-
|
1463 |
-
# if advanced:
|
1464 |
-
# advanced_block.visible = True
|
1465 |
-
# basic_block.visible = False
|
1466 |
-
# else:
|
1467 |
-
# advanced_block.visible = False
|
1468 |
-
# basic_block.visible = True
|
1469 |
-
|
1470 |
-
# # change visibility
|
1471 |
-
# advanced_radio.change(fn=lambda x: gr.update(visible=x=="Advanced"), inputs=advanced_radio, outputs=[advanced_block])
|
1472 |
-
# advanced_radio.change(fn=lambda x: gr.update(visible=x=="Basic"), inputs=advanced_radio, outputs=[basic_block])
|
1473 |
-
|
1474 |
-
# def find_num_classes(dataset_name):
|
1475 |
-
# num_classes = None
|
1476 |
-
# for cat, datasets in DATASETS.items():
|
1477 |
-
# datasets = [tup[0] for tup in datasets]
|
1478 |
-
# if dataset_name in datasets:
|
1479 |
-
# num_classes = DATASETS[cat][datasets.index(dataset_name)][1]
|
1480 |
-
# break
|
1481 |
-
# return num_classes
|
1482 |
-
|
1483 |
-
# def change_filter_options(dataset_name):
|
1484 |
-
# num_classes = find_num_classes(dataset_name)
|
1485 |
-
# if num_classes is None:
|
1486 |
-
# return (gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox", visible=False),
|
1487 |
-
# gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info="e.g. `0,1,2`. This dataset has no class label", visible=False))
|
1488 |
-
# return (gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox", visible=True),
|
1489 |
-
# gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=True))
|
1490 |
-
# dataset_dropdown.change(fn=change_filter_options, inputs=dataset_dropdown, outputs=[filter_by_class_checkbox, filter_by_class_text])
|
1491 |
-
|
1492 |
-
# def change_filter_by_class(is_filter, dataset_name):
|
1493 |
-
# num_classes = find_num_classes(dataset_name)
|
1494 |
-
# return gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=is_filter)
|
1495 |
-
# filter_by_class_checkbox.change(fn=change_filter_by_class, inputs=[filter_by_class_checkbox, dataset_dropdown], outputs=filter_by_class_text)
|
1496 |
-
|
1497 |
-
# def change_random_seed(is_random):
|
1498 |
-
# return gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=is_random)
|
1499 |
-
# is_random_checkbox.change(fn=change_random_seed, inputs=is_random_checkbox, outputs=random_seed_slider)
|
1500 |
-
|
1501 |
-
|
1502 |
-
# def load_dataset_images(is_advanced, dataset_name, num_images=10,
|
1503 |
-
# is_filter=True, filter_by_class_text="0,1,2",
|
1504 |
-
# is_random=False, seed=1):
|
1505 |
-
# progress = gr.Progress()
|
1506 |
-
# progress(0, desc="Loading Images")
|
1507 |
-
# if is_advanced == "Basic":
|
1508 |
-
# gr.Info("Loaded images from Ego-Exo4D")
|
1509 |
-
# return default_images
|
1510 |
-
# try:
|
1511 |
-
# progress(0.5, desc="Downloading Dataset")
|
1512 |
-
# dataset = load_dataset(dataset_name, trust_remote_code=True)
|
1513 |
-
# key = list(dataset.keys())[0]
|
1514 |
-
# dataset = dataset[key]
|
1515 |
-
# except Exception as e:
|
1516 |
-
# gr.Error(f"Error loading dataset {dataset_name}: {e}")
|
1517 |
-
# return None
|
1518 |
-
# if num_images > len(dataset):
|
1519 |
-
# num_images = len(dataset)
|
1520 |
-
|
1521 |
-
# if is_filter:
|
1522 |
-
# progress(0.8, desc="Filtering Images")
|
1523 |
-
# classes = [int(i) for i in filter_by_class_text.split(",")]
|
1524 |
-
# labels = np.array(dataset['label'])
|
1525 |
-
# unique_labels = np.unique(labels)
|
1526 |
-
# valid_classes = [i for i in classes if i in unique_labels]
|
1527 |
-
# invalid_classes = [i for i in classes if i not in unique_labels]
|
1528 |
-
# if len(invalid_classes) > 0:
|
1529 |
-
# gr.Warning(f"Classes {invalid_classes} not found in the dataset.")
|
1530 |
-
# if len(valid_classes) == 0:
|
1531 |
-
# gr.Error(f"Classes {classes} not found in the dataset.")
|
1532 |
-
# return None
|
1533 |
-
# # shuffle each class
|
1534 |
-
# chunk_size = num_images // len(valid_classes)
|
1535 |
-
# image_idx = []
|
1536 |
-
# for i in valid_classes:
|
1537 |
-
# idx = np.where(labels == i)[0]
|
1538 |
-
# if is_random:
|
1539 |
-
# idx = np.random.RandomState(seed).choice(idx, chunk_size, replace=False)
|
1540 |
-
# else:
|
1541 |
-
# idx = idx[:chunk_size]
|
1542 |
-
# image_idx.extend(idx.tolist())
|
1543 |
-
# if not is_filter:
|
1544 |
-
# if is_random:
|
1545 |
-
# image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
|
1546 |
-
# else:
|
1547 |
-
# image_idx = list(range(num_images))
|
1548 |
-
# key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
|
1549 |
-
# images = [dataset[i][key] for i in image_idx]
|
1550 |
-
# gr.Info(f"Loaded {len(images)} images from {dataset_name}")
|
1551 |
-
# del dataset
|
1552 |
-
|
1553 |
-
# if dataset_name in CENTER_CROP_DATASETS:
|
1554 |
-
# def center_crop_image(img):
|
1555 |
-
# # image: PIL image
|
1556 |
-
# w, h = img.size
|
1557 |
-
# min_hw = min(h, w)
|
1558 |
-
# # center crop
|
1559 |
-
# left = (w - min_hw) // 2
|
1560 |
-
# top = (h - min_hw) // 2
|
1561 |
-
# right = left + min_hw
|
1562 |
-
# bottom = top + min_hw
|
1563 |
-
# img = img.crop((left, top, right, bottom))
|
1564 |
-
# return img
|
1565 |
-
# images = [center_crop_image(image) for image in images]
|
1566 |
-
|
1567 |
-
# return images
|
1568 |
-
|
1569 |
-
# load_images_button.click(load_dataset_images,
|
1570 |
-
# inputs=[advanced_radio, dataset_dropdown, num_images_slider,
|
1571 |
-
# filter_by_class_checkbox, filter_by_class_text,
|
1572 |
-
# is_random_checkbox, random_seed_slider],
|
1573 |
-
# outputs=[input_gallery])
|
1574 |
-
|
1575 |
-
# return dataset_dropdown, num_images_slider, random_seed_slider, load_images_button
|
1576 |
-
|
1577 |
-
|
1578 |
# def random_rotate_rgb_gallery(images):
|
1579 |
# if images is None or len(images) == 0:
|
1580 |
# gr.Warning("No images selected.")
|
@@ -1969,19 +1922,19 @@ with demo:
|
|
1969 |
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
1970 |
add_output_images_buttons(l1_gallery)
|
1971 |
l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
1972 |
-
l1_cluster_gallery = gr.Gallery(value=[], label="Recursion #1 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=
|
1973 |
with gr.Column(scale=5, min_width=200):
|
1974 |
gr.Markdown('### Output (Recursion #2)')
|
1975 |
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
1976 |
add_output_images_buttons(l2_gallery)
|
1977 |
l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
1978 |
-
l2_cluster_gallery = gr.Gallery(value=[], label="Recursion #2 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=
|
1979 |
with gr.Column(scale=5, min_width=200):
|
1980 |
gr.Markdown('### Output (Recursion #3)')
|
1981 |
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
1982 |
add_output_images_buttons(l3_gallery)
|
1983 |
l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
1984 |
-
l3_cluster_gallery = gr.Gallery(value=[], label="Recursion #3 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=
|
1985 |
|
1986 |
with gr.Row():
|
1987 |
with gr.Column(scale=5, min_width=200):
|
@@ -2352,7 +2305,7 @@ with demo:
|
|
2352 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
2353 |
add_output_images_buttons(output_gallery)
|
2354 |
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
2355 |
-
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height=
|
2356 |
[
|
2357 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
2358 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
|
|
308 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
309 |
return blended.astype(np.uint8)
|
310 |
|
311 |
+
|
312 |
+
def segment_fg_bg(images):
|
313 |
+
|
314 |
+
images = F.interpolate(images, (224, 224), mode="bilinear")
|
315 |
+
|
316 |
+
# model = load_alignedthreemodel()
|
317 |
+
model = load_model("CLIP(ViT-B-16/openai)")
|
318 |
+
from ncut_pytorch.backbone import resample_position_embeddings
|
319 |
+
pos_embed = model.model.visual.positional_embedding
|
320 |
+
pos_embed = resample_position_embeddings(pos_embed, 14, 14)
|
321 |
+
model.model.visual.positional_embedding = torch.nn.Parameter(pos_embed)
|
322 |
+
|
323 |
+
batch_size = 4
|
324 |
+
chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
325 |
+
|
326 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
327 |
+
model.to(device)
|
328 |
+
means = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
329 |
+
stds = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
330 |
+
|
331 |
+
fg_acts, bg_acts = [], []
|
332 |
+
for chunk_idx in chunk_idxs:
|
333 |
+
with torch.no_grad():
|
334 |
+
input_images = images[chunk_idx].to(device)
|
335 |
+
# transform the input images
|
336 |
+
input_images = (input_images - means) / stds
|
337 |
+
# output = model(input_images)[:, 5]
|
338 |
+
output = model(input_images)['attn'][6]
|
339 |
+
fg_act = output[:, 6, 6].mean(0)
|
340 |
+
bg_act = output[:, 0, 0].mean(0)
|
341 |
+
fg_acts.append(fg_act)
|
342 |
+
bg_acts.append(bg_act)
|
343 |
+
fg_act = torch.stack(fg_acts, dim=0).mean(0)
|
344 |
+
bg_act = torch.stack(bg_acts, dim=0).mean(0)
|
345 |
+
fg_act = F.normalize(fg_act, dim=-1)
|
346 |
+
bg_act = F.normalize(bg_act, dim=-1)
|
347 |
+
|
348 |
+
# ref_image = default_images[0]
|
349 |
+
# image = Image.open(ref_image).convert("RGB").resize((224, 224), Image.Resampling.BILINEAR)
|
350 |
+
# image = torch.tensor(np.array(image)).permute(2, 0, 1).float().to(device)
|
351 |
+
# image = (image / 255.0 - means) / stds
|
352 |
+
# output = model(image)['attn'][6][0]
|
353 |
+
# # print(output.shape)
|
354 |
+
# # bg on the center
|
355 |
+
# fg_act = output[5, 5]
|
356 |
+
# # bg on the bottom left
|
357 |
+
# bg_act = output[0, 0]
|
358 |
+
# fg_act = F.normalize(fg_act, dim=-1)
|
359 |
+
# bg_act = F.normalize(bg_act, dim=-1)
|
360 |
+
|
361 |
+
# print(images.mean(), images.std())
|
362 |
+
|
363 |
+
fg_act, bg_act = fg_act.to(device), bg_act.to(device)
|
364 |
+
chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
365 |
+
heatmap_fgs, heatmap_bgs = [], []
|
366 |
+
for chunk_idx in chunk_idxs:
|
367 |
+
with torch.no_grad():
|
368 |
+
input_images = images[chunk_idx].to(device)
|
369 |
+
# transform the input images
|
370 |
+
input_images = (input_images - means) / stds
|
371 |
+
# output = model(input_images)[:, 5]
|
372 |
+
output = model(input_images)['attn'][6]
|
373 |
+
output = F.normalize(output, dim=-1)
|
374 |
+
heatmap_fg = output @ fg_act[:, None]
|
375 |
+
heatmap_bg = output @ bg_act[:, None]
|
376 |
+
heatmap_fgs.append(heatmap_fg.cpu())
|
377 |
+
heatmap_bgs.append(heatmap_bg.cpu())
|
378 |
+
heatmap_fg = torch.cat(heatmap_fgs, dim=0)
|
379 |
+
heatmap_bg = torch.cat(heatmap_bgs, dim=0)
|
380 |
+
return heatmap_fg, heatmap_bg
|
381 |
+
|
382 |
+
|
383 |
+
def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=False, clusters=50, eig_idx=None, title='cluster'):
|
384 |
progress = gr.Progress()
|
385 |
progress(progess_start, desc="Finding Clusters by FPS")
|
386 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
390 |
|
391 |
# gr.Info("Finding Clusters by FPS, no magnitude filtering")
|
392 |
top_p_idx = torch.arange(eigvecs.shape[0])
|
393 |
+
if eig_idx is not None:
|
394 |
+
top_p_idx = eig_idx
|
395 |
# gr.Info("Finding Clusters by FPS, with magnitude filtering")
|
396 |
# p = 0.8
|
397 |
# top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
|
398 |
|
399 |
+
|
400 |
ret_magnitude = magnitude.reshape(-1, h, w)
|
401 |
|
402 |
|
|
|
413 |
right = F.normalize(right, dim=-1)
|
414 |
heatmap = left @ right.T
|
415 |
heatmap = F.normalize(heatmap, dim=-1)
|
416 |
+
num_samples = clusters + 20
|
417 |
if num_samples > fps_idx.shape[0]:
|
418 |
num_samples = fps_idx.shape[0]
|
419 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
|
|
473 |
|
474 |
fig_images = []
|
475 |
i_cluster = 0
|
476 |
+
num_plots = clusters // 5
|
477 |
plot_step_float = (1.0 - progess_start) / num_plots
|
478 |
for i_fig in range(num_plots):
|
479 |
+
progress(progess_start + i_fig * plot_step_float, desc=f"Plotting {title}")
|
480 |
if not advanced:
|
481 |
fig, axs = plt.subplots(3, 5, figsize=(15, 9))
|
482 |
if advanced:
|
|
|
496 |
_heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
|
497 |
axs[i, j].imshow(_heatmap)
|
498 |
if i == 0:
|
499 |
+
axs[i, j].set_title(f"{title} {i_cluster+1}", fontsize=24)
|
500 |
i_cluster += 1
|
501 |
plt.tight_layout(h_pad=0.5, w_pad=0.3)
|
502 |
|
|
|
515 |
|
516 |
return fig_images, ret_magnitude
|
517 |
|
518 |
+
def make_cluster_plot_advanced(eigvecs, images, h=64, w=64):
|
519 |
+
heatmap_fg, heatmap_bg = segment_fg_bg(images.clone())
|
520 |
+
heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w')
|
521 |
+
heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w')
|
522 |
+
heatmap_fg = F.interpolate(heatmap_fg, (h, w), mode="bilinear")
|
523 |
+
heatmap_bg = F.interpolate(heatmap_bg, (h, w), mode="bilinear")
|
524 |
+
heatmap_fg = heatmap_fg.flatten()
|
525 |
+
heatmap_bg = heatmap_bg.flatten()
|
526 |
+
|
527 |
+
fg_minus_bg = heatmap_fg - heatmap_bg
|
528 |
+
fg_mask = fg_minus_bg > fg_minus_bg.quantile(0.8)
|
529 |
+
bg_mask = fg_minus_bg < fg_minus_bg.quantile(0.2)
|
530 |
+
|
531 |
+
# fg_mask = heatmap_fg > heatmap_fg.quantile(0.8)
|
532 |
+
# bg_mask = heatmap_bg > heatmap_bg.quantile(0.8)
|
533 |
+
other_mask = ~(fg_mask | bg_mask)
|
534 |
+
|
535 |
+
fg_idx = torch.arange(heatmap_fg.shape[0])[fg_mask]
|
536 |
+
bg_idx = torch.arange(heatmap_bg.shape[0])[bg_mask]
|
537 |
+
other_idx = torch.arange(heatmap_fg.shape[0])[other_mask]
|
538 |
+
|
539 |
+
fg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=fg_idx, title="fg")
|
540 |
+
bg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=bg_idx, title="bg")
|
541 |
+
other_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=other_idx, title="other")
|
542 |
+
|
543 |
+
cluster_images = fg_images + bg_images + other_images
|
544 |
+
|
545 |
+
magitude = torch.norm(eigvecs, dim=-1)
|
546 |
+
magitude = magitude.reshape(-1, h, w)
|
547 |
+
|
548 |
+
# magitude = fg_minus_bg.reshape(-1, h, w) #TODO
|
549 |
+
|
550 |
+
return cluster_images, magitude
|
551 |
|
552 |
def ncut_run(
|
553 |
model,
|
|
|
709 |
if torch.cuda.is_available():
|
710 |
images = images.cuda()
|
711 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
712 |
+
cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w)
|
713 |
logging_str += f"Recursion #{i+1} plot time: {time.time() - start:.2f}s\n"
|
714 |
|
715 |
norm_images = []
|
|
|
824 |
images = images.cuda()
|
825 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
826 |
advanced = kwargs.get("advanced", False)
|
827 |
+
if advanced:
|
828 |
+
cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w)
|
829 |
+
else:
|
830 |
+
cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=False)
|
831 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
832 |
|
833 |
norm_images = None
|
|
|
847 |
logging_str += "Eigenvector Magnitude\n"
|
848 |
logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
|
849 |
gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=10)
|
850 |
+
|
851 |
return to_pil_images(rgb), cluster_images, norm_images, logging_str
|
852 |
|
853 |
|
854 |
|
855 |
def _ncut_run(*args, **kwargs):
|
856 |
n_ret = kwargs.pop("n_ret", 1)
|
857 |
+
# try:
|
858 |
+
# if torch.cuda.is_available():
|
859 |
+
# torch.cuda.empty_cache()
|
860 |
|
861 |
+
# ret = ncut_run(*args, **kwargs)
|
862 |
|
863 |
+
# if torch.cuda.is_available():
|
864 |
+
# torch.cuda.empty_cache()
|
865 |
|
866 |
+
# ret = list(ret)[:n_ret] + [ret[-1]]
|
867 |
+
# return ret
|
868 |
+
# except Exception as e:
|
869 |
+
# gr.Error(str(e))
|
870 |
+
# if torch.cuda.is_available():
|
871 |
+
# torch.cuda.empty_cache()
|
872 |
+
# return *(None for _ in range(n_ret)), "Error: " + str(e)
|
873 |
+
|
874 |
+
ret = ncut_run(*args, **kwargs)
|
875 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
876 |
+
return ret
|
877 |
|
878 |
if USE_HUGGINGFACE_ZEROGPU:
|
879 |
@spaces.GPU(duration=30)
|
|
|
1297 |
images += [Image.open(new_image) for new_image in new_images]
|
1298 |
if isinstance(new_images, str):
|
1299 |
images.append(Image.open(new_images))
|
1300 |
+
gr.Info(f"Total images: {len(images)}")
|
1301 |
return images
|
1302 |
upload_button.upload(convert_to_pil_and_append, inputs=[input_gallery, upload_button], outputs=[input_gallery])
|
1303 |
|
|
|
1513 |
if existing_images is None:
|
1514 |
existing_images = []
|
1515 |
existing_images += new_images
|
1516 |
+
gr.Info(f"Total images: {len(existing_images)}")
|
1517 |
return existing_images
|
1518 |
|
1519 |
load_images_button.click(load_and_append,
|
|
|
1528 |
|
1529 |
|
1530 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1531 |
# def random_rotate_rgb_gallery(images):
|
1532 |
# if images is None or len(images) == 0:
|
1533 |
# gr.Warning("No images selected.")
|
|
|
1922 |
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
1923 |
add_output_images_buttons(l1_gallery)
|
1924 |
l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
1925 |
+
l1_cluster_gallery = gr.Gallery(value=[], label="Recursion #1 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=500, show_share_button=True, preview=True, interactive=False)
|
1926 |
with gr.Column(scale=5, min_width=200):
|
1927 |
gr.Markdown('### Output (Recursion #2)')
|
1928 |
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
1929 |
add_output_images_buttons(l2_gallery)
|
1930 |
l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
1931 |
+
l2_cluster_gallery = gr.Gallery(value=[], label="Recursion #2 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=500, show_share_button=True, preview=True, interactive=False)
|
1932 |
with gr.Column(scale=5, min_width=200):
|
1933 |
gr.Markdown('### Output (Recursion #3)')
|
1934 |
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
1935 |
add_output_images_buttons(l3_gallery)
|
1936 |
l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
1937 |
+
l3_cluster_gallery = gr.Gallery(value=[], label="Recursion #3 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=500, show_share_button=True, preview=True, interactive=False)
|
1938 |
|
1939 |
with gr.Row():
|
1940 |
with gr.Column(scale=5, min_width=200):
|
|
|
2305 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
2306 |
add_output_images_buttons(output_gallery)
|
2307 |
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
2308 |
+
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height=500, show_share_button=True, preview=True, interactive=False)
|
2309 |
[
|
2310 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
2311 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|