|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .config import use_fused_attn |
|
from .mlp import Mlp |
|
from .weight_init import trunc_normal_tf_ |
|
|
|
|
|
class AttentionPoolLatent(nn.Module): |
|
""" Attention pooling w/ latent query |
|
""" |
|
fused_attn: torch.jit.Final[bool] |
|
|
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int = None, |
|
embed_dim: int = None, |
|
num_heads: int = 8, |
|
mlp_ratio: float = 4.0, |
|
qkv_bias: bool = True, |
|
qk_norm: bool = False, |
|
latent_len: int = 1, |
|
latent_dim: int = None, |
|
pos_embed: str = '', |
|
pool_type: str = 'token', |
|
norm_layer: Optional[nn.Module] = None, |
|
drop: float = 0.0, |
|
): |
|
super().__init__() |
|
embed_dim = embed_dim or in_features |
|
out_features = out_features or in_features |
|
assert embed_dim % num_heads == 0 |
|
self.num_heads = num_heads |
|
self.head_dim = embed_dim // num_heads |
|
self.scale = self.head_dim ** -0.5 |
|
self.pool = pool_type |
|
self.fused_attn = use_fused_attn() |
|
|
|
if pos_embed == 'abs': |
|
spatial_len = self.feat_size |
|
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features)) |
|
else: |
|
self.pos_embed = None |
|
|
|
self.latent_dim = latent_dim or embed_dim |
|
self.latent_len = latent_len |
|
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) |
|
|
|
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) |
|
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) |
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
|
self.proj = nn.Linear(embed_dim, embed_dim) |
|
self.proj_drop = nn.Dropout(drop) |
|
|
|
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() |
|
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) |
|
|
|
self.init_weights() |
|
|
|
def init_weights(self): |
|
if self.pos_embed is not None: |
|
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) |
|
trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5) |
|
|
|
def forward(self, x): |
|
B, N, C = x.shape |
|
|
|
if self.pos_embed is not None: |
|
|
|
x = x + self.pos_embed.unsqueeze(0).to(x.dtype) |
|
|
|
q_latent = self.latent.expand(B, -1, -1) |
|
q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
|
k, v = kv.unbind(0) |
|
|
|
q, k = self.q_norm(q), self.k_norm(k) |
|
|
|
if self.fused_attn: |
|
x = F.scaled_dot_product_attention(q, k, v) |
|
else: |
|
q = q * self.scale |
|
attn = q @ k.transpose(-2, -1) |
|
attn = attn.softmax(dim=-1) |
|
x = attn @ v |
|
x = x.transpose(1, 2).reshape(B, self.latent_len, C) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
|
|
x = x + self.mlp(self.norm(x)) |
|
|
|
|
|
if self.pool == 'token': |
|
x = x[:, 0] |
|
elif self.pool == 'avg': |
|
x = x.mean(1) |
|
return x |