Spaces:
Running
on
Zero
Running
on
Zero
update plotting fg
Browse files
app.py
CHANGED
@@ -312,7 +312,8 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
|
|
312 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
313 |
return blended.astype(np.uint8)
|
314 |
|
315 |
-
|
|
|
316 |
def segment_fg_bg(images):
|
317 |
|
318 |
images = F.interpolate(images, (224, 224), mode="bilinear")
|
@@ -459,7 +460,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
459 |
# fps_heatmaps[idx.item()] = heatmap.cpu()
|
460 |
fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:6]].cpu()
|
461 |
top3_image_idx[idx.item()] = mask_sort_idx[:3]
|
462 |
-
top10_image_idx[idx.item()] = mask_sort_idx[:
|
463 |
# do the sorting
|
464 |
_sort_idx = torch.tensor(sort_values).argsort(descending=True)
|
465 |
fps_idx = fps_idx[_sort_idx]
|
@@ -486,7 +487,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
486 |
if not advanced:
|
487 |
fig, axs = plt.subplots(3, 5, figsize=(15, 9))
|
488 |
if advanced:
|
489 |
-
fig, axs = plt.subplots(
|
490 |
for ax in axs.flatten():
|
491 |
ax.axis("off")
|
492 |
for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
|
@@ -521,7 +522,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
521 |
|
522 |
return fig_images, ret_magnitude
|
523 |
|
524 |
-
def make_cluster_plot_advanced(eigvecs, images, h=64, w=64):
|
525 |
heatmap_fg, heatmap_bg = segment_fg_bg(images.clone())
|
526 |
heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w')
|
527 |
heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w')
|
@@ -542,9 +543,15 @@ def make_cluster_plot_advanced(eigvecs, images, h=64, w=64):
|
|
542 |
bg_idx = torch.arange(heatmap_bg.shape[0])[bg_mask]
|
543 |
other_idx = torch.arange(heatmap_fg.shape[0])[other_mask]
|
544 |
|
545 |
-
fg_images
|
546 |
-
|
547 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
|
549 |
cluster_images = fg_images + bg_images + other_images
|
550 |
|
@@ -833,7 +840,7 @@ def ncut_run(
|
|
833 |
if advanced:
|
834 |
cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w)
|
835 |
else:
|
836 |
-
cluster_images, eig_magnitude =
|
837 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
838 |
|
839 |
norm_images = None
|
@@ -1859,7 +1866,7 @@ with demo:
|
|
1859 |
|
1860 |
with gr.Column(scale=5, min_width=200):
|
1861 |
output_gallery = make_output_images_section()
|
1862 |
-
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=
|
1863 |
[
|
1864 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
1865 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
|
|
312 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
313 |
return blended.astype(np.uint8)
|
314 |
|
315 |
+
# preload the model
|
316 |
+
load_model("CLIP(ViT-B-16/openai)")
|
317 |
def segment_fg_bg(images):
|
318 |
|
319 |
images = F.interpolate(images, (224, 224), mode="bilinear")
|
|
|
460 |
# fps_heatmaps[idx.item()] = heatmap.cpu()
|
461 |
fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:6]].cpu()
|
462 |
top3_image_idx[idx.item()] = mask_sort_idx[:3]
|
463 |
+
top10_image_idx[idx.item()] = mask_sort_idx[:5]
|
464 |
# do the sorting
|
465 |
_sort_idx = torch.tensor(sort_values).argsort(descending=True)
|
466 |
fps_idx = fps_idx[_sort_idx]
|
|
|
487 |
if not advanced:
|
488 |
fig, axs = plt.subplots(3, 5, figsize=(15, 9))
|
489 |
if advanced:
|
490 |
+
fig, axs = plt.subplots(5, 5, figsize=(15, 15))
|
491 |
for ax in axs.flatten():
|
492 |
ax.axis("off")
|
493 |
for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
|
|
|
522 |
|
523 |
return fig_images, ret_magnitude
|
524 |
|
525 |
+
def make_cluster_plot_advanced(eigvecs, images, h=64, w=64, num_fg=100, num_bg=10, num_other=0, small_title=True):
|
526 |
heatmap_fg, heatmap_bg = segment_fg_bg(images.clone())
|
527 |
heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w')
|
528 |
heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w')
|
|
|
543 |
bg_idx = torch.arange(heatmap_bg.shape[0])[bg_mask]
|
544 |
other_idx = torch.arange(heatmap_fg.shape[0])[other_mask]
|
545 |
|
546 |
+
fg_images = []
|
547 |
+
if num_fg > 0:
|
548 |
+
fg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=num_fg, eig_idx=fg_idx, title="fg" if small_title else "cluster")
|
549 |
+
bg_images = []
|
550 |
+
if num_bg > 0:
|
551 |
+
bg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=num_bg, eig_idx=bg_idx, title="bg" if small_title else "cluster")
|
552 |
+
other_images = []
|
553 |
+
if num_other > 0:
|
554 |
+
other_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=num_other, eig_idx=other_idx, title="other" if small_title else "cluster")
|
555 |
|
556 |
cluster_images = fg_images + bg_images + other_images
|
557 |
|
|
|
840 |
if advanced:
|
841 |
cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w)
|
842 |
else:
|
843 |
+
cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w, num_fg=20, num_bg=0, num_other=0, small_title=False)
|
844 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
845 |
|
846 |
norm_images = None
|
|
|
1866 |
|
1867 |
with gr.Column(scale=5, min_width=200):
|
1868 |
output_gallery = make_output_images_section()
|
1869 |
+
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=True, interactive=True)
|
1870 |
[
|
1871 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
1872 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|