import argparse import os import os.path as osp import torch from torch import nn import torch.nn.functional as F from PIL import Image, ImageDraw import json from torchvision.transforms import ToPILImage from lib import SupportSets, GENFORCE_MODELS, update_progress, update_stdout, STYLEGAN_LAYERS from models.load_generator import load_generator class DataParallelPassthrough(nn.DataParallel): def __getattr__(self, name): try: return super(DataParallelPassthrough, self).__getattr__(name) except AttributeError: return getattr(self.module, name) class ModelArgs: def __init__(self, **kwargs): self.__dict__.update(kwargs) def tensor2image(tensor, img_size=None, adaptive=False): # Squeeze tensor image tensor = tensor.squeeze(dim=0) if adaptive: tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) if img_size: return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)).resize((img_size, img_size)) else: return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)) else: tensor = (tensor + 1) / 2 tensor.clamp(0, 1) if img_size: return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)).resize((img_size, img_size)) else: return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)) def one_hot(dims, value, idx): vec = torch.zeros(dims) vec[idx] = value return vec def create_strip(image_list, N=5, strip_height=256): """Create strip of images across a given latent path. Args: image_list (list) : list of images (PIL.Image.Image) across a given latent path N (int) : number of images in strip strip_height (int) : strip height in pixels -- its width will be N * strip_height Returns: transformed_images_strip (PIL.Image.Image) : strip PIL image """ step = len(image_list) // N + 1 transformed_images_strip = Image.new('RGB', (N * strip_height, strip_height)) for i in range(N): j = i * step if i * step < len(image_list) else len(image_list) - 1 transformed_images_strip.paste(image_list[j].resize((strip_height, strip_height)), (i * strip_height, 0)) return transformed_images_strip def create_gif(image_list, gif_height=256): """Create gif frames for images across a given latent path. Args: image_list (list) : list of images (PIL.Image.Image) across a given latent path gif_height (int) : gif height in pixels -- its width will be N * gif_height Returns: transformed_images_gif_frames (list): list of gif frames in PIL (PIL.Image.Image) """ transformed_images_gif_frames = [] for i in range(len(image_list)): # Create gif frame gif_frame = Image.new('RGB', (2 * gif_height, gif_height)) gif_frame.paste(image_list[len(image_list) // 2].resize((gif_height, gif_height)), (0, 0)) gif_frame.paste(image_list[i].resize((gif_height, gif_height)), (gif_height, 0)) # Draw progress bar draw_bar = ImageDraw.Draw(gif_frame) bar_h = 12 bar_colour = (252, 186, 3) draw_bar.rectangle(xy=((gif_height, gif_height - bar_h), ((1 + (i / len(image_list))) * gif_height, gif_height)), fill=bar_colour) transformed_images_gif_frames.append(gif_frame) return transformed_images_gif_frames def main(): """ContraCLIP -- Latent space traversal script. A script for traversing the latent space of a pre-trained GAN generator through paths defined by the warpings of a set of pre-trained support vectors. Latent codes are drawn from a pre-defined collection via the `--pool` argument. The generated images are stored under `results/` directory. Options: ================================================================================================================ -v, --verbose : set verbose mode on ================================================================================================================ --exp : set experiment's model dir, as created by `train.py`, i.e., it should contain a subdirectory `models/` with two files, namely `reconstructor.pt` and `support_sets.pt`, which contain the weights for the reconstructor and the support sets, respectively, and an `args.json` file that contains the arguments the model has been trained with. --pool : directory of pre-defined pool of latent codes (created by `sample_gan.py`) --w-space : latent codes in the pool are in W/W+ space (typically as inverted codes of real images) ================================================================================================================ --shift-steps : set number of shifts to be applied to each latent code at each direction (positive/negative). That is, the total number of shifts applied to each latent code will be equal to 2 * args.shift_steps. --eps : set shift step magnitude for generating G(z'), where z' = z +/- eps * direction. --shift-leap : set path shift leap (after how many steps to generate images) --batch-size : set generator batch size (if not set, use the total number of images per path) --img-size : set size of saved generated images (if not set, use the output size of the respective GAN generator) --img-quality : JPEG image quality (max 95) --gif : generate collated GIF images for all paths and all latent codes --gif-height : set GIF image height -- width will be 2 * args.gif_height --gif-fps : set number of frames per second for the generated GIF images --strip : create traversal strip images --strip-number : set number of images per strip --strip-height : set strip height -- width will be 2 * args.strip_height ================================================================================================================ --cuda : use CUDA (default) --no-cuda : do not use CUDA ================================================================================================================ """ parser = argparse.ArgumentParser(description="ContraCLIP latent space traversal script") parser.add_argument('-v', '--verbose', action='store_true', help="set verbose mode on") # ================================================================================================================ # parser.add_argument('--w-space', action='store_true', help="latent codes are given in the W-space") parser.add_argument('--exp', type=str, required=True, help="set experiment's model dir (created by `train.py`)") parser.add_argument('--pool', type=str, required=True, help="directory of pre-defined pool of latent codes" "(created by `sample_gan.py`)") parser.add_argument('--shift-steps', type=int, default=16, help="set number of shifts per positive/negative path " "direction") parser.add_argument('--eps', type=float, default=0.2, help="set shift step magnitude") parser.add_argument('--shift-leap', type=int, default=1, help="set path shift leap (after how many steps to generate images)") parser.add_argument('--batch-size', type=int, help="set generator batch size (if not set, use the total number of " "images per path)") parser.add_argument('--img-size', type=int, help="set size of saved generated images (if not set, use the output " "size of the respective GAN generator)") parser.add_argument('--img-quality', type=int, default=50, help="set JPEG image quality") parser.add_argument('--strip', action='store_true', help="create traversal strip images") parser.add_argument('--strip-number', type=int, default=9, help="set number of images per strip") parser.add_argument('--strip-height', type=int, default=256, help="set strip height") parser.add_argument('--gif', action='store_true', help="create GIF traversals") parser.add_argument('--gif-height', type=int, default=256, help="set gif height") parser.add_argument('--gif-fps', type=int, default=30, help="set gif frame rate") # ================================================================================================================ # parser.add_argument('--cuda', dest='cuda', action='store_true', help="use CUDA during training") parser.add_argument('--no-cuda', dest='cuda', action='store_false', help="do NOT use CUDA during training") parser.set_defaults(cuda=True) # ================================================================================================================ # # Parse given arguments args = parser.parse_args() # Check structure of `args.exp` if not osp.isdir(args.exp): raise NotADirectoryError("Invalid given directory: {}".format(args.exp)) # -- args.json file (pre-trained model arguments) args_json_file = osp.join(args.exp, 'args.json') if not osp.isfile(args_json_file): raise FileNotFoundError("File not found: {}".format(args_json_file)) args_json = ModelArgs(**json.load(open(args_json_file))) gan = args_json.__dict__["gan"] stylegan_space = args_json.__dict__["stylegan_space"] stylegan_layer = args_json.__dict__["stylegan_layer"] if "stylegan_layer" in args_json.__dict__ else None truncation = args_json.__dict__["truncation"] # TODO: Check if `--w-space` is valid if args.w_space and (('stylegan' not in gan) or ('W' not in stylegan_space)): raise NotImplementedError # -- models directory (support sets and reconstructor, final or checkpoint files) models_dir = osp.join(args.exp, 'models') if not osp.isdir(models_dir): raise NotADirectoryError("Invalid models directory: {}".format(models_dir)) # ---- Get all files of models directory models_dir_files = [f for f in os.listdir(models_dir) if osp.isfile(osp.join(models_dir, f))] # ---- Check for latent support sets (LSS) model file (final or checkpoint) latent_support_sets_model = osp.join(models_dir, 'latent_support_sets.pt') model_iter = '' if not osp.isfile(latent_support_sets_model): latent_support_sets_checkpoint_files = [] for f in models_dir_files: if 'latent_support_sets-' in f: latent_support_sets_checkpoint_files.append(f) latent_support_sets_checkpoint_files.sort() latent_support_sets_model = osp.join(models_dir, latent_support_sets_checkpoint_files[-1]) model_iter = '-{}'.format(latent_support_sets_checkpoint_files[-1].split('.')[0].split('-')[-1]) # -- Get prompt corpus list with open(osp.join(models_dir, 'semantic_dipoles.json'), 'r') as f: semantic_dipoles = json.load(f) # Check given pool directory pool = osp.join('experiments', 'latent_codes', gan, args.pool) if not osp.isdir(pool): raise NotADirectoryError("Invalid pool directory: {} -- Please run sample_gan.py to create it.".format(pool)) # CUDA use_cuda = False multi_gpu = False if torch.cuda.is_available(): if args.cuda: use_cuda = True torch.set_default_tensor_type('torch.cuda.FloatTensor') if torch.cuda.device_count() > 1: multi_gpu = True else: print("*** WARNING ***: It looks like you have a CUDA device, but aren't using CUDA.\n" " Run with --cuda for optimal training speed.") torch.set_default_tensor_type('torch.FloatTensor') else: torch.set_default_tensor_type('torch.FloatTensor') # Build GAN generator model and load with pre-trained weights if args.verbose: print("#. Build GAN generator model G and load with pre-trained weights...") print(" \\__GAN generator : {} (res: {})".format(gan, GENFORCE_MODELS[gan][1])) print(" \\__Pre-trained weights: {}".format(GENFORCE_MODELS[gan][0])) G = load_generator(model_name=gan, latent_is_w=('stylegan' in gan) and ('W' in args_json.__dict__["stylegan_space"]), verbose=args.verbose).eval() # Upload GAN generator model to GPU if use_cuda: G = G.cuda() # Parallelize GAN generator model into multiple GPUs if available if multi_gpu: G = DataParallelPassthrough(G) # Build latent support sets model LSS if args.verbose: print("#. Build Latent Support Sets model LSS...") # Get support vector dimensionality support_vectors_dim = G.dim_z if ('stylegan' in gan) and (stylegan_space == 'W+'): support_vectors_dim *= (stylegan_layer + 1) LSS = SupportSets(num_support_sets=len(semantic_dipoles), num_support_dipoles=args_json.__dict__["num_latent_support_dipoles"], support_vectors_dim=support_vectors_dim, jung_radius=1) # Load pre-trained weights and set to evaluation mode if args.verbose: print(" \\__Pre-trained weights: {}".format(latent_support_sets_model)) LSS.load_state_dict(torch.load(latent_support_sets_model, map_location=lambda storage, loc: storage)) if args.verbose: print(" \\__Set to evaluation mode") LSS.eval() # Upload support sets model to GPU if use_cuda: LSS = LSS.cuda() # Set number of generative paths num_gen_paths = LSS.num_support_sets # Create output dir for generated images out_dir = osp.join(args.exp, 'results', args.pool + model_iter, '{}_{}_{}'.format(2 * args.shift_steps, args.eps, round(2 * args.shift_steps * args.eps, 3))) os.makedirs(out_dir, exist_ok=True) # Set default batch size if args.batch_size is None: args.batch_size = 2 * args.shift_steps + 1 ## ============================================================================================================== ## ## ## ## [Latent Codes Pool] ## ## ## ## ============================================================================================================== ## # Get latent codes from the given pool if args.verbose: print("#. Use latent codes from pool {}...".format(args.pool)) latent_codes_dirs = [dI for dI in os.listdir(pool) if os.path.isdir(os.path.join(pool, dI))] latent_codes_dirs.sort() latent_codes_list = [torch.load(osp.join(pool, subdir, 'latent_code_{}.pt'.format('w+' if args.w_space else 'z')), map_location=lambda storage, loc: storage) for subdir in latent_codes_dirs] # Get latent codes in torch Tensor format -- xs refers to z or w+ codes xs = torch.cat(latent_codes_list) if use_cuda: xs = xs.cuda() num_of_latent_codes = xs.size()[0] ## ============================================================================================================== ## ## ## ## [Latent space traversal] ## ## ## ## ============================================================================================================== ## if args.verbose: print("#. Traverse latent space...") print(" \\__Experiment : {}".format(osp.basename(osp.abspath(args.exp)))) print(" \\__Number of test latent codes : {}".format(num_of_latent_codes)) print(" \\__Test latent codes shape : {}".format(xs.shape)) print(" \\__Shift magnitude : {}".format(args.eps)) print(" \\__Shift steps : {}".format(2 * args.shift_steps)) print(" \\__Traversal length : {}".format(round(2 * args.shift_steps * args.eps, 3))) # Iterate over given latent codes for i in range(num_of_latent_codes): # Get latent code x_ = xs[i, :].unsqueeze(0) latent_code_hash = latent_codes_dirs[i] if args.verbose: update_progress(" \\__.Latent code hash: {} [{:03d}/{:03d}] ".format(latent_code_hash, i+1, num_of_latent_codes), num_of_latent_codes, i) # Create directory for current latent code latent_code_dir = osp.join(out_dir, '{}'.format(latent_code_hash)) os.makedirs(latent_code_dir, exist_ok=True) # Create directory for storing path images transformed_images_root_dir = osp.join(latent_code_dir, 'paths_images') os.makedirs(transformed_images_root_dir, exist_ok=True) transformed_images_strips_root_dir = osp.join(latent_code_dir, 'paths_strips') os.makedirs(transformed_images_strips_root_dir, exist_ok=True) # Keep all latent paths the current latent code (sample) paths_latent_codes = [] # Keep phi coefficients phi_coeffs = dict() ## ========================================================================================================== ## ## ## ## [ Path Traversal ] ## ## ## ## ========================================================================================================== ## # Iterate over (interpretable) directions for dim in range(num_gen_paths): if args.verbose: print() update_progress(" \\__path: {:03d}/{:03d} ".format(dim + 1, num_gen_paths), num_gen_paths, dim + 1) # Create shifted latent codes (for the given latent code z) and generate transformed images transformed_images = [] # Current path's latent codes and shifts lists latent_code = x_ if (not args.w_space) and ('stylegan' in gan) and ('W' in stylegan_space): latent_code = G.get_w(x_, truncation=truncation) if stylegan_space == 'W': latent_code = latent_code[:, 0, :] current_path_latent_codes = [latent_code] current_path_latent_shifts = [torch.zeros_like(latent_code).cuda() if use_cuda else torch.zeros_like(latent_code)] ## ====================================================================================================== ## ## ## ## [ Traverse through current path (positive/negative directions) ] ## ## ## ## ====================================================================================================== ## # == Positive direction == latent_code = x_.clone() if (not args.w_space) and ('stylegan' in gan) and ('W' in stylegan_space): latent_code = G.get_w(x_, truncation=truncation).clone() if stylegan_space == 'W': latent_code = latent_code[:, 0, :] cnt = 0 for _ in range(args.shift_steps): cnt += 1 # Calculate shift vector based on current z support_sets_mask = torch.zeros(1, LSS.num_support_sets) support_sets_mask[0, dim] = 1.0 if use_cuda: support_sets_mask.cuda() # Get latent space shift vector and shifted latent code if ('stylegan' in gan) and (stylegan_space == 'W+'): with torch.no_grad(): shift = args.eps * LSS(support_sets_mask, latent_code[:, :stylegan_layer + 1, :].reshape(latent_code.shape[0], -1)) latent_code = latent_code + \ F.pad(input=shift, pad=(0, (STYLEGAN_LAYERS[gan] - 1 - stylegan_layer) * 512), mode='constant', value=0).reshape_as(latent_code) current_path_latent_code = latent_code else: with torch.no_grad(): shift = args.eps * LSS(support_sets_mask, latent_code) latent_code = latent_code + shift current_path_latent_code = latent_code # Store latent codes and shifts if cnt == args.shift_leap: if ('stylegan' in gan) and (stylegan_space == 'W+'): current_path_latent_shifts.append( F.pad(input=shift, pad=(0, (STYLEGAN_LAYERS[gan] - 1 - stylegan_layer) * 512), mode='constant', value=0).reshape_as(latent_code)) else: current_path_latent_shifts.append(shift) current_path_latent_codes.append(current_path_latent_code) cnt = 0 positive_endpoint = latent_code.clone().reshape(1, -1) # ======================== # == Negative direction == latent_code = x_.clone() if (not args.w_space) and ('stylegan' in gan) and ('W' in stylegan_space): latent_code = G.get_w(x_, truncation=truncation).clone() if stylegan_space == 'W': latent_code = latent_code[:, 0, :] cnt = 0 for _ in range(args.shift_steps): cnt += 1 # Calculate shift vector based on current z support_sets_mask = torch.zeros(1, LSS.num_support_sets) support_sets_mask[0, dim] = 1.0 if use_cuda: support_sets_mask.cuda() # Get latent space shift vector and shifted latent code if ('stylegan' in gan) and (stylegan_space == 'W+'): with torch.no_grad(): shift = -args.eps * LSS( support_sets_mask, latent_code[:, :stylegan_layer + 1, :].reshape(latent_code.shape[0], -1)) latent_code = latent_code + \ F.pad(input=shift, pad=(0, (STYLEGAN_LAYERS[gan] - 1 - stylegan_layer) * 512), mode='constant', value=0).reshape_as(latent_code) current_path_latent_code = latent_code else: with torch.no_grad(): shift = -args.eps * LSS(support_sets_mask, latent_code) latent_code = latent_code + shift current_path_latent_code = latent_code # Store latent codes and shifts if cnt == args.shift_leap: if ('stylegan' in gan) and (stylegan_space == 'W+'): current_path_latent_shifts = \ [F.pad(input=shift, pad=(0, (STYLEGAN_LAYERS[gan] - 1 - stylegan_layer) * 512), mode='constant', value=0).reshape_as(latent_code)] + current_path_latent_shifts else: current_path_latent_shifts = [shift] + current_path_latent_shifts current_path_latent_codes = [current_path_latent_code] + current_path_latent_codes cnt = 0 negative_endpoint = latent_code.clone().reshape(1, -1) # ======================== # Calculate latent path phi coefficient (end-to-end distance / latent path length) phi = torch.norm(negative_endpoint - positive_endpoint, dim=1).item() / (2 * args.shift_steps * args.eps) phi_coeffs.update({dim: phi}) # Generate transformed images # Split latent codes and shifts in batches current_path_latent_codes = torch.cat(current_path_latent_codes) current_path_latent_codes_batches = torch.split(current_path_latent_codes, args.batch_size) current_path_latent_shifts = torch.cat(current_path_latent_shifts) current_path_latent_shifts_batches = torch.split(current_path_latent_shifts, args.batch_size) if len(current_path_latent_codes_batches) != len(current_path_latent_shifts_batches): raise AssertionError() else: num_batches = len(current_path_latent_codes_batches) transformed_img = [] for t in range(num_batches): with torch.no_grad(): transformed_img.append(G(current_path_latent_codes_batches[t] + current_path_latent_shifts_batches[t])) transformed_img = torch.cat(transformed_img) # Convert tensors (transformed images) into PIL images for t in range(transformed_img.shape[0]): transformed_images.append(tensor2image(transformed_img[t, :].cpu(), img_size=args.img_size, adaptive=True)) # Save all images in `transformed_images` list under `transformed_images_root_dir//` transformed_images_dir = osp.join(transformed_images_root_dir, 'path_{:03d}'.format(dim)) os.makedirs(transformed_images_dir, exist_ok=True) for t in range(len(transformed_images)): transformed_images[t].save(osp.join(transformed_images_dir, '{:06d}.jpg'.format(t)), "JPEG", quality=args.img_quality, optimize=True, progressive=True) # Save original image if (t == len(transformed_images) // 2) and (dim == 0): transformed_images[t].save(osp.join(latent_code_dir, 'original_image.jpg'), "JPEG", quality=95, optimize=True, progressive=True) # Create strip of images transformed_images_strip = create_strip(image_list=transformed_images, N=args.strip_number, strip_height=args.strip_height) transformed_images_strip.save(osp.join(transformed_images_strips_root_dir, 'path_{:03d}_strip.jpg'.format(dim)), "JPEG", quality=args.img_quality, optimize=True, progressive=True) # Save gif (static original image + traversal gif) transformed_images_gif_frames = create_gif(transformed_images, gif_height=args.gif_height) im = Image.new(mode='RGB', size=(2 * args.gif_height, args.gif_height)) im.save(fp=osp.join(transformed_images_strips_root_dir, 'path_{:03d}.gif'.format(dim)), append_images=transformed_images_gif_frames, save_all=True, optimize=True, loop=0, duration=1000 // args.gif_fps) # Append latent paths paths_latent_codes.append(current_path_latent_codes.unsqueeze(0)) if args.verbose: update_stdout(1) # ============================================================================================================ # # Save all latent paths and shifts for the current latent code (sample) in a tensor of size: # paths_latent_codes : torch.Size([num_gen_paths, 2 * args.shift_steps + 1, G.dim_z]) torch.save(torch.cat(paths_latent_codes), osp.join(latent_code_dir, 'paths_latent_codes.pt')) if args.verbose: update_stdout(1) print() print() # Create summarizing MD files if args.gif or args.strip: # For each interpretable path (warping function), collect the generated image sequences for each original latent # code and collate them into a GIF file print("#. Write summarizing MD files...") # Write .md summary files if args.gif: md_summary_file = osp.join(out_dir, 'results.md') md_summary_file_f = open(md_summary_file, "w") md_summary_file_f.write("# Experiment: {}\n".format(args.exp)) if args.strip: md_summary_strips_file = osp.join(out_dir, 'results_strips.md') md_summary_strips_file_f = open(md_summary_strips_file, "w") md_summary_strips_file_f.write("# Experiment: {}\n".format(args.exp)) if args.gif or args.strip: for dim in range(num_gen_paths): # Append to .md summary files if args.gif: md_summary_file_f.write("### \"{}\" → \"{}\"\n".format(semantic_dipoles[dim][1], semantic_dipoles[dim][0])) md_summary_file_f.write("

\n") if args.strip: md_summary_strips_file_f.write("## \"{}\" → \"{}\"\n".format(semantic_dipoles[dim][1], semantic_dipoles[dim][0])) md_summary_strips_file_f.write("

\n") for lc in latent_codes_dirs: if args.gif: md_summary_file_f.write("\n".format( osp.join(lc, 'paths_strips', 'path_{:03d}.gif'.format(dim)))) if args.strip: md_summary_strips_file_f.write("\n".format( osp.join(lc, 'paths_strips', 'path_{:03d}_strip.jpg'.format(dim)))) if args.gif: md_summary_file_f.write("phi={}\n".format(phi_coeffs[dim])) md_summary_file_f.write("

\n") if args.strip: md_summary_strips_file_f.write("phi={}\n".format(phi_coeffs[dim])) md_summary_strips_file_f.write("

\n") if args.gif: md_summary_file_f.close() if args.strip: md_summary_strips_file_f.close() if __name__ == '__main__': main()