huzey commited on
Commit
734e8a4
1 Parent(s): 043dbc3

update advanced plots

Browse files
Files changed (1) hide show
  1. app.py +28 -13
app.py CHANGED
@@ -257,15 +257,19 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
257
  blended = (1 - opacity1) * image + opacity2 * heatmap
258
  return blended.astype(np.uint8)
259
 
260
- def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6):
261
  progress = gr.Progress()
262
  progress(progess_start, desc="Finding Clusters by FPS")
263
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
264
  eigvecs = eigvecs.to(device)
265
  from ncut_pytorch.ncut_pytorch import farthest_point_sampling
266
  magnitude = torch.norm(eigvecs, dim=-1)
267
- p = 0.8
268
- top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
 
 
 
 
269
 
270
  ret_magnitude = magnitude.reshape(-1, h, w)
271
 
@@ -283,7 +287,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6):
283
  right = F.normalize(right, dim=-1)
284
  heatmap = left @ right.T
285
  heatmap = F.normalize(heatmap, dim=-1)
286
- num_samples = 80
287
  if num_samples > fps_idx.shape[0]:
288
  num_samples = fps_idx.shape[0]
289
  r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
@@ -328,13 +332,19 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6):
328
  # reverse the fps_idx
329
  # fps_idx = fps_idx.flip(0)
330
  # discard the big clusters
331
- fps_idx = fps_idx[10:]
332
- # shuffle the fps_idx
333
- fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
 
 
 
 
 
 
334
 
335
  fig_images = []
336
  i_cluster = 0
337
- num_plots = 10
338
  plot_step_float = (1.0 - progess_start) / num_plots
339
  for i_fig in range(num_plots):
340
  progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
@@ -609,7 +619,8 @@ def ncut_run(
609
  if torch.cuda.is_available():
610
  images = images.cuda()
611
  _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
612
- cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start)
 
613
  logging_str += f"plot time: {time.time() - start:.2f}s\n"
614
 
615
  norm_images = None
@@ -622,8 +633,10 @@ def ncut_run(
622
  colormap = matplotlib.colormaps['Reds']
623
  for i_image in range(eig_magnitude.shape[0]):
624
  norm_image = colormap(eig_magnitude[i_image])
625
- norm_image = (norm_image[..., :3] * 255).astype(np.uint8)
626
- norm_images.append(Image.fromarray(norm_image))
 
 
627
  logging_str += "Eigenvector Magnitude\n"
628
  logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
629
  gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=0)
@@ -855,6 +868,7 @@ def run_fn(
855
  n_ret=1,
856
  plot_clusters=False,
857
  alignedcut_eig_norm_plot=False,
 
858
  ):
859
 
860
  progress=gr.Progress()
@@ -987,6 +1001,7 @@ def run_fn(
987
  "n_ret": n_ret,
988
  "plot_clusters": plot_clusters,
989
  "alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
 
990
  }
991
  # print(kwargs)
992
 
@@ -1416,7 +1431,7 @@ with demo:
1416
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1417
 
1418
  submit_button.click(
1419
- partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True),
1420
  inputs=[
1421
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1422
  positive_prompt, negative_prompt,
@@ -1802,7 +1817,7 @@ with demo:
1802
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1803
 
1804
  submit_button.click(
1805
- partial(run_fn, n_ret=3),
1806
  inputs=[
1807
  input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
1808
  positive_prompt, negative_prompt,
 
257
  blended = (1 - opacity1) * image + opacity2 * heatmap
258
  return blended.astype(np.uint8)
259
 
260
+ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=False):
261
  progress = gr.Progress()
262
  progress(progess_start, desc="Finding Clusters by FPS")
263
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
264
  eigvecs = eigvecs.to(device)
265
  from ncut_pytorch.ncut_pytorch import farthest_point_sampling
266
  magnitude = torch.norm(eigvecs, dim=-1)
267
+
268
+ # gr.Info("Finding Clusters by FPS, no magnitude filtering")
269
+ top_p_idx = torch.arange(eigvecs.shape[0])
270
+ # gr.Info("Finding Clusters by FPS, with magnitude filtering")
271
+ # p = 0.8
272
+ # top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
273
 
274
  ret_magnitude = magnitude.reshape(-1, h, w)
275
 
 
287
  right = F.normalize(right, dim=-1)
288
  heatmap = left @ right.T
289
  heatmap = F.normalize(heatmap, dim=-1)
290
+ num_samples = 80 if not advanced else 130
291
  if num_samples > fps_idx.shape[0]:
292
  num_samples = fps_idx.shape[0]
293
  r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
 
332
  # reverse the fps_idx
333
  # fps_idx = fps_idx.flip(0)
334
  # discard the big clusters
335
+
336
+ # gr.Info("Discarding the biggest 10 clusters")
337
+ # fps_idx = fps_idx[10:]
338
+ # gr.Info("Not discarding the biggest 10 clusters")
339
+ # gr.Info("Discarding the smallest 30 out of 80 sampled clusters")
340
+
341
+ if not advanced:
342
+ # shuffle the fps_idx
343
+ fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
344
 
345
  fig_images = []
346
  i_cluster = 0
347
+ num_plots = 10 if not advanced else 20
348
  plot_step_float = (1.0 - progess_start) / num_plots
349
  for i_fig in range(num_plots):
350
  progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
 
619
  if torch.cuda.is_available():
620
  images = images.cuda()
621
  _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
622
+ advanced = kwargs.get("advanced", False)
623
+ cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=advanced)
624
  logging_str += f"plot time: {time.time() - start:.2f}s\n"
625
 
626
  norm_images = None
 
633
  colormap = matplotlib.colormaps['Reds']
634
  for i_image in range(eig_magnitude.shape[0]):
635
  norm_image = colormap(eig_magnitude[i_image])
636
+ # norm_image = (norm_image[..., :3] * 255).astype(np.uint8)
637
+ # norm_images.append(Image.fromarray(norm_image))
638
+ norm_images.append(torch.tensor(norm_image[..., :3]))
639
+ norm_images = to_pil_images(norm_images)
640
  logging_str += "Eigenvector Magnitude\n"
641
  logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
642
  gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=0)
 
868
  n_ret=1,
869
  plot_clusters=False,
870
  alignedcut_eig_norm_plot=False,
871
+ advanced=False,
872
  ):
873
 
874
  progress=gr.Progress()
 
1001
  "n_ret": n_ret,
1002
  "plot_clusters": plot_clusters,
1003
  "alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
1004
+ "advanced": advanced,
1005
  }
1006
  # print(kwargs)
1007
 
 
1431
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1432
 
1433
  submit_button.click(
1434
+ partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True, advanced=True),
1435
  inputs=[
1436
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1437
  positive_prompt, negative_prompt,
 
1817
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1818
 
1819
  submit_button.click(
1820
+ partial(run_fn, n_ret=3, advanced=True),
1821
  inputs=[
1822
  input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
1823
  positive_prompt, negative_prompt,