Spaces:
Runtime error
Runtime error
File size: 6,004 Bytes
47c46ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
from argparse import (
ArgumentParser,
Namespace,
)
from typing import (
Dict,
Iterable,
Optional,
Tuple,
)
import numpy as np
import torch
from torch import nn
from utils.misc import (
optional_string,
iterable_to_str,
)
from .contextual_loss import ContextualLoss
from .color_transfer_loss import ColorTransferLoss
from .regularize_noise import NoiseRegularizer
from .reconstruction import (
EyeLoss,
FaceLoss,
create_perceptual_loss,
ReconstructionArguments,
)
class LossArguments:
@staticmethod
def add_arguments(parser: ArgumentParser):
ReconstructionArguments.add_arguments(parser)
parser.add_argument("--color_transfer", type=float, default=1e10, help="color transfer loss weight")
parser.add_argument("--eye", type=float, default=0.1, help="eye loss weight")
parser.add_argument('--noise_regularize', type=float, default=5e4)
# contextual loss
parser.add_argument("--contextual", type=float, default=0.1, help="contextual loss weight")
parser.add_argument("--cx_layers", nargs='*', help="contextual loss layers",
choices=['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4'],
default=['relu3_4', 'relu2_2', 'relu1_2'])
@staticmethod
def to_string(args: Namespace) -> str:
return (
ReconstructionArguments.to_string(args)
+ optional_string(args.eye > 0, f"-eye{args.eye}")
+ optional_string(args.color_transfer, f"-color{args.color_transfer:.1e}")
+ optional_string(
args.contextual,
f"-cx{args.contextual}({iterable_to_str(args.cx_layers)})"
)
#+ optional_string(args.mse, f"-mse{args.mse}")
+ optional_string(args.noise_regularize, f"-NR{args.noise_regularize:.1e}")
)
class BakedMultiContextualLoss(nn.Module):
"""Random sample different image patches for different vgg layers."""
def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256):
super().__init__()
self.cxs = nn.ModuleList([ContextualLoss(use_vgg=True, vgg_layers=[layer])
for layer in args.cx_layers])
self.size = size
self.sibling = sibling.detach()
def forward(self, img: torch.Tensor):
cx_loss = 0
for cx in self.cxs:
h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2)
cx_loss = cx(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size]) + cx_loss
return cx_loss
class BakedContextualLoss(ContextualLoss):
def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256):
super().__init__(use_vgg=True, vgg_layers=args.cx_layers)
self.size = size
self.sibling = sibling.detach()
def forward(self, img: torch.Tensor):
h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2)
return super().forward(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size])
class JointLoss(nn.Module):
def __init__(
self,
args: Namespace,
target: torch.Tensor,
sibling: Optional[torch.Tensor],
sibling_rgbs: Optional[Iterable[torch.Tensor]] = None,
):
super().__init__()
self.weights = {
"face": 1., "eye": args.eye,
"contextual": args.contextual, "color_transfer": args.color_transfer,
"noise": args.noise_regularize,
}
reconstruction = {}
if args.vgg > 0 or args.vggface > 0:
percept = create_perceptual_loss(args)
reconstruction.update(
{"face": FaceLoss(target, input_size=args.generator_size, size=args.recon_size, percept=percept)}
)
if args.eye > 0:
reconstruction.update(
{"eye": EyeLoss(target, input_size=args.generator_size, percept=percept)}
)
self.reconstruction = nn.ModuleDict(reconstruction)
exemplar = {}
if args.contextual > 0 and len(args.cx_layers) > 0:
assert sibling is not None
exemplar.update(
{"contextual": BakedContextualLoss(sibling, args)}
)
if args.color_transfer > 0:
assert sibling_rgbs is not None
self.sibling_rgbs = sibling_rgbs
exemplar.update(
{"color_transfer": ColorTransferLoss(init_rgbs=sibling_rgbs)}
)
self.exemplar = nn.ModuleDict(exemplar)
if args.noise_regularize > 0:
self.noise_criterion = NoiseRegularizer()
def forward(
self, img, degrade=None, noises=None, rgbs=None, rgb_level: Optional[int] = None
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
rgbs: results from the ToRGB layers
"""
# TODO: add current optimization resolution for noises
losses = {}
# reconstruction losses
for name, criterion in self.reconstruction.items():
losses[name] = criterion(img, degrade=degrade)
# exemplar losses
if 'contextual' in self.exemplar:
losses["contextual"] = self.exemplar["contextual"](img)
if "color_transfer" in self.exemplar:
assert rgbs is not None
losses["color_transfer"] = self.exemplar["color_transfer"](rgbs, level=rgb_level)
# noise regularizer
if self.weights["noise"] > 0:
losses["noise"] = self.noise_criterion(noises)
total_loss = 0
for name, loss in losses.items():
total_loss = total_loss + self.weights[name] * loss
return total_loss, losses
def update_sibling(self, sibling: torch.Tensor):
assert "contextual" in self.exemplar
self.exemplar["contextual"].sibling = sibling.detach()
|