ForkedHulk2 / core /models /necks /simple_fpn.py
tuandunghcmut's picture
Upload folder using huggingface_hub
345ee20 verified
raw
history blame
5.7 kB
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from core.models.ops.utils import ShapeSpec
from core.models.ops.utils import c2_xavier_fill
from core.utils import NestedTensor
def _get_activation(activation):
"""Return an activation function given a string"""
if activation == "relu":
return nn.ReLU()
elif activation == "gelu":
return nn.GELU()
else:
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
class MAEdecoder_proj_neck(nn.Module):
"""
the feature projection neck for MAE decoder with only a Linear layer, a LayerNorm (optional) and a type embedding
(optional). The type embedding is used for the decoder with a smaller dimension than that in the encoder.
:param mask_dim: int. the dimension of the mask
:param backbone: placeholder
:param task_sp_list: list. task specific list for DDP communication. Default: ().
:param mask_forward: bool. whether to forward the mask. Default: True.
:param modality: str. the modality of the input. Default: 'rgb'.
:param type_embed: bool. whether to use type embedding. Default: False.
:param type_embed_zero_init: bool. whether to initialize the type embedding with zeros. Always lead to better
performance when True .Default: False.
:param neck_layernorm: bool. whether to use LayerNorm in the neck. Default: False.
"""
def __init__(self,
mask_dim,
backbone, # placeholder
task_sp_list=(),
mask_forward=True,
modality='rgb',
type_embed=False,
type_embed_zero_init=False,
neck_layernorm=False,
conv_neck=False,
):
super(MAEdecoder_proj_neck, self).__init__()
self.task_sp_list = task_sp_list
self.modality = modality
self.vis_token_dim = self.embed_dim = backbone.embed_dim
self.mask_dim = mask_dim
self.neck_layernorm = neck_layernorm
self.conv_neck = conv_neck
self.mask_map = nn.Sequential(
nn.Linear(self.embed_dim, mask_dim, bias=True)
) if mask_dim else False
if self.conv_neck:
self.mask_map = nn.Sequential(
nn.Conv2d(self.embed_dim, mask_dim, 1)
)
self.neck_ln = nn.LayerNorm(mask_dim) if neck_layernorm else False
self.mask_forward = mask_forward
# the type embedding is used for the decoder with a smaller dimension than that in the encoder
self.task_embed_decoder = nn.Embedding(1, mask_dim) if type_embed else None
if type_embed and type_embed_zero_init:
self.task_embed_decoder.weight.data = torch.zeros(1, mask_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, features):
# to be compatible with unihcpv1, neck still using mask_map as the projector,
# and mask_features as the neck_output, will be deprecated in the future
# import pdb;pdb.set_trace()
if self.neck_ln and not self.conv_neck:
if self.mask_map and self.mask_forward:
features.update({f'neck_output_{self.modality}': {'mask_features': self.neck_ln(self.mask_map(features['backbone_output'])),
'multi_scale_features': [features['backbone_output']],
f'task_embed_decoder': self.task_embed_decoder,
}})
else:
features.update({'neck_output': {'mask_features': None,
'multi_scale_features': [features['backbone_output']]}})
elif self.conv_neck:
# only for v2 det detection
Hp = features.adapter_output_rgb.N_H
Wp = features.adapter_output_rgb.N_W
B = features.backbone_output.shape[0]
proj_feats = self.mask_map(features['backbone_output'].permute(0, 2, 1).reshape(B, -1, Hp, Wp)).flatten(2, 3).permute(0, 2, 1)
if self.neck_ln:
proj_feats = self.neck_ln(proj_feats)
if self.mask_map and self.mask_forward:
features.update({f'neck_output_{self.modality}': {'mask_features': proj_feats,
'multi_scale_features': [features['backbone_output']],
f'task_embed_decoder': self.task_embed_decoder,
}})
else:
if self.mask_map and self.mask_forward:
features.update({f'neck_output_{self.modality}': {'mask_features': self.mask_map(features['backbone_output']),
'multi_scale_features': [features['backbone_output']],
f'task_embed_decoder': self.task_embed_decoder,
}})
else:
features.update({'neck_output': {'mask_features': None,
'multi_scale_features': [features['backbone_output']]}})
return features