from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from .config import use_fused_attn from .mlp import Mlp from .weight_init import trunc_normal_tf_ class AttentionPoolLatent(nn.Module): """ Attention pooling w/ latent query """ fused_attn: torch.jit.Final[bool] def __init__( self, in_features: int, out_features: int = None, embed_dim: int = None, num_heads: int = 8, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_norm: bool = False, latent_len: int = 1, latent_dim: int = None, pos_embed: str = '', pool_type: str = 'token', norm_layer: Optional[nn.Module] = None, drop: float = 0.0, ): super().__init__() embed_dim = embed_dim or in_features out_features = out_features or in_features assert embed_dim % num_heads == 0 self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scale = self.head_dim ** -0.5 self.pool = pool_type self.fused_attn = use_fused_attn() if pos_embed == 'abs': spatial_len = self.feat_size self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features)) else: self.pos_embed = None self.latent_dim = latent_dim or embed_dim self.latent_len = latent_len self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.proj = nn.Linear(embed_dim, embed_dim) self.proj_drop = nn.Dropout(drop) self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) self.init_weights() def init_weights(self): if self.pos_embed is not None: trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5) def forward(self, x): B, N, C = x.shape if self.pos_embed is not None: # FIXME interpolate x = x + self.pos_embed.unsqueeze(0).to(x.dtype) q_latent = self.latent.expand(B, -1, -1) q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2) kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) k, v = kv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) if self.fused_attn: x = F.scaled_dot_product_attention(q, k, v) else: q = q * self.scale attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) x = attn @ v x = x.transpose(1, 2).reshape(B, self.latent_len, C) x = self.proj(x) x = self.proj_drop(x) x = x + self.mlp(self.norm(x)) # optional pool if latent seq_len > 1 and pooled output is desired if self.pool == 'token': x = x[:, 0] elif self.pool == 'avg': x = x.mean(1) return x