# Copyright 2023 Adobe Research. All rights reserved. # To view a copy of the license, visit LICENSE.md. import sys sys.path.append('..') import argparse from pathlib import Path import torch import torch.nn.functional as F import torchvision.transforms as T import dnnlib import legacy from expansion_utils import io_utils, latent_operations def generate_images( ckpt, num_samples, truncation_psi ): device = torch.device('cuda') with dnnlib.util.open_url(ckpt) as f: snapshot_dict = legacy.load_network_pkl(f) G = snapshot_dict['G_ema'].to(device) latent_basis = snapshot_dict['latent_basis'].to(device) subspace_distance = snapshot_dict['subspace_distance'] repurposed_dims = snapshot_dict['repurposed_dims'].cpu() # out_dir = Path(out_dir) def norm_fn(tensor): minFrom= tensor.min() maxFrom= tensor.max() minTo = 0 maxTo=1 return minTo + (maxTo - minTo) * ((tensor - minFrom) / (maxFrom - minFrom)) topil = T.ToPILImage(mode='RGB') # norm_fn = T.Normalize( # # mean=[0.485, 0.456, 0.406], # # std=[0.229, 0.224, 0.225] # mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], # std=[1/0.229, 1/0.224, 1/0.225] # ) all_imgs = [] for i in range(num_samples): per_sample_imgs = [] z = torch.randn((1, G.z_dim), ).to(device) w = G.mapping(z, None, truncation_psi=truncation_psi) base_w, edit_ws = latent_operations.project_to_subspaces(w, latent_basis, repurposed_dims, step_size=subspace_distance, mean=G.mapping.w_avg) edit_ws = edit_ws[0] # Single step base_img = G.synthesis(base_w, noise_mode='const') per_sample_imgs.append(topil(norm_fn(base_img.squeeze()))) # io_utils.save_images(base_img, out_dir.joinpath('base', f'{i:05d}')) for idx, (dim_num, edit_w) in enumerate(zip(repurposed_dims, edit_ws)): # dim_out_dir = out_dir.joinpath(f'dim_{dim_num}') if idx % 4 == 0: edit_img = G.synthesis(edit_w, noise_mode='const') # mean, std = edit_img.mean((0,2)), edit_img.std((0,2)) # norm_fn = T.Normalize(mean, std) edited_img_pil = topil( norm_fn( edit_img.squeeze(), ) ) per_sample_imgs.append(edited_img_pil) # io_utils.save_images(edit_img, dim_out_dir.joinpath(f'{i:05d}')) all_imgs.append(per_sample_imgs) return all_imgs if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--ckpt', help='Network pickle filename', required=True) parser.add_argument('--out_dir', help='Where to save the output images', type=str, required=True, metavar='DIR') parser.add_argument('--num', help='Number of independant samples', type=int) parser.add_argument('--truncation_psi', help='Coefficient for truncation', type=float, default=1) args = parser.parse_args() with torch.no_grad(): generate_images(args.ckpt, args.out_dir, args.num, args.truncation_psi)