""" Lambda Layer Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` - https://arxiv.org/abs/2102.08602 @misc{2102.08602, Author = {Irwan Bello}, Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention}, Year = {2021}, } Status: This impl is a WIP. Code snippets in the paper were used as reference but good chance some details are missing/wrong. I've only implemented local lambda conv based pos embeddings. For a PyTorch impl that includes other embedding options checkout https://github.com/lucidrains/lambda-networks Hacked together by / Copyright 2021 Ross Wightman """ import torch from torch import nn import torch.nn.functional as F from .grid import ndgrid from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ def rel_pos_indices(size): size = to_2tuple(size) pos = torch.stack(ndgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1) rel_pos = pos[:, None, :] - pos[:, :, None] rel_pos[0] += size[0] - 1 rel_pos[1] += size[1] - 1 return rel_pos # 2, H * W, H * W class LambdaLayer(nn.Module): """Lambda Layer Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` - https://arxiv.org/abs/2102.08602 NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. The internal dimensions of the lambda module are controlled via the interaction of several arguments. * the output dimension of the module is specified by dim_out, which falls back to input dim if not set * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim * the query (q) and key (k) dimension are determined by * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None * q = num_heads * dim_head, k = dim_head * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set Args: dim (int): input dimension to the module dim_out (int): output dimension of the module, same as dim if not set feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W stride (int): output stride of the module, avg pool used if stride == 2 num_heads (int): parallel attention heads. dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9) qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) qkv_bias (bool): add bias to q, k, and v projections """ def __init__( self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9, qk_ratio=1.0, qkv_bias=False): super().__init__() dim_out = dim_out or dim assert dim_out % num_heads == 0, ' should be divided by num_heads' self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads self.num_heads = num_heads self.dim_v = dim_out // num_heads self.qkv = nn.Conv2d( dim, num_heads * self.dim_qk + self.dim_qk + self.dim_v, kernel_size=1, bias=qkv_bias) self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk) self.norm_v = nn.BatchNorm2d(self.dim_v) if r is not None: # local lambda convolution for pos self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0)) self.pos_emb = None self.rel_pos_indices = None else: # relative pos embedding assert feat_size is not None feat_size = to_2tuple(feat_size) rel_size = [2 * s - 1 for s in feat_size] self.conv_lambda = None self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk)) self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.reset_parameters() def reset_parameters(self): trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in if self.conv_lambda is not None: trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5) if self.pos_emb is not None: trunc_normal_(self.pos_emb, std=.02) def forward(self, x): B, C, H, W = x.shape M = H * W qkv = self.qkv(x) q, k, v = torch.split(qkv, [ self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1) q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M content_lam = k @ v # B, K, V content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V if self.pos_emb is None: position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V else: # FIXME relative pos embedding path not fully verified pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1) position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W out = self.pool(out) return out