pengdadaaa's picture
Upload 741 files
786f6a6 verified
raw
history blame
19.4 kB
""" Relative position embedding modules and functions
Hacked together by / Copyright 2022 Ross Wightman
"""
import math
import os
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .grid import ndgrid
from .interpolate import RegularGridInterpolator
from .mlp import Mlp
from .weight_init import trunc_normal_
_USE_SCIPY = int(os.environ.get('TIMM_USE_SCIPY_INTERP', 0)) > 0
def gen_relative_position_index(
q_size: Tuple[int, int],
k_size: Optional[Tuple[int, int]] = None,
class_token: bool = False,
) -> torch.Tensor:
# Adapted with significant modifications from Swin / BeiT codebases
# get pair-wise relative position index for each token inside the window
assert k_size is None, 'Different q & k sizes not currently supported' # FIXME
coords = torch.stack(ndgrid(torch.arange(q_size[0]), torch.arange(q_size[1]))).flatten(1) # 2, Wh, Ww
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += q_size[1] - 1
relative_coords[:, :, 0] *= 2 * q_size[1] - 1
num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1)
# else:
# # FIXME different q vs k sizes is a WIP, need to better offset the two grids?
# q_coords = torch.stack(
# ndgrid(
# torch.arange(q_size[0]),
# torch.arange(q_size[1])
# )
# ).flatten(1) # 2, Wh, Ww
# k_coords = torch.stack(
# ndgrid(
# torch.arange(k_size[0]),
# torch.arange(k_size[1])
# )
# ).flatten(1)
# relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
# relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
# relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0
# relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1
# relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1
# relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw
# num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + k_size[1] - 1) + 3
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
if class_token:
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
# NOTE not intended or tested with MLP log-coords
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
relative_position_index[0, 0:] = num_relative_distance
relative_position_index[0:, 0] = num_relative_distance + 1
relative_position_index[0, 0] = num_relative_distance + 2
return relative_position_index.contiguous()
def resize_rel_pos_bias_table_simple(
rel_pos_bias,
new_window_size: Tuple[int, int],
new_bias_shape: Tuple[int, ...],
):
dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1)
if rel_pos_bias.ndim == 3:
# TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported
_, dst_h, dst_w = new_bias_shape
num_attn_heads, src_h, src_w = rel_pos_bias.shape
assert dst_h == dst_size[0] and dst_w == dst_size[1]
if src_h != dst_h or src_w != dst_w:
rel_pos_bias = torch.nn.functional.interpolate(
rel_pos_bias.unsqueeze(0),
size=dst_size,
mode="bicubic",
align_corners=False,
).squeeze(0)
else:
assert rel_pos_bias.ndim == 2
# (num_pos, num_heads) (aka flat) bias shape
dst_num_pos, _ = new_bias_shape
src_num_pos, num_attn_heads = rel_pos_bias.shape
num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1])
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
src_size = (src_size, src_size) # FIXME could support non-equal src if argument passed
if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]:
if num_extra_tokens:
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
else:
extra_tokens = None
rel_pos_bias = torch.nn.functional.interpolate(
rel_pos_bias.transpose(1, 0).reshape((1, -1, src_size[0], src_size[1])),
size=dst_size,
mode="bicubic",
align_corners=False,
).view(-1, dst_num_pos - num_extra_tokens).transpose(0, 1)
if extra_tokens is not None:
rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
return rel_pos_bias
def resize_rel_pos_bias_table_levit(
position_bias_table,
new_size,
interpolation: str = 'bicubic',
antialias: bool = True,
):
"""
Resample relative position bias table suggested in LeVit
Adapted from: https://github.com/microsoft/Cream/blob/main/TinyViT/utils.py
"""
L1, nH1 = position_bias_table.size()
L2, nH2 = new_size
assert nH1 == nH2
if L1 != L2:
orig_dtype = position_bias_table.dtype
position_bias_table = position_bias_table.float()
# bicubic interpolate relative_position_bias_table if not match
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
relative_position_bias_table_resized = F.interpolate(
position_bias_table.permute(1, 0).view(1, nH1, S1, S1),
size=(S2, S2),
mode=interpolation,
antialias=antialias)
relative_position_bias_table_resized = \
relative_position_bias_table_resized.view(nH2, L2).permute(1, 0)
relative_position_bias_table_resized.to(orig_dtype)
return relative_position_bias_table_resized
else:
return position_bias_table
def resize_rel_pos_bias_table(
rel_pos_bias,
new_window_size: Tuple[int, int],
new_bias_shape: Tuple[int, ...],
):
""" Resize relative position bias table using more advanced interpolation.
Modified from code in Microsoft Unilm (https://github.com/microsoft/unilm) repo (BeiT, BeiT-v2, etc).
https://github.com/microsoft/unilm/blob/5255d52de86dad642810f5849dd357769346c1d7/beit/run_class_finetuning.py#L351
Args:
rel_pos_bias:
new_window_size:
new_bias_shape:
Returns:
"""
if _USE_SCIPY:
from scipy import interpolate
dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1)
if rel_pos_bias.ndim == 3:
# TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported
num_extra_tokens = 0
_, dst_h, dst_w = new_bias_shape
assert dst_h == dst_size[0] and dst_w == dst_size[1]
num_attn_heads, src_h, src_w = rel_pos_bias.shape
src_size = (src_h, src_w)
has_flat_shape = False
else:
assert rel_pos_bias.ndim == 2
# (num_pos, num_heads) (aka flat) bias shape
dst_num_pos, _ = new_bias_shape
src_num_pos, num_attn_heads = rel_pos_bias.shape
num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1])
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
src_size = (src_size, src_size)
has_flat_shape = True
if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]:
# print("Interpolating position from %dx%d to %dx%d" % (src_size[0], src_size[1], dst_size[0], dst_size[1]))
if num_extra_tokens:
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
else:
extra_tokens = None
def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)
def _calc(src, dst):
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src // 2)
if gp > dst // 2:
right = q
else:
left = q
dis = []
cur = 1
for i in range(src // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
return r_ids + [0] + dis
y = _calc(src_size[0], dst_size[0])
x = _calc(src_size[1], dst_size[1])
yx = [torch.tensor(y), torch.tensor(x)]
# print("Original positions = %s" % str(x))
ty = dst_size[0] // 2.0
tx = dst_size[1] // 2.0
dy = torch.arange(-ty, ty + 0.1, 1.0)
dx = torch.arange(-tx, tx + 0.1, 1.0)
dyx = ndgrid(dy, dx)
# print("Target positions = %s" % str(dx))
all_rel_pos_bias = []
for i in range(num_attn_heads):
if has_flat_shape:
z = rel_pos_bias[:, i].view(src_size[0], src_size[1]).float()
else:
z = rel_pos_bias[i, :, :].float()
if _USE_SCIPY:
# Original beit code uses scipy w/ cubic interpolation
f = interpolate.interp2d(x, y, z.numpy(), kind='cubic')
r = torch.Tensor(f(dx, dy)).contiguous().to(rel_pos_bias.device)
else:
# Without scipy dependency, I've found a reasonably simple impl
# that supports uneven spaced interpolation pts with 'linear' interp.
# Results are comparable to scipy for model accuracy in most cases.
f = RegularGridInterpolator(yx, z)
r = f(dyx).contiguous().to(rel_pos_bias.device)
if has_flat_shape:
r = r.view(-1, 1)
all_rel_pos_bias.append(r)
if has_flat_shape:
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
else:
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=0)
if extra_tokens is not None:
assert has_flat_shape
rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
return rel_pos_bias
class RelPosBias(nn.Module):
""" Relative Position Bias
Adapted from Swin-V1 relative position bias impl, modularized.
"""
def __init__(self, window_size, num_heads, prefix_tokens=0):
super().__init__()
assert prefix_tokens <= 1
self.window_size = window_size
self.window_area = window_size[0] * window_size[1]
self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
self.register_buffer(
"relative_position_index",
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0).view(-1),
persistent=False,
)
self.init_weights()
def init_weights(self):
trunc_normal_(self.relative_position_bias_table, std=.02)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index]
# win_h * win_w, win_h * win_w, num_heads
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
def gen_relative_log_coords(
win_size: Tuple[int, int],
pretrained_win_size: Tuple[int, int] = (0, 0),
mode='swin',
):
assert mode in ('swin', 'cr')
# as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0]).to(torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1]).to(torch.float32)
relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w))
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
if mode == 'swin':
if pretrained_win_size[0] > 0:
relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
else:
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / math.log2(8)
else:
# mode == 'cr'
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
1.0 + relative_coords_table.abs())
return relative_coords_table
class RelPosMlp(nn.Module):
""" Log-Coordinate Relative Position MLP
Based on ideas presented in Swin-V2 paper (https://arxiv.org/abs/2111.09883)
This impl covers the 'swin' implementation as well as two timm specific modes ('cr', and 'rw')
"""
def __init__(
self,
window_size,
num_heads=8,
hidden_dim=128,
prefix_tokens=0,
mode='cr',
pretrained_window_size=(0, 0)
):
super().__init__()
self.window_size = window_size
self.window_area = self.window_size[0] * self.window_size[1]
self.prefix_tokens = prefix_tokens
self.num_heads = num_heads
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
if mode == 'swin':
self.bias_act = nn.Sigmoid()
self.bias_gain = 16
mlp_bias = (True, False)
else:
self.bias_act = nn.Identity()
self.bias_gain = None
mlp_bias = True
self.mlp = Mlp(
2, # x, y
hidden_features=hidden_dim,
out_features=num_heads,
act_layer=nn.ReLU,
bias=mlp_bias,
drop=(0.125, 0.)
)
self.register_buffer(
"relative_position_index",
gen_relative_position_index(window_size).view(-1),
persistent=False)
# get relative_coords_table
self.register_buffer(
"rel_coords_log",
gen_relative_log_coords(window_size, pretrained_window_size, mode=mode),
persistent=False)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.mlp(self.rel_coords_log)
if self.relative_position_index is not None:
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[self.relative_position_index]
relative_position_bias = relative_position_bias.view(self.bias_shape)
relative_position_bias = relative_position_bias.permute(2, 0, 1)
relative_position_bias = self.bias_act(relative_position_bias)
if self.bias_gain is not None:
relative_position_bias = self.bias_gain * relative_position_bias
if self.prefix_tokens:
relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
def generate_lookup_tensor(
length: int,
max_relative_position: Optional[int] = None,
):
"""Generate a one_hot lookup tensor to reindex embeddings along one dimension.
Args:
length: the length to reindex to.
max_relative_position: the maximum relative position to consider.
Relative position embeddings for distances above this threshold
are zeroed out.
Returns:
a lookup Tensor of size [length, length, vocab_size] that satisfies
ret[n,m,v] = 1{m - n + max_relative_position = v}.
"""
if max_relative_position is None:
max_relative_position = length - 1
# Return the cached lookup tensor, otherwise compute it and cache it.
vocab_size = 2 * max_relative_position + 1
ret = torch.zeros(length, length, vocab_size)
for i in range(length):
for x in range(length):
v = x - i + max_relative_position
if abs(x - i) > max_relative_position:
continue
ret[i, x, v] = 1
return ret
def reindex_2d_einsum_lookup(
relative_position_tensor,
height: int,
width: int,
height_lookup: torch.Tensor,
width_lookup: torch.Tensor,
) -> torch.Tensor:
"""Reindex 2d relative position bias with 2 independent einsum lookups.
Adapted from:
https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
Args:
relative_position_tensor: tensor of shape
[..., vocab_height, vocab_width, ...].
height: height to reindex to.
width: width to reindex to.
height_lookup: one-hot height lookup
width_lookup: one-hot width lookup
Returns:
reindexed_tensor: a Tensor of shape
[..., height * width, height * width, ...]
"""
reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup)
reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup)
area = height * width
return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area)
class RelPosBiasTf(nn.Module):
""" Relative Position Bias Impl (Compatible with Tensorflow MaxViT models)
Adapted from:
https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
"""
def __init__(self, window_size, num_heads, prefix_tokens=0):
super().__init__()
assert prefix_tokens <= 1
self.window_size = window_size
self.window_area = window_size[0] * window_size[1]
self.num_heads = num_heads
vocab_height = 2 * window_size[0] - 1
vocab_width = 2 * window_size[1] - 1
self.bias_shape = (self.num_heads, vocab_height, vocab_width)
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.bias_shape))
self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0]), persistent=False)
self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1]), persistent=False)
self.init_weights()
def init_weights(self):
nn.init.normal_(self.relative_position_bias_table, std=.02)
def get_bias(self) -> torch.Tensor:
# FIXME change to not use one-hot/einsum?
return reindex_2d_einsum_lookup(
self.relative_position_bias_table,
self.window_size[0],
self.window_size[1],
self.height_lookup,
self.width_lookup
)
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()