|
|
|
import tensorflow as tf |
|
import numpy as np |
|
from einops import rearrange |
|
from decord import VideoReader |
|
|
|
num_frames = 16 |
|
input_size = 224 |
|
patch_size = (16, 16) |
|
IMAGENET_MEAN = np.array([0.45, 0.45, 0.45]) |
|
IMAGENET_STD = np.array([0.225, 0.225, 0.225]) |
|
|
|
def format_frames(frame, output_size): |
|
frame = tf.image.convert_image_dtype(frame, tf.uint8) |
|
frame = tf.image.resize(frame, size=output_size) |
|
frame = frame / 255. |
|
frame = frame - IMAGENET_MEAN |
|
frame = frame / IMAGENET_STD |
|
return frame |
|
|
|
def read_video(file_path): |
|
container = VideoReader(file_path) |
|
return container |
|
|
|
def frame_sampling(container, num_frames): |
|
interval = len(container) // num_frames |
|
bids = np.arange(num_frames) * interval |
|
offset = np.random.randint(interval, size=bids.shape) |
|
frame_index = bids + offset |
|
frames = container.get_batch(frame_index).asnumpy() |
|
frames = np.stack(frames) |
|
frames = format_frames(frames, [input_size] * 2) |
|
return frames |
|
|
|
def denormalize(image): |
|
image = image.numpy() if not isinstance(image, np.ndarray) else image |
|
image = image * IMAGENET_STD + IMAGENET_MEAN |
|
image = (image * 255).clip(0, 255).astype('uint8') |
|
return image |
|
|
|
def reconstrunction(input_frame, bool_mask, pretrained_pred): |
|
img_squeeze = rearrange( |
|
input_frame.numpy(), |
|
'b (t p0) (h p1) (w p2) c -> b (t h w) (p0 p1 p2) c', |
|
p0=2, p1=patch_size[0], p2=patch_size[0] |
|
) |
|
img_mean = np.mean(img_squeeze, axis=-2, keepdims=True) |
|
img_variance = np.var(img_squeeze, axis=-2, ddof=1, keepdims=True) |
|
img_norm = (img_squeeze - img_mean) / (np.sqrt(img_variance) + 1e-6) |
|
img_patch = rearrange(img_norm, 'b n p c -> b n (p c)') |
|
img_patch[bool_mask] = pretrained_pred |
|
|
|
|
|
mask = np.ones_like(img_patch) |
|
mask[bool_mask] = 0 |
|
mask = rearrange( |
|
mask, 'b n (p c) -> b n p c', c=3 |
|
) |
|
mask = rearrange( |
|
mask, |
|
'b (t h w) (p0 p1 p2) c -> b (t p0) (h p1) (w p2) c', |
|
p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14 |
|
) |
|
|
|
|
|
rec_img = rearrange(img_patch, 'b n (p c) -> b n p c', c=3) |
|
|
|
|
|
img_mean = np.mean(img_squeeze, axis=-2, keepdims=True) |
|
img_std = np.sqrt(np.var(img_squeeze, axis=-2, ddof=1, keepdims=True) + 1e-6) |
|
rec_img = rec_img * img_std + img_mean |
|
rec_img = rearrange( |
|
rec_img, |
|
'b (t h w) (p0 p1 p2) c -> b (t p0) (h p1) (w p2) c', |
|
p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14 |
|
) |
|
|
|
return ( |
|
rec_img[0], |
|
mask[0] |
|
) |
|
|
|
|
|
class TubeMaskingGenerator: |
|
def __init__(self, input_size, mask_ratio): |
|
self.frames, self.height, self.width = input_size |
|
self.num_patches_per_frame = self.height * self.width |
|
self.total_patches = self.frames * self.num_patches_per_frame |
|
self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) |
|
self.total_masks = self.frames * self.num_masks_per_frame |
|
|
|
def __repr__(self): |
|
repr_str = "Maks: total patches {}, mask patches {}".format( |
|
self.total_patches, self.total_masks |
|
) |
|
return repr_str |
|
|
|
def __call__(self): |
|
mask_per_frame = np.hstack([ |
|
np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), |
|
np.ones(self.num_masks_per_frame), |
|
]) |
|
np.random.shuffle(mask_per_frame) |
|
mask = np.tile(mask_per_frame, (self.frames,1)).flatten() |
|
return mask |