|
import torch |
|
import torchvision.transforms as transforms |
|
import torch.utils.data as data |
|
from util import task |
|
from .image_folder import make_dataset |
|
import random |
|
import numpy as np |
|
import copy |
|
import skimage.morphology as sm |
|
from PIL import Image, ImageFile, ImageOps |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
|
|
|
|
|
|
class CreateDataset(data.Dataset): |
|
def __init__(self, opt): |
|
self.opt = opt |
|
self.img_paths, self.img_size = make_dataset(opt.img_file) |
|
if opt.mask_file != 'none': |
|
self.mask_paths, self.mask_size = make_dataset(opt.mask_file) |
|
self.transform = get_transform(opt, convert=False, augment=False) |
|
fixed_opt = copy.deepcopy(opt) |
|
fixed_opt.preprocess = 'scale_longside' |
|
fixed_opt.load_size = fixed_opt.fixed_size |
|
fixed_opt.no_flip = True |
|
self.transform_fixed = get_transform(fixed_opt, convert=True, augment=False) |
|
|
|
def __len__(self): |
|
"""return the total number of examples in the dataset""" |
|
return self.img_size |
|
|
|
def __getitem__(self, item): |
|
"""return a data point and its metadata information""" |
|
|
|
img_org, img, img_path = self._load_img(item) |
|
if self.opt.batch_size > 1: |
|
img_org = transforms.functional.pad(img_org, (0, 0, self.opt.fine_size-self.img_h, self.opt.fine_size-self.img_w)) |
|
img = transforms.functional.pad(img, (0, 0, self.opt.fixed_size - img.size(-1), self.opt.fixed_size - img.size(-2))) |
|
pad_mask = torch.zeros_like(img_org) |
|
pad_mask[:, :self.img_w, :self.img_h] = 1 |
|
|
|
mask, mask_type = self._load_mask(item, img_org) |
|
if self.opt.reverse_mask: |
|
if self.opt.isTrain: |
|
mask = 1 - mask if random.random() > 0.8 else mask |
|
else: |
|
mask = 1 - mask |
|
return {'img_org': img_org, 'img': img, 'img_path': img_path, 'mask': mask, 'pad_mask': pad_mask} |
|
|
|
def name(self): |
|
return "" |
|
|
|
def _load_img(self, item): |
|
"""load the original image and preprocess image""" |
|
img_path = self.img_paths[item % self.img_size] |
|
img_pil = Image.open(img_path).convert('RGB') |
|
img_org = self.transform(img_pil) |
|
img = self.transform_fixed(img_org) |
|
img_org = transforms.ToTensor()(img_org) |
|
img_pil.close() |
|
self.img_c, self.img_w, self.img_h = img_org.size() |
|
return img_org, img, img_path |
|
|
|
def _mask_dilation(self, mask): |
|
"""mask erosion for different region""" |
|
mask = np.array(mask) |
|
pixel = np.random.randint(3, 25) |
|
mask = sm.erosion(mask, sm.square(pixel)).astype(np.uint8) |
|
|
|
return mask |
|
|
|
def _load_mask(self, item, img): |
|
"""load the mask for image completion task""" |
|
c, h, w = img.size() |
|
if isinstance(self.opt.mask_type, list): |
|
mask_type_index = random.randint(0, len(self.opt.mask_type) - 1) |
|
mask_type = self.opt.mask_type[mask_type_index] |
|
else: |
|
mask_type = self.opt.mask_type |
|
|
|
if mask_type == 0: |
|
if random.random() > 0.3 and self.opt.isTrain: |
|
return task.random_regular_mask(img), mask_type |
|
return task.center_mask(img), mask_type |
|
elif mask_type == 1: |
|
return task.random_regular_mask(img), mask_type |
|
elif mask_type == 2: |
|
return task.random_irregular_mask(img), mask_type |
|
elif mask_type == 3: |
|
|
|
if self.opt.isTrain: |
|
mask_index = random.randint(0, self.mask_size-1) |
|
mask_transform = transforms.Compose( |
|
[ |
|
transforms.RandomHorizontalFlip(), |
|
transforms.RandomRotation(10), |
|
transforms.RandomCrop([self.opt.fine_size + 64, self.opt.fine_size + 64]), |
|
transforms.Resize([h, w]) |
|
] |
|
) |
|
else: |
|
mask_index = item |
|
mask_transform = transforms.Compose( |
|
[ |
|
transforms.Resize([h, w]) |
|
] |
|
) |
|
mask_pil = Image.open(self.mask_paths[mask_index]).convert('L') |
|
mask = mask_transform(mask_pil) |
|
mask_pil.close() |
|
if self.opt.isTrain: |
|
mask = self._mask_dilation(mask) |
|
else: |
|
mask = np.array(mask) < 128 |
|
mask = torch.tensor(mask).view(1, h, w).float() |
|
return mask, mask_type |
|
else: |
|
raise NotImplementedError('mask type [%s] is not implemented' % str(mask_type)) |
|
|
|
|
|
def dataloader(opt): |
|
datasets = CreateDataset(opt) |
|
dataset = data.DataLoader(datasets, batch_size=opt.batch_size, shuffle=not opt.no_shuffle, |
|
num_workers=int(opt.nThreads), drop_last=True) |
|
|
|
return dataset |
|
|
|
|
|
|
|
|
|
|
|
def _make_power_2(img, power, method=Image.BICUBIC): |
|
"""resize the image to the size of log2(base) times""" |
|
ow, oh = img.size |
|
base = 2 ** power |
|
nw, nh = int(max(1, round(ow / base)) * base), int(max(1, round(oh / base)) * base) |
|
if nw == ow and nh == oh: |
|
return img |
|
return img.resize((nw, nh), method) |
|
|
|
|
|
def _random_zoom(img, target_width, method=Image.BICUBIC): |
|
"""random resize the image scale""" |
|
zoom_level = np.random.uniform(0.8, 1.0, size=[2]) |
|
ow, oh = img.size |
|
nw, nh = int(round(max(target_width, ow * zoom_level[0]))), int(round(max(target_width, oh * zoom_level[1]))) |
|
return img.resize((nw, nh), method) |
|
|
|
|
|
def _scale_shortside(img, target_width, method=Image.BICUBIC): |
|
"""resize the short side to the target width""" |
|
ow, oh = img.size |
|
shortsize = min(ow, oh) |
|
scale = target_width / shortsize |
|
return img.resize((round(ow * scale), round(oh * scale)), method) |
|
|
|
|
|
def _scale_longside(img, target_width, method=Image.BICUBIC): |
|
"""resize the long side to the target width""" |
|
ow, oh = img.size |
|
longsize = max(ow, oh) |
|
scale = target_width / longsize |
|
return img.resize((round(ow * scale), round(oh * scale)), method) |
|
|
|
|
|
def _scale_randomside(img, target_width, method=Image.BICUBIC): |
|
"""resize the side to the target width with random side""" |
|
if random.random() > 0.5: |
|
return _scale_shortside(img, target_width, method) |
|
else: |
|
return _scale_longside(img, target_width, method) |
|
|
|
|
|
def _crop(img, pos=None, size=None): |
|
"""crop the image based on the given pos and size""" |
|
ow, oh = img.size |
|
if size is None: |
|
return img |
|
nw = min(ow, size) |
|
nh = min(oh, size) |
|
if (ow > nw or oh > nh): |
|
if pos is None: |
|
x1 = np.random.randint(0, int(ow-nw)+1) |
|
y1 = np.random.randint(0, int(oh-nh)+1) |
|
else: |
|
x1, y1 = pos |
|
return img.crop((x1, y1, x1 + nw, y1 + nh)) |
|
return img |
|
|
|
|
|
def _pad(img): |
|
"""expand the image to the square size""" |
|
ow, oh = img.size |
|
size = max(ow, oh) |
|
return ImageOps.pad(img, (size, size), centering=(0, 0)) |
|
|
|
|
|
def _flip(img, flip): |
|
if flip: |
|
return img.transpose(Image.FLIP_LEFT_RIGHT) |
|
return img |
|
|
|
|
|
def get_transform(opt, params=None, method=Image.BICUBIC, convert=True, augment=False): |
|
"""get the transform functions""" |
|
transforms_list = [] |
|
if 'resize' in opt.preprocess: |
|
osize = [opt.load_size, opt.load_size] |
|
transforms_list.append(transforms.Resize(osize)) |
|
elif 'scale_shortside' in opt.preprocess: |
|
transforms_list.append(transforms.Lambda(lambda img: _scale_shortside(img, opt.load_size, method))) |
|
elif 'scale_longside' in opt.preprocess: |
|
transforms_list.append(transforms.Lambda(lambda img: _scale_longside(img, opt.load_size, method))) |
|
elif "scale_randomside" in opt.preprocess: |
|
transforms_list.append(transforms.Lambda(lambda img: _scale_randomside(img, opt.load_size, method))) |
|
|
|
if 'zoom' in opt.preprocess: |
|
transforms_list.append(transforms.Lambda(lambda img: _random_zoom(img, opt.load_size, method))) |
|
|
|
if 'crop' in opt.preprocess and opt.isTrain: |
|
transforms_list.append(transforms.Lambda(lambda img: _crop(img, size=opt.fine_size))) |
|
if 'pad' in opt.preprocess: |
|
transforms_list.append(transforms.Lambda(lambda img: _pad(img))) |
|
|
|
transforms_list.append(transforms.Lambda(lambda img: _make_power_2(img, opt.data_powers, method))) |
|
|
|
if not opt.no_flip and opt.isTrain: |
|
transforms_list.append(transforms.RandomHorizontalFlip()) |
|
|
|
if augment and opt.isTrain: |
|
transforms_list.append(transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2)) |
|
|
|
if convert: |
|
transforms_list.append(transforms.ToTensor()) |
|
|
|
return transforms.Compose(transforms_list) |