Caleb Spradlin
initial commit
ab687e7
raw
history blame
1.59 kB
from .utils import RandomResizedCropNP
from .utils import SimmimMaskGenerator
import torchvision.transforms as T
class SimmimTransform:
"""
torchvision transform which transforms the input imagery into
addition to generating a MiM mask
"""
def __init__(self, config):
self.transform_img = \
T.Compose([
RandomResizedCropNP(scale=(0.67, 1.),
ratio=(3. / 4., 4. / 3.)),
T.ToTensor(),
T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)),
])
if config.MODEL.TYPE in ['swin', 'swinv2']:
model_patch_size = config.MODEL.SWINV2.PATCH_SIZE
else:
raise NotImplementedError
self.mask_generator = SimmimMaskGenerator(
input_size=config.DATA.IMG_SIZE,
mask_patch_size=config.DATA.MASK_PATCH_SIZE,
model_patch_size=model_patch_size,
mask_ratio=config.DATA.MASK_RATIO,
)
def __call__(self, img):
img = self.transform_img(img)
mask = self.mask_generator()
return img, mask
class TensorResizeTransform:
"""
torchvision transform which transforms the input imagery into
addition to generating a MiM mask
"""
def __init__(self, config):
self.transform_img = \
T.Compose([
T.ToTensor(),
T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)),
])
def __call__(self, img):
img = self.transform_img(img)
return img