|
""" MLP module w/ dropout and configurable activation layer |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
from functools import partial |
|
|
|
from torch import nn as nn |
|
|
|
from .grn import GlobalResponseNorm |
|
from .helpers import to_2tuple |
|
|
|
|
|
class Mlp(nn.Module): |
|
""" MLP as used in Vision Transformer, MLP-Mixer and related networks |
|
""" |
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.GELU, |
|
norm_layer=None, |
|
bias=True, |
|
drop=0., |
|
use_conv=False, |
|
): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
bias = to_2tuple(bias) |
|
drop_probs = to_2tuple(drop) |
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear |
|
|
|
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) |
|
self.act = act_layer() |
|
self.drop1 = nn.Dropout(drop_probs[0]) |
|
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() |
|
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) |
|
self.drop2 = nn.Dropout(drop_probs[1]) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.drop1(x) |
|
x = self.norm(x) |
|
x = self.fc2(x) |
|
x = self.drop2(x) |
|
return x |
|
|
|
|
|
class GluMlp(nn.Module): |
|
""" MLP w/ GLU style gating |
|
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 |
|
""" |
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.Sigmoid, |
|
norm_layer=None, |
|
bias=True, |
|
drop=0., |
|
use_conv=False, |
|
gate_last=True, |
|
): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
assert hidden_features % 2 == 0 |
|
bias = to_2tuple(bias) |
|
drop_probs = to_2tuple(drop) |
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear |
|
self.chunk_dim = 1 if use_conv else -1 |
|
self.gate_last = gate_last |
|
|
|
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) |
|
self.act = act_layer() |
|
self.drop1 = nn.Dropout(drop_probs[0]) |
|
self.norm = norm_layer(hidden_features // 2) if norm_layer is not None else nn.Identity() |
|
self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1]) |
|
self.drop2 = nn.Dropout(drop_probs[1]) |
|
|
|
def init_weights(self): |
|
|
|
fc1_mid = self.fc1.bias.shape[0] // 2 |
|
nn.init.ones_(self.fc1.bias[fc1_mid:]) |
|
nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x1, x2 = x.chunk(2, dim=self.chunk_dim) |
|
x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x2 |
|
x = self.drop1(x) |
|
x = self.norm(x) |
|
x = self.fc2(x) |
|
x = self.drop2(x) |
|
return x |
|
|
|
|
|
SwiGLUPacked = partial(GluMlp, act_layer=nn.SiLU, gate_last=False) |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
""" SwiGLU |
|
NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and |
|
better matches some other common impl which makes mapping checkpoints simpler. |
|
""" |
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.SiLU, |
|
norm_layer=None, |
|
bias=True, |
|
drop=0., |
|
): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
bias = to_2tuple(bias) |
|
drop_probs = to_2tuple(drop) |
|
|
|
self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0]) |
|
self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0]) |
|
self.act = act_layer() |
|
self.drop1 = nn.Dropout(drop_probs[0]) |
|
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() |
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) |
|
self.drop2 = nn.Dropout(drop_probs[1]) |
|
|
|
def init_weights(self): |
|
|
|
nn.init.ones_(self.fc1_g.bias) |
|
nn.init.normal_(self.fc1_g.weight, std=1e-6) |
|
|
|
def forward(self, x): |
|
x_gate = self.fc1_g(x) |
|
x = self.fc1_x(x) |
|
x = self.act(x_gate) * x |
|
x = self.drop1(x) |
|
x = self.norm(x) |
|
x = self.fc2(x) |
|
x = self.drop2(x) |
|
return x |
|
|
|
|
|
class GatedMlp(nn.Module): |
|
""" MLP as used in gMLP |
|
""" |
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.GELU, |
|
norm_layer=None, |
|
gate_layer=None, |
|
bias=True, |
|
drop=0., |
|
): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
bias = to_2tuple(bias) |
|
drop_probs = to_2tuple(drop) |
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) |
|
self.act = act_layer() |
|
self.drop1 = nn.Dropout(drop_probs[0]) |
|
if gate_layer is not None: |
|
assert hidden_features % 2 == 0 |
|
self.gate = gate_layer(hidden_features) |
|
hidden_features = hidden_features // 2 |
|
else: |
|
self.gate = nn.Identity() |
|
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() |
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) |
|
self.drop2 = nn.Dropout(drop_probs[1]) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.drop1(x) |
|
x = self.gate(x) |
|
x = self.norm(x) |
|
x = self.fc2(x) |
|
x = self.drop2(x) |
|
return x |
|
|
|
|
|
class ConvMlp(nn.Module): |
|
""" MLP using 1x1 convs that keeps spatial dims |
|
""" |
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.ReLU, |
|
norm_layer=None, |
|
bias=True, |
|
drop=0., |
|
): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
bias = to_2tuple(bias) |
|
|
|
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) |
|
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() |
|
self.act = act_layer() |
|
self.drop = nn.Dropout(drop) |
|
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.norm(x) |
|
x = self.act(x) |
|
x = self.drop(x) |
|
x = self.fc2(x) |
|
return x |
|
|
|
|
|
class GlobalResponseNormMlp(nn.Module): |
|
""" MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d |
|
""" |
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.GELU, |
|
bias=True, |
|
drop=0., |
|
use_conv=False, |
|
): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
bias = to_2tuple(bias) |
|
drop_probs = to_2tuple(drop) |
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear |
|
|
|
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) |
|
self.act = act_layer() |
|
self.drop1 = nn.Dropout(drop_probs[0]) |
|
self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv) |
|
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) |
|
self.drop2 = nn.Dropout(drop_probs[1]) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.drop1(x) |
|
x = self.grn(x) |
|
x = self.fc2(x) |
|
x = self.drop2(x) |
|
return x |
|
|