|
"""Adapted from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py""" |
|
|
|
import math |
|
|
|
from functools import partial |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
import triton |
|
import triton.language as tl |
|
|
|
from torch.distributed._tensor import Partial, Replicate, Shard |
|
from torch.distributed._tensor.experimental import local_map |
|
from torch._utils import _get_available_device_type, _get_device_module |
|
|
|
|
|
def get_device_info(): |
|
device_type = _get_available_device_type() |
|
|
|
if device_type is None: |
|
device_type = "cuda" |
|
|
|
device_module = _get_device_module(device_type) |
|
return device_type, device_module |
|
|
|
device_type, device_module = get_device_info() |
|
|
|
def build_norm(norm_type: str, dim: int, eps: float = 1e-6): |
|
""" |
|
Builds the specified normalization layer based on the norm_type. |
|
|
|
Args: |
|
norm_type (str): The type of normalization layer to build. |
|
Supported types: layernorm, np_layernorm, rmsnorm, fused_rmsnorm |
|
dim (int): The dimension of the normalization layer. |
|
eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. |
|
|
|
Returns: |
|
The built normalization layer. |
|
|
|
Raises: |
|
NotImplementedError: If an unknown norm_type is provided. |
|
""" |
|
norm_type = norm_type.lower() |
|
|
|
if norm_type == "layernorm": |
|
return nn.LayerNorm(dim, eps=eps, bias=False) |
|
elif norm_type == "np_layernorm": |
|
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) |
|
elif norm_type == "rmsnorm": |
|
return RMSNorm(dim, eps=eps) |
|
elif norm_type == "fused_rmsnorm": |
|
return FusedRMSNorm(dim, eps=eps) |
|
else: |
|
raise NotImplementedError(f"Unknown norm_type: '{norm_type}'") |
|
|
|
|
|
class FusedRMSNorm(nn.Module): |
|
"""Fused RMS Norm, wraps a fused Triton Kernel""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
eps: float = 1e-6, |
|
): |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
self.fused_rms_norm_fn = fused_rms_norm_fn |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""leverages Triton Fused RMS Norm kernel""" |
|
return self.fused_rms_norm_fn( |
|
x, |
|
self.weight, |
|
eps=self.eps, |
|
) |
|
|
|
def reset_parameters(self): |
|
torch.nn.init.ones_(self.weight) |
|
|
|
|
|
class RMSNorm(torch.nn.Module): |
|
def __init__(self, dim: int, eps: float = 1e-6): |
|
""" |
|
Initialize the RMSNorm normalization layer. |
|
|
|
Args: |
|
dim (int): The dimension of the input tensor. |
|
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. |
|
|
|
Attributes: |
|
eps (float): A small value added to the denominator for numerical stability. |
|
weight (nn.Parameter): Learnable scaling parameter. |
|
|
|
""" |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
def _norm(self, x): |
|
""" |
|
Apply the RMSNorm normalization to the input tensor. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor. |
|
|
|
Returns: |
|
torch.Tensor: The normalized tensor. |
|
|
|
""" |
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
def forward(self, x): |
|
""" |
|
Forward pass through the RMSNorm layer. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor. |
|
|
|
Returns: |
|
torch.Tensor: The output tensor after applying RMSNorm. |
|
|
|
""" |
|
output = self._norm(x.float()).type_as(x) |
|
return output * self.weight |
|
|
|
def reset_parameters(self): |
|
torch.nn.init.ones_(self.weight) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({}, num_warps=1), |
|
triton.Config({}, num_warps=2), |
|
triton.Config({}, num_warps=4), |
|
triton.Config({}, num_warps=8), |
|
triton.Config({}, num_warps=16), |
|
triton.Config({}, num_warps=32), |
|
], |
|
key=["N"], |
|
) |
|
@triton.jit |
|
def _rms_norm_fwd_kernel( |
|
X, |
|
stride_x, |
|
Y, |
|
stride_y, |
|
W, |
|
Rstd, |
|
eps, |
|
M, |
|
N, |
|
block_N: tl.constexpr, |
|
): |
|
row = tl.program_id(0) |
|
cols = tl.arange(0, block_N) |
|
|
|
|
|
mask = cols < N |
|
x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) |
|
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) |
|
|
|
|
|
xbar = tl.where(cols < N, x, 0.0) |
|
var = tl.sum(xbar * xbar, axis=0) / N |
|
rstd = 1 / tl.sqrt(var + eps) |
|
|
|
|
|
tl.store(Rstd + row, rstd) |
|
|
|
|
|
x_hat = x * rstd |
|
y = x_hat * w |
|
|
|
|
|
tl.store(Y + row * stride_y + cols, y, mask=mask) |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({}, num_warps=1), |
|
triton.Config({}, num_warps=2), |
|
triton.Config({}, num_warps=4), |
|
triton.Config({}, num_warps=8), |
|
triton.Config({}, num_warps=16), |
|
triton.Config({}, num_warps=32), |
|
], |
|
key=["N"], |
|
) |
|
@triton.jit |
|
def _rms_norm_bwd_kernel_sm( |
|
X, |
|
stride_x, |
|
W, |
|
DY, |
|
stride_dy, |
|
DX, |
|
stride_dx, |
|
Rstd, |
|
DW, |
|
eps, |
|
M, |
|
N, |
|
rows_per_program, |
|
block_N: tl.constexpr, |
|
): |
|
row_block_id = tl.program_id(0) |
|
row_start = row_block_id * rows_per_program |
|
cols = tl.arange(0, block_N) |
|
mask = cols < N |
|
|
|
|
|
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) |
|
|
|
|
|
dw = tl.zeros((block_N,), dtype=tl.float32) |
|
|
|
row_end = min(row_start + rows_per_program, M) |
|
for row in range(row_start, row_end): |
|
|
|
x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) |
|
dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) |
|
rstd = tl.load(Rstd + row) |
|
|
|
|
|
x_hat = x * rstd |
|
wdy = w * dy |
|
dw += dy * x_hat |
|
c1 = tl.sum(x_hat * wdy, axis=0) / N |
|
dx = (wdy - x_hat * c1) * rstd |
|
|
|
|
|
tl.store(DX + row * stride_dx + cols, dx, mask=mask) |
|
|
|
|
|
tl.store(DW + row_block_id * N + cols, dw, mask=mask) |
|
|
|
|
|
class TritonFusedRMSNorm(torch.autograd.Function): |
|
@partial( |
|
local_map, |
|
out_placements=[Shard(1)], |
|
in_placements=(None, [Shard(1)], [Replicate()], None), |
|
) |
|
@staticmethod |
|
def forward(ctx, x, weight, eps): |
|
x_shape_start = x.shape |
|
|
|
|
|
x = x.view(-1, x.shape[-1]) |
|
if x.stride(-1) != 1: |
|
x = x.contiguous() |
|
if weight.stride(-1) != 1: |
|
weight = weight.contiguous() |
|
|
|
M, N = x.shape |
|
y = torch.empty_like(x) |
|
rstd = torch.empty((M,), dtype=torch.float32, device=x.device) |
|
|
|
max_size = 65536 // x.element_size() |
|
block_N = min(max_size, triton.next_power_of_2(N)) |
|
|
|
if N > block_N: |
|
raise ValueError(f"N {N} must be <= {block_N=}") |
|
|
|
grid = lambda meta: (M,) |
|
_rms_norm_fwd_kernel[grid]( |
|
x, |
|
x.stride(0), |
|
y, |
|
y.stride(0), |
|
weight, |
|
rstd, |
|
eps, |
|
M, |
|
N, |
|
block_N, |
|
) |
|
|
|
ctx.eps = eps |
|
ctx.save_for_backward(x, weight, rstd) |
|
ctx.x_shape_start = x_shape_start |
|
|
|
y = y.reshape(x_shape_start) |
|
return y |
|
|
|
@partial( |
|
local_map, |
|
out_placements=([Shard(1)], [Partial()], None), |
|
in_placements=(None, [Shard(1)]), |
|
) |
|
@staticmethod |
|
def backward(ctx, dy): |
|
x, weight, rstd = ctx.saved_tensors |
|
eps = ctx.eps |
|
x_shape_start = ctx.x_shape_start |
|
|
|
|
|
dy = dy.view(-1, dy.shape[-1]) |
|
if dy.stride(-1) != 1: |
|
dy = dy.contiguous() |
|
|
|
M, N = dy.shape |
|
dx = torch.empty_like(x) |
|
|
|
sm_count = device_module.get_device_properties(x.device).multi_processor_count |
|
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) |
|
|
|
max_size = 65536 // x.element_size() |
|
block_N = min(max_size, triton.next_power_of_2(N)) |
|
rows_per_sm = math.ceil(M / sm_count) |
|
|
|
if N > block_N: |
|
raise ValueError(f"N {N} must be <= {block_N=}") |
|
|
|
grid = lambda meta: (sm_count,) |
|
_rms_norm_bwd_kernel_sm[grid]( |
|
x, |
|
x.stride(0), |
|
weight, |
|
dy, |
|
dy.stride(0), |
|
dx, |
|
dx.stride(0), |
|
rstd, |
|
_dw, |
|
eps, |
|
M, |
|
N, |
|
rows_per_sm, |
|
block_N, |
|
) |
|
dw = _dw.sum(0).to(weight.dtype) |
|
dx = dx.view(x_shape_start) |
|
return dx, dw, None |
|
|
|
|
|
|
|
def fused_rms_norm_fn( |
|
x, |
|
weight, |
|
eps=1e-6, |
|
): |
|
return TritonFusedRMSNorm.apply( |
|
x, |
|
weight, |
|
eps, |
|
) |