File size: 8,699 Bytes
8d015d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import split_feature, merge_splits, split_feature_1d, merge_splits_1d
def single_head_full_attention(q, k, v):
# q, k, v: [B, L, C]
assert q.dim() == k.dim() == v.dim() == 3
scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L]
attn = torch.softmax(scores, dim=2) # [B, L, L]
out = torch.matmul(attn, v) # [B, L, C]
return out
def single_head_full_attention_1d(q, k, v,
h=None,
w=None,
):
# q, k, v: [B, L, C]
assert h is not None and w is not None
assert q.size(1) == h * w
b, _, c = q.size()
q = q.view(b, h, w, c) # [B, H, W, C]
k = k.view(b, h, w, c)
v = v.view(b, h, w, c)
scale_factor = c ** 0.5
scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W]
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C]
return out
def single_head_split_window_attention(q, k, v,
num_splits=1,
with_shift=False,
h=None,
w=None,
attn_mask=None,
):
# ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
# q, k, v: [B, L, C]
assert q.dim() == k.dim() == v.dim() == 3
assert h is not None and w is not None
assert q.size(1) == h * w
b, _, c = q.size()
b_new = b * num_splits * num_splits
window_size_h = h // num_splits
window_size_w = w // num_splits
q = q.view(b, h, w, c) # [B, H, W, C]
k = k.view(b, h, w, c)
v = v.view(b, h, w, c)
scale_factor = c ** 0.5
if with_shift:
assert attn_mask is not None # compute once
shift_size_h = window_size_h // 2
shift_size_w = window_size_w // 2
q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
k = split_feature(k, num_splits=num_splits, channel_last=True)
v = split_feature(v, num_splits=num_splits, channel_last=True)
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K]
if with_shift:
scores += attn_mask.repeat(b, 1, 1)
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
num_splits=num_splits, channel_last=True) # [B, H, W, C]
# shift back
if with_shift:
out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
out = out.view(b, -1, c)
return out
def single_head_split_window_attention_1d(q, k, v,
relative_position_bias=None,
num_splits=1,
with_shift=False,
h=None,
w=None,
attn_mask=None,
):
# q, k, v: [B, L, C]
assert h is not None and w is not None
assert q.size(1) == h * w
b, _, c = q.size()
b_new = b * num_splits * h
window_size_w = w // num_splits
q = q.view(b * h, w, c) # [B*H, W, C]
k = k.view(b * h, w, c)
v = v.view(b * h, w, c)
scale_factor = c ** 0.5
if with_shift:
assert attn_mask is not None # compute once
shift_size_w = window_size_w // 2
q = torch.roll(q, shifts=-shift_size_w, dims=1)
k = torch.roll(k, shifts=-shift_size_w, dims=1)
v = torch.roll(v, shifts=-shift_size_w, dims=1)
q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C]
k = split_feature_1d(k, num_splits=num_splits)
v = split_feature_1d(v, num_splits=num_splits)
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
) / scale_factor # [B*H*K, W/K, W/K]
if with_shift:
# attn_mask: [K, W/K, W/K]
scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K]
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C]
out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C]
# shift back
if with_shift:
out = torch.roll(out, shifts=shift_size_w, dims=2)
out = out.view(b, -1, c)
return out
class SelfAttnPropagation(nn.Module):
"""
flow propagation with self-attention on feature
query: feature0, key: feature0, value: flow
"""
def __init__(self, in_channels,
**kwargs,
):
super(SelfAttnPropagation, self).__init__()
self.q_proj = nn.Linear(in_channels, in_channels)
self.k_proj = nn.Linear(in_channels, in_channels)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feature0, flow,
local_window_attn=False,
local_window_radius=1,
**kwargs,
):
# q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
if local_window_attn:
return self.forward_local_window_attn(feature0, flow,
local_window_radius=local_window_radius)
b, c, h, w = feature0.size()
query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
# a note: the ``correct'' implementation should be:
# ``query = self.q_proj(query), key = self.k_proj(query)''
# this problem is observed while cleaning up the code
# however, this doesn't affect the performance since the projection is a linear operation,
# thus the two projection matrices for key can be merged
# so I just leave it as is in order to not re-train all models :)
query = self.q_proj(query) # [B, H*W, C]
key = self.k_proj(query) # [B, H*W, C]
value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W]
prob = torch.softmax(scores, dim=-1)
out = torch.matmul(prob, value) # [B, H*W, 2]
out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
return out
def forward_local_window_attn(self, feature0, flow,
local_window_radius=1,
):
assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth
assert local_window_radius > 0
b, c, h, w = feature0.size()
value_channel = flow.size(1)
feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)
).reshape(b * h * w, 1, c) # [B*H*W, 1, C]
kernel_size = 2 * local_window_radius + 1
feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w)
feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size,
padding=local_window_radius) # [B, C*(2R+1)^2), H*W]
feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute(
0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2]
flow_window = F.unfold(flow, kernel_size=kernel_size,
padding=local_window_radius) # [B, 2*(2R+1)^2), H*W]
flow_window = flow_window.view(b, value_channel, kernel_size ** 2, h, w).permute(
0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, value_channel) # [B*H*W, (2R+1)^2, 2]
scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2]
prob = torch.softmax(scores, dim=-1)
out = torch.matmul(prob, flow_window).view(b, h, w, value_channel
).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
return out
|