|
|
|
import numpy as np |
|
import pytest |
|
import torch |
|
import torch.nn as nn |
|
|
|
from mmpose.models.backbones import TCN |
|
from mmpose.models.backbones.tcn import BasicTemporalBlock |
|
|
|
|
|
def test_basic_temporal_block(): |
|
with pytest.raises(AssertionError): |
|
|
|
block = BasicTemporalBlock(1024, 1024, dilation=81) |
|
x = torch.rand(2, 1024, 150) |
|
x_out = block(x) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
|
|
block = BasicTemporalBlock( |
|
1024, 1024, kernel_size=5, causal=True, use_stride_conv=True) |
|
x = torch.rand(2, 1024, 3) |
|
x_out = block(x) |
|
|
|
|
|
block = BasicTemporalBlock(1024, 1024) |
|
x = torch.rand(2, 1024, 241) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([2, 1024, 235]) |
|
|
|
|
|
block = BasicTemporalBlock(1024, 1024, causal=True) |
|
x = torch.rand(2, 1024, 241) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([2, 1024, 235]) |
|
|
|
|
|
block = BasicTemporalBlock(1024, 1024, residual=False) |
|
x = torch.rand(2, 1024, 241) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([2, 1024, 235]) |
|
|
|
|
|
block = BasicTemporalBlock(1024, 1024, use_stride_conv=True) |
|
x = torch.rand(2, 1024, 81) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([2, 1024, 27]) |
|
|
|
|
|
block = BasicTemporalBlock(1024, 1024, use_stride_conv=True, causal=True) |
|
x = torch.rand(2, 1024, 81) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([2, 1024, 27]) |
|
|
|
|
|
def test_tcn_backbone(): |
|
with pytest.raises(AssertionError): |
|
|
|
TCN(in_channels=34, num_blocks=3, kernel_sizes=(3, 3, 3)) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
TCN(in_channels=34, kernel_sizes=(3, 4, 3)) |
|
|
|
|
|
model = TCN(in_channels=34, num_blocks=2, kernel_sizes=(3, 3, 3)) |
|
pose2d = torch.rand((2, 34, 243)) |
|
feat = model(pose2d) |
|
assert len(feat) == 2 |
|
assert feat[0].shape == (2, 1024, 235) |
|
assert feat[1].shape == (2, 1024, 217) |
|
|
|
|
|
max_norm = 0.1 |
|
model = TCN( |
|
in_channels=34, |
|
num_blocks=4, |
|
kernel_sizes=(3, 3, 3, 3, 3), |
|
max_norm=max_norm) |
|
pose2d = torch.rand((2, 34, 243)) |
|
feat = model(pose2d) |
|
assert len(feat) == 4 |
|
assert feat[0].shape == (2, 1024, 235) |
|
assert feat[1].shape == (2, 1024, 217) |
|
assert feat[2].shape == (2, 1024, 163) |
|
assert feat[3].shape == (2, 1024, 1) |
|
|
|
for module in model.modules(): |
|
if isinstance(module, torch.nn.modules.conv._ConvNd): |
|
norm = module.weight.norm().item() |
|
np.testing.assert_allclose( |
|
np.maximum(norm, max_norm), max_norm, rtol=1e-4) |
|
|
|
|
|
model = TCN( |
|
in_channels=34, |
|
num_blocks=4, |
|
kernel_sizes=(3, 3, 3, 3, 3), |
|
use_stride_conv=True) |
|
pose2d = torch.rand((2, 34, 243)) |
|
feat = model(pose2d) |
|
assert len(feat) == 4 |
|
assert feat[0].shape == (2, 1024, 27) |
|
assert feat[1].shape == (2, 1024, 9) |
|
assert feat[2].shape == (2, 1024, 3) |
|
assert feat[3].shape == (2, 1024, 1) |
|
|
|
|
|
|
|
model1 = TCN( |
|
in_channels=34, |
|
stem_channels=4, |
|
num_blocks=1, |
|
kernel_sizes=(3, 3), |
|
dropout=0, |
|
residual=False, |
|
norm_cfg=None) |
|
model2 = TCN( |
|
in_channels=34, |
|
stem_channels=4, |
|
num_blocks=1, |
|
kernel_sizes=(3, 3), |
|
dropout=0, |
|
residual=False, |
|
norm_cfg=None, |
|
use_stride_conv=True) |
|
for m in model1.modules(): |
|
if isinstance(m, nn.Conv1d): |
|
nn.init.constant_(m.weight, 0.5) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
for m in model2.modules(): |
|
if isinstance(m, nn.Conv1d): |
|
nn.init.constant_(m.weight, 0.5) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
input1 = torch.rand((1, 34, 9)) |
|
input2 = input1.clone() |
|
outputs1 = model1(input1) |
|
outputs2 = model2(input2) |
|
for output1, output2 in zip(outputs1, outputs2): |
|
assert torch.isclose(output1, output2).all() |
|
|
|
criterion = nn.MSELoss() |
|
target = torch.rand(output1.shape) |
|
loss1 = criterion(output1, target) |
|
loss2 = criterion(output2, target) |
|
loss1.backward() |
|
loss2.backward() |
|
for m1, m2 in zip(model1.modules(), model2.modules()): |
|
if isinstance(m1, nn.Conv1d): |
|
assert torch.isclose(m1.weight.grad, m2.weight.grad).all() |
|
|