from typing import List, Optional import torch import torch.nn as nn from torch.nn import functional as F from collections import namedtuple class Conv2d(torch.nn.Conv2d): """ A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. """ def __init__(self, *args, **kwargs): """ Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: Args: norm (nn.Module, optional): a normalization layer activation (callable(Tensor) -> Tensor): a callable activation function It assumes that norm layer is used before activation. """ norm = kwargs.pop("norm", None) activation = kwargs.pop("activation", None) super().__init__(*args, **kwargs) self.norm = norm self.activation = activation def forward(self, x): # torchscript does not support SyncBatchNorm yet # https://github.com/pytorch/pytorch/issues/40507 # and we skip these codes in torchscript since: # 1. currently we only support torchscript in evaluation mode # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. if not torch.jit.is_scripting(): if x.numel() == 0 and self.training: # https://github.com/pytorch/pytorch/issues/12013 assert not isinstance( self.norm, torch.nn.SyncBatchNorm ), "SyncBatchNorm does not support empty inputs!" x = F.conv2d( x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) if self.norm is not None: x = self.norm(x) if self.activation is not None: x = self.activation(x) return x class ShapeSpec(namedtuple("_ShapeSpec", ["channels", "height", "width", "stride"])): """ A simple structure that contains basic shape specification about a tensor. It is often used as the auxiliary inputs/outputs of models, to complement the lack of shape inference ability among pytorch modules. Attributes: channels: height: width: stride: """ def __new__(cls, channels=None, height=None, width=None, stride=None): return super().__new__(cls, channels, height, width, stride) def get_norm(norm, out_channels): # todo: replace with syncbn """ Args: norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; or a callable that takes a channel number and returns the normalization layer as a nn.Module. Returns: nn.Module or None: the normalization layer """ if norm is None: return None if isinstance(norm, str): if len(norm) == 0: return None norm = { # "BN": BatchNorm2d, # # Fixed in https://github.com/pytorch/pytorch/pull/36382 # "SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm, # "FrozenBN": FrozenBatchNorm2d, "GN": lambda channels: nn.GroupNorm(32, channels), # for debugging: # "nnSyncBN": nn.SyncBatchNorm, # "naiveSyncBN": NaiveSyncBatchNorm, # # expose stats_mode N as an option to caller, required for zero-len inputs # "naiveSyncBN_N": lambda channels: NaiveSyncBatchNorm(channels, stats_mode="N"), }[norm] return norm(out_channels) def c2_xavier_fill(module: nn.Module) -> None: """ Initialize `module.weight` using the "XavierFill" implemented in Caffe2. Also initializes `module.bias` to 0. Args: module (torch.nn.Module): module to initialize. """ # Caffe2 implementation of XavierFill in fact # corresponds to kaiming_uniform_ in PyTorch nn.init.kaiming_uniform_(module.weight, a=1) if module.bias is not None: # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, # torch.Tensor]`. nn.init.constant_(module.bias, 0) def c2_msra_fill(module: nn.Module) -> None: """ Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. Also initializes `module.bias` to 0. Args: module (torch.nn.Module): module to initialize. """ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, # torch.Tensor]`. nn.init.constant_(module.bias, 0) def cat(tensors: List[torch.Tensor], dim: int = 0): """ Efficient version of torch.cat that avoids a copy if there is only a single element in a list """ assert isinstance(tensors, (list, tuple)) if len(tensors) == 1: return tensors[0] return torch.cat(tensors, dim)