|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from torch import nn |
|
|
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
|
|
def cast_tuple(val, depth=1): |
|
if isinstance(val, list): |
|
val = tuple(val) |
|
return val if isinstance(val, tuple) else (val,) * depth |
|
|
|
|
|
def max_neg_value(t): |
|
return -torch.finfo(t.dtype).max |
|
|
|
|
|
def stable_softmax(t, dim=-1, alpha=32**2): |
|
t = t / alpha |
|
t = t - torch.amax(t, dim=dim, keepdim=True).detach() |
|
return (t * alpha).softmax(dim=dim) |
|
|
|
|
|
def route_args(router, args, depth): |
|
routed_args = [(dict(), dict()) for _ in range(depth)] |
|
matched_keys = [key for key in args.keys() if key in router] |
|
|
|
for key in matched_keys: |
|
val = args[key] |
|
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): |
|
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) |
|
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) |
|
return routed_args |
|
|
|
|
|
|
|
class SequentialSequence(nn.Module): |
|
def __init__(self, layers, args_route={}, layer_dropout=0.0): |
|
super().__init__() |
|
assert all( |
|
len(route) == len(layers) for route in args_route.values() |
|
), "each argument route map must have the same depth as the number of sequential layers" |
|
self.layers = layers |
|
self.args_route = args_route |
|
self.layer_dropout = layer_dropout |
|
|
|
def forward(self, x, **kwargs): |
|
args = route_args(self.args_route, kwargs, len(self.layers)) |
|
layers_and_args = list(zip(self.layers, args)) |
|
|
|
for (f, g), (f_args, g_args) in layers_and_args: |
|
x = x + f(x, **f_args) |
|
x = x + g(x, **g_args) |
|
return x |
|
|
|
|
|
class DivideMax(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, x): |
|
maxes = x.amax(dim=self.dim, keepdim=True).detach() |
|
return x / maxes |
|
|
|
|
|
|
|
class LayerScale(nn.Module): |
|
def __init__(self, dim, depth, fn): |
|
super().__init__() |
|
if depth <= 18: |
|
init_eps = 0.1 |
|
elif depth > 18 and depth <= 24: |
|
init_eps = 1e-5 |
|
else: |
|
init_eps = 1e-6 |
|
|
|
scale = torch.zeros(1, 1, dim).fill_(init_eps) |
|
self.scale = nn.Parameter(scale) |
|
self.fn = fn |
|
|
|
def forward(self, x, **kwargs): |
|
return self.fn(x, **kwargs) * self.scale |
|
|
|
|
|
|
|
|
|
|
|
class PreNorm(nn.Module): |
|
def __init__(self, dim, fn, sandwich=False): |
|
super().__init__() |
|
self.norm = nn.LayerNorm(dim) |
|
self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity() |
|
self.fn = fn |
|
|
|
def forward(self, x, **kwargs): |
|
x = self.norm(x) |
|
x = self.fn(x, **kwargs) |
|
return self.norm_out(x) |
|
|
|
|
|
|
|
|
|
|
|
class GEGLU(nn.Module): |
|
def forward(self, x): |
|
x, gates = x.chunk(2, dim=-1) |
|
return x * F.gelu(gates) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, dropout=0.0, mult=4.0): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Linear(dim, dim * mult * 2), |
|
GEGLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(dim * mult, dim), |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, dim, seq_len, causal=True, heads=8, dim_head=64, dropout=0.0): |
|
super().__init__() |
|
inner_dim = dim_head * heads |
|
self.heads = heads |
|
self.seq_len = seq_len |
|
self.scale = dim_head**-0.5 |
|
|
|
self.causal = causal |
|
|
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) |
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) |
|
|
|
def forward(self, x, mask=None): |
|
b, n, _, h, device = *x.shape, self.heads, x.device |
|
softmax = torch.softmax |
|
|
|
qkv = self.to_qkv(x).chunk(3, dim=-1) |
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv) |
|
|
|
q = q * self.scale |
|
|
|
dots = torch.einsum("b h i d, b h j d -> b h i j", q, k) |
|
mask_value = max_neg_value(dots) |
|
|
|
if exists(mask): |
|
mask = rearrange(mask, "b j -> b () () j") |
|
dots.masked_fill_(~mask, mask_value) |
|
del mask |
|
|
|
if self.causal: |
|
i, j = dots.shape[-2:] |
|
mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool() |
|
dots.masked_fill_(mask, mask_value) |
|
|
|
attn = softmax(dots, dim=-1) |
|
|
|
out = torch.einsum("b h i j, b h j d -> b h i d", attn, v) |
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
out = self.to_out(out) |
|
return out |
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
depth, |
|
seq_len, |
|
causal=True, |
|
heads=8, |
|
dim_head=64, |
|
ff_mult=4, |
|
attn_dropout=0.0, |
|
ff_dropout=0.0, |
|
sparse_attn=False, |
|
sandwich_norm=False, |
|
): |
|
super().__init__() |
|
layers = nn.ModuleList([]) |
|
sparse_layer = cast_tuple(sparse_attn, depth) |
|
|
|
for ind, sparse_attn in zip(range(depth), sparse_layer): |
|
attn = Attention( |
|
dim, |
|
causal=causal, |
|
seq_len=seq_len, |
|
heads=heads, |
|
dim_head=dim_head, |
|
dropout=attn_dropout, |
|
) |
|
|
|
ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout) |
|
|
|
layers.append( |
|
nn.ModuleList( |
|
[ |
|
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)), |
|
LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)), |
|
] |
|
) |
|
) |
|
|
|
execute_type = SequentialSequence |
|
route_attn = ((True, False),) * depth |
|
attn_route_map = {"mask": route_attn} |
|
|
|
self.layers = execute_type(layers, args_route=attn_route_map) |
|
|
|
def forward(self, x, **kwargs): |
|
return self.layers(x, **kwargs) |
|
|