|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .utils import split_feature, merge_splits, split_feature_1d, merge_splits_1d |
|
|
|
|
|
def single_head_full_attention(q, k, v): |
|
|
|
assert q.dim() == k.dim() == v.dim() == 3 |
|
|
|
scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) |
|
attn = torch.softmax(scores, dim=2) |
|
out = torch.matmul(attn, v) |
|
|
|
return out |
|
|
|
|
|
def single_head_full_attention_1d(q, k, v, |
|
h=None, |
|
w=None, |
|
): |
|
|
|
|
|
assert h is not None and w is not None |
|
assert q.size(1) == h * w |
|
|
|
b, _, c = q.size() |
|
|
|
q = q.view(b, h, w, c) |
|
k = k.view(b, h, w, c) |
|
v = v.view(b, h, w, c) |
|
|
|
scale_factor = c ** 0.5 |
|
|
|
scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor |
|
|
|
attn = torch.softmax(scores, dim=-1) |
|
|
|
out = torch.matmul(attn, v).view(b, -1, c) |
|
|
|
return out |
|
|
|
|
|
def single_head_split_window_attention(q, k, v, |
|
num_splits=1, |
|
with_shift=False, |
|
h=None, |
|
w=None, |
|
attn_mask=None, |
|
): |
|
|
|
|
|
assert q.dim() == k.dim() == v.dim() == 3 |
|
|
|
assert h is not None and w is not None |
|
assert q.size(1) == h * w |
|
|
|
b, _, c = q.size() |
|
|
|
b_new = b * num_splits * num_splits |
|
|
|
window_size_h = h // num_splits |
|
window_size_w = w // num_splits |
|
|
|
q = q.view(b, h, w, c) |
|
k = k.view(b, h, w, c) |
|
v = v.view(b, h, w, c) |
|
|
|
scale_factor = c ** 0.5 |
|
|
|
if with_shift: |
|
assert attn_mask is not None |
|
shift_size_h = window_size_h // 2 |
|
shift_size_w = window_size_w // 2 |
|
|
|
q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) |
|
k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) |
|
v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) |
|
|
|
q = split_feature(q, num_splits=num_splits, channel_last=True) |
|
k = split_feature(k, num_splits=num_splits, channel_last=True) |
|
v = split_feature(v, num_splits=num_splits, channel_last=True) |
|
|
|
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) |
|
) / scale_factor |
|
|
|
if with_shift: |
|
scores += attn_mask.repeat(b, 1, 1) |
|
|
|
attn = torch.softmax(scores, dim=-1) |
|
|
|
out = torch.matmul(attn, v.view(b_new, -1, c)) |
|
|
|
out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), |
|
num_splits=num_splits, channel_last=True) |
|
|
|
|
|
if with_shift: |
|
out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) |
|
|
|
out = out.view(b, -1, c) |
|
|
|
return out |
|
|
|
|
|
def single_head_split_window_attention_1d(q, k, v, |
|
relative_position_bias=None, |
|
num_splits=1, |
|
with_shift=False, |
|
h=None, |
|
w=None, |
|
attn_mask=None, |
|
): |
|
|
|
|
|
assert h is not None and w is not None |
|
assert q.size(1) == h * w |
|
|
|
b, _, c = q.size() |
|
|
|
b_new = b * num_splits * h |
|
|
|
window_size_w = w // num_splits |
|
|
|
q = q.view(b * h, w, c) |
|
k = k.view(b * h, w, c) |
|
v = v.view(b * h, w, c) |
|
|
|
scale_factor = c ** 0.5 |
|
|
|
if with_shift: |
|
assert attn_mask is not None |
|
shift_size_w = window_size_w // 2 |
|
|
|
q = torch.roll(q, shifts=-shift_size_w, dims=1) |
|
k = torch.roll(k, shifts=-shift_size_w, dims=1) |
|
v = torch.roll(v, shifts=-shift_size_w, dims=1) |
|
|
|
q = split_feature_1d(q, num_splits=num_splits) |
|
k = split_feature_1d(k, num_splits=num_splits) |
|
v = split_feature_1d(v, num_splits=num_splits) |
|
|
|
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) |
|
) / scale_factor |
|
|
|
if with_shift: |
|
|
|
scores += attn_mask.repeat(b * h, 1, 1) |
|
|
|
attn = torch.softmax(scores, dim=-1) |
|
|
|
out = torch.matmul(attn, v.view(b_new, -1, c)) |
|
|
|
out = merge_splits_1d(out, h, num_splits=num_splits) |
|
|
|
|
|
if with_shift: |
|
out = torch.roll(out, shifts=shift_size_w, dims=2) |
|
|
|
out = out.view(b, -1, c) |
|
|
|
return out |
|
|
|
|
|
class SelfAttnPropagation(nn.Module): |
|
""" |
|
flow propagation with self-attention on feature |
|
query: feature0, key: feature0, value: flow |
|
""" |
|
|
|
def __init__(self, in_channels, |
|
**kwargs, |
|
): |
|
super(SelfAttnPropagation, self).__init__() |
|
|
|
self.q_proj = nn.Linear(in_channels, in_channels) |
|
self.k_proj = nn.Linear(in_channels, in_channels) |
|
|
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def forward(self, feature0, flow, |
|
local_window_attn=False, |
|
local_window_radius=1, |
|
**kwargs, |
|
): |
|
|
|
if local_window_attn: |
|
return self.forward_local_window_attn(feature0, flow, |
|
local_window_radius=local_window_radius) |
|
|
|
b, c, h, w = feature0.size() |
|
|
|
query = feature0.view(b, c, h * w).permute(0, 2, 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query = self.q_proj(query) |
|
key = self.k_proj(query) |
|
|
|
value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) |
|
|
|
scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) |
|
prob = torch.softmax(scores, dim=-1) |
|
|
|
out = torch.matmul(prob, value) |
|
out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) |
|
|
|
return out |
|
|
|
def forward_local_window_attn(self, feature0, flow, |
|
local_window_radius=1, |
|
): |
|
assert flow.size(1) == 2 or flow.size(1) == 1 |
|
assert local_window_radius > 0 |
|
|
|
b, c, h, w = feature0.size() |
|
|
|
value_channel = flow.size(1) |
|
|
|
feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1) |
|
).reshape(b * h * w, 1, c) |
|
|
|
kernel_size = 2 * local_window_radius + 1 |
|
|
|
feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) |
|
|
|
feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size, |
|
padding=local_window_radius) |
|
|
|
feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute( |
|
0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) |
|
|
|
flow_window = F.unfold(flow, kernel_size=kernel_size, |
|
padding=local_window_radius) |
|
|
|
flow_window = flow_window.view(b, value_channel, kernel_size ** 2, h, w).permute( |
|
0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, value_channel) |
|
|
|
scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) |
|
|
|
prob = torch.softmax(scores, dim=-1) |
|
|
|
out = torch.matmul(prob, flow_window).view(b, h, w, value_channel |
|
).permute(0, 3, 1, 2).contiguous() |
|
|
|
return out |
|
|