Spaces:
Runtime error
Runtime error
# Reference: https://github.com/google-research/deeplab2/blob/main/model/pixel_decoder/kmax.py | |
# Modified by Qihang Yu | |
from turtle import forward | |
from typing import Dict, List | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from timm.models.layers import DropPath | |
from timm.models.layers import trunc_normal_tf_ as trunc_normal_ | |
from detectron2.config import configurable | |
from detectron2.layers import ShapeSpec | |
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY | |
from torch.cuda.amp import autocast | |
from ..backbone.convnext import LayerNorm | |
import math | |
def get_activation(name): | |
if name is None or name.lower() == 'none': | |
return nn.Identity() | |
if name == 'relu': | |
return nn.ReLU() | |
elif name == 'gelu': | |
return nn.GELU() | |
class SyncBNCPU(nn.SyncBatchNorm): | |
def forward(self, input): | |
self._check_input_dim(input) | |
self._check_non_zero_input_channels(input) | |
if self.momentum is None: | |
exponential_average_factor = 0.0 | |
else: | |
exponential_average_factor = self.momentum | |
bn_training = False | |
running_mean = self.running_mean | |
running_var = self.running_var | |
# fallback to framework BN when synchronization is not necessary | |
return F.batch_norm( | |
input, | |
running_mean, | |
running_var, | |
self.weight, | |
self.bias, | |
bn_training, | |
exponential_average_factor, | |
self.eps, | |
) | |
def get_norm(name, channels): | |
if name is None or name.lower() == 'none': | |
return nn.Identity() | |
if name.lower() == 'syncbn': | |
return SyncBNCPU(channels, eps=1e-3, momentum=0.01) | |
class ConvBN(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, norm=None, act=None, | |
conv_type='2d', conv_init='he_normal', norm_init=1.0): | |
super().__init__() | |
if conv_type == '2d': | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) | |
elif conv_type == '1d': | |
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) | |
self.norm = get_norm(norm, out_channels) | |
self.act = get_activation(act) | |
if conv_init == 'normal': | |
nn.init.normal_(self.conv.weight, std=.02) | |
elif conv_init == 'trunc_normal': | |
trunc_normal_(self.conv.weight, std=.02) | |
elif conv_init == 'he_normal': | |
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/HeNormal | |
trunc_normal_(self.conv.weight, std=math.sqrt(2.0 / in_channels)) | |
elif conv_init == 'xavier_uniform': | |
nn.init.xavier_uniform_(self.conv.weight) | |
if bias: | |
nn.init.zeros_(self.conv.bias) | |
if norm is not None: | |
nn.init.constant_(self.norm.weight, norm_init) | |
def forward(self, x): | |
return self.act(self.norm(self.conv(x))) | |
MAX_SPAN = 255 | |
def _compute_relative_distance_matrix(query_length, key_length): | |
if (key_length - query_length) % 2: | |
raise ValueError('Key_length should be query_length + 2 * memory_flange.') | |
key_index = torch.arange(key_length) | |
query_index = torch.arange(query_length) + (key_length - query_length) // 2 | |
distance_matrix = key_index[None, :] - query_index[:, None] | |
# Shift the distance_matrix so that it is >= 0. Each entry of the | |
# distance_matrix distance will index a relative positional embedding. | |
distance_matrix = distance_matrix + MAX_SPAN - 1 | |
return distance_matrix | |
class RelativePositionalEncoding(nn.Module): | |
def __init__(self, query_length, key_length, depth): | |
super().__init__() | |
self._embeddings = nn.Embedding(MAX_SPAN * 2 - 1, depth) | |
trunc_normal_(self._embeddings.weight, std=1.0) | |
self._relative_distance_matrix = _compute_relative_distance_matrix(query_length, key_length) | |
self.query_length = query_length | |
self.key_length = key_length | |
self.depth = depth | |
def forward(self): | |
return self._embeddings.weight[self._relative_distance_matrix.reshape(-1)].reshape(self.query_length, self.key_length, self.depth) | |
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_layers.py#L36 | |
class AxialAttention(nn.Module): | |
def __init__(self, in_planes, query_shape=56, total_key_depth=512, total_value_depth=1024, num_heads=8): | |
assert (total_key_depth % num_heads == 0) and (total_value_depth % num_heads == 0) | |
super().__init__() | |
self._in_planes = in_planes | |
self._query_shape = query_shape | |
self._total_key_depth = total_key_depth | |
self._total_value_depth = total_value_depth | |
self._num_heads = num_heads | |
self._key_depth_per_head = total_key_depth // num_heads | |
self.qkv_transform = ConvBN(in_planes, self._total_key_depth * 2 + self._total_value_depth, kernel_size=1, stride=1, | |
padding=0, bias=False, norm=None, act=None, conv_type='1d') | |
trunc_normal_(self.qkv_transform.conv.weight, std=in_planes ** -0.5) | |
self._query_rpe = RelativePositionalEncoding(query_shape, query_shape, self._key_depth_per_head) | |
self._key_rpe = RelativePositionalEncoding(query_shape, query_shape, self._key_depth_per_head) | |
self._value_rpe = RelativePositionalEncoding(query_shape, query_shape, total_value_depth // num_heads) | |
self._batch_norm_qkv = get_norm('syncbn', self._total_key_depth * 2 + self._total_value_depth) | |
self._batch_norm_similarity = get_norm('syncbn', num_heads * 3) | |
self._batch_norm_retrieved_output = get_norm('syncbn', self._total_value_depth * 2) | |
def forward(self, x): | |
N, C, L = x.shape | |
qkv = self._batch_norm_qkv(self.qkv_transform(x)) | |
q, k, v = torch.split(qkv, [self._total_key_depth, self._total_key_depth, self._total_value_depth], dim=1) | |
q = q.reshape(N, self._num_heads, self._total_key_depth // self._num_heads, L) | |
k = k.reshape(N, self._num_heads, self._total_key_depth // self._num_heads, L) | |
v = v.reshape(N, self._num_heads, self._total_value_depth // self._num_heads, L) | |
similarity_logits = [] | |
content_similarity = torch.einsum('bhdl,bhdm->bhlm', q, k) | |
query_rpe = self._query_rpe() | |
query_rpe_similarity = torch.einsum('bhdl,lmd->bhlm', q, query_rpe) | |
key_rpe = self._key_rpe() | |
key_rpe_similarity = torch.einsum('bhdm,lmd->bhlm', k, key_rpe) | |
similarity_logits = torch.cat([content_similarity, query_rpe_similarity, key_rpe_similarity], dim=1) | |
similarity_logits = self._batch_norm_similarity(similarity_logits).reshape(N, 3, self._num_heads, L, L).sum(dim=1) | |
with autocast(enabled=False): | |
weights = F.softmax(similarity_logits.float(), dim=-1) | |
retrieved_content = torch.einsum('bhlm,bhdm->bhdl', weights, v) | |
value_rpe = self._value_rpe() | |
retrieved_rpe = torch.einsum('bhlm,lmd->bhdl', weights, value_rpe) | |
retrieved_output = torch.cat([retrieved_content, retrieved_rpe], dim=1).reshape(N, 2*self._total_value_depth, L) | |
retrieved_output = self._batch_norm_retrieved_output(retrieved_output).reshape(N, 2, self._total_value_depth, L).sum(1) | |
return retrieved_output | |
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_layers.py#L316 | |
class AxialAttention2D(nn.Module): | |
def __init__(self, in_planes, query_shape=[56, 56], filters=512, key_expansion=1, value_expansion=2, num_heads=8): | |
super().__init__() | |
total_key_depth = int(round(filters * key_expansion)) | |
total_value_depth = int(round(filters * value_expansion)) | |
self._total_key_depth = total_key_depth | |
self._total_value_depth = total_value_depth | |
self._height_axis = AxialAttention( | |
in_planes=in_planes, | |
query_shape=query_shape[0], | |
total_key_depth=total_key_depth, | |
total_value_depth=total_value_depth, | |
num_heads=num_heads) | |
self._width_axis = AxialAttention( | |
in_planes=total_value_depth, | |
query_shape=query_shape[1], | |
total_key_depth=total_key_depth, | |
total_value_depth=total_value_depth, | |
num_heads=num_heads) | |
def forward(self, x): | |
# N C H W -> N W C H | |
N, C, H, W = x.shape | |
x = x.permute(0, 3, 1, 2).contiguous() | |
x = x.reshape(N*W, C, H) | |
x = self._height_axis(x) | |
# N W C H -> N H C W | |
x = x.reshape(N, W, self._total_value_depth, H).permute(0, 3, 2, 1).contiguous() | |
x = x.reshape(N*H, self._total_value_depth, W) | |
x = self._width_axis(x) | |
x = x.reshape(N, H, self._total_value_depth, W).permute(0, 2, 1, 3).contiguous() | |
x = x.reshape(N, self._total_value_depth, H, W) | |
return x | |
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_blocks.py#L36 | |
class SingleBlock(nn.Module): | |
def __init__(self, inplanes, filter_list, block_type, query_shape=[56, 56], key_expansion=1, value_expansion=2, num_heads=8, drop_path_prob=0.0): | |
super(SingleBlock, self).__init__() | |
self._block_type = block_type.lower() | |
self._filter_list = filter_list | |
self._conv1_bn_act = ConvBN(inplanes, self._filter_list[0], kernel_size=1, bias=False, norm='syncbn', act='gelu') | |
if self._block_type == 'axial': | |
self._attention = AxialAttention2D(in_planes=self._filter_list[0], query_shape=query_shape, filters=self._filter_list[1], | |
key_expansion=key_expansion, value_expansion=value_expansion, num_heads=num_heads) | |
output_channel = filter_list[1] * value_expansion | |
elif self._block_type == 'bottleneck': | |
self._conv2_bn_act = ConvBN(self._filter_list[0], self._filter_list[1], kernel_size=3, padding=1, bias=False, norm='syncbn', act='gelu') | |
output_channel = filter_list[1] | |
self._conv3_bn = ConvBN(output_channel, self._filter_list[2], kernel_size=1, bias=False, norm='syncbn', act=None, norm_init=0.0) | |
self._shortcut = None | |
if inplanes != self._filter_list[-1]: | |
self._shortcut = ConvBN(inplanes, self._filter_list[-1], kernel_size=1, bias=False, norm='syncbn', act=None) | |
self.drop_path = DropPath(drop_path_prob) if drop_path_prob > 0. else nn.Identity() | |
def forward(self, x): | |
x = F.gelu(x) | |
shortcut = x | |
if self._shortcut is not None: | |
shortcut = self._shortcut(shortcut) | |
x = self._conv1_bn_act(x) | |
if self._block_type == 'axial': | |
x = self._attention(x) | |
x = F.gelu(x) | |
elif self._block_type == 'bottleneck': | |
x = self._conv2_bn_act(x) | |
x = self._conv3_bn(x) | |
x = self.drop_path(x) + shortcut | |
return x | |
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L42 | |
class BlockGroup(nn.Module): | |
def __init__(self, inplanes, base_filter, num_blocks, block_type, **kwargs): | |
super().__init__() | |
self._num_blocks = num_blocks | |
block_type = block_type.lower() | |
if block_type == 'axial': | |
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L247 | |
filter_list = [base_filter * 2, base_filter, base_filter * 4] | |
elif block_type == 'bottleneck': | |
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L250 | |
filter_list = [base_filter, base_filter, base_filter * 4] | |
self._blocks = nn.ModuleList() | |
for i in range(num_blocks): | |
self._blocks.append(SingleBlock(inplanes=inplanes, filter_list=filter_list, block_type=block_type, **kwargs)) | |
inplanes = filter_list[-1] | |
def forward(self, x): | |
for i in range(self._num_blocks): | |
x = self._blocks[i](x) | |
return x | |
# https://github.com/google-research/deeplab2/blob/7a01a7165e97b3325ad7ea9b6bcc02d67fecd07a/model/layers/resized_fuse.py#L31 | |
class ResizedFuse(nn.Module): | |
def __init__(self, low_in_channels, high_in_channels, out_channels): | |
super().__init__() | |
self.low_in_channels = low_in_channels | |
self.high_in_channels = high_in_channels | |
self.out_channels = out_channels | |
if low_in_channels != out_channels: | |
self._conv_bn_low = ConvBN(low_in_channels, out_channels, kernel_size=1, bias=False, norm='syncbn', act=None) | |
if high_in_channels != out_channels: | |
self._conv_bn_high = ConvBN(high_in_channels, out_channels, kernel_size=1, bias=False, norm='syncbn', act=None) | |
def forward(self, lowres_x, highres_x): | |
align_corners = (lowres_x.shape[-1] % 2 == 1) | |
if self.low_in_channels != self.out_channels: | |
lowres_x = F.gelu(lowres_x) | |
lowres_x = self._conv_bn_low(lowres_x) | |
lowres_x = F.interpolate(lowres_x, size=highres_x.shape[2:], mode='bilinear', align_corners=align_corners) | |
else: | |
lowres_x = F.interpolate(lowres_x, size=highres_x.shape[2:], mode='bilinear', align_corners=align_corners) | |
if self.high_in_channels != self.out_channels: | |
highres_x = F.gelu(highres_x) | |
highres_x = self._conv_bn_high(highres_x) | |
return lowres_x + highres_x | |
class kMaXPixelDecoder(nn.Module): | |
def __init__( | |
self, | |
input_shape: Dict[str, ShapeSpec], | |
*, | |
dec_layers: List[int], | |
dec_channels: List[int], | |
layer_types: List[str], | |
drop_path_prob: float, | |
spatial_shape: List[int], | |
): | |
""" | |
NOTE: this interface is experimental. | |
Args: | |
""" | |
super().__init__() | |
self.num_stages = len(input_shape) | |
assert self.num_stages == len(dec_layers) and self.num_stages == len(dec_channels) and self.num_stages == len(layer_types) | |
# For now, we hard code all hyper-parameters. | |
block_types = ['axial', 'axial', 'bottleneck', 'bottleneck'] | |
input_shape = sorted(input_shape.items(), key=lambda x: -x[1].stride) | |
self.in_features = [k for k, v in input_shape] # starting from "res5" to "res2" | |
in_channels = [v.channels for k, v in input_shape] | |
add_one = (spatial_shape[0] % 2, spatial_shape[1] % 2) | |
query_shape = [ | |
(spatial_shape[0]//32+add_one[0], spatial_shape[1]//32+add_one[1]), | |
(spatial_shape[0]//16+add_one[0], spatial_shape[1]//16+add_one[1]), | |
(spatial_shape[0]//8+add_one[0], spatial_shape[1]//8+add_one[1]), | |
(spatial_shape[0]//4+add_one[0], spatial_shape[1]//4+add_one[1])] | |
self._in_norms = nn.ModuleList() | |
self._stages = nn.ModuleList() | |
self._resized_fuses = nn.ModuleList() | |
for i in range(self.num_stages): | |
self._in_norms.append(LayerNorm(in_channels[i], data_format="channels_first")) | |
inplanes = in_channels[i] if i == 0 else dec_channels[i] | |
self._stages.append(BlockGroup(inplanes=inplanes, | |
base_filter=dec_channels[i], num_blocks=dec_layers[i], block_type=block_types[i], | |
query_shape=query_shape[i], key_expansion=1, value_expansion=2, num_heads=8, drop_path_prob=0.0)) | |
if i > 0: | |
self._resized_fuses.append(ResizedFuse( | |
low_in_channels=dec_channels[i-1] * 4, | |
high_in_channels=in_channels[i], | |
out_channels=dec_channels[i])) | |
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): | |
ret = {} | |
ret["input_shape"] = { | |
k: v for k, v in input_shape.items() if k in cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.IN_FEATURES | |
} | |
ret["dec_layers"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DEC_LAYERS | |
ret["dec_channels"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DEC_CHANNELS | |
ret["layer_types"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.LAYER_TYPES | |
ret["drop_path_prob"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DROP_PATH_PROB | |
ret["spatial_shape"] = cfg.INPUT.IMAGE_SIZE # We expect the height == width | |
return ret | |
def forward_features(self, features): | |
out = [] | |
multi_scale_features = [] | |
x = self._in_norms[0](features[self.in_features[0]]) | |
for idx in range(self.num_stages - 1): | |
x = self._stages[idx](x) | |
out.append(x) | |
x = self._resized_fuses[idx]( | |
lowres_x=x, | |
highres_x=self._in_norms[idx+1](features[self.in_features[idx+1]])) | |
x = self._stages[-1](x) | |
out.append(x) | |
multi_scale_features = out[:3] # OS32, 16, 8, they are used for kmax_transformer_decoder. | |
panoptic_features = out[-1] # OS4, it is used for final mask prediction. | |
# OS 32, 8, 4 | |
semantic_features = [features[self.in_features[0]], features[self.in_features[2]], features[self.in_features[3]]] | |
return panoptic_features, semantic_features, multi_scale_features | |