import pytest import torch from torch.nn.modules.batchnorm import _BatchNorm from mmdet.models.necks import FPN, ChannelMapper def test_fpn(): """Tests fpn.""" s = 64 in_channels = [8, 16, 32, 64] feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8] out_channels = 8 # `num_outs` is not equal to len(in_channels) - start_level with pytest.raises(AssertionError): FPN(in_channels=in_channels, out_channels=out_channels, start_level=1, num_outs=2) # `end_level` is larger than len(in_channels) - 1 with pytest.raises(AssertionError): FPN(in_channels=in_channels, out_channels=out_channels, start_level=1, end_level=4, num_outs=2) # `num_outs` is not equal to end_level - start_level with pytest.raises(AssertionError): FPN(in_channels=in_channels, out_channels=out_channels, start_level=1, end_level=3, num_outs=1) # Invalid `add_extra_convs` option with pytest.raises(AssertionError): FPN(in_channels=in_channels, out_channels=out_channels, start_level=1, add_extra_convs='on_xxx', num_outs=5) fpn_model = FPN( in_channels=in_channels, out_channels=out_channels, start_level=1, add_extra_convs=True, num_outs=5) # FPN expects a multiple levels of features per image feats = [ torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i]) for i in range(len(in_channels)) ] outs = fpn_model(feats) assert fpn_model.add_extra_convs == 'on_input' assert len(outs) == fpn_model.num_outs for i in range(fpn_model.num_outs): outs[i].shape[1] == out_channels outs[i].shape[2] == outs[i].shape[3] == s // (2**i) # Tests for fpn with no extra convs (pooling is used instead) fpn_model = FPN( in_channels=in_channels, out_channels=out_channels, start_level=1, add_extra_convs=False, num_outs=5) outs = fpn_model(feats) assert len(outs) == fpn_model.num_outs assert not fpn_model.add_extra_convs for i in range(fpn_model.num_outs): outs[i].shape[1] == out_channels outs[i].shape[2] == outs[i].shape[3] == s // (2**i) # Tests for fpn with lateral bns fpn_model = FPN( in_channels=in_channels, out_channels=out_channels, start_level=1, add_extra_convs=True, no_norm_on_lateral=False, norm_cfg=dict(type='BN', requires_grad=True), num_outs=5) outs = fpn_model(feats) assert len(outs) == fpn_model.num_outs assert fpn_model.add_extra_convs == 'on_input' for i in range(fpn_model.num_outs): outs[i].shape[1] == out_channels outs[i].shape[2] == outs[i].shape[3] == s // (2**i) bn_exist = False for m in fpn_model.modules(): if isinstance(m, _BatchNorm): bn_exist = True assert bn_exist # Bilinear upsample fpn_model = FPN( in_channels=in_channels, out_channels=out_channels, start_level=1, add_extra_convs=True, upsample_cfg=dict(mode='bilinear', align_corners=True), num_outs=5) fpn_model(feats) outs = fpn_model(feats) assert len(outs) == fpn_model.num_outs assert fpn_model.add_extra_convs == 'on_input' for i in range(fpn_model.num_outs): outs[i].shape[1] == out_channels outs[i].shape[2] == outs[i].shape[3] == s // (2**i) # Scale factor instead of fixed upsample size upsample fpn_model = FPN( in_channels=in_channels, out_channels=out_channels, start_level=1, add_extra_convs=True, upsample_cfg=dict(scale_factor=2), num_outs=5) outs = fpn_model(feats) assert len(outs) == fpn_model.num_outs for i in range(fpn_model.num_outs): outs[i].shape[1] == out_channels outs[i].shape[2] == outs[i].shape[3] == s // (2**i) # Extra convs source is 'inputs' fpn_model = FPN( in_channels=in_channels, out_channels=out_channels, add_extra_convs='on_input', start_level=1, num_outs=5) assert fpn_model.add_extra_convs == 'on_input' outs = fpn_model(feats) assert len(outs) == fpn_model.num_outs for i in range(fpn_model.num_outs): outs[i].shape[1] == out_channels outs[i].shape[2] == outs[i].shape[3] == s // (2**i) # Extra convs source is 'laterals' fpn_model = FPN( in_channels=in_channels, out_channels=out_channels, add_extra_convs='on_lateral', start_level=1, num_outs=5) assert fpn_model.add_extra_convs == 'on_lateral' outs = fpn_model(feats) assert len(outs) == fpn_model.num_outs for i in range(fpn_model.num_outs): outs[i].shape[1] == out_channels outs[i].shape[2] == outs[i].shape[3] == s // (2**i) # Extra convs source is 'outputs' fpn_model = FPN( in_channels=in_channels, out_channels=out_channels, add_extra_convs='on_output', start_level=1, num_outs=5) assert fpn_model.add_extra_convs == 'on_output' outs = fpn_model(feats) assert len(outs) == fpn_model.num_outs for i in range(fpn_model.num_outs): outs[i].shape[1] == out_channels outs[i].shape[2] == outs[i].shape[3] == s // (2**i) # extra_convs_on_inputs=False is equal to extra convs source is 'on_output' fpn_model = FPN( in_channels=in_channels, out_channels=out_channels, add_extra_convs=True, extra_convs_on_inputs=False, start_level=1, num_outs=5, ) assert fpn_model.add_extra_convs == 'on_output' outs = fpn_model(feats) assert len(outs) == fpn_model.num_outs for i in range(fpn_model.num_outs): outs[i].shape[1] == out_channels outs[i].shape[2] == outs[i].shape[3] == s // (2**i) # extra_convs_on_inputs=True is equal to extra convs source is 'on_input' fpn_model = FPN( in_channels=in_channels, out_channels=out_channels, add_extra_convs=True, extra_convs_on_inputs=True, start_level=1, num_outs=5, ) assert fpn_model.add_extra_convs == 'on_input' outs = fpn_model(feats) assert len(outs) == fpn_model.num_outs for i in range(fpn_model.num_outs): outs[i].shape[1] == out_channels outs[i].shape[2] == outs[i].shape[3] == s // (2**i) def test_channel_mapper(): """Tests ChannelMapper.""" s = 64 in_channels = [8, 16, 32, 64] feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8] out_channels = 8 kernel_size = 3 feats = [ torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i]) for i in range(len(in_channels)) ] # in_channels must be a list with pytest.raises(AssertionError): channel_mapper = ChannelMapper( in_channels=10, out_channels=out_channels, kernel_size=kernel_size) # the length of channel_mapper's inputs must be equal to the length of # in_channels with pytest.raises(AssertionError): channel_mapper = ChannelMapper( in_channels=in_channels[:-1], out_channels=out_channels, kernel_size=kernel_size) channel_mapper(feats) channel_mapper = ChannelMapper( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size) outs = channel_mapper(feats) assert len(outs) == len(feats) for i in range(len(feats)): outs[i].shape[1] == out_channels outs[i].shape[2] == outs[i].shape[3] == s // (2**i)