File size: 2,961 Bytes
3ef0208 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# helpers functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# lambda layer
class RLambdaLayer(nn.Module):
def __init__(
self,
dim,
*,
dim_k,
n=None,
r=None,
heads=4,
dim_out=None,
dim_u=1,
recurrence=None
):
super().__init__()
dim_out = default(dim_out, dim)
self.u = dim_u # intra-depth dimension
self.heads = heads
assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
dim_v = dim_out // heads
self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias=False)
self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias=False)
self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias=False)
self.norm_q = nn.BatchNorm2d(dim_k * heads)
self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
self.local_contexts = exists(r)
self.recurrence = recurrence
if exists(r):
assert (r % 2) == 1, 'Receptive kernel size should be odd'
self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding=(0, r // 2, r // 2))
else:
assert exists(n), 'You must specify the total sequence length (h x w)'
self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
def apply_lambda(self, lambda_c, lambda_p, x):
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
q = self.to_q(x)
q = self.norm_q(q)
q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h=h)
Yc = einsum('b h k n, b k v -> b h v n', q, lambda_c)
if self.local_contexts:
Yp = einsum('b h k n, b k v n -> b h v n', q, lambda_p.flatten(3))
else:
Yp = einsum('b h k n, b n k v -> b h v n', q, lambda_p)
Y = Yc + Yp
out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh=hh, ww=ww)
return out
def forward(self, x):
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
k = self.to_k(x)
v = self.to_v(x)
v = self.norm_v(v)
k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u=u)
v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u=u)
k = k.softmax(dim=-1)
位c = einsum('b u k m, b u v m -> b k v', k, v)
if self.local_contexts:
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh=hh, ww=ww)
位p = self.pos_conv(v)
else:
位p = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
out = x
for i in range(self.recurrence):
out = self.apply_lambda(位c, 位p, out)
return out
|