File size: 5,168 Bytes
3bbb319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
import torch.nn as nn
from mmpose.models.backbones import TCN
from mmpose.models.backbones.tcn import BasicTemporalBlock
def test_basic_temporal_block():
with pytest.raises(AssertionError):
# padding( + shift) should not be larger than x.shape[2]
block = BasicTemporalBlock(1024, 1024, dilation=81)
x = torch.rand(2, 1024, 150)
x_out = block(x)
with pytest.raises(AssertionError):
# when use_stride_conv is True, shift + kernel_size // 2 should
# not be larger than x.shape[2]
block = BasicTemporalBlock(
1024, 1024, kernel_size=5, causal=True, use_stride_conv=True)
x = torch.rand(2, 1024, 3)
x_out = block(x)
# BasicTemporalBlock with causal == False
block = BasicTemporalBlock(1024, 1024)
x = torch.rand(2, 1024, 241)
x_out = block(x)
assert x_out.shape == torch.Size([2, 1024, 235])
# BasicTemporalBlock with causal == True
block = BasicTemporalBlock(1024, 1024, causal=True)
x = torch.rand(2, 1024, 241)
x_out = block(x)
assert x_out.shape == torch.Size([2, 1024, 235])
# BasicTemporalBlock with residual == False
block = BasicTemporalBlock(1024, 1024, residual=False)
x = torch.rand(2, 1024, 241)
x_out = block(x)
assert x_out.shape == torch.Size([2, 1024, 235])
# BasicTemporalBlock, use_stride_conv == True
block = BasicTemporalBlock(1024, 1024, use_stride_conv=True)
x = torch.rand(2, 1024, 81)
x_out = block(x)
assert x_out.shape == torch.Size([2, 1024, 27])
# BasicTemporalBlock with use_stride_conv == True and causal == True
block = BasicTemporalBlock(1024, 1024, use_stride_conv=True, causal=True)
x = torch.rand(2, 1024, 81)
x_out = block(x)
assert x_out.shape == torch.Size([2, 1024, 27])
def test_tcn_backbone():
with pytest.raises(AssertionError):
# num_blocks should equal len(kernel_sizes) - 1
TCN(in_channels=34, num_blocks=3, kernel_sizes=(3, 3, 3))
with pytest.raises(AssertionError):
# kernel size should be odd
TCN(in_channels=34, kernel_sizes=(3, 4, 3))
# Test TCN with 2 blocks (use_stride_conv == False)
model = TCN(in_channels=34, num_blocks=2, kernel_sizes=(3, 3, 3))
pose2d = torch.rand((2, 34, 243))
feat = model(pose2d)
assert len(feat) == 2
assert feat[0].shape == (2, 1024, 235)
assert feat[1].shape == (2, 1024, 217)
# Test TCN with 4 blocks and weight norm clip
max_norm = 0.1
model = TCN(
in_channels=34,
num_blocks=4,
kernel_sizes=(3, 3, 3, 3, 3),
max_norm=max_norm)
pose2d = torch.rand((2, 34, 243))
feat = model(pose2d)
assert len(feat) == 4
assert feat[0].shape == (2, 1024, 235)
assert feat[1].shape == (2, 1024, 217)
assert feat[2].shape == (2, 1024, 163)
assert feat[3].shape == (2, 1024, 1)
for module in model.modules():
if isinstance(module, torch.nn.modules.conv._ConvNd):
norm = module.weight.norm().item()
np.testing.assert_allclose(
np.maximum(norm, max_norm), max_norm, rtol=1e-4)
# Test TCN with 4 blocks (use_stride_conv == True)
model = TCN(
in_channels=34,
num_blocks=4,
kernel_sizes=(3, 3, 3, 3, 3),
use_stride_conv=True)
pose2d = torch.rand((2, 34, 243))
feat = model(pose2d)
assert len(feat) == 4
assert feat[0].shape == (2, 1024, 27)
assert feat[1].shape == (2, 1024, 9)
assert feat[2].shape == (2, 1024, 3)
assert feat[3].shape == (2, 1024, 1)
# Check that the model w. or w/o use_stride_conv will have the same
# output and gradient after a forward+backward pass
model1 = TCN(
in_channels=34,
stem_channels=4,
num_blocks=1,
kernel_sizes=(3, 3),
dropout=0,
residual=False,
norm_cfg=None)
model2 = TCN(
in_channels=34,
stem_channels=4,
num_blocks=1,
kernel_sizes=(3, 3),
dropout=0,
residual=False,
norm_cfg=None,
use_stride_conv=True)
for m in model1.modules():
if isinstance(m, nn.Conv1d):
nn.init.constant_(m.weight, 0.5)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
for m in model2.modules():
if isinstance(m, nn.Conv1d):
nn.init.constant_(m.weight, 0.5)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
input1 = torch.rand((1, 34, 9))
input2 = input1.clone()
outputs1 = model1(input1)
outputs2 = model2(input2)
for output1, output2 in zip(outputs1, outputs2):
assert torch.isclose(output1, output2).all()
criterion = nn.MSELoss()
target = torch.rand(output1.shape)
loss1 = criterion(output1, target)
loss2 = criterion(output2, target)
loss1.backward()
loss2.backward()
for m1, m2 in zip(model1.modules(), model2.modules()):
if isinstance(m1, nn.Conv1d):
assert torch.isclose(m1.weight.grad, m2.weight.grad).all()
|