deep_privacy2_face / dp2 /generator /
haakohu's picture
import torch
import torch.nn as nn
from easydict import EasyDict
from .base import BaseGenerator
import numpy as np
from typing import List
class LatentVariableConcat(nn.Module):
def __init__(self, conv2d_config):
def forward(self, _inp):
x, mask, batch = _inp
z = batch["z"]
x =, z), dim=1)
return (x, mask, batch)
def get_padding(kernel_size: int, dilation: int, stride: int):
out = (dilation * (kernel_size - 1) - 1) / 2 + 1
return int(np.floor(out))
class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=None, dilation=1, groups=1,
bias=True, padding_mode='zeros',
demodulation=False, wsconv=False, gain=1,
*args, **kwargs):
if padding is None:
padding = get_padding(kernel_size, dilation, stride)
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode)
self.demodulation = demodulation
self.wsconv = wsconv
if self.wsconv:
fan_in =[1:]) / self.groups
self.ws_scale = gain / np.sqrt(fan_in)
if bias:
nn.init.constant_(self.bias, val=0)
assert not self.padding_mode == "circular",\
"conv2d_forward does not support circular padding. Look at original pytorch code"
def _get_weight(self):
weight = self.weight
if self.wsconv:
weight = self.ws_scale * weight
if self.demodulation:
demod = torch.rsqrt(weight.pow(2).sum([1, 2, 3]) + 1e-7)
weight = weight * demod.view(self.out_channels, 1, 1, 1)
return weight
def conv2d_forward(self, x, weight, bias=True):
bias_ = None
if bias:
bias_ = self.bias
return nn.functional.conv2d(x, weight, bias_, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, _inp):
x, mask = _inp
weight = self._get_weight()
return self.conv2d_forward(x, weight), mask
def __repr__(self):
return ", ".join([
f"Weight Scale={self.wsconv}",
f"Bias={self.bias is not None}"
class LeakyReLU(nn.LeakyReLU):
def forward(self, _inp):
x, mask = _inp
return super().forward(x), mask
class AvgPool2d(nn.AvgPool2d):
def forward(self, _inp):
x, mask, *args = _inp
x = super().forward(x)
mask = super().forward(mask)
if len(args) > 0:
return (x, mask, *args)
return x, mask
def up(x):
if x.shape[0] == 1 and x.shape[2] == 1 and x.shape[3] == 1:
# Analytical normalization
return x
return nn.functional.interpolate(
x, scale_factor=2, mode="nearest")
class NearestUpsample(nn.Module):
def forward(self, _inp):
x, mask, *args = _inp
x = up(x)
mask = up(mask)
if len(args) > 0:
return (x, mask, *args)
return x, mask
class PixelwiseNormalization(nn.Module):
def forward(self, _inp):
x, mask = _inp
norm = torch.rsqrt((x**2).mean(dim=1, keepdim=True) + 1e-7)
return x * norm, mask
class Linear(nn.Linear):
def __init__(self, in_features, out_features):
super().__init__(in_features, out_features)
self.linear = nn.Linear(in_features, out_features)
fanIn = in_features
self.wtScale = 1 / np.sqrt(fanIn)
nn.init.constant_(self.bias, val=0)
def _get_weight(self):
return self.weight * self.wtScale
def forward_linear(self, x, weight):
return nn.functional.linear(x, weight, self.bias)
def forward(self, x):
return self.forward_linear(x, self._get_weight())
class OneHotPoseConcat(nn.Module):
def forward(self, _inp):
x, mask, batch = _inp
landmarks = batch["landmarks_oh"]
res = x.shape[-1]
landmark = landmarks[res]
x =, landmark), dim=1)
del batch["landmarks_oh"][res]
return x, mask, batch
def transition_features(x_old, x_new, transition_variable):
assert x_old.shape == x_new.shape,\
"Old shape: {}, New: {}".format(x_old.shape, x_new.shape)
return torch.lerp(x_old.float(), x_new.float(), transition_variable)
class TransitionBlock(nn.Module):
def forward(self, _inp):
x, mask, batch = _inp
x = transition_features(
batch["x_old"], x, batch["transition_value"])
mask = transition_features(
batch["mask_old"], mask, batch["transition_value"])
del batch["x_old"]
del batch["mask_old"]
return x, mask, batch
class UnetSkipConnection(nn.Module):
def __init__(self, conv2d_config: dict, in_channels: int,
out_channels: int, resolution: int,
residual: bool, enabled: bool):
self.use_iconv = conv2d_config.conv.type == "iconv"
self._in_channels = in_channels
self._out_channels = out_channels
self._resolution = resolution
self._enabled = enabled
self._residual = residual
if self.use_iconv:
self.beta0 = torch.nn.Parameter(torch.tensor(1.))
self.beta1 = torch.nn.Parameter(torch.tensor(1.))
if self._residual:
self.conv = build_base_conv(
conv2d_config, False, in_channels // 2,
out_channels, kernel_size=1, padding=0)
self.conv = ConvAct(
conv2d_config, in_channels, out_channels,
kernel_size=1, padding=0)
def forward(self, _inp):
if not self._enabled:
return _inp
x, mask, batch = _inp
skip_x, skip_mask = batch["unet_features"][self._resolution]
assert x.shape == skip_x.shape, (x.shape, skip_x.shape)
del batch["unet_features"][self._resolution]
if self.use_iconv:
denom = skip_mask * self.beta0.relu() + mask * self.beta1.relu() + 1e-8
gamma = skip_mask * self.beta0.relu() / denom
x = skip_x * gamma + (1 - gamma) * x
mask = skip_mask * gamma + (1 - gamma) * mask
if self._residual:
skip_x, skip_mask = self.conv((skip_x, skip_mask))
x = (x + skip_x) / np.sqrt(2)
if self._probabilistic:
mask = (mask + skip_mask) / np.sqrt(2)
x =, skip_x), dim=1)
x, mask = self.conv((x, mask))
return x, mask, batch
def __repr__(self):
return " ".join([
f"In channels={self._in_channels}",
f"Out channels={self._out_channels}",
f"Residual: {self._residual}",
f"Enabled: {self._enabled}"
f"IConv: {self.use_iconv}"
def get_conv(ctype, post_act):
type2conv = {
"conv": Conv2d,
"gconv": GatedConv
# Do not apply for output layer
if not post_act and ctype in ["gconv", "iconv"]:
return type2conv["conv"]
assert ctype in type2conv
return type2conv[ctype]
def build_base_conv(
conv2d_config, post_act: bool, *args, **kwargs) -> nn.Conv2d:
for k, v in conv2d_config.conv.items():
assert k not in kwargs
kwargs[k] = v
# Demodulation should not be used for output layers.
demodulation = conv2d_config.normalization == "demodulation" and post_act
kwargs["demodulation"] = demodulation
conv = get_conv(conv2d_config.conv.type, post_act)
return conv(*args, **kwargs)
def build_post_activation(in_channels, conv2d_config) -> List[nn.Module]:
_layers = []
negative_slope = conv2d_config.leaky_relu_nslope
_layers.append(LeakyReLU(negative_slope, inplace=True))
if conv2d_config.normalization == "pixel_wise":
return _layers
def build_avgpool(conv2d_config, kernel_size) -> nn.AvgPool2d:
return AvgPool2d(kernel_size)
def build_convact(conv2d_config, *args, **kwargs):
conv = build_base_conv(conv2d_config, True, *args, **kwargs)
out_channels = conv.out_channels
post_act = build_post_activation(out_channels, conv2d_config)
return nn.Sequential(conv, *post_act)
class ConvAct(nn.Module):
def __init__(self, conv2d_config, *args, **kwargs):
self._conv2d_config = conv2d_config
conv = build_base_conv(conv2d_config, True, *args, **kwargs)
self.in_channels = conv.in_channels
self.out_channels = conv.out_channels
_layers = [conv]
_layers.extend(build_post_activation(self.out_channels, conv2d_config))
self.layers = nn.Sequential(*_layers)
def forward(self, _inp):
return self.layers(_inp)
class GatedConv(Conv2d):
def __init__(self, in_channels, out_channels, *args, **kwargs):
out_channels *= 2
super().__init__(in_channels, out_channels, *args, **kwargs)
assert self.out_channels % 2 == 0
self.lrelu = nn.LeakyReLU(0.2, inplace=True)
self.sigmoid = nn.Sigmoid()
def conv2d_forward(self, x, weight, bias=True):
x_ = super().conv2d_forward(x, weight, bias)
x = x_[:, :self.out_channels // 2]
y = x_[:, self.out_channels // 2:]
x = self.lrelu(x)
y = y.sigmoid()
assert x.shape == y.shape, f"{x.shape}, {y.shape}"
return x * y
class BasicBlock(nn.Module):
def __init__(
self, conv2d_config, resolution: int, in_channels: int,
out_channels: List[int], residual: bool):
assert len(out_channels) == 2
self._resolution = resolution
self._residual = residual
self.out_channels = out_channels
_layers = []
_in_channels = in_channels
for out_ch in out_channels:
conv = build_base_conv(
conv2d_config, True, _in_channels, out_ch, kernel_size=3,
_layers.extend(build_post_activation(_in_channels, conv2d_config))
_in_channels = out_ch
self.layers = nn.Sequential(*_layers)
if self._residual:
self.residual_conv = build_base_conv(
conv2d_config, post_act=False, in_channels=in_channels,
kernel_size=1, padding=0)
self.const = 1 / np.sqrt(2)
def forward(self, _inp):
x, mask, batch = _inp
y = x
mask_ = mask
assert y.shape[-1] == self._resolution or y.shape[-1] == 1
y, mask = self.layers((x, mask))
if self._residual:
residual, mask_ = self.residual_conv((x, mask_))
y = (y + residual) * self.const
mask = (mask + mask_) * self.const
return y, mask, batch
def extra_repr(self):
return f"Residual={self._residual}, Resolution={self._resolution}"
class PoseNormalize(nn.Module):
def forward(self, x):
return x * 2 - 1
class ScalarPoseFCNN(nn.Module):
def __init__(self, pose_size, hidden_size,
pose_size = pose_size
self._hidden_size = hidden_size
output_size =
self.output_shape = output_shape
self.pose_preprocessor = nn.Sequential(
Linear(pose_size, hidden_size),
Linear(hidden_size, output_size),
def forward(self, _inp):
x, mask, batch = _inp
pose_info = batch["landmarks"]
del batch["landmarks"]
pose = self.pose_preprocessor(pose_info)
pose = pose.view(-1, *self.output_shape)
if x.shape[0] == 1 and x.shape[2] == 1 and x.shape[3] == 1:
# Analytical normalization propagation
pose = pose.mean(dim=2, keepdim=True).mean(dim=3, keepdims=True)
x =, pose), dim=1)
return x, mask, batch
def __repr__(self):
return " ".join([
f"output shape={self.output_shape}"
class Attention(nn.Module):
def __init__(self, in_channels):
super(Attention, self).__init__()
# Channel multiplier
self.in_channels = in_channels
self.theta = Conv2d(
self.in_channels, self.in_channels // 8, kernel_size=1, padding=0,
self.phi = Conv2d(
self.in_channels, self.in_channels // 8, kernel_size=1, padding=0,
self.g = Conv2d(
self.in_channels, self.in_channels // 2, kernel_size=1, padding=0,
self.o = Conv2d(
self.in_channels // 2, self.in_channels, kernel_size=1, padding=0,
# Learnable gain parameter
self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True)
def forward(self, _inp):
x, mask, batch = _inp
# Apply convs
theta, _ = self.theta((x, None))
phi = nn.functional.max_pool2d(self.phi((x, None))[0], [2, 2])
g = nn.functional.max_pool2d(self.g((x, None))[0], [2, 2])
# Perform reshapes
theta = theta.view(-1, self.in_channels // 8, x.shape[2] * x.shape[3])
phi = phi.view(-1, self.in_channels // 8, x.shape[2] * x.shape[3] // 4)
g = g.view(-1, self.in_channels // 2, x.shape[2] * x.shape[3] // 4)
# Matmul and softmax to get attention maps
beta = nn.functional.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
# Attention map times g path
o = self.o((torch.bmm(g, beta.transpose(1, 2)).view(-1,
self.in_channels // 2, x.shape[2], x.shape[3]), None))[0]
return self.gamma * o + x, mask, batch
class MSGGenerator(BaseGenerator):
def __init__(self):
max_imsize = 128
unet = dict(enabled=True, residual=False)
min_fmap_resolution = 4
model_size = 512
image_channels = 3
pose_size = 14
residual = False
conv_size = {
4: model_size,
8: model_size,
16: model_size,
32: model_size,
64: model_size//2,
128: model_size//4,
256: model_size//8,
512: model_size//16
self.removable_hooks = []
self.rgb_convolutions = nn.ModuleDict()
self.max_imsize = max_imsize
self._image_channels = image_channels
self._min_fmap_resolution = min_fmap_resolution
self._residual = residual
self._pose_size = pose_size
self.current_imsize = max_imsize
self._unet_cfg = unet
self.concat_input_mask = True
self.res2channels = {int(k): v for k, v in conv_size.items()}
self.conv2d_config = EasyDict(
def _init_encoder(self):
self.encoder = nn.ModuleList()
imsize = self.max_imsize
self.from_rgb = build_convact(
in_channels=self._image_channels + self.concat_input_mask*2,
while imsize >= self._min_fmap_resolution:
current_size = self.res2channels[imsize]
next_size = self.res2channels[max(imsize//2, self._min_fmap_resolution)]
block = BasicBlock(
self.conv2d_config, imsize, current_size,
[current_size, next_size], self._residual)
self.encoder.add_module(f"basic_block{imsize}", block)
if imsize != self._min_fmap_resolution:
f"downsample{imsize}", AvgPool2d(2))
imsize //= 2
def _init_decoder(self):
self.decoder = nn.ModuleList()
"latent_concat", LatentVariableConcat(self.conv2d_config))
if self._pose_size > 0:
m = self._min_fmap_resolution
pose_shape = (16, m, m)
pose_fcnn = ScalarPoseFCNN(self._pose_size, 128, pose_shape)
self.decoder.add_module("pose_fcnn", pose_fcnn)
imsize = self._min_fmap_resolution
self.rgb_convolutions = nn.ModuleDict()
while imsize <= self.max_imsize:
current_size = self.res2channels[max(imsize//2, self._min_fmap_resolution)]
start_size = current_size
if imsize == self._min_fmap_resolution:
start_size += 32
if self._pose_size > 0:
start_size += 16
self.decoder.add_module(f"upsample{imsize}", NearestUpsample())
skip = UnetSkipConnection(
self.conv2d_config, current_size*2, current_size, imsize,
self.decoder.add_module(f"skip_connection{imsize}", skip)
next_size = self.res2channels[imsize]
block = BasicBlock(
self.conv2d_config, imsize, start_size, [start_size, next_size],
self.decoder.add_module(f"basic_block{imsize}", block)
to_rgb = build_base_conv(
self.conv2d_config, False, in_channels=next_size,
out_channels=self._image_channels, kernel_size=1)
self.rgb_convolutions[str(imsize)] = to_rgb
imsize *= 2
self.norm_constant = len(self.rgb_convolutions)
def forward_decoder(self, x, mask, batch):
imsize_start = max(x.shape[-1] // 2, 1)
rgb = torch.zeros(
(x.shape[0], self._image_channels,
imsize_start, imsize_start),
dtype=x.dtype, device=x.device)
mask_size = 1
mask_out = torch.zeros(
(x.shape[0], mask_size,
imsize_start, imsize_start),
dtype=x.dtype, device=x.device)
imsize = self._min_fmap_resolution // 2
for module in self.decoder:
x, mask, batch = module((x, mask, batch))
if isinstance(module, BasicBlock):
imsize *= 2
rgb = up(rgb)
mask_out = up(mask_out)
conv = self.rgb_convolutions[str(imsize)]
rgb_, mask_ = conv((x, mask))
assert rgb_.shape == rgb.shape,\
f"rgb_ {rgb_.shape}, rgb: {rgb.shape}"
rgb = rgb + rgb_
return rgb / self.norm_constant, mask_out
def forward_encoder(self, x, mask, batch):
if self.concat_input_mask:
x =, mask, 1 - mask), dim=1)
unet_features = {}
x, mask = self.from_rgb((x, mask))
for module in self.encoder:
x, mask, batch = module((x, mask, batch))
if isinstance(module, BasicBlock):
unet_features[module._resolution] = (x, mask)
return x, mask, unet_features
def forward(
mask, keypoints=None, z=None,
keypoints = keypoints.flatten(start_dim=1).clip(-1, 1)
if z is None:
z = self.get_z(condition)
z = z.view(-1, 32, 4, 4)
batch = dict(
orig_mask = mask
x, mask, unet_features = self.forward_encoder(condition, mask, batch)
batch = dict(
x, mask = self.forward_decoder(x, mask, batch)
x = condition * orig_mask + (1 - orig_mask) * x
return dict(img=x)
def load_state_dict(self, state_dict, strict=True):
if "parameters" in state_dict:
state_dict = state_dict["parameters"]
old_checkpoint = any("basic_block0" in key for key in state_dict)
if not old_checkpoint:
return super().load_state_dict(state_dict, strict=strict)
mapping = {}
imsize = self._min_fmap_resolution
i = 0
while imsize <= self.max_imsize:
old_key = f"decoder.basic_block{i}."
new_key = f"decoder.basic_block{imsize}."
mapping[old_key] = new_key
if i >= 1:
old_key = old_key.replace("basic_block", "skip_connection")
new_key = new_key.replace("basic_block", "skip_connection")
mapping[old_key] = new_key
mapping[old_key] = new_key
old_key = f"encoder.basic_block{i}."
new_key = f"encoder.basic_block{imsize}."
mapping[old_key] = new_key
old_key = "from_rgb.conv.layers.0."
new_key = "from_rgb.0."
mapping[old_key] = new_key
i += 1
imsize *= 2
new_sd = {}
for key, value in state_dict.items():
old_key = key
if "from_rgb" in key:
new_sd[key.replace("encoder.", "").replace(".conv.layers", "")] = value
for subkey, new_subkey in mapping.items():
if subkey in key:
old_key = key
key = key.replace(subkey, new_subkey)
if "decoder.to_rgb" in key:
new_sd[key] = value
return super().load_state_dict(new_sd, strict=strict)
def update_w(self, *args, **kwargs):