import torch import torch.nn as nn from torch.nn import init import functools from torch.optim import lr_scheduler import torch.nn.functional as F import math from einops import rearrange from .transformer_ops.transformer_function import TransformerEncoderLayer ###################################################################################### # Attention-Aware Layer ###################################################################################### class AttnAware(nn.Module): def __init__(self, input_nc, activation='gelu', norm='pixel', num_heads=2): super(AttnAware, self).__init__() activation_layer = get_nonlinearity_layer(activation) norm_layer = get_norm_layer(norm) head_dim = input_nc // num_heads self.num_heads = num_heads self.input_nc = input_nc self.scale = head_dim ** -0.5 self.query_conv = nn.Sequential( norm_layer(input_nc), activation_layer, nn.Conv2d(input_nc, input_nc, kernel_size=1) ) self.key_conv = nn.Sequential( norm_layer(input_nc), activation_layer, nn.Conv2d(input_nc, input_nc, kernel_size=1) ) self.weight = nn.Conv2d(self.num_heads*2, 2, kernel_size=1, stride=1) self.to_out = ResnetBlock(input_nc * 2, input_nc, 1, 0, activation, norm) def forward(self, x, pre=None, mask=None): B, C, W, H = x.size() q = self.query_conv(x).view(B, -1, W*H) k = self.key_conv(x).view(B, -1, W*H) v = x.view(B, -1, W*H) q = rearrange(q, 'b (h d) n -> b h n d', h=self.num_heads) k = rearrange(k, 'b (h d) n -> b h n d', h=self.num_heads) v = rearrange(v, 'b (h d) n -> b h n d', h=self.num_heads) dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale if pre is not None: # attention-aware weight B, head, N, N = dots.size() mask_n = mask.view(B, -1, 1, W * H).expand_as(dots) w_visible = (dots.detach() * mask_n).max(dim=-1, keepdim=True)[0] w_invisible = (dots.detach() * (1-mask_n)).max(dim=-1, keepdim=True)[0] weight = torch.cat([w_visible.view(B, head, W, H), w_invisible.view(B, head, W, H)], dim=1) weight = self.weight(weight) weight = F.softmax(weight, dim=1) # visible attention score pre_v = pre.view(B, -1, W*H) pre_v = rearrange(pre_v, 'b (h d) n -> b h n d', h=self.num_heads) dots_visible = torch.where(dots > 0, dots * mask_n, dots / (mask_n + 1e-8)) attn_visible = dots_visible.softmax(dim=-1) context_flow = torch.einsum('bhij, bhjd->bhid', attn_visible, pre_v) context_flow = rearrange(context_flow, 'b h n d -> b (h d) n').view(B, -1, W, H) # invisible attention score dots_invisible = torch.where(dots > 0, dots * (1 - mask_n), dots / ((1 - mask_n) + 1e-8)) attn_invisible = dots_invisible.softmax(dim=-1) self_attention = torch.einsum('bhij, bhjd->bhid', attn_invisible, v) self_attention = rearrange(self_attention, 'b h n d -> b (h d) n').view(B, -1, W, H) # out out = weight[:, :1, :, :]*context_flow + weight[:, 1:, :, :]*self_attention else: attn = dots.softmax(dim=-1) out = torch.einsum('bhij, bhjd->bhid', attn, v) out = rearrange(out, 'b h n d -> b (h d) n').view(B, -1, W, H) out = self.to_out(torch.cat([out, x], dim=1)) return out ###################################################################################### # base modules ###################################################################################### class NoiseInjection(nn.Module): def __init__(self): super(NoiseInjection, self).__init__() self.alpha = nn.Parameter(torch.zeros(1)) def forward(self, x, noise=None, mask=None): if noise is None: b, _, h, w = x.size() noise = x.new_empty(b, 1, h, w).normal_() if mask is not None: mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear', align_corners=True) return x + self.alpha * noise * (1 - mask) # add noise only to the invisible part return x + self.alpha * noise class ConstantInput(nn.Module): """ add position embedding for each learned VQ word """ def __init__(self, channel, size=16): super().__init__() self.input = nn.Parameter(torch.randn(1, channel, size, size)) def forward(self, input): batch = input.shape[0] out = self.input.repeat(batch, 1, 1, 1) return out class UpSample(nn.Module): """ sample with convolutional operation :param input_nc: input channel :param with_conv: use convolution to refine the feature :param kernel_size: feature size :param return_mask: return mask for the confidential score """ def __init__(self, input_nc, with_conv=False, kernel_size=3, return_mask=False): super(UpSample, self).__init__() self.with_conv = with_conv self.return_mask = return_mask if self.with_conv: self.conv = PartialConv2d(input_nc, input_nc, kernel_size=kernel_size, stride=1, padding=int(int(kernel_size-1)/2), return_mask=True) def forward(self, x, mask=None): x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) mask = F.interpolate(mask, scale_factor=2, mode='bilinear', align_corners=True) if mask is not None else mask if self.with_conv: x, mask = self.conv(x, mask) if self.return_mask: return x, mask else: return x class DownSample(nn.Module): """ sample with convolutional operation :param input_nc: input channel :param with_conv: use convolution to refine the feature :param kernel_size: feature size :param return_mask: return mask for the confidential score """ def __init__(self, input_nc, with_conv=False, kernel_size=3, return_mask=False): super(DownSample, self).__init__() self.with_conv = with_conv self.return_mask = return_mask if self.with_conv: self.conv = PartialConv2d(input_nc, input_nc, kernel_size=kernel_size, stride=2, padding=int(int(kernel_size-1)/2), return_mask=True) def forward(self, x, mask=None): if self.with_conv: x, mask = self.conv(x, mask) else: x = F.avg_pool2d(x, kernel_size=2, stride=2) mask = F.avg_pool2d(mask, kernel_size=2, stride=2) if mask is not None else mask if self.return_mask: return x, mask else: return x class ResnetBlock(nn.Module): def __init__(self, input_nc, output_nc=None, kernel=3, dropout=0.0, activation='gelu', norm='pixel', return_mask=False): super(ResnetBlock, self).__init__() activation_layer = get_nonlinearity_layer(activation) norm_layer = get_norm_layer(norm) self.return_mask = return_mask output_nc = input_nc if output_nc is None else output_nc self.norm1 = norm_layer(input_nc) self.conv1 = PartialConv2d(input_nc, output_nc, kernel_size=kernel, padding=int((kernel-1)/2), return_mask=True) self.norm2 = norm_layer(output_nc) self.conv2 = PartialConv2d(output_nc, output_nc, kernel_size=kernel, padding=int((kernel-1)/2), return_mask=True) self.dropout = nn.Dropout(dropout) self.act = activation_layer if input_nc != output_nc: self.short = PartialConv2d(input_nc, output_nc, kernel_size=1, stride=1, padding=0) else: self.short = Identity() def forward(self, x, mask=None): x_short = self.short(x) x, mask = self.conv1(self.act(self.norm1(x)), mask) x, mask = self.conv2(self.dropout(self.act(self.norm2(x))), mask) if self.return_mask: return (x + x_short) / math.sqrt(2), mask else: return (x + x_short) / math.sqrt(2) class DiffEncoder(nn.Module): def __init__(self, input_nc, ngf=64, kernel_size=2, embed_dim=512, down_scale=4, num_res_blocks=2, dropout=0.0, rample_with_conv=True, activation='gelu', norm='pixel', use_attn=False): super(DiffEncoder, self).__init__() activation_layer = get_nonlinearity_layer(activation) norm_layer = get_norm_layer(norm) # start self.encode = PartialConv2d(input_nc, ngf, kernel_size=kernel_size, stride=1, padding=int((kernel_size-1)/2), return_mask=True) # down self.use_attn = use_attn self.down_scale = down_scale self.num_res_blocks = num_res_blocks self.down = nn.ModuleList() out_dim = ngf for i in range(down_scale): block = nn.ModuleList() down = nn.Module() in_dim = out_dim out_dim = int(in_dim * 2) down.downsample = DownSample(in_dim, rample_with_conv, kernel_size=2, return_mask=True) for i_block in range(num_res_blocks): block.append(ResnetBlock(in_dim, out_dim, kernel_size, dropout, activation, norm, return_mask=True)) in_dim = out_dim down.block = block self.down.append(down) # middle self.mid = nn.Module() self.mid.block1 = ResnetBlock(out_dim, out_dim, kernel_size, dropout, activation, norm, return_mask=True) if self.use_attn: self.mid.attn = TransformerEncoderLayer(out_dim, kernel=1) self.mid.block2 = ResnetBlock(out_dim, out_dim, kernel_size, dropout, activation, norm, return_mask=True) # end self.conv_out = ResnetBlock(out_dim, embed_dim, kernel_size, dropout, activation, norm, return_mask=True) def forward(self, x, mask=None, return_mask=False): x, mask = self.encode(x, mask) # down sampling for i in range(self.down_scale): x, mask = self.down[i].downsample(x, mask) for i_block in range(self.num_res_blocks): x, mask = self.down[i].block[i_block](x, mask) # middle x, mask = self.mid.block1(x, mask) if self.use_attn: x = self.mid.attn(x) x, mask = self.mid.block2(x, mask) # end x, mask = self.conv_out(x, mask) if return_mask: return x, mask return x class DiffDecoder(nn.Module): def __init__(self, output_nc, ngf=64, kernel_size=3, embed_dim=512, up_scale=4, num_res_blocks=2, dropout=0.0, word_size=16, rample_with_conv=True, activation='gelu', norm='pixel', add_noise=False, use_attn=True, use_pos=True): super(DiffDecoder, self).__init__() activation_layer = get_nonlinearity_layer(activation) norm_layer = get_norm_layer(norm) self.up_scale = up_scale self.num_res_blocks = num_res_blocks self.add_noise = add_noise self.use_attn = use_attn self.use_pos = use_pos in_dim = ngf * (2 ** self.up_scale) # start if use_pos: self.pos_embed = ConstantInput(embed_dim, size=word_size) self.conv_in = PartialConv2d(embed_dim, in_dim, kernel_size=kernel_size, stride=1, padding=int((kernel_size-1)/2)) # middle self.mid = nn.Module() self.mid.block1 = ResnetBlock(in_dim, in_dim, kernel_size, dropout, activation, norm) if self.use_attn: self.mid.attn = TransformerEncoderLayer(in_dim, kernel=1) self.mid.block2 = ResnetBlock(in_dim, in_dim, kernel_size, dropout, activation, norm) # up self.up = nn.ModuleList() out_dim = in_dim for i in range(up_scale): block = nn.ModuleList() attn = nn.ModuleList() noise = nn.ModuleList() up = nn.Module() in_dim = out_dim out_dim = int(in_dim / 2) for i_block in range(num_res_blocks): if add_noise: noise.append(NoiseInjection()) block.append(ResnetBlock(in_dim, out_dim, kernel_size, dropout, activation, norm)) in_dim = out_dim if i == 0 and self.use_attn: attn.append(TransformerEncoderLayer(in_dim, kernel=1)) up.block = block up.attn = attn up.noise = noise upsample = True if (i != 0) else False up.out = ToRGB(in_dim, output_nc, upsample, activation, norm) up.upsample = UpSample(in_dim, rample_with_conv, kernel_size=3) self.up.append(up) # end self.decode = ToRGB(in_dim, output_nc, True, activation, norm) def forward(self, x, mask=None): x = x + self.pos_embed(x) if self.use_pos else x x = self.conv_in(x) # middle x = self.mid.block1(x) if self.use_attn: x = self.mid.attn(x) x = self.mid.block2(x) # up skip = None for i in range(self.up_scale): for i_block in range(self.num_res_blocks): if self.add_noise: x = self.up[i].noise[i_block](x, mask=mask) x = self.up[i].block[i_block](x) if len(self.up[i].attn) > 0: x = self.up[i].attn[i_block](x) skip = self.up[i].out(x, skip) x = self.up[i].upsample(x) # end x = self.decode(x, skip) return x class LinearEncoder(nn.Module): def __init__(self, input_nc, kernel_size=16, embed_dim=512): super(LinearEncoder, self).__init__() self.encode = PartialConv2d(input_nc, embed_dim, kernel_size=kernel_size, stride=kernel_size, return_mask=True) def forward(self, x, mask=None, return_mask=False): x, mask = self.encode(x, mask) if return_mask: return x, mask return x class LinearDecoder(nn.Module): def __init__(self, output_nc, ngf=64, kernel_size=16, embed_dim=512, activation='gelu', norm='pixel'): super(LinearDecoder, self).__init__() activation_layer = get_nonlinearity_layer(activation) norm_layer = get_norm_layer(norm) self.decode = nn.Sequential( norm_layer(embed_dim), activation_layer, PartialConv2d(embed_dim, ngf*kernel_size*kernel_size, kernel_size=3, padding=1), nn.PixelShuffle(kernel_size), norm_layer(ngf), activation_layer, PartialConv2d(ngf, output_nc, kernel_size=3, padding=1) ) def forward(self, x, mask=None): x = self.decode(x) return torch.tanh(x) class ToRGB(nn.Module): def __init__(self, input_nc, output_nc, upsample=True, activation='gelu', norm='pixel'): super().__init__() activation_layer = get_nonlinearity_layer(activation) norm_layer = get_norm_layer(norm) if upsample: self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) input_nc = input_nc + output_nc self.conv = nn.Sequential( norm_layer(input_nc), activation_layer, PartialConv2d(input_nc, output_nc, kernel_size=3, padding=1) ) def forward(self, input, skip=None): if skip is not None: skip = self.upsample(skip) input = torch.cat([input, skip], dim=1) out = self.conv(input) return torch.tanh(out) ###################################################################################### # base function for network structure ###################################################################################### def get_scheduler(optimizer, opt): """Return a learning rate scheduler Parameters: optimizer -- the optimizer of the network opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine """ if opt.lr_policy == 'linear': def lambda_rule(iter): lr_l = 1.0 - max(0, iter + opt.iter_count - opt.n_iter) / float(opt.n_iter_decay + 1) return lr_l scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) elif opt.lr_policy == 'plateau': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) elif opt.lr_policy == 'cosine': scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) else: return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) return scheduler def init_weights(net, init_type='normal', init_gain=0.02, debug=False): """Initialize network weights. Parameters: net (network) -- network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal init_gain (float) -- scaling factor for normal, xavier and orthogonal. We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might work better for some applications. Feel free to try yourself. """ def init_func(m): # define the initialization function classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if debug: print(classname) if init_type == 'normal': init.normal_(m.weight.data, 0.0, init_gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=init_gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=init_gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. init.normal_(m.weight.data, 1.0, init_gain) init.constant_(m.bias.data, 0.0) net.apply(init_func) # apply the initialization function def init_net(net, init_type='normal', init_gain=0.02, debug=False, initialize_weights=True): """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights Parameters: net (network) -- the network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal gain (float) -- scaling factor for normal, xavier and orthogonal. gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 Return an initialized network. """ if initialize_weights: init_weights(net, init_type, init_gain=init_gain, debug=debug) return net class Identity(nn.Module): def forward(self, x): return x def get_norm_layer(norm_type='instance'): """Return a normalization layer Parameters: norm_type (str) -- the name of the normalization layer: batch | instance | none For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. """ if norm_type == 'batch': norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=True) elif norm_type == 'pixel': norm_layer = functools.partial(PixelwiseNorm) elif norm_type == 'layer': norm_layer = functools.partial(nn.LayerNorm) elif norm_type == 'none': def norm_layer(x): return Identity() else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer def get_nonlinearity_layer(activation_type='PReLU'): """Get the activation layer for the networks""" if activation_type == 'relu': nonlinearity_layer = nn.ReLU() elif activation_type == 'gelu': nonlinearity_layer = nn.GELU() elif activation_type == 'leakyrelu': nonlinearity_layer = nn.LeakyReLU(0.2) elif activation_type == 'prelu': nonlinearity_layer = nn.PReLU() else: raise NotImplementedError('activation layer [%s] is not found' % activation_type) return nonlinearity_layer class PixelwiseNorm(nn.Module): def __init__(self, input_nc): super(PixelwiseNorm, self).__init__() self.init = False self.alpha = nn.Parameter(torch.ones(1, input_nc, 1, 1)) def forward(self, x, alpha=1e-8): """ forward pass of the module :param x: input activations volume :param alpha: small number for numerical stability :return: y => pixel normalized activations """ # x = x - x.mean(dim=1, keepdim=True) y = x.pow(2.).mean(dim=1, keepdim=True).add(alpha).rsqrt() # [N1HW] y = x * y # normalize the input x volume return self.alpha*y ############################################################################### # BSD 3-Clause License # # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Author & Contact: Guilin Liu (guilinl@nvidia.com) ############################################################################### class PartialConv2d(nn.Conv2d): def __init__(self, *args, **kwargs): # whether the mask is multi-channel or not if 'multi_channel' in kwargs: self.multi_channel = kwargs['multi_channel'] kwargs.pop('multi_channel') else: self.multi_channel = False if 'return_mask' in kwargs: self.return_mask = kwargs['return_mask'] kwargs.pop('return_mask') else: self.return_mask = False super(PartialConv2d, self).__init__(*args, **kwargs) if self.multi_channel: self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1]) else: self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]) self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * \ self.weight_maskUpdater.shape[3] self.last_size = (None, None, None, None) self.update_mask = None self.mask_ratio = None def forward(self, input, mask_in=None): assert len(input.shape) == 4 if mask_in is not None or self.last_size != tuple(input.shape): self.last_size = tuple(input.shape) with torch.no_grad(): if self.weight_maskUpdater.type() != input.type(): self.weight_maskUpdater = self.weight_maskUpdater.to(input) if mask_in is None: # if mask is not provided, create a mask if self.multi_channel: mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input) else: mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input) else: mask = mask_in self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1) # for mixed precision training, change 1e-8 to 1e-6 self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-8) self.update_mask1 = torch.clamp(self.update_mask, 0, 1) self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask1) raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input) if self.bias is not None: bias_view = self.bias.view(1, self.out_channels, 1, 1) output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view output = torch.mul(output, self.update_mask1) else: output = torch.mul(raw_out, self.mask_ratio) if self.return_mask: return output, self.update_mask / self.slide_winsize # replace the valid value to confident score else: return output