File size: 22,287 Bytes
8c212a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 |
import sys
import os
import os.path as osp
import clip
import json
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
import numpy as np
import time
import shutil
from .aux import TrainingStatTracker, update_progress, update_stdout, sec2dhms
from .config import SEMANTIC_DIPOLES_CORPORA, STYLEGAN_LAYERS
class DataParallelPassthrough(nn.DataParallel):
def __getattr__(self, name):
try:
return super(DataParallelPassthrough, self).__getattr__(name)
except AttributeError:
return getattr(self.module, name)
class Trainer(object):
def __init__(self, params=None, exp_dir=None, use_cuda=False, multi_gpu=False):
if params is None:
raise ValueError("Cannot build a Trainer instance with empty params: params={}".format(params))
else:
self.params = params
self.use_cuda = use_cuda
self.multi_gpu = multi_gpu
# Set output directory for current experiment (wip)
self.wip_dir = osp.join("experiments", "wip", exp_dir)
# Set directory for completed experiment
self.complete_dir = osp.join("experiments", "complete", exp_dir)
# Create log subdirectory and define stat.json file
self.stats_json = osp.join(self.wip_dir, 'stats.json')
if not osp.isfile(self.stats_json):
with open(self.stats_json, 'w') as out:
json.dump({}, out)
# Create models sub-directory
self.models_dir = osp.join(self.wip_dir, 'models')
os.makedirs(self.models_dir, exist_ok=True)
# Define checkpoint model file
self.checkpoint = osp.join(self.models_dir, 'checkpoint.pt')
# Array of iteration times
self.iter_times = np.array([])
# Set up training statistics tracker
self.stat_tracker = TrainingStatTracker()
# Define cosine similarity loss
self.cosine_embedding_loss = nn.CosineEmbeddingLoss()
# Define cross entropy loss
self.cross_entropy_loss = nn.CrossEntropyLoss()
# Define transform of CLIP image encoder
self.clip_img_transform = transforms.Compose([transforms.Resize(224),
transforms.CenterCrop(224),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))])
def contrastive_loss(self, img_batch, txt_batch):
n_img, d_img = img_batch.shape
n_txt, d_txt = txt_batch.shape
# TODO: assert that dimensions are the same?
# Normalise image and text batches
img_batch_l2 = F.normalize(img_batch, p=2, dim=-1)
txt_batch_l2 = F.normalize(txt_batch, p=2, dim=-1)
# Calculate inner product similarity matrix
similarity_matrix = torch.matmul(img_batch_l2, txt_batch_l2.T)
labels = torch.arange(n_img)
return self.cross_entropy_loss(similarity_matrix / self.params.temperature, labels)
def get_starting_iteration(self, latent_support_sets):
"""Check if checkpoint file exists (under `self.models_dir`) and set starting iteration at the checkpoint
iteration; also load checkpoint weights to `latent_support_sets`. Otherwise, set starting iteration to 1 in
order to train from scratch.
Returns:
starting_iter (int): starting iteration
"""
starting_iter = 1
if osp.isfile(self.checkpoint):
checkpoint_dict = torch.load(self.checkpoint)
starting_iter = checkpoint_dict['iter']
latent_support_sets.load_state_dict(checkpoint_dict['latent_support_sets'])
return starting_iter
def log_progress(self, iteration, mean_iter_time, elapsed_time, eta):
"""Log progress (loss + ETA).
Args:
iteration (int) : current iteration
mean_iter_time (float) : mean iteration time
elapsed_time (float) : elapsed time until current iteration
eta (float) : estimated time of experiment completion
"""
# Get current training stats (for the previous `self.params.log_freq` steps) and flush them
stats = self.stat_tracker.get_means()
# Update training statistics json file
with open(self.stats_json) as f:
stats_dict = json.load(f)
stats_dict.update({iteration: stats})
with open(self.stats_json, 'w') as out:
json.dump(stats_dict, out)
# Flush training statistics tracker
self.stat_tracker.flush()
update_progress(" \\__.Training [bs: {}] [iter: {:06d}/{:06d}] ".format(
self.params.batch_size, iteration, self.params.max_iter), self.params.max_iter, iteration + 1)
if iteration < self.params.max_iter - 1:
print()
print(" ===================================================================")
print(" \\__Loss : {:.08f}".format(stats['loss']))
print(" ===================================================================")
print(" \\__Mean iter time : {:.3f} sec".format(mean_iter_time))
print(" \\__Elapsed time : {}".format(sec2dhms(elapsed_time)))
print(" \\__ETA : {}".format(sec2dhms(eta)))
print(" ===================================================================")
update_stdout(8)
def train(self, generator, latent_support_sets, corpus_support_sets, clip_model):
"""GANxPlainer training function.
Args:
generator : non-trainable (pre-trained) GAN generator
latent_support_sets : trainable LSS model -- interpretable latent paths model
corpus_support_sets : non-trainable CSS model -- non-linear paths in the CLIP space
clip_model : non-trainable (pre-trained) CLIP model
"""
# Save initial `latent_support_sets` model as `latent_support_sets_init.pt`
torch.save(latent_support_sets.state_dict(), osp.join(self.models_dir, 'latent_support_sets_init.pt'))
# Save initial `corpus_support_sets` model as `corpus_support_sets_init.pt`
torch.save(corpus_support_sets.state_dict(), osp.join(self.models_dir, 'corpus_support_sets_init.pt'))
# Save prompt corpus list to json
with open(osp.join(self.models_dir, 'semantic_dipoles.json'), 'w') as json_f:
json.dump(SEMANTIC_DIPOLES_CORPORA[self.params.corpus], json_f)
# Upload models to GPU if `self.use_cuda` is set (i.e., if args.cuda and torch.cuda.is_available is True).
if self.use_cuda:
generator.cuda().eval()
clip_model.cuda().eval()
corpus_support_sets.cuda()
latent_support_sets.cuda().train()
else:
generator.eval()
clip_model.eval()
latent_support_sets.train()
# Set latent support sets (LSS) optimizer
latent_support_sets_optim = torch.optim.Adam(latent_support_sets.parameters(), lr=self.params.lr)
# Set learning rate scheduler -- reduce lr after 90% of the total number of training iterations
latent_support_sets_lr_scheduler = StepLR(optimizer=latent_support_sets_optim,
step_size=int(0.9 * self.params.max_iter),
gamma=0.1)
# Get starting iteration
starting_iter = self.get_starting_iteration(latent_support_sets)
# Parallelize models into multiple GPUs, if available and `multi_gpu=True`.
if self.multi_gpu:
print("#. Parallelize G and CLIP over {} GPUs...".format(torch.cuda.device_count()))
# Parallelize generator G
generator = DataParallelPassthrough(generator)
# Parallelize CLIP model
clip_model = DataParallelPassthrough(clip_model)
# Check starting iteration
if starting_iter == self.params.max_iter:
print("#. This experiment has already been completed and can be found @ {}".format(self.wip_dir))
print("#. Copy {} to {}...".format(self.wip_dir, self.complete_dir))
try:
shutil.copytree(src=self.wip_dir, dst=self.complete_dir, ignore=shutil.ignore_patterns('checkpoint.pt'))
print(" \\__Done!")
except IOError as e:
print(" \\__Already exists -- {}".format(e))
sys.exit()
print("#. Start training from iteration {}".format(starting_iter))
# Get experiment's start time
t0 = time.time()
# Start training
for iteration in range(starting_iter, self.params.max_iter + 1):
# Get current iteration's start time
iter_t0 = time.time()
# Set gradients to zero
generator.zero_grad()
latent_support_sets.zero_grad()
clip_model.zero_grad()
# Sample latent codes from standard Gaussian
z = torch.randn(self.params.batch_size, generator.dim_z)
if self.use_cuda:
z = z.cuda()
# Generate images for the given latent codes
latent_code = z
if 'stylegan' in self.params.gan:
if self.params.stylegan_space == 'W':
latent_code = generator.get_w(z, truncation=self.params.truncation)[:, 0, :]
elif self.params.stylegan_space == 'W+':
latent_code = generator.get_w(z, truncation=self.params.truncation)
img = generator(latent_code)
# Sample indices of shift vectors (`self.params.batch_size` out of `self.params.num_support_sets`)
# target_support_sets_indices = torch.randint(0, self.params.num_support_sets, [self.params.batch_size])
target_support_sets_indices = torch.randint(0, latent_support_sets.num_support_sets,
[self.params.batch_size])
if self.use_cuda:
target_support_sets_indices = target_support_sets_indices.cuda()
# Sample shift magnitudes from uniform distributions
# U[self.params.min_shift_magnitude, self.params.max_shift_magnitude], and
# U[-self.params.max_shift_magnitude, self.params.min_shift_magnitude]
# Create a pool of shift magnitudes of 2 * `self.params.batch_size` shifts (half negative, half positive)
# and sample `self.params.batch_size` of them
shift_magnitudes_pos = (self.params.min_shift_magnitude - self.params.max_shift_magnitude) * \
torch.rand(target_support_sets_indices.size()) + self.params.max_shift_magnitude
shift_magnitudes_neg = (self.params.min_shift_magnitude - self.params.max_shift_magnitude) * \
torch.rand(target_support_sets_indices.size()) - self.params.min_shift_magnitude
shift_magnitudes_pool = torch.cat((shift_magnitudes_neg, shift_magnitudes_pos))
shift_magnitudes_ids = torch.arange(len(shift_magnitudes_pool), dtype=torch.float)
target_shift_magnitudes = shift_magnitudes_pool[torch.multinomial(input=shift_magnitudes_ids,
num_samples=self.params.batch_size,
replacement=False)]
if self.use_cuda:
target_shift_magnitudes = target_shift_magnitudes.cuda()
# Create support sets mask of size (batch_size, num_support_sets) in the form:
# support_sets_mask[i] = [0, ..., 0, 1, 0, ..., 0]
support_sets_mask = torch.zeros([self.params.batch_size, latent_support_sets.num_support_sets])
prompt_mask = torch.zeros([self.params.batch_size, 2])
prompt_sign = torch.zeros([self.params.batch_size, 1])
if self.use_cuda:
support_sets_mask = support_sets_mask.cuda()
prompt_mask = prompt_mask.cuda()
prompt_sign = prompt_sign.cuda()
for i, (index, val) in enumerate(zip(target_support_sets_indices, target_shift_magnitudes)):
support_sets_mask[i][index] += 1.0
if val >= 0:
prompt_mask[i, 0] = 1.0
prompt_sign[i] = +1.0
else:
prompt_mask[i, 1] = 1.0
prompt_sign[i] = -1.0
prompt_mask = prompt_mask.unsqueeze(1)
# Calculate shift vectors for the given latent codes -- in the case of StyleGAN, shifts live in the
# self.params.stylegan_space, i.e., in Z-, W-, or W+-space. In the Z-/W-space the dimensionality of the
# latent space is 512. In the case of W+-space, the dimensionality is 512 * (self.params.stylegan_layer + 1)
if ('stylegan' in self.params.gan) and (self.params.stylegan_space == 'W+'):
shift = target_shift_magnitudes.reshape(-1, 1) * latent_support_sets(
support_sets_mask, latent_code[:, :self.params.stylegan_layer + 1, :].reshape(latent_code.shape[0],
-1))
else:
shift = target_shift_magnitudes.reshape(-1, 1) * latent_support_sets(support_sets_mask, latent_code)
# Generate images the shifted latent codes
if ('stylegan' in self.params.gan) and (self.params.stylegan_space == 'W+'):
latent_code_reshaped = latent_code.reshape(latent_code.shape[0], -1)
shift = F.pad(input=shift,
pad=(0, (STYLEGAN_LAYERS[self.params.gan] - 1 - self.params.stylegan_layer) * 512),
mode='constant',
value=0)
latent_code_shifted = latent_code_reshaped + shift
latent_code_shifted_reshaped = latent_code_shifted.reshape_as(latent_code)
img_shifted = generator(latent_code_shifted_reshaped)
else:
img_shifted = generator(latent_code + shift)
# TODO: add comment
img_pairs = torch.cat([self.clip_img_transform(img), self.clip_img_transform(img_shifted)], dim=0)
clip_img_pairs_features = clip_model.encode_image(img_pairs)
clip_img_features, clip_img_shifted_features = torch.split(clip_img_pairs_features, img.shape[0], dim=0)
clip_img_diff_features = clip_img_shifted_features - clip_img_features
############################################################################################################
## ##
## Linear Text Paths (StyleCLIP approach) ##
## ##
############################################################################################################
if self.params.styleclip:
corpus_text_features_batch = torch.matmul(support_sets_mask, corpus_support_sets.SUPPORT_SETS).reshape(
-1, 2 * corpus_support_sets.num_support_dipoles, corpus_support_sets.support_vectors_dim)
corpus_text_features_batch = torch.matmul(prompt_mask, corpus_text_features_batch).squeeze(1)
# Calculate cosine similarity loss
if self.params.loss == 'cossim':
loss = self.cosine_embedding_loss(clip_img_shifted_features, corpus_text_features_batch,
torch.ones(corpus_text_features_batch.shape[0]).to(
'cuda' if self.use_cuda else 'cpu'))
# Calculate contrastive loss
elif self.params.loss == 'contrastive':
loss = self.contrastive_loss(clip_img_shifted_features.float(), corpus_text_features_batch)
############################################################################################################
## ##
## Linear Text Paths ##
## ##
############################################################################################################
elif self.params.linear:
corpus_text_features_batch = torch.matmul(support_sets_mask, corpus_support_sets.SUPPORT_SETS).reshape(
-1, 2 * corpus_support_sets.num_support_dipoles, corpus_support_sets.support_vectors_dim)
# Calculate cosine similarity loss
if self.params.loss == 'cossim':
loss = self.cosine_embedding_loss(clip_img_diff_features, prompt_sign * (
corpus_text_features_batch[:, 0, :] - corpus_text_features_batch[:, 1, :]) -
clip_img_features,
torch.ones(corpus_text_features_batch.shape[0]).to(
'cuda' if self.use_cuda else 'cpu'))
# Calculate contrastive loss
elif self.params.loss == 'contrastive':
loss = self.contrastive_loss(clip_img_diff_features.float(), prompt_sign * (
corpus_text_features_batch[:, 0, :] - corpus_text_features_batch[:, 1, :]) -
clip_img_features)
############################################################################################################
## ##
## Non-linear Text Paths ##
## ##
############################################################################################################
else:
# Calculate local text direction using CSS
local_text_directions = target_shift_magnitudes.reshape(-1, 1) * corpus_support_sets(support_sets_mask,
clip_img_features)
# Calculate cosine similarity loss
if self.params.loss == 'cossim':
loss = self.cosine_embedding_loss(clip_img_diff_features, local_text_directions,
torch.ones(local_text_directions.shape[0]).to(
'cuda' if self.use_cuda else 'cpu'))
# Calculate contrastive loss
elif self.params.loss == 'contrastive':
loss = self.contrastive_loss(img_batch=clip_img_diff_features.float(),
txt_batch=local_text_directions)
# Back-propagate!
loss.backward()
# Update weights
clip_model.float()
latent_support_sets_optim.step()
latent_support_sets_lr_scheduler.step()
clip.model.convert_weights(clip_model)
# Update statistics tracker
self.stat_tracker.update(loss=loss.item())
# Get time of completion of current iteration
iter_t = time.time()
# Compute elapsed time for current iteration and append to `iter_times`
self.iter_times = np.append(self.iter_times, iter_t - iter_t0)
# Compute elapsed time so far
elapsed_time = iter_t - t0
# Compute rolling mean iteration time
mean_iter_time = self.iter_times.mean()
# Compute estimated time of experiment completion
eta = elapsed_time * ((self.params.max_iter - iteration) / (iteration - starting_iter + 1))
# Log progress in stdout
if iteration % self.params.log_freq == 0:
self.log_progress(iteration, mean_iter_time, elapsed_time, eta)
# Save checkpoint model file and latent support_sets model state dicts after current iteration
if iteration % self.params.ckp_freq == 0:
# Build checkpoint dict
checkpoint_dict = {
'iter': iteration,
'latent_support_sets': latent_support_sets.state_dict(),
}
torch.save(checkpoint_dict, self.checkpoint)
# === End of training loop ===
# Get experiment's total elapsed time
elapsed_time = time.time() - t0
# Save final latent support sets (LSS) model
latent_support_sets_model_filename = osp.join(self.models_dir, 'latent_support_sets.pt')
torch.save(latent_support_sets.state_dict(), latent_support_sets_model_filename)
for _ in range(10):
print()
print("#.Training completed -- Total elapsed time: {}.".format(sec2dhms(elapsed_time)))
print("#. Copy {} to {}...".format(self.wip_dir, self.complete_dir))
try:
shutil.copytree(src=self.wip_dir, dst=self.complete_dir)
print(" \\__Done!")
except IOError as e:
print(" \\__Already exists -- {}".format(e))
|