pengdadaaa's picture
Upload 741 files
786f6a6 verified
raw
history blame
3.22 kB
""" Conv2d w/ Same Padding
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
from .config import is_exportable, is_scriptable
from .padding import pad_same, pad_same_arg, get_padding_value
_USE_EXPORT_CONV = False
def conv2d_same(
x,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
dilation: Tuple[int, int] = (1, 1),
groups: int = 1,
):
x = pad_same(x, weight.shape[-2:], stride, dilation)
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size,
stride, 0, dilation, groups, bias,
)
def forward(self, x):
return conv2d_same(
x, self.weight, self.bias,
self.stride, self.padding, self.dilation, self.groups,
)
class Conv2dSameExport(nn.Conv2d):
""" ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
NOTE: This does not currently work with torch.jit.script
"""
# pylint: disable=unused-argument
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
):
super(Conv2dSameExport, self).__init__(
in_channels, out_channels, kernel_size,
stride, 0, dilation, groups, bias,
)
self.pad = None
self.pad_input_size = (0, 0)
def forward(self, x):
input_size = x.size()[-2:]
if self.pad is None:
pad_arg = pad_same_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
self.pad = nn.ZeroPad2d(pad_arg)
self.pad_input_size = input_size
x = self.pad(x)
return F.conv2d(
x, self.weight, self.bias,
self.stride, self.padding, self.dilation, self.groups,
)
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
padding = kwargs.pop('padding', '')
kwargs.setdefault('bias', False)
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
if is_dynamic:
if _USE_EXPORT_CONV and is_exportable():
# older PyTorch ver needed this to export same padding reasonably
assert not is_scriptable() # Conv2DSameExport does not work with jit
return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
else:
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
else:
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)