huzey commited on
Commit
22610e0
1 Parent(s): 4e7b524

update cluster plot cuda

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -497,6 +497,8 @@ def ncut_run(
497
  if not video_output:
498
  start = time.time()
499
  h, w = features.shape[1], features.shape[2]
 
 
500
  _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
501
  cluster_images = make_cluster_plot(eigvecs, _images, h=h, w=w)
502
  logging_str += f"Plot time: {time.time() - start:.2f}s\n"
@@ -602,9 +604,9 @@ def reverse_transform_image(image, stablediffusion=False):
602
  if stablediffusion:
603
  image = (image + 1) / 2
604
  else:
605
- mean = [0.485, 0.456, 0.406]
606
- std = [0.229, 0.224, 0.225]
607
- image = image * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1)
608
  image = torch.clamp(image, 0, 1)
609
  return image
610
 
@@ -1168,7 +1170,7 @@ with demo:
1168
 
1169
  with gr.Column(scale=5, min_width=200):
1170
  output_gallery = make_output_images_section()
1171
- cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=False, elem_id="clusters", columns=[2], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=True)
1172
  [
1173
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1174
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
 
497
  if not video_output:
498
  start = time.time()
499
  h, w = features.shape[1], features.shape[2]
500
+ if torch.cuda.is_available():
501
+ images = images.cuda()
502
  _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
503
  cluster_images = make_cluster_plot(eigvecs, _images, h=h, w=w)
504
  logging_str += f"Plot time: {time.time() - start:.2f}s\n"
 
604
  if stablediffusion:
605
  image = (image + 1) / 2
606
  else:
607
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(image.device)
608
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(image.device)
609
+ image = image * std + mean
610
  image = torch.clamp(image, 0, 1)
611
  return image
612
 
 
1170
 
1171
  with gr.Column(scale=5, min_width=200):
1172
  output_gallery = make_output_images_section()
1173
+ cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=False, elem_id="clusters", columns=[2], rows=[1], object_fit="contain", height=450, show_share_button=True, preview=True)
1174
  [
1175
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1176
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,