import torch import torch.nn as nn from .modules import STU from .modules import MLP from .modules import Attention try: from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP triton_mlp = True except ImportError as e: print( f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead." ) triton_mlp = False try: from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm triton_norm = True except ImportError as e: print( f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation." ) from torch.nn import RMSNorm triton_norm = False class STULayer(nn.Module): def __init__(self, config, phi, n): super(STULayer, self).__init__() if isinstance(config.torch_dtype, str): torch_dtype = getattr(torch, config.torch_dtype) else: torch_dtype = config.torch_dtype self.stu_norm = ( TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd, dtype=torch_dtype) ) self.stu = STU(config, phi, n) self.stu = self.stu.to(dtype=torch_dtype) self.mlp_norm = ( TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd, dtype=torch_dtype) ) self.mlp = ( TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype) ) # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP self.stu_norm = self.stu_norm.to(dtype=torch_dtype) self.mlp = self.mlp.to(dtype=torch_dtype) self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: # Debug dtype # Normalize and apply STU x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype) # Match dtype for STU x_stu = self.stu(x_normed).to(dtype=x.dtype) # Ensure output matches `x`'s dtype x = x + x_stu # Normalize and apply MLP x_normed_mlp = self.mlp_norm(x).to(dtype=self.mlp.gate_proj.weight.dtype) # Match dtype for MLP x_mlp = self.mlp(x_normed_mlp).to(dtype=x.dtype) # Ensure output matches `x`'s dtype x = x + x_mlp return x class AttentionLayer(nn.Module): def __init__(self, config) -> None: super(AttentionLayer, self).__init__() if isinstance(config.torch_dtype, str): torch_dtype = getattr(torch, config.torch_dtype) else: torch_dtype = config.torch_dtype self.attn_norm = ( TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd, dtype=torch_dtype) ) self.attn = Attention(config) self.attn = self.attn.to(dtype=torch_dtype) self.mlp_norm = ( TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd, dtype=torch_dtype) ) self.mlp = ( TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype) ) self.mlp = self.mlp.to(dtype=torch_dtype) # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP self.attn_norm = self.attn_norm.to(dtype=torch_dtype) self.mlp = self.mlp.to(dtype=torch_dtype) self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.attn_norm(x)) x = x + self.mlp(self.mlp_norm(x)) return x