|
import torch |
|
import torch.nn as nn |
|
|
|
from .modules import STU |
|
from .modules import MLP |
|
from .modules import Attention |
|
try: |
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP |
|
triton_mlp = True |
|
except ImportError as e: |
|
print( |
|
f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead." |
|
) |
|
triton_mlp = False |
|
|
|
try: |
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm |
|
triton_norm = True |
|
except ImportError as e: |
|
print( |
|
f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation." |
|
) |
|
from torch.nn import RMSNorm |
|
triton_norm = False |
|
|
|
|
|
class STULayer(nn.Module): |
|
def __init__(self, config, phi, n): |
|
super(STULayer, self).__init__() |
|
if isinstance(config.torch_dtype, str): |
|
torch_dtype = getattr(torch, config.torch_dtype) |
|
else: |
|
torch_dtype = config.torch_dtype |
|
self.stu_norm = ( |
|
TritonNorm(config.n_embd) |
|
if triton_norm |
|
else RMSNorm(config.n_embd, dtype=torch_dtype) |
|
) |
|
self.stu = STU(config, phi, n) |
|
self.stu = self.stu.to(dtype=torch_dtype) |
|
self.mlp_norm = ( |
|
TritonNorm(config.n_embd) |
|
if triton_norm |
|
else RMSNorm(config.n_embd, dtype=torch_dtype) |
|
) |
|
self.mlp = ( |
|
TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype) |
|
) |
|
|
|
|
|
self.stu_norm = self.stu_norm.to(dtype=torch_dtype) |
|
self.mlp = self.mlp.to(dtype=torch_dtype) |
|
self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype) |
|
x_stu = self.stu(x_normed).to(dtype=x.dtype) |
|
x = x + x_stu |
|
|
|
|
|
x_normed_mlp = self.mlp_norm(x).to(dtype=self.mlp.gate_proj.weight.dtype) |
|
x_mlp = self.mlp(x_normed_mlp).to(dtype=x.dtype) |
|
x = x + x_mlp |
|
|
|
return x |
|
|
|
class AttentionLayer(nn.Module): |
|
def __init__(self, config) -> None: |
|
super(AttentionLayer, self).__init__() |
|
if isinstance(config.torch_dtype, str): |
|
torch_dtype = getattr(torch, config.torch_dtype) |
|
else: |
|
torch_dtype = config.torch_dtype |
|
self.attn_norm = ( |
|
TritonNorm(config.n_embd) |
|
if triton_norm |
|
else RMSNorm(config.n_embd, dtype=torch_dtype) |
|
) |
|
self.attn = Attention(config) |
|
self.attn = self.attn.to(dtype=torch_dtype) |
|
self.mlp_norm = ( |
|
TritonNorm(config.n_embd) |
|
if triton_norm |
|
else RMSNorm(config.n_embd, dtype=torch_dtype) |
|
) |
|
self.mlp = ( |
|
TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype) |
|
) |
|
self.mlp = self.mlp.to(dtype=torch_dtype) |
|
|
|
|
|
self.attn_norm = self.attn_norm.to(dtype=torch_dtype) |
|
self.mlp = self.mlp.to(dtype=torch_dtype) |
|
self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = x + self.attn(self.attn_norm(x)) |
|
x = x + self.mlp(self.mlp_norm(x)) |
|
return x |
|
|