import torch import torch.nn as nn class RMSNorm(nn.Module): def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6): """ Initialize the RMSNorm normalization layer. Args: dim (int): The dimension of the input tensor. eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. Attributes: eps (float): A small value added to the denominator for numerical stability. weight (nn.Parameter): Learnable scaling parameter. """ super().__init__() self.eps = eps if elementwise_affine: self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): """ Apply the RMSNorm normalization to the input tensor. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The normalized tensor. """ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): """ Forward pass through the RMSNorm layer. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying RMSNorm. """ output = self._norm(x.float()).type_as(x) if hasattr(self, "weight"): output = output * self.weight return output class GroupNorm32(nn.GroupNorm): def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None): super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype) def forward(self, x): y = super().forward(x).to(x.dtype) return y def normalization(channels, dtype=None): """ Make a standard normalization layer. :param channels: number of input channels. :return: an nn.Module for normalization. """ return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype)