mvsoom's picture
Upload folder using huggingface_hub
3133fdb
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import itertools
import unittest
import numpy as np
import torch
from pytorchvideo.layers.convolutions import ConvReduce3D
from pytorchvideo.models.stem import (
create_acoustic_res_basic_stem,
create_res_basic_stem,
ResNetBasicStem,
)
from torch import nn
class TestResNetBasicStem(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
def test_create_simple_stem(self):
"""
Test simple ResNetBasicStem (without pooling layer).
"""
for input_dim, output_dim in itertools.product((2, 3), (4, 8, 16)):
model = ResNetBasicStem(
conv=nn.Conv3d(
input_dim,
output_dim,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
norm=nn.BatchNorm3d(output_dim),
activation=nn.ReLU(),
pool=None,
)
# Test forwarding.
for tensor in TestResNetBasicStem._get_inputs(input_dim):
if tensor.shape[1] != input_dim:
with self.assertRaises(RuntimeError):
output_tensor = model(tensor)
continue
else:
output_tensor = model(tensor)
input_shape = tensor.shape
output_shape = output_tensor.shape
output_shape_gt = (
input_shape[0],
output_dim,
input_shape[2],
input_shape[3],
input_shape[4],
)
self.assertEqual(
output_shape,
output_shape_gt,
"Output shape {} is different from expected shape {}".format(
output_shape, output_shape_gt
),
)
def test_create_stem_with_conv_reduced_3d(self):
"""
Test simple ResNetBasicStem with ConvReduce3D.
"""
for input_dim, output_dim in itertools.product((2, 3), (4, 8, 16)):
model = ResNetBasicStem(
conv=ConvReduce3D(
in_channels=input_dim,
out_channels=output_dim,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias=(False, False),
),
norm=nn.BatchNorm3d(output_dim),
activation=nn.ReLU(),
pool=None,
)
# Test forwarding.
for tensor in TestResNetBasicStem._get_inputs(input_dim):
if tensor.shape[1] != input_dim:
with self.assertRaises(RuntimeError):
output_tensor = model(tensor)
continue
else:
output_tensor = model(tensor)
input_shape = tensor.shape
output_shape = output_tensor.shape
output_shape_gt = (
input_shape[0],
output_dim,
input_shape[2],
input_shape[3],
input_shape[4],
)
self.assertEqual(
output_shape,
output_shape_gt,
"Output shape {} is different from expected shape {}".format(
output_shape, output_shape_gt
),
)
def test_create_complex_stem(self):
"""
Test complex ResNetBasicStem.
"""
for input_dim, output_dim in itertools.product((2, 3), (4, 8, 16)):
model = ResNetBasicStem(
conv=nn.Conv3d(
input_dim,
output_dim,
kernel_size=[3, 7, 7],
stride=[1, 2, 2],
padding=[1, 3, 3],
bias=False,
),
norm=nn.BatchNorm3d(output_dim),
activation=nn.ReLU(),
pool=nn.MaxPool3d(
kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1]
),
)
# Test forwarding.
for input_tensor in TestResNetBasicStem._get_inputs(input_dim):
if input_tensor.shape[1] != input_dim:
with self.assertRaises(Exception):
output_tensor = model(input_tensor)
continue
else:
output_tensor = model(input_tensor)
input_shape = input_tensor.shape
output_shape = output_tensor.shape
output_shape_gt = (
input_shape[0],
output_dim,
input_shape[2],
(((input_shape[3] - 1) // 2 + 1) - 1) // 2 + 1,
(((input_shape[4] - 1) // 2 + 1) - 1) // 2 + 1,
)
self.assertEqual(
output_shape,
output_shape_gt,
"Output shape {} is different from expected shape {}".format(
output_shape, output_shape_gt
),
)
def test_create_stem_with_callable(self):
"""
Test builder `create_res_basic_stem` with callable inputs.
"""
for (pool, activation, norm) in itertools.product(
(nn.AvgPool3d, nn.MaxPool3d, None),
(nn.ReLU, nn.Softmax, nn.Sigmoid, None),
(nn.BatchNorm3d, None),
):
model = create_res_basic_stem(
in_channels=3,
out_channels=64,
pool=pool,
activation=activation,
norm=norm,
)
model_gt = ResNetBasicStem(
conv=nn.Conv3d(
3,
64,
kernel_size=[3, 7, 7],
stride=[1, 2, 2],
padding=[1, 3, 3],
bias=False,
),
norm=None if norm is None else norm(64),
activation=None if activation is None else activation(),
pool=None
if pool is None
else pool(kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1]),
)
model.load_state_dict(
model_gt.state_dict(), strict=True
) # explicitly use strict mode.
# Test forwarding.
for input_tensor in TestResNetBasicStem._get_inputs():
with torch.no_grad():
if input_tensor.shape[1] != 3:
with self.assertRaises(RuntimeError):
output_tensor = model(input_tensor)
continue
else:
output_tensor = model(input_tensor)
output_tensor_gt = model_gt(input_tensor)
self.assertEqual(
output_tensor.shape,
output_tensor_gt.shape,
"Output shape {} is different from expected shape {}".format(
output_tensor.shape, output_tensor_gt.shape
),
)
self.assertTrue(
np.allclose(output_tensor.numpy(), output_tensor_gt.numpy())
)
def test_create_acoustic_stem_with_callable(self):
"""
Test builder `create_acoustic_res_basic_stem` with callable
inputs.
"""
for (pool, activation, norm) in itertools.product(
(nn.AvgPool3d, nn.MaxPool3d, None),
(nn.ReLU, nn.Softmax, nn.Sigmoid, None),
(nn.BatchNorm3d, None),
):
model = create_acoustic_res_basic_stem(
in_channels=3,
out_channels=64,
pool=pool,
activation=activation,
norm=norm,
)
model_gt = ResNetBasicStem(
conv=ConvReduce3D(
in_channels=3,
out_channels=64,
kernel_size=((3, 1, 1), (1, 7, 7)),
stride=((1, 1, 1), (1, 1, 1)),
padding=((1, 0, 0), (0, 3, 3)),
bias=(False, False),
),
norm=None if norm is None else norm(64),
activation=None if activation is None else activation(),
pool=None
if pool is None
else pool(kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1]),
)
model.load_state_dict(
model_gt.state_dict(), strict=True
) # explicitly use strict mode.
# Test forwarding.
for input_tensor in TestResNetBasicStem._get_inputs():
with torch.no_grad():
if input_tensor.shape[1] != 3:
with self.assertRaises(RuntimeError):
output_tensor = model(input_tensor)
continue
else:
output_tensor = model(input_tensor)
output_tensor_gt = model_gt(input_tensor)
self.assertEqual(
output_tensor.shape,
output_tensor_gt.shape,
"Output shape {} is different from expected shape {}".format(
output_tensor.shape, output_tensor_gt.shape
),
)
self.assertTrue(
np.allclose(output_tensor.numpy(), output_tensor_gt.numpy())
)
@staticmethod
def _get_inputs(input_dim: int = 3) -> torch.tensor:
"""
Provide different tensors as test cases.
Yield:
(torch.tensor): tensor as test case input.
"""
# Prepare random tensor as test cases.
shapes = (
# Forward succeeded.
(1, input_dim, 3, 7, 7),
(1, input_dim, 5, 7, 7),
(1, input_dim, 7, 7, 7),
(2, input_dim, 3, 7, 7),
(4, input_dim, 3, 7, 7),
(8, input_dim, 3, 7, 7),
(2, input_dim, 3, 7, 14),
(2, input_dim, 3, 14, 7),
(2, input_dim, 3, 14, 14),
# Forward failed.
(8, input_dim * 2, 3, 7, 7),
(8, input_dim * 4, 5, 7, 7),
)
for shape in shapes:
yield torch.rand(shape)