Spaces:
Running
on
Zero
Running
on
Zero
remove mask size filter
Browse files
app.py
CHANGED
@@ -287,7 +287,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
287 |
right = F.normalize(right, dim=-1)
|
288 |
heatmap = left @ right.T
|
289 |
heatmap = F.normalize(heatmap, dim=-1)
|
290 |
-
num_samples =
|
291 |
if num_samples > fps_idx.shape[0]:
|
292 |
num_samples = fps_idx.shape[0]
|
293 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
@@ -305,6 +305,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
305 |
fps_heatmaps = {}
|
306 |
sort_values = []
|
307 |
top3_image_idx = {}
|
|
|
308 |
for _, idx in enumerate(fps_idx):
|
309 |
heatmap = F.cosine_similarity(eigvecs, eigvecs[idx][None], dim=-1)
|
310 |
|
@@ -314,7 +315,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
314 |
# tensor = tensor[torch.randperm(tensor.shape[0])[:max_size]]
|
315 |
# return tensor.quantile(p)
|
316 |
# top_p = top_percentile(heatmap, p=0.5)
|
317 |
-
top_p = 0.
|
318 |
|
319 |
heatmap = heatmap.reshape(-1, h, w)
|
320 |
mask = (heatmap > top_p).float()
|
@@ -324,8 +325,9 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
324 |
mask = mask[mask_sort_idx[:3]]
|
325 |
sort_values.append(mask.mean().item())
|
326 |
# fps_heatmaps[idx.item()] = heatmap.cpu()
|
327 |
-
fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:
|
328 |
top3_image_idx[idx.item()] = mask_sort_idx[:3]
|
|
|
329 |
# do the sorting
|
330 |
_sort_idx = torch.tensor(sort_values).argsort(descending=True)
|
331 |
fps_idx = fps_idx[_sort_idx]
|
@@ -342,13 +344,17 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
342 |
# shuffle the fps_idx
|
343 |
fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
|
344 |
|
|
|
345 |
fig_images = []
|
346 |
i_cluster = 0
|
347 |
num_plots = 10 if not advanced else 20
|
348 |
plot_step_float = (1.0 - progess_start) / num_plots
|
349 |
for i_fig in range(num_plots):
|
350 |
progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
|
351 |
-
|
|
|
|
|
|
|
352 |
for ax in axs.flatten():
|
353 |
ax.axis("off")
|
354 |
for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
|
@@ -358,7 +364,8 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
358 |
size = (images.shape[1], images.shape[2])
|
359 |
heatmap = apply_reds_colormap(heatmap, size)
|
360 |
# for i, image_idx in enumerate(sorted_image_idxs[:3]):
|
361 |
-
|
|
|
362 |
# _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[image_idx])
|
363 |
_heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
|
364 |
axs[i, j].imshow(_heatmap)
|
@@ -378,10 +385,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
378 |
|
379 |
fig_images.append(img)
|
380 |
plt.close()
|
381 |
-
|
382 |
-
# plt.imshow(img)
|
383 |
-
# plt.axis("off")
|
384 |
-
# plt.show()
|
385 |
return fig_images, ret_magnitude
|
386 |
|
387 |
|
@@ -647,26 +651,26 @@ def ncut_run(
|
|
647 |
|
648 |
def _ncut_run(*args, **kwargs):
|
649 |
n_ret = kwargs.pop("n_ret", 1)
|
650 |
-
try:
|
651 |
-
|
652 |
-
|
653 |
|
654 |
-
|
655 |
|
656 |
-
|
657 |
-
|
658 |
|
659 |
-
|
660 |
-
|
661 |
-
except Exception as e:
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
|
671 |
if USE_HUGGINGFACE_ZEROGPU:
|
672 |
@spaces.GPU(duration=30)
|
@@ -1415,7 +1419,7 @@ with demo:
|
|
1415 |
with gr.Column(scale=5, min_width=200):
|
1416 |
output_gallery = make_output_images_section()
|
1417 |
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
1418 |
-
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[
|
1419 |
[
|
1420 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
1421 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
|
|
287 |
right = F.normalize(right, dim=-1)
|
288 |
heatmap = left @ right.T
|
289 |
heatmap = F.normalize(heatmap, dim=-1)
|
290 |
+
num_samples = 50 if not advanced else 100
|
291 |
if num_samples > fps_idx.shape[0]:
|
292 |
num_samples = fps_idx.shape[0]
|
293 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
|
|
305 |
fps_heatmaps = {}
|
306 |
sort_values = []
|
307 |
top3_image_idx = {}
|
308 |
+
top10_image_idx = {}
|
309 |
for _, idx in enumerate(fps_idx):
|
310 |
heatmap = F.cosine_similarity(eigvecs, eigvecs[idx][None], dim=-1)
|
311 |
|
|
|
315 |
# tensor = tensor[torch.randperm(tensor.shape[0])[:max_size]]
|
316 |
# return tensor.quantile(p)
|
317 |
# top_p = top_percentile(heatmap, p=0.5)
|
318 |
+
top_p = 0.8
|
319 |
|
320 |
heatmap = heatmap.reshape(-1, h, w)
|
321 |
mask = (heatmap > top_p).float()
|
|
|
325 |
mask = mask[mask_sort_idx[:3]]
|
326 |
sort_values.append(mask.mean().item())
|
327 |
# fps_heatmaps[idx.item()] = heatmap.cpu()
|
328 |
+
fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:10]].cpu()
|
329 |
top3_image_idx[idx.item()] = mask_sort_idx[:3]
|
330 |
+
top10_image_idx[idx.item()] = mask_sort_idx[:10]
|
331 |
# do the sorting
|
332 |
_sort_idx = torch.tensor(sort_values).argsort(descending=True)
|
333 |
fps_idx = fps_idx[_sort_idx]
|
|
|
344 |
# shuffle the fps_idx
|
345 |
fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
|
346 |
|
347 |
+
|
348 |
fig_images = []
|
349 |
i_cluster = 0
|
350 |
num_plots = 10 if not advanced else 20
|
351 |
plot_step_float = (1.0 - progess_start) / num_plots
|
352 |
for i_fig in range(num_plots):
|
353 |
progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
|
354 |
+
if not advanced:
|
355 |
+
fig, axs = plt.subplots(3, 5, figsize=(15, 9))
|
356 |
+
if advanced:
|
357 |
+
fig, axs = plt.subplots(6, 5, figsize=(15, 18))
|
358 |
for ax in axs.flatten():
|
359 |
ax.axis("off")
|
360 |
for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
|
|
|
364 |
size = (images.shape[1], images.shape[2])
|
365 |
heatmap = apply_reds_colormap(heatmap, size)
|
366 |
# for i, image_idx in enumerate(sorted_image_idxs[:3]):
|
367 |
+
image_idxs = top3_image_idx[idx.item()] if not advanced else top10_image_idx[idx.item()]
|
368 |
+
for i, image_idx in enumerate(image_idxs):
|
369 |
# _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[image_idx])
|
370 |
_heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
|
371 |
axs[i, j].imshow(_heatmap)
|
|
|
385 |
|
386 |
fig_images.append(img)
|
387 |
plt.close()
|
388 |
+
|
|
|
|
|
|
|
389 |
return fig_images, ret_magnitude
|
390 |
|
391 |
|
|
|
651 |
|
652 |
def _ncut_run(*args, **kwargs):
|
653 |
n_ret = kwargs.pop("n_ret", 1)
|
654 |
+
# try:
|
655 |
+
# if torch.cuda.is_available():
|
656 |
+
# torch.cuda.empty_cache()
|
657 |
|
658 |
+
# ret = ncut_run(*args, **kwargs)
|
659 |
|
660 |
+
# if torch.cuda.is_available():
|
661 |
+
# torch.cuda.empty_cache()
|
662 |
|
663 |
+
# ret = list(ret)[:n_ret] + [ret[-1]]
|
664 |
+
# return ret
|
665 |
+
# except Exception as e:
|
666 |
+
# gr.Error(str(e))
|
667 |
+
# if torch.cuda.is_available():
|
668 |
+
# torch.cuda.empty_cache()
|
669 |
+
# return *(None for _ in range(n_ret)), "Error: " + str(e)
|
670 |
+
|
671 |
+
ret = ncut_run(*args, **kwargs)
|
672 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
673 |
+
return ret
|
674 |
|
675 |
if USE_HUGGINGFACE_ZEROGPU:
|
676 |
@spaces.GPU(duration=30)
|
|
|
1419 |
with gr.Column(scale=5, min_width=200):
|
1420 |
output_gallery = make_output_images_section()
|
1421 |
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
1422 |
+
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
|
1423 |
[
|
1424 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
1425 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|