File size: 3,253 Bytes
86a52a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import torch
from torch import nn
from torch.nn import functional as F
from . import kernels
from .parallel_experts import ParallelExperts
class GLUMLP(nn.Module):
def __init__(
self,
input_size,
hidden_size,
num_experts,
top_k,
activation=nn.SiLU(),
):
super(GLUMLP, self).__init__()
self.num_experts = num_experts
self.input_size = input_size
self.hidden_size = hidden_size
self.experts = ParallelExperts(num_experts, input_size, 2 * hidden_size)
self.output_experts = ParallelExperts(num_experts, hidden_size, input_size)
self.top_k = min(top_k, self.num_experts)
self.activation = activation
def extra_repr(self):
return 'k={}'.format(self.top_k)
def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor):
x_shape = x.size()
x = x.view(-1, x_shape[-1])
with torch.no_grad():
sorted_expert_idxs, sorted_scattered_idxs = kernels.ops.flatten_and_sort(expert_idxs)
padded_block_idxs, expert_offsets = kernels.ops.padded_block_indices(sorted_expert_idxs, self.num_experts)
h, gates = self.experts(
x, self.top_k,
sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
grouped_out=True
).chunk(2, dim=-1)
h = self.activation(gates) * h
y = self.output_experts(
h, 1, sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
grouped_in=True,
gates=expert_p,
)
y = y.view(*x_shape[:-1], y.size(-1))
return y
class MLP(nn.Module):
def __init__(
self,
input_size,
hidden_size,
num_experts,
top_k,
activation=None,
):
super(MLP, self).__init__()
self.num_experts = num_experts
self.input_size = input_size
self.hidden_size = hidden_size
self.experts = ParallelExperts(num_experts, input_size, hidden_size)
self.output_experts = ParallelExperts(num_experts, hidden_size, input_size)
self.top_k = min(top_k, self.num_experts)
self.activation = activation
def extra_repr(self):
return 'k={}'.format(self.top_k)
def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor):
x_shape = x.size()
x = x.view(-1, x_shape[-1])
with torch.no_grad():
sorted_expert_idxs, sorted_scattered_idxs = kernels.ops.flatten_and_sort(expert_idxs)
padded_block_idxs, expert_offsets = kernels.ops.padded_block_indices(sorted_expert_idxs, self.num_experts)
h = self.experts(
x, self.top_k,
sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
grouped_out=True
)
h = self.activation(h)
y = self.output_experts(
h, 1, sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
grouped_in=True,
gates=expert_p,
)
y = y.view(*x_shape[:-1], y.size(-1))
return y
|