tuandunghcmut's picture
Upload folder using huggingface_hub
345ee20 verified
from typing import List, Optional
import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import namedtuple
class Conv2d(torch.nn.Conv2d):
"""
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
"""
def __init__(self, *args, **kwargs):
"""
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
Args:
norm (nn.Module, optional): a normalization layer
activation (callable(Tensor) -> Tensor): a callable activation function
It assumes that norm layer is used before activation.
"""
norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None)
super().__init__(*args, **kwargs)
self.norm = norm
self.activation = activation
def forward(self, x):
# torchscript does not support SyncBatchNorm yet
# https://github.com/pytorch/pytorch/issues/40507
# and we skip these codes in torchscript since:
# 1. currently we only support torchscript in evaluation mode
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
if not torch.jit.is_scripting():
if x.numel() == 0 and self.training:
# https://github.com/pytorch/pytorch/issues/12013
assert not isinstance(
self.norm, torch.nn.SyncBatchNorm
), "SyncBatchNorm does not support empty inputs!"
x = F.conv2d(
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x
class ShapeSpec(namedtuple("_ShapeSpec", ["channels", "height", "width", "stride"])):
"""
A simple structure that contains basic shape specification about a tensor.
It is often used as the auxiliary inputs/outputs of models,
to complement the lack of shape inference ability among pytorch modules.
Attributes:
channels:
height:
width:
stride:
"""
def __new__(cls, channels=None, height=None, width=None, stride=None):
return super().__new__(cls, channels, height, width, stride)
def get_norm(norm, out_channels): # todo: replace with syncbn
"""
Args:
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
or a callable that takes a channel number and returns
the normalization layer as a nn.Module.
Returns:
nn.Module or None: the normalization layer
"""
if norm is None:
return None
if isinstance(norm, str):
if len(norm) == 0:
return None
norm = {
# "BN": BatchNorm2d,
# # Fixed in https://github.com/pytorch/pytorch/pull/36382
# "SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm,
# "FrozenBN": FrozenBatchNorm2d,
"GN": lambda channels: nn.GroupNorm(32, channels),
# for debugging:
# "nnSyncBN": nn.SyncBatchNorm,
# "naiveSyncBN": NaiveSyncBatchNorm,
# # expose stats_mode N as an option to caller, required for zero-len inputs
# "naiveSyncBN_N": lambda channels: NaiveSyncBatchNorm(channels, stats_mode="N"),
}[norm]
return norm(out_channels)
def c2_xavier_fill(module: nn.Module) -> None:
"""
Initialize `module.weight` using the "XavierFill" implemented in Caffe2.
Also initializes `module.bias` to 0.
Args:
module (torch.nn.Module): module to initialize.
"""
# Caffe2 implementation of XavierFill in fact
# corresponds to kaiming_uniform_ in PyTorch
nn.init.kaiming_uniform_(module.weight, a=1)
if module.bias is not None:
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module,
# torch.Tensor]`.
nn.init.constant_(module.bias, 0)
def c2_msra_fill(module: nn.Module) -> None:
"""
Initialize `module.weight` using the "MSRAFill" implemented in Caffe2.
Also initializes `module.bias` to 0.
Args:
module (torch.nn.Module): module to initialize.
"""
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module,
# torch.Tensor]`.
nn.init.constant_(module.bias, 0)
def cat(tensors: List[torch.Tensor], dim: int = 0):
"""
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
"""
assert isinstance(tensors, (list, tuple))
if len(tensors) == 1:
return tensors[0]
return torch.cat(tensors, dim)