huzey commited on
Commit
04cb121
1 Parent(s): f7efa0a

fix SD norm

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -461,15 +461,18 @@ def extract_video_frames(video_path, max_frames=100):
461
  # return as list of PIL images
462
  return [(Image.fromarray(frames[i]), "") for i in range(frames.shape[0])]
463
 
464
- def transform_image(image, resolution=(1024, 1024)):
465
  image = image.convert('RGB').resize(resolution, Image.LANCZOS)
466
  # Convert to torch tensor
467
  image = torch.tensor(np.array(image).transpose(2, 0, 1)).float()
468
  image = image / 255
469
  # Normalize
470
- mean = [0.485, 0.456, 0.406]
471
- std = [0.229, 0.224, 0.225]
472
- image = (image - torch.tensor(mean).view(3, 1, 1)) / torch.tensor(std).view(3, 1, 1)
 
 
 
473
  return image
474
 
475
  def plot_one_image_36_grid(original_image, tsne_rgb_images):
@@ -617,7 +620,8 @@ def run_fn(
617
  else:
618
  resolution = RES_DICT[model_name]
619
  images = [tup[0] for tup in images]
620
- images = [transform_image(image, resolution=resolution) for image in images]
 
621
  images = torch.stack(images)
622
 
623
 
@@ -1401,9 +1405,9 @@ with demo:
1401
  with gr.Column():
1402
  gr.Markdown("###### Running out of GPU? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
1403
 
1404
- # # for local development
1405
- # if os.path.exists("/hf_token.txt"):
1406
- # os.environ["HF_ACCESS_TOKEN"] = open("/hf_token.txt").read().strip()
1407
 
1408
  if DOWNLOAD_ALL_MODELS_DATASETS:
1409
  from ncut_pytorch.backbone import download_all_models
 
461
  # return as list of PIL images
462
  return [(Image.fromarray(frames[i]), "") for i in range(frames.shape[0])]
463
 
464
+ def transform_image(image, resolution=(1024, 1024), stablediffusion=False):
465
  image = image.convert('RGB').resize(resolution, Image.LANCZOS)
466
  # Convert to torch tensor
467
  image = torch.tensor(np.array(image).transpose(2, 0, 1)).float()
468
  image = image / 255
469
  # Normalize
470
+ if not stablediffusion:
471
+ mean = [0.485, 0.456, 0.406]
472
+ std = [0.229, 0.224, 0.225]
473
+ image = (image - torch.tensor(mean).view(3, 1, 1)) / torch.tensor(std).view(3, 1, 1)
474
+ if stablediffusion:
475
+ image = image * 2 - 1
476
  return image
477
 
478
  def plot_one_image_36_grid(original_image, tsne_rgb_images):
 
620
  else:
621
  resolution = RES_DICT[model_name]
622
  images = [tup[0] for tup in images]
623
+ stablediffusion = True if "Diffusion" in model_name else False
624
+ images = [transform_image(image, resolution=resolution, stablediffusion=stablediffusion) for image in images]
625
  images = torch.stack(images)
626
 
627
 
 
1405
  with gr.Column():
1406
  gr.Markdown("###### Running out of GPU? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
1407
 
1408
+ # for local development
1409
+ if os.path.exists("/hf_token.txt"):
1410
+ os.environ["HF_ACCESS_TOKEN"] = open("/hf_token.txt").read().strip()
1411
 
1412
  if DOWNLOAD_ALL_MODELS_DATASETS:
1413
  from ncut_pytorch.backbone import download_all_models