import argparse import os import os.path as osp import torch from torch import nn import torch.nn.functional as F from PIL import Image, ImageDraw import json from torchvision.transforms import ToPILImage from lib import SupportSets, GENFORCE_MODELS, update_progress, update_stdout, STYLEGAN_LAYERS from models.load_generator import load_generator import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from sklearn.manifold import TSNE class DataParallelPassthrough(nn.DataParallel): def __getattr__(self, name): try: return super(DataParallelPassthrough, self).__getattr__(name) except AttributeError: return getattr(self.module, name) class ModelArgs: def __init__(self, **kwargs): self.__dict__.update(kwargs) def tensor2image(tensor, img_size=None, adaptive=False): # Squeeze tensor image tensor = tensor.squeeze(dim=0) if adaptive: tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) if img_size: return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)).resize((img_size, img_size)) else: return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)) else: tensor = (tensor + 1) / 2 tensor.clamp(0, 1) if img_size: return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)).resize((img_size, img_size)) else: return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)) def one_hot(dims, value, idx): vec = torch.zeros(dims) vec[idx] = value return vec def create_strip(image_list, N=5, strip_height=256): """Create strip of images across a given latent path. Args: image_list (list) : list of images (PIL.Image.Image) across a given latent path N (int) : number of images in strip strip_height (int) : strip height in pixels -- its width will be N * strip_height Returns: transformed_images_strip (PIL.Image.Image) : strip PIL image """ step = len(image_list) // N + 1 transformed_images_strip = Image.new('RGB', (N * strip_height, strip_height)) for i in range(N): j = i * step if i * step < len(image_list) else len(image_list) - 1 transformed_images_strip.paste(image_list[j].resize((strip_height, strip_height)), (i * strip_height, 0)) return transformed_images_strip def create_gif(image_list, gif_height=256): """Create gif frames for images across a given latent path. Args: image_list (list) : list of images (PIL.Image.Image) across a given latent path gif_height (int) : gif height in pixels -- its width will be N * gif_height Returns: transformed_images_gif_frames (list): list of gif frames in PIL (PIL.Image.Image) """ transformed_images_gif_frames = [] for i in range(len(image_list)): # Create gif frame gif_frame = Image.new('RGB', (2 * gif_height, gif_height)) gif_frame.paste(image_list[len(image_list) // 2].resize((gif_height, gif_height)), (0, 0)) gif_frame.paste(image_list[i].resize((gif_height, gif_height)), (gif_height, 0)) # Draw progress bar draw_bar = ImageDraw.Draw(gif_frame) bar_h = 12 bar_colour = (252, 186, 3) draw_bar.rectangle(xy=((gif_height, gif_height - bar_h), ((1 + (i / len(image_list))) * gif_height, gif_height)), fill=bar_colour) transformed_images_gif_frames.append(gif_frame) return transformed_images_gif_frames def visualize_latent_space(tsne_latent_codes, semantic_dipoles, output_dir, save_filename="latent_space_tsne.png", shift_steps=16): """ Visualize the t-SNE reduced latent space with minimal annotations. Args: tsne_latent_codes (np.ndarray): The 3D latent codes after t-SNE transformation. semantic_dipoles (list): List of semantic directions (labels) for paths. shift_steps (int): Number of positive/negative steps along each path. output_dir (str): Directory to save the generated plot. save_filename (str): Name of the file to save the plot. """ fig = plt.figure(figsize=(16, 12)) # Larger figure for clarity ax = fig.add_subplot(111, projection='3d') num_paths = len(semantic_dipoles) # Each dipole represents one path cmap = plt.cm.get_cmap('tab10', num_paths) for i in range(num_paths): # Indices for the path in tsne_latent_codes start_idx = i * (2 * shift_steps + 1) pos_idx = start_idx + shift_steps # Positive endpoint neg_idx = start_idx + 2 * shift_steps # Negative endpoint # Extract path points path_indices = list(range(start_idx, neg_idx + 1)) path_coords = tsne_latent_codes[path_indices] # Plot the entire path (all intermediate points in a single color) ax.plot( path_coords[:, 0], path_coords[:, 1], path_coords[:, 2], color=cmap(i), linewidth=2 ) # Extract positive and negative endpoint coordinates pos_coords = tsne_latent_codes[pos_idx] neg_coords = tsne_latent_codes[neg_idx] # Plot positive and negative endpoints ax.scatter(*pos_coords, color=cmap(i), s=100, label=f"{semantic_dipoles[i][0]} → {semantic_dipoles[i][1]}") ax.scatter(*neg_coords, color=cmap(i), s=100) # Add legend ax.legend(loc='best', fontsize=10) # Set titles and labels ax.set_title("t-SNE Latent Space Visualization") ax.set_xlabel("t-SNE Dimension 1") ax.set_ylabel("t-SNE Dimension 2") ax.set_zlabel("t-SNE Dimension 3") # Save the plot os.makedirs(output_dir, exist_ok=True) save_path = osp.join(output_dir, save_filename) plt.savefig(save_path, dpi=300, bbox_inches="tight") print(f"Visualization saved to {save_path}") def main(): """ContraCLIP -- Latent space traversal script. A script for traversing the latent space of a pre-trained GAN generator through paths defined by the warpings of a set of pre-trained support vectors. Latent codes are drawn from a pre-defined collection via the `--pool` argument. The generated images are stored under `results/` directory. Options: ================================================================================================================ -v, --verbose : set verbose mode on ================================================================================================================ --exp : set experiment's model dir, as created by `train.py`, i.e., it should contain a subdirectory `models/` with two files, namely `reconstructor.pt` and `support_sets.pt`, which contain the weights for the reconstructor and the support sets, respectively, and an `args.json` file that contains the arguments the model has been trained with. --pool : directory of pre-defined pool of latent codes (created by `sample_gan.py`) --w-space : latent codes in the pool are in W/W+ space (typically as inverted codes of real images) ================================================================================================================ --shift-steps : set number of shifts to be applied to each latent code at each direction (positive/negative). That is, the total number of shifts applied to each latent code will be equal to 2 * args.shift_steps. --eps : set shift step magnitude for generating G(z'), where z' = z +/- eps * direction. --shift-leap : set path shift leap (after how many steps to generate images) --batch-size : set generator batch size (if not set, use the total number of images per path) --img-size : set size of saved generated images (if not set, use the output size of the respective GAN generator) --img-quality : JPEG image quality (max 95) --gif : generate collated GIF images for all paths and all latent codes --gif-height : set GIF image height -- width will be 2 * args.gif_height --gif-fps : set number of frames per second for the generated GIF images --strip : create traversal strip images --strip-number : set number of images per strip --strip-height : set strip height -- width will be 2 * args.strip_height ================================================================================================================ --cuda : use CUDA (default) --no-cuda : do not use CUDA ================================================================================================================ """ parser = argparse.ArgumentParser(description="ContraCLIP latent space traversal script") parser.add_argument('-v', '--verbose', action='store_true', help="set verbose mode on") # ================================================================================================================ # parser.add_argument('--w-space', action='store_true', help="latent codes are given in the W-space") parser.add_argument('--exp', type=str, required=True, help="set experiment's model dir (created by `train.py`)") parser.add_argument('--pool', type=str, required=True, help="directory of pre-defined pool of latent codes" "(created by `sample_gan.py`)") parser.add_argument('--shift-steps', type=int, default=16, help="set number of shifts per positive/negative path " "direction") parser.add_argument('--eps', type=float, default=0.2, help="set shift step magnitude") parser.add_argument('--shift-leap', type=int, default=1, help="set path shift leap (after how many steps to generate images)") parser.add_argument('--batch-size', type=int, help="set generator batch size (if not set, use the total number of " "images per path)") parser.add_argument('--img-size', type=int, help="set size of saved generated images (if not set, use the output " "size of the respective GAN generator)") parser.add_argument('--img-quality', type=int, default=50, help="set JPEG image quality") parser.add_argument('--strip', action='store_true', help="create traversal strip images") parser.add_argument('--strip-number', type=int, default=9, help="set number of images per strip") parser.add_argument('--strip-height', type=int, default=256, help="set strip height") parser.add_argument('--gif', action='store_true', help="create GIF traversals") parser.add_argument('--gif-height', type=int, default=256, help="set gif height") parser.add_argument('--gif-fps', type=int, default=30, help="set gif frame rate") # ================================================================================================================ # parser.add_argument('--cuda', dest='cuda', action='store_true', help="use CUDA during training") parser.add_argument('--no-cuda', dest='cuda', action='store_false', help="do NOT use CUDA during training") parser.set_defaults(cuda=True) # ================================================================================================================ # # Parse given arguments args = parser.parse_args() # Check structure of `args.exp` if not osp.isdir(args.exp): raise NotADirectoryError("Invalid given directory: {}".format(args.exp)) # -- args.json file (pre-trained model arguments) args_json_file = osp.join(args.exp, 'args.json') if not osp.isfile(args_json_file): raise FileNotFoundError("File not found: {}".format(args_json_file)) args_json = ModelArgs(**json.load(open(args_json_file))) gan = args_json.__dict__["gan"] stylegan_space = args_json.__dict__["stylegan_space"] stylegan_layer = args_json.__dict__["stylegan_layer"] if "stylegan_layer" in args_json.__dict__ else None truncation = args_json.__dict__["truncation"] # TODO: Check if `--w-space` is valid if args.w_space and (('stylegan' not in gan) or ('W' not in stylegan_space)): raise NotImplementedError # -- models directory (support sets and reconstructor, final or checkpoint files) models_dir = osp.join(args.exp, 'models') if not osp.isdir(models_dir): raise NotADirectoryError("Invalid models directory: {}".format(models_dir)) # ---- Get all files of models directory models_dir_files = [f for f in os.listdir(models_dir) if osp.isfile(osp.join(models_dir, f))] # ---- Check for latent support sets (LSS) model file (final or checkpoint) latent_support_sets_model = osp.join(models_dir, 'latent_support_sets.pt') model_iter = '' if not osp.isfile(latent_support_sets_model): latent_support_sets_checkpoint_files = [] for f in models_dir_files: if 'latent_support_sets-' in f: latent_support_sets_checkpoint_files.append(f) latent_support_sets_checkpoint_files.sort() latent_support_sets_model = osp.join(models_dir, latent_support_sets_checkpoint_files[-1]) model_iter = '-{}'.format(latent_support_sets_checkpoint_files[-1].split('.')[0].split('-')[-1]) # -- Get prompt corpus list with open(osp.join(models_dir, 'semantic_dipoles.json'), 'r') as f: semantic_dipoles = json.load(f) # semantic_directions = [f"{dipole[0]} → {dipole[1]}" for dipole in semantic_dipoles] # Check given pool directory pool = osp.join('experiments', 'latent_codes', gan, args.pool) if not osp.isdir(pool): raise NotADirectoryError("Invalid pool directory: {} -- Please run sample_gan.py to create it.".format(pool)) # CUDA use_cuda = False multi_gpu = False if torch.cuda.is_available(): if args.cuda: use_cuda = True torch.set_default_tensor_type('torch.cuda.FloatTensor') if torch.cuda.device_count() > 1: multi_gpu = True else: print("*** WARNING ***: It looks like you have a CUDA device, but aren't using CUDA.\n" " Run with --cuda for optimal training speed.") torch.set_default_tensor_type('torch.FloatTensor') else: torch.set_default_tensor_type('torch.FloatTensor') # Build GAN generator model and load with pre-trained weights if args.verbose: print("#. Build GAN generator model G and load with pre-trained weights...") print(" \\__GAN generator : {} (res: {})".format(gan, GENFORCE_MODELS[gan][1])) print(" \\__Pre-trained weights: {}".format(GENFORCE_MODELS[gan][0])) G = load_generator(model_name=gan, latent_is_w=('stylegan' in gan) and ('W' in args_json.__dict__["stylegan_space"]), verbose=args.verbose).eval() # Upload GAN generator model to GPU if use_cuda: G = G.cuda() # Parallelize GAN generator model into multiple GPUs if available if multi_gpu: G = DataParallelPassthrough(G) # Build latent support sets model LSS if args.verbose: print("#. Build Latent Support Sets model LSS...") # Get support vector dimensionality support_vectors_dim = G.dim_z if ('stylegan' in gan) and (stylegan_space == 'W+'): support_vectors_dim *= (stylegan_layer + 1) LSS = SupportSets(num_support_sets=len(semantic_dipoles), num_support_dipoles=args_json.__dict__["num_latent_support_dipoles"], support_vectors_dim=support_vectors_dim, jung_radius=1) # Load pre-trained weights and set to evaluation mode if args.verbose: print(" \\__Pre-trained weights: {}".format(latent_support_sets_model)) LSS.load_state_dict(torch.load(latent_support_sets_model, map_location=lambda storage, loc: storage)) if args.verbose: print(" \\__Set to evaluation mode") LSS.eval() # Upload support sets model to GPU if use_cuda: LSS = LSS.cuda() # Set number of generative paths num_gen_paths = LSS.num_support_sets # Create output dir for generated images out_dir = osp.join(args.exp, 'results', args.pool + model_iter, '{}_{}_{}'.format(2 * args.shift_steps, args.eps, round(2 * args.shift_steps * args.eps, 3))) os.makedirs(out_dir, exist_ok=True) # Set default batch size if args.batch_size is None: args.batch_size = 2 * args.shift_steps + 1 ## ============================================================================================================== ## ## ## ## [Latent Codes Pool] ## ## ## ## ============================================================================================================== ## # Get latent codes from the given pool if args.verbose: print("#. Use latent codes from pool {}...".format(args.pool)) latent_codes_dirs = [dI for dI in os.listdir(pool) if os.path.isdir(os.path.join(pool, dI))] latent_codes_dirs.sort() latent_codes_list = [torch.load(osp.join(pool, subdir, 'latent_code_{}.pt'.format('w+' if args.w_space else 'z')), map_location=lambda storage, loc: storage) for subdir in latent_codes_dirs] # Get latent codes in torch Tensor format -- xs refers to z or w+ codes xs = torch.cat(latent_codes_list) if use_cuda: xs = xs.cuda() num_of_latent_codes = xs.size()[0] ## ============================================================================================================== ## ## ## ## [Latent space traversal] ## ## ## ## ============================================================================================================== ## if args.verbose: print("#. Traverse latent space...") print(" \\__Experiment : {}".format(osp.basename(osp.abspath(args.exp)))) print(" \\__Number of test latent codes : {}".format(num_of_latent_codes)) print(" \\__Test latent codes shape : {}".format(xs.shape)) print(" \\__Shift magnitude : {}".format(args.eps)) print(" \\__Shift steps : {}".format(2 * args.shift_steps)) print(" \\__Traversal length : {}".format(round(2 * args.shift_steps * args.eps, 3))) # Store latent codes for T-SNE visualization (for all paths across each latent code) all_paths_latent_codes = [] # Iterate over given latent codes for i in range(num_of_latent_codes): # Get latent code x_ = xs[i, :].unsqueeze(0) latent_code_hash = latent_codes_dirs[i] if args.verbose: update_progress(" \\__.Latent code hash: {} [{:03d}/{:03d}] ".format(latent_code_hash, i+1, num_of_latent_codes), num_of_latent_codes, i) # Append the starting latent code to tsne_latent_codes # tsne_latent_codes.append(x_.clone().cpu().numpy().flatten()) # Create directory for current latent code latent_code_dir = osp.join(out_dir, '{}'.format(latent_code_hash)) os.makedirs(latent_code_dir, exist_ok=True) # Create directory for storing path images transformed_images_root_dir = osp.join(latent_code_dir, 'paths_images') os.makedirs(transformed_images_root_dir, exist_ok=True) transformed_images_strips_root_dir = osp.join(latent_code_dir, 'paths_strips') os.makedirs(transformed_images_strips_root_dir, exist_ok=True) # Keep all latent paths the current latent code (sample) paths_latent_codes = [] # Keep phi coefficients phi_coeffs = dict() ## ========================================================================================================== ## ## ## ## [ Path Traversal ] ## ## ## ## ========================================================================================================== ## # Iterate over (interpretable) directions for dim in range(num_gen_paths): if args.verbose: print() update_progress(" \\__path: {:03d}/{:03d} ".format(dim + 1, num_gen_paths), num_gen_paths, dim + 1) # Create shifted latent codes (for the given latent code z) and generate transformed images transformed_images = [] # Current path's latent codes and shifts lists latent_code = x_ if (not args.w_space) and ('stylegan' in gan) and ('W' in stylegan_space): latent_code = G.get_w(x_, truncation=truncation) if stylegan_space == 'W': latent_code = latent_code[:, 0, :] current_path_latent_codes = [latent_code] current_path_latent_shifts = [torch.zeros_like(latent_code).cuda() if use_cuda else torch.zeros_like(latent_code)] ## ====================================================================================================== ## ## ## ## [ Traverse through current path (positive/negative directions) ] ## ## ## ## ====================================================================================================== ## # == Positive direction == latent_code = x_.clone() if (not args.w_space) and ('stylegan' in gan) and ('W' in stylegan_space): latent_code = G.get_w(x_, truncation=truncation).clone() if stylegan_space == 'W': latent_code = latent_code[:, 0, :] cnt = 0 for k in range(args.shift_steps): cnt += 1 # Calculate shift vector based on current z support_sets_mask = torch.zeros(1, LSS.num_support_sets) support_sets_mask[0, dim] = 1.0 if use_cuda: support_sets_mask.cuda() # Get latent space shift vector and shifted latent code if ('stylegan' in gan) and (stylegan_space == 'W+'): with torch.no_grad(): shift = args.eps * LSS(support_sets_mask, latent_code[:, :stylegan_layer + 1, :].reshape(latent_code.shape[0], -1)) latent_code = latent_code + \ F.pad(input=shift, pad=(0, (STYLEGAN_LAYERS[gan] - 1 - stylegan_layer) * 512), mode='constant', value=0).reshape_as(latent_code) current_path_latent_code = latent_code else: with torch.no_grad(): shift = args.eps * LSS(support_sets_mask, latent_code) latent_code = latent_code + shift current_path_latent_code = latent_code # Append intermediate latent code # if k != args.shift_steps - 1: # tsne_latent_codes.append(latent_code.clone().cpu().numpy().flatten()) # Store latent codes and shifts if cnt == args.shift_leap: if ('stylegan' in gan) and (stylegan_space == 'W+'): current_path_latent_shifts.append( F.pad(input=shift, pad=(0, (STYLEGAN_LAYERS[gan] - 1 - stylegan_layer) * 512), mode='constant', value=0).reshape_as(latent_code)) else: current_path_latent_shifts.append(shift) current_path_latent_codes.append(current_path_latent_code) cnt = 0 positive_endpoint = latent_code.clone().reshape(1, -1) # tsne_latent_codes.append(positive_endpoint.clone().cpu().numpy().flatten()) # ======================== # == Negative direction == latent_code = x_.clone() if (not args.w_space) and ('stylegan' in gan) and ('W' in stylegan_space): latent_code = G.get_w(x_, truncation=truncation).clone() if stylegan_space == 'W': latent_code = latent_code[:, 0, :] cnt = 0 for k in range(args.shift_steps): cnt += 1 # Calculate shift vector based on current z support_sets_mask = torch.zeros(1, LSS.num_support_sets) support_sets_mask[0, dim] = 1.0 if use_cuda: support_sets_mask.cuda() # Get latent space shift vector and shifted latent code if ('stylegan' in gan) and (stylegan_space == 'W+'): with torch.no_grad(): shift = -args.eps * LSS( support_sets_mask, latent_code[:, :stylegan_layer + 1, :].reshape(latent_code.shape[0], -1)) latent_code = latent_code + \ F.pad(input=shift, pad=(0, (STYLEGAN_LAYERS[gan] - 1 - stylegan_layer) * 512), mode='constant', value=0).reshape_as(latent_code) current_path_latent_code = latent_code else: with torch.no_grad(): shift = -args.eps * LSS(support_sets_mask, latent_code) latent_code = latent_code + shift current_path_latent_code = latent_code # Append intermediate latent code # if k != args.shift_steps - 1: # tsne_latent_codes.append(latent_code.clone().cpu().numpy().flatten()) # Store latent codes and shifts if cnt == args.shift_leap: if ('stylegan' in gan) and (stylegan_space == 'W+'): current_path_latent_shifts = \ [F.pad(input=shift, pad=(0, (STYLEGAN_LAYERS[gan] - 1 - stylegan_layer) * 512), mode='constant', value=0).reshape_as(latent_code)] + current_path_latent_shifts else: current_path_latent_shifts = [shift] + current_path_latent_shifts current_path_latent_codes = [current_path_latent_code] + current_path_latent_codes cnt = 0 negative_endpoint = latent_code.clone().reshape(1, -1) # tsne_latent_codes.append(latent_code.clone().cpu().numpy().flatten()) # ======================== # Calculate latent path phi coefficient (end-to-end distance / latent path length) phi = torch.norm(negative_endpoint - positive_endpoint, dim=1).item() / (2 * args.shift_steps * args.eps) phi_coeffs.update({dim: phi}) # Generate transformed images # Split latent codes and shifts in batches current_path_latent_codes = torch.cat(current_path_latent_codes) current_path_latent_codes_batches = torch.split(current_path_latent_codes, args.batch_size) current_path_latent_shifts = torch.cat(current_path_latent_shifts) current_path_latent_shifts_batches = torch.split(current_path_latent_shifts, args.batch_size) if len(current_path_latent_codes_batches) != len(current_path_latent_shifts_batches): raise AssertionError() else: num_batches = len(current_path_latent_codes_batches) transformed_img = [] for t in range(num_batches): with torch.no_grad(): transformed_img.append(G(current_path_latent_codes_batches[t] + current_path_latent_shifts_batches[t])) transformed_img = torch.cat(transformed_img) # Convert tensors (transformed images) into PIL images for t in range(transformed_img.shape[0]): transformed_images.append(tensor2image(transformed_img[t, :].cpu(), img_size=args.img_size, adaptive=True)) # Save all images in `transformed_images` list under `transformed_images_root_dir//` transformed_images_dir = osp.join(transformed_images_root_dir, 'path_{:03d}'.format(dim)) os.makedirs(transformed_images_dir, exist_ok=True) for t in range(len(transformed_images)): transformed_images[t].save(osp.join(transformed_images_dir, '{:06d}.jpg'.format(t)), "JPEG", quality=args.img_quality, optimize=True, progressive=True) # Save original image if (t == len(transformed_images) // 2) and (dim == 0): transformed_images[t].save(osp.join(latent_code_dir, 'original_image.jpg'), "JPEG", quality=95, optimize=True, progressive=True) # Create strip of images transformed_images_strip = create_strip(image_list=transformed_images, N=args.strip_number, strip_height=args.strip_height) transformed_images_strip.save(osp.join(transformed_images_strips_root_dir, 'path_{:03d}_strip.jpg'.format(dim)), "JPEG", quality=args.img_quality, optimize=True, progressive=True) # Save gif (static original image + traversal gif) transformed_images_gif_frames = create_gif(transformed_images, gif_height=args.gif_height) im = Image.new(mode='RGB', size=(2 * args.gif_height, args.gif_height)) im.save(fp=osp.join(transformed_images_strips_root_dir, 'path_{:03d}.gif'.format(dim)), append_images=transformed_images_gif_frames, save_all=True, optimize=True, loop=0, duration=1000 // args.gif_fps) # Append latent paths paths_latent_codes.append(current_path_latent_codes.unsqueeze(0)) if args.verbose: update_stdout(1) # ============================================================================================================ # # Save all latent paths and shifts for the current latent code (sample) in a tensor of size: # paths_latent_codes : torch.Size([num_gen_paths, 2 * args.shift_steps + 1, G.dim_z]) paths_latent_codes_tensor = torch.cat(paths_latent_codes) torch.save(paths_latent_codes_tensor, osp.join(latent_code_dir, 'paths_latent_codes.pt')) all_paths_latent_codes.append(paths_latent_codes_tensor.cpu().numpy()) if args.verbose: update_stdout(1) print() print() # After processing all latent codes and paths if args.verbose: print("Performing t-SNE on latent codes for visualization...") # # Consolidate all paths for T-SNE visualization (total_paths = num_of_latent_codes * num_gen_paths) # all_paths_np = np.concatenate(all_paths_latent_codes, axis=0) # Shape: [total_paths, steps_per_path, latent_dim] # all_paths_flattened = all_paths_np.reshape(-1, all_paths_np.shape[-1]) # Flatten paths into 2D array for T-SNE # # Apply 3D T-SNE # tsne_model = TSNE(n_components=3, perplexity=30, learning_rate=200, random_state=42) # tsne_transformed = tsne_model.fit_transform(all_paths_flattened) # Shape: [total_points, 3] # path_indices = [] # List to store indices for each path # start_idx = 0 # Starting index for the current path in all_paths_np # steps_per_path = 2 * args.shift_steps + 1 # Number of points in each path # # Iterate over each latent code and its paths # for i in range(num_of_latent_codes): # Loop through latent codes # for dim in range(num_gen_paths): # Loop through directions (paths) # # Generate the indices for this path # indices = list(range(start_idx, start_idx + steps_per_path)) # path_indices.append(indices) # # Update the starting index for the next path # start_idx += steps_per_path all_paths_latent_code_0 = all_paths_latent_codes[0] num_paths, num_steps, _ = all_paths_latent_code_0.shape tsne_latent_codes = all_paths_latent_code_0.reshape(-1, all_paths_latent_code_0.shape[-1]) # Apply 3D T-SNE tsne_model = TSNE(n_components=3, perplexity=30, learning_rate=200, random_state=42) tsne_transformed = tsne_model.fit_transform(tsne_latent_codes) # Shape: [total_points = num_paths * num_steps, 3] # For this specific latent code, generate indices for each of its paths path_indices = [] start_idx = 0 for _ in range(num_paths): indices = list(range(start_idx, start_idx + num_steps)) path_indices.append(indices) start_idx += num_steps tsne_vis_dir = osp.join(out_dir, 'tsne_visualizations') visualize_latent_space( tsne_latent_codes=tsne_transformed, # T-SNE-reduced latent codes semantic_dipoles=semantic_dipoles, # Semantic labels for paths paths=path_indices, # Indices of paths (for a single latent code) output_dir=tsne_vis_dir, save_filename="latent_space_tsne.png" ) # Create summarizing MD files if args.gif or args.strip: # For each interpretable path (warping function), collect the generated image sequences for each original latent # code and collate them into a GIF file print("#. Write summarizing MD files...") # Write .md summary files if args.gif: md_summary_file = osp.join(out_dir, 'results.md') md_summary_file_f = open(md_summary_file, "w") md_summary_file_f.write("# Experiment: {}\n".format(args.exp)) if args.strip: md_summary_strips_file = osp.join(out_dir, 'results_strips.md') md_summary_strips_file_f = open(md_summary_strips_file, "w") md_summary_strips_file_f.write("# Experiment: {}\n".format(args.exp)) if args.gif or args.strip: for dim in range(num_gen_paths): # Append to .md summary files if args.gif: md_summary_file_f.write("### \"{}\" → \"{}\"\n".format(semantic_dipoles[dim][1], semantic_dipoles[dim][0])) md_summary_file_f.write("

\n") if args.strip: md_summary_strips_file_f.write("## \"{}\" → \"{}\"\n".format(semantic_dipoles[dim][1], semantic_dipoles[dim][0])) md_summary_strips_file_f.write("

\n") for lc in latent_codes_dirs: if args.gif: md_summary_file_f.write("\n".format( osp.join(lc, 'paths_strips', 'path_{:03d}.gif'.format(dim)))) if args.strip: md_summary_strips_file_f.write("\n".format( osp.join(lc, 'paths_strips', 'path_{:03d}_strip.jpg'.format(dim)))) if args.gif: md_summary_file_f.write("phi={}\n".format(phi_coeffs[dim])) md_summary_file_f.write("

\n") if args.strip: md_summary_strips_file_f.write("phi={}\n".format(phi_coeffs[dim])) md_summary_strips_file_f.write("

\n") if args.gif: md_summary_file_f.close() if args.strip: md_summary_strips_file_f.close() if __name__ == '__main__': main()