camenduru's picture
thanks to show ❤
3bbb319
raw
history blame
No virus
4.12 kB
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.models.backbones.csp_darknet import CSPDarknet
from .utils import check_norm_state, is_norm
def test_csp_darknet_backbone():
with pytest.raises(ValueError):
# frozen_stages must in range(-1, len(arch_setting) + 1)
CSPDarknet(frozen_stages=6)
with pytest.raises(AssertionError):
# out_indices in range(len(arch_setting) + 1)
CSPDarknet(out_indices=[6])
# Test CSPDarknet with first stage frozen
frozen_stages = 1
model = CSPDarknet(frozen_stages=frozen_stages)
model.train()
for mod in model.stem.modules():
for param in mod.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'stage{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
# Test CSPDarknet with norm_eval=True
model = CSPDarknet(norm_eval=True)
model.train()
assert check_norm_state(model.modules(), False)
# Test CSPDarknet-P5 forward with widen_factor=0.5
model = CSPDarknet(arch='P5', widen_factor=0.25, out_indices=range(0, 5))
model.train()
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 16, 32, 32))
assert feat[1].shape == torch.Size((1, 32, 16, 16))
assert feat[2].shape == torch.Size((1, 64, 8, 8))
assert feat[3].shape == torch.Size((1, 128, 4, 4))
assert feat[4].shape == torch.Size((1, 256, 2, 2))
# Test CSPDarknet-P6 forward with widen_factor=0.5
model = CSPDarknet(
arch='P6',
widen_factor=0.25,
out_indices=range(0, 6),
spp_kernal_sizes=(3, 5, 7))
model.train()
imgs = torch.randn(1, 3, 128, 128)
feat = model(imgs)
assert feat[0].shape == torch.Size((1, 16, 64, 64))
assert feat[1].shape == torch.Size((1, 32, 32, 32))
assert feat[2].shape == torch.Size((1, 64, 16, 16))
assert feat[3].shape == torch.Size((1, 128, 8, 8))
assert feat[4].shape == torch.Size((1, 192, 4, 4))
assert feat[5].shape == torch.Size((1, 256, 2, 2))
# Test CSPDarknet forward with dict(type='ReLU')
model = CSPDarknet(
widen_factor=0.125, act_cfg=dict(type='ReLU'), out_indices=range(0, 5))
model.train()
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 8, 32, 32))
assert feat[1].shape == torch.Size((1, 16, 16, 16))
assert feat[2].shape == torch.Size((1, 32, 8, 8))
assert feat[3].shape == torch.Size((1, 64, 4, 4))
assert feat[4].shape == torch.Size((1, 128, 2, 2))
# Test CSPDarknet with BatchNorm forward
model = CSPDarknet(widen_factor=0.125, out_indices=range(0, 5))
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
model.train()
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 8, 32, 32))
assert feat[1].shape == torch.Size((1, 16, 16, 16))
assert feat[2].shape == torch.Size((1, 32, 8, 8))
assert feat[3].shape == torch.Size((1, 64, 4, 4))
assert feat[4].shape == torch.Size((1, 128, 2, 2))
# Test CSPDarknet with custom arch forward
arch_ovewrite = [[32, 56, 3, True, False], [56, 224, 2, True, False],
[224, 512, 1, True, False]]
model = CSPDarknet(
arch_ovewrite=arch_ovewrite,
widen_factor=0.25,
out_indices=(0, 1, 2, 3))
model.train()
imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size((1, 8, 16, 16))
assert feat[1].shape == torch.Size((1, 14, 8, 8))
assert feat[2].shape == torch.Size((1, 56, 4, 4))
assert feat[3].shape == torch.Size((1, 128, 2, 2))