# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import unittest from collections import Counter from itertools import permutations import numpy as np import torch from pytorchvideo.data.utils import thwc_to_cthw from pytorchvideo.transforms import ( ApplyTransformToKey, AugMix, create_video_transform, CutMix, MixUp, MixVideo, Normalize, OpSampler, Permute, RandAugment, RandomResizedCrop, RandomShortSideScale, ShortSideScale, UniformCropVideo, UniformTemporalSubsample, ) from pytorchvideo.transforms.functional import ( clip_boxes_to_image, convert_to_one_hot, div_255, horizontal_flip_with_boxes, random_crop_with_boxes, random_short_side_scale_with_boxes, short_side_scale, short_side_scale_with_boxes, uniform_crop, uniform_crop_with_boxes, uniform_temporal_subsample, uniform_temporal_subsample_repeated, ) from torchvision.transforms import Compose from torchvision.transforms._transforms_video import ( CenterCropVideo, NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo, ) from utils import create_dummy_video_frames, create_random_bbox class TestTransforms(unittest.TestCase): def test_compose_with_video_transforms(self): video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) test_clip = {"video": video, "label": 0} # Compose using torchvision and pytorchvideo transformst to ensure they interact # correctly. num_subsample = 10 transform = Compose( [ ApplyTransformToKey( key="video", transform=Compose( [ UniformTemporalSubsample(num_subsample), NormalizeVideo([video.mean()] * 3, [video.std()] * 3), RandomShortSideScale(min_size=15, max_size=25), RandomCropVideo(10), RandomHorizontalFlipVideo(p=0.5), ] ), ) ] ) actual = transform(test_clip) c, t, h, w = actual["video"].shape self.assertEqual(c, 3) self.assertEqual(t, num_subsample) self.assertEqual(h, 10) self.assertEqual(w, 10) def test_uniform_temporal_subsample(self): video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) actual = uniform_temporal_subsample(video, video.shape[1]) self.assertTrue(actual.equal(video)) video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) actual = uniform_temporal_subsample(video, video.shape[1] // 2) self.assertTrue(actual.equal(video[:, [0, 2, 4, 6, 8, 10, 12, 14, 16, 19]])) video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) actual = uniform_temporal_subsample(video, 1) self.assertTrue(actual.equal(video[:, 0:1])) def test_short_side_scale_width_shorter_pytorch(self): video = thwc_to_cthw(create_dummy_video_frames(20, 20, 10)).to( dtype=torch.float32 ) actual = short_side_scale(video, 5, backend="pytorch") self.assertEqual(actual.shape, (3, 20, 10, 5)) def test_short_side_scale_height_shorter_pytorch(self): video = thwc_to_cthw(create_dummy_video_frames(20, 10, 20)).to( dtype=torch.float32 ) actual = short_side_scale(video, 5, backend="pytorch") self.assertEqual(actual.shape, (3, 20, 5, 10)) def test_short_side_scale_equal_size_pytorch(self): video = thwc_to_cthw(create_dummy_video_frames(20, 10, 10)).to( dtype=torch.float32 ) actual = short_side_scale(video, 10, backend="pytorch") self.assertEqual(actual.shape, (3, 20, 10, 10)) def test_short_side_scale_width_shorter_opencv(self): video = thwc_to_cthw(create_dummy_video_frames(20, 20, 10)).to( dtype=torch.float32 ) actual = short_side_scale(video, 5, backend="opencv") self.assertEqual(actual.shape, (3, 20, 10, 5)) def test_short_side_scale_height_shorter_opencv(self): video = thwc_to_cthw(create_dummy_video_frames(20, 10, 20)).to( dtype=torch.float32 ) actual = short_side_scale(video, 5, backend="opencv") self.assertEqual(actual.shape, (3, 20, 5, 10)) def test_short_side_scale_equal_size_opencv(self): video = thwc_to_cthw(create_dummy_video_frames(20, 10, 10)).to( dtype=torch.float32 ) actual = short_side_scale(video, 10, backend="opencv") self.assertEqual(actual.shape, (3, 20, 10, 10)) def test_random_short_side_scale_height_shorter_pytorch_with_boxes(self): video = thwc_to_cthw(create_dummy_video_frames(20, 10, 20)).to( dtype=torch.float32 ) boxes = create_random_bbox(7, 10, 20) actual, scaled_boxes = random_short_side_scale_with_boxes( video, min_size=4, max_size=8, backend="pytorch", boxes=boxes ) self.assertEqual(actual.shape[0], 3) self.assertEqual(actual.shape[1], 20) self.assertTrue(actual.shape[2] <= 8 and actual.shape[2] >= 4) self._check_boxes(7, actual.shape[2], actual.shape[3], boxes) def test_short_side_scale_height_shorter_pytorch_with_boxes(self): video = thwc_to_cthw(create_dummy_video_frames(20, 10, 20)).to( dtype=torch.float32 ) boxes = create_random_bbox(7, 10, 20) actual, scaled_boxes = short_side_scale_with_boxes( video, boxes=boxes, size=5, backend="pytorch", ) self.assertEqual(actual.shape, (3, 20, 5, 10)) self._check_boxes(7, 5, 10, boxes) def test_torchscriptable_input_output(self): video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) # Test all the torchscriptable tensors. for transform in [UniformTemporalSubsample(10), RandomShortSideScale(10, 20)]: transform_script = torch.jit.script(transform) self.assertTrue(isinstance(transform_script, torch.jit.ScriptModule)) # Seed before each transform to force determinism. torch.manual_seed(0) output = transform(video) torch.manual_seed(0) script_output = transform_script(video) self.assertTrue(output.equal(script_output)) def test_uniform_temporal_subsample_repeated(self): video = thwc_to_cthw(create_dummy_video_frames(32, 10, 10)).to( dtype=torch.float32 ) actual = uniform_temporal_subsample_repeated(video, (1, 4)) expected_shape = ((3, 32, 10, 10), (3, 8, 10, 10)) for idx in range(len(actual)): self.assertEqual(actual[idx].shape, expected_shape[idx]) def test_uniform_crop(self): # For videos with height < width. video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) # Left crop. actual = uniform_crop(video, size=20, spatial_idx=0) self.assertTrue(actual.equal(video[:, :, 5:25, :20])) # Center crop. actual = uniform_crop(video, size=20, spatial_idx=1) self.assertTrue(actual.equal(video[:, :, 5:25, 10:30])) # Right crop. actual = uniform_crop(video, size=20, spatial_idx=2) self.assertTrue(actual.equal(video[:, :, 5:25, 20:])) # For videos with height > width. video = thwc_to_cthw(create_dummy_video_frames(20, 40, 30)).to( dtype=torch.float32 ) # Top crop. actual = uniform_crop(video, size=20, spatial_idx=0) self.assertTrue(actual.equal(video[:, :, :20, 5:25])) # Center crop. actual = uniform_crop(video, size=20, spatial_idx=1) self.assertTrue(actual.equal(video[:, :, 10:30, 5:25])) # Bottom crop. actual = uniform_crop(video, size=20, spatial_idx=2) self.assertTrue(actual.equal(video[:, :, 20:, 5:25])) def test_uniform_crop_with_boxes(self): # For videos with height < width. video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) boxes_inp = create_random_bbox(7, 30, 40) # Left crop. actual, boxes = uniform_crop_with_boxes( video, size=20, spatial_idx=0, boxes=boxes_inp ) self.assertTrue(actual.equal(video[:, :, 5:25, :20])) self._check_boxes(7, actual.shape[-2], actual.shape[-1], boxes) # Center crop. actual, boxes = uniform_crop_with_boxes( video, size=20, spatial_idx=1, boxes=boxes_inp ) self.assertTrue(actual.equal(video[:, :, 5:25, 10:30])) self._check_boxes(7, actual.shape[-2], actual.shape[-1], boxes) # Right crop. actual, boxes = uniform_crop_with_boxes( video, size=20, spatial_idx=2, boxes=boxes_inp ) self.assertTrue(actual.equal(video[:, :, 5:25, 20:])) self._check_boxes(7, actual.shape[-2], actual.shape[-1], boxes) # For videos with height > width. video = thwc_to_cthw(create_dummy_video_frames(20, 40, 30)).to( dtype=torch.float32 ) # Top crop. actual, boxes = uniform_crop_with_boxes( video, size=20, spatial_idx=0, boxes=boxes_inp ) self.assertTrue(actual.equal(video[:, :, :20, 5:25])) self._check_boxes(7, actual.shape[-2], actual.shape[-1], boxes) # Center crop. actual, boxes = uniform_crop_with_boxes( video, size=20, spatial_idx=1, boxes=boxes_inp ) self.assertTrue(actual.equal(video[:, :, 10:30, 5:25])) self._check_boxes(7, actual.shape[-2], actual.shape[-1], boxes) # Bottom crop. actual, boxes = uniform_crop_with_boxes( video, size=20, spatial_idx=2, boxes=boxes_inp ) self.assertTrue(actual.equal(video[:, :, 20:, 5:25])) self._check_boxes(7, actual.shape[-2], actual.shape[-1], boxes) def test_random_crop_with_boxes(self): # For videos with height < width. video = thwc_to_cthw(create_dummy_video_frames(15, 30, 40)).to( dtype=torch.float32 ) boxes_inp = create_random_bbox(7, 30, 40) actual, boxes = random_crop_with_boxes(video, size=20, boxes=boxes_inp) self.assertEqual(actual.shape, (3, 15, 20, 20)) self._check_boxes(7, actual.shape[2], actual.shape[3], boxes) def test_uniform_crop_transform(self): video = thwc_to_cthw(create_dummy_video_frames(10, 30, 40)).to( dtype=torch.float32 ) test_clip = {"video": video, "aug_index": 1, "label": 0} transform = UniformCropVideo(20) actual = transform(test_clip) c, t, h, w = actual["video"].shape self.assertEqual(c, 3) self.assertEqual(t, 10) self.assertEqual(h, 20) self.assertEqual(w, 20) self.assertTrue(actual["video"].equal(video[:, :, 5:25, 10:30])) def test_clip_boxes(self): boxes_inp = create_random_bbox(7, 40, 80) clipped_boxes = clip_boxes_to_image(boxes_inp, 20, 40) self._check_boxes(7, 20, 40, clipped_boxes) def test_horizontal_flip_with_boxes(self): video = thwc_to_cthw(create_dummy_video_frames(10, 20, 40)).to( dtype=torch.float32 ) boxes_inp = create_random_bbox(7, 20, 40) actual, boxes = horizontal_flip_with_boxes(0.0, video, boxes_inp) self.assertTrue(actual.equal(video)) self.assertTrue(boxes.equal(boxes_inp)) actual, boxes = horizontal_flip_with_boxes(1.0, video, boxes_inp) self.assertEqual(actual.shape, video.shape) self._check_boxes(7, actual.shape[-2], actual.shape[-1], boxes) self.assertTrue(actual.flip((-1)).equal(video)) def test_normalize(self): video = thwc_to_cthw(create_dummy_video_frames(10, 30, 40)).to( dtype=torch.float32 ) transform = Normalize(video.mean(), video.std()) actual = transform(video) self.assertAlmostEqual(actual.mean().item(), 0) self.assertAlmostEqual(actual.std().item(), 1) def test_center_crop(self): video = thwc_to_cthw(create_dummy_video_frames(10, 30, 40)).to( dtype=torch.float32 ) transform = CenterCropVideo(10) actual = transform(video) c, t, h, w = actual.shape self.assertEqual(c, 3) self.assertEqual(t, 10) self.assertEqual(h, 10) self.assertEqual(w, 10) self.assertTrue(actual.equal(video[:, :, 10:20, 15:25])) def test_convert_to_one_hot(self): # Test without label smooth. num_class = 5 num_samples = 10 labels = torch.arange(0, num_samples) % num_class one_hot = convert_to_one_hot(labels, num_class) self.assertEqual(one_hot.sum(), num_samples) label_value = 1.0 for index in range(num_samples): label = labels[index] self.assertEqual(one_hot[index][label], label_value) # Test with label smooth. labels = torch.arange(0, num_samples) % num_class label_smooth = 0.1 one_hot_smooth = convert_to_one_hot( labels, num_class, label_smooth=label_smooth ) self.assertEqual(one_hot_smooth.sum(), num_samples) label_value_smooth = 1 - label_smooth + label_smooth / num_class for index in range(num_samples): label = labels[index] self.assertEqual(one_hot_smooth[index][label], label_value_smooth) def test_OpSampler(self): # Test with weights. n_transform = 3 transform_list = [lambda x, i=i: x.fill_(i) for i in range(n_transform)] transform_weight = [1] * n_transform transform = OpSampler(transform_list, transform_weight) input_tensor = torch.rand(1) out_tensor = transform(input_tensor) self.assertTrue(out_tensor.sum() in list(range(n_transform))) # Test without weights. input_tensor = torch.rand(1) transform_no_weight = OpSampler(transform_list) out_tensor = transform_no_weight(input_tensor) self.assertTrue(out_tensor.sum() in list(range(n_transform))) # Make sure each transform is sampled without replacement. transform_op_values = [3, 5, 7] all_possible_out = [15, 21, 35] transform_list = [lambda x, i=i: x * i for i in transform_op_values] test_time = 100 transform_no_replacement = OpSampler(transform_list, num_sample_op=2) for _ in range(test_time): input_tensor = torch.ones(1) out_tensor = transform_no_replacement(input_tensor) self.assertTrue(out_tensor.sum() in all_possible_out) # Make sure each transform is sampled with replacement. transform_op_values = [3, 5, 7] possible_replacement_out = [9, 25, 49] input_tensor = torch.ones(1) transform_list = [lambda x, i=i: x * i for i in transform_op_values] test_time = 100 transform_no_replacement = OpSampler( transform_list, replacement=True, num_sample_op=2 ) replace_time = 0 for _ in range(test_time): input_tensor = torch.ones(1) out_tensor = transform_no_replacement(input_tensor) if out_tensor.sum() in possible_replacement_out: replace_time += 1 self.assertTrue(replace_time > 0) # Test without weights. transform_op_values = [3.0, 5.0, 7.0] input_tensor = torch.ones(1) transform_list = [lambda x, i=i: x * i for i in transform_op_values] test_time = 10000 weights = [10.0, 2.0, 1.0] transform_no_replacement = OpSampler(transform_list, weights) weight_counter = Counter() for _ in range(test_time): input_tensor = torch.ones(1) out_tensor = transform_no_replacement(input_tensor) weight_counter[out_tensor.sum().item()] += 1 for index, w in enumerate(weights): gt_dis = w / sum(weights) out_key = transform_op_values[index] self.assertTrue( np.allclose(weight_counter[out_key] / test_time, gt_dis, rtol=0.2) ) def test_mixup(self): # Test images. batch_size = 2 h_size = 10 w_size = 10 c_size = 3 input_images = torch.rand(batch_size, c_size, h_size, w_size) input_images[0, :].fill_(0) input_images[1, :].fill_(1) alpha = 1.0 label_smoothing = 0.0 num_classes = 5 transform_mixup = MixUp( alpha=alpha, label_smoothing=label_smoothing, num_classes=num_classes, ) labels = torch.arange(0, batch_size) % num_classes mixed_images, mixed_labels = transform_mixup(input_images, labels) gt_image_sum = h_size * w_size * c_size label_sum = batch_size self.assertTrue( np.allclose(mixed_images.sum().item(), gt_image_sum, rtol=0.001) ) self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001)) self.assertEqual(mixed_labels.size(0), batch_size) self.assertEqual(mixed_labels.size(1), num_classes) # Test videos. batch_size = 2 h_size = 10 w_size = 10 c_size = 3 t_size = 2 input_video = torch.rand(batch_size, c_size, t_size, h_size, w_size) input_video[0, :].fill_(0) input_video[1, :].fill_(1) alpha = 1.0 label_smoothing = 0.0 num_classes = 5 transform_mixup = MixUp( alpha=alpha, label_smoothing=label_smoothing, num_classes=num_classes, ) labels = torch.arange(0, batch_size) % num_classes mixed_videos, mixed_labels = transform_mixup(input_video, labels) gt_video_sum = h_size * w_size * c_size * t_size label_sum = batch_size self.assertTrue( np.allclose(mixed_videos.sum().item(), gt_video_sum, rtol=0.001) ) self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001)) self.assertEqual(mixed_labels.size(0), batch_size) self.assertEqual(mixed_labels.size(1), num_classes) # Test videos with label smoothing. input_video = torch.rand(batch_size, c_size, t_size, h_size, w_size) input_video[0, :].fill_(0) input_video[1, :].fill_(1) alpha = 1.0 label_smoothing = 0.2 num_classes = 5 transform_mixup = MixUp( alpha=alpha, label_smoothing=label_smoothing, num_classes=num_classes, ) labels = torch.arange(0, batch_size) % num_classes mixed_videos, mixed_labels = transform_mixup(input_video, labels) gt_video_sum = h_size * w_size * c_size * t_size label_sum = batch_size self.assertTrue( np.allclose(mixed_videos.sum().item(), gt_video_sum, rtol=0.001) ) self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001)) self.assertEqual(mixed_labels.size(0), batch_size) self.assertEqual(mixed_labels.size(1), num_classes) # Check the smoothing value is in label. smooth_value = label_smoothing / num_classes self.assertTrue(smooth_value in torch.unique(mixed_labels)) def test_cutmix(self): torch.manual_seed(0) # Test images. batch_size = 2 h_size = 10 w_size = 10 c_size = 3 input_images = torch.rand(batch_size, c_size, h_size, w_size) input_images[0, :].fill_(0) input_images[1, :].fill_(1) alpha = 1.0 label_smoothing = 0.0 num_classes = 5 transform_cutmix = CutMix( alpha=alpha, label_smoothing=label_smoothing, num_classes=num_classes, ) labels = torch.arange(0, batch_size) % num_classes mixed_images, mixed_labels = transform_cutmix(input_images, labels) gt_image_sum = h_size * w_size * c_size label_sum = batch_size self.assertTrue( np.allclose(mixed_images.sum().item(), gt_image_sum, rtol=0.001) ) self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001)) self.assertEqual(mixed_labels.size(0), batch_size) self.assertEqual(mixed_labels.size(1), num_classes) # Test videos. batch_size = 2 h_size = 10 w_size = 10 c_size = 3 t_size = 2 input_video = torch.rand(batch_size, c_size, t_size, h_size, w_size) input_video[0, :].fill_(0) input_video[1, :].fill_(1) alpha = 1.0 label_smoothing = 0.0 num_classes = 5 transform_cutmix = CutMix( alpha=alpha, label_smoothing=label_smoothing, num_classes=num_classes, ) labels = torch.arange(0, batch_size) % num_classes mixed_videos, mixed_labels = transform_cutmix(input_video, labels) gt_video_sum = h_size * w_size * c_size * t_size label_sum = batch_size self.assertTrue( np.allclose(mixed_videos.sum().item(), gt_video_sum, rtol=0.001) ) self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001)) self.assertEqual(mixed_labels.size(0), batch_size) self.assertEqual(mixed_labels.size(1), num_classes) # Test videos with label smoothing. input_video = torch.rand(batch_size, c_size, t_size, h_size, w_size) input_video[0, :].fill_(0) input_video[1, :].fill_(1) alpha = 1.0 label_smoothing = 0.2 num_classes = 5 transform_cutmix = CutMix( alpha=alpha, label_smoothing=label_smoothing, num_classes=num_classes, ) labels = torch.arange(0, batch_size) % num_classes mixed_videos, mixed_labels = transform_cutmix(input_video, labels) gt_video_sum = h_size * w_size * c_size * t_size label_sum = batch_size self.assertTrue( np.allclose(mixed_videos.sum().item(), gt_video_sum, rtol=0.001) ) self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001)) self.assertEqual(mixed_labels.size(0), batch_size) self.assertEqual(mixed_labels.size(1), num_classes) # Check the smoothing value is in label. smooth_value = label_smoothing / num_classes self.assertTrue(smooth_value in torch.unique(mixed_labels)) # Check cutmixed video has both 0 and 1. # Run 20 times to avoid rare cases where the random box is empty. test_times = 20 seen_all_value1 = False seen_all_value2 = False for _ in range(test_times): mixed_videos, mixed_labels = transform_cutmix(input_video, labels) if 0 in mixed_videos[0, :] and 1 in mixed_videos[0, :]: seen_all_value1 = True if 0 in mixed_videos[1, :] and 1 in mixed_videos[1, :]: seen_all_value2 = True if seen_all_value1 and seen_all_value2: break self.assertTrue(seen_all_value1) self.assertTrue(seen_all_value2) def test_mixvideo(self): self.assertRaises(AssertionError, MixVideo, cutmix_prob=2.0) torch.manual_seed(0) # Test images. batch_size = 2 h_size = 10 w_size = 10 c_size = 3 input_images = torch.rand(batch_size, c_size, h_size, w_size) input_images[0, :].fill_(0) input_images[1, :].fill_(1) mixup_alpha = 1.0 cutmix_alpha = 1.0 label_smoothing = 0.0 num_classes = 5 transform_mix = MixVideo( mixup_alpha=mixup_alpha, cutmix_alpha=cutmix_alpha, label_smoothing=label_smoothing, num_classes=num_classes, ) labels = torch.arange(0, batch_size) % num_classes mixed_images, mixed_labels = transform_mix(input_images, labels) gt_image_sum = h_size * w_size * c_size label_sum = batch_size self.assertTrue( np.allclose(mixed_images.sum().item(), gt_image_sum, rtol=0.001) ) self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001)) self.assertEqual(mixed_labels.size(0), batch_size) self.assertEqual(mixed_labels.size(1), num_classes) # Test videos. batch_size = 2 h_size = 10 w_size = 10 c_size = 3 t_size = 2 input_video = torch.rand(batch_size, c_size, t_size, h_size, w_size) input_video[0, :].fill_(0) input_video[1, :].fill_(1) mixup_alpha = 1.0 cutmix_alpha = 1.0 label_smoothing = 0.0 num_classes = 5 transform_mix = MixVideo( mixup_alpha=mixup_alpha, cutmix_alpha=cutmix_alpha, label_smoothing=label_smoothing, num_classes=num_classes, ) labels = torch.arange(0, batch_size) % num_classes mixed_videos, mixed_labels = transform_mix(input_video, labels) gt_video_sum = h_size * w_size * c_size * t_size label_sum = batch_size self.assertTrue( np.allclose(mixed_videos.sum().item(), gt_video_sum, rtol=0.001) ) self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001)) self.assertEqual(mixed_labels.size(0), batch_size) self.assertEqual(mixed_labels.size(1), num_classes) def _check_boxes(self, num_boxes, height, width, boxes): self.assertEqual(boxes.shape, (num_boxes, 4)) self.assertTrue(boxes[:, [0, 2]].min() >= 0 and boxes[:, [0, 2]].max() < width) self.assertTrue(boxes[:, [1, 3]].min() >= 0 and boxes[:, [1, 3]].max() < height) def test_randaug(self): # Test default RandAugment. t, c, h, w = 8, 3, 200, 200 test_time = 20 video_tensor = torch.rand(t, c, h, w) video_rand_aug_fn = RandAugment() for _ in range(test_time): video_tensor_aug = video_rand_aug_fn(video_tensor) self.assertTrue(video_tensor.size() == video_tensor_aug.size()) self.assertTrue(video_tensor.dtype == video_tensor_aug.dtype) # Make sure the video is in range. self.assertTrue(video_tensor_aug.max().item() <= 1) self.assertTrue(video_tensor_aug.min().item() >= 0) # Test RandAugment with uniform sampling. t, c, h, w = 8, 3, 200, 200 test_time = 20 video_tensor = torch.rand(t, c, h, w) video_rand_aug_fn = RandAugment(sampling_type="uniform") for _ in range(test_time): video_tensor_aug = video_rand_aug_fn(video_tensor) self.assertTrue(video_tensor.size() == video_tensor_aug.size()) self.assertTrue(video_tensor.dtype == video_tensor_aug.dtype) # Make sure the video is in range. self.assertTrue(video_tensor_aug.max().item() <= 1) self.assertTrue(video_tensor_aug.min().item() >= 0) # Test if default fill color if found. # Test multiple times due to randomness. t, c, h, w = 8, 3, 200, 200 test_time = 40 video_tensor = torch.ones(t, c, h, w) video_rand_aug_fn = RandAugment( num_layers=1, prob=1, sampling_type="gaussian", ) found_fill_color = 0 for _ in range(test_time): video_tensor_aug = video_rand_aug_fn(video_tensor) if 0.5 in video_tensor_aug: found_fill_color += 1 self.assertTrue(found_fill_color >= 1) def test_random_resized_crop(self): # Test default parameters. crop_size = 10 video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) transform = RandomResizedCrop( target_height=crop_size, target_width=crop_size, scale=(0.08, 1.0), aspect_ratio=(3.0 / 4.0, 4.0 / 3.0), ) video_resized = transform(video) c, t, h, w = video_resized.shape self.assertEqual(c, 3) self.assertEqual(t, 20) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) self.assertEqual(video_resized.dtype, torch.float32) # Test reversed parameters. crop_size = 29 video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) transform = RandomResizedCrop( target_height=crop_size, target_width=crop_size, scale=(1.8, 0.08), aspect_ratio=(4.0 / 3.0, 3.0 / 4.0), shift=True, ) video_resized = transform(video) c, t, h, w = video_resized.shape self.assertEqual(c, 3) self.assertEqual(t, 20) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) self.assertEqual(video_resized.dtype, torch.float32) # Test one channel. crop_size = 10 video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) transform = RandomResizedCrop( target_height=crop_size, target_width=crop_size, scale=(1.8, 1.2), aspect_ratio=(4.0 / 3.0, 3.0 / 4.0), ) video_resized = transform(video[0:1, :, :, :]) c, t, h, w = video_resized.shape self.assertEqual(c, 1) self.assertEqual(t, 20) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) self.assertEqual(video_resized.dtype, torch.float32) # Test interpolation. crop_size = 10 video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) transform = RandomResizedCrop( target_height=crop_size, target_width=crop_size, scale=(0.08, 1.0), aspect_ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation="bicubic", ) video_resized = transform(video) c, t, h, w = video_resized.shape self.assertEqual(c, 3) self.assertEqual(t, 20) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) self.assertEqual(video_resized.dtype, torch.float32) # Test log_uniform_ratio. crop_size = 10 video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) transform = RandomResizedCrop( target_height=crop_size, target_width=crop_size, scale=(0.08, 1.0), aspect_ratio=(3.0 / 4.0, 4.0 / 3.0), log_uniform_ratio=False, ) video_resized = transform(video) c, t, h, w = video_resized.shape self.assertEqual(c, 3) self.assertEqual(t, 20) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) self.assertEqual(video_resized.dtype, torch.float32) def test_augmix(self): # Test default AugMix. t, c, h, w = 8, 3, 200, 200 test_time = 20 video_tensor = torch.rand(t, c, h, w) video_augmix_fn = AugMix() for _ in range(test_time): video_tensor_aug = video_augmix_fn(video_tensor) self.assertTrue(video_tensor.size() == video_tensor_aug.size()) self.assertTrue(video_tensor.dtype == video_tensor_aug.dtype) # Make sure the video is in range. self.assertTrue(video_tensor_aug.max().item() <= 1) self.assertTrue(video_tensor_aug.min().item() >= 0) # Test AugMix with non-default parameters. t, c, h, w = 8, 3, 200, 200 test_time = 20 video_tensor = torch.rand(t, c, h, w) video_augmix_fn = AugMix(magnitude=9, alpha=0.5, width=4, depth=3) for _ in range(test_time): video_tensor_aug = video_augmix_fn(video_tensor) self.assertTrue(video_tensor.size() == video_tensor_aug.size()) self.assertTrue(video_tensor.dtype == video_tensor_aug.dtype) # Make sure the video is in range. self.assertTrue(video_tensor_aug.max().item() <= 1) self.assertTrue(video_tensor_aug.min().item() >= 0) # Test AugMix with uint8 video. t, c, h, w = 8, 3, 200, 200 test_time = 20 video_tensor = torch.randint(0, 255, (t, c, h, w)).type(torch.uint8) video_augmix_fn = AugMix(transform_hparas={"fill": (128, 128, 128)}) for _ in range(test_time): video_tensor_aug = video_augmix_fn(video_tensor) self.assertTrue(video_tensor.size() == video_tensor_aug.size()) self.assertTrue(video_tensor.dtype == video_tensor_aug.dtype) # Make sure the video is in range. self.assertTrue(video_tensor_aug.max().item() <= 255) self.assertTrue(video_tensor_aug.min().item() >= 0) # Compare results of AugMix for uint8 and float. t, c, h, w = 8, 3, 200, 200 test_time = 40 video_tensor_uint8 = torch.randint(0, 255, (t, c, h, w)).type(torch.uint8) video_tensor_float = (video_tensor_uint8 / 255.0).type(torch.float32) video_augmix_fn_uint8 = AugMix( width=1, depth=1, transform_hparas={"fill": (128, 128, 128)} ) video_augmix_fn_float = AugMix(width=1, depth=1) for i in range(test_time): torch.set_rng_state(torch.manual_seed(i).get_state()) video_tensor_uint8_aug = video_augmix_fn_uint8(video_tensor_uint8) torch.set_rng_state(torch.manual_seed(i).get_state()) video_tensor_float_aug = video_augmix_fn_float(video_tensor_float) self.assertTrue( torch.mean( torch.abs((video_tensor_uint8_aug / 255.0) - video_tensor_float_aug) ) < 0.01 ) self.assertTrue(video_tensor_uint8.size() == video_tensor_uint8_aug.size()) self.assertTrue(video_tensor_uint8.dtype == video_tensor_uint8_aug.dtype) self.assertTrue(video_tensor_float.size() == video_tensor_float_aug.size()) self.assertTrue(video_tensor_float.dtype == video_tensor_float_aug.dtype) # Make sure the video is in range. self.assertTrue(video_tensor_uint8_aug.max().item() <= 255) self.assertTrue(video_tensor_uint8_aug.min().item() >= 0) self.assertTrue(video_tensor_float_aug.max().item() <= 255) self.assertTrue(video_tensor_float_aug.min().item() >= 0) # Test asserts. self.assertRaises(AssertionError, AugMix, magnitude=11) self.assertRaises(AssertionError, AugMix, magnitude=1.1) self.assertRaises(AssertionError, AugMix, alpha=-0.3) self.assertRaises(AssertionError, AugMix, width=0) def test_permute(self): video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) for p in list(permutations(range(0, 4))): self.assertTrue(video.permute(*p).equal(Permute(p)(video))) def test_video_transform_factory(self): # Test asserts/raises. self.assertRaises(TypeError, create_video_transform, mode="val", crop_size="s") self.assertRaises( AssertionError, create_video_transform, mode="val", crop_size=30, min_size=10, ) self.assertRaises( AssertionError, create_video_transform, mode="val", crop_size=(30, 40), min_size=35, ) self.assertRaises( AssertionError, create_video_transform, mode="val", remove_key="key" ) self.assertRaises( AssertionError, create_video_transform, mode="val", aug_paras={"magnitude": 10}, ) self.assertRaises( NotImplementedError, create_video_transform, mode="train", aug_type="xyz" ) # Test train mode. video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) test_clip = {"video": video, "audio1": None, "audio2": None, "label": 0} num_subsample = 10 crop_size = 10 transform = create_video_transform( mode="train", num_samples=num_subsample, convert_to_float=False, video_mean=[video.mean()] * 3, video_std=[video.std()] * 3, min_size=15, crop_size=crop_size, ) transform_dict = create_video_transform( mode="train", video_key="video", remove_key=["audio1", "audio2"], num_samples=num_subsample, convert_to_float=False, video_mean=[video.mean()] * 3, video_std=[video.std()] * 3, min_size=15, crop_size=crop_size, ) transform_frame = create_video_transform( mode="train", num_samples=None, convert_to_float=False, video_mean=[video.mean()] * 3, video_std=[video.std()] * 3, min_size=15, crop_size=crop_size, ) video_tensor_transformed = transform(video) video_dict_transformed = transform_dict(test_clip) video_frame_transformed = transform_frame(video[:, 0:1, :, :]) c, t, h, w = video_tensor_transformed.shape self.assertEqual(c, 3) self.assertEqual(t, num_subsample) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) c, t, h, w = video_dict_transformed["video"].shape self.assertEqual(c, 3) self.assertEqual(t, num_subsample) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) self.assertFalse("audio1" in video_dict_transformed) self.assertFalse("audio2" in video_dict_transformed) c, t, h, w = video_frame_transformed.shape self.assertEqual(c, 3) self.assertEqual(t, 1) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) # Test val mode. video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to( dtype=torch.float32 ) test_clip = {"video": video, "audio": None, "label": 0} test_clip2 = {"video": video, "audio": None, "label": 0} num_subsample = 10 transform = create_video_transform( mode="val", num_samples=num_subsample, convert_to_float=False, video_mean=[video.mean()] * 3, video_std=[video.std()] * 3, min_size=15, crop_size=crop_size, ) transform_dict = create_video_transform( mode="val", video_key="video", num_samples=num_subsample, convert_to_float=False, video_mean=[video.mean()] * 3, video_std=[video.std()] * 3, min_size=15, crop_size=crop_size, ) transform_comp = Compose( [ ApplyTransformToKey( key="video", transform=Compose( [ UniformTemporalSubsample(num_subsample), NormalizeVideo([video.mean()] * 3, [video.std()] * 3), ShortSideScale(size=15), CenterCropVideo(crop_size), ] ), ) ] ) transform_frame = create_video_transform( mode="val", num_samples=None, convert_to_float=False, video_mean=[video.mean()] * 3, video_std=[video.std()] * 3, min_size=15, crop_size=crop_size, ) video_tensor_transformed = transform(video) video_dict_transformed = transform_dict(test_clip) video_comp_transformed = transform_comp(test_clip2) video_frame_transformed = transform_frame(video[:, 0:1, :, :]) self.assertTrue(video_tensor_transformed.equal(video_dict_transformed["video"])) self.assertTrue( video_dict_transformed["video"].equal(video_comp_transformed["video"]) ) torch.testing.assert_close( video_frame_transformed, video_tensor_transformed[:, 0:1, :, :] ) c, t, h, w = video_dict_transformed["video"].shape self.assertEqual(c, 3) self.assertEqual(t, num_subsample) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) self.assertTrue("audio" in video_dict_transformed) c, t, h, w = video_frame_transformed.shape self.assertEqual(c, 3) self.assertEqual(t, 1) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) # Test uint8 video. video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)) test_clip = {"video": video, "audio": None, "label": 0} transform_uint8 = create_video_transform( mode="val", num_samples=num_subsample, convert_to_float=True, min_size=15, crop_size=crop_size, ) transform_float32 = create_video_transform( mode="val", num_samples=num_subsample, convert_to_float=False, min_size=15, crop_size=crop_size, ) video_uint8_transformed = transform_uint8(video) video_float32_transformed = transform_float32( video.to(dtype=torch.float32) / 255.0 ) self.assertRaises( AssertionError, transform_uint8, video.to(dtype=torch.float32) ) self.assertTrue(video_uint8_transformed.equal(video_float32_transformed)) c, t, h, w = video_uint8_transformed.shape self.assertEqual(c, 3) self.assertEqual(t, num_subsample) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) c, t, h, w = video_float32_transformed.shape self.assertEqual(c, 3) self.assertEqual(t, num_subsample) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) # Test augmentations. video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)) transform_randaug = create_video_transform( mode="train", num_samples=num_subsample, min_size=15, crop_size=crop_size, aug_type="randaug", ) transform_augmix = create_video_transform( mode="train", num_samples=num_subsample, min_size=15, crop_size=crop_size, aug_type="augmix", ) transform_randaug_paras = create_video_transform( mode="train", num_samples=num_subsample, min_size=15, crop_size=crop_size, aug_type="randaug", aug_paras={ "magnitude": 8, "num_layers": 3, "prob": 0.7, "sampling_type": "uniform", }, ) transform_augmix_paras = create_video_transform( mode="train", num_samples=num_subsample, min_size=15, crop_size=crop_size, aug_type="augmix", aug_paras={"magnitude": 5, "alpha": 0.5, "width": 2, "depth": 3}, ) video_randaug_transformed = transform_randaug(video) video_augmix_transformed = transform_augmix(video) video_randaug_paras_transformed = transform_randaug_paras(video) video_augmix_paras_transformed = transform_augmix_paras(video) c, t, h, w = video_randaug_transformed.shape self.assertEqual(c, 3) self.assertEqual(t, num_subsample) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) c, t, h, w = video_augmix_transformed.shape self.assertEqual(c, 3) self.assertEqual(t, num_subsample) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) c, t, h, w = video_randaug_paras_transformed.shape self.assertEqual(c, 3) self.assertEqual(t, num_subsample) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) c, t, h, w = video_augmix_paras_transformed.shape self.assertEqual(c, 3) self.assertEqual(t, num_subsample) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) # Test Inception-style cropping. video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)) transform_inception = create_video_transform( mode="train", num_samples=num_subsample, min_size=15, crop_size=crop_size, random_resized_crop_paras={}, ) video_inception_transformed = transform_inception(video) c, t, h, w = video_inception_transformed.shape self.assertEqual(c, 3) self.assertEqual(t, num_subsample) self.assertEqual(h, crop_size) self.assertEqual(w, crop_size) def test_div_255(self): t, c, h, w = 8, 3, 200, 200 video_tensor = torch.rand(t, c, h, w) output_tensor = div_255(video_tensor) expect_tensor = video_tensor / 255 self.assertEqual(output_tensor.shape, video_tensor.shape) self.assertTrue(bool(torch.all(torch.eq(output_tensor, expect_tensor))))