|
import os |
|
import argparse |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from core.models.ops.utils import ShapeSpec |
|
from core.models.ops.utils import c2_xavier_fill |
|
|
|
from core.utils import NestedTensor |
|
|
|
|
|
def _get_activation(activation): |
|
"""Return an activation function given a string""" |
|
if activation == "relu": |
|
return nn.ReLU() |
|
elif activation == "gelu": |
|
return nn.GELU() |
|
else: |
|
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |
|
|
|
|
|
|
|
class MAEdecoder_proj_neck(nn.Module): |
|
""" |
|
the feature projection neck for MAE decoder with only a Linear layer, a LayerNorm (optional) and a type embedding |
|
(optional). The type embedding is used for the decoder with a smaller dimension than that in the encoder. |
|
:param mask_dim: int. the dimension of the mask |
|
:param backbone: placeholder |
|
:param task_sp_list: list. task specific list for DDP communication. Default: (). |
|
:param mask_forward: bool. whether to forward the mask. Default: True. |
|
:param modality: str. the modality of the input. Default: 'rgb'. |
|
:param type_embed: bool. whether to use type embedding. Default: False. |
|
:param type_embed_zero_init: bool. whether to initialize the type embedding with zeros. Always lead to better |
|
performance when True .Default: False. |
|
:param neck_layernorm: bool. whether to use LayerNorm in the neck. Default: False. |
|
""" |
|
def __init__(self, |
|
mask_dim, |
|
backbone, |
|
task_sp_list=(), |
|
mask_forward=True, |
|
modality='rgb', |
|
type_embed=False, |
|
type_embed_zero_init=False, |
|
neck_layernorm=False, |
|
conv_neck=False, |
|
): |
|
super(MAEdecoder_proj_neck, self).__init__() |
|
self.task_sp_list = task_sp_list |
|
self.modality = modality |
|
|
|
self.vis_token_dim = self.embed_dim = backbone.embed_dim |
|
self.mask_dim = mask_dim |
|
self.neck_layernorm = neck_layernorm |
|
self.conv_neck = conv_neck |
|
|
|
self.mask_map = nn.Sequential( |
|
nn.Linear(self.embed_dim, mask_dim, bias=True) |
|
) if mask_dim else False |
|
if self.conv_neck: |
|
self.mask_map = nn.Sequential( |
|
nn.Conv2d(self.embed_dim, mask_dim, 1) |
|
) |
|
self.neck_ln = nn.LayerNorm(mask_dim) if neck_layernorm else False |
|
|
|
self.mask_forward = mask_forward |
|
|
|
self.task_embed_decoder = nn.Embedding(1, mask_dim) if type_embed else None |
|
if type_embed and type_embed_zero_init: |
|
self.task_embed_decoder.weight.data = torch.zeros(1, mask_dim) |
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, features): |
|
|
|
|
|
|
|
if self.neck_ln and not self.conv_neck: |
|
if self.mask_map and self.mask_forward: |
|
features.update({f'neck_output_{self.modality}': {'mask_features': self.neck_ln(self.mask_map(features['backbone_output'])), |
|
'multi_scale_features': [features['backbone_output']], |
|
f'task_embed_decoder': self.task_embed_decoder, |
|
}}) |
|
else: |
|
features.update({'neck_output': {'mask_features': None, |
|
'multi_scale_features': [features['backbone_output']]}}) |
|
elif self.conv_neck: |
|
|
|
Hp = features.adapter_output_rgb.N_H |
|
Wp = features.adapter_output_rgb.N_W |
|
B = features.backbone_output.shape[0] |
|
|
|
proj_feats = self.mask_map(features['backbone_output'].permute(0, 2, 1).reshape(B, -1, Hp, Wp)).flatten(2, 3).permute(0, 2, 1) |
|
|
|
if self.neck_ln: |
|
proj_feats = self.neck_ln(proj_feats) |
|
if self.mask_map and self.mask_forward: |
|
features.update({f'neck_output_{self.modality}': {'mask_features': proj_feats, |
|
'multi_scale_features': [features['backbone_output']], |
|
f'task_embed_decoder': self.task_embed_decoder, |
|
}}) |
|
else: |
|
if self.mask_map and self.mask_forward: |
|
features.update({f'neck_output_{self.modality}': {'mask_features': self.mask_map(features['backbone_output']), |
|
'multi_scale_features': [features['backbone_output']], |
|
f'task_embed_decoder': self.task_embed_decoder, |
|
}}) |
|
else: |
|
features.update({'neck_output': {'mask_features': None, |
|
'multi_scale_features': [features['backbone_output']]}}) |
|
|
|
return features |
|
|
|
|
|
|
|
|