dattarij's picture
adding ContraCLIP folder
8c212a5
raw
history blame
22.3 kB
import sys
import os
import os.path as osp
import clip
import json
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
import numpy as np
import time
import shutil
from .aux import TrainingStatTracker, update_progress, update_stdout, sec2dhms
from .config import SEMANTIC_DIPOLES_CORPORA, STYLEGAN_LAYERS
class DataParallelPassthrough(nn.DataParallel):
def __getattr__(self, name):
try:
return super(DataParallelPassthrough, self).__getattr__(name)
except AttributeError:
return getattr(self.module, name)
class Trainer(object):
def __init__(self, params=None, exp_dir=None, use_cuda=False, multi_gpu=False):
if params is None:
raise ValueError("Cannot build a Trainer instance with empty params: params={}".format(params))
else:
self.params = params
self.use_cuda = use_cuda
self.multi_gpu = multi_gpu
# Set output directory for current experiment (wip)
self.wip_dir = osp.join("experiments", "wip", exp_dir)
# Set directory for completed experiment
self.complete_dir = osp.join("experiments", "complete", exp_dir)
# Create log subdirectory and define stat.json file
self.stats_json = osp.join(self.wip_dir, 'stats.json')
if not osp.isfile(self.stats_json):
with open(self.stats_json, 'w') as out:
json.dump({}, out)
# Create models sub-directory
self.models_dir = osp.join(self.wip_dir, 'models')
os.makedirs(self.models_dir, exist_ok=True)
# Define checkpoint model file
self.checkpoint = osp.join(self.models_dir, 'checkpoint.pt')
# Array of iteration times
self.iter_times = np.array([])
# Set up training statistics tracker
self.stat_tracker = TrainingStatTracker()
# Define cosine similarity loss
self.cosine_embedding_loss = nn.CosineEmbeddingLoss()
# Define cross entropy loss
self.cross_entropy_loss = nn.CrossEntropyLoss()
# Define transform of CLIP image encoder
self.clip_img_transform = transforms.Compose([transforms.Resize(224),
transforms.CenterCrop(224),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))])
def contrastive_loss(self, img_batch, txt_batch):
n_img, d_img = img_batch.shape
n_txt, d_txt = txt_batch.shape
# TODO: assert that dimensions are the same?
# Normalise image and text batches
img_batch_l2 = F.normalize(img_batch, p=2, dim=-1)
txt_batch_l2 = F.normalize(txt_batch, p=2, dim=-1)
# Calculate inner product similarity matrix
similarity_matrix = torch.matmul(img_batch_l2, txt_batch_l2.T)
labels = torch.arange(n_img)
return self.cross_entropy_loss(similarity_matrix / self.params.temperature, labels)
def get_starting_iteration(self, latent_support_sets):
"""Check if checkpoint file exists (under `self.models_dir`) and set starting iteration at the checkpoint
iteration; also load checkpoint weights to `latent_support_sets`. Otherwise, set starting iteration to 1 in
order to train from scratch.
Returns:
starting_iter (int): starting iteration
"""
starting_iter = 1
if osp.isfile(self.checkpoint):
checkpoint_dict = torch.load(self.checkpoint)
starting_iter = checkpoint_dict['iter']
latent_support_sets.load_state_dict(checkpoint_dict['latent_support_sets'])
return starting_iter
def log_progress(self, iteration, mean_iter_time, elapsed_time, eta):
"""Log progress (loss + ETA).
Args:
iteration (int) : current iteration
mean_iter_time (float) : mean iteration time
elapsed_time (float) : elapsed time until current iteration
eta (float) : estimated time of experiment completion
"""
# Get current training stats (for the previous `self.params.log_freq` steps) and flush them
stats = self.stat_tracker.get_means()
# Update training statistics json file
with open(self.stats_json) as f:
stats_dict = json.load(f)
stats_dict.update({iteration: stats})
with open(self.stats_json, 'w') as out:
json.dump(stats_dict, out)
# Flush training statistics tracker
self.stat_tracker.flush()
update_progress(" \\__.Training [bs: {}] [iter: {:06d}/{:06d}] ".format(
self.params.batch_size, iteration, self.params.max_iter), self.params.max_iter, iteration + 1)
if iteration < self.params.max_iter - 1:
print()
print(" ===================================================================")
print(" \\__Loss : {:.08f}".format(stats['loss']))
print(" ===================================================================")
print(" \\__Mean iter time : {:.3f} sec".format(mean_iter_time))
print(" \\__Elapsed time : {}".format(sec2dhms(elapsed_time)))
print(" \\__ETA : {}".format(sec2dhms(eta)))
print(" ===================================================================")
update_stdout(8)
def train(self, generator, latent_support_sets, corpus_support_sets, clip_model):
"""GANxPlainer training function.
Args:
generator : non-trainable (pre-trained) GAN generator
latent_support_sets : trainable LSS model -- interpretable latent paths model
corpus_support_sets : non-trainable CSS model -- non-linear paths in the CLIP space
clip_model : non-trainable (pre-trained) CLIP model
"""
# Save initial `latent_support_sets` model as `latent_support_sets_init.pt`
torch.save(latent_support_sets.state_dict(), osp.join(self.models_dir, 'latent_support_sets_init.pt'))
# Save initial `corpus_support_sets` model as `corpus_support_sets_init.pt`
torch.save(corpus_support_sets.state_dict(), osp.join(self.models_dir, 'corpus_support_sets_init.pt'))
# Save prompt corpus list to json
with open(osp.join(self.models_dir, 'semantic_dipoles.json'), 'w') as json_f:
json.dump(SEMANTIC_DIPOLES_CORPORA[self.params.corpus], json_f)
# Upload models to GPU if `self.use_cuda` is set (i.e., if args.cuda and torch.cuda.is_available is True).
if self.use_cuda:
generator.cuda().eval()
clip_model.cuda().eval()
corpus_support_sets.cuda()
latent_support_sets.cuda().train()
else:
generator.eval()
clip_model.eval()
latent_support_sets.train()
# Set latent support sets (LSS) optimizer
latent_support_sets_optim = torch.optim.Adam(latent_support_sets.parameters(), lr=self.params.lr)
# Set learning rate scheduler -- reduce lr after 90% of the total number of training iterations
latent_support_sets_lr_scheduler = StepLR(optimizer=latent_support_sets_optim,
step_size=int(0.9 * self.params.max_iter),
gamma=0.1)
# Get starting iteration
starting_iter = self.get_starting_iteration(latent_support_sets)
# Parallelize models into multiple GPUs, if available and `multi_gpu=True`.
if self.multi_gpu:
print("#. Parallelize G and CLIP over {} GPUs...".format(torch.cuda.device_count()))
# Parallelize generator G
generator = DataParallelPassthrough(generator)
# Parallelize CLIP model
clip_model = DataParallelPassthrough(clip_model)
# Check starting iteration
if starting_iter == self.params.max_iter:
print("#. This experiment has already been completed and can be found @ {}".format(self.wip_dir))
print("#. Copy {} to {}...".format(self.wip_dir, self.complete_dir))
try:
shutil.copytree(src=self.wip_dir, dst=self.complete_dir, ignore=shutil.ignore_patterns('checkpoint.pt'))
print(" \\__Done!")
except IOError as e:
print(" \\__Already exists -- {}".format(e))
sys.exit()
print("#. Start training from iteration {}".format(starting_iter))
# Get experiment's start time
t0 = time.time()
# Start training
for iteration in range(starting_iter, self.params.max_iter + 1):
# Get current iteration's start time
iter_t0 = time.time()
# Set gradients to zero
generator.zero_grad()
latent_support_sets.zero_grad()
clip_model.zero_grad()
# Sample latent codes from standard Gaussian
z = torch.randn(self.params.batch_size, generator.dim_z)
if self.use_cuda:
z = z.cuda()
# Generate images for the given latent codes
latent_code = z
if 'stylegan' in self.params.gan:
if self.params.stylegan_space == 'W':
latent_code = generator.get_w(z, truncation=self.params.truncation)[:, 0, :]
elif self.params.stylegan_space == 'W+':
latent_code = generator.get_w(z, truncation=self.params.truncation)
img = generator(latent_code)
# Sample indices of shift vectors (`self.params.batch_size` out of `self.params.num_support_sets`)
# target_support_sets_indices = torch.randint(0, self.params.num_support_sets, [self.params.batch_size])
target_support_sets_indices = torch.randint(0, latent_support_sets.num_support_sets,
[self.params.batch_size])
if self.use_cuda:
target_support_sets_indices = target_support_sets_indices.cuda()
# Sample shift magnitudes from uniform distributions
# U[self.params.min_shift_magnitude, self.params.max_shift_magnitude], and
# U[-self.params.max_shift_magnitude, self.params.min_shift_magnitude]
# Create a pool of shift magnitudes of 2 * `self.params.batch_size` shifts (half negative, half positive)
# and sample `self.params.batch_size` of them
shift_magnitudes_pos = (self.params.min_shift_magnitude - self.params.max_shift_magnitude) * \
torch.rand(target_support_sets_indices.size()) + self.params.max_shift_magnitude
shift_magnitudes_neg = (self.params.min_shift_magnitude - self.params.max_shift_magnitude) * \
torch.rand(target_support_sets_indices.size()) - self.params.min_shift_magnitude
shift_magnitudes_pool = torch.cat((shift_magnitudes_neg, shift_magnitudes_pos))
shift_magnitudes_ids = torch.arange(len(shift_magnitudes_pool), dtype=torch.float)
target_shift_magnitudes = shift_magnitudes_pool[torch.multinomial(input=shift_magnitudes_ids,
num_samples=self.params.batch_size,
replacement=False)]
if self.use_cuda:
target_shift_magnitudes = target_shift_magnitudes.cuda()
# Create support sets mask of size (batch_size, num_support_sets) in the form:
# support_sets_mask[i] = [0, ..., 0, 1, 0, ..., 0]
support_sets_mask = torch.zeros([self.params.batch_size, latent_support_sets.num_support_sets])
prompt_mask = torch.zeros([self.params.batch_size, 2])
prompt_sign = torch.zeros([self.params.batch_size, 1])
if self.use_cuda:
support_sets_mask = support_sets_mask.cuda()
prompt_mask = prompt_mask.cuda()
prompt_sign = prompt_sign.cuda()
for i, (index, val) in enumerate(zip(target_support_sets_indices, target_shift_magnitudes)):
support_sets_mask[i][index] += 1.0
if val >= 0:
prompt_mask[i, 0] = 1.0
prompt_sign[i] = +1.0
else:
prompt_mask[i, 1] = 1.0
prompt_sign[i] = -1.0
prompt_mask = prompt_mask.unsqueeze(1)
# Calculate shift vectors for the given latent codes -- in the case of StyleGAN, shifts live in the
# self.params.stylegan_space, i.e., in Z-, W-, or W+-space. In the Z-/W-space the dimensionality of the
# latent space is 512. In the case of W+-space, the dimensionality is 512 * (self.params.stylegan_layer + 1)
if ('stylegan' in self.params.gan) and (self.params.stylegan_space == 'W+'):
shift = target_shift_magnitudes.reshape(-1, 1) * latent_support_sets(
support_sets_mask, latent_code[:, :self.params.stylegan_layer + 1, :].reshape(latent_code.shape[0],
-1))
else:
shift = target_shift_magnitudes.reshape(-1, 1) * latent_support_sets(support_sets_mask, latent_code)
# Generate images the shifted latent codes
if ('stylegan' in self.params.gan) and (self.params.stylegan_space == 'W+'):
latent_code_reshaped = latent_code.reshape(latent_code.shape[0], -1)
shift = F.pad(input=shift,
pad=(0, (STYLEGAN_LAYERS[self.params.gan] - 1 - self.params.stylegan_layer) * 512),
mode='constant',
value=0)
latent_code_shifted = latent_code_reshaped + shift
latent_code_shifted_reshaped = latent_code_shifted.reshape_as(latent_code)
img_shifted = generator(latent_code_shifted_reshaped)
else:
img_shifted = generator(latent_code + shift)
# TODO: add comment
img_pairs = torch.cat([self.clip_img_transform(img), self.clip_img_transform(img_shifted)], dim=0)
clip_img_pairs_features = clip_model.encode_image(img_pairs)
clip_img_features, clip_img_shifted_features = torch.split(clip_img_pairs_features, img.shape[0], dim=0)
clip_img_diff_features = clip_img_shifted_features - clip_img_features
############################################################################################################
## ##
## Linear Text Paths (StyleCLIP approach) ##
## ##
############################################################################################################
if self.params.styleclip:
corpus_text_features_batch = torch.matmul(support_sets_mask, corpus_support_sets.SUPPORT_SETS).reshape(
-1, 2 * corpus_support_sets.num_support_dipoles, corpus_support_sets.support_vectors_dim)
corpus_text_features_batch = torch.matmul(prompt_mask, corpus_text_features_batch).squeeze(1)
# Calculate cosine similarity loss
if self.params.loss == 'cossim':
loss = self.cosine_embedding_loss(clip_img_shifted_features, corpus_text_features_batch,
torch.ones(corpus_text_features_batch.shape[0]).to(
'cuda' if self.use_cuda else 'cpu'))
# Calculate contrastive loss
elif self.params.loss == 'contrastive':
loss = self.contrastive_loss(clip_img_shifted_features.float(), corpus_text_features_batch)
############################################################################################################
## ##
## Linear Text Paths ##
## ##
############################################################################################################
elif self.params.linear:
corpus_text_features_batch = torch.matmul(support_sets_mask, corpus_support_sets.SUPPORT_SETS).reshape(
-1, 2 * corpus_support_sets.num_support_dipoles, corpus_support_sets.support_vectors_dim)
# Calculate cosine similarity loss
if self.params.loss == 'cossim':
loss = self.cosine_embedding_loss(clip_img_diff_features, prompt_sign * (
corpus_text_features_batch[:, 0, :] - corpus_text_features_batch[:, 1, :]) -
clip_img_features,
torch.ones(corpus_text_features_batch.shape[0]).to(
'cuda' if self.use_cuda else 'cpu'))
# Calculate contrastive loss
elif self.params.loss == 'contrastive':
loss = self.contrastive_loss(clip_img_diff_features.float(), prompt_sign * (
corpus_text_features_batch[:, 0, :] - corpus_text_features_batch[:, 1, :]) -
clip_img_features)
############################################################################################################
## ##
## Non-linear Text Paths ##
## ##
############################################################################################################
else:
# Calculate local text direction using CSS
local_text_directions = target_shift_magnitudes.reshape(-1, 1) * corpus_support_sets(support_sets_mask,
clip_img_features)
# Calculate cosine similarity loss
if self.params.loss == 'cossim':
loss = self.cosine_embedding_loss(clip_img_diff_features, local_text_directions,
torch.ones(local_text_directions.shape[0]).to(
'cuda' if self.use_cuda else 'cpu'))
# Calculate contrastive loss
elif self.params.loss == 'contrastive':
loss = self.contrastive_loss(img_batch=clip_img_diff_features.float(),
txt_batch=local_text_directions)
# Back-propagate!
loss.backward()
# Update weights
clip_model.float()
latent_support_sets_optim.step()
latent_support_sets_lr_scheduler.step()
clip.model.convert_weights(clip_model)
# Update statistics tracker
self.stat_tracker.update(loss=loss.item())
# Get time of completion of current iteration
iter_t = time.time()
# Compute elapsed time for current iteration and append to `iter_times`
self.iter_times = np.append(self.iter_times, iter_t - iter_t0)
# Compute elapsed time so far
elapsed_time = iter_t - t0
# Compute rolling mean iteration time
mean_iter_time = self.iter_times.mean()
# Compute estimated time of experiment completion
eta = elapsed_time * ((self.params.max_iter - iteration) / (iteration - starting_iter + 1))
# Log progress in stdout
if iteration % self.params.log_freq == 0:
self.log_progress(iteration, mean_iter_time, elapsed_time, eta)
# Save checkpoint model file and latent support_sets model state dicts after current iteration
if iteration % self.params.ckp_freq == 0:
# Build checkpoint dict
checkpoint_dict = {
'iter': iteration,
'latent_support_sets': latent_support_sets.state_dict(),
}
torch.save(checkpoint_dict, self.checkpoint)
# === End of training loop ===
# Get experiment's total elapsed time
elapsed_time = time.time() - t0
# Save final latent support sets (LSS) model
latent_support_sets_model_filename = osp.join(self.models_dir, 'latent_support_sets.pt')
torch.save(latent_support_sets.state_dict(), latent_support_sets_model_filename)
for _ in range(10):
print()
print("#.Training completed -- Total elapsed time: {}.".format(sec2dhms(elapsed_time)))
print("#. Copy {} to {}...".format(self.wip_dir, self.complete_dir))
try:
shutil.copytree(src=self.wip_dir, dst=self.complete_dir)
print(" \\__Done!")
except IOError as e:
print(" \\__Already exists -- {}".format(e))