yagizdevre's picture
fix
ffefe3d
raw
history blame contribute delete
893 Bytes
import torch.nn as nn
from torch.nn import functional as F
import torch
class MLP(nn.Module):
def __init__(self, config):
# https://arxiv.org/pdf/2002.05202
super().__init__()
self.hidden_size = config.dim
self.intermediate_size = config.dim * config.mlp_scale
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
gate = self.gate_proj(x)
gate = F.gelu(gate, approximate="tanh")
up = self.up_proj(x)
fuse = gate * up
outputs = self.down_proj(fuse)
outputs = self.dropout(outputs)
return outputs