File size: 6,040 Bytes
786f6a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
""" Normalization layers and wrappers
Norm layer definitions that support fast norm and consistent channel arg order (always first arg).
Hacked together by / Copyright 2022 Ross Wightman
"""
import numbers
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm
class GroupNorm(nn.GroupNorm):
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x):
if self.fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
class GroupNorm1(nn.GroupNorm):
""" Group Normalization with 1 group.
Input: tensor in shape [B, C, *]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
class LayerNorm(nn.LayerNorm):
""" LayerNorm w/ fast norm option
"""
def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x
class LayerNorm2d(nn.LayerNorm):
""" LayerNorm for channels of '2D' spatial NCHW tensors """
def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
return x
def _is_contiguous(tensor: torch.Tensor) -> bool:
# jit is oh so lovely :/
if torch.jit.is_scripting():
return tensor.is_contiguous()
else:
return tensor.is_contiguous(memory_format=torch.contiguous_format)
@torch.jit.script
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
x = (x - u) * torch.rsqrt(s + eps)
x = x * weight[:, None, None] + bias[:, None, None]
return x
def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
u = x.mean(dim=1, keepdim=True)
s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0)
x = (x - u) * torch.rsqrt(s + eps)
x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1)
return x
class LayerNormExp2d(nn.LayerNorm):
""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
Experimental implementation w/ manual norm for tensors non-contiguous tensors.
This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last
layout. However, benefits are not always clear and can perform worse on other GPUs.
"""
def __init__(self, num_channels, eps=1e-6):
super().__init__(num_channels, eps=eps)
def forward(self, x) -> torch.Tensor:
if _is_contiguous(x):
x = F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
else:
x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
return x
class RmsNorm(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
normalized_shape = channels
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter('weight', None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
return x
|