|
import os |
|
import random |
|
import imageio |
|
import numpy as np |
|
import torch.utils.data as data |
|
|
|
from data import common |
|
|
|
from utils import interact |
|
|
|
class Dataset(data.Dataset): |
|
"""Basic dataloader class |
|
""" |
|
def __init__(self, args, mode='train'): |
|
super(Dataset, self).__init__() |
|
self.args = args |
|
self.mode = mode |
|
|
|
self.modes = () |
|
self.set_modes() |
|
self._check_mode() |
|
|
|
self.set_keys() |
|
|
|
if self.mode == 'train': |
|
dataset = args.data_train |
|
elif self.mode == 'val': |
|
dataset = args.data_val |
|
elif self.mode == 'test': |
|
dataset = args.data_test |
|
elif self.mode == 'demo': |
|
pass |
|
else: |
|
raise NotImplementedError('not implemented for this mode: {}!'.format(self.mode)) |
|
|
|
if self.mode == 'demo': |
|
self.subset_root = args.demo_input_dir |
|
else: |
|
self.subset_root = os.path.join(args.data_root, dataset, self.mode) |
|
|
|
self.blur_list = [] |
|
self.sharp_list = [] |
|
|
|
self._scan() |
|
|
|
def set_modes(self): |
|
self.modes = ('train', 'val', 'test', 'demo') |
|
|
|
def _check_mode(self): |
|
"""Should be called in the child class __init__() after super |
|
""" |
|
if self.mode not in self.modes: |
|
raise NotImplementedError('mode error: not for {}'.format(self.mode)) |
|
|
|
return |
|
|
|
def set_keys(self): |
|
self.blur_key = 'blur' |
|
self.sharp_key = 'sharp' |
|
|
|
self.non_blur_keys = [] |
|
self.non_sharp_keys = [] |
|
|
|
return |
|
|
|
def _scan(self, root=None): |
|
"""Should be called in the child class __init__() after super |
|
""" |
|
if root is None: |
|
root = self.subset_root |
|
|
|
if self.blur_key in self.non_blur_keys: |
|
self.non_blur_keys.remove(self.blur_key) |
|
if self.sharp_key in self.non_sharp_keys: |
|
self.non_sharp_keys.remove(self.sharp_key) |
|
|
|
def _key_check(path, true_key, false_keys): |
|
path = os.path.join(path, '') |
|
if path.find(true_key) >= 0: |
|
for false_key in false_keys: |
|
if path.find(false_key) >= 0: |
|
return False |
|
|
|
return True |
|
else: |
|
return False |
|
|
|
def _get_list_by_key(root, true_key, false_keys): |
|
data_list = [] |
|
for sub, dirs, files in os.walk(root): |
|
if not dirs: |
|
file_list = [os.path.join(sub, f) for f in files] |
|
if _key_check(sub, true_key, false_keys): |
|
data_list += file_list |
|
|
|
data_list.sort() |
|
|
|
return data_list |
|
|
|
def _rectify_keys(): |
|
self.blur_key = os.path.join(self.blur_key, '') |
|
self.non_blur_keys = [os.path.join(non_blur_key, '') for non_blur_key in self.non_blur_keys] |
|
self.sharp_key = os.path.join(self.sharp_key, '') |
|
self.non_sharp_keys = [os.path.join(non_sharp_key, '') for non_sharp_key in self.non_sharp_keys] |
|
|
|
_rectify_keys() |
|
|
|
self.blur_list = _get_list_by_key(root, self.blur_key, self.non_blur_keys) |
|
self.sharp_list = _get_list_by_key(root, self.sharp_key, self.non_sharp_keys) |
|
|
|
if len(self.sharp_list) > 0: |
|
assert(len(self.blur_list) == len(self.sharp_list)) |
|
|
|
return |
|
|
|
def __getitem__(self, idx): |
|
|
|
blur = imageio.imread(self.blur_list[idx], pilmode='RGB') |
|
if len(self.sharp_list) > 0: |
|
sharp = imageio.imread(self.sharp_list[idx], pilmode='RGB') |
|
imgs = [blur, sharp] |
|
else: |
|
imgs = [blur] |
|
|
|
pad_width = 0 |
|
if self.mode == 'train': |
|
imgs = common.crop(*imgs, ps=self.args.patch_size) |
|
if self.args.augment: |
|
imgs = common.augment(*imgs, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=self.args.rgb_range) |
|
imgs[0] = common.add_noise(imgs[0], sigma_sigma=2, rgb_range=self.args.rgb_range) |
|
elif self.mode == 'demo': |
|
imgs[0], pad_width = common.pad(imgs[0], divisor=2**(self.args.n_scales-1)) |
|
else: |
|
pass |
|
|
|
if self.args.gaussian_pyramid: |
|
imgs = common.generate_pyramid(*imgs, n_scales=self.args.n_scales) |
|
|
|
imgs = common.np2tensor(*imgs) |
|
relpath = os.path.relpath(self.blur_list[idx], self.subset_root) |
|
|
|
blur = imgs[0] |
|
sharp = imgs[1] if len(imgs) > 1 else False |
|
|
|
return blur, sharp, pad_width, idx, relpath |
|
|
|
def __len__(self): |
|
return len(self.blur_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|