File size: 6,596 Bytes
8d015d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from einops.layers.torch import Rearrange
from einops import rearrange
class BroadMultiHeadAttention(nn.Module):
def __init__(self, dim, heads):
super(BroadMultiHeadAttention, self).__init__()
self.dim = dim
self.heads = heads
self.scale = (dim/heads) ** -0.5
self.attend = nn.Softmax(dim=-1)
def attend_with_rpe(self, Q, K):
Q = rearrange(Q.squeeze(), 'i (heads d) -> heads i d', heads=self.heads)
K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads)
dots = einsum('hid, bhjd -> bhij', Q, K) * self.scale # (b hw) heads 1 pointnum
return self.attend(dots)
def forward(self, Q, K, V):
attn = self.attend_with_rpe(Q, K)
B, _, _ = K.shape
_, N, _ = Q.shape
V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads)
out = einsum('bhij, bhjd -> bhid', attn, V)
out = rearrange(out, 'b heads n d -> b n (heads d)', b=B, n=N)
return out
class MultiHeadAttention(nn.Module):
def __init__(self, dim, heads):
super(MultiHeadAttention, self).__init__()
self.dim = dim
self.heads = heads
self.scale = (dim/heads) ** -0.5
self.attend = nn.Softmax(dim=-1)
def attend_with_rpe(self, Q, K):
Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads)
K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads)
dots = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # (b hw) heads 1 pointnum
return self.attend(dots)
def forward(self, Q, K, V):
attn = self.attend_with_rpe(Q, K)
B, HW, _ = Q.shape
V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads)
out = einsum('bhij, bhjd -> bhid', attn, V)
out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW)
return out
# class MultiHeadAttentionRelative_encoder(nn.Module):
# def __init__(self, dim, heads):
# super(MultiHeadAttentionRelative, self).__init__()
# self.dim = dim
# self.heads = heads
# self.scale = (dim/heads) ** -0.5
# self.attend = nn.Softmax(dim=-1)
# def attend_with_rpe(self, Q, K, Q_r, K_r):
# """
# Q: [BH1W1, H3W3, dim]
# K: [BH1W1, H3W3, dim]
# Q_r: [BH1W1, H3W3, H3W3, dim]
# K_r: [BH1W1, H3W3, H3W3, dim]
# """
# Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
# K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
# K_r = rearrange(K_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
# Q_r = rearrange(Q_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
# # context-context similarity
# c_c = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # [(B H1W1) heads H3W3 H3W3]
# # context-position similarity
# c_p = einsum('bhid, bhjd -> bhij', Q, K_r) * self.scale # [(B H1W1) heads 1 H3W3]
# # position-context similarity
# p_c = einsum('bhijd, bhikd -> bhijk', Q_r[:,:,:,None,:], K[:,:,:,None,:])
# p_c = torch.squeeze(p_c, dim=4)
# p_c = p_c.permute(0, 1, 3, 2)
# dots = c_c + c_p + p_c
# return self.attend(dots)
# def forward(self, Q, K, V, Q_r, K_r):
# attn = self.attend_with_rpe(Q, K, Q_r, K_r)
# B, HW, _ = Q.shape
# V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads)
# out = einsum('bhij, bhjd -> bhid', attn, V)
# out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW)
# return out
class MultiHeadAttentionRelative(nn.Module):
def __init__(self, dim, heads):
super(MultiHeadAttentionRelative, self).__init__()
self.dim = dim
self.heads = heads
self.scale = (dim/heads) ** -0.5
self.attend = nn.Softmax(dim=-1)
def attend_with_rpe(self, Q, K, Q_r, K_r):
"""
Q: [BH1W1, 1, dim]
K: [BH1W1, H3W3, dim]
Q_r: [BH1W1, H3W3, dim]
K_r: [BH1W1, H3W3, dim]
"""
Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) # [BH1W1, heads, 1, dim]
K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
K_r = rearrange(K_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
Q_r = rearrange(Q_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
# context-context similarity
c_c = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # [(B H1W1) heads 1 H3W3]
# context-position similarity
c_p = einsum('bhid, bhjd -> bhij', Q, K_r) * self.scale # [(B H1W1) heads 1 H3W3]
# position-context similarity
p_c = einsum('bhijd, bhikd -> bhijk', Q_r[:,:,:,None,:], K[:,:,:,None,:]) * self.scale
p_c = torch.squeeze(p_c, dim=4)
p_c = p_c.permute(0, 1, 3, 2)
dots = c_c + c_p + p_c
return self.attend(dots)
def forward(self, Q, K, V, Q_r, K_r):
attn = self.attend_with_rpe(Q, K, Q_r, K_r)
B, HW, _ = Q.shape
V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads)
out = einsum('bhij, bhjd -> bhid', attn, V)
out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW)
return out
def LinearPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1/200):
# 200 should be enough for a 8x downsampled image
# assume x to be [_, _, 2]
freq_bands = torch.linspace(0, dim//4-1, dim//4).to(x.device)
return torch.cat([torch.sin(3.14*x[..., -2:-1]*freq_bands*NORMALIZE_FACOR), torch.cos(3.14*x[..., -2:-1]*freq_bands*NORMALIZE_FACOR), torch.sin(3.14*x[..., -1:]*freq_bands*NORMALIZE_FACOR), torch.cos(3.14*x[..., -1:]*freq_bands*NORMALIZE_FACOR)], dim=-1)
def ExpPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1/200):
# 200 should be enough for a 8x downsampled image
# assume x to be [_, _, 2]
freq_bands = torch.linspace(0, dim//4-1, dim//4).to(x.device)
return torch.cat([torch.sin(x[..., -2:-1]*(NORMALIZE_FACOR * 2 ** freq_bands)), torch.cos(x[..., -2:-1]*(NORMALIZE_FACOR * 2 ** freq_bands)), torch.sin(x[..., -1:]*(NORMALIZE_FACOR * 2 ** freq_bands)), torch.cos(x[..., -1:]*(NORMALIZE_FACOR * 2 ** freq_bands))], dim=-1) |