AK391
files
d380b77
raw
history blame
13.2 kB
import glob
import logging
import os
import random
import albumentations as A
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import webdataset
from omegaconf import open_dict, OmegaConf
from skimage.feature import canny
from skimage.transform import rescale, resize
from torch.utils.data import Dataset, IterableDataset, DataLoader, DistributedSampler, ConcatDataset
from saicinpainting.evaluation.data import InpaintingDataset as InpaintingEvaluationDataset, \
OurInpaintingDataset as OurInpaintingEvaluationDataset, ceil_modulo, InpaintingEvalOnlineDataset
from saicinpainting.training.data.aug import IAAAffine2, IAAPerspective2
from saicinpainting.training.data.masks import get_mask_generator
LOGGER = logging.getLogger(__name__)
class InpaintingTrainDataset(Dataset):
def __init__(self, indir, mask_generator, transform):
self.in_files = list(glob.glob(os.path.join(indir, '**', '*.jpg'), recursive=True))
self.mask_generator = mask_generator
self.transform = transform
self.iter_i = 0
def __len__(self):
return len(self.in_files)
def __getitem__(self, item):
path = self.in_files[item]
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = self.transform(image=img)['image']
img = np.transpose(img, (2, 0, 1))
# TODO: maybe generate mask before augmentations? slower, but better for segmentation-based masks
mask = self.mask_generator(img, iter_i=self.iter_i)
self.iter_i += 1
return dict(image=img,
mask=mask)
class InpaintingTrainWebDataset(IterableDataset):
def __init__(self, indir, mask_generator, transform, shuffle_buffer=200):
self.impl = webdataset.Dataset(indir).shuffle(shuffle_buffer).decode('rgb').to_tuple('jpg')
self.mask_generator = mask_generator
self.transform = transform
def __iter__(self):
for iter_i, (img,) in enumerate(self.impl):
img = np.clip(img * 255, 0, 255).astype('uint8')
img = self.transform(image=img)['image']
img = np.transpose(img, (2, 0, 1))
mask = self.mask_generator(img, iter_i=iter_i)
yield dict(image=img,
mask=mask)
class ImgSegmentationDataset(Dataset):
def __init__(self, indir, mask_generator, transform, out_size, segm_indir, semantic_seg_n_classes):
self.indir = indir
self.segm_indir = segm_indir
self.mask_generator = mask_generator
self.transform = transform
self.out_size = out_size
self.semantic_seg_n_classes = semantic_seg_n_classes
self.in_files = list(glob.glob(os.path.join(indir, '**', '*.jpg'), recursive=True))
def __len__(self):
return len(self.in_files)
def __getitem__(self, item):
path = self.in_files[item]
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (self.out_size, self.out_size))
img = self.transform(image=img)['image']
img = np.transpose(img, (2, 0, 1))
mask = self.mask_generator(img)
segm, segm_classes= self.load_semantic_segm(path)
result = dict(image=img,
mask=mask,
segm=segm,
segm_classes=segm_classes)
return result
def load_semantic_segm(self, img_path):
segm_path = img_path.replace(self.indir, self.segm_indir).replace(".jpg", ".png")
mask = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (self.out_size, self.out_size))
tensor = torch.from_numpy(np.clip(mask.astype(int)-1, 0, None))
ohe = F.one_hot(tensor.long(), num_classes=self.semantic_seg_n_classes) # w x h x n_classes
return ohe.permute(2, 0, 1).float(), tensor.unsqueeze(0)
def get_transforms(transform_variant, out_size):
if transform_variant == 'default':
transform = A.Compose([
A.RandomScale(scale_limit=0.2), # +/- 20%
A.PadIfNeeded(min_height=out_size, min_width=out_size),
A.RandomCrop(height=out_size, width=out_size),
A.HorizontalFlip(),
A.CLAHE(),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5),
A.ToFloat()
])
elif transform_variant == 'distortions':
transform = A.Compose([
IAAPerspective2(scale=(0.0, 0.06)),
IAAAffine2(scale=(0.7, 1.3),
rotate=(-40, 40),
shear=(-0.1, 0.1)),
A.PadIfNeeded(min_height=out_size, min_width=out_size),
A.OpticalDistortion(),
A.RandomCrop(height=out_size, width=out_size),
A.HorizontalFlip(),
A.CLAHE(),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5),
A.ToFloat()
])
elif transform_variant == 'distortions_scale05_1':
transform = A.Compose([
IAAPerspective2(scale=(0.0, 0.06)),
IAAAffine2(scale=(0.5, 1.0),
rotate=(-40, 40),
shear=(-0.1, 0.1),
p=1),
A.PadIfNeeded(min_height=out_size, min_width=out_size),
A.OpticalDistortion(),
A.RandomCrop(height=out_size, width=out_size),
A.HorizontalFlip(),
A.CLAHE(),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5),
A.ToFloat()
])
elif transform_variant == 'distortions_scale03_12':
transform = A.Compose([
IAAPerspective2(scale=(0.0, 0.06)),
IAAAffine2(scale=(0.3, 1.2),
rotate=(-40, 40),
shear=(-0.1, 0.1),
p=1),
A.PadIfNeeded(min_height=out_size, min_width=out_size),
A.OpticalDistortion(),
A.RandomCrop(height=out_size, width=out_size),
A.HorizontalFlip(),
A.CLAHE(),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5),
A.ToFloat()
])
elif transform_variant == 'distortions_scale03_07':
transform = A.Compose([
IAAPerspective2(scale=(0.0, 0.06)),
IAAAffine2(scale=(0.3, 0.7), # scale 512 to 256 in average
rotate=(-40, 40),
shear=(-0.1, 0.1),
p=1),
A.PadIfNeeded(min_height=out_size, min_width=out_size),
A.OpticalDistortion(),
A.RandomCrop(height=out_size, width=out_size),
A.HorizontalFlip(),
A.CLAHE(),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5),
A.ToFloat()
])
elif transform_variant == 'distortions_light':
transform = A.Compose([
IAAPerspective2(scale=(0.0, 0.02)),
IAAAffine2(scale=(0.8, 1.8),
rotate=(-20, 20),
shear=(-0.03, 0.03)),
A.PadIfNeeded(min_height=out_size, min_width=out_size),
A.RandomCrop(height=out_size, width=out_size),
A.HorizontalFlip(),
A.CLAHE(),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5),
A.ToFloat()
])
elif transform_variant == 'non_space_transform':
transform = A.Compose([
A.CLAHE(),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5),
A.ToFloat()
])
elif transform_variant == 'no_augs':
transform = A.Compose([
A.ToFloat()
])
else:
raise ValueError(f'Unexpected transform_variant {transform_variant}')
return transform
def make_default_train_dataloader(indir, kind='default', out_size=512, mask_gen_kwargs=None, transform_variant='default',
mask_generator_kind="mixed", dataloader_kwargs=None, ddp_kwargs=None, **kwargs):
LOGGER.info(f'Make train dataloader {kind} from {indir}. Using mask generator={mask_generator_kind}')
mask_generator = get_mask_generator(kind=mask_generator_kind, kwargs=mask_gen_kwargs)
transform = get_transforms(transform_variant, out_size)
if kind == 'default':
dataset = InpaintingTrainDataset(indir=indir,
mask_generator=mask_generator,
transform=transform,
**kwargs)
elif kind == 'default_web':
dataset = InpaintingTrainWebDataset(indir=indir,
mask_generator=mask_generator,
transform=transform,
**kwargs)
elif kind == 'img_with_segm':
dataset = ImgSegmentationDataset(indir=indir,
mask_generator=mask_generator,
transform=transform,
out_size=out_size,
**kwargs)
else:
raise ValueError(f'Unknown train dataset kind {kind}')
if dataloader_kwargs is None:
dataloader_kwargs = {}
is_dataset_only_iterable = kind in ('default_web',)
if ddp_kwargs is not None and not is_dataset_only_iterable:
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['sampler'] = DistributedSampler(dataset, **ddp_kwargs)
if is_dataset_only_iterable and 'shuffle' in dataloader_kwargs:
with open_dict(dataloader_kwargs):
del dataloader_kwargs['shuffle']
dataloader = DataLoader(dataset, **dataloader_kwargs)
return dataloader
def make_default_val_dataset(indir, kind='default', out_size=512, transform_variant='default', **kwargs):
if OmegaConf.is_list(indir) or isinstance(indir, (tuple, list)):
return ConcatDataset([
make_default_val_dataset(idir, kind=kind, out_size=out_size, transform_variant=transform_variant, **kwargs) for idir in indir
])
LOGGER.info(f'Make val dataloader {kind} from {indir}')
mask_generator = get_mask_generator(kind=kwargs.get("mask_generator_kind"), kwargs=kwargs.get("mask_gen_kwargs"))
if transform_variant is not None:
transform = get_transforms(transform_variant, out_size)
if kind == 'default':
dataset = InpaintingEvaluationDataset(indir, **kwargs)
elif kind == 'our_eval':
dataset = OurInpaintingEvaluationDataset(indir, **kwargs)
elif kind == 'img_with_segm':
dataset = ImgSegmentationDataset(indir=indir,
mask_generator=mask_generator,
transform=transform,
out_size=out_size,
**kwargs)
elif kind == 'online':
dataset = InpaintingEvalOnlineDataset(indir=indir,
mask_generator=mask_generator,
transform=transform,
out_size=out_size,
**kwargs)
else:
raise ValueError(f'Unknown val dataset kind {kind}')
return dataset
def make_default_val_dataloader(*args, dataloader_kwargs=None, **kwargs):
dataset = make_default_val_dataset(*args, **kwargs)
if dataloader_kwargs is None:
dataloader_kwargs = {}
dataloader = DataLoader(dataset, **dataloader_kwargs)
return dataloader
def make_constant_area_crop_params(img_height, img_width, min_size=128, max_size=512, area=256*256, round_to_mod=16):
min_size = min(img_height, img_width, min_size)
max_size = min(img_height, img_width, max_size)
if random.random() < 0.5:
out_height = min(max_size, ceil_modulo(random.randint(min_size, max_size), round_to_mod))
out_width = min(max_size, ceil_modulo(area // out_height, round_to_mod))
else:
out_width = min(max_size, ceil_modulo(random.randint(min_size, max_size), round_to_mod))
out_height = min(max_size, ceil_modulo(area // out_width, round_to_mod))
start_y = random.randint(0, img_height - out_height)
start_x = random.randint(0, img_width - out_width)
return (start_y, start_x, out_height, out_width)