File size: 3,626 Bytes
ac3a9cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import torch
import torch.nn as nn
from .modules import STU
from .modules import MLP
from .modules import Attention
try:
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP
triton_mlp = True
except ImportError as e:
print(
f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead."
)
triton_mlp = False
try:
from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
triton_norm = True
except ImportError as e:
print(
f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
)
from torch.nn import RMSNorm
triton_norm = False
class STULayer(nn.Module):
def __init__(self, config, phi, n):
super(STULayer, self).__init__()
if isinstance(config.torch_dtype, str):
torch_dtype = getattr(torch, config.torch_dtype)
else:
torch_dtype = config.torch_dtype
self.stu_norm = (
TritonNorm(config.n_embd)
if triton_norm
else RMSNorm(config.n_embd, dtype=torch_dtype)
)
self.stu = STU(config, phi, n)
self.stu = self.stu.to(dtype=torch_dtype)
self.mlp_norm = (
TritonNorm(config.n_embd)
if triton_norm
else RMSNorm(config.n_embd, dtype=torch_dtype)
)
self.mlp = (
TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
)
# TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
self.stu_norm = self.stu_norm.to(dtype=torch_dtype)
self.mlp = self.mlp.to(dtype=torch_dtype)
self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Debug dtype
# Normalize and apply STU
x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype) # Match dtype for STU
x_stu = self.stu(x_normed).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
x = x + x_stu
# Normalize and apply MLP
x_normed_mlp = self.mlp_norm(x).to(dtype=self.mlp.gate_proj.weight.dtype) # Match dtype for MLP
x_mlp = self.mlp(x_normed_mlp).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
x = x + x_mlp
return x
class AttentionLayer(nn.Module):
def __init__(self, config) -> None:
super(AttentionLayer, self).__init__()
if isinstance(config.torch_dtype, str):
torch_dtype = getattr(torch, config.torch_dtype)
else:
torch_dtype = config.torch_dtype
self.attn_norm = (
TritonNorm(config.n_embd)
if triton_norm
else RMSNorm(config.n_embd, dtype=torch_dtype)
)
self.attn = Attention(config)
self.attn = self.attn.to(dtype=torch_dtype)
self.mlp_norm = (
TritonNorm(config.n_embd)
if triton_norm
else RMSNorm(config.n_embd, dtype=torch_dtype)
)
self.mlp = (
TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
)
self.mlp = self.mlp.to(dtype=torch_dtype)
# TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
self.attn_norm = self.attn_norm.to(dtype=torch_dtype)
self.mlp = self.mlp.to(dtype=torch_dtype)
self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.attn_norm(x))
x = x + self.mlp(self.mlp_norm(x))
return x
|