| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import logging | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						from typing import Tuple | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn as nn | 
					
					
						
						| 
							 | 
						import torch.nn.functional as F | 
					
					
						
						| 
							 | 
						from torch import nn | 
					
					
						
						| 
							 | 
						from transformers import LlamaConfig | 
					
					
						
						| 
							 | 
						from transformers.models.llama.modeling_llama import ( | 
					
					
						
						| 
							 | 
						    ACT2FN, | 
					
					
						
						| 
							 | 
						    LLAMA_ATTENTION_CLASSES, | 
					
					
						
						| 
							 | 
						    LlamaDecoderLayer, | 
					
					
						
						| 
							 | 
						    LlamaForCausalLM, | 
					
					
						
						| 
							 | 
						    LlamaMLP, | 
					
					
						
						| 
							 | 
						    LlamaModel, | 
					
					
						
						| 
							 | 
						    LlamaRMSNorm, | 
					
					
						
						| 
							 | 
						    LlamaRotaryEmbedding, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						logger = logging.getLogger(__name__) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class AriaMoELMConfig(LlamaConfig): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Configuration class for AriaMoE language model. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    This class extends the LlamaConfig to include additional parameters specific to the Mixture of Experts (MoE) architecture. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    model_type = "aria_moe_lm" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        moe_intermediate_size: int = 4096, | 
					
					
						
						| 
							 | 
						        moe_num_experts: int = 8, | 
					
					
						
						| 
							 | 
						        moe_topk: int = 2, | 
					
					
						
						| 
							 | 
						        moe_z_loss_coeff: float = 1e-5, | 
					
					
						
						| 
							 | 
						        moe_aux_loss_coeff: float = 1e-3, | 
					
					
						
						| 
							 | 
						        moe_num_shared_experts: int = 2, | 
					
					
						
						| 
							 | 
						        **kwargs, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Initialize the AriaMoELMConfig. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            moe_intermediate_size (int): The intermediate size for MoE layers. Default is 4096. | 
					
					
						
						| 
							 | 
						            moe_num_experts (int): The number of experts in the MoE layer. Default is 8. | 
					
					
						
						| 
							 | 
						            moe_topk (int): The number of top experts to route to for each token. Default is 2. | 
					
					
						
						| 
							 | 
						            moe_z_loss_coeff (float): The coefficient for the auxiliary z-loss. Default is 1e-5. | 
					
					
						
						| 
							 | 
						            moe_aux_loss_coeff (float): The coefficient for the auxiliary load balancing loss. Default is 1e-3. | 
					
					
						
						| 
							 | 
						            moe_num_shared_experts (int): The number of shared experts. Default is 2. | 
					
					
						
						| 
							 | 
						            **kwargs: Additional keyword arguments to be passed to the parent LlamaConfig. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        super().__init__(**kwargs) | 
					
					
						
						| 
							 | 
						        self.moe_intermediate_size = moe_intermediate_size | 
					
					
						
						| 
							 | 
						        self.moe_num_experts = moe_num_experts | 
					
					
						
						| 
							 | 
						        self.moe_topk = moe_topk | 
					
					
						
						| 
							 | 
						        self.moe_z_loss_coeff = moe_z_loss_coeff | 
					
					
						
						| 
							 | 
						        self.moe_aux_loss_coeff = moe_aux_loss_coeff | 
					
					
						
						| 
							 | 
						        self.moe_num_shared_experts = moe_num_shared_experts | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						class MoEAuxLossAutoScaler(torch.autograd.Function): | 
					
					
						
						| 
							 | 
						    """An AutoScaler that compute and scales the grad for auxiliary loss.""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    main_loss_backward_scale: torch.Tensor = torch.tensor(1.0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @staticmethod | 
					
					
						
						| 
							 | 
						    def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor): | 
					
					
						
						| 
							 | 
						        """Preserve the aux_loss by storing it in the context to avoid garbage collection. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            output (torch.Tensor): The output tensor. | 
					
					
						
						| 
							 | 
						            aux_loss (torch.Tensor): The auxiliary loss tensor. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            torch.Tensor: The output tensor. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        ctx.save_for_backward(aux_loss) | 
					
					
						
						| 
							 | 
						        return output | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @staticmethod | 
					
					
						
						| 
							 | 
						    def backward(ctx, grad_output: torch.Tensor): | 
					
					
						
						| 
							 | 
						        """Compute and scale the gradient for auxiliary loss.. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            grad_output (torch.Tensor): The gradient of the output. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        (aux_loss,) = ctx.saved_tensors | 
					
					
						
						| 
							 | 
						        aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale | 
					
					
						
						| 
							 | 
						        scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale | 
					
					
						
						| 
							 | 
						        return grad_output, scaled_aux_loss_grad | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @staticmethod | 
					
					
						
						| 
							 | 
						    def set_loss_scale(scale: torch.Tensor): | 
					
					
						
						| 
							 | 
						        """set the scale of the aux loss. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        MoEAuxLossAutoScaler.main_loss_backward_scale = scale | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def z_loss_func(logits, z_loss_coeff): | 
					
					
						
						| 
							 | 
						    """Encourages the router's logits to remain small to enhance stability. | 
					
					
						
						| 
							 | 
						    Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        logits (torch.Tensor): The logits of the router. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        torch.Tensor: The logits after applying the z-loss. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff | 
					
					
						
						| 
							 | 
						    return z_loss | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def switch_load_balancing_loss_func( | 
					
					
						
						| 
							 | 
						    probs: torch.Tensor, | 
					
					
						
						| 
							 | 
						    tokens_per_expert: torch.Tensor, | 
					
					
						
						| 
							 | 
						    topk: int, | 
					
					
						
						| 
							 | 
						    moe_aux_loss_coeff: float, | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						    """Calculate the auxiliary loss for better load balacing. | 
					
					
						
						| 
							 | 
						    Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        probs (torch.Tensor): The softmax probs output by the router for each token. [num_tokens, num_experts] | 
					
					
						
						| 
							 | 
						        tokens_per_expert (torch.Tensor): The number of assigned tokens for each expert. [num_experts] | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        torch.Tensor: The auxiliary loss for load balancing. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    num_tokens = probs.shape[0] * topk | 
					
					
						
						| 
							 | 
						    num_experts = probs.shape[1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    probs_mean_per_expert = probs.mean(dim=0) | 
					
					
						
						| 
							 | 
						    aux_loss = torch.sum(probs_mean_per_expert * tokens_per_expert) * ( | 
					
					
						
						| 
							 | 
						        num_experts / num_tokens * moe_aux_loss_coeff | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    return aux_loss | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						class TopKRouter(nn.Module): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Top-K Router for Mixture of Experts (MoE) models. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    This router determines which experts should process each token based on the top-k scoring experts. | 
					
					
						
						| 
							 | 
						    It also applies auxiliary losses to encourage load balancing among experts. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        config (AriaMoELMConfig): Configuration object containing MoE-related parameters. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, config: AriaMoELMConfig): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.config = config | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.weight = nn.Parameter( | 
					
					
						
						| 
							 | 
						            torch.empty((self.config.moe_num_experts, self.config.hidden_size)) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def gating(self, input: torch.Tensor) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Compute the gating logits for each token-expert pair. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            input (torch.Tensor): Input tensor of shape [batch_size * seq_len, hidden_size]. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            torch.Tensor: Logits tensor of shape [batch_size * seq_len, num_experts]. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        logits = torch.nn.functional.linear(input, self.weight) | 
					
					
						
						| 
							 | 
						        return logits | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def apply_z_loss(self, logits: torch.Tensor) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Apply z-loss to encourage router logits to remain small for enhanced stability. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            logits (torch.Tensor): Router logits. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            torch.Tensor: Logits with z-loss applied. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        z_loss = z_loss_func(logits, self.config.moe_z_loss_coeff) | 
					
					
						
						| 
							 | 
						        logits = MoEAuxLossAutoScaler.apply(logits, z_loss) | 
					
					
						
						| 
							 | 
						        return logits | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def apply_aux_loss( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        logits: torch.Tensor, | 
					
					
						
						| 
							 | 
						        tokens_per_expert: torch.Tensor, | 
					
					
						
						| 
							 | 
						        activation: torch.Tensor, | 
					
					
						
						| 
							 | 
						    ) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Apply auxiliary loss for load balancing among experts. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            logits (torch.Tensor): Router logits. | 
					
					
						
						| 
							 | 
						            tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. | 
					
					
						
						| 
							 | 
						            activation (torch.Tensor): Activation values. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            torch.Tensor: Activation with auxiliary loss applied. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        probs = torch.softmax(logits, dim=-1, dtype=torch.float32) | 
					
					
						
						| 
							 | 
						        aux_loss = switch_load_balancing_loss_func( | 
					
					
						
						| 
							 | 
						            probs, | 
					
					
						
						| 
							 | 
						            tokens_per_expert, | 
					
					
						
						| 
							 | 
						            self.config.moe_topk, | 
					
					
						
						| 
							 | 
						            self.config.moe_aux_loss_coeff, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        return MoEAuxLossAutoScaler.apply(activation, aux_loss) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def routing( | 
					
					
						
						| 
							 | 
						        self, logits: torch.Tensor | 
					
					
						
						| 
							 | 
						    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Perform the routing operation to determine expert assignments. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            logits (torch.Tensor): Router logits. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | 
					
					
						
						| 
							 | 
						                - scores: Softmax probabilities for top-k experts. | 
					
					
						
						| 
							 | 
						                - top_indices: Indices of top-k experts for each token. | 
					
					
						
						| 
							 | 
						                - tokens_per_expert: Number of tokens assigned to each expert. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        logits = self.apply_z_loss(logits) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1) | 
					
					
						
						| 
							 | 
						        scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        tokens_per_expert = torch.histc( | 
					
					
						
						| 
							 | 
						            top_indices.flatten(), | 
					
					
						
						| 
							 | 
						            bins=self.config.moe_num_experts, | 
					
					
						
						| 
							 | 
						            min=0, | 
					
					
						
						| 
							 | 
						            max=self.config.moe_num_experts - 1, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        scores = self.apply_aux_loss(logits, tokens_per_expert, scores) | 
					
					
						
						| 
							 | 
						        return scores, top_indices, tokens_per_expert | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward( | 
					
					
						
						| 
							 | 
						        self, input: torch.Tensor | 
					
					
						
						| 
							 | 
						    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Forward pass of the TopKRouter. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            input (torch.Tensor): Input tensor of shape [batch_size * seq_len, hidden_size]. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | 
					
					
						
						| 
							 | 
						                - scores: Softmax probabilities for top-k experts. | 
					
					
						
						| 
							 | 
						                - top_indices: Indices of top-k experts for each token. | 
					
					
						
						| 
							 | 
						                - tokens_per_expert: Number of tokens assigned to each expert. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        logits = self.gating(input) | 
					
					
						
						| 
							 | 
						        logits = logits.view(-1, self.config.moe_num_experts) | 
					
					
						
						| 
							 | 
						        scores, top_indices, tokens_per_expert = self.routing(logits) | 
					
					
						
						| 
							 | 
						        return scores, top_indices, tokens_per_expert | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						class TokenDispatcher: | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Handles the dispatching and gathering of tokens to and from experts. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    This class is responsible for permuting tokens based on expert assignments and | 
					
					
						
						| 
							 | 
						    unpermuting them after expert processing. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        config (AriaMoELMConfig): Configuration object containing MoE-related parameters. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, config: AriaMoELMConfig): | 
					
					
						
						| 
							 | 
						        self.config = config | 
					
					
						
						| 
							 | 
						        self.hidden_states_shape = None | 
					
					
						
						| 
							 | 
						        self.reversed_input_permutation_mapping = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def token_permutation( | 
					
					
						
						| 
							 | 
						        self, hidden_states: torch.Tensor, indices: torch.Tensor | 
					
					
						
						| 
							 | 
						    ) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Permute tokens based on expert assignments. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            hidden_states (torch.Tensor): Input hidden states. | 
					
					
						
						| 
							 | 
						            indices (torch.Tensor): Expert assignment indices. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            torch.Tensor: Permuted tokens. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        self.hidden_states_shape = hidden_states.shape | 
					
					
						
						| 
							 | 
						        hidden_states = hidden_states.view(-1, hidden_states.size(-1)) | 
					
					
						
						| 
							 | 
						        flatten_indices = indices.flatten() | 
					
					
						
						| 
							 | 
						        sorted_indices = torch.argsort(flatten_indices, stable=True) | 
					
					
						
						| 
							 | 
						        permuted_tokens = hidden_states.index_select( | 
					
					
						
						| 
							 | 
						            0, sorted_indices // self.config.moe_topk | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.reversed_input_permutation_mapping = sorted_indices | 
					
					
						
						| 
							 | 
						        return permuted_tokens | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def token_unpermutation( | 
					
					
						
						| 
							 | 
						        self, permuted_tokens: torch.Tensor, scores: torch.Tensor | 
					
					
						
						| 
							 | 
						    ) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Unpermute tokens and combine expert outputs. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            permuted_tokens (torch.Tensor): Tokens after expert processing. | 
					
					
						
						| 
							 | 
						            scores (torch.Tensor): Expert assignment scores. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            torch.Tensor: Unpermuted and combined output. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        num_unpermuted_tokens = scores.numel() | 
					
					
						
						| 
							 | 
						        unpermuted_tokens = torch.zeros( | 
					
					
						
						| 
							 | 
						            (num_unpermuted_tokens, permuted_tokens.size(1)), | 
					
					
						
						| 
							 | 
						            dtype=permuted_tokens.dtype, | 
					
					
						
						| 
							 | 
						            device=permuted_tokens.device, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        unpermuted_tokens.index_copy_( | 
					
					
						
						| 
							 | 
						            0, self.reversed_input_permutation_mapping, permuted_tokens | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        unpermuted_tokens = unpermuted_tokens.reshape( | 
					
					
						
						| 
							 | 
						            -1, self.config.moe_topk, permuted_tokens.size(1) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        unpermuted_tokens = unpermuted_tokens * scores.unsqueeze(-1) | 
					
					
						
						| 
							 | 
						        unpermuted_tokens = unpermuted_tokens.sum(dim=1).type_as(permuted_tokens) | 
					
					
						
						| 
							 | 
						        output = unpermuted_tokens.view(self.hidden_states_shape) | 
					
					
						
						| 
							 | 
						        return output | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class SharedExpertMLP(LlamaMLP): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Shared Expert MLP for shared experts. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Unlike routed experts, shared experts process all tokens without routing. | 
					
					
						
						| 
							 | 
						    This class reconfigures the intermediate size in comparison to the LlamaMLP. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        config (AriaMoELMConfig): Configuration object for the AriaMoE language model. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, config: AriaMoELMConfig): | 
					
					
						
						| 
							 | 
						        nn.Module.__init__(self) | 
					
					
						
						| 
							 | 
						        self.config = config | 
					
					
						
						| 
							 | 
						        self.hidden_size = config.hidden_size | 
					
					
						
						| 
							 | 
						        self.intermediate_size = ( | 
					
					
						
						| 
							 | 
						            config.moe_intermediate_size * config.moe_num_shared_experts | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.gate_proj = nn.Linear( | 
					
					
						
						| 
							 | 
						            self.hidden_size, self.intermediate_size, bias=config.mlp_bias | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.up_proj = nn.Linear( | 
					
					
						
						| 
							 | 
						            self.hidden_size, self.intermediate_size, bias=config.mlp_bias | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.down_proj = nn.Linear( | 
					
					
						
						| 
							 | 
						            self.intermediate_size, self.hidden_size, bias=config.mlp_bias | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.act_fn = ACT2FN[config.hidden_act] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def sequential_gemm(input, weight, tokens_per_expert): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        input (torch.Tensor): Input tensor of shape (num_tokens, in_features). | 
					
					
						
						| 
							 | 
						        weight (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features). | 
					
					
						
						| 
							 | 
						        tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        torch.Tensor: Output tensor of shape (num_tokens, out_features). | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    num_tokens = input.shape[0] | 
					
					
						
						| 
							 | 
						    out_features = weight.shape[-1] | 
					
					
						
						| 
							 | 
						    output = torch.zeros( | 
					
					
						
						| 
							 | 
						        num_tokens, out_features, dtype=input.dtype, device=input.device | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device) | 
					
					
						
						| 
							 | 
						    cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    for expert_num in range(weight.shape[0]): | 
					
					
						
						| 
							 | 
						        start = cumsum_num_tokens[expert_num] | 
					
					
						
						| 
							 | 
						        end = cumsum_num_tokens[expert_num + 1] | 
					
					
						
						| 
							 | 
						        tokens = input[start:end] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        out = torch.matmul(tokens, weight[expert_num]) | 
					
					
						
						| 
							 | 
						        output[start:end] = out | 
					
					
						
						| 
							 | 
						    return output | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						try: | 
					
					
						
						| 
							 | 
						    from grouped_gemm.ops import gmm as experts_gemm | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if os.environ.get("USE_GROUPED_GEMM", "1") == "0": | 
					
					
						
						| 
							 | 
						        logger.warning( | 
					
					
						
						| 
							 | 
						            "environment variable USE_GROUPED_GEMM is set to 0, using sequential GEMM instead." | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        experts_gemm = sequential_gemm | 
					
					
						
						| 
							 | 
						except ImportError: | 
					
					
						
						| 
							 | 
						    logger.warning( | 
					
					
						
						| 
							 | 
						        "`grouped_gemm` is not installed, using sequential GEMM, which is slower." | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    experts_gemm = sequential_gemm | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class GroupedGEMM(nn.Module): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Grouped GEMM (General Matrix Multiplication) module for efficient expert computation. | 
					
					
						
						| 
							 | 
						    This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm) | 
					
					
						
						| 
							 | 
						    for optimized performance. If the grouped_gemm library is not installed, it gracefully | 
					
					
						
						| 
							 | 
						    falls back to a sequential GEMM implementation, which may be slower but ensures | 
					
					
						
						| 
							 | 
						    functionality. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        in_features (int): Number of input features. | 
					
					
						
						| 
							 | 
						        out_features (int): Number of output features. | 
					
					
						
						| 
							 | 
						        groups (int): Number of expert groups. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, in_features, out_features, groups): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.in_features = in_features | 
					
					
						
						| 
							 | 
						        self.out_features = out_features | 
					
					
						
						| 
							 | 
						        self.groups = groups | 
					
					
						
						| 
							 | 
						        self.weight = nn.Parameter(torch.empty(groups, in_features, out_features)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, input, tokens_per_expert): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Perform grouped matrix multiplication. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            input (torch.Tensor): Input tensor of shape (num_tokens, in_features). | 
					
					
						
						| 
							 | 
						            tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            torch.Tensor: Output tensor of shape (num_tokens, out_features). | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        tokens_per_expert = tokens_per_expert.cpu() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        torch.cuda.set_device(input.device) | 
					
					
						
						| 
							 | 
						        return experts_gemm(input, self.weight, tokens_per_expert) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class GroupedMLP(nn.Module): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Grouped MLP module for Mixture of Experts. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        config (AriaMoELMConfig): Configuration object for the model. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, config: AriaMoELMConfig) -> None: | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.config = config | 
					
					
						
						| 
							 | 
						        self.fc1 = GroupedGEMM( | 
					
					
						
						| 
							 | 
						            config.hidden_size, config.moe_intermediate_size * 2, config.moe_num_experts | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.fc2 = GroupedGEMM( | 
					
					
						
						| 
							 | 
						            config.moe_intermediate_size, config.hidden_size, config.moe_num_experts | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        def glu(x): | 
					
					
						
						| 
							 | 
						            x = torch.chunk(x, 2, dim=-1) | 
					
					
						
						| 
							 | 
						            return F.silu(x[0]) * x[1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.activation_func = glu | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, permuted_tokens, tokens_per_expert): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Forward pass of the Grouped MLP. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            permuted_tokens (torch.Tensor): Permuted input tokens. | 
					
					
						
						| 
							 | 
						            tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            torch.Tensor: Output tensor after passing through the MLP. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        fc1_output = self.fc1(permuted_tokens, tokens_per_expert) | 
					
					
						
						| 
							 | 
						        fc1_output = self.activation_func(fc1_output) | 
					
					
						
						| 
							 | 
						        fc2_output = self.fc2(fc1_output, tokens_per_expert) | 
					
					
						
						| 
							 | 
						        return fc2_output | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class MoELayer(nn.Module): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Mixture of Experts (MoE) Layer for the AriaMoE model. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    This layer implements the MoE mechanism, which routes input tokens to different experts | 
					
					
						
						| 
							 | 
						    based on a routing algorithm, processes them through the experts, and then combines | 
					
					
						
						| 
							 | 
						    the outputs. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        config (AriaMoELMConfig): Configuration object for the MoE layer. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, config: AriaMoELMConfig): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.router = TopKRouter(config) | 
					
					
						
						| 
							 | 
						        self.token_dispatcher = TokenDispatcher(config) | 
					
					
						
						| 
							 | 
						        self.experts = GroupedMLP(config) | 
					
					
						
						| 
							 | 
						        self.shared_experts = SharedExpertMLP(config) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Forward pass of the MoE Layer. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            hidden_states (torch.Tensor): Input tensor of shape (batch_size, sequence_length, hidden_size). | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            torch.Tensor: Output tensor after passing through the MoE layer. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Process: | 
					
					
						
						| 
							 | 
						        1. Route tokens to experts using the router. | 
					
					
						
						| 
							 | 
						        2. Permute tokens based on routing decisions. | 
					
					
						
						| 
							 | 
						        3. Process tokens through experts. | 
					
					
						
						| 
							 | 
						        4. Unpermute and combine expert outputs. | 
					
					
						
						| 
							 | 
						        5. Add shared expert output to the final result. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        scores, indices, tokens_per_expert = self.router(hidden_states) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        permuted_tokens = self.token_dispatcher.token_permutation( | 
					
					
						
						| 
							 | 
						            hidden_states, indices | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        expert_output = self.experts(permuted_tokens, tokens_per_expert) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        output = self.token_dispatcher.token_unpermutation(expert_output, scores) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        shared_expert_output = self.shared_experts(hidden_states) | 
					
					
						
						| 
							 | 
						        output += shared_expert_output | 
					
					
						
						| 
							 | 
						        return output | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class MoEDecoderLayer(LlamaDecoderLayer): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Custom Decoder Layer for the AriaMoE model which modifies the standard `LlamaDecoderLayer` by | 
					
					
						
						| 
							 | 
						    replacing the traditional MLP with a Mixture of Experts (MoE) Layer. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        config (LlamaConfig): Configuration object for the layer. | 
					
					
						
						| 
							 | 
						        layer_idx (int): Index of the current layer in the model. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, config: LlamaConfig, layer_idx: int): | 
					
					
						
						| 
							 | 
						        nn.Module.__init__(self) | 
					
					
						
						| 
							 | 
						        self.hidden_size = config.hidden_size | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( | 
					
					
						
						| 
							 | 
						            config=config, layer_idx=layer_idx | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.mlp = MoELayer(config) | 
					
					
						
						| 
							 | 
						        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | 
					
					
						
						| 
							 | 
						        self.post_attention_layernorm = LlamaRMSNorm( | 
					
					
						
						| 
							 | 
						            config.hidden_size, eps=config.rms_norm_eps | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class AriaMoELMModel(LlamaModel): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Custom LlamaModel for the AriaMoE model which modifies the standard LlamaModel by | 
					
					
						
						| 
							 | 
						    replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    This model implements a Mixture of Experts (MoE) approach, where each layer contains | 
					
					
						
						| 
							 | 
						    multiple expert networks that specialize in different aspects of the input. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        config (LlamaConfig): Configuration object for the model. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, config: LlamaConfig): | 
					
					
						
						| 
							 | 
						        super().__init__(config) | 
					
					
						
						| 
							 | 
						        self.padding_idx = config.pad_token_id | 
					
					
						
						| 
							 | 
						        self.vocab_size = config.vocab_size | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.embed_tokens = nn.Embedding( | 
					
					
						
						| 
							 | 
						            config.vocab_size, config.hidden_size, self.padding_idx | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.layers = nn.ModuleList( | 
					
					
						
						| 
							 | 
						            [ | 
					
					
						
						| 
							 | 
						                MoEDecoderLayer(config, layer_idx) | 
					
					
						
						| 
							 | 
						                for layer_idx in range(config.num_hidden_layers) | 
					
					
						
						| 
							 | 
						            ] | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | 
					
					
						
						| 
							 | 
						        self.rotary_emb = LlamaRotaryEmbedding(config=config) | 
					
					
						
						| 
							 | 
						        self.gradient_checkpointing = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.post_init() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class AriaMoELMForCausalLM(LlamaForCausalLM): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    AriaMoE model for causal language modeling tasks. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    This class extends LlamaForCausalLM to incorporate the Mixture of Experts (MoE) approach, | 
					
					
						
						| 
							 | 
						    allowing for more efficient and scalable language modeling. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        config (AriaMoELMConfig): Configuration object for the model. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    _tied_weights_keys = ["lm_head.weight"] | 
					
					
						
						| 
							 | 
						    config_class = AriaMoELMConfig | 
					
					
						
						| 
							 | 
						    _no_split_modules = ["MoEDecoderLayer"] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, config): | 
					
					
						
						| 
							 | 
						        super().__init__(config) | 
					
					
						
						| 
							 | 
						        self.model = AriaMoELMModel(config) | 
					
					
						
						| 
							 | 
						        self.vocab_size = config.vocab_size | 
					
					
						
						| 
							 | 
						        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.post_init() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def set_z_loss_coeff(self, z_loss_coeff: float): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Set the coefficient for the z-loss in the MoE routing. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            z_loss_coeff (float): The coefficient for the z-loss. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        self.config.moe_z_loss_coeff = z_loss_coeff | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def set_aux_loss_coeff(self, aux_loss_coeff: float): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Set the coefficient for the auxiliary loss in the MoE routing. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            aux_loss_coeff (float): The coefficient for the auxiliary loss. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        self.config.moe_aux_loss_coeff = aux_loss_coeff | 
					
					
						
						| 
							 | 
						
 |