|
|
|
import pytest |
|
import torch |
|
|
|
from mmpose.models.backbones import SEResNeXt |
|
from mmpose.models.backbones.seresnext import SEBottleneck as SEBottleneckX |
|
|
|
|
|
def test_bottleneck(): |
|
with pytest.raises(AssertionError): |
|
|
|
SEBottleneckX(64, 64, groups=32, width_per_group=4, style='tensorflow') |
|
|
|
|
|
block = SEBottleneckX( |
|
64, 256, groups=32, width_per_group=4, stride=2, style='pytorch') |
|
assert block.width_per_group == 4 |
|
assert block.conv2.stride == (2, 2) |
|
assert block.conv2.groups == 32 |
|
assert block.conv2.out_channels == 128 |
|
assert block.conv2.out_channels == block.mid_channels |
|
|
|
|
|
block = SEBottleneckX( |
|
64, 256, groups=1, width_per_group=4, stride=2, style='pytorch') |
|
assert block.conv2.stride == (2, 2) |
|
assert block.conv2.groups == 1 |
|
assert block.conv2.out_channels == 64 |
|
assert block.mid_channels == 64 |
|
assert block.conv2.out_channels == block.mid_channels |
|
|
|
|
|
block = SEBottleneckX( |
|
64, 64, base_channels=16, groups=32, width_per_group=4) |
|
x = torch.randn(1, 64, 56, 56) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([1, 64, 56, 56]) |
|
|
|
|
|
def test_seresnext(): |
|
with pytest.raises(KeyError): |
|
|
|
SEResNeXt(depth=18) |
|
|
|
|
|
model = SEResNeXt( |
|
depth=50, groups=32, width_per_group=4, out_indices=(0, 1, 2, 3)) |
|
for m in model.modules(): |
|
if isinstance(m, SEBottleneckX): |
|
assert m.conv2.groups == 32 |
|
model.init_weights() |
|
model.train() |
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
feat = model(imgs) |
|
assert len(feat) == 4 |
|
assert feat[0].shape == torch.Size([1, 256, 56, 56]) |
|
assert feat[1].shape == torch.Size([1, 512, 28, 28]) |
|
assert feat[2].shape == torch.Size([1, 1024, 14, 14]) |
|
assert feat[3].shape == torch.Size([1, 2048, 7, 7]) |
|
|
|
|
|
model = SEResNeXt( |
|
depth=50, groups=32, width_per_group=4, out_indices=(3, )) |
|
for m in model.modules(): |
|
if isinstance(m, SEBottleneckX): |
|
assert m.conv2.groups == 32 |
|
model.init_weights() |
|
model.train() |
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
feat = model(imgs) |
|
assert feat.shape == torch.Size([1, 2048, 7, 7]) |
|
|