Spaces:
Running
on
Zero
Running
on
Zero
update advanced plots
Browse files
app.py
CHANGED
@@ -257,15 +257,19 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
|
|
257 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
258 |
return blended.astype(np.uint8)
|
259 |
|
260 |
-
def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6):
|
261 |
progress = gr.Progress()
|
262 |
progress(progess_start, desc="Finding Clusters by FPS")
|
263 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
264 |
eigvecs = eigvecs.to(device)
|
265 |
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
|
266 |
magnitude = torch.norm(eigvecs, dim=-1)
|
267 |
-
|
268 |
-
|
|
|
|
|
|
|
|
|
269 |
|
270 |
ret_magnitude = magnitude.reshape(-1, h, w)
|
271 |
|
@@ -283,7 +287,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6):
|
|
283 |
right = F.normalize(right, dim=-1)
|
284 |
heatmap = left @ right.T
|
285 |
heatmap = F.normalize(heatmap, dim=-1)
|
286 |
-
num_samples = 80
|
287 |
if num_samples > fps_idx.shape[0]:
|
288 |
num_samples = fps_idx.shape[0]
|
289 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
@@ -328,13 +332,19 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6):
|
|
328 |
# reverse the fps_idx
|
329 |
# fps_idx = fps_idx.flip(0)
|
330 |
# discard the big clusters
|
331 |
-
|
332 |
-
#
|
333 |
-
fps_idx = fps_idx[
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
fig_images = []
|
336 |
i_cluster = 0
|
337 |
-
num_plots = 10
|
338 |
plot_step_float = (1.0 - progess_start) / num_plots
|
339 |
for i_fig in range(num_plots):
|
340 |
progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
|
@@ -609,7 +619,8 @@ def ncut_run(
|
|
609 |
if torch.cuda.is_available():
|
610 |
images = images.cuda()
|
611 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
612 |
-
|
|
|
613 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
614 |
|
615 |
norm_images = None
|
@@ -622,8 +633,10 @@ def ncut_run(
|
|
622 |
colormap = matplotlib.colormaps['Reds']
|
623 |
for i_image in range(eig_magnitude.shape[0]):
|
624 |
norm_image = colormap(eig_magnitude[i_image])
|
625 |
-
norm_image = (norm_image[..., :3] * 255).astype(np.uint8)
|
626 |
-
norm_images.append(Image.fromarray(norm_image))
|
|
|
|
|
627 |
logging_str += "Eigenvector Magnitude\n"
|
628 |
logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
|
629 |
gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=0)
|
@@ -855,6 +868,7 @@ def run_fn(
|
|
855 |
n_ret=1,
|
856 |
plot_clusters=False,
|
857 |
alignedcut_eig_norm_plot=False,
|
|
|
858 |
):
|
859 |
|
860 |
progress=gr.Progress()
|
@@ -987,6 +1001,7 @@ def run_fn(
|
|
987 |
"n_ret": n_ret,
|
988 |
"plot_clusters": plot_clusters,
|
989 |
"alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
|
|
|
990 |
}
|
991 |
# print(kwargs)
|
992 |
|
@@ -1416,7 +1431,7 @@ with demo:
|
|
1416 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
1417 |
|
1418 |
submit_button.click(
|
1419 |
-
partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True),
|
1420 |
inputs=[
|
1421 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
1422 |
positive_prompt, negative_prompt,
|
@@ -1802,7 +1817,7 @@ with demo:
|
|
1802 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
1803 |
|
1804 |
submit_button.click(
|
1805 |
-
partial(run_fn, n_ret=3),
|
1806 |
inputs=[
|
1807 |
input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
|
1808 |
positive_prompt, negative_prompt,
|
|
|
257 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
258 |
return blended.astype(np.uint8)
|
259 |
|
260 |
+
def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=False):
|
261 |
progress = gr.Progress()
|
262 |
progress(progess_start, desc="Finding Clusters by FPS")
|
263 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
264 |
eigvecs = eigvecs.to(device)
|
265 |
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
|
266 |
magnitude = torch.norm(eigvecs, dim=-1)
|
267 |
+
|
268 |
+
# gr.Info("Finding Clusters by FPS, no magnitude filtering")
|
269 |
+
top_p_idx = torch.arange(eigvecs.shape[0])
|
270 |
+
# gr.Info("Finding Clusters by FPS, with magnitude filtering")
|
271 |
+
# p = 0.8
|
272 |
+
# top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
|
273 |
|
274 |
ret_magnitude = magnitude.reshape(-1, h, w)
|
275 |
|
|
|
287 |
right = F.normalize(right, dim=-1)
|
288 |
heatmap = left @ right.T
|
289 |
heatmap = F.normalize(heatmap, dim=-1)
|
290 |
+
num_samples = 80 if not advanced else 130
|
291 |
if num_samples > fps_idx.shape[0]:
|
292 |
num_samples = fps_idx.shape[0]
|
293 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
|
|
332 |
# reverse the fps_idx
|
333 |
# fps_idx = fps_idx.flip(0)
|
334 |
# discard the big clusters
|
335 |
+
|
336 |
+
# gr.Info("Discarding the biggest 10 clusters")
|
337 |
+
# fps_idx = fps_idx[10:]
|
338 |
+
# gr.Info("Not discarding the biggest 10 clusters")
|
339 |
+
# gr.Info("Discarding the smallest 30 out of 80 sampled clusters")
|
340 |
+
|
341 |
+
if not advanced:
|
342 |
+
# shuffle the fps_idx
|
343 |
+
fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
|
344 |
|
345 |
fig_images = []
|
346 |
i_cluster = 0
|
347 |
+
num_plots = 10 if not advanced else 20
|
348 |
plot_step_float = (1.0 - progess_start) / num_plots
|
349 |
for i_fig in range(num_plots):
|
350 |
progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
|
|
|
619 |
if torch.cuda.is_available():
|
620 |
images = images.cuda()
|
621 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
622 |
+
advanced = kwargs.get("advanced", False)
|
623 |
+
cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=advanced)
|
624 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
625 |
|
626 |
norm_images = None
|
|
|
633 |
colormap = matplotlib.colormaps['Reds']
|
634 |
for i_image in range(eig_magnitude.shape[0]):
|
635 |
norm_image = colormap(eig_magnitude[i_image])
|
636 |
+
# norm_image = (norm_image[..., :3] * 255).astype(np.uint8)
|
637 |
+
# norm_images.append(Image.fromarray(norm_image))
|
638 |
+
norm_images.append(torch.tensor(norm_image[..., :3]))
|
639 |
+
norm_images = to_pil_images(norm_images)
|
640 |
logging_str += "Eigenvector Magnitude\n"
|
641 |
logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
|
642 |
gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=0)
|
|
|
868 |
n_ret=1,
|
869 |
plot_clusters=False,
|
870 |
alignedcut_eig_norm_plot=False,
|
871 |
+
advanced=False,
|
872 |
):
|
873 |
|
874 |
progress=gr.Progress()
|
|
|
1001 |
"n_ret": n_ret,
|
1002 |
"plot_clusters": plot_clusters,
|
1003 |
"alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
|
1004 |
+
"advanced": advanced,
|
1005 |
}
|
1006 |
# print(kwargs)
|
1007 |
|
|
|
1431 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
1432 |
|
1433 |
submit_button.click(
|
1434 |
+
partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True, advanced=True),
|
1435 |
inputs=[
|
1436 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
1437 |
positive_prompt, negative_prompt,
|
|
|
1817 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
1818 |
|
1819 |
submit_button.click(
|
1820 |
+
partial(run_fn, n_ret=3, advanced=True),
|
1821 |
inputs=[
|
1822 |
input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
|
1823 |
positive_prompt, negative_prompt,
|