File size: 11,545 Bytes
8d015d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
import numpy as np

from einops import rearrange

from .utils import coords_grid, bilinear_sampler, upflow8
from .attention import BroadMultiHeadAttention, MultiHeadAttention, LinearPositionEmbeddingSine, \
    ExpPositionEmbeddingSine
from typing import Optional, Tuple
from .twins import Size_, PosConv

from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_


class PatchEmbed(nn.Module):
    def __init__(self, patch_size=16, in_chans=1, embed_dim=64, pe='linear'):
        super().__init__()
        self.patch_size = patch_size
        self.dim = embed_dim
        self.pe = pe

        # assert patch_size == 8
        if patch_size == 8:
            self.proj = nn.Sequential(
                nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2),
                nn.ReLU(),
                nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=6, stride=2, padding=2),
                nn.ReLU(),
                nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=6, stride=2, padding=2),
            )
        elif patch_size == 4:
            self.proj = nn.Sequential(
                nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2),
                nn.ReLU(),
                nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=6, stride=2, padding=2),
            )
        else:
            print(f"patch size = {patch_size} is unacceptable.")

        self.ffn_with_coord = nn.Sequential(
            nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1)
        )
        self.norm = nn.LayerNorm(embed_dim * 2)

    def forward(self, x) -> Tuple[torch.Tensor, Size_]:
        B, C, H, W = x.shape  # C == 1

        pad_l = pad_t = 0
        pad_r = (self.patch_size - W % self.patch_size) % self.patch_size
        pad_b = (self.patch_size - H % self.patch_size) % self.patch_size
        x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))

        x = self.proj(x)
        out_size = x.shape[2:]

        patch_coord = coords_grid(B, out_size[0], out_size[1]).to(
            x.device) * self.patch_size + self.patch_size / 2  # in feature coordinate space
        patch_coord = patch_coord.view(B, 2, -1).permute(0, 2, 1)
        if self.pe == 'linear':
            patch_coord_enc = LinearPositionEmbeddingSine(patch_coord, dim=self.dim)
        elif self.pe == 'exp':
            patch_coord_enc = ExpPositionEmbeddingSine(patch_coord, dim=self.dim)
        patch_coord_enc = patch_coord_enc.permute(0, 2, 1).view(B, -1, out_size[0], out_size[1])

        x_pe = torch.cat([x, patch_coord_enc], dim=1)
        x = self.ffn_with_coord(x_pe)
        x = self.norm(x.flatten(2).transpose(1, 2))

        return x, out_size


from .twins import Block, CrossBlock


class VerticalSelfAttentionLayer(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.):
        super(VerticalSelfAttentionLayer, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        embed_dim = dim
        mlp_ratio = 4
        ws = 7
        sr_ratio = 4
        dpr = 0.
        drop_rate = dropout
        attn_drop_rate = 0.

        self.local_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate,
                                 attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=ws, with_rpe=True)
        self.global_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate,
                                  attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=1, with_rpe=True)

    def forward(self, x, size, context=None):
        x = self.local_block(x, size, context)
        x = self.global_block(x, size, context)

        return x

    def compute_params(self):
        num = 0
        for param in self.parameters():
            num += np.prod(param.size())

        return num


class SelfAttentionLayer(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.):
        super(SelfAttentionLayer, self).__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.multi_head_attn = MultiHeadAttention(dim, num_heads)
        self.q, self.k, self.v = nn.Linear(dim, dim, bias=True), nn.Linear(dim, dim, bias=True), nn.Linear(dim, dim,
                                                                                                           bias=True)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.ffn = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        """
            x: [BH1W1, H3W3, D]
        """
        short_cut = x
        x = self.norm1(x)

        q, k, v = self.q(x), self.k(x), self.v(x)

        x = self.multi_head_attn(q, k, v)

        x = self.proj(x)
        x = short_cut + self.proj_drop(x)

        x = x + self.drop_path(self.ffn(self.norm2(x)))

        return x

    def compute_params(self):
        num = 0
        for param in self.parameters():
            num += np.prod(param.size())

        return num


