SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
import numpy as np
from einops import rearrange
from .utils import coords_grid, bilinear_sampler, upflow8
from .attention import BroadMultiHeadAttention, MultiHeadAttention, LinearPositionEmbeddingSine, \
ExpPositionEmbeddingSine
from typing import Optional, Tuple
from .twins import Size_, PosConv
from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_
class PatchEmbed(nn.Module):
def __init__(self, patch_size=16, in_chans=1, embed_dim=64, pe='linear'):
super().__init__()
self.patch_size = patch_size
self.dim = embed_dim
self.pe = pe
# assert patch_size == 8
if patch_size == 8:
self.proj = nn.Sequential(
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2),
nn.ReLU(),
nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=6, stride=2, padding=2),
nn.ReLU(),
nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=6, stride=2, padding=2),
)
elif patch_size == 4:
self.proj = nn.Sequential(
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2),
nn.ReLU(),
nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=6, stride=2, padding=2),
)
else:
print(f"patch size = {patch_size} is unacceptable.")
self.ffn_with_coord = nn.Sequential(
nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1),
nn.ReLU(),
nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1)
)
self.norm = nn.LayerNorm(embed_dim * 2)
def forward(self, x) -> Tuple[torch.Tensor, Size_]:
B, C, H, W = x.shape # C == 1
pad_l = pad_t = 0
pad_r = (self.patch_size - W % self.patch_size) % self.patch_size
pad_b = (self.patch_size - H % self.patch_size) % self.patch_size
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
x = self.proj(x)
out_size = x.shape[2:]
patch_coord = coords_grid(B, out_size[0], out_size[1]).to(
x.device) * self.patch_size + self.patch_size / 2 # in feature coordinate space
patch_coord = patch_coord.view(B, 2, -1).permute(0, 2, 1)
if self.pe == 'linear':
patch_coord_enc = LinearPositionEmbeddingSine(patch_coord, dim=self.dim)
elif self.pe == 'exp':
patch_coord_enc = ExpPositionEmbeddingSine(patch_coord, dim=self.dim)
patch_coord_enc = patch_coord_enc.permute(0, 2, 1).view(B, -1, out_size[0], out_size[1])
x_pe = torch.cat([x, patch_coord_enc], dim=1)
x = self.ffn_with_coord(x_pe)
x = self.norm(x.flatten(2).transpose(1, 2))
return x, out_size
from .twins import Block, CrossBlock
class VerticalSelfAttentionLayer(nn.Module):
def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.):
super(VerticalSelfAttentionLayer, self).__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
embed_dim = dim
mlp_ratio = 4
ws = 7
sr_ratio = 4
dpr = 0.
drop_rate = dropout
attn_drop_rate = 0.
self.local_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=ws, with_rpe=True)
self.global_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=1, with_rpe=True)
def forward(self, x, size, context=None):
x = self.local_block(x, size, context)
x = self.global_block(x, size, context)
return x
def compute_params(self):
num = 0
for param in self.parameters():
num += np.prod(param.size())
return num
class SelfAttentionLayer(nn.Module):
def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.):
super(SelfAttentionLayer, self).__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.multi_head_attn = MultiHeadAttention(dim, num_heads)
self.q, self.k, self.v = nn.Linear(dim, dim, bias=True), nn.Linear(dim, dim, bias=True), nn.Linear(dim, dim,
bias=True)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.ffn = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
"""
x: [BH1W1, H3W3, D]
"""
short_cut = x
x = self.norm1(x)
q, k, v = self.q(x), self.k(x), self.v(x)
x = self.multi_head_attn(q, k, v)
x = self.proj(x)
x = short_cut + self.proj_drop(x)
x = x + self.drop_path(self.ffn(self.norm2(x)))
return x
def compute_params(self):
num = 0
for param in self.parameters():
num += np.prod(param.size())
return num
class CrossAttentionLayer(nn.Module):
def __init__(self, qk_dim, v_dim, query_token_dim, tgt_token_dim, num_heads=8, attn_drop=0., proj_drop=0.,
drop_path=0., dropout=0.):
super(CrossAttentionLayer, self).__init__()
assert qk_dim % num_heads == 0, f"dim {qk_dim} should be divided by num_heads {num_heads}."
assert v_dim % num_heads == 0, f"dim {v_dim} should be divided by num_heads {num_heads}."
"""
Query Token: [N, C] -> [N, qk_dim] (Q)
Target Token: [M, D] -> [M, qk_dim] (K), [M, v_dim] (V)
"""
self.num_heads = num_heads
head_dim = qk_dim // num_heads
self.scale = head_dim ** -0.5
self.norm1 = nn.LayerNorm(query_token_dim)
self.norm2 = nn.LayerNorm(query_token_dim)
self.multi_head_attn = BroadMultiHeadAttention(qk_dim, num_heads)
self.q, self.k, self.v = nn.Linear(query_token_dim, qk_dim, bias=True), nn.Linear(tgt_token_dim, qk_dim,
bias=True), nn.Linear(
tgt_token_dim, v_dim, bias=True)
self.proj = nn.Linear(v_dim, query_token_dim)
self.proj_drop = nn.Dropout(proj_drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.ffn = nn.Sequential(
nn.Linear(query_token_dim, query_token_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(query_token_dim, query_token_dim),
nn.Dropout(dropout)
)
def forward(self, query, tgt_token):
"""
x: [BH1W1, H3W3, D]
"""
short_cut = query
query = self.norm1(query)
q, k, v = self.q(query), self.k(tgt_token), self.v(tgt_token)
x = self.multi_head_attn(q, k, v)
x = short_cut + self.proj_drop(self.proj(x))
x = x + self.drop_path(self.ffn(self.norm2(x)))
return x
class CostPerceiverEncoder(nn.Module):
def __init__(self, patch_size, encoder_depth, cost_latent_token_num, cost_latent_dim, cost_latent_input_dim, pe,
dropout):
super(CostPerceiverEncoder, self).__init__()
self.cost_latent_token_num = cost_latent_token_num
self.patch_size = patch_size
self.patch_embed = PatchEmbed(in_chans=1, patch_size=8,
embed_dim=cost_latent_input_dim, pe=pe)
self.depth = encoder_depth
self.latent_tokens = nn.Parameter(torch.randn(1, cost_latent_token_num, cost_latent_dim))
query_token_dim, tgt_token_dim = cost_latent_dim, cost_latent_input_dim * 2
qk_dim, v_dim = query_token_dim, query_token_dim
self.input_layer = CrossAttentionLayer(qk_dim, v_dim, query_token_dim, tgt_token_dim, dropout=dropout)
self.encoder_layers = nn.ModuleList(
[SelfAttentionLayer(cost_latent_dim, dropout=dropout) for _ in range(self.depth)])
self.vertical_encoder_layers = nn.ModuleList(
[VerticalSelfAttentionLayer(cost_latent_dim, dropout=dropout) for _ in range(self.depth)])
def forward(self, cost_volume, context=None):
B, heads, H1, W1, H2, W2 = cost_volume.shape
cost_maps = cost_volume.permute(0, 2, 3, 1, 4, 5).contiguous().view(B * H1 * W1, 1, H2, W2)
x, size = self.patch_embed(cost_maps) # B*H1*W1, size[0]*size[1], C
x = self.input_layer(self.latent_tokens, x)
short_cut = x
for idx, layer in enumerate(self.encoder_layers):
x = layer(x)
x = x.view(B, H1 * W1, self.cost_latent_token_num, -1).permute(0, 2, 1, 3).reshape(
B * self.cost_latent_token_num, H1 * W1, -1)
x = self.vertical_encoder_layers[idx](x, (H1, W1), context)
x = x.view(B, self.cost_latent_token_num, H1 * W1, -1).permute(0, 2, 1, 3).reshape(B * H1 * W1,
self.cost_latent_token_num,
-1)
x = x + short_cut
return x, size
class MemoryEncoder(nn.Module):
def __init__(self, encoder_latent_dim, cost_heads_num, feat_cross_attn, patch_size, encoder_depth,
cost_latent_token_num, cost_latent_dim, cost_latent_input_dim, pe, dropout):
super(MemoryEncoder, self).__init__()
self.feat_cross_attn = feat_cross_attn
self.cost_heads_num = cost_heads_num
self.channel_convertor = nn.Conv2d(encoder_latent_dim, encoder_latent_dim, 1, padding=0, bias=False)
self.cost_perceiver_encoder = CostPerceiverEncoder(patch_size, encoder_depth, cost_latent_token_num,
cost_latent_dim, cost_latent_input_dim, pe, dropout)
def corr(self, fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = rearrange(fmap1, 'b (heads d) h w -> b heads (h w) d', heads=self.cost_heads_num)
fmap2 = rearrange(fmap2, 'b (heads d) h w -> b heads (h w) d', heads=self.cost_heads_num)
corr = einsum('bhid, bhjd -> bhij', fmap1, fmap2)
corr = corr.permute(0, 2, 1, 3).view(batch * ht * wd, self.cost_heads_num, ht, wd)
corr = corr.view(batch, ht * wd, self.cost_heads_num, ht * wd).permute(0, 2, 1, 3)
corr = corr.view(batch, self.cost_heads_num, ht, wd, ht, wd)
return corr
def forward(self, feat_s, feat_t, context=None):
cost_volume = self.corr(feat_s, feat_t)
x, size = self.cost_perceiver_encoder(cost_volume, context)
return x, cost_volume, size