pandagpt-vicuna-v0-7b / code /pytorchvideo /tests /test_models_vision_transformers.py
mvsoom's picture
Upload folder using huggingface_hub
3133fdb
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import itertools
import unittest
import warnings
import torch
from pytorchvideo.models.vision_transformers import (
create_multiscale_vision_transformers,
)
class TestVisionTransformers(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
def test_create_mvit(self):
"""
Test MViT.
"""
# Test MViT with 3D case.
num_head = 100
batch_size = 1
fake_input = torch.rand(batch_size, 3, 4, 28, 28)
model = create_multiscale_vision_transformers(
spatial_size=28,
temporal_size=4,
patch_embed_dim=12,
depth=1,
head_num_classes=num_head,
pool_kv_stride_adaptive=[1, 2, 2],
)
output = model(fake_input)
gt_shape_tensor = torch.rand(batch_size, num_head)
self.assertEqual(output.shape, gt_shape_tensor.shape)
# Test MViT with 3D case with pool first.
num_head = 100
batch_size = 1
fake_input = torch.rand(batch_size, 3, 4, 28, 28)
model = create_multiscale_vision_transformers(
spatial_size=28,
temporal_size=4,
patch_embed_dim=12,
depth=1,
head_num_classes=num_head,
pool_first=True,
pool_q_stride_size=[[0, 1, 2, 2]],
)
output = model(fake_input)
gt_shape_tensor = torch.rand(batch_size, num_head)
self.assertEqual(output.shape, gt_shape_tensor.shape)
# Test MViT with 2D case for images.
conv_patch_kernel = (7, 7)
conv_patch_stride = (4, 4)
conv_patch_padding = (3, 3)
num_head = 100
batch_size = 1
fake_input = torch.rand(batch_size, 3, 28, 28)
model = create_multiscale_vision_transformers(
spatial_size=(28, 28),
temporal_size=1,
patch_embed_dim=12,
depth=1,
head_num_classes=num_head,
use_2d_patch=True,
conv_patch_embed_kernel=conv_patch_kernel,
conv_patch_embed_stride=conv_patch_stride,
conv_patch_embed_padding=conv_patch_padding,
)
output = model(fake_input)
gt_shape_tensor = torch.rand(batch_size, num_head)
self.assertEqual(output.shape, gt_shape_tensor.shape)
# Test MViT without patch_embed.
conv_patch_kernel = (7, 7)
conv_patch_stride = (4, 4)
conv_patch_padding = (3, 3)
num_head = 100
batch_size = 1
fake_input = torch.rand(batch_size, 8, 12)
model = create_multiscale_vision_transformers(
spatial_size=(8, 1),
temporal_size=1,
patch_embed_dim=12,
depth=1,
enable_patch_embed=False,
head_num_classes=num_head,
)
output = model(fake_input)
gt_shape_tensor = torch.rand(batch_size, num_head)
self.assertEqual(output.shape, gt_shape_tensor.shape)
self.assertRaises(
AssertionError,
create_multiscale_vision_transformers,
spatial_size=28,
temporal_size=4,
use_2d_patch=True,
)
self.assertRaises(
AssertionError,
create_multiscale_vision_transformers,
spatial_size=28,
temporal_size=1,
pool_kv_stride_adaptive=[[2, 2, 2]],
pool_kv_stride_size=[[1, 1, 2, 2]],
)
self.assertRaises(
NotImplementedError,
create_multiscale_vision_transformers,
spatial_size=28,
temporal_size=1,
norm="fakenorm",
)
def test_mvit_is_torchscriptable(self):
batch_size = 2
num_head = 4
spatial_size = (28, 28)
temporal_size = 4
depth = 2
patch_embed_dim = 96
# The following binary settings are covered by `test_layers_attention.py`:
# `qkv_bias`, `depthwise_conv`, `separate_qkv`, `bias_on` `pool_first`
# `residual_pool`
true_false_opts = [
"cls_embed_on",
"sep_pos_embed",
"enable_patch_embed",
"enable_patch_embed_norm",
]
# Loop over `2 ^ len(true_false_opts)` configurations
for true_false_settings in itertools.product(
*([[True, False]] * len(true_false_opts))
):
named_tf_settings = dict(zip(true_false_opts, true_false_settings))
model = create_multiscale_vision_transformers(
spatial_size=spatial_size,
temporal_size=temporal_size,
depth=depth,
head_num_classes=num_head,
patch_embed_dim=patch_embed_dim,
pool_kv_stride_adaptive=[1, 2, 2],
**named_tf_settings,
create_scriptable_model=False,
).eval()
ts_model = torch.jit.script(model)
input_shape = (
(3, temporal_size, spatial_size[0], spatial_size[1])
if named_tf_settings["enable_patch_embed"]
else (
temporal_size * spatial_size[0] * spatial_size[1],
patch_embed_dim,
)
)
fake_input = torch.rand(batch_size, *input_shape)
expected = model(fake_input)
actual = ts_model(fake_input)
torch.testing.assert_allclose(expected, actual)
def test_mvit_create_scriptable_model_is_deprecated(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
create_multiscale_vision_transformers(
spatial_size=28,
temporal_size=4,
norm="batchnorm",
depth=2,
head_num_classes=100,
create_scriptable_model=True,
)
assert len(w) == 1
assert issubclass(w[-1].category, DeprecationWarning)