camenduru's picture
thanks to show ❤
3bbb319
raw
history blame contribute delete
No virus
3.99 kB
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from mmpose.models.backbones import LiteHRNet
from mmpose.models.backbones.litehrnet import LiteHRModule
from mmpose.models.backbones.resnet import Bottleneck
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (_BatchNorm, )):
return True
return False
def all_zeros(modules):
"""Check if the weight(and bias) is all zero."""
weight_zero = torch.equal(modules.weight.data,
torch.zeros_like(modules.weight.data))
if hasattr(modules, 'bias'):
bias_zero = torch.equal(modules.bias.data,
torch.zeros_like(modules.bias.data))
else:
bias_zero = True
return weight_zero and bias_zero
def test_litehrmodule():
# Test LiteHRModule forward
block = LiteHRModule(
num_branches=1,
num_blocks=1,
in_channels=[
40,
],
reduce_ratio=8,
module_type='LITE')
x = torch.randn(2, 40, 56, 56)
x_out = block([[x]])
assert x_out[0][0].shape == torch.Size([2, 40, 56, 56])
block = LiteHRModule(
num_branches=1,
num_blocks=1,
in_channels=[
40,
],
reduce_ratio=8,
module_type='NAIVE')
x = torch.randn(2, 40, 56, 56)
x_out = block([x])
assert x_out[0].shape == torch.Size([2, 40, 56, 56])
with pytest.raises(ValueError):
block = LiteHRModule(
num_branches=1,
num_blocks=1,
in_channels=[
40,
],
reduce_ratio=8,
module_type='none')
def test_litehrnet_backbone():
extra = dict(
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
num_stages=3,
stages_spec=dict(
num_modules=(2, 4, 2),
num_branches=(2, 3, 4),
num_blocks=(2, 2, 2),
module_type=('LITE', 'LITE', 'LITE'),
with_fuse=(True, True, True),
reduce_ratios=(8, 8, 8),
num_channels=(
(40, 80),
(40, 80, 160),
(40, 80, 160, 320),
)),
with_head=True)
model = LiteHRNet(extra, in_channels=3)
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([2, 40, 56, 56])
# Test HRNet zero initialization of residual
model = LiteHRNet(extra, in_channels=3)
model.init_weights()
for m in model.modules():
if isinstance(m, Bottleneck):
assert all_zeros(m.norm3)
model.train()
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([2, 40, 56, 56])
extra = dict(
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
num_stages=3,
stages_spec=dict(
num_modules=(2, 4, 2),
num_branches=(2, 3, 4),
num_blocks=(2, 2, 2),
module_type=('NAIVE', 'NAIVE', 'NAIVE'),
with_fuse=(True, True, True),
reduce_ratios=(8, 8, 8),
num_channels=(
(40, 80),
(40, 80, 160),
(40, 80, 160, 320),
)),
with_head=True)
model = LiteHRNet(extra, in_channels=3)
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([2, 40, 56, 56])
# Test HRNet zero initialization of residual
model = LiteHRNet(extra, in_channels=3)
model.init_weights()
for m in model.modules():
if isinstance(m, Bottleneck):
assert all_zeros(m.norm3)
model.train()
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([2, 40, 56, 56])