|
import argparse |
|
import os |
|
import os.path as osp |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from PIL import Image, ImageDraw |
|
import json |
|
from torchvision.transforms import ToPILImage |
|
from lib import SupportSets, GENFORCE_MODELS, update_progress, update_stdout, STYLEGAN_LAYERS |
|
from models.load_generator import load_generator |
|
|
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from mpl_toolkits.mplot3d import Axes3D |
|
from sklearn.manifold import TSNE |
|
|
|
class DataParallelPassthrough(nn.DataParallel): |
|
def __getattr__(self, name): |
|
try: |
|
return super(DataParallelPassthrough, self).__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.module, name) |
|
|
|
|
|
class ModelArgs: |
|
def __init__(self, **kwargs): |
|
self.__dict__.update(kwargs) |
|
|
|
|
|
def tensor2image(tensor, img_size=None, adaptive=False): |
|
|
|
tensor = tensor.squeeze(dim=0) |
|
if adaptive: |
|
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) |
|
if img_size: |
|
return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)).resize((img_size, img_size)) |
|
else: |
|
return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)) |
|
else: |
|
tensor = (tensor + 1) / 2 |
|
tensor.clamp(0, 1) |
|
if img_size: |
|
return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)).resize((img_size, img_size)) |
|
else: |
|
return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)) |
|
|
|
|
|
def one_hot(dims, value, idx): |
|
vec = torch.zeros(dims) |
|
vec[idx] = value |
|
return vec |
|
|
|
|
|
def create_strip(image_list, N=5, strip_height=256): |
|
"""Create strip of images across a given latent path. |
|
|
|
Args: |
|
image_list (list) : list of images (PIL.Image.Image) across a given latent path |
|
N (int) : number of images in strip |
|
strip_height (int) : strip height in pixels -- its width will be N * strip_height |
|
|
|
|
|
Returns: |
|
transformed_images_strip (PIL.Image.Image) : strip PIL image |
|
""" |
|
step = len(image_list) // N + 1 |
|
transformed_images_strip = Image.new('RGB', (N * strip_height, strip_height)) |
|
for i in range(N): |
|
j = i * step if i * step < len(image_list) else len(image_list) - 1 |
|
transformed_images_strip.paste(image_list[j].resize((strip_height, strip_height)), (i * strip_height, 0)) |
|
return transformed_images_strip |
|
|
|
|
|
def create_gif(image_list, gif_height=256): |
|
"""Create gif frames for images across a given latent path. |
|
|
|
Args: |
|
image_list (list) : list of images (PIL.Image.Image) across a given latent path |
|
gif_height (int) : gif height in pixels -- its width will be N * gif_height |
|
|
|
Returns: |
|
transformed_images_gif_frames (list): list of gif frames in PIL (PIL.Image.Image) |
|
""" |
|
transformed_images_gif_frames = [] |
|
for i in range(len(image_list)): |
|
|
|
gif_frame = Image.new('RGB', (2 * gif_height, gif_height)) |
|
gif_frame.paste(image_list[len(image_list) // 2].resize((gif_height, gif_height)), (0, 0)) |
|
gif_frame.paste(image_list[i].resize((gif_height, gif_height)), (gif_height, 0)) |
|
|
|
|
|
draw_bar = ImageDraw.Draw(gif_frame) |
|
bar_h = 12 |
|
bar_colour = (252, 186, 3) |
|
draw_bar.rectangle(xy=((gif_height, gif_height - bar_h), |
|
((1 + (i / len(image_list))) * gif_height, gif_height)), |
|
fill=bar_colour) |
|
|
|
transformed_images_gif_frames.append(gif_frame) |
|
|
|
return transformed_images_gif_frames |
|
|
|
def visualize_latent_space(tsne_latent_codes, semantic_dipoles, output_dir, save_filename="latent_space_tsne.png", shift_steps=16): |
|
""" |
|
Visualize the t-SNE reduced latent space with minimal annotations. |
|
|
|
Args: |
|
tsne_latent_codes (np.ndarray): The 3D latent codes after t-SNE transformation. |
|
semantic_dipoles (list): List of semantic directions (labels) for paths. |
|
shift_steps (int): Number of positive/negative steps along each path. |
|
output_dir (str): Directory to save the generated plot. |
|
save_filename (str): Name of the file to save the plot. |
|
""" |
|
fig = plt.figure(figsize=(16, 12)) |
|
ax = fig.add_subplot(111, projection='3d') |
|
|
|
num_paths = len(semantic_dipoles) |
|
cmap = plt.cm.get_cmap('tab10', num_paths) |
|
|
|
for i in range(num_paths): |
|
|
|
start_idx = i * (2 * shift_steps + 1) |
|
pos_idx = start_idx + shift_steps |
|
neg_idx = start_idx + 2 * shift_steps |
|
|
|
|
|
path_indices = list(range(start_idx, neg_idx + 1)) |
|
path_coords = tsne_latent_codes[path_indices] |
|
|
|
|
|
ax.plot( |
|
path_coords[:, 0], path_coords[:, 1], path_coords[:, 2], |
|
color=cmap(i), |
|
linewidth=2 |
|
) |
|
|
|
|
|
pos_coords = tsne_latent_codes[pos_idx] |
|
neg_coords = tsne_latent_codes[neg_idx] |
|
|
|
|
|
ax.scatter(*pos_coords, color=cmap(i), s=100, label=f"{semantic_dipoles[i][0]} → {semantic_dipoles[i][1]}") |
|
ax.scatter(*neg_coords, color=cmap(i), s=100) |
|
|
|
|
|
ax.legend(loc='best', fontsize=10) |
|
|
|
|
|
ax.set_title("t-SNE Latent Space Visualization") |
|
ax.set_xlabel("t-SNE Dimension 1") |
|
ax.set_ylabel("t-SNE Dimension 2") |
|
ax.set_zlabel("t-SNE Dimension 3") |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
save_path = osp.join(output_dir, save_filename) |
|
plt.savefig(save_path, dpi=300, bbox_inches="tight") |
|
print(f"Visualization saved to {save_path}") |
|
|
|
|
|
def main(): |
|
"""ContraCLIP -- Latent space traversal script. |
|
|
|
A script for traversing the latent space of a pre-trained GAN generator through paths defined by the warpings of |
|
a set of pre-trained support vectors. Latent codes are drawn from a pre-defined collection via the `--pool` |
|
argument. The generated images are stored under `results/` directory. |
|
|
|
Options: |
|
================================================================================================================ |
|
-v, --verbose : set verbose mode on |
|
================================================================================================================ |
|
--exp : set experiment's model dir, as created by `train.py`, i.e., it should contain a subdirectory |
|
`models/` with two files, namely `reconstructor.pt` and `support_sets.pt`, which |
|
contain the weights for the reconstructor and the support sets, respectively, and an `args.json` |
|
file that contains the arguments the model has been trained with. |
|
--pool : directory of pre-defined pool of latent codes (created by `sample_gan.py`) |
|
--w-space : latent codes in the pool are in W/W+ space (typically as inverted codes of real images) |
|
================================================================================================================ |
|
--shift-steps : set number of shifts to be applied to each latent code at each direction (positive/negative). |
|
That is, the total number of shifts applied to each latent code will be equal to |
|
2 * args.shift_steps. |
|
--eps : set shift step magnitude for generating G(z'), where z' = z +/- eps * direction. |
|
--shift-leap : set path shift leap (after how many steps to generate images) |
|
--batch-size : set generator batch size (if not set, use the total number of images per path) |
|
--img-size : set size of saved generated images (if not set, use the output size of the respective GAN |
|
generator) |
|
--img-quality : JPEG image quality (max 95) |
|
--gif : generate collated GIF images for all paths and all latent codes |
|
--gif-height : set GIF image height -- width will be 2 * args.gif_height |
|
--gif-fps : set number of frames per second for the generated GIF images |
|
--strip : create traversal strip images |
|
--strip-number : set number of images per strip |
|
--strip-height : set strip height -- width will be 2 * args.strip_height |
|
================================================================================================================ |
|
--cuda : use CUDA (default) |
|
--no-cuda : do not use CUDA |
|
================================================================================================================ |
|
""" |
|
parser = argparse.ArgumentParser(description="ContraCLIP latent space traversal script") |
|
parser.add_argument('-v', '--verbose', action='store_true', help="set verbose mode on") |
|
|
|
parser.add_argument('--w-space', action='store_true', help="latent codes are given in the W-space") |
|
parser.add_argument('--exp', type=str, required=True, help="set experiment's model dir (created by `train.py`)") |
|
parser.add_argument('--pool', type=str, required=True, help="directory of pre-defined pool of latent codes" |
|
"(created by `sample_gan.py`)") |
|
parser.add_argument('--shift-steps', type=int, default=16, help="set number of shifts per positive/negative path " |
|
"direction") |
|
parser.add_argument('--eps', type=float, default=0.2, help="set shift step magnitude") |
|
parser.add_argument('--shift-leap', type=int, default=1, |
|
help="set path shift leap (after how many steps to generate images)") |
|
parser.add_argument('--batch-size', type=int, help="set generator batch size (if not set, use the total number of " |
|
"images per path)") |
|
parser.add_argument('--img-size', type=int, help="set size of saved generated images (if not set, use the output " |
|
"size of the respective GAN generator)") |
|
parser.add_argument('--img-quality', type=int, default=50, help="set JPEG image quality") |
|
|
|
parser.add_argument('--strip', action='store_true', help="create traversal strip images") |
|
parser.add_argument('--strip-number', type=int, default=9, help="set number of images per strip") |
|
parser.add_argument('--strip-height', type=int, default=256, help="set strip height") |
|
parser.add_argument('--gif', action='store_true', help="create GIF traversals") |
|
parser.add_argument('--gif-height', type=int, default=256, help="set gif height") |
|
parser.add_argument('--gif-fps', type=int, default=30, help="set gif frame rate") |
|
|
|
parser.add_argument('--cuda', dest='cuda', action='store_true', help="use CUDA during training") |
|
parser.add_argument('--no-cuda', dest='cuda', action='store_false', help="do NOT use CUDA during training") |
|
parser.set_defaults(cuda=True) |
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if not osp.isdir(args.exp): |
|
raise NotADirectoryError("Invalid given directory: {}".format(args.exp)) |
|
|
|
|
|
args_json_file = osp.join(args.exp, 'args.json') |
|
if not osp.isfile(args_json_file): |
|
raise FileNotFoundError("File not found: {}".format(args_json_file)) |
|
args_json = ModelArgs(**json.load(open(args_json_file))) |
|
gan = args_json.__dict__["gan"] |
|
stylegan_space = args_json.__dict__["stylegan_space"] |
|
stylegan_layer = args_json.__dict__["stylegan_layer"] if "stylegan_layer" in args_json.__dict__ else None |
|
truncation = args_json.__dict__["truncation"] |
|
|
|
|
|
if args.w_space and (('stylegan' not in gan) or ('W' not in stylegan_space)): |
|
raise NotImplementedError |
|
|
|
|
|
models_dir = osp.join(args.exp, 'models') |
|
if not osp.isdir(models_dir): |
|
raise NotADirectoryError("Invalid models directory: {}".format(models_dir)) |
|
|
|
|
|
models_dir_files = [f for f in os.listdir(models_dir) if osp.isfile(osp.join(models_dir, f))] |
|
|
|
|
|
latent_support_sets_model = osp.join(models_dir, 'latent_support_sets.pt') |
|
model_iter = '' |
|
if not osp.isfile(latent_support_sets_model): |
|
latent_support_sets_checkpoint_files = [] |
|
for f in models_dir_files: |
|
if 'latent_support_sets-' in f: |
|
latent_support_sets_checkpoint_files.append(f) |
|
latent_support_sets_checkpoint_files.sort() |
|
latent_support_sets_model = osp.join(models_dir, latent_support_sets_checkpoint_files[-1]) |
|
model_iter = '-{}'.format(latent_support_sets_checkpoint_files[-1].split('.')[0].split('-')[-1]) |
|
|
|
|
|
with open(osp.join(models_dir, 'semantic_dipoles.json'), 'r') as f: |
|
semantic_dipoles = json.load(f) |
|
|
|
|
|
|
|
|
|
pool = osp.join('experiments', 'latent_codes', gan, args.pool) |
|
if not osp.isdir(pool): |
|
raise NotADirectoryError("Invalid pool directory: {} -- Please run sample_gan.py to create it.".format(pool)) |
|
|
|
|
|
use_cuda = False |
|
multi_gpu = False |
|
if torch.cuda.is_available(): |
|
if args.cuda: |
|
use_cuda = True |
|
torch.set_default_tensor_type('torch.cuda.FloatTensor') |
|
if torch.cuda.device_count() > 1: |
|
multi_gpu = True |
|
else: |
|
print("*** WARNING ***: It looks like you have a CUDA device, but aren't using CUDA.\n" |
|
" Run with --cuda for optimal training speed.") |
|
torch.set_default_tensor_type('torch.FloatTensor') |
|
else: |
|
torch.set_default_tensor_type('torch.FloatTensor') |
|
|
|
|
|
if args.verbose: |
|
print("#. Build GAN generator model G and load with pre-trained weights...") |
|
print(" \\__GAN generator : {} (res: {})".format(gan, GENFORCE_MODELS[gan][1])) |
|
print(" \\__Pre-trained weights: {}".format(GENFORCE_MODELS[gan][0])) |
|
|
|
G = load_generator(model_name=gan, |
|
latent_is_w=('stylegan' in gan) and ('W' in args_json.__dict__["stylegan_space"]), |
|
verbose=args.verbose).eval() |
|
|
|
|
|
if use_cuda: |
|
G = G.cuda() |
|
|
|
|
|
if multi_gpu: |
|
G = DataParallelPassthrough(G) |
|
|
|
|
|
if args.verbose: |
|
print("#. Build Latent Support Sets model LSS...") |
|
|
|
|
|
support_vectors_dim = G.dim_z |
|
if ('stylegan' in gan) and (stylegan_space == 'W+'): |
|
support_vectors_dim *= (stylegan_layer + 1) |
|
|
|
LSS = SupportSets(num_support_sets=len(semantic_dipoles), |
|
num_support_dipoles=args_json.__dict__["num_latent_support_dipoles"], |
|
support_vectors_dim=support_vectors_dim, |
|
jung_radius=1) |
|
|
|
|
|
if args.verbose: |
|
print(" \\__Pre-trained weights: {}".format(latent_support_sets_model)) |
|
LSS.load_state_dict(torch.load(latent_support_sets_model, map_location=lambda storage, loc: storage)) |
|
if args.verbose: |
|
print(" \\__Set to evaluation mode") |
|
LSS.eval() |
|
|
|
|
|
if use_cuda: |
|
LSS = LSS.cuda() |
|
|
|
|
|
num_gen_paths = LSS.num_support_sets |
|
|
|
|
|
out_dir = osp.join(args.exp, 'results', args.pool + model_iter, |
|
'{}_{}_{}'.format(2 * args.shift_steps, args.eps, round(2 * args.shift_steps * args.eps, 3))) |
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
|
|
if args.batch_size is None: |
|
args.batch_size = 2 * args.shift_steps + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.verbose: |
|
print("#. Use latent codes from pool {}...".format(args.pool)) |
|
latent_codes_dirs = [dI for dI in os.listdir(pool) if os.path.isdir(os.path.join(pool, dI))] |
|
latent_codes_dirs.sort() |
|
latent_codes_list = [torch.load(osp.join(pool, subdir, 'latent_code_{}.pt'.format('w+' if args.w_space else 'z')), |
|
map_location=lambda storage, loc: storage) for subdir in latent_codes_dirs] |
|
|
|
|
|
xs = torch.cat(latent_codes_list) |
|
if use_cuda: |
|
xs = xs.cuda() |
|
num_of_latent_codes = xs.size()[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.verbose: |
|
print("#. Traverse latent space...") |
|
print(" \\__Experiment : {}".format(osp.basename(osp.abspath(args.exp)))) |
|
print(" \\__Number of test latent codes : {}".format(num_of_latent_codes)) |
|
print(" \\__Test latent codes shape : {}".format(xs.shape)) |
|
print(" \\__Shift magnitude : {}".format(args.eps)) |
|
print(" \\__Shift steps : {}".format(2 * args.shift_steps)) |
|
print(" \\__Traversal length : {}".format(round(2 * args.shift_steps * args.eps, 3))) |
|
|
|
|
|
all_paths_latent_codes = [] |
|
|
|
|
|
for i in range(num_of_latent_codes): |
|
|
|
x_ = xs[i, :].unsqueeze(0) |
|
|
|
latent_code_hash = latent_codes_dirs[i] |
|
if args.verbose: |
|
update_progress(" \\__.Latent code hash: {} [{:03d}/{:03d}] ".format(latent_code_hash, |
|
i+1, |
|
num_of_latent_codes), |
|
num_of_latent_codes, i) |
|
|
|
|
|
|
|
|
|
|
|
latent_code_dir = osp.join(out_dir, '{}'.format(latent_code_hash)) |
|
os.makedirs(latent_code_dir, exist_ok=True) |
|
|
|
|
|
transformed_images_root_dir = osp.join(latent_code_dir, 'paths_images') |
|
os.makedirs(transformed_images_root_dir, exist_ok=True) |
|
transformed_images_strips_root_dir = osp.join(latent_code_dir, 'paths_strips') |
|
os.makedirs(transformed_images_strips_root_dir, exist_ok=True) |
|
|
|
|
|
paths_latent_codes = [] |
|
|
|
|
|
phi_coeffs = dict() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for dim in range(num_gen_paths): |
|
if args.verbose: |
|
print() |
|
update_progress(" \\__path: {:03d}/{:03d} ".format(dim + 1, num_gen_paths), num_gen_paths, dim + 1) |
|
|
|
|
|
transformed_images = [] |
|
|
|
|
|
latent_code = x_ |
|
if (not args.w_space) and ('stylegan' in gan) and ('W' in stylegan_space): |
|
latent_code = G.get_w(x_, truncation=truncation) |
|
if stylegan_space == 'W': |
|
latent_code = latent_code[:, 0, :] |
|
current_path_latent_codes = [latent_code] |
|
current_path_latent_shifts = [torch.zeros_like(latent_code).cuda() if use_cuda |
|
else torch.zeros_like(latent_code)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latent_code = x_.clone() |
|
if (not args.w_space) and ('stylegan' in gan) and ('W' in stylegan_space): |
|
latent_code = G.get_w(x_, truncation=truncation).clone() |
|
if stylegan_space == 'W': |
|
latent_code = latent_code[:, 0, :] |
|
|
|
cnt = 0 |
|
for k in range(args.shift_steps): |
|
cnt += 1 |
|
|
|
|
|
support_sets_mask = torch.zeros(1, LSS.num_support_sets) |
|
support_sets_mask[0, dim] = 1.0 |
|
if use_cuda: |
|
support_sets_mask.cuda() |
|
|
|
|
|
if ('stylegan' in gan) and (stylegan_space == 'W+'): |
|
with torch.no_grad(): |
|
shift = args.eps * LSS(support_sets_mask, |
|
latent_code[:, :stylegan_layer + 1, :].reshape(latent_code.shape[0], -1)) |
|
latent_code = latent_code + \ |
|
F.pad(input=shift, pad=(0, (STYLEGAN_LAYERS[gan] - 1 - stylegan_layer) * 512), |
|
mode='constant', value=0).reshape_as(latent_code) |
|
current_path_latent_code = latent_code |
|
else: |
|
with torch.no_grad(): |
|
shift = args.eps * LSS(support_sets_mask, latent_code) |
|
latent_code = latent_code + shift |
|
current_path_latent_code = latent_code |
|
|
|
|
|
|
|
|
|
|
|
|
|
if cnt == args.shift_leap: |
|
if ('stylegan' in gan) and (stylegan_space == 'W+'): |
|
current_path_latent_shifts.append( |
|
F.pad(input=shift, pad=(0, (STYLEGAN_LAYERS[gan] - 1 - stylegan_layer) * 512), |
|
mode='constant', value=0).reshape_as(latent_code)) |
|
else: |
|
current_path_latent_shifts.append(shift) |
|
current_path_latent_codes.append(current_path_latent_code) |
|
cnt = 0 |
|
positive_endpoint = latent_code.clone().reshape(1, -1) |
|
|
|
|
|
|
|
|
|
|
|
latent_code = x_.clone() |
|
if (not args.w_space) and ('stylegan' in gan) and ('W' in stylegan_space): |
|
latent_code = G.get_w(x_, truncation=truncation).clone() |
|
if stylegan_space == 'W': |
|
latent_code = latent_code[:, 0, :] |
|
cnt = 0 |
|
for k in range(args.shift_steps): |
|
cnt += 1 |
|
|
|
support_sets_mask = torch.zeros(1, LSS.num_support_sets) |
|
support_sets_mask[0, dim] = 1.0 |
|
if use_cuda: |
|
support_sets_mask.cuda() |
|
|
|
|
|
if ('stylegan' in gan) and (stylegan_space == 'W+'): |
|
with torch.no_grad(): |
|
shift = -args.eps * LSS( |
|
support_sets_mask, latent_code[:, :stylegan_layer + 1, :].reshape(latent_code.shape[0], -1)) |
|
latent_code = latent_code + \ |
|
F.pad(input=shift, pad=(0, (STYLEGAN_LAYERS[gan] - 1 - stylegan_layer) * 512), |
|
mode='constant', value=0).reshape_as(latent_code) |
|
current_path_latent_code = latent_code |
|
else: |
|
with torch.no_grad(): |
|
shift = -args.eps * LSS(support_sets_mask, latent_code) |
|
latent_code = latent_code + shift |
|
current_path_latent_code = latent_code |
|
|
|
|
|
|
|
|
|
|
|
|
|
if cnt == args.shift_leap: |
|
if ('stylegan' in gan) and (stylegan_space == 'W+'): |
|
current_path_latent_shifts = \ |
|
[F.pad(input=shift, pad=(0, (STYLEGAN_LAYERS[gan] - 1 - stylegan_layer) * 512), |
|
mode='constant', value=0).reshape_as(latent_code)] + current_path_latent_shifts |
|
else: |
|
current_path_latent_shifts = [shift] + current_path_latent_shifts |
|
current_path_latent_codes = [current_path_latent_code] + current_path_latent_codes |
|
cnt = 0 |
|
negative_endpoint = latent_code.clone().reshape(1, -1) |
|
|
|
|
|
|
|
|
|
|
|
phi = torch.norm(negative_endpoint - positive_endpoint, dim=1).item() / (2 * args.shift_steps * args.eps) |
|
phi_coeffs.update({dim: phi}) |
|
|
|
|
|
|
|
current_path_latent_codes = torch.cat(current_path_latent_codes) |
|
current_path_latent_codes_batches = torch.split(current_path_latent_codes, args.batch_size) |
|
current_path_latent_shifts = torch.cat(current_path_latent_shifts) |
|
current_path_latent_shifts_batches = torch.split(current_path_latent_shifts, args.batch_size) |
|
if len(current_path_latent_codes_batches) != len(current_path_latent_shifts_batches): |
|
raise AssertionError() |
|
else: |
|
num_batches = len(current_path_latent_codes_batches) |
|
|
|
transformed_img = [] |
|
for t in range(num_batches): |
|
with torch.no_grad(): |
|
transformed_img.append(G(current_path_latent_codes_batches[t] + |
|
current_path_latent_shifts_batches[t])) |
|
transformed_img = torch.cat(transformed_img) |
|
|
|
|
|
for t in range(transformed_img.shape[0]): |
|
transformed_images.append(tensor2image(transformed_img[t, :].cpu(), |
|
img_size=args.img_size, |
|
adaptive=True)) |
|
|
|
transformed_images_dir = osp.join(transformed_images_root_dir, 'path_{:03d}'.format(dim)) |
|
os.makedirs(transformed_images_dir, exist_ok=True) |
|
|
|
for t in range(len(transformed_images)): |
|
transformed_images[t].save(osp.join(transformed_images_dir, '{:06d}.jpg'.format(t)), |
|
"JPEG", quality=args.img_quality, optimize=True, progressive=True) |
|
|
|
if (t == len(transformed_images) // 2) and (dim == 0): |
|
transformed_images[t].save(osp.join(latent_code_dir, 'original_image.jpg'), |
|
"JPEG", quality=95, optimize=True, progressive=True) |
|
|
|
|
|
transformed_images_strip = create_strip(image_list=transformed_images, N=args.strip_number, |
|
strip_height=args.strip_height) |
|
transformed_images_strip.save(osp.join(transformed_images_strips_root_dir, |
|
'path_{:03d}_strip.jpg'.format(dim)), |
|
"JPEG", quality=args.img_quality, optimize=True, progressive=True) |
|
|
|
|
|
transformed_images_gif_frames = create_gif(transformed_images, gif_height=args.gif_height) |
|
im = Image.new(mode='RGB', size=(2 * args.gif_height, args.gif_height)) |
|
im.save(fp=osp.join(transformed_images_strips_root_dir, 'path_{:03d}.gif'.format(dim)), |
|
append_images=transformed_images_gif_frames, |
|
save_all=True, |
|
optimize=True, |
|
loop=0, |
|
duration=1000 // args.gif_fps) |
|
|
|
|
|
paths_latent_codes.append(current_path_latent_codes.unsqueeze(0)) |
|
|
|
if args.verbose: |
|
update_stdout(1) |
|
|
|
|
|
|
|
|
|
paths_latent_codes_tensor = torch.cat(paths_latent_codes) |
|
torch.save(paths_latent_codes_tensor, osp.join(latent_code_dir, 'paths_latent_codes.pt')) |
|
all_paths_latent_codes.append(paths_latent_codes_tensor.cpu().numpy()) |
|
|
|
if args.verbose: |
|
update_stdout(1) |
|
print() |
|
print() |
|
|
|
|
|
if args.verbose: |
|
print("Performing t-SNE on latent codes for visualization...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_paths_latent_code_0 = all_paths_latent_codes[0] |
|
num_paths, num_steps, _ = all_paths_latent_code_0.shape |
|
tsne_latent_codes = all_paths_latent_code_0.reshape(-1, all_paths_latent_code_0.shape[-1]) |
|
|
|
|
|
tsne_model = TSNE(n_components=3, perplexity=30, learning_rate=200, random_state=42) |
|
tsne_transformed = tsne_model.fit_transform(tsne_latent_codes) |
|
|
|
|
|
path_indices = [] |
|
start_idx = 0 |
|
for _ in range(num_paths): |
|
indices = list(range(start_idx, start_idx + num_steps)) |
|
path_indices.append(indices) |
|
start_idx += num_steps |
|
|
|
|
|
tsne_vis_dir = osp.join(out_dir, 'tsne_visualizations') |
|
visualize_latent_space( |
|
tsne_latent_codes=tsne_transformed, |
|
semantic_dipoles=semantic_dipoles, |
|
paths=path_indices, |
|
output_dir=tsne_vis_dir, |
|
save_filename="latent_space_tsne.png" |
|
) |
|
|
|
|
|
if args.gif or args.strip: |
|
|
|
|
|
print("#. Write summarizing MD files...") |
|
|
|
|
|
if args.gif: |
|
md_summary_file = osp.join(out_dir, 'results.md') |
|
md_summary_file_f = open(md_summary_file, "w") |
|
md_summary_file_f.write("# Experiment: {}\n".format(args.exp)) |
|
|
|
if args.strip: |
|
md_summary_strips_file = osp.join(out_dir, 'results_strips.md') |
|
md_summary_strips_file_f = open(md_summary_strips_file, "w") |
|
md_summary_strips_file_f.write("# Experiment: {}\n".format(args.exp)) |
|
|
|
if args.gif or args.strip: |
|
for dim in range(num_gen_paths): |
|
|
|
if args.gif: |
|
md_summary_file_f.write("### \"{}\" → \"{}\"\n".format(semantic_dipoles[dim][1], |
|
semantic_dipoles[dim][0])) |
|
md_summary_file_f.write("<p align=\"center\">\n") |
|
if args.strip: |
|
md_summary_strips_file_f.write("## \"{}\" → \"{}\"\n".format(semantic_dipoles[dim][1], |
|
semantic_dipoles[dim][0])) |
|
md_summary_strips_file_f.write("<p align=\"center\">\n") |
|
|
|
for lc in latent_codes_dirs: |
|
if args.gif: |
|
md_summary_file_f.write("<img src=\"{}\" width=\"450\" class=\"center\"/>\n".format( |
|
osp.join(lc, 'paths_strips', 'path_{:03d}.gif'.format(dim)))) |
|
if args.strip: |
|
md_summary_strips_file_f.write("<img src=\"{}\" style=\"width: 75vw\"/>\n".format( |
|
osp.join(lc, 'paths_strips', 'path_{:03d}_strip.jpg'.format(dim)))) |
|
if args.gif: |
|
md_summary_file_f.write("phi={}\n".format(phi_coeffs[dim])) |
|
md_summary_file_f.write("</p>\n") |
|
if args.strip: |
|
md_summary_strips_file_f.write("phi={}\n".format(phi_coeffs[dim])) |
|
md_summary_strips_file_f.write("</p>\n") |
|
|
|
if args.gif: |
|
md_summary_file_f.close() |
|
if args.strip: |
|
md_summary_strips_file_f.close() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|