import torch import torch.nn as nn import torch.nn.functional as F import copy from .seecoder_utils import with_pos_embed from lib.model_zoo.common.get_model import get_model, register symbol = 'seecoder' ########### # helpers # ########### def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu, not {activation}.") def c2_xavier_fill(module): # Caffe2 implementation of XavierFill in fact nn.init.kaiming_uniform_(module.weight, a=1) if module.bias is not None: nn.init.constant_(module.bias, 0) def with_pos_embed(x, pos): return x if pos is None else x + pos ########### # Modules # ########### class Conv2d_Convenience(nn.Conv2d): def __init__(self, *args, **kwargs): norm = kwargs.pop("norm", None) activation = kwargs.pop("activation", None) super().__init__(*args, **kwargs) self.norm = norm self.activation = activation def forward(self, x): x = F.conv2d( x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) if self.norm is not None: x = self.norm(x) if self.activation is not None: x = self.activation(x) return x class DecoderLayer(nn.Module): def __init__(self, dim=256, feedforward_dim=1024, dropout=0.1, activation="relu", n_heads=8,): super().__init__() self.self_attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(dim) self.linear1 = nn.Linear(dim, feedforward_dim) self.activation = _get_activation_fn(activation) self.dropout2 = nn.Dropout(dropout) self.linear2 = nn.Linear(feedforward_dim, dim) self.dropout3 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(dim) def forward(self, x): h = x h1 = self.self_attn(x, x, x, attn_mask=None)[0] h = h + self.dropout1(h1) h = self.norm1(h) h2 = self.linear2(self.dropout2(self.activation(self.linear1(h)))) h = h + self.dropout3(h2) h = self.norm2(h) return h class DecoderLayerStacked(nn.Module): def __init__(self, layer, num_layers, norm=None): super().__init__() self.layers = _get_clones(layer, num_layers) self.num_layers = num_layers self.norm = norm def forward(self, x): h = x for _, layer in enumerate(self.layers): h = layer(h) if self.norm is not None: h = self.norm(h) return h class SelfAttentionLayer(nn.Module): def __init__(self, channels, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.self_attn = nn.MultiheadAttention(channels, nhead, dropout=dropout) self.norm = nn.LayerNorm(channels) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward_post(self, qkv, qk_pos = None, mask = None,): h = qkv qk = with_pos_embed(qkv, qk_pos).transpose(0, 1) v = qkv.transpose(0, 1) h1 = self.self_attn(qk, qk, v, attn_mask=mask)[0] h1 = h1.transpose(0, 1) h = h + self.dropout(h1) h = self.norm(h) return h def forward_pre(self, tgt, tgt_mask = None, tgt_key_padding_mask = None, query_pos = None): # deprecated assert False tgt2 = self.norm(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) return tgt def forward(self, *args, **kwargs): if self.normalize_before: return self.forward_pre(*args, **kwargs) return self.forward_post(*args, **kwargs) class CrossAttentionLayer(nn.Module): def __init__(self, channels, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.multihead_attn = nn.MultiheadAttention(channels, nhead, dropout=dropout) self.norm = nn.LayerNorm(channels) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward_post(self, q, kv, q_pos = None, k_pos = None, mask = None,): h = q q = with_pos_embed(q, q_pos).transpose(0, 1) k = with_pos_embed(kv, k_pos).transpose(0, 1) v = kv.transpose(0, 1) h1 = self.multihead_attn(q, k, v, attn_mask=mask)[0] h1 = h1.transpose(0, 1) h = h + self.dropout(h1) h = self.norm(h) return h def forward_pre(self, tgt, memory, memory_mask = None, memory_key_padding_mask = None, pos = None, query_pos = None): # Deprecated assert False tgt2 = self.norm(tgt) tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) return tgt def forward(self, *args, **kwargs): if self.normalize_before: return self.forward_pre(*args, **kwargs) return self.forward_post(*args, **kwargs) class FeedForwardLayer(nn.Module): def __init__(self, channels, hidden_channels=2048, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.linear1 = nn.Linear(channels, hidden_channels) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(hidden_channels, channels) self.norm = nn.LayerNorm(channels) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward_post(self, x): h = x h1 = self.linear2(self.dropout(self.activation(self.linear1(h)))) h = h + self.dropout(h1) h = self.norm(h) return h def forward_pre(self, x): xn = self.norm(x) h = x h1 = self.linear2(self.dropout(self.activation(self.linear1(xn)))) h = h + self.dropout(h1) return h def forward(self, *args, **kwargs): if self.normalize_before: return self.forward_pre(*args, **kwargs) return self.forward_post(*args, **kwargs) class MLP(nn.Module): def __init__(self, in_channels, channels, out_channels, num_layers): super().__init__() self.num_layers = num_layers h = [channels] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([in_channels]+h, h+[out_channels])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x class PPE_MLP(nn.Module): def __init__(self, freq_num=20, freq_max=None, out_channel=768, mlp_layer=3): import math super().__init__() self.freq_num = freq_num self.freq_max = freq_max self.out_channel = out_channel self.mlp_layer = mlp_layer self.twopi = 2 * math.pi mlp = [] in_channel = freq_num*4 for idx in range(mlp_layer): linear = nn.Linear(in_channel, out_channel, bias=True) nn.init.xavier_normal_(linear.weight) nn.init.constant_(linear.bias, 0) mlp.append(linear) if idx != mlp_layer-1: mlp.append(nn.SiLU()) in_channel = out_channel self.mlp = nn.Sequential(*mlp) nn.init.constant_(self.mlp[-1].weight, 0) def forward(self, x, mask=None): assert mask is None, "Mask not implemented" h, w = x.shape[-2:] minlen = min(h, w) h_embed, w_embed = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') if self.training: import numpy.random as npr pertube_h, pertube_w = npr.uniform(-0.5, 0.5), npr.uniform(-0.5, 0.5) else: pertube_h, pertube_w = 0, 0 h_embed = (h_embed+0.5 - h/2 + pertube_h) / (minlen) * self.twopi w_embed = (w_embed+0.5 - w/2 + pertube_w) / (minlen) * self.twopi h_embed, w_embed = h_embed.to(x.device).to(x.dtype), w_embed.to(x.device).to(x.dtype) dim_t = torch.linspace(0, 1, self.freq_num, dtype=torch.float32, device=x.device) freq_max = self.freq_max if self.freq_max is not None else minlen/2 dim_t = freq_max ** dim_t.to(x.dtype) pos_h = h_embed[:, :, None] * dim_t pos_w = w_embed[:, :, None] * dim_t pos = torch.cat((pos_h.sin(), pos_h.cos(), pos_w.sin(), pos_w.cos()), dim=-1) pos = self.mlp(pos) pos = pos.permute(2, 0, 1)[None] return pos def __repr__(self, _repr_indent=4): head = "Positional encoding " + self.__class__.__name__ body = [ "num_pos_feats: {}".format(self.num_pos_feats), "temperature: {}".format(self.temperature), "normalize: {}".format(self.normalize), "scale: {}".format(self.scale), ] # _repr_indent = 4 lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines) ########### # Decoder # ########### @register('seecoder_decoder') class Decoder(nn.Module): def __init__( self, inchannels, trans_input_tags, trans_num_layers, trans_dim, trans_nheads, trans_dropout, trans_feedforward_dim,): super().__init__() trans_inchannels = { k: v for k, v in inchannels.items() if k in trans_input_tags} fpn_inchannels = { k: v for k, v in inchannels.items() if k not in trans_input_tags} self.trans_tags = sorted(list(trans_inchannels.keys())) self.fpn_tags = sorted(list(fpn_inchannels.keys())) self.all_tags = sorted(list(inchannels.keys())) if len(self.trans_tags)==0: assert False # Not allowed self.num_trans_lvls = len(self.trans_tags) self.inproj_layers = nn.ModuleDict() for tagi in self.trans_tags: layeri = nn.Sequential( nn.Conv2d(trans_inchannels[tagi], trans_dim, kernel_size=1), nn.GroupNorm(32, trans_dim),) nn.init.xavier_uniform_(layeri[0].weight, gain=1) nn.init.constant_(layeri[0].bias, 0) self.inproj_layers[tagi] = layeri tlayer = DecoderLayer( dim = trans_dim, n_heads = trans_nheads, dropout = trans_dropout, feedforward_dim = trans_feedforward_dim, activation = 'relu',) self.transformer = DecoderLayerStacked(tlayer, trans_num_layers) for p in self.transformer.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) self.level_embed = nn.Parameter(torch.Tensor(len(self.trans_tags), trans_dim)) nn.init.normal_(self.level_embed) self.lateral_layers = nn.ModuleDict() self.output_layers = nn.ModuleDict() for tagi in self.all_tags: lateral_conv = Conv2d_Convenience( inchannels[tagi], trans_dim, kernel_size=1, bias=False, norm=nn.GroupNorm(32, trans_dim)) c2_xavier_fill(lateral_conv) self.lateral_layers[tagi] = lateral_conv for tagi in self.fpn_tags: output_conv = Conv2d_Convenience( trans_dim, trans_dim, kernel_size=3, stride=1, padding=1, bias=False, norm=nn.GroupNorm(32, trans_dim), activation=F.relu,) c2_xavier_fill(output_conv) self.output_layers[tagi] = output_conv def forward(self, features): x = [] spatial_shapes = {} for idx, tagi in enumerate(self.trans_tags[::-1]): xi = features[tagi] xi = self.inproj_layers[tagi](xi) bs, _, h, w = xi.shape spatial_shapes[tagi] = (h, w) xi = xi.flatten(2).transpose(1, 2) + self.level_embed[idx].view(1, 1, -1) x.append(xi) x_length = [xi.shape[1] for xi in x] x_concat = torch.cat(x, 1) y_concat = self.transformer(x_concat) y = torch.split(y_concat, x_length, dim=1) out = {} for idx, tagi in enumerate(self.trans_tags[::-1]): h, w = spatial_shapes[tagi] yi = y[idx].transpose(1, 2).view(bs, -1, h, w) out[tagi] = yi for idx, tagi in enumerate(self.all_tags[::-1]): lconv = self.lateral_layers[tagi] if tagi in self.trans_tags: out[tagi] = out[tagi] + lconv(features[tagi]) tag_save = tagi else: oconv = self.output_layers[tagi] h = lconv(features[tagi]) oprev = out[tag_save] h = h + F.interpolate(oconv(oprev), size=h.shape[-2:], mode="bilinear", align_corners=False) out[tagi] = h return out ##################### # Query Transformer # ##################### @register('seecoder_query_transformer') class QueryTransformer(nn.Module): def __init__(self, in_channels, hidden_dim, num_queries = [8, 144], nheads = 8, num_layers = 9, feedforward_dim = 2048, mask_dim = 256, pre_norm = False, num_feature_levels = 3, enforce_input_project = False, with_fea2d_pos = True): super().__init__() if with_fea2d_pos: self.pe_layer = PPE_MLP(freq_num=20, freq_max=None, out_channel=hidden_dim, mlp_layer=3) else: self.pe_layer = None if in_channels!=hidden_dim or enforce_input_project: self.input_proj = nn.ModuleList() for _ in range(num_feature_levels): self.input_proj.append(nn.Conv2d(in_channels, hidden_dim, kernel_size=1)) c2_xavier_fill(self.input_proj[-1]) else: self.input_proj = None self.num_heads = nheads self.num_layers = num_layers self.transformer_selfatt_layers = nn.ModuleList() self.transformer_crossatt_layers = nn.ModuleList() self.transformer_feedforward_layers = nn.ModuleList() for _ in range(self.num_layers): self.transformer_selfatt_layers.append( SelfAttentionLayer( channels=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm, )) self.transformer_crossatt_layers.append( CrossAttentionLayer( channels=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm, )) self.transformer_feedforward_layers.append( FeedForwardLayer( channels=hidden_dim, hidden_channels=feedforward_dim, dropout=0.0, normalize_before=pre_norm, )) self.num_queries = num_queries num_gq, num_lq = self.num_queries self.init_query = nn.Embedding(num_gq+num_lq, hidden_dim) self.query_pos_embedding = nn.Embedding(num_gq+num_lq, hidden_dim) self.num_feature_levels = num_feature_levels self.level_embed = nn.Embedding(num_feature_levels, hidden_dim) def forward(self, x): # x is a list of multi-scale feature assert len(x) == self.num_feature_levels fea2d = [] fea2d_pos = [] size_list = [] for i in range(self.num_feature_levels): size_list.append(x[i].shape[-2:]) if self.pe_layer is not None: pi = self.pe_layer(x[i], None).flatten(2) pi = pi.transpose(1, 2) else: pi = None xi = self.input_proj[i](x[i]) if self.input_proj is not None else x[i] xi = xi.flatten(2) + self.level_embed.weight[i][None, :, None] xi = xi.transpose(1, 2) fea2d.append(xi) fea2d_pos.append(pi) bs, _, _ = fea2d[0].shape num_gq, num_lq = self.num_queries gquery = self.init_query.weight[:num_gq].unsqueeze(0).repeat(bs, 1, 1) lquery = self.init_query.weight[num_gq:].unsqueeze(0).repeat(bs, 1, 1) gquery_pos = self.query_pos_embedding.weight[:num_gq].unsqueeze(0).repeat(bs, 1, 1) lquery_pos = self.query_pos_embedding.weight[num_gq:].unsqueeze(0).repeat(bs, 1, 1) for i in range(self.num_layers): level_index = i % self.num_feature_levels qout = self.transformer_crossatt_layers[i]( q = lquery, kv = fea2d[level_index], q_pos = lquery_pos, k_pos = fea2d_pos[level_index], mask = None,) lquery = qout qout = self.transformer_selfatt_layers[i]( qkv = torch.cat([gquery, lquery], dim=1), qk_pos = torch.cat([gquery_pos, lquery_pos], dim=1),) qout = self.transformer_feedforward_layers[i](qout) gquery = qout[:, :num_gq] lquery = qout[:, num_gq:] output = torch.cat([gquery, lquery], dim=1) return output ################## # Main structure # ################## @register('seecoder') class SemanticExtractionEncoder(nn.Module): def __init__(self, imencoder_cfg, imdecoder_cfg, qtransformer_cfg): super().__init__() self.imencoder = get_model()(imencoder_cfg) self.imdecoder = get_model()(imdecoder_cfg) self.qtransformer = get_model()(qtransformer_cfg) def forward(self, x): fea = self.imencoder(x) hs = {'res3' : fea['res3'], 'res4' : fea['res4'], 'res5' : fea['res5'], } hs = self.imdecoder(hs) hs = [hs['res3'], hs['res4'], hs['res5']] q = self.qtransformer(hs) return q def encode(self, x): return self(x)