import torch import torch.nn.functional as F from torch import nn class ResBlocks(nn.Module): def __init__(self, num_blocks, dim, norm, activation, pad_type): super(ResBlocks, self).__init__() self.model = [] for i in range(num_blocks): self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)] self.model = nn.Sequential(*self.model) def forward(self, x): return self.model(x) class ResBlock(nn.Module): def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): super(ResBlock, self).__init__() model = [] model += [Conv2dBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] model += [Conv2dBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] self.model = nn.Sequential(*model) def forward(self, x): residual = x out = self.model(x) out += residual return out class ActFirstResBlock(nn.Module): def __init__(self, fin, fout, fhid=None, activation='lrelu', norm='none'): super().__init__() self.learned_shortcut = (fin != fout) self.fin = fin self.fout = fout self.fhid = min(fin, fout) if fhid is None else fhid self.conv_0 = Conv2dBlock(self.fin, self.fhid, 3, 1, padding=1, pad_type='reflect', norm=norm, activation=activation, activation_first=True) self.conv_1 = Conv2dBlock(self.fhid, self.fout, 3, 1, padding=1, pad_type='reflect', norm=norm, activation=activation, activation_first=True) if self.learned_shortcut: self.conv_s = Conv2dBlock(self.fin, self.fout, 1, 1, activation='none', use_bias=False) def forward(self, x): x_s = self.conv_s(x) if self.learned_shortcut else x dx = self.conv_0(x) dx = self.conv_1(dx) out = x_s + dx return out class LinearBlock(nn.Module): def __init__(self, in_dim, out_dim, norm='none', activation='relu'): super(LinearBlock, self).__init__() use_bias = True self.fc = nn.Linear(in_dim, out_dim, bias=use_bias) # initialize normalization norm_dim = out_dim if norm == 'bn': self.norm = nn.BatchNorm1d(norm_dim) elif norm == 'in': self.norm = nn.InstanceNorm1d(norm_dim) elif norm == 'none': self.norm = None else: assert 0, "Unsupported normalization: {}".format(norm) # initialize activation if activation == 'relu': self.activation = nn.ReLU(inplace=False) elif activation == 'lrelu': self.activation = nn.LeakyReLU(0.2, inplace=False) elif activation == 'tanh': self.activation = nn.Tanh() elif activation == 'none': self.activation = None else: assert 0, "Unsupported activation: {}".format(activation) def forward(self, x): out = self.fc(x) if self.norm: out = self.norm(out) if self.activation: out = self.activation(out) return out class Conv2dBlock(nn.Module): def __init__(self, in_dim, out_dim, ks, st, padding=0, norm='none', activation='relu', pad_type='zero', use_bias=True, activation_first=False): super(Conv2dBlock, self).__init__() self.use_bias = use_bias self.activation_first = activation_first # initialize padding if pad_type == 'reflect': self.pad = nn.ReflectionPad2d(padding) elif pad_type == 'replicate': self.pad = nn.ReplicationPad2d(padding) elif pad_type == 'zero': self.pad = nn.ZeroPad2d(padding) else: assert 0, "Unsupported padding type: {}".format(pad_type) # initialize normalization norm_dim = out_dim if norm == 'bn': self.norm = nn.BatchNorm2d(norm_dim) elif norm == 'in': self.norm = nn.InstanceNorm2d(norm_dim) elif norm == 'adain': self.norm = AdaptiveInstanceNorm2d(norm_dim) elif norm == 'none': self.norm = None else: assert 0, "Unsupported normalization: {}".format(norm) # initialize activation if activation == 'relu': self.activation = nn.ReLU(inplace=False) elif activation == 'lrelu': self.activation = nn.LeakyReLU(0.2, inplace=False) elif activation == 'tanh': self.activation = nn.Tanh() elif activation == 'none': self.activation = None else: assert 0, "Unsupported activation: {}".format(activation) self.conv = nn.Conv2d(in_dim, out_dim, ks, st, bias=self.use_bias) def forward(self, x): if self.activation_first: if self.activation: x = self.activation(x) x = self.conv(self.pad(x)) if self.norm: x = self.norm(x) else: x = self.conv(self.pad(x)) if self.norm: x = self.norm(x) if self.activation: x = self.activation(x) return x class AdaptiveInstanceNorm2d(nn.Module): def __init__(self, num_features, eps=1e-5, momentum=0.1): super(AdaptiveInstanceNorm2d, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.weight = None self.bias = None self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) def forward(self, x): assert self.weight is not None and \ self.bias is not None, "Please assign AdaIN weight first" b, c = x.size(0), x.size(1) running_mean = self.running_mean.repeat(b) running_var = self.running_var.repeat(b) x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) out = F.batch_norm( x_reshaped, running_mean, running_var, self.weight, self.bias, True, self.momentum, self.eps) return out.view(b, c, *x.size()[2:]) def __repr__(self): return self.__class__.__name__ + '(' + str(self.num_features) + ')'