File size: 4,816 Bytes
8ec10cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import random
import imageio
import numpy as np
import torch.utils.data as data
from data import common
from utils import interact
class Dataset(data.Dataset):
"""Basic dataloader class
"""
def __init__(self, args, mode='train'):
super(Dataset, self).__init__()
self.args = args
self.mode = mode
self.modes = ()
self.set_modes()
self._check_mode()
self.set_keys()
if self.mode == 'train':
dataset = args.data_train
elif self.mode == 'val':
dataset = args.data_val
elif self.mode == 'test':
dataset = args.data_test
elif self.mode == 'demo':
pass
else:
raise NotImplementedError('not implemented for this mode: {}!'.format(self.mode))
if self.mode == 'demo':
self.subset_root = args.demo_input_dir
else:
self.subset_root = os.path.join(args.data_root, dataset, self.mode)
self.blur_list = []
self.sharp_list = []
self._scan()
def set_modes(self):
self.modes = ('train', 'val', 'test', 'demo')
def _check_mode(self):
"""Should be called in the child class __init__() after super
"""
if self.mode not in self.modes:
raise NotImplementedError('mode error: not for {}'.format(self.mode))
return
def set_keys(self):
self.blur_key = 'blur' # to be overwritten by child class
self.sharp_key = 'sharp' # to be overwritten by child class
self.non_blur_keys = []
self.non_sharp_keys = []
return
def _scan(self, root=None):
"""Should be called in the child class __init__() after super
"""
if root is None:
root = self.subset_root
if self.blur_key in self.non_blur_keys:
self.non_blur_keys.remove(self.blur_key)
if self.sharp_key in self.non_sharp_keys:
self.non_sharp_keys.remove(self.sharp_key)
def _key_check(path, true_key, false_keys):
path = os.path.join(path, '')
if path.find(true_key) >= 0:
for false_key in false_keys:
if path.find(false_key) >= 0:
return False
return True
else:
return False
def _get_list_by_key(root, true_key, false_keys):
data_list = []
for sub, dirs, files in os.walk(root):
if not dirs:
file_list = [os.path.join(sub, f) for f in files]
if _key_check(sub, true_key, false_keys):
data_list += file_list
data_list.sort()
return data_list
def _rectify_keys():
self.blur_key = os.path.join(self.blur_key, '')
self.non_blur_keys = [os.path.join(non_blur_key, '') for non_blur_key in self.non_blur_keys]
self.sharp_key = os.path.join(self.sharp_key, '')
self.non_sharp_keys = [os.path.join(non_sharp_key, '') for non_sharp_key in self.non_sharp_keys]
_rectify_keys()
self.blur_list = _get_list_by_key(root, self.blur_key, self.non_blur_keys)
self.sharp_list = _get_list_by_key(root, self.sharp_key, self.non_sharp_keys)
if len(self.sharp_list) > 0:
assert(len(self.blur_list) == len(self.sharp_list))
return
def __getitem__(self, idx):
blur = imageio.imread(self.blur_list[idx], pilmode='RGB')
if len(self.sharp_list) > 0:
sharp = imageio.imread(self.sharp_list[idx], pilmode='RGB')
imgs = [blur, sharp]
else:
imgs = [blur]
pad_width = 0 # dummy value
if self.mode == 'train':
imgs = common.crop(*imgs, ps=self.args.patch_size)
if self.args.augment:
imgs = common.augment(*imgs, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=self.args.rgb_range)
imgs[0] = common.add_noise(imgs[0], sigma_sigma=2, rgb_range=self.args.rgb_range)
elif self.mode == 'demo':
imgs[0], pad_width = common.pad(imgs[0], divisor=2**(self.args.n_scales-1)) # pad in case of non-divisible size
else:
pass # deliver test image as is.
if self.args.gaussian_pyramid:
imgs = common.generate_pyramid(*imgs, n_scales=self.args.n_scales)
imgs = common.np2tensor(*imgs)
relpath = os.path.relpath(self.blur_list[idx], self.subset_root)
blur = imgs[0]
sharp = imgs[1] if len(imgs) > 1 else False
return blur, sharp, pad_width, idx, relpath
def __len__(self):
return len(self.blur_list)
# return 32
|