Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import timm | |
from hybridnets.model import BiFPN, Regressor, Classifier, BiFPNDecoder | |
from utils.utils import Anchors | |
from hybridnets.model import SegmentationHead | |
from encoders import get_encoder | |
class HybridNetsBackbone(nn.Module): | |
def __init__(self, num_classes=80, compound_coef=0, seg_classes=1, backbone_name=None, **kwargs): | |
super(HybridNetsBackbone, self).__init__() | |
self.compound_coef = compound_coef | |
self.seg_classes = seg_classes | |
self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6, 7] | |
self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384, 384] | |
self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8, 8] | |
self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536] | |
self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5, 5] | |
self.pyramid_levels = [5, 5, 5, 5, 5, 5, 5, 5, 6] | |
self.anchor_scale = [1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,] | |
self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]) | |
self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])) | |
conv_channel_coef = { | |
# the channels of P3/P4/P5. | |
0: [40, 112, 320], | |
1: [40, 112, 320], | |
2: [48, 120, 352], | |
3: [48, 136, 384], | |
4: [56, 160, 448], | |
5: [64, 176, 512], | |
6: [72, 200, 576], | |
7: [72, 200, 576], | |
8: [80, 224, 640], | |
} | |
num_anchors = len(self.aspect_ratios) * self.num_scales | |
self.bifpn = nn.Sequential( | |
*[BiFPN(self.fpn_num_filters[self.compound_coef], | |
conv_channel_coef[compound_coef], | |
True if _ == 0 else False, | |
attention=True if compound_coef < 6 else False, | |
use_p8=compound_coef > 7) | |
for _ in range(self.fpn_cell_repeats[compound_coef])]) | |
self.num_classes = num_classes | |
self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors, | |
num_layers=self.box_class_repeats[self.compound_coef], | |
pyramid_levels=self.pyramid_levels[self.compound_coef]) | |
'''Modified by Dat Vu''' | |
# self.decoder = DecoderModule() | |
self.bifpndecoder = BiFPNDecoder(pyramid_channels=self.fpn_num_filters[self.compound_coef]) | |
self.segmentation_head = SegmentationHead( | |
in_channels=64, | |
out_channels=self.seg_classes+1 if self.seg_classes > 1 else self.seg_classes, | |
activation='softmax2d' if self.seg_classes > 1 else 'sigmoid', | |
kernel_size=1, | |
upsampling=4, | |
) | |
self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors, | |
num_classes=num_classes, | |
num_layers=self.box_class_repeats[self.compound_coef], | |
pyramid_levels=self.pyramid_levels[self.compound_coef]) | |
self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef], | |
pyramid_levels=(torch.arange(self.pyramid_levels[self.compound_coef]) + 3).tolist(), | |
**kwargs) | |
if backbone_name: | |
# Use timm to create another backbone that you prefer | |
# https://github.com/rwightman/pytorch-image-models | |
self.encoder = timm.create_model(backbone_name, pretrained=True, features_only=True, out_indices=(2,3,4)) # P3,P4,P5 | |
else: | |
# EfficientNet_Pytorch | |
self.encoder = get_encoder( | |
'efficientnet-b' + str(self.backbone_compound_coef[compound_coef]), | |
in_channels=3, | |
depth=5, | |
weights='imagenet', | |
) | |
self.initialize_decoder(self.bifpndecoder) | |
self.initialize_head(self.segmentation_head) | |
self.initialize_decoder(self.bifpn) | |
def freeze_bn(self): | |
for m in self.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |
def forward(self, inputs): | |
max_size = inputs.shape[-1] | |
# p1, p2, p3, p4, p5 = self.backbone_net(inputs) | |
p2, p3, p4, p5 = self.encoder(inputs)[-4:] # self.backbone_net(inputs) | |
features = (p3, p4, p5) | |
features = self.bifpn(features) | |
p3,p4,p5,p6,p7 = features | |
outputs = self.bifpndecoder((p2,p3,p4,p5,p6,p7)) | |
segmentation = self.segmentation_head(outputs) | |
regression = self.regressor(features) | |
classification = self.classifier(features) | |
anchors = self.anchors(inputs, inputs.dtype) | |
return features, regression, classification, anchors, segmentation | |
def initialize_decoder(self, module): | |
for m in module.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.BatchNorm2d): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.Linear): | |
nn.init.xavier_uniform_(m.weight) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
def initialize_head(self, module): | |
for m in module.modules(): | |
if isinstance(m, (nn.Linear, nn.Conv2d)): | |
nn.init.xavier_uniform_(m.weight) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |