STU-426M / mlp.py
yagizdevre's picture
Model uploaded.
ac3a9cf
import torch.nn as nn
from torch.nn import functional as F
class MLP(nn.Module):
def __init__(self, config, dtype=None):
# https://arxiv.org/pdf/2002.05202
super().__init__()
torch_dtype = getattr(torch, config.torch_dtype, torch.float32) # Use config dtype
dtype = dtype if dtype is not None else torch_dtype
self.hidden_size = config.n_embd
self.intermediate_size = config.n_embd * config.mlp_scale
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias, dtype=torch.bfloat16)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
dtype = self.gate_proj.weight.dtype # Match the dtype of projection layers
x = x.to(dtype=dtype) # Convert input to the same dtype
x = x.to(self.gate_proj.weight.dtype)
gate = self.gate_proj(x)
gate = F.gelu(gate, approximate="tanh").to(dtype=dtype)
up = self.up_proj(x).to(dtype=dtype)
fuse = gate * up
outputs = self.down_proj(fuse).to(dtype=dtype)
outputs = self.dropout(outputs)
return outputs