|
|
|
|
|
from collections import namedtuple |
|
from functools import wraps |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
from einops.layers.torch import Rearrange |
|
from packaging import version |
|
from torch import einsum, nn |
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def once(fn): |
|
called = False |
|
|
|
@wraps(fn) |
|
def inner(x): |
|
nonlocal called |
|
if called: |
|
return |
|
called = True |
|
return fn(x) |
|
|
|
return inner |
|
|
|
|
|
print_once = once(print) |
|
|
|
|
|
|
|
|
|
class Attend(nn.Module): |
|
def __init__(self, dropout=0.0, causal=False, use_flash=False): |
|
super().__init__() |
|
self.dropout = dropout |
|
self.attn_dropout = nn.Dropout(dropout) |
|
|
|
self.causal = causal |
|
self.register_buffer("mask", None, persistent=False) |
|
|
|
self.use_flash = use_flash |
|
assert not ( |
|
use_flash and version.parse(torch.__version__) < version.parse("2.0.0") |
|
), "in order to use flash attention, you must be using pytorch 2.0 or above" |
|
|
|
|
|
self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]) |
|
self.cpu_config = self.config(True, True, True) |
|
self.cuda_config = None |
|
|
|
if not torch.cuda.is_available() or not use_flash: |
|
return |
|
|
|
device_properties = torch.cuda.get_device_properties(torch.device("cuda")) |
|
|
|
if device_properties.major == 8 and device_properties.minor == 0: |
|
print_once("A100 GPU detected, using flash attention if input tensor is on cuda") |
|
self.cuda_config = self.config(True, False, False) |
|
else: |
|
print_once("Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda") |
|
self.cuda_config = self.config(False, True, True) |
|
|
|
def get_mask(self, n, device): |
|
if exists(self.mask) and self.mask.shape[-1] >= n: |
|
return self.mask[:n, :n] |
|
|
|
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) |
|
self.register_buffer("mask", mask, persistent=False) |
|
return mask |
|
|
|
def flash_attn(self, q, k, v, mask=None): |
|
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda |
|
|
|
|
|
|
|
|
|
if k.ndim == 3: |
|
k = rearrange(k, "b ... -> b 1 ...").expand_as(q) |
|
|
|
if v.ndim == 3: |
|
v = rearrange(v, "b ... -> b 1 ...").expand_as(q) |
|
|
|
|
|
|
|
|
|
if exists(mask): |
|
mask = rearrange(mask, "b j -> b 1 1 j") |
|
mask = mask.expand(-1, heads, q_len, -1) |
|
|
|
|
|
|
|
config = self.cuda_config if is_cuda else self.cpu_config |
|
|
|
|
|
|
|
with torch.backends.cuda.sdp_kernel(**config._asdict()): |
|
out = F.scaled_dot_product_attention( |
|
q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal |
|
) |
|
|
|
return out |
|
|
|
def forward(self, q, k, v, mask=None): |
|
""" |
|
einstein notation |
|
b - batch |
|
h - heads |
|
n, i, j - sequence length (base sequence length, source, target) |
|
d - feature dimension |
|
""" |
|
|
|
n, device = q.shape[-2], q.device |
|
|
|
scale = q.shape[-1] ** -0.5 |
|
|
|
if self.use_flash: |
|
return self.flash_attn(q, k, v, mask=mask) |
|
|
|
kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d" |
|
|
|
|
|
|
|
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale |
|
|
|
|
|
|
|
if exists(mask): |
|
mask = rearrange(mask, "b j -> b 1 1 j") |
|
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) |
|
|
|
|
|
|
|
if self.causal: |
|
causal_mask = self.get_mask(n, device) |
|
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) |
|
|
|
|
|
|
|
attn = sim.softmax(dim=-1) |
|
attn = self.attn_dropout(attn) |
|
|
|
|
|
|
|
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) |
|
|
|
return out |
|
|
|
|
|
def Sequential(*mods): |
|
return nn.Sequential(*filter(exists, mods)) |
|
|
|
|
|
def exists(x): |
|
return x is not None |
|
|
|
|
|
def default(val, d): |
|
if exists(val): |
|
return val |
|
return d() if callable(d) else d |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, dim, scale=True, dim_cond=None): |
|
super().__init__() |
|
self.cond = exists(dim_cond) |
|
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None |
|
|
|
self.scale = dim**0.5 |
|
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None |
|
|
|
def forward(self, x, cond=None): |
|
gamma = default(self.gamma, 1) |
|
out = F.normalize(x, dim=-1) * self.scale * gamma |
|
|
|
if not self.cond: |
|
return out |
|
|
|
assert exists(cond) |
|
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1) |
|
gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta)) |
|
return out * gamma + beta |
|
|
|
|
|
class CausalConv1d(nn.Conv1d): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
(kernel_size,) = self.kernel_size |
|
(dilation,) = self.dilation |
|
(stride,) = self.stride |
|
|
|
assert stride == 1 |
|
self.causal_padding = dilation * (kernel_size - 1) |
|
|
|
def forward(self, x): |
|
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0) |
|
return super().forward(causal_padded_x) |
|
|
|
|
|
class GEGLU(nn.Module): |
|
def forward(self, x): |
|
x, gate = x.chunk(2, dim=-1) |
|
return F.gelu(gate) * x |
|
|
|
|
|
def FeedForward(dim, mult=4, causal_conv=False): |
|
dim_inner = int(dim * mult * 2 / 3) |
|
|
|
conv = None |
|
if causal_conv: |
|
conv = nn.Sequential( |
|
Rearrange("b n d -> b d n"), |
|
CausalConv1d(dim_inner, dim_inner, 3), |
|
Rearrange("b d n -> b n d"), |
|
) |
|
|
|
return Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)) |
|
|
|
|
|
class PerceiverResampler(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
depth=2, |
|
dim_context=None, |
|
num_latents=32, |
|
dim_head=64, |
|
heads=8, |
|
ff_mult=4, |
|
use_flash_attn=False, |
|
): |
|
super().__init__() |
|
dim_context = default(dim_context, dim) |
|
|
|
self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity() |
|
|
|
self.latents = nn.Parameter(torch.randn(num_latents, dim)) |
|
nn.init.normal_(self.latents, std=0.02) |
|
|
|
self.layers = nn.ModuleList([]) |
|
for _ in range(depth): |
|
self.layers.append( |
|
nn.ModuleList( |
|
[ |
|
Attention( |
|
dim=dim, |
|
dim_head=dim_head, |
|
heads=heads, |
|
use_flash=use_flash_attn, |
|
cross_attn_include_queries=True, |
|
), |
|
FeedForward(dim=dim, mult=ff_mult), |
|
] |
|
) |
|
) |
|
|
|
self.norm = RMSNorm(dim) |
|
|
|
def forward(self, x, mask=None): |
|
batch = x.shape[0] |
|
|
|
x = self.proj_context(x) |
|
|
|
latents = repeat(self.latents, "n d -> b n d", b=batch) |
|
|
|
for attn, ff in self.layers: |
|
latents = attn(latents, x, mask=mask) + latents |
|
latents = ff(latents) + latents |
|
|
|
return self.norm(latents) |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
*, |
|
dim_context=None, |
|
causal=False, |
|
dim_head=64, |
|
heads=8, |
|
dropout=0.0, |
|
use_flash=False, |
|
cross_attn_include_queries=False, |
|
): |
|
super().__init__() |
|
self.scale = dim_head**-0.5 |
|
self.heads = heads |
|
self.cross_attn_include_queries = cross_attn_include_queries |
|
|
|
dim_inner = dim_head * heads |
|
dim_context = default(dim_context, dim) |
|
|
|
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash) |
|
self.to_q = nn.Linear(dim, dim_inner, bias=False) |
|
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False) |
|
self.to_out = nn.Linear(dim_inner, dim, bias=False) |
|
|
|
def forward(self, x, context=None, mask=None): |
|
h, has_context = self.heads, exists(context) |
|
|
|
context = default(context, x) |
|
|
|
if has_context and self.cross_attn_include_queries: |
|
context = torch.cat((x, context), dim=-2) |
|
|
|
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1)) |
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) |
|
|
|
out = self.attend(q, k, v, mask=mask) |
|
|
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
return self.to_out(out) |
|
|