venite's picture
initial
f670afc
# flake8: noqa
import numpy as np
from types import SimpleNamespace
import torch
from torch import nn
import bias_act_cuda
# ----------------------------------------------------------------------------
activation_funcs = {
'linear': SimpleNamespace(func=lambda x, **_: x, def_alpha=0, def_gain=1,
cuda_idx=1, ref='', has_2nd_grad=False),
'relu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.relu(x),
def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2,
ref='y', has_2nd_grad=False),
'leakyrelu': SimpleNamespace(
func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y',
has_2nd_grad=False),
'tanh': SimpleNamespace(func=lambda x, **_: torch.tanh(x), def_alpha=0,
def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
'sigmoid': SimpleNamespace(func=lambda x, **_: torch.sigmoid(x),
def_alpha=0, def_gain=1, cuda_idx=5, ref='y',
has_2nd_grad=True),
'elu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.elu(x),
def_alpha=0, def_gain=1, cuda_idx=6, ref='y',
has_2nd_grad=True),
'selu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.selu(x),
def_alpha=0, def_gain=1, cuda_idx=7, ref='y',
has_2nd_grad=True),
'softplus': SimpleNamespace(
func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0,
def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
'swish': SimpleNamespace(func=lambda x, **_: torch.sigmoid(x) * x,
def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9,
ref='x', has_2nd_grad=True),
}
# ----------------------------------------------------------------------------
_null_tensor = torch.empty([0])
def _bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None,
impl='cuda'):
assert isinstance(x, torch.Tensor)
assert impl in ['ref', 'cuda']
if impl == 'cuda' and x.device.type == 'cuda':
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain,
clamp=clamp).apply(x, b)
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain,
clamp=clamp)
# ----------------------------------------------------------------------------
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
assert isinstance(x, torch.Tensor)
assert clamp is None or clamp >= 0
spec = activation_funcs[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)
# Add bias.
if b is not None:
assert isinstance(b, torch.Tensor) and b.ndim == 1
assert 0 <= dim < x.ndim
assert b.shape[0] == x.shape[dim]
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
# Evaluate activation function.
alpha = float(alpha)
x = spec.func(x, alpha=alpha)
# Scale by gain.
gain = float(gain)
if gain != 1:
x = x * gain
# Clamp.
if clamp >= 0:
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
return x
# ----------------------------------------------------------------------------
_bias_act_cuda_cache = dict()
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
"""Fast CUDA implementation of `bias_act()` using custom ops.
"""
# Parse arguments.
assert clamp is None or clamp >= 0
spec = activation_funcs[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)
# Lookup from cache.
key = (dim, act, alpha, gain, clamp)
if key in _bias_act_cuda_cache:
return _bias_act_cuda_cache[key]
# Forward op.
class BiasActCuda(torch.autograd.Function):
@staticmethod
def forward(ctx, x, b): # pylint: disable=arguments-differ
if x.ndim > 2 and x.stride()[1] == 1:
ctx.memory_format = torch.channels_last
else:
ctx.memory_format = torch.contiguous_format
x = x.contiguous(memory_format=ctx.memory_format)
b = b.contiguous() if b is not None else _null_tensor
y = x
if act != 'linear' or gain != 1 or clamp >= 0 or b is not \
_null_tensor:
y = bias_act_cuda.bias_act_cuda(x, b, _null_tensor, _null_tensor,
_null_tensor, 0, dim, spec.cuda_idx, alpha,
gain, clamp)
ctx.save_for_backward(
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
y if 'y' in spec.ref else _null_tensor)
return y
@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
dy = dy.contiguous(memory_format=ctx.memory_format)
x, b, y = ctx.saved_tensors
dx = None
db = None
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
dx = dy
if act != 'linear' or gain != 1 or clamp >= 0:
dx = BiasActCudaGrad.apply(dy, x, b, y)
if ctx.needs_input_grad[1]:
db = dx.sum([i for i in range(dx.ndim) if i != dim])
return dx, db
# Backward op.
class BiasActCudaGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
if x.ndim > 2 and x.stride()[1] == 1:
ctx.memory_format = torch.channels_last
else:
ctx.memory_format = torch.contiguous_format
dx = bias_act_cuda.bias_act_cuda(dy, b, x, y, _null_tensor, 1, dim,
spec.cuda_idx, alpha, gain, clamp)
ctx.save_for_backward(
dy if spec.has_2nd_grad else _null_tensor,
x, b, y)
return dx
@staticmethod
def backward(ctx, d_dx): # pylint: disable=arguments-differ
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
dy, x, b, y = ctx.saved_tensors
d_dy = None
d_x = None
d_b = None
d_y = None
if ctx.needs_input_grad[0]:
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
if spec.has_2nd_grad and (
ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
d_x = bias_act_cuda.bias_act_cuda(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx,
alpha, gain, clamp)
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
return d_dy, d_x, d_b, d_y
# Add to cache.
_bias_act_cuda_cache[key] = BiasActCuda
return BiasActCuda
class FusedNonlinearity(nn.Module):
def __init__(self, nonlinearity, num_channels=None, lr_mul=1.0, alpha=None, impl='cuda', gain=None):
super().__init__()
if num_channels is not None:
self.bias = nn.Parameter(torch.zeros(num_channels))
else:
self.register_parameter('bias', None)
self.nonlinearity = nonlinearity
self.gain = gain
self.alpha = alpha
self.lr_mul = lr_mul
self.impl = impl
def forward(self, x):
bias = self.bias.type_as(x) * self.lr_mul if self.bias is not None else None
return _bias_act(
x, b=bias, dim=1, act=self.nonlinearity,
alpha=self.alpha, gain=self.gain, clamp=None, impl=self.impl
)
def __repr__(self):
mod_str = f'{self.__class__.__name__}(type={self.nonlinearity}'
if self.gain is not None:
mod_str += f', gain={self.gain}'
if self.alpha is not None:
mod_str += f', alpha={self.alpha}'
if self.lr_mul != 1:
mod_str += f', lr_mul={self.lr_mul}'
mod_str += ')'
return mod_str