File size: 27,005 Bytes
061483f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 |
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Tri Dao.
# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py
# Implement residual + layer_norm / rms_norm.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from fla.utils import contiguous
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
dtype = x.dtype
if upcast:
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
if residual is not None:
x = (x + residual).to(x.dtype)
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
dtype
)
return out if not prenorm else (out, x)
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
dtype = x.dtype
if upcast:
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
if residual is not None:
x = (x + residual).to(x.dtype)
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + \
bias if bias is not None else (x * rstd * weight)
out = out.to(dtype)
return out if not prenorm else (out, x)
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
O, # pointer to the gate
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
RESIDUAL, # pointer to the residual
RESIDUAL_OUT, # pointer to the residual
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_res_row,
stride_res_out_row,
N, # number of columns in X
eps, # epsilon to avoid division by zero
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
O += row * stride_x_row
if HAS_RESIDUAL:
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_RESIDUAL:
residual = tl.load(RESIDUAL + cols, mask=cols <
N, other=0.0).to(tl.float32)
x += residual
if STORE_RESIDUAL_OUT:
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
y = x_hat * w if HAS_WEIGHT else x_hat
if HAS_BIAS:
y = y + b
# Swish output gate
o = tl.load(O + cols, mask=cols < N, other=0.0).to(tl.float32)
y = y * o * tl.sigmoid(o)
# Write output
tl.store(Y + cols, y, mask=mask)
def _layer_norm_fwd(
x, o, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
):
if residual is not None:
residual_dtype = residual.dtype
M, N = x.shape
assert x.stride(-1) == 1
if residual is not None:
assert residual.stride(-1) == 1
assert residual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
assert y.stride(-1) == 1
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
assert residual_out.stride(-1) == 1
else:
residual_out = None
mean = torch.empty((M,), dtype=torch.float32,
device="cuda") if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[(M,)](
x,
o,
y,
weight,
bias,
residual,
residual_out,
mean,
rstd,
x.stride(0),
y.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
N,
eps,
is_rms_norm,
BLOCK_N,
residual is not None,
residual_out is not None,
weight is not None,
bias is not None,
)
# residual_out is None if residual is None and residual_dtype == input_dtype
return y, mean, rstd, residual_out if residual_out is not None else x
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit
def _layer_norm_bwd_kernel(
X, # pointer to the input
O, # pointer to the gate
W, # pointer to the weights
B, # pointer to the biases
Y, # pointer to the output to be recomputed
DY, # pointer to the output gradient
DX, # pointer to the input gradient
DO, # pointer to the gate gradient
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
DRESIDUAL,
DRESIDUAL_IN,
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_dy_row,
stride_dx_row,
stride_dres_row,
stride_dres_in_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
rows_per_program,
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_DRESIDUAL: tl.constexpr,
STORE_DRESIDUAL: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr,
):
# Map the program id to the elements of X, DX, and DY it should compute.
row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program
cols = tl.arange(0, BLOCK_N)
mask = cols < N
X += row_start * stride_x_row
O += row_start * stride_x_row
if HAS_DRESIDUAL:
DRESIDUAL += row_start * stride_dres_row
if STORE_DRESIDUAL:
DRESIDUAL_IN += row_start * stride_dres_in_row
DY += row_start * stride_dy_row
DX += row_start * stride_dx_row
DO += row_start * stride_dx_row
if RECOMPUTE_OUTPUT:
Y += row_start * stride_y_row
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
if RECOMPUTE_OUTPUT and HAS_BIAS:
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
if HAS_BIAS:
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
row_end = min((row_block_id + 1) * rows_per_program, M)
for row in range(row_start, row_end):
# Load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
o = tl.load(O + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
if not IS_RMS_NORM:
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# Compute dx
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
xhat = tl.where(mask, xhat, 0.0)
y = xhat * w if HAS_WEIGHT else xhat
if HAS_BIAS:
y = y + b
if RECOMPUTE_OUTPUT:
tl.store(Y + cols, y, mask=mask)
sigmoid_o = tl.sigmoid(o)
do = dy * y * (sigmoid_o + o * sigmoid_o * (1 - sigmoid_o))
dy = dy * o * sigmoid_o
wdy = dy
if HAS_WEIGHT:
wdy = dy * w
dw += dy * xhat
if HAS_BIAS:
db += dy
if not IS_RMS_NORM:
c1 = tl.sum(xhat * wdy, axis=0) / N
c2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * c1 + c2)) * rstd
else:
c1 = tl.sum(xhat * wdy, axis=0) / N
dx = (wdy - xhat * c1) * rstd
if HAS_DRESIDUAL:
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
dx += dres
# Write dx
if STORE_DRESIDUAL:
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
tl.store(DX + cols, dx, mask=mask)
tl.store(DO + cols, do, mask=mask)
X += stride_x_row
O += stride_x_row
if HAS_DRESIDUAL:
DRESIDUAL += stride_dres_row
if STORE_DRESIDUAL:
DRESIDUAL_IN += stride_dres_in_row
if RECOMPUTE_OUTPUT:
Y += stride_y_row
DY += stride_dy_row
DX += stride_dx_row
DO += stride_dx_row
if HAS_WEIGHT:
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
if HAS_BIAS:
tl.store(DB + row_block_id * N + cols, db, mask=mask)
def _layer_norm_bwd(
dy,
x,
o,
weight,
bias,
eps,
mean,
rstd,
dresidual=None,
has_residual=False,
is_rms_norm=False,
x_dtype=None,
recompute_output=False,
):
M, N = x.shape
assert x.stride(-1) == 1
assert dy.stride(-1) == 1
assert dy.shape == (M, N)
if dresidual is not None:
assert dresidual.stride(-1) == 1
assert dresidual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
dx = (
torch.empty_like(x)
if x_dtype is None
else torch.empty(M, N, dtype=x_dtype, device=x.device)
)
do = (
torch.empty_like(o)
if x_dtype is None
else torch.empty(M, N, dtype=x_dtype, device=x.device)
)
dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
_dw = (
torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
if weight is not None
else None
)
_db = (
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
if bias is not None
else None
)
rows_per_program = math.ceil(M / sm_count)
grid = (sm_count,)
with torch.cuda.device(x.device.index):
_layer_norm_bwd_kernel[grid](
x,
o,
weight,
bias,
y,
dy,
dx,
do,
_dw,
_db,
dresidual,
dresidual_in,
mean,
rstd,
x.stride(0),
0 if not recompute_output else y.stride(0),
dy.stride(0),
dx.stride(0),
dresidual.stride(0) if dresidual is not None else 0,
dresidual_in.stride(0) if dresidual_in is not None else 0,
M,
N,
eps,
rows_per_program,
is_rms_norm,
BLOCK_N,
dresidual is not None,
dresidual_in is not None,
weight is not None,
bias is not None,
)
dw = _dw.sum(0).to(weight.dtype) if weight is not None else None
db = _db.sum(0).to(bias.dtype) if bias is not None else None
# Don't need to compute dresidual_in separately in this case
if has_residual and dx.dtype == x.dtype:
dresidual_in = dx
return (dx, do, dw, db, dresidual_in) if not recompute_output else (dx, do, dw, db, dresidual_in, y)
class LayerNormSwishGateFn(torch.autograd.Function):
@staticmethod
@contiguous
def forward(
ctx,
x,
o,
weight,
bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
x_shape_og = x.shape
o_shape_og = o.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
o = o.reshape(-1, o.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
x, o, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
)
ctx.save_for_backward(residual_out, o, weight, bias, mean, rstd)
ctx.x_shape_og = x_shape_og
ctx.o_shape_og = o_shape_og
ctx.eps = eps
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
y = y.reshape(x_shape_og)
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
@staticmethod
@contiguous
def backward(ctx, dy, *args):
x, o, weight, bias, mean, rstd = ctx.saved_tensors
dy = dy.reshape(-1, dy.shape[-1])
assert dy.shape == x.shape
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
assert dresidual.shape == x.shape
else:
dresidual = None
dx, do, dw, db, dresidual_in = _layer_norm_bwd(
dy,
x,
o,
weight,
bias,
ctx.eps,
mean,
rstd,
dresidual,
ctx.has_residual,
ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
)
return (
dx.reshape(ctx.x_shape_og),
do.reshape(ctx.o_shape_og),
dw,
db,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
None,
None,
None,
None,
)
class LayerNormSwishGateLinearFn(torch.autograd.Function):
@staticmethod
@contiguous
def forward(
ctx,
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
x_shape_og = x.shape
o_shape_og = o.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
o = o.reshape(-1, o.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
x,
o,
norm_weight,
norm_bias,
eps,
residual,
residual_dtype=residual_dtype,
is_rms_norm=is_rms_norm
)
y = y.reshape(x_shape_og)
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
linear_weight = linear_weight.to(dtype)
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
# We don't store y, will be recomputed in the backward pass to save memory
ctx.save_for_backward(residual_out, o, norm_weight, norm_bias, linear_weight, mean, rstd)
ctx.x_shape_og = x_shape_og
ctx.o_shape_og = o_shape_og
ctx.eps = eps
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
ctx.linear_bias_is_none = linear_bias is None
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
@staticmethod
@contiguous
def backward(ctx, dout, *args):
x, o, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
dy = F.linear(dout, linear_weight.t())
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
assert dy.shape == x.shape
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
assert dresidual.shape == x.shape
else:
dresidual = None
dx, do, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
dy,
x,
o,
norm_weight,
norm_bias,
ctx.eps,
mean,
rstd,
dresidual=dresidual,
has_residual=ctx.has_residual,
is_rms_norm=ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
recompute_output=True,
)
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
return (
dx.reshape(ctx.x_shape_og),
do.reshape(ctx.o_shape_og),
dnorm_weight,
dnorm_bias,
dlinear_weight,
dlinear_bias,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
None,
None,
None,
None,
)
def layer_norm_swish_gate_fn(
x,
o,
weight,
bias,
residual=None,
prenorm=False,
residual_in_fp32=False,
eps=1e-6
):
return LayerNormSwishGateFn.apply(
x,
o,
weight,
bias,
residual,
eps,
prenorm,
residual_in_fp32,
False
)
def rms_norm_swish_gate_fn(
x,
o,
weight,
bias,
residual=None,
prenorm=False,
residual_in_fp32=False,
eps=1e-6
):
return LayerNormSwishGateFn.apply(
x,
o,
weight,
bias,
residual,
eps,
prenorm,
residual_in_fp32,
True
)
def layer_norm_swish_gate_linear_fn(
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
prenorm=False,
residual_in_fp32=False,
eps=1e-6
):
return LayerNormSwishGateLinearFn.apply(
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual,
eps,
prenorm,
residual_in_fp32,
False
)
def rms_norm_swish_gate_linear_fn(
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
prenorm=False,
residual_in_fp32=False,
eps=1e-6
):
return LayerNormSwishGateLinearFn.apply(
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual,
eps,
prenorm,
residual_in_fp32,
True
)
class FusedLayerNormSwishGate(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> FusedLayerNormSwishGate:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False):
return layer_norm_swish_gate_fn(
x,
o,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32
)
class FusedRMSNormSwishGate(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> FusedRMSNormSwishGate:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False):
return rms_norm_swish_gate_fn(
x,
o,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32
)
class FusedLayerNormSwishGateLinear(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> FusedLayerNormSwishGateLinear:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
return layer_norm_swish_gate_linear_fn(
x,
o,
self.weight,
self.bias,
weight,
bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32
)
class FusedRMSNormSwishGateLinear(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> FusedRMSNormSwishGateLinear:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
return rms_norm_swish_gate_linear_fn(
x,
o,
self.weight,
self.bias,
weight,
bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32
)
|