File size: 3,458 Bytes
786f6a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import use_fused_attn
from .mlp import Mlp
from .weight_init import trunc_normal_tf_
class AttentionPoolLatent(nn.Module):
""" Attention pooling w/ latent query
"""
fused_attn: torch.jit.Final[bool]
def __init__(
self,
in_features: int,
out_features: int = None,
embed_dim: int = None,
num_heads: int = 8,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_norm: bool = False,
latent_len: int = 1,
latent_dim: int = None,
pos_embed: str = '',
pool_type: str = 'token',
norm_layer: Optional[nn.Module] = None,
drop: float = 0.0,
):
super().__init__()
embed_dim = embed_dim or in_features
out_features = out_features or in_features
assert embed_dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.pool = pool_type
self.fused_attn = use_fused_attn()
if pos_embed == 'abs':
spatial_len = self.feat_size
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
else:
self.pos_embed = None
self.latent_dim = latent_dim or embed_dim
self.latent_len = latent_len
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(drop)
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
self.init_weights()
def init_weights(self):
if self.pos_embed is not None:
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5)
def forward(self, x):
B, N, C = x.shape
if self.pos_embed is not None:
# FIXME interpolate
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
q_latent = self.latent.expand(B, -1, -1)
q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(q, k, v)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
x = self.proj(x)
x = self.proj_drop(x)
x = x + self.mlp(self.norm(x))
# optional pool if latent seq_len > 1 and pooled output is desired
if self.pool == 'token':
x = x[:, 0]
elif self.pool == 'avg':
x = x.mean(1)
return x |