from argparse import ( ArgumentParser, Namespace, ) from typing import ( Dict, Iterable, Optional, Tuple, ) import numpy as np import torch from torch import nn from utils.misc import ( optional_string, iterable_to_str, ) from .contextual_loss import ContextualLoss from .color_transfer_loss import ColorTransferLoss from .regularize_noise import NoiseRegularizer from .reconstruction import ( EyeLoss, FaceLoss, create_perceptual_loss, ReconstructionArguments, ) class LossArguments: @staticmethod def add_arguments(parser: ArgumentParser): ReconstructionArguments.add_arguments(parser) parser.add_argument("--color_transfer", type=float, default=1e10, help="color transfer loss weight") parser.add_argument("--eye", type=float, default=0.1, help="eye loss weight") parser.add_argument('--noise_regularize', type=float, default=5e4) # contextual loss parser.add_argument("--contextual", type=float, default=0.1, help="contextual loss weight") parser.add_argument("--cx_layers", nargs='*', help="contextual loss layers", choices=['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4'], default=['relu3_4', 'relu2_2', 'relu1_2']) @staticmethod def to_string(args: Namespace) -> str: return ( ReconstructionArguments.to_string(args) + optional_string(args.eye > 0, f"-eye{args.eye}") + optional_string(args.color_transfer, f"-color{args.color_transfer:.1e}") + optional_string( args.contextual, f"-cx{args.contextual}({iterable_to_str(args.cx_layers)})" ) #+ optional_string(args.mse, f"-mse{args.mse}") + optional_string(args.noise_regularize, f"-NR{args.noise_regularize:.1e}") ) class BakedMultiContextualLoss(nn.Module): """Random sample different image patches for different vgg layers.""" def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256): super().__init__() self.cxs = nn.ModuleList([ContextualLoss(use_vgg=True, vgg_layers=[layer]) for layer in args.cx_layers]) self.size = size self.sibling = sibling.detach() def forward(self, img: torch.Tensor): cx_loss = 0 for cx in self.cxs: h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2) cx_loss = cx(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size]) + cx_loss return cx_loss class BakedContextualLoss(ContextualLoss): def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256): super().__init__(use_vgg=True, vgg_layers=args.cx_layers) self.size = size self.sibling = sibling.detach() def forward(self, img: torch.Tensor): h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2) return super().forward(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size]) class JointLoss(nn.Module): def __init__( self, args: Namespace, target: torch.Tensor, sibling: Optional[torch.Tensor], sibling_rgbs: Optional[Iterable[torch.Tensor]] = None, ): super().__init__() self.weights = { "face": 1., "eye": args.eye, "contextual": args.contextual, "color_transfer": args.color_transfer, "noise": args.noise_regularize, } reconstruction = {} if args.vgg > 0 or args.vggface > 0: percept = create_perceptual_loss(args) reconstruction.update( {"face": FaceLoss(target, input_size=args.generator_size, size=args.recon_size, percept=percept)} ) if args.eye > 0: reconstruction.update( {"eye": EyeLoss(target, input_size=args.generator_size, percept=percept)} ) self.reconstruction = nn.ModuleDict(reconstruction) exemplar = {} if args.contextual > 0 and len(args.cx_layers) > 0: assert sibling is not None exemplar.update( {"contextual": BakedContextualLoss(sibling, args)} ) if args.color_transfer > 0: assert sibling_rgbs is not None self.sibling_rgbs = sibling_rgbs exemplar.update( {"color_transfer": ColorTransferLoss(init_rgbs=sibling_rgbs)} ) self.exemplar = nn.ModuleDict(exemplar) if args.noise_regularize > 0: self.noise_criterion = NoiseRegularizer() def forward( self, img, degrade=None, noises=None, rgbs=None, rgb_level: Optional[int] = None ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Args: rgbs: results from the ToRGB layers """ # TODO: add current optimization resolution for noises losses = {} # reconstruction losses for name, criterion in self.reconstruction.items(): losses[name] = criterion(img, degrade=degrade) # exemplar losses if 'contextual' in self.exemplar: losses["contextual"] = self.exemplar["contextual"](img) if "color_transfer" in self.exemplar: assert rgbs is not None losses["color_transfer"] = self.exemplar["color_transfer"](rgbs, level=rgb_level) # noise regularizer if self.weights["noise"] > 0: losses["noise"] = self.noise_criterion(noises) total_loss = 0 for name, loss in losses.items(): total_loss = total_loss + self.weights[name] * loss return total_loss, losses def update_sibling(self, sibling: torch.Tensor): assert "contextual" in self.exemplar self.exemplar["contextual"].sibling = sibling.detach()