File size: 3,606 Bytes
24cfc1b
 
 
 
 
 
 
 
 
 
 
 
 
 
93ca8bb
24cfc1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

import tensorflow as tf
import numpy as np
from einops import rearrange
from decord import VideoReader

num_frames = 16
input_size = 224
patch_size = (16, 16)
IMAGENET_MEAN = np.array([0.45, 0.45, 0.45])
IMAGENET_STD = np.array([0.225, 0.225, 0.225])

def format_frames(frame, output_size):
    frame = tf.image.convert_image_dtype(frame, tf.uint8)
    frame = tf.image.resize(frame, size=output_size)
    frame = frame / 255.
    frame = frame - IMAGENET_MEAN
    frame = frame / IMAGENET_STD
    return frame

def read_video(file_path):
    container = VideoReader(file_path)
    return container

def frame_sampling(container, num_frames):
    interval = len(container) // num_frames
    bids = np.arange(num_frames) * interval
    offset = np.random.randint(interval, size=bids.shape)
    frame_index = bids + offset
    frames = container.get_batch(frame_index).asnumpy()
    frames = np.stack(frames)
    frames = format_frames(frames, [input_size] * 2)
    return frames

def denormalize(image):
    image = image.numpy() if not isinstance(image, np.ndarray) else image
    image = image * IMAGENET_STD + IMAGENET_MEAN
    image = (image * 255).clip(0, 255).astype('uint8')
    return image

def reconstrunction(input_frame, bool_mask, pretrained_pred):
    img_squeeze = rearrange(
        input_frame.numpy(), 
        'b (t p0) (h p1) (w p2) c -> b (t h w) (p0 p1 p2) c', 
        p0=2, p1=patch_size[0], p2=patch_size[0]
    )
    img_mean = np.mean(img_squeeze, axis=-2, keepdims=True)
    img_variance = np.var(img_squeeze, axis=-2, ddof=1, keepdims=True)
    img_norm = (img_squeeze - img_mean) / (np.sqrt(img_variance) + 1e-6)
    img_patch = rearrange(img_norm, 'b n p c -> b n (p c)')
    img_patch[bool_mask] = pretrained_pred

    # make mask
    mask = np.ones_like(img_patch)
    mask[bool_mask] = 0
    mask = rearrange(
        mask, 'b n (p c) -> b n p c', c=3
    )
    mask = rearrange(
        mask, 
        'b (t h w) (p0 p1 p2) c -> b (t p0) (h p1) (w p2) c',
        p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14
    )

    #save reconstruction video
    rec_img = rearrange(img_patch, 'b n (p c) -> b n p c', c=3)

    # Notice: To visualize the reconstruction video, we add the predict and the original mean and var of each patch.
    img_mean = np.mean(img_squeeze, axis=-2, keepdims=True)
    img_std = np.sqrt(np.var(img_squeeze, axis=-2, ddof=1, keepdims=True) + 1e-6) 
    rec_img = rec_img * img_std + img_mean
    rec_img = rearrange(
        rec_img, 
        'b (t h w) (p0 p1 p2) c -> b (t p0) (h p1) (w p2) c', 
        p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14
    )

    return (
        rec_img[0], 
        mask[0]
    )


class TubeMaskingGenerator:
    def __init__(self, input_size, mask_ratio):
        self.frames, self.height, self.width = input_size
        self.num_patches_per_frame =  self.height * self.width
        self.total_patches = self.frames * self.num_patches_per_frame 
        self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
        self.total_masks = self.frames * self.num_masks_per_frame

    def __repr__(self):
        repr_str = "Maks: total patches {}, mask patches {}".format(
            self.total_patches, self.total_masks
        )
        return repr_str

    def __call__(self):
        mask_per_frame = np.hstack([
            np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
            np.ones(self.num_masks_per_frame),
        ])
        np.random.shuffle(mask_per_frame)
        mask = np.tile(mask_per_frame, (self.frames,1)).flatten()
        return mask