class CrossAttentionLayer(nn.Module):
    def __init__(self, qk_dim, v_dim, query_token_dim, tgt_token_dim, num_heads=8, attn_drop=0., proj_drop=0.,
                 drop_path=0., dropout=0.):
        super(CrossAttentionLayer, self).__init__()
        assert qk_dim % num_heads == 0, f"dim {qk_dim} should be divided by num_heads {num_heads}."
        assert v_dim % num_heads == 0, f"dim {v_dim} should be divided by num_heads {num_heads}."
        """
            Query Token:    [N, C]  -> [N, qk_dim]  (Q)
            Target Token:   [M, D]  -> [M, qk_dim]  (K),    [M, v_dim]  (V)
        """
        self.num_heads = num_heads
        head_dim = qk_dim // num_heads
        self.scale = head_dim ** -0.5

        self.norm1 = nn.LayerNorm(query_token_dim)
        self.norm2 = nn.LayerNorm(query_token_dim)
        self.multi_head_attn = BroadMultiHeadAttention(qk_dim, num_heads)
        self.q, self.k, self.v = nn.Linear(query_token_dim, qk_dim, bias=True), nn.Linear(tgt_token_dim, qk_dim,
                                                                                          bias=True), nn.Linear(
            tgt_token_dim, v_dim, bias=True)

        self.proj = nn.Linear(v_dim, query_token_dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.ffn = nn.Sequential(
            nn.Linear(query_token_dim, query_token_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(query_token_dim, query_token_dim),
            nn.Dropout(dropout)
        )

    def forward(self, query, tgt_token):
        """
            x: [BH1W1, H3W3, D]
        """
        short_cut = query
        query = self.norm1(query)

        q, k, v = self.q(query), self.k(tgt_token), self.v(tgt_token)

        x = self.multi_head_attn(q, k, v)

        x = short_cut + self.proj_drop(self.proj(x))

        x = x + self.drop_path(self.ffn(self.norm2(x)))

        return x


class CostPerceiverEncoder(nn.Module):
    def __init__(self, patch_size, encoder_depth, cost_latent_token_num, cost_latent_dim, cost_latent_input_dim, pe,
                 dropout):
        super(CostPerceiverEncoder, self).__init__()
        self.cost_latent_token_num = cost_latent_token_num
        self.patch_size = patch_size
        self.patch_embed = PatchEmbed(in_chans=1, patch_size=8,
                                      embed_dim=cost_latent_input_dim, pe=pe)

        self.depth = encoder_depth

        self.latent_tokens = nn.Parameter(torch.randn(1, cost_latent_token_num, cost_latent_dim))

        query_token_dim, tgt_token_dim = cost_latent_dim, cost_latent_input_dim * 2
        qk_dim, v_dim = query_token_dim, query_token_dim
        self.input_layer = CrossAttentionLayer(qk_dim, v_dim, query_token_dim, tgt_token_dim, dropout=dropout)

        self.encoder_layers = nn.ModuleList(
            [SelfAttentionLayer(cost_latent_dim, dropout=dropout) for _ in range(self.depth)])

        self.vertical_encoder_layers = nn.ModuleList(
            [VerticalSelfAttentionLayer(cost_latent_dim, dropout=dropout) for _ in range(self.depth)])

    def forward(self, cost_volume, context=None):
        B, heads, H1, W1, H2, W2 = cost_volume.shape
        cost_maps = cost_volume.permute(0, 2, 3, 1, 4, 5).contiguous().view(B * H1 * W1, 1, H2, W2)

        x, size = self.patch_embed(cost_maps)  # B*H1*W1, size[0]*size[1], C

        x = self.input_layer(self.latent_tokens, x)

        short_cut = x

        for idx, layer in enumerate(self.encoder_layers):
            x = layer(x)
            x = x.view(B, H1 * W1, self.cost_latent_token_num, -1).permute(0, 2, 1, 3).reshape(
                B * self.cost_latent_token_num, H1 * W1, -1)
            x = self.vertical_encoder_layers[idx](x, (H1, W1), context)
            x = x.view(B, self.cost_latent_token_num, H1 * W1, -1).permute(0, 2, 1, 3).reshape(B * H1 * W1,
                                                                                               self.cost_latent_token_num,
                                                                                               -1)

        x = x + short_cut
        return x, size


class MemoryEncoder(nn.Module):
    def __init__(self, encoder_latent_dim, cost_heads_num, feat_cross_attn, patch_size, encoder_depth,
                 cost_latent_token_num, cost_latent_dim, cost_latent_input_dim, pe, dropout):
        super(MemoryEncoder, self).__init__()
        self.feat_cross_attn = feat_cross_attn
        self.cost_heads_num = cost_heads_num
        self.channel_convertor = nn.Conv2d(encoder_latent_dim, encoder_latent_dim, 1, padding=0, bias=False)
        self.cost_perceiver_encoder = CostPerceiverEncoder(patch_size, encoder_depth, cost_latent_token_num,
                                                           cost_latent_dim, cost_latent_input_dim, pe, dropout)

    def corr(self, fmap1, fmap2):
        batch, dim, ht, wd = fmap1.shape
        fmap1 = rearrange(fmap1, 'b (heads d) h w -> b heads (h w) d', heads=self.cost_heads_num)
        fmap2 = rearrange(fmap2, 'b (heads d) h w -> b heads (h w) d', heads=self.cost_heads_num)
        corr = einsum('bhid, bhjd -> bhij', fmap1, fmap2)
        corr = corr.permute(0, 2, 1, 3).view(batch * ht * wd, self.cost_heads_num, ht, wd)
        corr = corr.view(batch, ht * wd, self.cost_heads_num, ht * wd).permute(0, 2, 1, 3)
        corr = corr.view(batch, self.cost_heads_num, ht, wd, ht, wd)

        return corr

    def forward(self, feat_s, feat_t, context=None):
        cost_volume = self.corr(feat_s, feat_t)
        x, size = self.cost_perceiver_encoder(cost_volume, context)

        return x, cost_volume, size