Caleb Spradlin
initial commit
ab687e7
raw
history blame
3.64 kB
import torch
import numpy as np
from numba import njit
# TRANSFORMS UTILS
class RandomResizedCropNP(object):
"""
Numpy implementation of RandomResizedCrop
"""
def __init__(self,
scale=(0.08, 1.0),
ratio=(3.0/4.0, 4.0/3.0)):
self.scale = scale
self.ratio = ratio
def __call__(self, img):
height, width = img.shape[:2]
area = height * width
for _ in range(10):
target_area = np.random.uniform(*self.scale) * area
aspect_ratio = np.random.uniform(*self.ratio)
w = int(round(np.sqrt(target_area * aspect_ratio)))
h = int(round(np.sqrt(target_area / aspect_ratio)))
if np.random.random() < 0.5:
w, h = h, w
if w <= width and h <= height:
x1 = np.random.randint(0, width - w + 1)
y1 = np.random.randint(0, height - h + 1)
cropped = img[y1:y1+h, x1:x1+w, :]
cropped = np.moveaxis(cropped, -1, 0)
cropped_resized = torch.nn.functional.interpolate(
torch.from_numpy(cropped).unsqueeze(0),
size=height,
mode='bicubic',
align_corners=False)
cropped_squeezed_numpy = cropped_resized.squeeze().numpy()
cropped_squeezed_numpy = np.moveaxis(
cropped_squeezed_numpy, 0, -1)
return cropped_squeezed_numpy
# if crop was not successful after 10 attempts, use center crop
w = min(width, height)
x1 = (width - w) // 2
y1 = (height - w) // 2
cropped = img[y1:y1+w, x1:x1+w, :]
cropped = np.moveaxis(cropped, -1, 0)
cropped_resized = torch.nn.functional.interpolate(torch.from_numpy(
cropped).unsqueeze(0),
size=height,
mode='bicubic',
align_corners=False)
cropped_squeezed_numpy = cropped_resized.squeeze().numpy()
cropped_squeezed_numpy = np.moveaxis(cropped_squeezed_numpy, 0, -1)
return cropped_squeezed_numpy
# MASKING
class SimmimMaskGenerator:
"""
Generates the masks for masked-image-modeling
"""
def __init__(self,
input_size=192,
mask_patch_size=32,
model_patch_size=4,
mask_ratio=0.6):
self.input_size = input_size
self.mask_patch_size = mask_patch_size
self.model_patch_size = model_patch_size
self.mask_ratio = mask_ratio
assert self.input_size % self.mask_patch_size == 0
assert self.mask_patch_size % self.model_patch_size == 0
self.rand_size = self.input_size // self.mask_patch_size
self.scale = self.mask_patch_size // self.model_patch_size
self.token_count = self.rand_size ** 2
self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
def __call__(self):
mask = make_simmim_mask(self.token_count, self.mask_count,
self.rand_size, self.scale)
mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
return mask
@njit()
def make_simmim_mask(token_count, mask_count, rand_size, scale):
"""JIT-compiled random mask generation
Args:
token_count
mask_count
rand_size
scale
Returns:
mask
"""
mask_idx = np.random.permutation(token_count)[:mask_count]
mask = np.zeros(token_count, dtype=np.int64)
mask[mask_idx] = 1
mask = mask.reshape((rand_size, rand_size))
return mask