Spaces:
Configuration error
Configuration error
import torch | |
from .cut_model import CUTModel | |
class SinCUTModel(CUTModel): | |
""" This class implements the single image translation model (Fig 9) of | |
Contrastive Learning for Unpaired Image-to-Image Translation | |
Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu | |
ECCV, 2020 | |
""" | |
def modify_commandline_options(parser, is_train=True): | |
parser = CUTModel.modify_commandline_options(parser, is_train) | |
parser.add_argument('--lambda_R1', type=float, default=1.0, | |
help='weight for the R1 gradient penalty') | |
parser.add_argument('--lambda_identity', type=float, default=1.0, | |
help='the "identity preservation loss"') | |
parser.set_defaults(nce_includes_all_negatives_from_minibatch=True, | |
dataset_mode="singleimage", | |
netG="stylegan2", | |
stylegan2_G_num_downsampling=1, | |
netD="stylegan2", | |
gan_mode="nonsaturating", | |
num_patches=1, | |
nce_layers="0,2,4", | |
lambda_NCE=4.0, | |
ngf=10, | |
ndf=8, | |
lr=0.002, | |
beta1=0.0, | |
beta2=0.99, | |
load_size=1024, | |
crop_size=64, | |
preprocess="zoom_and_patch", | |
) | |
if is_train: | |
parser.set_defaults(preprocess="zoom_and_patch", | |
batch_size=16, | |
save_epoch_freq=1, | |
save_latest_freq=20000, | |
n_epochs=8, | |
n_epochs_decay=8, | |
) | |
else: | |
parser.set_defaults(preprocess="none", # load the whole image as it is | |
batch_size=1, | |
num_test=1, | |
) | |
return parser | |
def __init__(self, opt): | |
super().__init__(opt) | |
if self.isTrain: | |
if opt.lambda_R1 > 0.0: | |
self.loss_names += ['D_R1'] | |
if opt.lambda_identity > 0.0: | |
self.loss_names += ['idt'] | |
def compute_D_loss(self): | |
self.real_B.requires_grad_() | |
GAN_loss_D = super().compute_D_loss() | |
self.loss_D_R1 = self.R1_loss(self.pred_real, self.real_B) | |
self.loss_D = GAN_loss_D + self.loss_D_R1 | |
return self.loss_D | |
def compute_G_loss(self): | |
CUT_loss_G = super().compute_G_loss() | |
self.loss_idt = torch.nn.functional.l1_loss(self.idt_B, self.real_B) * self.opt.lambda_identity | |
return CUT_loss_G + self.loss_idt | |
def R1_loss(self, real_pred, real_img): | |
grad_real, = torch.autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True, retain_graph=True) | |
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() | |
return grad_penalty * (self.opt.lambda_R1 * 0.5) | |