|
from typing import List |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .args import MoeArgs |
|
|
|
|
|
class MoeLayer(nn.Module): |
|
def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs): |
|
super().__init__() |
|
assert len(experts) > 0 |
|
self.experts = nn.ModuleList(experts) |
|
self.gate = gate |
|
self.args = moe_args |
|
|
|
def forward(self, inputs: torch.Tensor): |
|
gate_logits = self.gate(inputs) |
|
weights, selected_experts = torch.topk( |
|
gate_logits, self.args.num_experts_per_tok |
|
) |
|
weights = torch.nn.functional.softmax(weights, dim=1, dtype=torch.float).to( |
|
inputs.dtype |
|
) |
|
results = torch.zeros_like(inputs) |
|
for i, expert in enumerate(self.experts): |
|
batch_idx, nth_expert = torch.where(selected_experts == i) |
|
results[batch_idx] += weights[batch_idx, nth_expert, None] * expert( |
|
inputs[batch_idx] |
|
) |
|
return results |
|
|