Qihang Yu
Add kMaX-DeepLab
f395a55
# Reference: https://github.com/google-research/deeplab2/blob/main/model/pixel_decoder/kmax.py
# Modified by Qihang Yu
from turtle import forward
from typing import Dict, List
import torch
from torch import nn
from torch.nn import functional as F
from timm.models.layers import DropPath
from timm.models.layers import trunc_normal_tf_ as trunc_normal_
from detectron2.config import configurable
from detectron2.layers import ShapeSpec
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from torch.cuda.amp import autocast
from ..backbone.convnext import LayerNorm
import math
def get_activation(name):
if name is None or name.lower() == 'none':
return nn.Identity()
if name == 'relu':
return nn.ReLU()
elif name == 'gelu':
return nn.GELU()
class SyncBNCPU(nn.SyncBatchNorm):
def forward(self, input):
self._check_input_dim(input)
self._check_non_zero_input_channels(input)
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
bn_training = False
running_mean = self.running_mean
running_var = self.running_var
# fallback to framework BN when synchronization is not necessary
return F.batch_norm(
input,
running_mean,
running_var,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)
def get_norm(name, channels):
if name is None or name.lower() == 'none':
return nn.Identity()
if name.lower() == 'syncbn':
return SyncBNCPU(channels, eps=1e-3, momentum=0.01)
class ConvBN(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, norm=None, act=None,
conv_type='2d', conv_init='he_normal', norm_init=1.0):
super().__init__()
if conv_type == '2d':
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
elif conv_type == '1d':
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.norm = get_norm(norm, out_channels)
self.act = get_activation(act)
if conv_init == 'normal':
nn.init.normal_(self.conv.weight, std=.02)
elif conv_init == 'trunc_normal':
trunc_normal_(self.conv.weight, std=.02)
elif conv_init == 'he_normal':
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/HeNormal
trunc_normal_(self.conv.weight, std=math.sqrt(2.0 / in_channels))
elif conv_init == 'xavier_uniform':
nn.init.xavier_uniform_(self.conv.weight)
if bias:
nn.init.zeros_(self.conv.bias)
if norm is not None:
nn.init.constant_(self.norm.weight, norm_init)
def forward(self, x):
return self.act(self.norm(self.conv(x)))
MAX_SPAN = 255
def _compute_relative_distance_matrix(query_length, key_length):
if (key_length - query_length) % 2:
raise ValueError('Key_length should be query_length + 2 * memory_flange.')
key_index = torch.arange(key_length)
query_index = torch.arange(query_length) + (key_length - query_length) // 2
distance_matrix = key_index[None, :] - query_index[:, None]
# Shift the distance_matrix so that it is >= 0. Each entry of the
# distance_matrix distance will index a relative positional embedding.
distance_matrix = distance_matrix + MAX_SPAN - 1
return distance_matrix
class RelativePositionalEncoding(nn.Module):
def __init__(self, query_length, key_length, depth):
super().__init__()
self._embeddings = nn.Embedding(MAX_SPAN * 2 - 1, depth)
trunc_normal_(self._embeddings.weight, std=1.0)
self._relative_distance_matrix = _compute_relative_distance_matrix(query_length, key_length)
self.query_length = query_length
self.key_length = key_length
self.depth = depth
def forward(self):
return self._embeddings.weight[self._relative_distance_matrix.reshape(-1)].reshape(self.query_length, self.key_length, self.depth)
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_layers.py#L36
class AxialAttention(nn.Module):
def __init__(self, in_planes, query_shape=56, total_key_depth=512, total_value_depth=1024, num_heads=8):
assert (total_key_depth % num_heads == 0) and (total_value_depth % num_heads == 0)
super().__init__()
self._in_planes = in_planes
self._query_shape = query_shape
self._total_key_depth = total_key_depth
self._total_value_depth = total_value_depth
self._num_heads = num_heads
self._key_depth_per_head = total_key_depth // num_heads
self.qkv_transform = ConvBN(in_planes, self._total_key_depth * 2 + self._total_value_depth, kernel_size=1, stride=1,
padding=0, bias=False, norm=None, act=None, conv_type='1d')
trunc_normal_(self.qkv_transform.conv.weight, std=in_planes ** -0.5)
self._query_rpe = RelativePositionalEncoding(query_shape, query_shape, self._key_depth_per_head)
self._key_rpe = RelativePositionalEncoding(query_shape, query_shape, self._key_depth_per_head)
self._value_rpe = RelativePositionalEncoding(query_shape, query_shape, total_value_depth // num_heads)
self._batch_norm_qkv = get_norm('syncbn', self._total_key_depth * 2 + self._total_value_depth)
self._batch_norm_similarity = get_norm('syncbn', num_heads * 3)
self._batch_norm_retrieved_output = get_norm('syncbn', self._total_value_depth * 2)
def forward(self, x):
N, C, L = x.shape
qkv = self._batch_norm_qkv(self.qkv_transform(x))
q, k, v = torch.split(qkv, [self._total_key_depth, self._total_key_depth, self._total_value_depth], dim=1)
q = q.reshape(N, self._num_heads, self._total_key_depth // self._num_heads, L)
k = k.reshape(N, self._num_heads, self._total_key_depth // self._num_heads, L)
v = v.reshape(N, self._num_heads, self._total_value_depth // self._num_heads, L)
similarity_logits = []
content_similarity = torch.einsum('bhdl,bhdm->bhlm', q, k)
query_rpe = self._query_rpe()
query_rpe_similarity = torch.einsum('bhdl,lmd->bhlm', q, query_rpe)
key_rpe = self._key_rpe()
key_rpe_similarity = torch.einsum('bhdm,lmd->bhlm', k, key_rpe)
similarity_logits = torch.cat([content_similarity, query_rpe_similarity, key_rpe_similarity], dim=1)
similarity_logits = self._batch_norm_similarity(similarity_logits).reshape(N, 3, self._num_heads, L, L).sum(dim=1)
with autocast(enabled=False):
weights = F.softmax(similarity_logits.float(), dim=-1)
retrieved_content = torch.einsum('bhlm,bhdm->bhdl', weights, v)
value_rpe = self._value_rpe()
retrieved_rpe = torch.einsum('bhlm,lmd->bhdl', weights, value_rpe)
retrieved_output = torch.cat([retrieved_content, retrieved_rpe], dim=1).reshape(N, 2*self._total_value_depth, L)
retrieved_output = self._batch_norm_retrieved_output(retrieved_output).reshape(N, 2, self._total_value_depth, L).sum(1)
return retrieved_output
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_layers.py#L316
class AxialAttention2D(nn.Module):
def __init__(self, in_planes, query_shape=[56, 56], filters=512, key_expansion=1, value_expansion=2, num_heads=8):
super().__init__()
total_key_depth = int(round(filters * key_expansion))
total_value_depth = int(round(filters * value_expansion))
self._total_key_depth = total_key_depth
self._total_value_depth = total_value_depth
self._height_axis = AxialAttention(
in_planes=in_planes,
query_shape=query_shape[0],
total_key_depth=total_key_depth,
total_value_depth=total_value_depth,
num_heads=num_heads)
self._width_axis = AxialAttention(
in_planes=total_value_depth,
query_shape=query_shape[1],
total_key_depth=total_key_depth,
total_value_depth=total_value_depth,
num_heads=num_heads)
def forward(self, x):
# N C H W -> N W C H
N, C, H, W = x.shape
x = x.permute(0, 3, 1, 2).contiguous()
x = x.reshape(N*W, C, H)
x = self._height_axis(x)
# N W C H -> N H C W
x = x.reshape(N, W, self._total_value_depth, H).permute(0, 3, 2, 1).contiguous()
x = x.reshape(N*H, self._total_value_depth, W)
x = self._width_axis(x)
x = x.reshape(N, H, self._total_value_depth, W).permute(0, 2, 1, 3).contiguous()
x = x.reshape(N, self._total_value_depth, H, W)
return x
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_blocks.py#L36
class SingleBlock(nn.Module):
def __init__(self, inplanes, filter_list, block_type, query_shape=[56, 56], key_expansion=1, value_expansion=2, num_heads=8, drop_path_prob=0.0):
super(SingleBlock, self).__init__()
self._block_type = block_type.lower()
self._filter_list = filter_list
self._conv1_bn_act = ConvBN(inplanes, self._filter_list[0], kernel_size=1, bias=False, norm='syncbn', act='gelu')
if self._block_type == 'axial':
self._attention = AxialAttention2D(in_planes=self._filter_list[0], query_shape=query_shape, filters=self._filter_list[1],
key_expansion=key_expansion, value_expansion=value_expansion, num_heads=num_heads)
output_channel = filter_list[1] * value_expansion
elif self._block_type == 'bottleneck':
self._conv2_bn_act = ConvBN(self._filter_list[0], self._filter_list[1], kernel_size=3, padding=1, bias=False, norm='syncbn', act='gelu')
output_channel = filter_list[1]
self._conv3_bn = ConvBN(output_channel, self._filter_list[2], kernel_size=1, bias=False, norm='syncbn', act=None, norm_init=0.0)
self._shortcut = None
if inplanes != self._filter_list[-1]:
self._shortcut = ConvBN(inplanes, self._filter_list[-1], kernel_size=1, bias=False, norm='syncbn', act=None)
self.drop_path = DropPath(drop_path_prob) if drop_path_prob > 0. else nn.Identity()
def forward(self, x):
x = F.gelu(x)
shortcut = x
if self._shortcut is not None:
shortcut = self._shortcut(shortcut)
x = self._conv1_bn_act(x)
if self._block_type == 'axial':
x = self._attention(x)
x = F.gelu(x)
elif self._block_type == 'bottleneck':
x = self._conv2_bn_act(x)
x = self._conv3_bn(x)
x = self.drop_path(x) + shortcut
return x
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L42
class BlockGroup(nn.Module):
def __init__(self, inplanes, base_filter, num_blocks, block_type, **kwargs):
super().__init__()
self._num_blocks = num_blocks
block_type = block_type.lower()
if block_type == 'axial':
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L247
filter_list = [base_filter * 2, base_filter, base_filter * 4]
elif block_type == 'bottleneck':
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L250
filter_list = [base_filter, base_filter, base_filter * 4]
self._blocks = nn.ModuleList()
for i in range(num_blocks):
self._blocks.append(SingleBlock(inplanes=inplanes, filter_list=filter_list, block_type=block_type, **kwargs))
inplanes = filter_list[-1]
def forward(self, x):
for i in range(self._num_blocks):
x = self._blocks[i](x)
return x
# https://github.com/google-research/deeplab2/blob/7a01a7165e97b3325ad7ea9b6bcc02d67fecd07a/model/layers/resized_fuse.py#L31
class ResizedFuse(nn.Module):
def __init__(self, low_in_channels, high_in_channels, out_channels):
super().__init__()
self.low_in_channels = low_in_channels
self.high_in_channels = high_in_channels
self.out_channels = out_channels
if low_in_channels != out_channels:
self._conv_bn_low = ConvBN(low_in_channels, out_channels, kernel_size=1, bias=False, norm='syncbn', act=None)
if high_in_channels != out_channels:
self._conv_bn_high = ConvBN(high_in_channels, out_channels, kernel_size=1, bias=False, norm='syncbn', act=None)
def forward(self, lowres_x, highres_x):
align_corners = (lowres_x.shape[-1] % 2 == 1)
if self.low_in_channels != self.out_channels:
lowres_x = F.gelu(lowres_x)
lowres_x = self._conv_bn_low(lowres_x)
lowres_x = F.interpolate(lowres_x, size=highres_x.shape[2:], mode='bilinear', align_corners=align_corners)
else:
lowres_x = F.interpolate(lowres_x, size=highres_x.shape[2:], mode='bilinear', align_corners=align_corners)
if self.high_in_channels != self.out_channels:
highres_x = F.gelu(highres_x)
highres_x = self._conv_bn_high(highres_x)
return lowres_x + highres_x
@SEM_SEG_HEADS_REGISTRY.register()
class kMaXPixelDecoder(nn.Module):
@configurable
def __init__(
self,
input_shape: Dict[str, ShapeSpec],
*,
dec_layers: List[int],
dec_channels: List[int],
layer_types: List[str],
drop_path_prob: float,
spatial_shape: List[int],
):
"""
NOTE: this interface is experimental.
Args:
"""
super().__init__()
self.num_stages = len(input_shape)
assert self.num_stages == len(dec_layers) and self.num_stages == len(dec_channels) and self.num_stages == len(layer_types)
# For now, we hard code all hyper-parameters.
block_types = ['axial', 'axial', 'bottleneck', 'bottleneck']
input_shape = sorted(input_shape.items(), key=lambda x: -x[1].stride)
self.in_features = [k for k, v in input_shape] # starting from "res5" to "res2"
in_channels = [v.channels for k, v in input_shape]
add_one = (spatial_shape[0] % 2, spatial_shape[1] % 2)
query_shape = [
(spatial_shape[0]//32+add_one[0], spatial_shape[1]//32+add_one[1]),
(spatial_shape[0]//16+add_one[0], spatial_shape[1]//16+add_one[1]),
(spatial_shape[0]//8+add_one[0], spatial_shape[1]//8+add_one[1]),
(spatial_shape[0]//4+add_one[0], spatial_shape[1]//4+add_one[1])]
self._in_norms = nn.ModuleList()
self._stages = nn.ModuleList()
self._resized_fuses = nn.ModuleList()
for i in range(self.num_stages):
self._in_norms.append(LayerNorm(in_channels[i], data_format="channels_first"))
inplanes = in_channels[i] if i == 0 else dec_channels[i]
self._stages.append(BlockGroup(inplanes=inplanes,
base_filter=dec_channels[i], num_blocks=dec_layers[i], block_type=block_types[i],
query_shape=query_shape[i], key_expansion=1, value_expansion=2, num_heads=8, drop_path_prob=0.0))
if i > 0:
self._resized_fuses.append(ResizedFuse(
low_in_channels=dec_channels[i-1] * 4,
high_in_channels=in_channels[i],
out_channels=dec_channels[i]))
@classmethod
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
ret = {}
ret["input_shape"] = {
k: v for k, v in input_shape.items() if k in cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.IN_FEATURES
}
ret["dec_layers"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DEC_LAYERS
ret["dec_channels"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DEC_CHANNELS
ret["layer_types"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.LAYER_TYPES
ret["drop_path_prob"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DROP_PATH_PROB
ret["spatial_shape"] = cfg.INPUT.IMAGE_SIZE # We expect the height == width
return ret
def forward_features(self, features):
out = []
multi_scale_features = []
x = self._in_norms[0](features[self.in_features[0]])
for idx in range(self.num_stages - 1):
x = self._stages[idx](x)
out.append(x)
x = self._resized_fuses[idx](
lowres_x=x,
highres_x=self._in_norms[idx+1](features[self.in_features[idx+1]]))
x = self._stages[-1](x)
out.append(x)
multi_scale_features = out[:3] # OS32, 16, 8, they are used for kmax_transformer_decoder.
panoptic_features = out[-1] # OS4, it is used for final mask prediction.
# OS 32, 8, 4
semantic_features = [features[self.in_features[0]], features[self.in_features[2]], features[self.in_features[3]]]
return panoptic_features, semantic_features, multi_scale_features