|
from typing import List, Optional |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from collections import namedtuple |
|
|
|
class Conv2d(torch.nn.Conv2d): |
|
""" |
|
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
""" |
|
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: |
|
|
|
Args: |
|
norm (nn.Module, optional): a normalization layer |
|
activation (callable(Tensor) -> Tensor): a callable activation function |
|
|
|
It assumes that norm layer is used before activation. |
|
""" |
|
norm = kwargs.pop("norm", None) |
|
activation = kwargs.pop("activation", None) |
|
super().__init__(*args, **kwargs) |
|
|
|
self.norm = norm |
|
self.activation = activation |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if not torch.jit.is_scripting(): |
|
if x.numel() == 0 and self.training: |
|
|
|
assert not isinstance( |
|
self.norm, torch.nn.SyncBatchNorm |
|
), "SyncBatchNorm does not support empty inputs!" |
|
|
|
x = F.conv2d( |
|
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups |
|
) |
|
if self.norm is not None: |
|
x = self.norm(x) |
|
if self.activation is not None: |
|
x = self.activation(x) |
|
return x |
|
|
|
|
|
class ShapeSpec(namedtuple("_ShapeSpec", ["channels", "height", "width", "stride"])): |
|
""" |
|
A simple structure that contains basic shape specification about a tensor. |
|
It is often used as the auxiliary inputs/outputs of models, |
|
to complement the lack of shape inference ability among pytorch modules. |
|
|
|
Attributes: |
|
channels: |
|
height: |
|
width: |
|
stride: |
|
""" |
|
|
|
def __new__(cls, channels=None, height=None, width=None, stride=None): |
|
return super().__new__(cls, channels, height, width, stride) |
|
|
|
|
|
def get_norm(norm, out_channels): |
|
""" |
|
Args: |
|
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; |
|
or a callable that takes a channel number and returns |
|
the normalization layer as a nn.Module. |
|
|
|
Returns: |
|
nn.Module or None: the normalization layer |
|
""" |
|
if norm is None: |
|
return None |
|
if isinstance(norm, str): |
|
if len(norm) == 0: |
|
return None |
|
norm = { |
|
|
|
|
|
|
|
|
|
"GN": lambda channels: nn.GroupNorm(32, channels), |
|
|
|
|
|
|
|
|
|
|
|
}[norm] |
|
return norm(out_channels) |
|
|
|
|
|
def c2_xavier_fill(module: nn.Module) -> None: |
|
""" |
|
Initialize `module.weight` using the "XavierFill" implemented in Caffe2. |
|
Also initializes `module.bias` to 0. |
|
|
|
Args: |
|
module (torch.nn.Module): module to initialize. |
|
""" |
|
|
|
|
|
nn.init.kaiming_uniform_(module.weight, a=1) |
|
if module.bias is not None: |
|
|
|
|
|
nn.init.constant_(module.bias, 0) |
|
|
|
|
|
def c2_msra_fill(module: nn.Module) -> None: |
|
""" |
|
Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. |
|
Also initializes `module.bias` to 0. |
|
|
|
Args: |
|
module (torch.nn.Module): module to initialize. |
|
""" |
|
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") |
|
if module.bias is not None: |
|
|
|
|
|
nn.init.constant_(module.bias, 0) |
|
|
|
|
|
def cat(tensors: List[torch.Tensor], dim: int = 0): |
|
""" |
|
Efficient version of torch.cat that avoids a copy if there is only a single element in a list |
|
""" |
|
assert isinstance(tensors, (list, tuple)) |
|
if len(tensors) == 1: |
|
return tensors[0] |
|
return torch.cat(tensors, dim) |