|
|
|
|
|
import itertools |
|
import unittest |
|
|
|
import numpy as np |
|
import torch |
|
from pytorchvideo.layers.convolutions import ConvReduce3D |
|
from pytorchvideo.models.stem import ( |
|
create_acoustic_res_basic_stem, |
|
create_res_basic_stem, |
|
ResNetBasicStem, |
|
) |
|
from torch import nn |
|
|
|
|
|
class TestResNetBasicStem(unittest.TestCase): |
|
def setUp(self): |
|
super().setUp() |
|
torch.set_rng_state(torch.manual_seed(42).get_state()) |
|
|
|
def test_create_simple_stem(self): |
|
""" |
|
Test simple ResNetBasicStem (without pooling layer). |
|
""" |
|
for input_dim, output_dim in itertools.product((2, 3), (4, 8, 16)): |
|
model = ResNetBasicStem( |
|
conv=nn.Conv3d( |
|
input_dim, |
|
output_dim, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
bias=False, |
|
), |
|
norm=nn.BatchNorm3d(output_dim), |
|
activation=nn.ReLU(), |
|
pool=None, |
|
) |
|
|
|
|
|
for tensor in TestResNetBasicStem._get_inputs(input_dim): |
|
if tensor.shape[1] != input_dim: |
|
with self.assertRaises(RuntimeError): |
|
output_tensor = model(tensor) |
|
continue |
|
else: |
|
output_tensor = model(tensor) |
|
|
|
input_shape = tensor.shape |
|
output_shape = output_tensor.shape |
|
output_shape_gt = ( |
|
input_shape[0], |
|
output_dim, |
|
input_shape[2], |
|
input_shape[3], |
|
input_shape[4], |
|
) |
|
|
|
self.assertEqual( |
|
output_shape, |
|
output_shape_gt, |
|
"Output shape {} is different from expected shape {}".format( |
|
output_shape, output_shape_gt |
|
), |
|
) |
|
|
|
def test_create_stem_with_conv_reduced_3d(self): |
|
""" |
|
Test simple ResNetBasicStem with ConvReduce3D. |
|
""" |
|
for input_dim, output_dim in itertools.product((2, 3), (4, 8, 16)): |
|
model = ResNetBasicStem( |
|
conv=ConvReduce3D( |
|
in_channels=input_dim, |
|
out_channels=output_dim, |
|
kernel_size=(3, 3), |
|
stride=(1, 1), |
|
padding=(1, 1), |
|
bias=(False, False), |
|
), |
|
norm=nn.BatchNorm3d(output_dim), |
|
activation=nn.ReLU(), |
|
pool=None, |
|
) |
|
|
|
|
|
for tensor in TestResNetBasicStem._get_inputs(input_dim): |
|
if tensor.shape[1] != input_dim: |
|
with self.assertRaises(RuntimeError): |
|
output_tensor = model(tensor) |
|
continue |
|
else: |
|
output_tensor = model(tensor) |
|
|
|
input_shape = tensor.shape |
|
output_shape = output_tensor.shape |
|
output_shape_gt = ( |
|
input_shape[0], |
|
output_dim, |
|
input_shape[2], |
|
input_shape[3], |
|
input_shape[4], |
|
) |
|
|
|
self.assertEqual( |
|
output_shape, |
|
output_shape_gt, |
|
"Output shape {} is different from expected shape {}".format( |
|
output_shape, output_shape_gt |
|
), |
|
) |
|
|
|
def test_create_complex_stem(self): |
|
""" |
|
Test complex ResNetBasicStem. |
|
""" |
|
for input_dim, output_dim in itertools.product((2, 3), (4, 8, 16)): |
|
model = ResNetBasicStem( |
|
conv=nn.Conv3d( |
|
input_dim, |
|
output_dim, |
|
kernel_size=[3, 7, 7], |
|
stride=[1, 2, 2], |
|
padding=[1, 3, 3], |
|
bias=False, |
|
), |
|
norm=nn.BatchNorm3d(output_dim), |
|
activation=nn.ReLU(), |
|
pool=nn.MaxPool3d( |
|
kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1] |
|
), |
|
) |
|
|
|
|
|
for input_tensor in TestResNetBasicStem._get_inputs(input_dim): |
|
if input_tensor.shape[1] != input_dim: |
|
with self.assertRaises(Exception): |
|
output_tensor = model(input_tensor) |
|
continue |
|
else: |
|
output_tensor = model(input_tensor) |
|
|
|
input_shape = input_tensor.shape |
|
output_shape = output_tensor.shape |
|
|
|
output_shape_gt = ( |
|
input_shape[0], |
|
output_dim, |
|
input_shape[2], |
|
(((input_shape[3] - 1) // 2 + 1) - 1) // 2 + 1, |
|
(((input_shape[4] - 1) // 2 + 1) - 1) // 2 + 1, |
|
) |
|
|
|
self.assertEqual( |
|
output_shape, |
|
output_shape_gt, |
|
"Output shape {} is different from expected shape {}".format( |
|
output_shape, output_shape_gt |
|
), |
|
) |
|
|
|
def test_create_stem_with_callable(self): |
|
""" |
|
Test builder `create_res_basic_stem` with callable inputs. |
|
""" |
|
for (pool, activation, norm) in itertools.product( |
|
(nn.AvgPool3d, nn.MaxPool3d, None), |
|
(nn.ReLU, nn.Softmax, nn.Sigmoid, None), |
|
(nn.BatchNorm3d, None), |
|
): |
|
model = create_res_basic_stem( |
|
in_channels=3, |
|
out_channels=64, |
|
pool=pool, |
|
activation=activation, |
|
norm=norm, |
|
) |
|
model_gt = ResNetBasicStem( |
|
conv=nn.Conv3d( |
|
3, |
|
64, |
|
kernel_size=[3, 7, 7], |
|
stride=[1, 2, 2], |
|
padding=[1, 3, 3], |
|
bias=False, |
|
), |
|
norm=None if norm is None else norm(64), |
|
activation=None if activation is None else activation(), |
|
pool=None |
|
if pool is None |
|
else pool(kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1]), |
|
) |
|
|
|
model.load_state_dict( |
|
model_gt.state_dict(), strict=True |
|
) |
|
|
|
|
|
for input_tensor in TestResNetBasicStem._get_inputs(): |
|
with torch.no_grad(): |
|
if input_tensor.shape[1] != 3: |
|
with self.assertRaises(RuntimeError): |
|
output_tensor = model(input_tensor) |
|
continue |
|
else: |
|
output_tensor = model(input_tensor) |
|
output_tensor_gt = model_gt(input_tensor) |
|
self.assertEqual( |
|
output_tensor.shape, |
|
output_tensor_gt.shape, |
|
"Output shape {} is different from expected shape {}".format( |
|
output_tensor.shape, output_tensor_gt.shape |
|
), |
|
) |
|
self.assertTrue( |
|
np.allclose(output_tensor.numpy(), output_tensor_gt.numpy()) |
|
) |
|
|
|
def test_create_acoustic_stem_with_callable(self): |
|
""" |
|
Test builder `create_acoustic_res_basic_stem` with callable |
|
inputs. |
|
""" |
|
for (pool, activation, norm) in itertools.product( |
|
(nn.AvgPool3d, nn.MaxPool3d, None), |
|
(nn.ReLU, nn.Softmax, nn.Sigmoid, None), |
|
(nn.BatchNorm3d, None), |
|
): |
|
model = create_acoustic_res_basic_stem( |
|
in_channels=3, |
|
out_channels=64, |
|
pool=pool, |
|
activation=activation, |
|
norm=norm, |
|
) |
|
model_gt = ResNetBasicStem( |
|
conv=ConvReduce3D( |
|
in_channels=3, |
|
out_channels=64, |
|
kernel_size=((3, 1, 1), (1, 7, 7)), |
|
stride=((1, 1, 1), (1, 1, 1)), |
|
padding=((1, 0, 0), (0, 3, 3)), |
|
bias=(False, False), |
|
), |
|
norm=None if norm is None else norm(64), |
|
activation=None if activation is None else activation(), |
|
pool=None |
|
if pool is None |
|
else pool(kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1]), |
|
) |
|
|
|
model.load_state_dict( |
|
model_gt.state_dict(), strict=True |
|
) |
|
|
|
|
|
for input_tensor in TestResNetBasicStem._get_inputs(): |
|
with torch.no_grad(): |
|
if input_tensor.shape[1] != 3: |
|
with self.assertRaises(RuntimeError): |
|
output_tensor = model(input_tensor) |
|
continue |
|
else: |
|
output_tensor = model(input_tensor) |
|
output_tensor_gt = model_gt(input_tensor) |
|
self.assertEqual( |
|
output_tensor.shape, |
|
output_tensor_gt.shape, |
|
"Output shape {} is different from expected shape {}".format( |
|
output_tensor.shape, output_tensor_gt.shape |
|
), |
|
) |
|
self.assertTrue( |
|
np.allclose(output_tensor.numpy(), output_tensor_gt.numpy()) |
|
) |
|
|
|
@staticmethod |
|
def _get_inputs(input_dim: int = 3) -> torch.tensor: |
|
""" |
|
Provide different tensors as test cases. |
|
|
|
Yield: |
|
(torch.tensor): tensor as test case input. |
|
""" |
|
|
|
shapes = ( |
|
|
|
(1, input_dim, 3, 7, 7), |
|
(1, input_dim, 5, 7, 7), |
|
(1, input_dim, 7, 7, 7), |
|
(2, input_dim, 3, 7, 7), |
|
(4, input_dim, 3, 7, 7), |
|
(8, input_dim, 3, 7, 7), |
|
(2, input_dim, 3, 7, 14), |
|
(2, input_dim, 3, 14, 7), |
|
(2, input_dim, 3, 14, 14), |
|
|
|
(8, input_dim * 2, 3, 7, 7), |
|
(8, input_dim * 4, 5, 7, 7), |
|
) |
|
for shape in shapes: |
|
yield torch.rand(shape) |
|
|