File size: 9,558 Bytes
9991887 4d7d25c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 |
"""Adapted from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py"""
import math
from functools import partial
import torch
import torch.nn as nn
import triton
import triton.language as tl
from torch.distributed._tensor import Partial, Replicate, Shard
from torch.distributed._tensor.experimental import local_map
from torch._utils import _get_available_device_type, _get_device_module
def get_device_info():
device_type = _get_available_device_type()
if device_type is None:
device_type = "cuda" # Default to CUDA
device_module = _get_device_module(device_type)
return device_type, device_module
device_type, device_module = get_device_info()
def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
"""
Builds the specified normalization layer based on the norm_type.
Args:
norm_type (str): The type of normalization layer to build.
Supported types: layernorm, np_layernorm, rmsnorm, fused_rmsnorm
dim (int): The dimension of the normalization layer.
eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
Returns:
The built normalization layer.
Raises:
NotImplementedError: If an unknown norm_type is provided.
"""
norm_type = norm_type.lower() # Normalize to lowercase
if norm_type == "layernorm":
return nn.LayerNorm(dim, eps=eps, bias=False)
elif norm_type == "np_layernorm":
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
elif norm_type == "rmsnorm":
return RMSNorm(dim, eps=eps)
elif norm_type == "fused_rmsnorm":
return FusedRMSNorm(dim, eps=eps)
else:
raise NotImplementedError(f"Unknown norm_type: '{norm_type}'")
class FusedRMSNorm(nn.Module):
"""Fused RMS Norm, wraps a fused Triton Kernel"""
def __init__(
self,
dim: int,
eps: float = 1e-6,
):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.fused_rms_norm_fn = fused_rms_norm_fn
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""leverages Triton Fused RMS Norm kernel"""
return self.fused_rms_norm_fn(
x,
self.weight,
eps=self.eps,
)
def reset_parameters(self):
torch.nn.init.ones_(self.weight) # type: ignore
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
return output * self.weight
def reset_parameters(self):
torch.nn.init.ones_(self.weight) # type: ignore
# FusedRMSNorm in Triton
# Credit
# Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
# Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N"],
)
@triton.jit
def _rms_norm_fwd_kernel(
X,
stride_x,
Y,
stride_y,
W,
Rstd,
eps,
M, # num rows
N, # num cols
block_N: tl.constexpr,
):
row = tl.program_id(0)
cols = tl.arange(0, block_N)
# Load input data and weights
mask = cols < N
x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
# Compute mean and variance
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# Store the reciprocal standard deviation
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
x_hat = x * rstd
y = x_hat * w
# Write output
tl.store(Y + row * stride_y + cols, y, mask=mask)
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N"],
)
@triton.jit
def _rms_norm_bwd_kernel_sm(
X,
stride_x,
W,
DY,
stride_dy,
DX,
stride_dx,
Rstd,
DW,
eps,
M, # num rows
N, # num cols
rows_per_program,
block_N: tl.constexpr,
):
row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program
cols = tl.arange(0, block_N)
mask = cols < N
# Load weights
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
# Accumulate gradients for weights
dw = tl.zeros((block_N,), dtype=tl.float32)
row_end = min(row_start + rows_per_program, M)
for row in range(row_start, row_end):
# Load input, output gradient, and reciprocal standard deviation
x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32)
rstd = tl.load(Rstd + row)
# Compute normalized input and gradients
x_hat = x * rstd
wdy = w * dy
dw += dy * x_hat
c1 = tl.sum(x_hat * wdy, axis=0) / N
dx = (wdy - x_hat * c1) * rstd
# Store input gradient
tl.store(DX + row * stride_dx + cols, dx, mask=mask)
# Store weight gradients
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
class TritonFusedRMSNorm(torch.autograd.Function):
@partial(
local_map,
out_placements=[Shard(1)],
in_placements=(None, [Shard(1)], [Replicate()], None),
)
@staticmethod
def forward(ctx, x, weight, eps):
x_shape_start = x.shape
# Flatten input
x = x.view(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if weight.stride(-1) != 1:
weight = weight.contiguous()
M, N = x.shape
y = torch.empty_like(x)
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
max_size = 65536 // x.element_size()
block_N = min(max_size, triton.next_power_of_2(N))
if N > block_N:
raise ValueError(f"N {N} must be <= {block_N=}")
grid = lambda meta: (M,)
_rms_norm_fwd_kernel[grid](
x,
x.stride(0),
y,
y.stride(0),
weight,
rstd,
eps,
M,
N,
block_N,
)
ctx.eps = eps
ctx.save_for_backward(x, weight, rstd)
ctx.x_shape_start = x_shape_start
y = y.reshape(x_shape_start)
return y
@partial(
local_map,
out_placements=([Shard(1)], [Partial()], None),
in_placements=(None, [Shard(1)]),
)
@staticmethod
def backward(ctx, dy):
x, weight, rstd = ctx.saved_tensors
eps = ctx.eps
x_shape_start = ctx.x_shape_start
# Flatten input and output gradients
dy = dy.view(-1, dy.shape[-1])
if dy.stride(-1) != 1:
dy = dy.contiguous()
M, N = dy.shape
dx = torch.empty_like(x)
sm_count = device_module.get_device_properties(x.device).multi_processor_count
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
max_size = 65536 // x.element_size()
block_N = min(max_size, triton.next_power_of_2(N))
rows_per_sm = math.ceil(M / sm_count)
if N > block_N:
raise ValueError(f"N {N} must be <= {block_N=}")
grid = lambda meta: (sm_count,)
_rms_norm_bwd_kernel_sm[grid](
x,
x.stride(0),
weight,
dy,
dy.stride(0),
dx,
dx.stride(0),
rstd,
_dw,
eps,
M,
N,
rows_per_sm,
block_N,
)
dw = _dw.sum(0).to(weight.dtype)
dx = dx.view(x_shape_start)
return dx, dw, None
# expose fusedRMSNorm as a function
def fused_rms_norm_fn(
x,
weight,
eps=1e-6,
):
return TritonFusedRMSNorm.apply(
x,
weight,
eps,
) |