|
|
|
import pytest |
|
import torch |
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
|
|
from mmpose.models.backbones import LiteHRNet |
|
from mmpose.models.backbones.litehrnet import LiteHRModule |
|
from mmpose.models.backbones.resnet import Bottleneck |
|
|
|
|
|
def is_norm(modules): |
|
"""Check if is one of the norms.""" |
|
if isinstance(modules, (_BatchNorm, )): |
|
return True |
|
return False |
|
|
|
|
|
def all_zeros(modules): |
|
"""Check if the weight(and bias) is all zero.""" |
|
weight_zero = torch.equal(modules.weight.data, |
|
torch.zeros_like(modules.weight.data)) |
|
if hasattr(modules, 'bias'): |
|
bias_zero = torch.equal(modules.bias.data, |
|
torch.zeros_like(modules.bias.data)) |
|
else: |
|
bias_zero = True |
|
|
|
return weight_zero and bias_zero |
|
|
|
|
|
def test_litehrmodule(): |
|
|
|
block = LiteHRModule( |
|
num_branches=1, |
|
num_blocks=1, |
|
in_channels=[ |
|
40, |
|
], |
|
reduce_ratio=8, |
|
module_type='LITE') |
|
|
|
x = torch.randn(2, 40, 56, 56) |
|
x_out = block([[x]]) |
|
assert x_out[0][0].shape == torch.Size([2, 40, 56, 56]) |
|
|
|
block = LiteHRModule( |
|
num_branches=1, |
|
num_blocks=1, |
|
in_channels=[ |
|
40, |
|
], |
|
reduce_ratio=8, |
|
module_type='NAIVE') |
|
|
|
x = torch.randn(2, 40, 56, 56) |
|
x_out = block([x]) |
|
assert x_out[0].shape == torch.Size([2, 40, 56, 56]) |
|
|
|
with pytest.raises(ValueError): |
|
block = LiteHRModule( |
|
num_branches=1, |
|
num_blocks=1, |
|
in_channels=[ |
|
40, |
|
], |
|
reduce_ratio=8, |
|
module_type='none') |
|
|
|
|
|
def test_litehrnet_backbone(): |
|
extra = dict( |
|
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1), |
|
num_stages=3, |
|
stages_spec=dict( |
|
num_modules=(2, 4, 2), |
|
num_branches=(2, 3, 4), |
|
num_blocks=(2, 2, 2), |
|
module_type=('LITE', 'LITE', 'LITE'), |
|
with_fuse=(True, True, True), |
|
reduce_ratios=(8, 8, 8), |
|
num_channels=( |
|
(40, 80), |
|
(40, 80, 160), |
|
(40, 80, 160, 320), |
|
)), |
|
with_head=True) |
|
|
|
model = LiteHRNet(extra, in_channels=3) |
|
|
|
imgs = torch.randn(2, 3, 224, 224) |
|
feat = model(imgs) |
|
assert len(feat) == 1 |
|
assert feat[0].shape == torch.Size([2, 40, 56, 56]) |
|
|
|
|
|
model = LiteHRNet(extra, in_channels=3) |
|
model.init_weights() |
|
for m in model.modules(): |
|
if isinstance(m, Bottleneck): |
|
assert all_zeros(m.norm3) |
|
model.train() |
|
|
|
imgs = torch.randn(2, 3, 224, 224) |
|
feat = model(imgs) |
|
assert len(feat) == 1 |
|
assert feat[0].shape == torch.Size([2, 40, 56, 56]) |
|
|
|
extra = dict( |
|
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1), |
|
num_stages=3, |
|
stages_spec=dict( |
|
num_modules=(2, 4, 2), |
|
num_branches=(2, 3, 4), |
|
num_blocks=(2, 2, 2), |
|
module_type=('NAIVE', 'NAIVE', 'NAIVE'), |
|
with_fuse=(True, True, True), |
|
reduce_ratios=(8, 8, 8), |
|
num_channels=( |
|
(40, 80), |
|
(40, 80, 160), |
|
(40, 80, 160, 320), |
|
)), |
|
with_head=True) |
|
|
|
model = LiteHRNet(extra, in_channels=3) |
|
|
|
imgs = torch.randn(2, 3, 224, 224) |
|
feat = model(imgs) |
|
assert len(feat) == 1 |
|
assert feat[0].shape == torch.Size([2, 40, 56, 56]) |
|
|
|
|
|
model = LiteHRNet(extra, in_channels=3) |
|
model.init_weights() |
|
for m in model.modules(): |
|
if isinstance(m, Bottleneck): |
|
assert all_zeros(m.norm3) |
|
model.train() |
|
|
|
imgs = torch.randn(2, 3, 224, 224) |
|
feat = model(imgs) |
|
assert len(feat) == 1 |
|
assert feat[0].shape == torch.Size([2, 40, 56, 56]) |
|
|