Transformer_500M / layers.py
yagizdevre's picture
transformer new
a2fbb2f
import torch
import torch.nn as nn
from .attn import FlexAttention
from .modules import MLP
from .modules import Attention
try:
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP
triton_mlp = True
except ImportError as e:
print(
f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead."
)
triton_mlp = False
try:
from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
triton_norm = True
except ImportError as e:
print(
f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
)
from torch.nn import RMSNorm
triton_norm = False
class AttentionLayer(nn.Module):
def __init__(self, config, mask_mod, score_mod=None) -> None:
super(AttentionLayer, self).__init__()
self.attn_norm = nn.RMSNorm(config.dim)
self.attn = FlexAttention(
config=config,
mask_mod=mask_mod,
score_mod=score_mod,
)
self.mlp_norm = nn.RMSNorm(config.dim)
self.mlp = MLP(config)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor:
x = x + self.attn(self.attn_norm(x), freqs_cis=freqs_cis)
x = x + self.mlp(self.mlp_norm(x))
return x