|
import sys |
|
import os |
|
import os.path as osp |
|
import clip |
|
import json |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torch.optim.lr_scheduler import StepLR |
|
from torchvision import transforms |
|
import numpy as np |
|
import time |
|
import shutil |
|
from .aux import TrainingStatTracker, update_progress, update_stdout, sec2dhms |
|
from .config import SEMANTIC_DIPOLES_CORPORA, STYLEGAN_LAYERS |
|
|
|
|
|
class DataParallelPassthrough(nn.DataParallel): |
|
def __getattr__(self, name): |
|
try: |
|
return super(DataParallelPassthrough, self).__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.module, name) |
|
|
|
|
|
class Trainer(object): |
|
def __init__(self, params=None, exp_dir=None, use_cuda=False, multi_gpu=False): |
|
if params is None: |
|
raise ValueError("Cannot build a Trainer instance with empty params: params={}".format(params)) |
|
else: |
|
self.params = params |
|
self.use_cuda = use_cuda |
|
self.multi_gpu = multi_gpu |
|
|
|
|
|
self.wip_dir = osp.join("experiments", "wip", exp_dir) |
|
|
|
|
|
self.complete_dir = osp.join("experiments", "complete", exp_dir) |
|
|
|
|
|
self.stats_json = osp.join(self.wip_dir, 'stats.json') |
|
if not osp.isfile(self.stats_json): |
|
with open(self.stats_json, 'w') as out: |
|
json.dump({}, out) |
|
|
|
|
|
self.models_dir = osp.join(self.wip_dir, 'models') |
|
os.makedirs(self.models_dir, exist_ok=True) |
|
|
|
self.checkpoint = osp.join(self.models_dir, 'checkpoint.pt') |
|
|
|
|
|
self.iter_times = np.array([]) |
|
|
|
|
|
self.stat_tracker = TrainingStatTracker() |
|
|
|
|
|
self.cosine_embedding_loss = nn.CosineEmbeddingLoss() |
|
|
|
|
|
self.cross_entropy_loss = nn.CrossEntropyLoss() |
|
|
|
|
|
self.clip_img_transform = transforms.Compose([transforms.Resize(224), |
|
transforms.CenterCrop(224), |
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), |
|
(0.26862954, 0.26130258, 0.27577711))]) |
|
|
|
def contrastive_loss(self, img_batch, txt_batch): |
|
n_img, d_img = img_batch.shape |
|
n_txt, d_txt = txt_batch.shape |
|
|
|
|
|
|
|
|
|
img_batch_l2 = F.normalize(img_batch, p=2, dim=-1) |
|
txt_batch_l2 = F.normalize(txt_batch, p=2, dim=-1) |
|
|
|
|
|
similarity_matrix = torch.matmul(img_batch_l2, txt_batch_l2.T) |
|
labels = torch.arange(n_img) |
|
|
|
return self.cross_entropy_loss(similarity_matrix / self.params.temperature, labels) |
|
|
|
def get_starting_iteration(self, latent_support_sets): |
|
"""Check if checkpoint file exists (under `self.models_dir`) and set starting iteration at the checkpoint |
|
iteration; also load checkpoint weights to `latent_support_sets`. Otherwise, set starting iteration to 1 in |
|
order to train from scratch. |
|
|
|
Returns: |
|
starting_iter (int): starting iteration |
|
|
|
""" |
|
starting_iter = 1 |
|
if osp.isfile(self.checkpoint): |
|
checkpoint_dict = torch.load(self.checkpoint) |
|
starting_iter = checkpoint_dict['iter'] |
|
latent_support_sets.load_state_dict(checkpoint_dict['latent_support_sets']) |
|
|
|
return starting_iter |
|
|
|
def log_progress(self, iteration, mean_iter_time, elapsed_time, eta): |
|
"""Log progress (loss + ETA). |
|
|
|
Args: |
|
iteration (int) : current iteration |
|
mean_iter_time (float) : mean iteration time |
|
elapsed_time (float) : elapsed time until current iteration |
|
eta (float) : estimated time of experiment completion |
|
""" |
|
|
|
stats = self.stat_tracker.get_means() |
|
|
|
|
|
with open(self.stats_json) as f: |
|
stats_dict = json.load(f) |
|
stats_dict.update({iteration: stats}) |
|
with open(self.stats_json, 'w') as out: |
|
json.dump(stats_dict, out) |
|
|
|
|
|
self.stat_tracker.flush() |
|
|
|
update_progress(" \\__.Training [bs: {}] [iter: {:06d}/{:06d}] ".format( |
|
self.params.batch_size, iteration, self.params.max_iter), self.params.max_iter, iteration + 1) |
|
if iteration < self.params.max_iter - 1: |
|
print() |
|
print(" ===================================================================") |
|
print(" \\__Loss : {:.08f}".format(stats['loss'])) |
|
print(" ===================================================================") |
|
print(" \\__Mean iter time : {:.3f} sec".format(mean_iter_time)) |
|
print(" \\__Elapsed time : {}".format(sec2dhms(elapsed_time))) |
|
print(" \\__ETA : {}".format(sec2dhms(eta))) |
|
print(" ===================================================================") |
|
update_stdout(8) |
|
|
|
def train(self, generator, latent_support_sets, corpus_support_sets, clip_model): |
|
"""GANxPlainer training function. |
|
|
|
Args: |
|
generator : non-trainable (pre-trained) GAN generator |
|
latent_support_sets : trainable LSS model -- interpretable latent paths model |
|
corpus_support_sets : non-trainable CSS model -- non-linear paths in the CLIP space |
|
clip_model : non-trainable (pre-trained) CLIP model |
|
|
|
""" |
|
|
|
torch.save(latent_support_sets.state_dict(), osp.join(self.models_dir, 'latent_support_sets_init.pt')) |
|
|
|
|
|
torch.save(corpus_support_sets.state_dict(), osp.join(self.models_dir, 'corpus_support_sets_init.pt')) |
|
|
|
|
|
with open(osp.join(self.models_dir, 'semantic_dipoles.json'), 'w') as json_f: |
|
json.dump(SEMANTIC_DIPOLES_CORPORA[self.params.corpus], json_f) |
|
|
|
|
|
if self.use_cuda: |
|
generator.cuda().eval() |
|
clip_model.cuda().eval() |
|
corpus_support_sets.cuda() |
|
latent_support_sets.cuda().train() |
|
else: |
|
generator.eval() |
|
clip_model.eval() |
|
latent_support_sets.train() |
|
|
|
|
|
latent_support_sets_optim = torch.optim.Adam(latent_support_sets.parameters(), lr=self.params.lr) |
|
|
|
|
|
latent_support_sets_lr_scheduler = StepLR(optimizer=latent_support_sets_optim, |
|
step_size=int(0.9 * self.params.max_iter), |
|
gamma=0.1) |
|
|
|
|
|
starting_iter = self.get_starting_iteration(latent_support_sets) |
|
|
|
|
|
if self.multi_gpu: |
|
print("#. Parallelize G and CLIP over {} GPUs...".format(torch.cuda.device_count())) |
|
|
|
generator = DataParallelPassthrough(generator) |
|
|
|
clip_model = DataParallelPassthrough(clip_model) |
|
|
|
|
|
if starting_iter == self.params.max_iter: |
|
print("#. This experiment has already been completed and can be found @ {}".format(self.wip_dir)) |
|
print("#. Copy {} to {}...".format(self.wip_dir, self.complete_dir)) |
|
try: |
|
shutil.copytree(src=self.wip_dir, dst=self.complete_dir, ignore=shutil.ignore_patterns('checkpoint.pt')) |
|
print(" \\__Done!") |
|
except IOError as e: |
|
print(" \\__Already exists -- {}".format(e)) |
|
sys.exit() |
|
print("#. Start training from iteration {}".format(starting_iter)) |
|
|
|
|
|
t0 = time.time() |
|
|
|
|
|
for iteration in range(starting_iter, self.params.max_iter + 1): |
|
|
|
|
|
iter_t0 = time.time() |
|
|
|
|
|
generator.zero_grad() |
|
latent_support_sets.zero_grad() |
|
clip_model.zero_grad() |
|
|
|
|
|
z = torch.randn(self.params.batch_size, generator.dim_z) |
|
if self.use_cuda: |
|
z = z.cuda() |
|
|
|
|
|
latent_code = z |
|
if 'stylegan' in self.params.gan: |
|
if self.params.stylegan_space == 'W': |
|
latent_code = generator.get_w(z, truncation=self.params.truncation)[:, 0, :] |
|
elif self.params.stylegan_space == 'W+': |
|
latent_code = generator.get_w(z, truncation=self.params.truncation) |
|
img = generator(latent_code) |
|
|
|
|
|
|
|
target_support_sets_indices = torch.randint(0, latent_support_sets.num_support_sets, |
|
[self.params.batch_size]) |
|
if self.use_cuda: |
|
target_support_sets_indices = target_support_sets_indices.cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
shift_magnitudes_pos = (self.params.min_shift_magnitude - self.params.max_shift_magnitude) * \ |
|
torch.rand(target_support_sets_indices.size()) + self.params.max_shift_magnitude |
|
shift_magnitudes_neg = (self.params.min_shift_magnitude - self.params.max_shift_magnitude) * \ |
|
torch.rand(target_support_sets_indices.size()) - self.params.min_shift_magnitude |
|
shift_magnitudes_pool = torch.cat((shift_magnitudes_neg, shift_magnitudes_pos)) |
|
|
|
shift_magnitudes_ids = torch.arange(len(shift_magnitudes_pool), dtype=torch.float) |
|
target_shift_magnitudes = shift_magnitudes_pool[torch.multinomial(input=shift_magnitudes_ids, |
|
num_samples=self.params.batch_size, |
|
replacement=False)] |
|
if self.use_cuda: |
|
target_shift_magnitudes = target_shift_magnitudes.cuda() |
|
|
|
|
|
|
|
support_sets_mask = torch.zeros([self.params.batch_size, latent_support_sets.num_support_sets]) |
|
prompt_mask = torch.zeros([self.params.batch_size, 2]) |
|
prompt_sign = torch.zeros([self.params.batch_size, 1]) |
|
if self.use_cuda: |
|
support_sets_mask = support_sets_mask.cuda() |
|
prompt_mask = prompt_mask.cuda() |
|
prompt_sign = prompt_sign.cuda() |
|
for i, (index, val) in enumerate(zip(target_support_sets_indices, target_shift_magnitudes)): |
|
support_sets_mask[i][index] += 1.0 |
|
if val >= 0: |
|
prompt_mask[i, 0] = 1.0 |
|
prompt_sign[i] = +1.0 |
|
else: |
|
prompt_mask[i, 1] = 1.0 |
|
prompt_sign[i] = -1.0 |
|
prompt_mask = prompt_mask.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
if ('stylegan' in self.params.gan) and (self.params.stylegan_space == 'W+'): |
|
shift = target_shift_magnitudes.reshape(-1, 1) * latent_support_sets( |
|
support_sets_mask, latent_code[:, :self.params.stylegan_layer + 1, :].reshape(latent_code.shape[0], |
|
-1)) |
|
else: |
|
shift = target_shift_magnitudes.reshape(-1, 1) * latent_support_sets(support_sets_mask, latent_code) |
|
|
|
|
|
if ('stylegan' in self.params.gan) and (self.params.stylegan_space == 'W+'): |
|
latent_code_reshaped = latent_code.reshape(latent_code.shape[0], -1) |
|
shift = F.pad(input=shift, |
|
pad=(0, (STYLEGAN_LAYERS[self.params.gan] - 1 - self.params.stylegan_layer) * 512), |
|
mode='constant', |
|
value=0) |
|
latent_code_shifted = latent_code_reshaped + shift |
|
latent_code_shifted_reshaped = latent_code_shifted.reshape_as(latent_code) |
|
img_shifted = generator(latent_code_shifted_reshaped) |
|
else: |
|
img_shifted = generator(latent_code + shift) |
|
|
|
|
|
img_pairs = torch.cat([self.clip_img_transform(img), self.clip_img_transform(img_shifted)], dim=0) |
|
clip_img_pairs_features = clip_model.encode_image(img_pairs) |
|
clip_img_features, clip_img_shifted_features = torch.split(clip_img_pairs_features, img.shape[0], dim=0) |
|
clip_img_diff_features = clip_img_shifted_features - clip_img_features |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.params.styleclip: |
|
corpus_text_features_batch = torch.matmul(support_sets_mask, corpus_support_sets.SUPPORT_SETS).reshape( |
|
-1, 2 * corpus_support_sets.num_support_dipoles, corpus_support_sets.support_vectors_dim) |
|
corpus_text_features_batch = torch.matmul(prompt_mask, corpus_text_features_batch).squeeze(1) |
|
|
|
|
|
if self.params.loss == 'cossim': |
|
loss = self.cosine_embedding_loss(clip_img_shifted_features, corpus_text_features_batch, |
|
torch.ones(corpus_text_features_batch.shape[0]).to( |
|
'cuda' if self.use_cuda else 'cpu')) |
|
|
|
elif self.params.loss == 'contrastive': |
|
loss = self.contrastive_loss(clip_img_shifted_features.float(), corpus_text_features_batch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
elif self.params.linear: |
|
corpus_text_features_batch = torch.matmul(support_sets_mask, corpus_support_sets.SUPPORT_SETS).reshape( |
|
-1, 2 * corpus_support_sets.num_support_dipoles, corpus_support_sets.support_vectors_dim) |
|
|
|
|
|
if self.params.loss == 'cossim': |
|
loss = self.cosine_embedding_loss(clip_img_diff_features, prompt_sign * ( |
|
corpus_text_features_batch[:, 0, :] - corpus_text_features_batch[:, 1, :]) - |
|
clip_img_features, |
|
torch.ones(corpus_text_features_batch.shape[0]).to( |
|
'cuda' if self.use_cuda else 'cpu')) |
|
|
|
elif self.params.loss == 'contrastive': |
|
loss = self.contrastive_loss(clip_img_diff_features.float(), prompt_sign * ( |
|
corpus_text_features_batch[:, 0, :] - corpus_text_features_batch[:, 1, :]) - |
|
clip_img_features) |
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
local_text_directions = target_shift_magnitudes.reshape(-1, 1) * corpus_support_sets(support_sets_mask, |
|
clip_img_features) |
|
|
|
if self.params.loss == 'cossim': |
|
loss = self.cosine_embedding_loss(clip_img_diff_features, local_text_directions, |
|
torch.ones(local_text_directions.shape[0]).to( |
|
'cuda' if self.use_cuda else 'cpu')) |
|
|
|
elif self.params.loss == 'contrastive': |
|
loss = self.contrastive_loss(img_batch=clip_img_diff_features.float(), |
|
txt_batch=local_text_directions) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
clip_model.float() |
|
latent_support_sets_optim.step() |
|
latent_support_sets_lr_scheduler.step() |
|
clip.model.convert_weights(clip_model) |
|
|
|
|
|
self.stat_tracker.update(loss=loss.item()) |
|
|
|
|
|
iter_t = time.time() |
|
|
|
|
|
self.iter_times = np.append(self.iter_times, iter_t - iter_t0) |
|
|
|
|
|
elapsed_time = iter_t - t0 |
|
|
|
|
|
mean_iter_time = self.iter_times.mean() |
|
|
|
|
|
eta = elapsed_time * ((self.params.max_iter - iteration) / (iteration - starting_iter + 1)) |
|
|
|
|
|
if iteration % self.params.log_freq == 0: |
|
self.log_progress(iteration, mean_iter_time, elapsed_time, eta) |
|
|
|
|
|
if iteration % self.params.ckp_freq == 0: |
|
|
|
checkpoint_dict = { |
|
'iter': iteration, |
|
'latent_support_sets': latent_support_sets.state_dict(), |
|
} |
|
torch.save(checkpoint_dict, self.checkpoint) |
|
|
|
|
|
|
|
elapsed_time = time.time() - t0 |
|
|
|
|
|
latent_support_sets_model_filename = osp.join(self.models_dir, 'latent_support_sets.pt') |
|
torch.save(latent_support_sets.state_dict(), latent_support_sets_model_filename) |
|
|
|
for _ in range(10): |
|
print() |
|
print("#.Training completed -- Total elapsed time: {}.".format(sec2dhms(elapsed_time))) |
|
|
|
print("#. Copy {} to {}...".format(self.wip_dir, self.complete_dir)) |
|
try: |
|
shutil.copytree(src=self.wip_dir, dst=self.complete_dir) |
|
print(" \\__Done!") |
|
except IOError as e: |
|
print(" \\__Already exists -- {}".format(e)) |
|
|