|
import numpy as np |
|
|
|
|
|
class TubeMaskingGenerator: |
|
def __init__(self, input_size, mask_ratio): |
|
self.frames, self.height, self.width = input_size |
|
self.num_patches_per_frame = self.height * self.width |
|
self.total_patches = self.frames * self.num_patches_per_frame |
|
self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) |
|
self.total_masks = self.frames * self.num_masks_per_frame |
|
|
|
def __repr__(self): |
|
repr_str = "Maks: total patches {}, mask patches {}".format( |
|
self.total_patches, self.total_masks |
|
) |
|
return repr_str |
|
|
|
def __call__(self): |
|
mask_per_frame = np.hstack([ |
|
np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), |
|
np.ones(self.num_masks_per_frame), |
|
]) |
|
np.random.shuffle(mask_per_frame) |
|
mask = np.tile(mask_per_frame, (self.frames, 1)).flatten() |
|
return mask |
|
|
|
|
|
class RandomMaskingGenerator: |
|
def __init__(self, input_size, mask_ratio): |
|
if not isinstance(input_size, tuple): |
|
input_size = (input_size, ) * 3 |
|
|
|
self.frames, self.height, self.width = input_size |
|
|
|
self.num_patches = self.frames * self.height * self.width |
|
self.num_mask = int(mask_ratio * self.num_patches) |
|
|
|
def __repr__(self): |
|
repr_str = "Maks: total patches {}, mask patches {}".format( |
|
self.num_patches, self.num_mask) |
|
return repr_str |
|
|
|
def __call__(self): |
|
mask = np.hstack([ |
|
np.zeros(self.num_patches - self.num_mask), |
|
np.ones(self.num_mask), |
|
]) |
|
np.random.shuffle(mask) |
|
return mask |
|
|