|
|
|
"""
|
|
Wrappers around on some nn functions, mainly to support empty tensors.
|
|
|
|
Ideally, add support directly in PyTorch to empty tensors in those functions.
|
|
|
|
These can be removed once https://github.com/pytorch/pytorch/issues/12013
|
|
is implemented
|
|
"""
|
|
|
|
import warnings
|
|
from typing import List, Optional
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
from detectron2.utils.env import TORCH_VERSION
|
|
|
|
|
|
def shapes_to_tensor(x: List[int], device: Optional[torch.device] = None) -> torch.Tensor:
|
|
"""
|
|
Turn a list of integer scalars or integer Tensor scalars into a vector,
|
|
in a way that's both traceable and scriptable.
|
|
|
|
In tracing, `x` should be a list of scalar Tensor, so the output can trace to the inputs.
|
|
In scripting or eager, `x` should be a list of int.
|
|
"""
|
|
if torch.jit.is_scripting():
|
|
return torch.as_tensor(x, device=device)
|
|
if torch.jit.is_tracing():
|
|
assert all(
|
|
[isinstance(t, torch.Tensor) for t in x]
|
|
), "Shape should be tensor during tracing!"
|
|
|
|
ret = torch.stack(x)
|
|
if ret.device != device:
|
|
ret = ret.to(device=device)
|
|
return ret
|
|
return torch.as_tensor(x, device=device)
|
|
|
|
|
|
def check_if_dynamo_compiling():
|
|
if TORCH_VERSION >= (1, 14):
|
|
from torch._dynamo import is_compiling
|
|
|
|
return is_compiling()
|
|
else:
|
|
return False
|
|
|
|
|
|
def cat(tensors: List[torch.Tensor], dim: int = 0):
|
|
"""
|
|
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
|
|
"""
|
|
assert isinstance(tensors, (list, tuple))
|
|
if len(tensors) == 1:
|
|
return tensors[0]
|
|
return torch.cat(tensors, dim)
|
|
|
|
|
|
def empty_input_loss_func_wrapper(loss_func):
|
|
def wrapped_loss_func(input, target, *, reduction="mean", **kwargs):
|
|
"""
|
|
Same as `loss_func`, but returns 0 (instead of nan) for empty inputs.
|
|
"""
|
|
if target.numel() == 0 and reduction == "mean":
|
|
return input.sum() * 0.0
|
|
return loss_func(input, target, reduction=reduction, **kwargs)
|
|
|
|
return wrapped_loss_func
|
|
|
|
|
|
cross_entropy = empty_input_loss_func_wrapper(F.cross_entropy)
|
|
|
|
|
|
class _NewEmptyTensorOp(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, new_shape):
|
|
ctx.shape = x.shape
|
|
return x.new_empty(new_shape)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad):
|
|
shape = ctx.shape
|
|
return _NewEmptyTensorOp.apply(grad, shape), None
|
|
|
|
|
|
class Conv2d(torch.nn.Conv2d):
|
|
"""
|
|
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
"""
|
|
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
|
|
|
|
Args:
|
|
norm (nn.Module, optional): a normalization layer
|
|
activation (callable(Tensor) -> Tensor): a callable activation function
|
|
|
|
It assumes that norm layer is used before activation.
|
|
"""
|
|
norm = kwargs.pop("norm", None)
|
|
activation = kwargs.pop("activation", None)
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.norm = norm
|
|
self.activation = activation
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not torch.jit.is_scripting():
|
|
|
|
is_dynamo_compiling = check_if_dynamo_compiling()
|
|
if not is_dynamo_compiling:
|
|
with warnings.catch_warnings(record=True):
|
|
if x.numel() == 0 and self.training:
|
|
|
|
assert not isinstance(
|
|
self.norm, torch.nn.SyncBatchNorm
|
|
), "SyncBatchNorm does not support empty inputs!"
|
|
|
|
x = F.conv2d(
|
|
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
|
)
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
if self.activation is not None:
|
|
x = self.activation(x)
|
|
return x
|
|
|
|
|
|
ConvTranspose2d = torch.nn.ConvTranspose2d
|
|
BatchNorm2d = torch.nn.BatchNorm2d
|
|
interpolate = F.interpolate
|
|
Linear = torch.nn.Linear
|
|
|
|
|
|
def nonzero_tuple(x):
|
|
"""
|
|
A 'as_tuple=True' version of torch.nonzero to support torchscript.
|
|
because of https://github.com/pytorch/pytorch/issues/38718
|
|
"""
|
|
if torch.jit.is_scripting():
|
|
if x.dim() == 0:
|
|
return x.unsqueeze(0).nonzero().unbind(1)
|
|
return x.nonzero().unbind(1)
|
|
else:
|
|
return x.nonzero(as_tuple=True)
|
|
|
|
|
|
@torch.jit.script_if_tracing
|
|
def move_device_like(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Tracing friendly way to cast tensor to another tensor's device. Device will be treated
|
|
as constant during tracing, scripting the casting process as whole can workaround this issue.
|
|
"""
|
|
return src.to(dst.device)
|
|
|