import torch.nn as nn from torch.nn import functional as F class MLP(nn.Module): def __init__(self, config, dtype=None): # https://arxiv.org/pdf/2002.05202 super().__init__() torch_dtype = getattr(torch, config.torch_dtype, torch.float32) # Use config dtype dtype = dtype if dtype is not None else torch_dtype self.hidden_size = config.n_embd self.intermediate_size = config.n_embd * config.mlp_scale self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias, dtype=torch.bfloat16) self.dropout = nn.Dropout(config.dropout) def forward(self, x): dtype = self.gate_proj.weight.dtype # Match the dtype of projection layers x = x.to(dtype=dtype) # Convert input to the same dtype x = x.to(self.gate_proj.weight.dtype) gate = self.gate_proj(x) gate = F.gelu(gate, approximate="tanh").to(dtype=dtype) up = self.up_proj(x).to(dtype=dtype) fuse = gate * up outputs = self.down_proj(fuse).to(dtype=dtype) outputs = self.dropout(outputs) return outputs