import torch from torch import nn import timm from hybridnets.model import BiFPN, Regressor, Classifier, BiFPNDecoder from utils.utils import Anchors from hybridnets.model import SegmentationHead from encoders import get_encoder class HybridNetsBackbone(nn.Module): def __init__(self, num_classes=80, compound_coef=0, seg_classes=1, backbone_name=None, **kwargs): super(HybridNetsBackbone, self).__init__() self.compound_coef = compound_coef self.seg_classes = seg_classes self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6, 7] self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384, 384] self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8, 8] self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536] self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5, 5] self.pyramid_levels = [5, 5, 5, 5, 5, 5, 5, 5, 6] self.anchor_scale = [1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,] self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]) self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])) conv_channel_coef = { # the channels of P3/P4/P5. 0: [40, 112, 320], 1: [40, 112, 320], 2: [48, 120, 352], 3: [48, 136, 384], 4: [56, 160, 448], 5: [64, 176, 512], 6: [72, 200, 576], 7: [72, 200, 576], 8: [80, 224, 640], } num_anchors = len(self.aspect_ratios) * self.num_scales self.bifpn = nn.Sequential( *[BiFPN(self.fpn_num_filters[self.compound_coef], conv_channel_coef[compound_coef], True if _ == 0 else False, attention=True if compound_coef < 6 else False, use_p8=compound_coef > 7) for _ in range(self.fpn_cell_repeats[compound_coef])]) self.num_classes = num_classes self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors, num_layers=self.box_class_repeats[self.compound_coef], pyramid_levels=self.pyramid_levels[self.compound_coef]) '''Modified by Dat Vu''' # self.decoder = DecoderModule() self.bifpndecoder = BiFPNDecoder(pyramid_channels=self.fpn_num_filters[self.compound_coef]) self.segmentation_head = SegmentationHead( in_channels=64, out_channels=self.seg_classes+1 if self.seg_classes > 1 else self.seg_classes, activation='softmax2d' if self.seg_classes > 1 else 'sigmoid', kernel_size=1, upsampling=4, ) self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors, num_classes=num_classes, num_layers=self.box_class_repeats[self.compound_coef], pyramid_levels=self.pyramid_levels[self.compound_coef]) self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef], pyramid_levels=(torch.arange(self.pyramid_levels[self.compound_coef]) + 3).tolist(), **kwargs) if backbone_name: # Use timm to create another backbone that you prefer # https://github.com/rwightman/pytorch-image-models self.encoder = timm.create_model(backbone_name, pretrained=True, features_only=True, out_indices=(2,3,4)) # P3,P4,P5 else: # EfficientNet_Pytorch self.encoder = get_encoder( 'efficientnet-b' + str(self.backbone_compound_coef[compound_coef]), in_channels=3, depth=5, weights='imagenet', ) self.initialize_decoder(self.bifpndecoder) self.initialize_head(self.segmentation_head) self.initialize_decoder(self.bifpn) def freeze_bn(self): for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() def forward(self, inputs): max_size = inputs.shape[-1] # p1, p2, p3, p4, p5 = self.backbone_net(inputs) p2, p3, p4, p5 = self.encoder(inputs)[-4:] # self.backbone_net(inputs) features = (p3, p4, p5) features = self.bifpn(features) p3,p4,p5,p6,p7 = features outputs = self.bifpndecoder((p2,p3,p4,p5,p6,p7)) segmentation = self.segmentation_head(outputs) regression = self.regressor(features) classification = self.classifier(features) anchors = self.anchors(inputs, inputs.dtype) return features, regression, classification, anchors, segmentation def initialize_decoder(self, module): for m in module.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) def initialize_head(self, module): for m in module.modules(): if isinstance(m, (nn.Linear, nn.Conv2d)): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0)