Spaces:
Build error
Build error
# Modified partialconv source code based on implementation from | |
# https://github.com/NVIDIA/partialconv/blob/master/models/partialconv2d.py | |
############################################################################### | |
# BSD 3-Clause License | |
# | |
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Author & Contact: Guilin Liu ([email protected]) | |
############################################################################### | |
# Original Author & Contact: Guilin Liu ([email protected]) | |
# Modified by Kevin Shih ([email protected]) | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
class PartialConv1d(nn.Conv1d): | |
def __init__(self, *args, **kwargs): | |
self.multi_channel = False | |
self.return_mask = False | |
super(PartialConv1d, self).__init__(*args, **kwargs) | |
self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) | |
self.slide_winsize = ( | |
self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] | |
) | |
self.last_size = (None, None, None) | |
self.update_mask = None | |
self.mask_ratio = None | |
def forward(self, input: torch.Tensor, mask_in: torch.Tensor = None): | |
""" | |
input: standard input to a 1D conv | |
mask_in: binary mask for valid values, same shape as input | |
""" | |
assert len(input.shape) == 3 | |
# if a mask is input, or tensor shape changed, update mask ratio | |
if mask_in is not None or self.last_size != tuple(input.shape): | |
self.last_size = tuple(input.shape) | |
with torch.no_grad(): | |
if self.weight_maskUpdater.type() != input.type(): | |
self.weight_maskUpdater = self.weight_maskUpdater.to(input) | |
if mask_in is None: | |
mask = torch.ones(1, 1, input.data.shape[2]).to(input) | |
else: | |
mask = mask_in | |
self.update_mask = F.conv1d( | |
mask, | |
self.weight_maskUpdater, | |
bias=None, | |
stride=self.stride, | |
padding=self.padding, | |
dilation=self.dilation, | |
groups=1, | |
) | |
# for mixed precision training, change 1e-8 to 1e-6 | |
self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-6) | |
self.update_mask = torch.clamp(self.update_mask, 0, 1) | |
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask) | |
raw_out = super(PartialConv1d, self).forward( | |
torch.mul(input, mask) if mask_in is not None else input | |
) | |
if self.bias is not None: | |
bias_view = self.bias.view(1, self.out_channels, 1) | |
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view | |
output = torch.mul(output, self.update_mask) | |
else: | |
output = torch.mul(raw_out, self.mask_ratio) | |
if self.return_mask: | |
return output, self.update_mask | |
else: | |
return output | |