File size: 2,961 Bytes
3ef0208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange


# helpers functions

def exists(val):
    return val is not None


def default(val, d):
    return val if exists(val) else d


# lambda layer

class RLambdaLayer(nn.Module):
    def __init__(

            self,

            dim,

            *,

            dim_k,

            n=None,

            r=None,

            heads=4,

            dim_out=None,

            dim_u=1,

            recurrence=None

    ):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.u = dim_u  # intra-depth dimension
        self.heads = heads

        assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
        dim_v = dim_out // heads

        self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias=False)
        self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias=False)
        self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias=False)

        self.norm_q = nn.BatchNorm2d(dim_k * heads)
        self.norm_v = nn.BatchNorm2d(dim_v * dim_u)

        self.local_contexts = exists(r)
        self.recurrence = recurrence
        if exists(r):
            assert (r % 2) == 1, 'Receptive kernel size should be odd'
            self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding=(0, r // 2, r // 2))
        else:
            assert exists(n), 'You must specify the total sequence length (h x w)'
            self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))

    def apply_lambda(self, lambda_c, lambda_p, x):
        b, c, hh, ww, u, h = *x.shape, self.u, self.heads
        q = self.to_q(x)
        q = self.norm_q(q)
        q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h=h)
        Yc = einsum('b h k n, b k v -> b h v n', q, lambda_c)
        if self.local_contexts:
            Yp = einsum('b h k n, b k v n -> b h v n', q, lambda_p.flatten(3))
        else:
            Yp = einsum('b h k n, b n k v -> b h v n', q, lambda_p)
        Y = Yc + Yp
        out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh=hh, ww=ww)
        return out

    def forward(self, x):
        b, c, hh, ww, u, h = *x.shape, self.u, self.heads

        k = self.to_k(x)
        v = self.to_v(x)

        v = self.norm_v(v)

        k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u=u)
        v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u=u)

        k = k.softmax(dim=-1)

        位c = einsum('b u k m, b u v m -> b k v', k, v)

        if self.local_contexts:
            v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh=hh, ww=ww)
            位p = self.pos_conv(v)
        else:
            位p = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
        out = x
        for i in range(self.recurrence):
            out = self.apply_lambda(位c, 位p, out)
        return out