|
import torch |
|
import torch.nn as nn |
|
|
|
from .attn import FlexAttention |
|
from .modules import MLP |
|
from .modules import Attention |
|
try: |
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP |
|
triton_mlp = True |
|
except ImportError as e: |
|
print( |
|
f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead." |
|
) |
|
triton_mlp = False |
|
|
|
try: |
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm |
|
triton_norm = True |
|
except ImportError as e: |
|
print( |
|
f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation." |
|
) |
|
from torch.nn import RMSNorm |
|
triton_norm = False |
|
|
|
class AttentionLayer(nn.Module): |
|
def __init__(self, config, mask_mod, score_mod=None) -> None: |
|
super(AttentionLayer, self).__init__() |
|
self.attn_norm = nn.RMSNorm(config.dim) |
|
self.attn = FlexAttention( |
|
config=config, |
|
mask_mod=mask_mod, |
|
score_mod=score_mod, |
|
) |
|
self.mlp_norm = nn.RMSNorm(config.dim) |
|
self.mlp = MLP(config) |
|
|
|
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor: |
|
x = x + self.attn(self.attn_norm(x), freqs_cis=freqs_cis) |
|
x = x + self.mlp(self.mlp_norm(x)) |
|
return x |
|
|