File size: 1,364 Bytes
ac3a9cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch.nn as nn
from torch.nn import functional as F

class MLP(nn.Module):
    def __init__(self, config, dtype=None):
        # https://arxiv.org/pdf/2002.05202
        super().__init__()
        torch_dtype = getattr(torch, config.torch_dtype, torch.float32)  # Use config dtype
        dtype = dtype if dtype is not None else torch_dtype
        self.hidden_size = config.n_embd
        self.intermediate_size = config.n_embd * config.mlp_scale
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias, dtype=torch.bfloat16)
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self, x):
        dtype = self.gate_proj.weight.dtype  # Match the dtype of projection layers
        x = x.to(dtype=dtype)  # Convert input to the same dtype
        x = x.to(self.gate_proj.weight.dtype)
        gate = self.gate_proj(x)
        gate = F.gelu(gate, approximate="tanh").to(dtype=dtype)
        up = self.up_proj(x).to(dtype=dtype)
        fuse = gate * up
        outputs = self.down_proj(fuse).to(dtype=dtype)
        outputs = self.dropout(outputs)
        return outputs