|
import pytest |
|
import torch |
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
|
|
from mmdet.models.necks import FPN, ChannelMapper |
|
|
|
|
|
def test_fpn(): |
|
"""Tests fpn.""" |
|
s = 64 |
|
in_channels = [8, 16, 32, 64] |
|
feat_sizes = [s // 2**i for i in range(4)] |
|
out_channels = 8 |
|
|
|
with pytest.raises(AssertionError): |
|
FPN(in_channels=in_channels, |
|
out_channels=out_channels, |
|
start_level=1, |
|
num_outs=2) |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
FPN(in_channels=in_channels, |
|
out_channels=out_channels, |
|
start_level=1, |
|
end_level=4, |
|
num_outs=2) |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
FPN(in_channels=in_channels, |
|
out_channels=out_channels, |
|
start_level=1, |
|
end_level=3, |
|
num_outs=1) |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
FPN(in_channels=in_channels, |
|
out_channels=out_channels, |
|
start_level=1, |
|
add_extra_convs='on_xxx', |
|
num_outs=5) |
|
|
|
fpn_model = FPN( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
start_level=1, |
|
add_extra_convs=True, |
|
num_outs=5) |
|
|
|
|
|
feats = [ |
|
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i]) |
|
for i in range(len(in_channels)) |
|
] |
|
outs = fpn_model(feats) |
|
assert fpn_model.add_extra_convs == 'on_input' |
|
assert len(outs) == fpn_model.num_outs |
|
for i in range(fpn_model.num_outs): |
|
outs[i].shape[1] == out_channels |
|
outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
|
|
|
fpn_model = FPN( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
start_level=1, |
|
add_extra_convs=False, |
|
num_outs=5) |
|
outs = fpn_model(feats) |
|
assert len(outs) == fpn_model.num_outs |
|
assert not fpn_model.add_extra_convs |
|
for i in range(fpn_model.num_outs): |
|
outs[i].shape[1] == out_channels |
|
outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
|
|
|
fpn_model = FPN( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
start_level=1, |
|
add_extra_convs=True, |
|
no_norm_on_lateral=False, |
|
norm_cfg=dict(type='BN', requires_grad=True), |
|
num_outs=5) |
|
outs = fpn_model(feats) |
|
assert len(outs) == fpn_model.num_outs |
|
assert fpn_model.add_extra_convs == 'on_input' |
|
for i in range(fpn_model.num_outs): |
|
outs[i].shape[1] == out_channels |
|
outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
bn_exist = False |
|
for m in fpn_model.modules(): |
|
if isinstance(m, _BatchNorm): |
|
bn_exist = True |
|
assert bn_exist |
|
|
|
|
|
fpn_model = FPN( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
start_level=1, |
|
add_extra_convs=True, |
|
upsample_cfg=dict(mode='bilinear', align_corners=True), |
|
num_outs=5) |
|
fpn_model(feats) |
|
outs = fpn_model(feats) |
|
assert len(outs) == fpn_model.num_outs |
|
assert fpn_model.add_extra_convs == 'on_input' |
|
for i in range(fpn_model.num_outs): |
|
outs[i].shape[1] == out_channels |
|
outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
|
|
|
fpn_model = FPN( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
start_level=1, |
|
add_extra_convs=True, |
|
upsample_cfg=dict(scale_factor=2), |
|
num_outs=5) |
|
outs = fpn_model(feats) |
|
assert len(outs) == fpn_model.num_outs |
|
for i in range(fpn_model.num_outs): |
|
outs[i].shape[1] == out_channels |
|
outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
|
|
|
fpn_model = FPN( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
add_extra_convs='on_input', |
|
start_level=1, |
|
num_outs=5) |
|
assert fpn_model.add_extra_convs == 'on_input' |
|
outs = fpn_model(feats) |
|
assert len(outs) == fpn_model.num_outs |
|
for i in range(fpn_model.num_outs): |
|
outs[i].shape[1] == out_channels |
|
outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
|
|
|
fpn_model = FPN( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
add_extra_convs='on_lateral', |
|
start_level=1, |
|
num_outs=5) |
|
assert fpn_model.add_extra_convs == 'on_lateral' |
|
outs = fpn_model(feats) |
|
assert len(outs) == fpn_model.num_outs |
|
for i in range(fpn_model.num_outs): |
|
outs[i].shape[1] == out_channels |
|
outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
|
|
|
fpn_model = FPN( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
add_extra_convs='on_output', |
|
start_level=1, |
|
num_outs=5) |
|
assert fpn_model.add_extra_convs == 'on_output' |
|
outs = fpn_model(feats) |
|
assert len(outs) == fpn_model.num_outs |
|
for i in range(fpn_model.num_outs): |
|
outs[i].shape[1] == out_channels |
|
outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
|
|
|
fpn_model = FPN( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
add_extra_convs=True, |
|
extra_convs_on_inputs=False, |
|
start_level=1, |
|
num_outs=5, |
|
) |
|
assert fpn_model.add_extra_convs == 'on_output' |
|
outs = fpn_model(feats) |
|
assert len(outs) == fpn_model.num_outs |
|
for i in range(fpn_model.num_outs): |
|
outs[i].shape[1] == out_channels |
|
outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
|
|
|
fpn_model = FPN( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
add_extra_convs=True, |
|
extra_convs_on_inputs=True, |
|
start_level=1, |
|
num_outs=5, |
|
) |
|
assert fpn_model.add_extra_convs == 'on_input' |
|
outs = fpn_model(feats) |
|
assert len(outs) == fpn_model.num_outs |
|
for i in range(fpn_model.num_outs): |
|
outs[i].shape[1] == out_channels |
|
outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
|
|
|
def test_channel_mapper(): |
|
"""Tests ChannelMapper.""" |
|
s = 64 |
|
in_channels = [8, 16, 32, 64] |
|
feat_sizes = [s // 2**i for i in range(4)] |
|
out_channels = 8 |
|
kernel_size = 3 |
|
feats = [ |
|
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i]) |
|
for i in range(len(in_channels)) |
|
] |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
channel_mapper = ChannelMapper( |
|
in_channels=10, out_channels=out_channels, kernel_size=kernel_size) |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
channel_mapper = ChannelMapper( |
|
in_channels=in_channels[:-1], |
|
out_channels=out_channels, |
|
kernel_size=kernel_size) |
|
channel_mapper(feats) |
|
|
|
channel_mapper = ChannelMapper( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size) |
|
|
|
outs = channel_mapper(feats) |
|
assert len(outs) == len(feats) |
|
for i in range(len(feats)): |
|
outs[i].shape[1] == out_channels |
|
outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|