|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
class MLP(nn.Module): |
|
def __init__(self, config, dtype=None): |
|
|
|
super().__init__() |
|
torch_dtype = getattr(torch, config.torch_dtype, torch.float32) |
|
dtype = dtype if dtype is not None else torch_dtype |
|
self.hidden_size = config.n_embd |
|
self.intermediate_size = config.n_embd * config.mlp_scale |
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16) |
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16) |
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias, dtype=torch.bfloat16) |
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
def forward(self, x): |
|
dtype = self.gate_proj.weight.dtype |
|
x = x.to(dtype=dtype) |
|
x = x.to(self.gate_proj.weight.dtype) |
|
gate = self.gate_proj(x) |
|
gate = F.gelu(gate, approximate="tanh").to(dtype=dtype) |
|
up = self.up_proj(x).to(dtype=dtype) |
|
fuse = gate * up |
|
outputs = self.down_proj(fuse).to(dtype=dtype) |
|
outputs = self.dropout(outputs) |
|
return outputs |