|
""" Normalization layers and wrappers |
|
|
|
Norm layer definitions that support fast norm and consistent channel arg order (always first arg). |
|
|
|
Hacked together by / Copyright 2022 Ross Wightman |
|
""" |
|
import numbers |
|
from typing import Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm |
|
|
|
|
|
class GroupNorm(nn.GroupNorm): |
|
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): |
|
|
|
super().__init__(num_groups, num_channels, eps=eps, affine=affine) |
|
self.fast_norm = is_fast_norm() |
|
|
|
def forward(self, x): |
|
if self.fast_norm: |
|
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) |
|
else: |
|
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) |
|
|
|
|
|
class GroupNorm1(nn.GroupNorm): |
|
""" Group Normalization with 1 group. |
|
Input: tensor in shape [B, C, *] |
|
""" |
|
|
|
def __init__(self, num_channels, **kwargs): |
|
super().__init__(1, num_channels, **kwargs) |
|
self.fast_norm = is_fast_norm() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.fast_norm: |
|
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) |
|
else: |
|
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) |
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
""" LayerNorm w/ fast norm option |
|
""" |
|
def __init__(self, num_channels, eps=1e-6, affine=True): |
|
super().__init__(num_channels, eps=eps, elementwise_affine=affine) |
|
self._fast_norm = is_fast_norm() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self._fast_norm: |
|
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
|
else: |
|
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
|
return x |
|
|
|
|
|
class LayerNorm2d(nn.LayerNorm): |
|
""" LayerNorm for channels of '2D' spatial NCHW tensors """ |
|
def __init__(self, num_channels, eps=1e-6, affine=True): |
|
super().__init__(num_channels, eps=eps, elementwise_affine=affine) |
|
self._fast_norm = is_fast_norm() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = x.permute(0, 2, 3, 1) |
|
if self._fast_norm: |
|
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
|
else: |
|
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
|
x = x.permute(0, 3, 1, 2) |
|
return x |
|
|
|
|
|
def _is_contiguous(tensor: torch.Tensor) -> bool: |
|
|
|
if torch.jit.is_scripting(): |
|
return tensor.is_contiguous() |
|
else: |
|
return tensor.is_contiguous(memory_format=torch.contiguous_format) |
|
|
|
|
|
@torch.jit.script |
|
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): |
|
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) |
|
x = (x - u) * torch.rsqrt(s + eps) |
|
x = x * weight[:, None, None] + bias[:, None, None] |
|
return x |
|
|
|
|
|
def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): |
|
u = x.mean(dim=1, keepdim=True) |
|
s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0) |
|
x = (x - u) * torch.rsqrt(s + eps) |
|
x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1) |
|
return x |
|
|
|
|
|
class LayerNormExp2d(nn.LayerNorm): |
|
""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). |
|
|
|
Experimental implementation w/ manual norm for tensors non-contiguous tensors. |
|
|
|
This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last |
|
layout. However, benefits are not always clear and can perform worse on other GPUs. |
|
""" |
|
|
|
def __init__(self, num_channels, eps=1e-6): |
|
super().__init__(num_channels, eps=eps) |
|
|
|
def forward(self, x) -> torch.Tensor: |
|
if _is_contiguous(x): |
|
x = F.layer_norm( |
|
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) |
|
else: |
|
x = _layer_norm_cf(x, self.weight, self.bias, self.eps) |
|
return x |
|
|
|
|
|
class RmsNorm(nn.Module): |
|
""" RmsNorm w/ fast (apex) norm if available |
|
""" |
|
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] |
|
normalized_shape: Tuple[int, ...] |
|
eps: float |
|
elementwise_affine: bool |
|
|
|
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: |
|
factory_kwargs = {'device': device, 'dtype': dtype} |
|
super().__init__() |
|
normalized_shape = channels |
|
if isinstance(normalized_shape, numbers.Integral): |
|
|
|
normalized_shape = (normalized_shape,) |
|
self.normalized_shape = tuple(normalized_shape) |
|
self.eps = eps |
|
self.elementwise_affine = affine |
|
if self.elementwise_affine: |
|
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) |
|
else: |
|
self.register_parameter('weight', None) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self) -> None: |
|
if self.elementwise_affine: |
|
nn.init.ones_(self.weight) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) |
|
return x |
|
|