STU-426M / layers.py
yagizdevre's picture
Model uploaded.
ac3a9cf
import torch
import torch.nn as nn
from .modules import STU
from .modules import MLP
from .modules import Attention
try:
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP
triton_mlp = True
except ImportError as e:
print(
f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead."
)
triton_mlp = False
try:
from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
triton_norm = True
except ImportError as e:
print(
f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
)
from torch.nn import RMSNorm
triton_norm = False
class STULayer(nn.Module):
def __init__(self, config, phi, n):
super(STULayer, self).__init__()
if isinstance(config.torch_dtype, str):
torch_dtype = getattr(torch, config.torch_dtype)
else:
torch_dtype = config.torch_dtype
self.stu_norm = (
TritonNorm(config.n_embd)
if triton_norm
else RMSNorm(config.n_embd, dtype=torch_dtype)
)
self.stu = STU(config, phi, n)
self.stu = self.stu.to(dtype=torch_dtype)
self.mlp_norm = (
TritonNorm(config.n_embd)
if triton_norm
else RMSNorm(config.n_embd, dtype=torch_dtype)
)
self.mlp = (
TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
)
# TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
self.stu_norm = self.stu_norm.to(dtype=torch_dtype)
self.mlp = self.mlp.to(dtype=torch_dtype)
self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Debug dtype
# Normalize and apply STU
x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype) # Match dtype for STU
x_stu = self.stu(x_normed).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
x = x + x_stu
# Normalize and apply MLP
x_normed_mlp = self.mlp_norm(x).to(dtype=self.mlp.gate_proj.weight.dtype) # Match dtype for MLP
x_mlp = self.mlp(x_normed_mlp).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
x = x + x_mlp
return x
class AttentionLayer(nn.Module):
def __init__(self, config) -> None:
super(AttentionLayer, self).__init__()
if isinstance(config.torch_dtype, str):
torch_dtype = getattr(torch, config.torch_dtype)
else:
torch_dtype = config.torch_dtype
self.attn_norm = (
TritonNorm(config.n_embd)
if triton_norm
else RMSNorm(config.n_embd, dtype=torch_dtype)
)
self.attn = Attention(config)
self.attn = self.attn.to(dtype=torch_dtype)
self.mlp_norm = (
TritonNorm(config.n_embd)
if triton_norm
else RMSNorm(config.n_embd, dtype=torch_dtype)
)
self.mlp = (
TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
)
self.mlp = self.mlp.to(dtype=torch_dtype)
# TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
self.attn_norm = self.attn_norm.to(dtype=torch_dtype)
self.mlp = self.mlp.to(dtype=torch_dtype)
self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.attn_norm(x))
x = x + self.mlp(self.mlp_norm(x))
return x