|
import torch |
|
import torch.nn as nn |
|
from torch.nn import init |
|
import functools |
|
from torch.optim import lr_scheduler |
|
import torch.nn.functional as F |
|
import math |
|
from einops import rearrange |
|
from .transformer_ops.transformer_function import TransformerEncoderLayer |
|
|
|
|
|
|
|
|
|
|
|
class AttnAware(nn.Module): |
|
def __init__(self, input_nc, activation='gelu', norm='pixel', num_heads=2): |
|
super(AttnAware, self).__init__() |
|
|
|
activation_layer = get_nonlinearity_layer(activation) |
|
norm_layer = get_norm_layer(norm) |
|
head_dim = input_nc // num_heads |
|
self.num_heads = num_heads |
|
self.input_nc = input_nc |
|
self.scale = head_dim ** -0.5 |
|
|
|
self.query_conv = nn.Sequential( |
|
norm_layer(input_nc), |
|
activation_layer, |
|
nn.Conv2d(input_nc, input_nc, kernel_size=1) |
|
) |
|
self.key_conv = nn.Sequential( |
|
norm_layer(input_nc), |
|
activation_layer, |
|
nn.Conv2d(input_nc, input_nc, kernel_size=1) |
|
) |
|
|
|
self.weight = nn.Conv2d(self.num_heads*2, 2, kernel_size=1, stride=1) |
|
self.to_out = ResnetBlock(input_nc * 2, input_nc, 1, 0, activation, norm) |
|
|
|
def forward(self, x, pre=None, mask=None): |
|
B, C, W, H = x.size() |
|
q = self.query_conv(x).view(B, -1, W*H) |
|
k = self.key_conv(x).view(B, -1, W*H) |
|
v = x.view(B, -1, W*H) |
|
|
|
q = rearrange(q, 'b (h d) n -> b h n d', h=self.num_heads) |
|
k = rearrange(k, 'b (h d) n -> b h n d', h=self.num_heads) |
|
v = rearrange(v, 'b (h d) n -> b h n d', h=self.num_heads) |
|
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale |
|
|
|
if pre is not None: |
|
|
|
B, head, N, N = dots.size() |
|
mask_n = mask.view(B, -1, 1, W * H).expand_as(dots) |
|
w_visible = (dots.detach() * mask_n).max(dim=-1, keepdim=True)[0] |
|
w_invisible = (dots.detach() * (1-mask_n)).max(dim=-1, keepdim=True)[0] |
|
weight = torch.cat([w_visible.view(B, head, W, H), w_invisible.view(B, head, W, H)], dim=1) |
|
weight = self.weight(weight) |
|
weight = F.softmax(weight, dim=1) |
|
|
|
pre_v = pre.view(B, -1, W*H) |
|
pre_v = rearrange(pre_v, 'b (h d) n -> b h n d', h=self.num_heads) |
|
dots_visible = torch.where(dots > 0, dots * mask_n, dots / (mask_n + 1e-8)) |
|
attn_visible = dots_visible.softmax(dim=-1) |
|
context_flow = torch.einsum('bhij, bhjd->bhid', attn_visible, pre_v) |
|
context_flow = rearrange(context_flow, 'b h n d -> b (h d) n').view(B, -1, W, H) |
|
|
|
dots_invisible = torch.where(dots > 0, dots * (1 - mask_n), dots / ((1 - mask_n) + 1e-8)) |
|
attn_invisible = dots_invisible.softmax(dim=-1) |
|
self_attention = torch.einsum('bhij, bhjd->bhid', attn_invisible, v) |
|
self_attention = rearrange(self_attention, 'b h n d -> b (h d) n').view(B, -1, W, H) |
|
|
|
out = weight[:, :1, :, :]*context_flow + weight[:, 1:, :, :]*self_attention |
|
else: |
|
attn = dots.softmax(dim=-1) |
|
out = torch.einsum('bhij, bhjd->bhid', attn, v) |
|
|
|
out = rearrange(out, 'b h n d -> b (h d) n').view(B, -1, W, H) |
|
|
|
out = self.to_out(torch.cat([out, x], dim=1)) |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class NoiseInjection(nn.Module): |
|
def __init__(self): |
|
super(NoiseInjection, self).__init__() |
|
|
|
self.alpha = nn.Parameter(torch.zeros(1)) |
|
|
|
def forward(self, x, noise=None, mask=None): |
|
if noise is None: |
|
b, _, h, w = x.size() |
|
noise = x.new_empty(b, 1, h, w).normal_() |
|
if mask is not None: |
|
mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear', align_corners=True) |
|
return x + self.alpha * noise * (1 - mask) |
|
return x + self.alpha * noise |
|
|
|
|
|
class ConstantInput(nn.Module): |
|
""" |
|
add position embedding for each learned VQ word |
|
""" |
|
def __init__(self, channel, size=16): |
|
super().__init__() |
|
|
|
self.input = nn.Parameter(torch.randn(1, channel, size, size)) |
|
|
|
def forward(self, input): |
|
batch = input.shape[0] |
|
out = self.input.repeat(batch, 1, 1, 1) |
|
|
|
return out |
|
|
|
|
|
class UpSample(nn.Module): |
|
""" sample with convolutional operation |
|
:param input_nc: input channel |
|
:param with_conv: use convolution to refine the feature |
|
:param kernel_size: feature size |
|
:param return_mask: return mask for the confidential score |
|
""" |
|
def __init__(self, input_nc, with_conv=False, kernel_size=3, return_mask=False): |
|
super(UpSample, self).__init__() |
|
self.with_conv = with_conv |
|
self.return_mask = return_mask |
|
if self.with_conv: |
|
self.conv = PartialConv2d(input_nc, input_nc, kernel_size=kernel_size, stride=1, |
|
padding=int(int(kernel_size-1)/2), return_mask=True) |
|
|
|
def forward(self, x, mask=None): |
|
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) |
|
mask = F.interpolate(mask, scale_factor=2, mode='bilinear', align_corners=True) if mask is not None else mask |
|
if self.with_conv: |
|
x, mask = self.conv(x, mask) |
|
if self.return_mask: |
|
return x, mask |
|
else: |
|
return x |
|
|
|
|
|
class DownSample(nn.Module): |
|
""" sample with convolutional operation |
|
:param input_nc: input channel |
|
:param with_conv: use convolution to refine the feature |
|
:param kernel_size: feature size |
|
:param return_mask: return mask for the confidential score |
|
""" |
|
def __init__(self, input_nc, with_conv=False, kernel_size=3, return_mask=False): |
|
super(DownSample, self).__init__() |
|
self.with_conv = with_conv |
|
self.return_mask = return_mask |
|
if self.with_conv: |
|
self.conv = PartialConv2d(input_nc, input_nc, kernel_size=kernel_size, stride=2, |
|
padding=int(int(kernel_size-1)/2), return_mask=True) |
|
|
|
def forward(self, x, mask=None): |
|
if self.with_conv: |
|
x, mask = self.conv(x, mask) |
|
else: |
|
x = F.avg_pool2d(x, kernel_size=2, stride=2) |
|
mask = F.avg_pool2d(mask, kernel_size=2, stride=2) if mask is not None else mask |
|
if self.return_mask: |
|
return x, mask |
|
else: |
|
return x |
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__(self, input_nc, output_nc=None, kernel=3, dropout=0.0, activation='gelu', norm='pixel', return_mask=False): |
|
super(ResnetBlock, self).__init__() |
|
|
|
activation_layer = get_nonlinearity_layer(activation) |
|
norm_layer = get_norm_layer(norm) |
|
self.return_mask = return_mask |
|
|
|
output_nc = input_nc if output_nc is None else output_nc |
|
|
|
self.norm1 = norm_layer(input_nc) |
|
self.conv1 = PartialConv2d(input_nc, output_nc, kernel_size=kernel, padding=int((kernel-1)/2), return_mask=True) |
|
self.norm2 = norm_layer(output_nc) |
|
self.conv2 = PartialConv2d(output_nc, output_nc, kernel_size=kernel, padding=int((kernel-1)/2), return_mask=True) |
|
self.dropout = nn.Dropout(dropout) |
|
self.act = activation_layer |
|
|
|
if input_nc != output_nc: |
|
self.short = PartialConv2d(input_nc, output_nc, kernel_size=1, stride=1, padding=0) |
|
else: |
|
self.short = Identity() |
|
|
|
def forward(self, x, mask=None): |
|
x_short = self.short(x) |
|
x, mask = self.conv1(self.act(self.norm1(x)), mask) |
|
x, mask = self.conv2(self.dropout(self.act(self.norm2(x))), mask) |
|
if self.return_mask: |
|
return (x + x_short) / math.sqrt(2), mask |
|
else: |
|
return (x + x_short) / math.sqrt(2) |
|
|
|
|
|
class DiffEncoder(nn.Module): |
|
def __init__(self, input_nc, ngf=64, kernel_size=2, embed_dim=512, down_scale=4, num_res_blocks=2, dropout=0.0, |
|
rample_with_conv=True, activation='gelu', norm='pixel', use_attn=False): |
|
super(DiffEncoder, self).__init__() |
|
|
|
activation_layer = get_nonlinearity_layer(activation) |
|
norm_layer = get_norm_layer(norm) |
|
|
|
|
|
self.encode = PartialConv2d(input_nc, ngf, kernel_size=kernel_size, stride=1, padding=int((kernel_size-1)/2), return_mask=True) |
|
|
|
self.use_attn = use_attn |
|
self.down_scale = down_scale |
|
self.num_res_blocks = num_res_blocks |
|
self.down = nn.ModuleList() |
|
out_dim = ngf |
|
for i in range(down_scale): |
|
block = nn.ModuleList() |
|
down = nn.Module() |
|
in_dim = out_dim |
|
out_dim = int(in_dim * 2) |
|
down.downsample = DownSample(in_dim, rample_with_conv, kernel_size=2, return_mask=True) |
|
for i_block in range(num_res_blocks): |
|
block.append(ResnetBlock(in_dim, out_dim, kernel_size, dropout, activation, norm, return_mask=True)) |
|
in_dim = out_dim |
|
down.block = block |
|
self.down.append(down) |
|
|
|
self.mid = nn.Module() |
|
self.mid.block1 = ResnetBlock(out_dim, out_dim, kernel_size, dropout, activation, norm, return_mask=True) |
|
if self.use_attn: |
|
self.mid.attn = TransformerEncoderLayer(out_dim, kernel=1) |
|
self.mid.block2 = ResnetBlock(out_dim, out_dim, kernel_size, dropout, activation, norm, return_mask=True) |
|
|
|
self.conv_out = ResnetBlock(out_dim, embed_dim, kernel_size, dropout, activation, norm, return_mask=True) |
|
|
|
def forward(self, x, mask=None, return_mask=False): |
|
x, mask = self.encode(x, mask) |
|
|
|
for i in range(self.down_scale): |
|
x, mask = self.down[i].downsample(x, mask) |
|
for i_block in range(self.num_res_blocks): |
|
x, mask = self.down[i].block[i_block](x, mask) |
|
|
|
x, mask = self.mid.block1(x, mask) |
|
if self.use_attn: |
|
x = self.mid.attn(x) |
|
x, mask = self.mid.block2(x, mask) |
|
|
|
x, mask = self.conv_out(x, mask) |
|
if return_mask: |
|
return x, mask |
|
return x |
|
|
|
|
|
class DiffDecoder(nn.Module): |
|
def __init__(self, output_nc, ngf=64, kernel_size=3, embed_dim=512, up_scale=4, num_res_blocks=2, dropout=0.0, word_size=16, |
|
rample_with_conv=True, activation='gelu', norm='pixel', add_noise=False, use_attn=True, use_pos=True): |
|
super(DiffDecoder, self).__init__() |
|
|
|
activation_layer = get_nonlinearity_layer(activation) |
|
norm_layer = get_norm_layer(norm) |
|
self.up_scale = up_scale |
|
self.num_res_blocks = num_res_blocks |
|
self.add_noise = add_noise |
|
self.use_attn = use_attn |
|
self.use_pos = use_pos |
|
in_dim = ngf * (2 ** self.up_scale) |
|
|
|
|
|
if use_pos: |
|
self.pos_embed = ConstantInput(embed_dim, size=word_size) |
|
self.conv_in = PartialConv2d(embed_dim, in_dim, kernel_size=kernel_size, stride=1, padding=int((kernel_size-1)/2)) |
|
|
|
self.mid = nn.Module() |
|
self.mid.block1 = ResnetBlock(in_dim, in_dim, kernel_size, dropout, activation, norm) |
|
if self.use_attn: |
|
self.mid.attn = TransformerEncoderLayer(in_dim, kernel=1) |
|
self.mid.block2 = ResnetBlock(in_dim, in_dim, kernel_size, dropout, activation, norm) |
|
|
|
self.up = nn.ModuleList() |
|
out_dim = in_dim |
|
for i in range(up_scale): |
|
block = nn.ModuleList() |
|
attn = nn.ModuleList() |
|
noise = nn.ModuleList() |
|
up = nn.Module() |
|
in_dim = out_dim |
|
out_dim = int(in_dim / 2) |
|
for i_block in range(num_res_blocks): |
|
if add_noise: |
|
noise.append(NoiseInjection()) |
|
block.append(ResnetBlock(in_dim, out_dim, kernel_size, dropout, activation, norm)) |
|
in_dim = out_dim |
|
if i == 0 and self.use_attn: |
|
attn.append(TransformerEncoderLayer(in_dim, kernel=1)) |
|
up.block = block |
|
up.attn = attn |
|
up.noise = noise |
|
upsample = True if (i != 0) else False |
|
up.out = ToRGB(in_dim, output_nc, upsample, activation, norm) |
|
up.upsample = UpSample(in_dim, rample_with_conv, kernel_size=3) |
|
self.up.append(up) |
|
|
|
self.decode = ToRGB(in_dim, output_nc, True, activation, norm) |
|
|
|
def forward(self, x, mask=None): |
|
x = x + self.pos_embed(x) if self.use_pos else x |
|
x = self.conv_in(x) |
|
|
|
x = self.mid.block1(x) |
|
if self.use_attn: |
|
x = self.mid.attn(x) |
|
x = self.mid.block2(x) |
|
|
|
skip = None |
|
for i in range(self.up_scale): |
|
for i_block in range(self.num_res_blocks): |
|
if self.add_noise: |
|
x = self.up[i].noise[i_block](x, mask=mask) |
|
x = self.up[i].block[i_block](x) |
|
if len(self.up[i].attn) > 0: |
|
x = self.up[i].attn[i_block](x) |
|
skip = self.up[i].out(x, skip) |
|
x = self.up[i].upsample(x) |
|
|
|
x = self.decode(x, skip) |
|
return x |
|
|
|
|
|
class LinearEncoder(nn.Module): |
|
def __init__(self, input_nc, kernel_size=16, embed_dim=512): |
|
super(LinearEncoder, self).__init__() |
|
|
|
self.encode = PartialConv2d(input_nc, embed_dim, kernel_size=kernel_size, stride=kernel_size, return_mask=True) |
|
|
|
def forward(self, x, mask=None, return_mask=False): |
|
x, mask = self.encode(x, mask) |
|
if return_mask: |
|
return x, mask |
|
return x |
|
|
|
|
|
class LinearDecoder(nn.Module): |
|
def __init__(self, output_nc, ngf=64, kernel_size=16, embed_dim=512, activation='gelu', norm='pixel'): |
|
super(LinearDecoder, self).__init__() |
|
|
|
activation_layer = get_nonlinearity_layer(activation) |
|
norm_layer = get_norm_layer(norm) |
|
|
|
self.decode = nn.Sequential( |
|
norm_layer(embed_dim), |
|
activation_layer, |
|
PartialConv2d(embed_dim, ngf*kernel_size*kernel_size, kernel_size=3, padding=1), |
|
nn.PixelShuffle(kernel_size), |
|
norm_layer(ngf), |
|
activation_layer, |
|
PartialConv2d(ngf, output_nc, kernel_size=3, padding=1) |
|
) |
|
|
|
def forward(self, x, mask=None): |
|
x = self.decode(x) |
|
|
|
return torch.tanh(x) |
|
|
|
|
|
class ToRGB(nn.Module): |
|
def __init__(self, input_nc, output_nc, upsample=True, activation='gelu', norm='pixel'): |
|
super().__init__() |
|
|
|
activation_layer = get_nonlinearity_layer(activation) |
|
norm_layer = get_norm_layer(norm) |
|
|
|
if upsample: |
|
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
|
input_nc = input_nc + output_nc |
|
|
|
self.conv = nn.Sequential( |
|
norm_layer(input_nc), |
|
activation_layer, |
|
PartialConv2d(input_nc, output_nc, kernel_size=3, padding=1) |
|
) |
|
|
|
def forward(self, input, skip=None): |
|
if skip is not None: |
|
skip = self.upsample(skip) |
|
input = torch.cat([input, skip], dim=1) |
|
|
|
out = self.conv(input) |
|
|
|
return torch.tanh(out) |
|
|
|
|
|
|
|
|
|
|
|
def get_scheduler(optimizer, opt): |
|
"""Return a learning rate scheduler |
|
Parameters: |
|
optimizer -- the optimizer of the network |
|
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. |
|
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine |
|
""" |
|
if opt.lr_policy == 'linear': |
|
def lambda_rule(iter): |
|
lr_l = 1.0 - max(0, iter + opt.iter_count - opt.n_iter) / float(opt.n_iter_decay + 1) |
|
return lr_l |
|
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) |
|
elif opt.lr_policy == 'plateau': |
|
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) |
|
elif opt.lr_policy == 'cosine': |
|
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) |
|
else: |
|
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) |
|
return scheduler |
|
|
|
|
|
def init_weights(net, init_type='normal', init_gain=0.02, debug=False): |
|
"""Initialize network weights. |
|
|
|
Parameters: |
|
net (network) -- network to be initialized |
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
|
init_gain (float) -- scaling factor for normal, xavier and orthogonal. |
|
|
|
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might |
|
work better for some applications. Feel free to try yourself. |
|
""" |
|
def init_func(m): |
|
classname = m.__class__.__name__ |
|
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): |
|
if debug: |
|
print(classname) |
|
if init_type == 'normal': |
|
init.normal_(m.weight.data, 0.0, init_gain) |
|
elif init_type == 'xavier': |
|
init.xavier_normal_(m.weight.data, gain=init_gain) |
|
elif init_type == 'kaiming': |
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
|
elif init_type == 'orthogonal': |
|
init.orthogonal_(m.weight.data, gain=init_gain) |
|
else: |
|
raise NotImplementedError('initialization method [%s] is not implemented' % init_type) |
|
if hasattr(m, 'bias') and m.bias is not None: |
|
init.constant_(m.bias.data, 0.0) |
|
elif classname.find('BatchNorm2d') != -1: |
|
init.normal_(m.weight.data, 1.0, init_gain) |
|
init.constant_(m.bias.data, 0.0) |
|
|
|
net.apply(init_func) |
|
|
|
|
|
def init_net(net, init_type='normal', init_gain=0.02, debug=False, initialize_weights=True): |
|
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights |
|
Parameters: |
|
net (network) -- the network to be initialized |
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
|
gain (float) -- scaling factor for normal, xavier and orthogonal. |
|
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 |
|
|
|
Return an initialized network. |
|
""" |
|
if initialize_weights: |
|
init_weights(net, init_type, init_gain=init_gain, debug=debug) |
|
return net |
|
|
|
|
|
class Identity(nn.Module): |
|
def forward(self, x): |
|
return x |
|
|
|
|
|
def get_norm_layer(norm_type='instance'): |
|
"""Return a normalization layer |
|
|
|
Parameters: |
|
norm_type (str) -- the name of the normalization layer: batch | instance | none |
|
|
|
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). |
|
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. |
|
""" |
|
if norm_type == 'batch': |
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) |
|
elif norm_type == 'instance': |
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=True) |
|
elif norm_type == 'pixel': |
|
norm_layer = functools.partial(PixelwiseNorm) |
|
elif norm_type == 'layer': |
|
norm_layer = functools.partial(nn.LayerNorm) |
|
elif norm_type == 'none': |
|
def norm_layer(x): return Identity() |
|
else: |
|
raise NotImplementedError('normalization layer [%s] is not found' % norm_type) |
|
return norm_layer |
|
|
|
|
|
def get_nonlinearity_layer(activation_type='PReLU'): |
|
"""Get the activation layer for the networks""" |
|
if activation_type == 'relu': |
|
nonlinearity_layer = nn.ReLU() |
|
elif activation_type == 'gelu': |
|
nonlinearity_layer = nn.GELU() |
|
elif activation_type == 'leakyrelu': |
|
nonlinearity_layer = nn.LeakyReLU(0.2) |
|
elif activation_type == 'prelu': |
|
nonlinearity_layer = nn.PReLU() |
|
else: |
|
raise NotImplementedError('activation layer [%s] is not found' % activation_type) |
|
return nonlinearity_layer |
|
|
|
|
|
class PixelwiseNorm(nn.Module): |
|
def __init__(self, input_nc): |
|
super(PixelwiseNorm, self).__init__() |
|
self.init = False |
|
self.alpha = nn.Parameter(torch.ones(1, input_nc, 1, 1)) |
|
|
|
def forward(self, x, alpha=1e-8): |
|
""" |
|
forward pass of the module |
|
:param x: input activations volume |
|
:param alpha: small number for numerical stability |
|
:return: y => pixel normalized activations |
|
""" |
|
|
|
y = x.pow(2.).mean(dim=1, keepdim=True).add(alpha).rsqrt() |
|
y = x * y |
|
return self.alpha*y |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PartialConv2d(nn.Conv2d): |
|
def __init__(self, *args, **kwargs): |
|
|
|
|
|
if 'multi_channel' in kwargs: |
|
self.multi_channel = kwargs['multi_channel'] |
|
kwargs.pop('multi_channel') |
|
else: |
|
self.multi_channel = False |
|
|
|
if 'return_mask' in kwargs: |
|
self.return_mask = kwargs['return_mask'] |
|
kwargs.pop('return_mask') |
|
else: |
|
self.return_mask = False |
|
|
|
super(PartialConv2d, self).__init__(*args, **kwargs) |
|
|
|
if self.multi_channel: |
|
self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], |
|
self.kernel_size[1]) |
|
else: |
|
self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]) |
|
|
|
self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * \ |
|
self.weight_maskUpdater.shape[3] |
|
|
|
self.last_size = (None, None, None, None) |
|
self.update_mask = None |
|
self.mask_ratio = None |
|
|
|
def forward(self, input, mask_in=None): |
|
assert len(input.shape) == 4 |
|
if mask_in is not None or self.last_size != tuple(input.shape): |
|
self.last_size = tuple(input.shape) |
|
|
|
with torch.no_grad(): |
|
if self.weight_maskUpdater.type() != input.type(): |
|
self.weight_maskUpdater = self.weight_maskUpdater.to(input) |
|
|
|
if mask_in is None: |
|
|
|
if self.multi_channel: |
|
mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], |
|
input.data.shape[3]).to(input) |
|
else: |
|
mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input) |
|
else: |
|
mask = mask_in |
|
|
|
self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, |
|
padding=self.padding, dilation=self.dilation, groups=1) |
|
|
|
|
|
self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-8) |
|
self.update_mask1 = torch.clamp(self.update_mask, 0, 1) |
|
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask1) |
|
|
|
raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input) |
|
|
|
if self.bias is not None: |
|
bias_view = self.bias.view(1, self.out_channels, 1, 1) |
|
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view |
|
output = torch.mul(output, self.update_mask1) |
|
else: |
|
output = torch.mul(raw_out, self.mask_ratio) |
|
|
|
if self.return_mask: |
|
return output, self.update_mask / self.slide_winsize |
|
else: |
|
return output |