|
from .utils import RandomResizedCropNP |
|
from .utils import SimmimMaskGenerator |
|
|
|
import torchvision.transforms as T |
|
|
|
|
|
class SimmimTransform: |
|
""" |
|
torchvision transform which transforms the input imagery into |
|
addition to generating a MiM mask |
|
""" |
|
|
|
def __init__(self, config): |
|
|
|
self.transform_img = \ |
|
T.Compose([ |
|
RandomResizedCropNP(scale=(0.67, 1.), |
|
ratio=(3. / 4., 4. / 3.)), |
|
T.ToTensor(), |
|
T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), |
|
]) |
|
|
|
if config.MODEL.TYPE in ['swin', 'swinv2']: |
|
|
|
model_patch_size = config.MODEL.SWINV2.PATCH_SIZE |
|
|
|
else: |
|
|
|
raise NotImplementedError |
|
|
|
self.mask_generator = SimmimMaskGenerator( |
|
input_size=config.DATA.IMG_SIZE, |
|
mask_patch_size=config.DATA.MASK_PATCH_SIZE, |
|
model_patch_size=model_patch_size, |
|
mask_ratio=config.DATA.MASK_RATIO, |
|
) |
|
|
|
def __call__(self, img): |
|
|
|
img = self.transform_img(img) |
|
mask = self.mask_generator() |
|
|
|
return img, mask |
|
|
|
|
|
class TensorResizeTransform: |
|
""" |
|
torchvision transform which transforms the input imagery into |
|
addition to generating a MiM mask |
|
""" |
|
|
|
def __init__(self, config): |
|
|
|
self.transform_img = \ |
|
T.Compose([ |
|
T.ToTensor(), |
|
T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), |
|
]) |
|
|
|
def __call__(self, img): |
|
|
|
img = self.transform_img(img) |
|
|
|
return img |
|
|