from .utils import RandomResizedCropNP from .utils import SimmimMaskGenerator import torchvision.transforms as T class SimmimTransform: """ torchvision transform which transforms the input imagery into addition to generating a MiM mask """ def __init__(self, config): self.transform_img = \ T.Compose([ RandomResizedCropNP(scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)), T.ToTensor(), T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), ]) if config.MODEL.TYPE in ['swin', 'swinv2']: model_patch_size = config.MODEL.SWINV2.PATCH_SIZE else: raise NotImplementedError self.mask_generator = SimmimMaskGenerator( input_size=config.DATA.IMG_SIZE, mask_patch_size=config.DATA.MASK_PATCH_SIZE, model_patch_size=model_patch_size, mask_ratio=config.DATA.MASK_RATIO, ) def __call__(self, img): img = self.transform_img(img) mask = self.mask_generator() return img, mask class TensorResizeTransform: """ torchvision transform which transforms the input imagery into addition to generating a MiM mask """ def __init__(self, config): self.transform_img = \ T.Compose([ T.ToTensor(), T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), ]) def __call__(self, img): img = self.transform_img(img) return img