Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from easydict import EasyDict | |
from .base import BaseGenerator | |
import numpy as np | |
from typing import List | |
class LatentVariableConcat(nn.Module): | |
def __init__(self, conv2d_config): | |
super().__init__() | |
def forward(self, _inp): | |
x, mask, batch = _inp | |
z = batch["z"] | |
x = torch.cat((x, z), dim=1) | |
return (x, mask, batch) | |
def get_padding(kernel_size: int, dilation: int, stride: int): | |
out = (dilation * (kernel_size - 1) - 1) / 2 + 1 | |
return int(np.floor(out)) | |
class Conv2d(nn.Conv2d): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, | |
padding=None, dilation=1, groups=1, | |
bias=True, padding_mode='zeros', | |
demodulation=False, wsconv=False, gain=1, | |
*args, **kwargs): | |
if padding is None: | |
padding = get_padding(kernel_size, dilation, stride) | |
super().__init__( | |
in_channels, out_channels, kernel_size, stride, padding, dilation, | |
groups, bias, padding_mode) | |
self.demodulation = demodulation | |
self.wsconv = wsconv | |
if self.wsconv: | |
fan_in = np.prod(self.weight.shape[1:]) / self.groups | |
self.ws_scale = gain / np.sqrt(fan_in) | |
nn.init.normal_(self.weight) | |
if bias: | |
nn.init.constant_(self.bias, val=0) | |
assert not self.padding_mode == "circular",\ | |
"conv2d_forward does not support circular padding. Look at original pytorch code" | |
def _get_weight(self): | |
weight = self.weight | |
if self.wsconv: | |
weight = self.ws_scale * weight | |
if self.demodulation: | |
demod = torch.rsqrt(weight.pow(2).sum([1, 2, 3]) + 1e-7) | |
weight = weight * demod.view(self.out_channels, 1, 1, 1) | |
return weight | |
def conv2d_forward(self, x, weight, bias=True): | |
bias_ = None | |
if bias: | |
bias_ = self.bias | |
return nn.functional.conv2d(x, weight, bias_, self.stride, | |
self.padding, self.dilation, self.groups) | |
def forward(self, _inp): | |
x, mask = _inp | |
weight = self._get_weight() | |
return self.conv2d_forward(x, weight), mask | |
def __repr__(self): | |
return ", ".join([ | |
super().__repr__(), | |
f"Demodulation={self.demodulation}", | |
f"Weight Scale={self.wsconv}", | |
f"Bias={self.bias is not None}" | |
]) | |
class LeakyReLU(nn.LeakyReLU): | |
def forward(self, _inp): | |
x, mask = _inp | |
return super().forward(x), mask | |
class AvgPool2d(nn.AvgPool2d): | |
def forward(self, _inp): | |
x, mask, *args = _inp | |
x = super().forward(x) | |
mask = super().forward(mask) | |
if len(args) > 0: | |
return (x, mask, *args) | |
return x, mask | |
def up(x): | |
if x.shape[0] == 1 and x.shape[2] == 1 and x.shape[3] == 1: | |
# Analytical normalization | |
return x | |
return nn.functional.interpolate( | |
x, scale_factor=2, mode="nearest") | |
class NearestUpsample(nn.Module): | |
def forward(self, _inp): | |
x, mask, *args = _inp | |
x = up(x) | |
mask = up(mask) | |
if len(args) > 0: | |
return (x, mask, *args) | |
return x, mask | |
class PixelwiseNormalization(nn.Module): | |
def forward(self, _inp): | |
x, mask = _inp | |
norm = torch.rsqrt((x**2).mean(dim=1, keepdim=True) + 1e-7) | |
return x * norm, mask | |
class Linear(nn.Linear): | |
def __init__(self, in_features, out_features): | |
super().__init__(in_features, out_features) | |
self.linear = nn.Linear(in_features, out_features) | |
fanIn = in_features | |
self.wtScale = 1 / np.sqrt(fanIn) | |
nn.init.normal_(self.weight) | |
nn.init.constant_(self.bias, val=0) | |
def _get_weight(self): | |
return self.weight * self.wtScale | |
def forward_linear(self, x, weight): | |
return nn.functional.linear(x, weight, self.bias) | |
def forward(self, x): | |
return self.forward_linear(x, self._get_weight()) | |
class OneHotPoseConcat(nn.Module): | |
def forward(self, _inp): | |
x, mask, batch = _inp | |
landmarks = batch["landmarks_oh"] | |
res = x.shape[-1] | |
landmark = landmarks[res] | |
x = torch.cat((x, landmark), dim=1) | |
del batch["landmarks_oh"][res] | |
return x, mask, batch | |
def transition_features(x_old, x_new, transition_variable): | |
assert x_old.shape == x_new.shape,\ | |
"Old shape: {}, New: {}".format(x_old.shape, x_new.shape) | |
return torch.lerp(x_old.float(), x_new.float(), transition_variable) | |
class TransitionBlock(nn.Module): | |
def forward(self, _inp): | |
x, mask, batch = _inp | |
x = transition_features( | |
batch["x_old"], x, batch["transition_value"]) | |
mask = transition_features( | |
batch["mask_old"], mask, batch["transition_value"]) | |
del batch["x_old"] | |
del batch["mask_old"] | |
return x, mask, batch | |
class UnetSkipConnection(nn.Module): | |
def __init__(self, conv2d_config: dict, in_channels: int, | |
out_channels: int, resolution: int, | |
residual: bool, enabled: bool): | |
super().__init__() | |
self.use_iconv = conv2d_config.conv.type == "iconv" | |
self._in_channels = in_channels | |
self._out_channels = out_channels | |
self._resolution = resolution | |
self._enabled = enabled | |
self._residual = residual | |
if self.use_iconv: | |
self.beta0 = torch.nn.Parameter(torch.tensor(1.)) | |
self.beta1 = torch.nn.Parameter(torch.tensor(1.)) | |
else: | |
if self._residual: | |
self.conv = build_base_conv( | |
conv2d_config, False, in_channels // 2, | |
out_channels, kernel_size=1, padding=0) | |
else: | |
self.conv = ConvAct( | |
conv2d_config, in_channels, out_channels, | |
kernel_size=1, padding=0) | |
def forward(self, _inp): | |
if not self._enabled: | |
return _inp | |
x, mask, batch = _inp | |
skip_x, skip_mask = batch["unet_features"][self._resolution] | |
assert x.shape == skip_x.shape, (x.shape, skip_x.shape) | |
del batch["unet_features"][self._resolution] | |
if self.use_iconv: | |
denom = skip_mask * self.beta0.relu() + mask * self.beta1.relu() + 1e-8 | |
gamma = skip_mask * self.beta0.relu() / denom | |
x = skip_x * gamma + (1 - gamma) * x | |
mask = skip_mask * gamma + (1 - gamma) * mask | |
else: | |
if self._residual: | |
skip_x, skip_mask = self.conv((skip_x, skip_mask)) | |
x = (x + skip_x) / np.sqrt(2) | |
if self._probabilistic: | |
mask = (mask + skip_mask) / np.sqrt(2) | |
else: | |
x = torch.cat((x, skip_x), dim=1) | |
x, mask = self.conv((x, mask)) | |
return x, mask, batch | |
def __repr__(self): | |
return " ".join([ | |
self.__class__.__name__, | |
f"In channels={self._in_channels}", | |
f"Out channels={self._out_channels}", | |
f"Residual: {self._residual}", | |
f"Enabled: {self._enabled}" | |
f"IConv: {self.use_iconv}" | |
]) | |
def get_conv(ctype, post_act): | |
type2conv = { | |
"conv": Conv2d, | |
"gconv": GatedConv | |
} | |
# Do not apply for output layer | |
if not post_act and ctype in ["gconv", "iconv"]: | |
return type2conv["conv"] | |
assert ctype in type2conv | |
return type2conv[ctype] | |
def build_base_conv( | |
conv2d_config, post_act: bool, *args, **kwargs) -> nn.Conv2d: | |
for k, v in conv2d_config.conv.items(): | |
assert k not in kwargs | |
kwargs[k] = v | |
# Demodulation should not be used for output layers. | |
demodulation = conv2d_config.normalization == "demodulation" and post_act | |
kwargs["demodulation"] = demodulation | |
conv = get_conv(conv2d_config.conv.type, post_act) | |
return conv(*args, **kwargs) | |
def build_post_activation(in_channels, conv2d_config) -> List[nn.Module]: | |
_layers = [] | |
negative_slope = conv2d_config.leaky_relu_nslope | |
_layers.append(LeakyReLU(negative_slope, inplace=True)) | |
if conv2d_config.normalization == "pixel_wise": | |
_layers.append(PixelwiseNormalization()) | |
return _layers | |
def build_avgpool(conv2d_config, kernel_size) -> nn.AvgPool2d: | |
return AvgPool2d(kernel_size) | |
def build_convact(conv2d_config, *args, **kwargs): | |
conv = build_base_conv(conv2d_config, True, *args, **kwargs) | |
out_channels = conv.out_channels | |
post_act = build_post_activation(out_channels, conv2d_config) | |
return nn.Sequential(conv, *post_act) | |
class ConvAct(nn.Module): | |
def __init__(self, conv2d_config, *args, **kwargs): | |
super().__init__() | |
self._conv2d_config = conv2d_config | |
conv = build_base_conv(conv2d_config, True, *args, **kwargs) | |
self.in_channels = conv.in_channels | |
self.out_channels = conv.out_channels | |
_layers = [conv] | |
_layers.extend(build_post_activation(self.out_channels, conv2d_config)) | |
self.layers = nn.Sequential(*_layers) | |
def forward(self, _inp): | |
return self.layers(_inp) | |
class GatedConv(Conv2d): | |
def __init__(self, in_channels, out_channels, *args, **kwargs): | |
out_channels *= 2 | |
super().__init__(in_channels, out_channels, *args, **kwargs) | |
assert self.out_channels % 2 == 0 | |
self.lrelu = nn.LeakyReLU(0.2, inplace=True) | |
self.sigmoid = nn.Sigmoid() | |
def conv2d_forward(self, x, weight, bias=True): | |
x_ = super().conv2d_forward(x, weight, bias) | |
x = x_[:, :self.out_channels // 2] | |
y = x_[:, self.out_channels // 2:] | |
x = self.lrelu(x) | |
y = y.sigmoid() | |
assert x.shape == y.shape, f"{x.shape}, {y.shape}" | |
return x * y | |
class BasicBlock(nn.Module): | |
def __init__( | |
self, conv2d_config, resolution: int, in_channels: int, | |
out_channels: List[int], residual: bool): | |
super().__init__() | |
assert len(out_channels) == 2 | |
self._resolution = resolution | |
self._residual = residual | |
self.out_channels = out_channels | |
_layers = [] | |
_in_channels = in_channels | |
for out_ch in out_channels: | |
conv = build_base_conv( | |
conv2d_config, True, _in_channels, out_ch, kernel_size=3, | |
resolution=resolution) | |
_layers.append(conv) | |
_layers.extend(build_post_activation(_in_channels, conv2d_config)) | |
_in_channels = out_ch | |
self.layers = nn.Sequential(*_layers) | |
if self._residual: | |
self.residual_conv = build_base_conv( | |
conv2d_config, post_act=False, in_channels=in_channels, | |
out_channels=out_channels[-1], | |
kernel_size=1, padding=0) | |
self.const = 1 / np.sqrt(2) | |
def forward(self, _inp): | |
x, mask, batch = _inp | |
y = x | |
mask_ = mask | |
assert y.shape[-1] == self._resolution or y.shape[-1] == 1 | |
y, mask = self.layers((x, mask)) | |
if self._residual: | |
residual, mask_ = self.residual_conv((x, mask_)) | |
y = (y + residual) * self.const | |
mask = (mask + mask_) * self.const | |
return y, mask, batch | |
def extra_repr(self): | |
return f"Residual={self._residual}, Resolution={self._resolution}" | |
class PoseNormalize(nn.Module): | |
def forward(self, x): | |
return x * 2 - 1 | |
class ScalarPoseFCNN(nn.Module): | |
def __init__(self, pose_size, hidden_size, | |
output_shape): | |
super().__init__() | |
pose_size = pose_size | |
self._hidden_size = hidden_size | |
output_size = np.prod(output_shape) | |
self.output_shape = output_shape | |
self.pose_preprocessor = nn.Sequential( | |
PoseNormalize(), | |
Linear(pose_size, hidden_size), | |
nn.LeakyReLU(.2), | |
Linear(hidden_size, output_size), | |
nn.LeakyReLU(.2) | |
) | |
def forward(self, _inp): | |
x, mask, batch = _inp | |
pose_info = batch["landmarks"] | |
del batch["landmarks"] | |
pose = self.pose_preprocessor(pose_info) | |
pose = pose.view(-1, *self.output_shape) | |
if x.shape[0] == 1 and x.shape[2] == 1 and x.shape[3] == 1: | |
# Analytical normalization propagation | |
pose = pose.mean(dim=2, keepdim=True).mean(dim=3, keepdims=True) | |
x = torch.cat((x, pose), dim=1) | |
return x, mask, batch | |
def __repr__(self): | |
return " ".join([ | |
self.__class__.__name__, | |
f"hidden_size={self._hidden_size}", | |
f"output shape={self.output_shape}" | |
]) | |
class Attention(nn.Module): | |
def __init__(self, in_channels): | |
super(Attention, self).__init__() | |
# Channel multiplier | |
self.in_channels = in_channels | |
self.theta = Conv2d( | |
self.in_channels, self.in_channels // 8, kernel_size=1, padding=0, | |
bias=False) | |
self.phi = Conv2d( | |
self.in_channels, self.in_channels // 8, kernel_size=1, padding=0, | |
bias=False) | |
self.g = Conv2d( | |
self.in_channels, self.in_channels // 2, kernel_size=1, padding=0, | |
bias=False) | |
self.o = Conv2d( | |
self.in_channels // 2, self.in_channels, kernel_size=1, padding=0, | |
bias=False) | |
# Learnable gain parameter | |
self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True) | |
def forward(self, _inp): | |
x, mask, batch = _inp | |
# Apply convs | |
theta, _ = self.theta((x, None)) | |
phi = nn.functional.max_pool2d(self.phi((x, None))[0], [2, 2]) | |
g = nn.functional.max_pool2d(self.g((x, None))[0], [2, 2]) | |
# Perform reshapes | |
theta = theta.view(-1, self.in_channels // 8, x.shape[2] * x.shape[3]) | |
phi = phi.view(-1, self.in_channels // 8, x.shape[2] * x.shape[3] // 4) | |
g = g.view(-1, self.in_channels // 2, x.shape[2] * x.shape[3] // 4) | |
# Matmul and softmax to get attention maps | |
beta = nn.functional.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) | |
# Attention map times g path | |
o = self.o((torch.bmm(g, beta.transpose(1, 2)).view(-1, | |
self.in_channels // 2, x.shape[2], x.shape[3]), None))[0] | |
return self.gamma * o + x, mask, batch | |
class MSGGenerator(BaseGenerator): | |
def __init__(self): | |
super().__init__(512) | |
max_imsize = 128 | |
unet = dict(enabled=True, residual=False) | |
min_fmap_resolution = 4 | |
model_size = 512 | |
image_channels = 3 | |
pose_size = 14 | |
residual = False | |
conv_size = { | |
4: model_size, | |
8: model_size, | |
16: model_size, | |
32: model_size, | |
64: model_size//2, | |
128: model_size//4, | |
256: model_size//8, | |
512: model_size//16 | |
} | |
self.removable_hooks = [] | |
self.rgb_convolutions = nn.ModuleDict() | |
self.max_imsize = max_imsize | |
self._image_channels = image_channels | |
self._min_fmap_resolution = min_fmap_resolution | |
self._residual = residual | |
self._pose_size = pose_size | |
self.current_imsize = max_imsize | |
self._unet_cfg = unet | |
self.concat_input_mask = True | |
self.res2channels = {int(k): v for k, v in conv_size.items()} | |
self.conv2d_config = EasyDict( | |
pixel_normalization=True, | |
leaky_relu_nslope=.2, | |
normalization="pixel_wise", | |
conv=dict( | |
type="conv", | |
wsconv=True, | |
gain=1, | |
) | |
) | |
self._init_decoder() | |
self._init_encoder() | |
def _init_encoder(self): | |
self.encoder = nn.ModuleList() | |
imsize = self.max_imsize | |
self.from_rgb = build_convact( | |
self.conv2d_config, | |
in_channels=self._image_channels + self.concat_input_mask*2, | |
out_channels=self.res2channels[imsize], | |
kernel_size=1) | |
while imsize >= self._min_fmap_resolution: | |
current_size = self.res2channels[imsize] | |
next_size = self.res2channels[max(imsize//2, self._min_fmap_resolution)] | |
block = BasicBlock( | |
self.conv2d_config, imsize, current_size, | |
[current_size, next_size], self._residual) | |
self.encoder.add_module(f"basic_block{imsize}", block) | |
if imsize != self._min_fmap_resolution: | |
self.encoder.add_module( | |
f"downsample{imsize}", AvgPool2d(2)) | |
imsize //= 2 | |
def _init_decoder(self): | |
self.decoder = nn.ModuleList() | |
self.decoder.add_module( | |
"latent_concat", LatentVariableConcat(self.conv2d_config)) | |
if self._pose_size > 0: | |
m = self._min_fmap_resolution | |
pose_shape = (16, m, m) | |
pose_fcnn = ScalarPoseFCNN(self._pose_size, 128, pose_shape) | |
self.decoder.add_module("pose_fcnn", pose_fcnn) | |
imsize = self._min_fmap_resolution | |
self.rgb_convolutions = nn.ModuleDict() | |
while imsize <= self.max_imsize: | |
current_size = self.res2channels[max(imsize//2, self._min_fmap_resolution)] | |
start_size = current_size | |
if imsize == self._min_fmap_resolution: | |
start_size += 32 | |
if self._pose_size > 0: | |
start_size += 16 | |
else: | |
self.decoder.add_module(f"upsample{imsize}", NearestUpsample()) | |
skip = UnetSkipConnection( | |
self.conv2d_config, current_size*2, current_size, imsize, | |
**self._unet_cfg) | |
self.decoder.add_module(f"skip_connection{imsize}", skip) | |
next_size = self.res2channels[imsize] | |
block = BasicBlock( | |
self.conv2d_config, imsize, start_size, [start_size, next_size], | |
residual=self._residual) | |
self.decoder.add_module(f"basic_block{imsize}", block) | |
to_rgb = build_base_conv( | |
self.conv2d_config, False, in_channels=next_size, | |
out_channels=self._image_channels, kernel_size=1) | |
self.rgb_convolutions[str(imsize)] = to_rgb | |
imsize *= 2 | |
self.norm_constant = len(self.rgb_convolutions) | |
def forward_decoder(self, x, mask, batch): | |
imsize_start = max(x.shape[-1] // 2, 1) | |
rgb = torch.zeros( | |
(x.shape[0], self._image_channels, | |
imsize_start, imsize_start), | |
dtype=x.dtype, device=x.device) | |
mask_size = 1 | |
mask_out = torch.zeros( | |
(x.shape[0], mask_size, | |
imsize_start, imsize_start), | |
dtype=x.dtype, device=x.device) | |
imsize = self._min_fmap_resolution // 2 | |
for module in self.decoder: | |
x, mask, batch = module((x, mask, batch)) | |
if isinstance(module, BasicBlock): | |
imsize *= 2 | |
rgb = up(rgb) | |
mask_out = up(mask_out) | |
conv = self.rgb_convolutions[str(imsize)] | |
rgb_, mask_ = conv((x, mask)) | |
assert rgb_.shape == rgb.shape,\ | |
f"rgb_ {rgb_.shape}, rgb: {rgb.shape}" | |
rgb = rgb + rgb_ | |
return rgb / self.norm_constant, mask_out | |
def forward_encoder(self, x, mask, batch): | |
if self.concat_input_mask: | |
x = torch.cat((x, mask, 1 - mask), dim=1) | |
unet_features = {} | |
x, mask = self.from_rgb((x, mask)) | |
for module in self.encoder: | |
x, mask, batch = module((x, mask, batch)) | |
if isinstance(module, BasicBlock): | |
unet_features[module._resolution] = (x, mask) | |
return x, mask, unet_features | |
def forward( | |
self, | |
condition, | |
mask, keypoints=None, z=None, | |
**kwargs): | |
keypoints = keypoints.flatten(start_dim=1).clip(-1, 1) | |
if z is None: | |
z = self.get_z(condition) | |
z = z.view(-1, 32, 4, 4) | |
batch = dict( | |
landmarks=keypoints, | |
z=z) | |
orig_mask = mask | |
x, mask, unet_features = self.forward_encoder(condition, mask, batch) | |
batch = dict( | |
landmarks=keypoints, | |
z=z, | |
unet_features=unet_features) | |
x, mask = self.forward_decoder(x, mask, batch) | |
x = condition * orig_mask + (1 - orig_mask) * x | |
return dict(img=x) | |
def load_state_dict(self, state_dict, strict=True): | |
if "parameters" in state_dict: | |
state_dict = state_dict["parameters"] | |
old_checkpoint = any("basic_block0" in key for key in state_dict) | |
if not old_checkpoint: | |
return super().load_state_dict(state_dict, strict=strict) | |
mapping = {} | |
imsize = self._min_fmap_resolution | |
i = 0 | |
while imsize <= self.max_imsize: | |
old_key = f"decoder.basic_block{i}." | |
new_key = f"decoder.basic_block{imsize}." | |
mapping[old_key] = new_key | |
if i >= 1: | |
old_key = old_key.replace("basic_block", "skip_connection") | |
new_key = new_key.replace("basic_block", "skip_connection") | |
mapping[old_key] = new_key | |
mapping[old_key] = new_key | |
old_key = f"encoder.basic_block{i}." | |
new_key = f"encoder.basic_block{imsize}." | |
mapping[old_key] = new_key | |
old_key = "from_rgb.conv.layers.0." | |
new_key = "from_rgb.0." | |
mapping[old_key] = new_key | |
i += 1 | |
imsize *= 2 | |
new_sd = {} | |
for key, value in state_dict.items(): | |
old_key = key | |
if "from_rgb" in key: | |
new_sd[key.replace("encoder.", "").replace(".conv.layers", "")] = value | |
continue | |
for subkey, new_subkey in mapping.items(): | |
if subkey in key: | |
old_key = key | |
key = key.replace(subkey, new_subkey) | |
break | |
if "decoder.to_rgb" in key: | |
continue | |
new_sd[key] = value | |
return super().load_state_dict(new_sd, strict=strict) | |
def update_w(self, *args, **kwargs): | |
return | |