hyliu commited on
Commit
e98653e
·
verified ·
1 Parent(s): 3ef0208

Upload folder using huggingface_hub

Browse files
deblur/experiment/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ !.gitignore
deblur/src/data/__init__.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generic dataset loader"""
2
+
3
+ from importlib import import_module
4
+
5
+ from torch.utils.data import DataLoader
6
+ from torch.utils.data import SequentialSampler, RandomSampler
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ from .sampler import DistributedEvalSampler
9
+
10
+ class Data():
11
+ def __init__(self, args):
12
+
13
+ self.modes = ['train', 'val', 'test', 'demo']
14
+
15
+ self.action = {
16
+ 'train': args.do_train,
17
+ 'val': args.do_validate,
18
+ 'test': args.do_test,
19
+ 'demo': args.demo
20
+ }
21
+
22
+ self.dataset_name = {
23
+ 'train': args.data_train,
24
+ 'val': args.data_val,
25
+ 'test': args.data_test,
26
+ 'demo': 'Demo'
27
+ }
28
+
29
+ self.args = args
30
+
31
+ def _get_data_loader(mode='train'):
32
+ dataset_name = self.dataset_name[mode]
33
+ dataset = import_module('data.' + dataset_name.lower())
34
+ dataset = getattr(dataset, dataset_name)(args, mode)
35
+
36
+ if mode == 'train':
37
+ if args.distributed:
38
+ batch_size = int(args.batch_size / args.n_GPUs) # batch size per GPU (single-node training)
39
+ sampler = DistributedSampler(dataset, shuffle=True, num_replicas=args.world_size, rank=args.rank)
40
+ num_workers = int((args.num_workers + args.n_GPUs - 1) / args.n_GPUs) # num_workers per GPU (single-node training)
41
+ else:
42
+ batch_size = args.batch_size
43
+ sampler = RandomSampler(dataset, replacement=False)
44
+ num_workers = args.num_workers
45
+ drop_last = True
46
+
47
+ elif mode in ('val', 'test', 'demo'):
48
+ if args.distributed:
49
+ batch_size = 1 # 1 image per GPU
50
+ sampler = DistributedEvalSampler(dataset, shuffle=False, num_replicas=args.world_size, rank=args.rank)
51
+ num_workers = int((args.num_workers + args.n_GPUs - 1) / args.n_GPUs) # num_workers per GPU (single-node training)
52
+ else:
53
+ batch_size = args.n_GPUs # 1 image per GPU
54
+ sampler = SequentialSampler(dataset)
55
+ num_workers = args.num_workers
56
+ drop_last = False
57
+
58
+ loader = DataLoader(
59
+ dataset=dataset,
60
+ batch_size=batch_size,
61
+ shuffle=False,
62
+ sampler=sampler,
63
+ num_workers=num_workers,
64
+ pin_memory=True,
65
+ drop_last=drop_last,
66
+ )
67
+
68
+ return loader
69
+
70
+ self.loaders = {}
71
+ for mode in self.modes:
72
+ if self.action[mode]:
73
+ self.loaders[mode] = _get_data_loader(mode)
74
+ print('===> Loading {} dataset: {}'.format(mode, self.dataset_name[mode]))
75
+ else:
76
+ self.loaders[mode] = None
77
+
78
+ def get_loader(self):
79
+ return self.loaders
deblur/src/data/common.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ from skimage.color import rgb2hsv, hsv2rgb
4
+ # from skimage.transform import pyramid_gaussian
5
+
6
+ import torch
7
+
8
+ def _apply(func, x):
9
+
10
+ if isinstance(x, (list, tuple)):
11
+ return [_apply(func, x_i) for x_i in x]
12
+ elif isinstance(x, dict):
13
+ y = {}
14
+ for key, value in x.items():
15
+ y[key] = _apply(func, value)
16
+ return y
17
+ else:
18
+ return func(x)
19
+
20
+ def crop(*args, ps=256): # patch_size
21
+ # args = [input, target]
22
+ def _get_shape(*args):
23
+ if isinstance(args[0], (list, tuple)):
24
+ return _get_shape(args[0][0])
25
+ elif isinstance(args[0], dict):
26
+ return _get_shape(list(args[0].values())[0])
27
+ else:
28
+ return args[0].shape
29
+
30
+ h, w, _ = _get_shape(args)
31
+
32
+ py = random.randrange(0, h-ps+1)
33
+ px = random.randrange(0, w-ps+1)
34
+
35
+ def _crop(img):
36
+ if img.ndim == 2:
37
+ return img[py:py+ps, px:px+ps, np.newaxis]
38
+ else:
39
+ return img[py:py+ps, px:px+ps, :]
40
+
41
+ return _apply(_crop, args)
42
+
43
+ def add_noise(*args, sigma_sigma=2, rgb_range=255):
44
+
45
+ if len(args) == 1: # usually there is only a single input
46
+ args = args[0]
47
+
48
+ sigma = np.random.normal() * sigma_sigma * rgb_range/255
49
+
50
+ def _add_noise(img):
51
+ noise = np.random.randn(*img.shape).astype(np.float32) * sigma
52
+ return (img + noise).clip(0, rgb_range)
53
+
54
+ return _apply(_add_noise, args)
55
+
56
+ def augment(*args, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=255):
57
+ """augmentation consistent to input and target"""
58
+
59
+ choices = (False, True)
60
+
61
+ hflip = hflip and random.choice(choices)
62
+ vflip = rot and random.choice(choices)
63
+ rot90 = rot and random.choice(choices)
64
+ # shuffle = shuffle
65
+
66
+ if shuffle:
67
+ rgb_order = list(range(3))
68
+ random.shuffle(rgb_order)
69
+ if rgb_order == list(range(3)):
70
+ shuffle = False
71
+
72
+ if change_saturation:
73
+ amp_factor = np.random.uniform(0.5, 1.5)
74
+
75
+ def _augment(img):
76
+ if hflip: img = img[:, ::-1, :]
77
+ if vflip: img = img[::-1, :, :]
78
+ if rot90: img = img.transpose(1, 0, 2)
79
+ if shuffle and img.ndim > 2:
80
+ if img.shape[-1] == 3: # RGB image only
81
+ img = img[..., rgb_order]
82
+
83
+ if change_saturation:
84
+ hsv_img = rgb2hsv(img)
85
+ hsv_img[..., 1] *= amp_factor
86
+
87
+ img = hsv2rgb(hsv_img).clip(0, 1) * rgb_range
88
+
89
+ return img.astype(np.float32)
90
+
91
+ return _apply(_augment, args)
92
+
93
+ def pad(img, divisor=4, pad_width=None, negative=False):
94
+
95
+ def _pad_numpy(img, divisor=4, pad_width=None, negative=False):
96
+ if pad_width is None:
97
+ (h, w, _) = img.shape
98
+ pad_h = -h % divisor
99
+ pad_w = -w % divisor
100
+ pad_width = ((0, pad_h), (0, pad_w), (0, 0))
101
+
102
+ img = np.pad(img, pad_width, mode='edge')
103
+
104
+ return img, pad_width
105
+
106
+ def _pad_tensor(img, divisor=4, pad_width=None, negative=False):
107
+
108
+ n, c, h, w = img.shape
109
+ if pad_width is None:
110
+ pad_h = -h % divisor
111
+ pad_w = -w % divisor
112
+ pad_width = (0, pad_w, 0, pad_h)
113
+ else:
114
+ try:
115
+ pad_h = pad_width[0][1]
116
+ pad_w = pad_width[1][1]
117
+ if isinstance(pad_h, torch.Tensor):
118
+ pad_h = pad_h.item()
119
+ if isinstance(pad_w, torch.Tensor):
120
+ pad_w = pad_w.item()
121
+
122
+ pad_width = (0, pad_w, 0, pad_h)
123
+ except:
124
+ pass
125
+
126
+ if negative:
127
+ pad_width = [-val for val in pad_width]
128
+
129
+ img = torch.nn.functional.pad(img, pad_width, 'reflect')
130
+
131
+ return img, pad_width
132
+
133
+ if isinstance(img, np.ndarray):
134
+ return _pad_numpy(img, divisor, pad_width, negative)
135
+ else: # torch.Tensor
136
+ return _pad_tensor(img, divisor, pad_width, negative)
137
+
138
+ def generate_pyramid(*args, n_scales):
139
+
140
+ def _generate_pyramid(img):
141
+ if img.dtype != np.float32:
142
+ img = img.astype(np.float32)
143
+ # pyramid = list(pyramid_gaussian(img, n_scales-1, multichannel=True))
144
+ # bypass pyramid, deliver the image as is
145
+ pyramid = [img]
146
+
147
+ return pyramid
148
+
149
+ return _apply(_generate_pyramid, args)
150
+
151
+ def np2tensor(*args, rgb_range=255):
152
+ def _np2tensor(x):
153
+ np_transpose = np.ascontiguousarray(x.transpose(2, 0, 1))
154
+ tensor = torch.from_numpy(np_transpose).float()
155
+ tensor.mul_(rgb_range / 255)
156
+ return tensor
157
+
158
+ return _apply(_np2tensor, args)
159
+
160
+ def to(*args, device=None, dtype=torch.float):
161
+
162
+ def _to(x):
163
+ return x.to(device=device, dtype=dtype, non_blocking=True, copy=False)
164
+
165
+ return _apply(_to, args)
deblur/src/data/dataset.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import imageio
4
+ import numpy as np
5
+ import torch.utils.data as data
6
+
7
+ from data import common
8
+
9
+ from utils import interact
10
+
11
+ class Dataset(data.Dataset):
12
+ """Basic dataloader class
13
+ """
14
+ def __init__(self, args, mode='train'):
15
+ super(Dataset, self).__init__()
16
+ self.args = args
17
+ self.mode = mode
18
+
19
+ self.modes = ()
20
+ self.set_modes()
21
+ self._check_mode()
22
+
23
+ self.set_keys()
24
+
25
+ if self.mode == 'train':
26
+ dataset = args.data_train
27
+ elif self.mode == 'val':
28
+ dataset = args.data_val
29
+ elif self.mode == 'test':
30
+ dataset = args.data_test
31
+ elif self.mode == 'demo':
32
+ pass
33
+ else:
34
+ raise NotImplementedError('not implemented for this mode: {}!'.format(self.mode))
35
+
36
+ if self.mode == 'demo':
37
+ self.subset_root = args.demo_input_dir
38
+ else:
39
+ self.subset_root = os.path.join(args.data_root, dataset, self.mode)
40
+
41
+ self.blur_list = []
42
+ self.sharp_list = []
43
+
44
+ self._scan()
45
+
46
+ def set_modes(self):
47
+ self.modes = ('train', 'val', 'test', 'demo')
48
+
49
+ def _check_mode(self):
50
+ """Should be called in the child class __init__() after super
51
+ """
52
+ if self.mode not in self.modes:
53
+ raise NotImplementedError('mode error: not for {}'.format(self.mode))
54
+
55
+ return
56
+
57
+ def set_keys(self):
58
+ self.blur_key = 'blur' # to be overwritten by child class
59
+ self.sharp_key = 'sharp' # to be overwritten by child class
60
+
61
+ self.non_blur_keys = []
62
+ self.non_sharp_keys = []
63
+
64
+ return
65
+
66
+ def _scan(self, root=None):
67
+ """Should be called in the child class __init__() after super
68
+ """
69
+ if root is None:
70
+ root = self.subset_root
71
+
72
+ if self.blur_key in self.non_blur_keys:
73
+ self.non_blur_keys.remove(self.blur_key)
74
+ if self.sharp_key in self.non_sharp_keys:
75
+ self.non_sharp_keys.remove(self.sharp_key)
76
+
77
+ def _key_check(path, true_key, false_keys):
78
+ path = os.path.join(path, '')
79
+ if path.find(true_key) >= 0:
80
+ for false_key in false_keys:
81
+ if path.find(false_key) >= 0:
82
+ return False
83
+
84
+ return True
85
+ else:
86
+ return False
87
+
88
+ def _get_list_by_key(root, true_key, false_keys):
89
+ data_list = []
90
+ for sub, dirs, files in os.walk(root):
91
+ if not dirs:
92
+ file_list = [os.path.join(sub, f) for f in files]
93
+ if _key_check(sub, true_key, false_keys):
94
+ data_list += file_list
95
+
96
+ data_list.sort()
97
+
98
+ return data_list
99
+
100
+ def _rectify_keys():
101
+ self.blur_key = os.path.join(self.blur_key, '')
102
+ self.non_blur_keys = [os.path.join(non_blur_key, '') for non_blur_key in self.non_blur_keys]
103
+ self.sharp_key = os.path.join(self.sharp_key, '')
104
+ self.non_sharp_keys = [os.path.join(non_sharp_key, '') for non_sharp_key in self.non_sharp_keys]
105
+
106
+ _rectify_keys()
107
+
108
+ self.blur_list = _get_list_by_key(root, self.blur_key, self.non_blur_keys)
109
+ self.sharp_list = _get_list_by_key(root, self.sharp_key, self.non_sharp_keys)
110
+
111
+ if len(self.sharp_list) > 0:
112
+ assert(len(self.blur_list) == len(self.sharp_list))
113
+
114
+ return
115
+
116
+ def __getitem__(self, idx):
117
+
118
+ blur = imageio.imread(self.blur_list[idx], pilmode='RGB')
119
+ if len(self.sharp_list) > 0:
120
+ sharp = imageio.imread(self.sharp_list[idx], pilmode='RGB')
121
+ imgs = [blur, sharp]
122
+ else:
123
+ imgs = [blur]
124
+
125
+ pad_width = 0 # dummy value
126
+ if self.mode == 'train':
127
+ imgs = common.crop(*imgs, ps=self.args.patch_size)
128
+ if self.args.augment:
129
+ imgs = common.augment(*imgs, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=self.args.rgb_range)
130
+ imgs[0] = common.add_noise(imgs[0], sigma_sigma=2, rgb_range=self.args.rgb_range)
131
+ elif self.mode == 'demo':
132
+ imgs[0], pad_width = common.pad(imgs[0], divisor=2**(self.args.n_scales-1)) # pad in case of non-divisible size
133
+ else:
134
+ pass # deliver test image as is.
135
+
136
+ if self.args.gaussian_pyramid:
137
+ imgs = common.generate_pyramid(*imgs, n_scales=self.args.n_scales)
138
+
139
+ imgs = common.np2tensor(*imgs, rgb_range=self.args.rgb_range)
140
+ relpath = os.path.relpath(self.blur_list[idx], self.subset_root)
141
+
142
+ blur = imgs[0]
143
+ sharp = imgs[1] if len(imgs) > 1 else False
144
+
145
+ return blur, sharp, pad_width, idx, relpath
146
+
147
+ def __len__(self):
148
+ return len(self.blur_list)
149
+ # return 32
150
+
151
+
152
+
153
+
154
+
deblur/src/data/demo.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.dataset import Dataset
2
+
3
+ from utils import interact
4
+
5
+ class Demo(Dataset):
6
+ """Demo train, test subset class
7
+ """
8
+ def __init__(self, args, mode='demo'):
9
+ super(Demo, self).__init__(args, mode)
10
+
11
+ def set_modes(self):
12
+ self.modes = ('demo')
13
+
14
+ def set_keys(self):
15
+ super(Demo, self).set_keys()
16
+ self.blur_key = '' # all the files
17
+ self.non_sharp_keys = [''] # no files
18
+
19
+ def __getitem__(self, idx):
20
+ blur, sharp, pad_width, idx, relpath = super(Demo, self).__getitem__(idx)
21
+
22
+ return blur, sharp, pad_width, idx, relpath
deblur/src/data/gopro_large.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.dataset import Dataset
2
+
3
+ from utils import interact
4
+
5
+ class GOPRO_Large(Dataset):
6
+ """GOPRO_Large train, test subset class
7
+ """
8
+ def __init__(self, args, mode='train'):
9
+ super(GOPRO_Large, self).__init__(args, mode)
10
+
11
+ def set_modes(self):
12
+ self.modes = ('train', 'test')
13
+
14
+ def set_keys(self):
15
+ super(GOPRO_Large, self).set_keys()
16
+ self.blur_key = 'blur_gamma'
17
+ # self.sharp_key = 'sharp'
18
+
19
+ def __getitem__(self, idx):
20
+ blur, sharp, pad_width, idx, relpath = super(GOPRO_Large, self).__getitem__(idx)
21
+ relpath = relpath.replace('{}/'.format(self.blur_key), '')
22
+
23
+ return blur, sharp, pad_width, idx, relpath
deblur/src/data/reds.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.dataset import Dataset
2
+
3
+ from utils import interact
4
+
5
+ class REDS(Dataset):
6
+ """REDS train, val, test subset class
7
+ """
8
+ def __init__(self, args, mode='train'):
9
+ super(REDS, self).__init__(args, mode)
10
+
11
+ def set_modes(self):
12
+ self.modes = ('train', 'val', 'test')
13
+
14
+ def set_keys(self):
15
+ super(REDS, self).set_keys()
16
+ # self.blur_key = 'blur'
17
+ # self.sharp_key = 'sharp'
18
+
19
+ self.non_blur_keys = ['blur', 'blur_comp', 'blur_bicubic']
20
+ self.non_blur_keys.remove(self.blur_key)
21
+ self.non_sharp_keys = ['sharp_bicubic', 'sharp']
22
+ self.non_sharp_keys.remove(self.sharp_key)
23
+
24
+ def __getitem__(self, idx):
25
+ blur, sharp, pad_width, idx, relpath = super(REDS, self).__getitem__(idx)
26
+ relpath = relpath.replace('{}/{}/'.format(self.mode, self.blur_key), '')
27
+
28
+ return blur, sharp, pad_width, idx, relpath
deblur/src/data/sampler.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data import Sampler
4
+ import torch.distributed as dist
5
+
6
+
7
+ class DistributedEvalSampler(Sampler):
8
+ r"""
9
+ DistributedEvalSampler is different from DistributedSampler.
10
+ It does NOT add extra samples to make it evenly divisible.
11
+ DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever.
12
+ See this issue for details: https://github.com/pytorch/pytorch/issues/22584
13
+ shuffle is disabled by default
14
+
15
+ DistributedEvalSampler is for evaluation purpose where synchronization does not happen every epoch.
16
+ Synchronization should be done outside the dataloader loop.
17
+
18
+ Sampler that restricts data loading to a subset of the dataset.
19
+
20
+ It is especially useful in conjunction with
21
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
22
+ process can pass a :class`~torch.utils.data.DistributedSampler` instance as a
23
+ :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
24
+ original dataset that is exclusive to it.
25
+
26
+ .. note::
27
+ Dataset is assumed to be of constant size.
28
+
29
+ Arguments:
30
+ dataset: Dataset used for sampling.
31
+ num_replicas (int, optional): Number of processes participating in
32
+ distributed training. By default, :attr:`rank` is retrieved from the
33
+ current distributed group.
34
+ rank (int, optional): Rank of the current process within :attr:`num_replicas`.
35
+ By default, :attr:`rank` is retrieved from the current distributed
36
+ group.
37
+ shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
38
+ indices.
39
+ seed (int, optional): random seed used to shuffle the sampler if
40
+ :attr:`shuffle=True`. This number should be identical across all
41
+ processes in the distributed group. Default: ``0``.
42
+
43
+ .. warning::
44
+ In distributed mode, calling the :meth`set_epoch(epoch) <set_epoch>` method at
45
+ the beginning of each epoch **before** creating the :class:`DataLoader` iterator
46
+ is necessary to make shuffling work properly across multiple epochs. Otherwise,
47
+ the same ordering will be always used.
48
+
49
+ Example::
50
+
51
+ >>> sampler = DistributedSampler(dataset) if is_distributed else None
52
+ >>> loader = DataLoader(dataset, shuffle=(sampler is None),
53
+ ... sampler=sampler)
54
+ >>> for epoch in range(start_epoch, n_epochs):
55
+ ... if is_distributed:
56
+ ... sampler.set_epoch(epoch)
57
+ ... train(loader)
58
+ """
59
+
60
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, seed=0):
61
+ if num_replicas is None:
62
+ if not dist.is_available():
63
+ raise RuntimeError("Requires distributed package to be available")
64
+ num_replicas = dist.get_world_size()
65
+ if rank is None:
66
+ if not dist.is_available():
67
+ raise RuntimeError("Requires distributed package to be available")
68
+ rank = dist.get_rank()
69
+ self.dataset = dataset
70
+ self.num_replicas = num_replicas
71
+ self.rank = rank
72
+ self.epoch = 0
73
+ # self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
74
+ # self.total_size = self.num_samples * self.num_replicas
75
+ self.total_size = len(self.dataset) # true value without extra samples
76
+ indices = list(range(self.total_size))
77
+ indices = indices[self.rank:self.total_size:self.num_replicas]
78
+ self.num_samples = len(indices) # true value without extra samples
79
+
80
+ self.shuffle = shuffle
81
+ self.seed = seed
82
+
83
+ def __iter__(self):
84
+ if self.shuffle:
85
+ # deterministically shuffle based on epoch and seed
86
+ g = torch.Generator()
87
+ g.manual_seed(self.seed + self.epoch)
88
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
89
+ else:
90
+ indices = list(range(len(self.dataset)))
91
+
92
+
93
+ # # add extra samples to make it evenly divisible
94
+ # indices += indices[:(self.total_size - len(indices))]
95
+ # assert len(indices) == self.total_size
96
+
97
+ # subsample
98
+ indices = indices[self.rank:self.total_size:self.num_replicas]
99
+ assert len(indices) == self.num_samples
100
+
101
+ return iter(indices)
102
+
103
+ def __len__(self):
104
+ return self.num_samples
105
+
106
+ def set_epoch(self, epoch):
107
+ r"""
108
+ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
109
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
110
+ sampler will yield the same ordering.
111
+
112
+ Arguments:
113
+ epoch (int): _epoch number.
114
+ """
115
+ self.epoch = epoch
deblur/src/lambda_networks/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from lambda_networks.lambda_networks import LambdaLayer
2
+ from lambda_networks.rlambda_networks import RLambdaLayer
3
+ λLayer = LambdaLayer
deblur/src/lambda_networks/lambda_networks.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ # helpers functions
7
+
8
+ def exists(val):
9
+ return val is not None
10
+
11
+ def default(val, d):
12
+ return val if exists(val) else d
13
+
14
+ # lambda layer
15
+
16
+ class LambdaLayer(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim,
20
+ *,
21
+ dim_k,
22
+ n = None,
23
+ r = None,
24
+ heads = 4,
25
+ dim_out = None,
26
+ dim_u = 1):
27
+ super().__init__()
28
+ dim_out = default(dim_out, dim)
29
+ self.u = dim_u # intra-depth dimension
30
+ self.heads = heads
31
+
32
+ assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
33
+ dim_v = dim_out // heads
34
+
35
+ self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
36
+ self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
37
+ self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
38
+
39
+ self.norm_q = nn.BatchNorm2d(dim_k * heads)
40
+ self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
41
+
42
+ self.local_contexts = exists(r)
43
+ if exists(r):
44
+ assert (r % 2) == 1, 'Receptive kernel size should be odd'
45
+ self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
46
+ else:
47
+ assert exists(n), 'You must specify the total sequence length (h x w)'
48
+ self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
49
+
50
+
51
+ def forward(self, x):
52
+ b, c, hh, ww, u, h = *x.shape, self.u, self.heads
53
+
54
+ q = self.to_q(x)
55
+ k = self.to_k(x)
56
+ v = self.to_v(x)
57
+
58
+ q = self.norm_q(q)
59
+ v = self.norm_v(v)
60
+
61
+ q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
62
+ k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
63
+ v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
64
+
65
+ k = k.softmax(dim=-1)
66
+
67
+ λc = einsum('b u k m, b u v m -> b k v', k, v)
68
+ Yc = einsum('b h k n, b k v -> b h v n', q, λc)
69
+
70
+ if self.local_contexts:
71
+ v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
72
+ λp = self.pos_conv(v)
73
+ Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3))
74
+ else:
75
+ λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
76
+ Yp = einsum('b h k n, b n k v -> b h v n', q, λp)
77
+
78
+ Y = Yc + Yp
79
+ out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
80
+ return out
deblur/src/lambda_networks/rlambda_networks.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+
7
+ # helpers functions
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+
13
+ def default(val, d):
14
+ return val if exists(val) else d
15
+
16
+
17
+ # lambda layer
18
+
19
+ class RLambdaLayer(nn.Module):
20
+ def __init__(
21
+ self,
22
+ dim,
23
+ *,
24
+ dim_k,
25
+ n=None,
26
+ r=None,
27
+ heads=4,
28
+ dim_out=None,
29
+ dim_u=1,
30
+ recurrence=None
31
+ ):
32
+ super().__init__()
33
+ dim_out = default(dim_out, dim)
34
+ self.u = dim_u # intra-depth dimension
35
+ self.heads = heads
36
+
37
+ assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
38
+ dim_v = dim_out // heads
39
+
40
+ self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias=False)
41
+ self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias=False)
42
+ self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias=False)
43
+
44
+ self.norm_q = nn.BatchNorm2d(dim_k * heads)
45
+ self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
46
+
47
+ self.local_contexts = exists(r)
48
+ self.recurrence = recurrence
49
+ if exists(r):
50
+ assert (r % 2) == 1, 'Receptive kernel size should be odd'
51
+ self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding=(0, r // 2, r // 2))
52
+ else:
53
+ assert exists(n), 'You must specify the total sequence length (h x w)'
54
+ self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
55
+
56
+ def apply_lambda(self, lambda_c, lambda_p, x):
57
+ b, c, hh, ww, u, h = *x.shape, self.u, self.heads
58
+ q = self.to_q(x)
59
+ q = self.norm_q(q)
60
+ q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h=h)
61
+ Yc = einsum('b h k n, b k v -> b h v n', q, lambda_c)
62
+ if self.local_contexts:
63
+ Yp = einsum('b h k n, b k v n -> b h v n', q, lambda_p.flatten(3))
64
+ else:
65
+ Yp = einsum('b h k n, b n k v -> b h v n', q, lambda_p)
66
+ Y = Yc + Yp
67
+ out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh=hh, ww=ww)
68
+ return out
69
+
70
+ def forward(self, x):
71
+ b, c, hh, ww, u, h = *x.shape, self.u, self.heads
72
+
73
+ k = self.to_k(x)
74
+ v = self.to_v(x)
75
+
76
+ v = self.norm_v(v)
77
+
78
+ k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u=u)
79
+ v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u=u)
80
+
81
+ k = k.softmax(dim=-1)
82
+
83
+ λc = einsum('b u k m, b u v m -> b k v', k, v)
84
+
85
+ if self.local_contexts:
86
+ v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh=hh, ww=ww)
87
+ λp = self.pos_conv(v)
88
+ else:
89
+ λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
90
+ out = x
91
+ for i in range(self.recurrence):
92
+ out = self.apply_lambda(λc, λp, out)
93
+ return out
deblur/src/launch.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ distributed launcher adopted from torch.distributed.launch
2
+ usage example: https://github.com/facebookresearch/maskrcnn-benchmark
3
+ This enables using multiprocessing for each spawned process (as they are treated as main processes)
4
+ """
5
+ import sys
6
+ import subprocess
7
+ from argparse import ArgumentParser, REMAINDER
8
+
9
+ from utils import str2bool, int2str
10
+
11
+ def parse_args():
12
+ parser = ArgumentParser(description="PyTorch distributed training launch "
13
+ "helper utilty that will spawn up "
14
+ "multiple distributed processes")
15
+
16
+
17
+ parser.add_argument('--n_GPUs', type=int, default=1, help='the number of GPUs for training')
18
+
19
+ # positional
20
+ parser.add_argument("training_script", type=str,
21
+ help="The full path to the single GPU training "
22
+ "program/script to be launched in parallel, "
23
+ "followed by all the arguments for the "
24
+ "training script")
25
+
26
+ # rest from the training program
27
+ parser.add_argument('training_script_args', nargs=REMAINDER)
28
+ return parser.parse_args()
29
+
30
+ def main():
31
+ args = parse_args()
32
+
33
+ processes = []
34
+ for rank in range(0, args.n_GPUs):
35
+ cmd = [sys.executable]
36
+
37
+ cmd.append(args.training_script)
38
+ cmd.extend(args.training_script_args)
39
+
40
+ cmd += ['--distributed', 'True']
41
+ cmd += ['--launched', 'True']
42
+ cmd += ['--n_GPUs', str(args.n_GPUs)]
43
+ cmd += ['--rank', str(rank)]
44
+
45
+ process = subprocess.Popen(cmd)
46
+ processes.append(process)
47
+
48
+ for process in processes:
49
+ process.wait()
50
+ if process.returncode != 0:
51
+ raise subprocess.CalledProcessError(returncode=process.returncode,
52
+ cmd=cmd)
53
+
54
+ if __name__ == "__main__":
55
+ main()
deblur/src/loss/__init__.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from importlib import import_module
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.distributed as dist
7
+
8
+ import matplotlib.pyplot as plt
9
+ plt.switch_backend('agg') # https://github.com/matplotlib/matplotlib/issues/3466
10
+
11
+ from .metric import PSNR, SSIM
12
+
13
+ from utils import interact
14
+
15
+ def sequence_loss(sr, hr, loss_func, gamma=0.8, max_val=None):
16
+ """ Loss function defined over sequence of flow predictions """
17
+
18
+ n_recurrence = len(sr)
19
+ total_loss = 0.0
20
+ buffer=[0.0]*n_recurrence
21
+ # exlude invalid pixels and extremely large diplacements
22
+ for i in range(n_recurrence):
23
+ i_weight = gamma**(n_recurrence - i - 1)
24
+ i_loss = loss_func(sr[i],hr)
25
+ buffer[i]=i_loss.item()
26
+ # total_loss += i_weight * (valid[:, None] * i_loss).mean()
27
+ total_loss += i_weight * (i_loss)
28
+ return total_loss,buffer
29
+
30
+
31
+ class Loss(torch.nn.modules.loss._Loss):
32
+ def __init__(self, args, epoch=None, model=None, optimizer=None):
33
+ """
34
+ input:
35
+ args.loss use '+' to sum over different loss functions
36
+ use '*' to specify the loss weight
37
+
38
+ example:
39
+ 1*MSE+0.5*VGG54
40
+ loss = sum of MSE and VGG54(weight=0.5)
41
+
42
+ args.measure similar to args.loss, but without weight
43
+
44
+ example:
45
+ MSE+PSNR
46
+ measure MSE and PSNR, independently
47
+ """
48
+ super(Loss, self).__init__()
49
+
50
+ self.args = args
51
+
52
+ self.rgb_range = args.rgb_range
53
+ self.gamma=args.gamma
54
+ self.device_type = args.device_type
55
+ self.synchronized = False
56
+
57
+ self.epoch = args.start_epoch if epoch is None else epoch
58
+ self.save_dir = args.save_dir
59
+ self.save_name = os.path.join(self.save_dir, 'loss.pt')
60
+
61
+ # self.training = True
62
+ self.validating = False
63
+ self.testing = False
64
+ self.mode = 'train'
65
+ self.modes = ('train', 'val', 'test')
66
+
67
+ # Loss
68
+ self.loss = nn.ModuleDict()
69
+ self.loss_types = []
70
+ self.weight = {}
71
+ self.buffer=[0.0]*args.n_scales
72
+
73
+ self.loss_stat = {mode:{} for mode in self.modes}
74
+ # loss_stat[mode][loss_type][epoch] = loss_value
75
+ # loss_stat[mode]['Total'][epoch] = loss_total
76
+
77
+ for weighted_loss in args.loss.split('+'):
78
+ w, l = weighted_loss.split('*')
79
+ l = l.upper()
80
+ if l in ('ABS', 'L1'):
81
+ loss_type = 'L1'
82
+ func = nn.L1Loss()
83
+ elif l in ('MSE', 'L2'):
84
+ loss_type = 'L2'
85
+ func = nn.MSELoss()
86
+ elif l in ('ADV', 'GAN'):
87
+ loss_type = 'ADV'
88
+ m = import_module('loss.adversarial')
89
+ func = getattr(m, 'Adversarial')(args, model, optimizer)
90
+ else:
91
+ loss_type = l
92
+ m = import_module*'loss.{}'.format(l.lower())
93
+ func = getattr(m, l)(args)
94
+
95
+ self.loss_types += [loss_type]
96
+ self.loss[loss_type] = func
97
+ self.weight[loss_type] = float(w)
98
+
99
+ print('Loss function: {}'.format(args.loss))
100
+
101
+ # Metrics
102
+ self.do_measure = args.metric.lower() != 'none'
103
+
104
+ self.metric = nn.ModuleDict()
105
+ self.metric_types = []
106
+ self.metric_stat = {mode:{} for mode in self.modes}
107
+ # metric_stat[mode][metric_type][epoch] = metric_value
108
+
109
+ if self.do_measure:
110
+ for metric_type in args.metric.split(','):
111
+ metric_type = metric_type.upper()
112
+ if metric_type == 'PSNR':
113
+ metric_func = PSNR()
114
+ elif metric_type == 'SSIM':
115
+ metric_func = SSIM(args.device_type) # single precision
116
+ else:
117
+ raise NotImplementedError
118
+
119
+ self.metric_types += [metric_type]
120
+ self.metric[metric_type] = metric_func
121
+
122
+ print('Metrics: {}'.format(args.metric))
123
+
124
+ if args.start_epoch != 1:
125
+ self.load(args.start_epoch - 1)
126
+
127
+ for mode in self.modes:
128
+ for loss_type in self.loss:
129
+ if loss_type not in self.loss_stat[mode]:
130
+ self.loss_stat[mode][loss_type] = {} # initialize loss
131
+
132
+ if 'Total' not in self.loss_stat[mode]:
133
+ self.loss_stat[mode]['Total'] = {}
134
+
135
+ if self.do_measure:
136
+ for metric_type in self.metric:
137
+ if metric_type not in self.metric_stat[mode]:
138
+ self.metric_stat[mode][metric_type] = {}
139
+
140
+ self.count = 0
141
+ self.count_m = 0
142
+
143
+ self.to(args.device, dtype=args.dtype)
144
+
145
+ def train(self, mode=True):
146
+ super(Loss, self).train(mode)
147
+ if mode:
148
+ self.validating = False
149
+ self.testing = False
150
+ self.mode = 'train'
151
+ else: # default test mode
152
+ self.validating = False
153
+ self.testing = True
154
+ self.mode = 'test'
155
+
156
+ def validate(self):
157
+ super(Loss, self).eval()
158
+ # self.training = False
159
+ self.validating = True
160
+ self.testing = False
161
+ self.mode = 'val'
162
+
163
+ def test(self):
164
+ super(Loss, self).eval()
165
+ # self.training = False
166
+ self.validating = False
167
+ self.testing = True
168
+ self.mode = 'test'
169
+
170
+ def forward(self, input, target):
171
+ self.synchronized = False
172
+
173
+ loss = 0
174
+ weights=[0.32,0.08,0.02,0.01,0.005]
175
+ if len(input)>len(weights):
176
+ for i in range(len(input)-len(weights)):
177
+ weights.append(weights[-1]*0.5)
178
+ if len(input)==1:
179
+ weights=[1.0]
180
+ weights=weights[::-1]
181
+ def _ms_forward(input, target, func):
182
+ if isinstance(input, (list, tuple)): # loss for list output
183
+ _loss,buffer_lst=sequence_loss(input,target[0],func,gamma=self.gamma)
184
+ # first=func(input[0],target[0])
185
+ # _loss = first*weights[0]
186
+ # self.buffer=[first.item()]
187
+ # for i in range(1,len(input)):
188
+ # tmp=func(input[i],target[0])
189
+ # self.buffer.append(tmp.item())
190
+ # _loss+=tmp*weights[i]
191
+ self.buffer=buffer_lst
192
+ return _loss
193
+ elif isinstance(input, dict): # loss for dict output
194
+ _loss = []
195
+ for key in input:
196
+ _loss += [func(input[key], target[key])]
197
+ return sum(_loss)
198
+ else: # loss for tensor output
199
+ return func(input, target)
200
+
201
+ # initialize
202
+ if self.count == 0:
203
+ for loss_type in self.loss_types:
204
+ self.loss_stat[self.mode][loss_type][self.epoch] = 0
205
+ self.loss_stat[self.mode]['Total'][self.epoch] = 0
206
+
207
+ if isinstance(input, list):
208
+ count = input[0].shape[0]
209
+ else: # Tensor
210
+ count = input.shape[0] # batch size
211
+
212
+ isnan = False
213
+ for loss_type in self.loss_types:
214
+
215
+ if loss_type == 'ADV':
216
+ _loss = self.loss[loss_type](input[0], target[0], self.training) * self.weight[loss_type]
217
+ else:
218
+ _loss = _ms_forward(input, target, self.loss[loss_type]) * self.weight[loss_type]
219
+
220
+ if torch.isnan(_loss):
221
+ isnan = True # skip recording (will also be skipped at backprop)
222
+ else:
223
+ self.loss_stat[self.mode][loss_type][self.epoch] += _loss.item() * count
224
+ self.loss_stat[self.mode]['Total'][self.epoch] += _loss.item() * count
225
+
226
+ loss += _loss
227
+
228
+ if not isnan:
229
+ self.count += count
230
+
231
+ if not self.training and self.do_measure:
232
+ self.measure(input, target)
233
+
234
+ return loss
235
+
236
+ def measure(self, input, target):
237
+ if isinstance(input, (list, tuple)):
238
+ self.measure(input[0], target[0])
239
+ return
240
+ elif isinstance(input, dict):
241
+ first_key = list(input.keys())[0]
242
+ self.measure(input[first_key], target[first_key])
243
+ return
244
+ else:
245
+ pass
246
+
247
+ if self.count_m == 0:
248
+ for metric_type in self.metric_stat[self.mode]:
249
+ self.metric_stat[self.mode][metric_type][self.epoch] = 0
250
+
251
+ if isinstance(input, list):
252
+ count = input[0].shape[0]
253
+ else: # Tensor
254
+ count = input.shape[0] # batch size
255
+
256
+ for metric_type in self.metric_stat[self.mode]:
257
+
258
+ input = input.clamp(0, self.rgb_range) # not in_place
259
+ if self.rgb_range==1:
260
+ input*=255
261
+ target*=255
262
+ input.round_()
263
+ target.round_()
264
+ if self.rgb_range == 255:
265
+ target*=255
266
+ input.round_()
267
+
268
+ _metric = self.metric[metric_type](input, target)
269
+ self.metric_stat[self.mode][metric_type][self.epoch] += _metric.item() * count
270
+
271
+ self.count_m += count
272
+
273
+ return
274
+
275
+ def normalize(self):
276
+ if self.args.distributed:
277
+ dist.barrier()
278
+ if not self.synchronized:
279
+ self.all_reduce()
280
+
281
+ if self.count > 0:
282
+ for loss_type in self.loss_stat[self.mode]: # including 'Total'
283
+ self.loss_stat[self.mode][loss_type][self.epoch] /= self.count
284
+ self.count = 0
285
+
286
+ if self.count_m > 0:
287
+ for metric_type in self.metric_stat[self.mode]:
288
+ self.metric_stat[self.mode][metric_type][self.epoch] /= self.count_m
289
+ self.count_m = 0
290
+
291
+ return
292
+
293
+ def all_reduce(self, epoch=None):
294
+ # synchronize loss for distributed GPU processes
295
+
296
+ if epoch is None:
297
+ epoch = self.epoch
298
+
299
+ def _reduce_value(value, ReduceOp=dist.ReduceOp.SUM):
300
+ value_tensor = torch.Tensor([value]).to(self.args.device, self.args.dtype, non_blocking=True)
301
+ dist.all_reduce(value_tensor, ReduceOp, async_op=False)
302
+ value = value_tensor.item()
303
+ del value_tensor
304
+
305
+ return value
306
+
307
+ dist.barrier()
308
+ if self.count > 0: # I assume this should be true
309
+ self.count = _reduce_value(self.count, dist.ReduceOp.SUM)
310
+
311
+ for loss_type in self.loss_stat[self.mode]:
312
+ self.loss_stat[self.mode][loss_type][epoch] = _reduce_value(
313
+ self.loss_stat[self.mode][loss_type][epoch],
314
+ dist.ReduceOp.SUM
315
+ )
316
+
317
+ if self.count_m > 0:
318
+ self.count_m = _reduce_value(self.count_m, dist.ReduceOp.SUM)
319
+
320
+ for metric_type in self.metric_stat[self.mode]:
321
+ self.metric_stat[self.mode][metric_type][epoch] = _reduce_value(
322
+ self.metric_stat[self.mode][metric_type][epoch],
323
+ dist.ReduceOp.SUM
324
+ )
325
+
326
+ self.synchronized = True
327
+
328
+ return
329
+
330
+ def print_metrics(self):
331
+
332
+ print(self.get_metric_desc())
333
+ return
334
+
335
+ def get_last_loss(self):
336
+ return self.loss_stat[self.mode]['Total'][self.epoch]
337
+
338
+ def get_loss_desc(self):
339
+
340
+ if self.mode == 'train':
341
+ desc_prefix = 'Train'
342
+ elif self.mode == 'val':
343
+ desc_prefix = 'Validation'
344
+ else:
345
+ desc_prefix = 'Test'
346
+
347
+ loss = self.loss_stat[self.mode]['Total'][self.epoch]
348
+ if self.count > 0:
349
+ loss /= self.count
350
+ desc = '{} Loss: {:.1f}'.format(desc_prefix, loss)
351
+
352
+ if self.mode in ('val', 'test'):
353
+ metric_desc = self.get_metric_desc()
354
+ desc = '{}{}'.format(desc, metric_desc)
355
+
356
+ return desc
357
+
358
+ def get_metric_desc(self):
359
+ desc = ''
360
+ for metric_type in self.metric_stat[self.mode]:
361
+ measured = self.metric_stat[self.mode][metric_type][self.epoch]
362
+ if self.count_m > 0:
363
+ measured /= self.count_m
364
+
365
+ if metric_type == 'PSNR':
366
+ desc += ' {}: {:2.2f}'.format(metric_type, measured)
367
+ elif metric_type == 'SSIM':
368
+ desc += ' {}: {:1.4f}'.format(metric_type, measured)
369
+ else:
370
+ desc += ' {}: {:2.4f}'.format(metric_type, measured)
371
+
372
+ return desc
373
+
374
+ def step(self, plot_name=None):
375
+ self.normalize()
376
+ self.plot(plot_name)
377
+ if not self.training and self.do_measure:
378
+ # self.print_metrics()
379
+ self.plot_metric()
380
+ # self.epoch += 1
381
+
382
+ return
383
+
384
+ def save(self):
385
+
386
+ state = {
387
+ 'loss_stat': self.loss_stat,
388
+ 'metric_stat': self.metric_stat,
389
+ }
390
+ torch.save(state, self.save_name)
391
+
392
+ return
393
+
394
+ def load(self, epoch=None):
395
+
396
+ print('Loading loss record from {}'.format(self.save_name))
397
+ if os.path.exists(self.save_name):
398
+ state = torch.load(self.save_name, map_location=self.args.device)
399
+
400
+ self.loss_stat = state['loss_stat']
401
+ if 'metric_stat' in state:
402
+ self.metric_stat = state['metric_stat']
403
+ else:
404
+ pass
405
+ else:
406
+ print('no loss record found for {}!'.format(self.save_name))
407
+
408
+ if epoch is not None:
409
+ self.epoch = epoch
410
+
411
+ return
412
+
413
+ def plot(self, plot_name=None, metric=False):
414
+
415
+ self.plot_loss(plot_name)
416
+
417
+ if metric:
418
+ self.plot_metric(plot_name)
419
+ # else:
420
+ # self.plot_loss(plot_name)
421
+
422
+ return
423
+
424
+
425
+ def plot_loss(self, plot_name=None):
426
+ if plot_name is None:
427
+ plot_name = os.path.join(self.save_dir, "{}_loss.pdf".format(self.mode))
428
+
429
+ title = "{} loss".format(self.mode)
430
+
431
+ fig = plt.figure()
432
+ plt.title(title)
433
+ plt.xlabel('epochs')
434
+ plt.ylabel('loss')
435
+ plt.grid(True, linestyle=':')
436
+
437
+ for loss_type, loss_record in self.loss_stat[self.mode].items(): # including Total
438
+ axis = sorted([epoch for epoch in loss_record.keys() if epoch <= self.epoch])
439
+ value = [self.loss_stat[self.mode][loss_type][epoch] for epoch in axis]
440
+ label = loss_type
441
+
442
+ plt.plot(axis, value, label=label)
443
+
444
+ plt.xlim(0, self.epoch)
445
+ plt.legend()
446
+ plt.savefig(plot_name)
447
+ plt.close(fig)
448
+
449
+ return
450
+
451
+ def plot_metric(self, plot_name=None):
452
+ # assume there are only max 2 metrics
453
+ if plot_name is None:
454
+ plot_name = os.path.join(self.save_dir, "{}_metric.pdf".format(self.mode))
455
+
456
+ title = "{} metrics".format(self.mode)
457
+
458
+ fig, ax1 = plt.subplots()
459
+ plt.title(title)
460
+ plt.grid(True, linestyle=':')
461
+ ax1.set_xlabel('epochs')
462
+
463
+ plots = None
464
+ for metric_type, metric_record in self.metric_stat[self.mode].items():
465
+ axis = sorted([epoch for epoch in metric_record.keys() if epoch <= self.epoch])
466
+ value = [metric_record[epoch] for epoch in axis]
467
+ label = metric_type
468
+
469
+ if metric_type == 'PSNR':
470
+ ax = ax1
471
+ color='C0'
472
+ elif metric_type == 'SSIM':
473
+ ax2 = ax1.twinx()
474
+ ax = ax2
475
+ color='C1'
476
+
477
+ ax.set_ylabel(metric_type)
478
+ if plots is None:
479
+ plots = ax.plot(axis, value, label=label, color=color)
480
+ else:
481
+ plots += ax.plot(axis, value, label=label, color=color)
482
+
483
+ labels = [plot.get_label() for plot in plots]
484
+ plt.legend(plots, labels)
485
+ plt.xlim(0, self.epoch)
486
+ plt.savefig(plot_name)
487
+ plt.close(fig)
488
+
489
+ return
490
+
491
+ def sort(self):
492
+ # sort the loss/metric record
493
+ for mode in self.modes:
494
+ for loss_type, loss_epochs in self.loss_stat[mode].items():
495
+ self.loss_stat[mode][loss_type] = {epoch: loss_epochs[epoch] for epoch in sorted(loss_epochs)}
496
+
497
+ for metric_type, metric_epochs in self.metric_stat[mode].items():
498
+ self.metric_stat[mode][metric_type] = {epoch: metric_epochs[epoch] for epoch in sorted(metric_epochs)}
499
+
500
+ return self
deblur/src/loss/adversarial.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from utils import interact
5
+
6
+ import torch.cuda.amp as amp
7
+
8
+ class Adversarial(nn.modules.loss._Loss):
9
+ # pure loss function without saving & loading option
10
+ # but trains deiscriminator
11
+ def __init__(self, args, model, optimizer):
12
+ super(Adversarial, self).__init__()
13
+ self.args = args
14
+ self.model = model.model
15
+ self.optimizer = optimizer
16
+ self.scaler = amp.GradScaler(
17
+ init_scale=self.args.init_scale,
18
+ enabled=self.args.amp
19
+ )
20
+
21
+ self.gan_k = 1
22
+
23
+ self.BCELoss = nn.BCEWithLogitsLoss()
24
+
25
+ def forward(self, fake, real, training=False):
26
+ if training:
27
+ # update discriminator
28
+ fake_detach = fake.detach()
29
+ for _ in range(self.gan_k):
30
+ self.optimizer.D.zero_grad()
31
+ # d: B x 1 tensor
32
+ with amp.autocast(self.args.amp):
33
+ d_fake = self.model.D(fake_detach)
34
+ d_real = self.model.D(real)
35
+
36
+ label_fake = torch.zeros_like(d_fake)
37
+ label_real = torch.ones_like(d_real)
38
+
39
+ loss_d = self.BCELoss(d_fake, label_fake) + self.BCELoss(d_real, label_real)
40
+
41
+ self.scaler.scale(loss_d).backward(retain_graph=False)
42
+ self.scaler.step(self.optimizer.D)
43
+ self.scaler.update()
44
+ else:
45
+ d_real = self.model.D(real)
46
+ label_real = torch.ones_like(d_real)
47
+
48
+ # update generator (outside here)
49
+ d_fake_bp = self.model.D(fake)
50
+ loss_g = self.BCELoss(d_fake_bp, label_real)
51
+
52
+ return loss_g
deblur/src/loss/metric.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from skimage.metrics import peak_signal_noise_ratio, structural_similarity
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ def _expand(img):
7
+ if img.ndim < 4:
8
+ img = img.expand([1] * (4-img.ndim) + list(img.shape))
9
+
10
+ return img
11
+
12
+ class PSNR(nn.Module):
13
+ def __init__(self):
14
+ super(PSNR, self).__init__()
15
+
16
+ def forward(self, im1, im2, data_range=None):
17
+ # tensor input, constant output
18
+
19
+ if data_range is None:
20
+ data_range = 255 if im1.max() > 1 else 1
21
+
22
+ se = (im1-im2)**2
23
+ se = _expand(se)
24
+
25
+ mse = se.mean(dim=list(range(1, se.ndim)))
26
+ psnr = 10 * (data_range**2/mse).log10().mean()
27
+
28
+ return psnr
29
+
30
+ class SSIM(nn.Module):
31
+ def __init__(self, device_type='cpu', dtype=torch.float32):
32
+ super(SSIM, self).__init__()
33
+
34
+ self.device_type = device_type
35
+ self.dtype = dtype # SSIM in half precision could be inaccurate
36
+
37
+ def _get_ssim_weight():
38
+ truncate = 3.5
39
+ sigma = 1.5
40
+ r = int(truncate * sigma + 0.5) # radius as in ndimage
41
+ win_size = 2 * r + 1
42
+ nch = 3
43
+
44
+ weight = torch.Tensor([-(x - win_size//2)**2/float(2*sigma**2) for x in range(win_size)]).exp().unsqueeze(1)
45
+ weight = weight.mm(weight.t())
46
+ weight /= weight.sum()
47
+ weight = weight.repeat(nch, 1, 1, 1)
48
+
49
+ return weight
50
+
51
+ self.weight = _get_ssim_weight().to(self.device_type, dtype=self.dtype, non_blocking=True)
52
+
53
+ def forward(self, im1, im2, data_range=None):
54
+ """Implementation adopted from skimage.metrics.structural_similarity
55
+ Default arguments set to multichannel=True, gaussian_weight=True, use_sample_covariance=False
56
+ """
57
+
58
+ im1 = im1.to(self.device_type, dtype=self.dtype, non_blocking=True)
59
+ im2 = im2.to(self.device_type, dtype=self.dtype, non_blocking=True)
60
+
61
+ K1 = 0.01
62
+ K2 = 0.03
63
+ sigma = 1.5
64
+
65
+ truncate = 3.5
66
+ r = int(truncate * sigma + 0.5) # radius as in ndimage
67
+ win_size = 2 * r + 1
68
+
69
+ im1 = _expand(im1)
70
+ im2 = _expand(im2)
71
+
72
+ nch = im1.shape[1]
73
+
74
+ if im1.shape[2] < win_size or im1.shape[3] < win_size:
75
+ raise ValueError(
76
+ "win_size exceeds image extent. If the input is a multichannel "
77
+ "(color) image, set multichannel=True.")
78
+
79
+ if data_range is None:
80
+ data_range = 255 if im1.max() > 1 else 1
81
+
82
+ def filter_func(img): # no padding
83
+ return nn.functional.conv2d(img, self.weight, groups=nch).to(self.dtype)
84
+ # return torch.conv2d(img, self.weight, groups=nch).to(self.dtype)
85
+
86
+ # compute (weighted) means
87
+ ux = filter_func(im1)
88
+ uy = filter_func(im2)
89
+
90
+ # compute (weighted) variances and covariances
91
+ uxx = filter_func(im1 * im1)
92
+ uyy = filter_func(im2 * im2)
93
+ uxy = filter_func(im1 * im2)
94
+ vx = (uxx - ux * ux)
95
+ vy = (uyy - uy * uy)
96
+ vxy = (uxy - ux * uy)
97
+
98
+ R = data_range
99
+ C1 = (K1 * R) ** 2
100
+ C2 = (K2 * R) ** 2
101
+
102
+ A1, A2, B1, B2 = ((2 * ux * uy + C1,
103
+ 2 * vxy + C2,
104
+ ux ** 2 + uy ** 2 + C1,
105
+ vx + vy + C2))
106
+ D = B1 * B2
107
+ S = (A1 * A2) / D
108
+
109
+ # compute (weighted) mean of ssim
110
+ mssim = S.mean()
111
+
112
+ return mssim
deblur/src/main.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """main file that does everything"""
2
+ from utils import interact
3
+
4
+ from option import args, setup, cleanup
5
+ from data import Data
6
+ from model import Model
7
+ from loss import Loss
8
+ from optim import Optimizer
9
+ from train import Trainer
10
+
11
+ def main_worker(rank, args):
12
+ args.rank = rank
13
+ args = setup(args)
14
+
15
+ loaders = Data(args).get_loader()
16
+ model = Model(args)
17
+ model.parallelize()
18
+ optimizer = Optimizer(args, model)
19
+
20
+ criterion = Loss(args, model=model, optimizer=optimizer)
21
+
22
+ trainer = Trainer(args, model, criterion, optimizer, loaders)
23
+
24
+ if args.stay:
25
+ interact(local=locals())
26
+ exit()
27
+
28
+ if args.demo:
29
+ trainer.evaluate(epoch=args.start_epoch, mode='demo')
30
+ exit()
31
+
32
+ for epoch in range(1, args.start_epoch):
33
+ if args.do_validate:
34
+ if epoch % args.validate_every == 0:
35
+ trainer.fill_evaluation(epoch, 'val')
36
+ if args.do_test:
37
+ if epoch % args.test_every == 0:
38
+ trainer.fill_evaluation(epoch, 'test')
39
+
40
+ for epoch in range(args.start_epoch, args.end_epoch+1):
41
+ if args.do_train:
42
+ trainer.train(epoch)
43
+
44
+ if args.do_validate:
45
+ if epoch % args.validate_every == 0:
46
+ if trainer.epoch != epoch:
47
+ trainer.load(epoch)
48
+ trainer.validate(epoch)
49
+
50
+ if args.do_test:
51
+ if epoch % args.test_every == 0:
52
+ if trainer.epoch != epoch:
53
+ trainer.load(epoch)
54
+ trainer.test(epoch)
55
+
56
+ if args.rank == 0 or not args.launched:
57
+ print('')
58
+
59
+ trainer.imsaver.join_background()
60
+
61
+ cleanup(args)
62
+
63
+ def main():
64
+ main_worker(args.rank, args)
65
+
66
+ if __name__ == "__main__":
67
+ main()
deblur/src/model/LamResNet.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from . import common
4
+
5
+ from lambda_networks import LambdaLayer
6
+
7
+
8
+ def build_model(args):
9
+ return ResNet(args)
10
+
11
+
12
+ class ResNet(nn.Module):
13
+ def __init__(
14
+ self,
15
+ args,
16
+ in_channels=3,
17
+ out_channels=3,
18
+ n_feats=None,
19
+ kernel_size=None,
20
+ n_resblocks=None,
21
+ mean_shift=True,
22
+ ):
23
+ super(ResNet, self).__init__()
24
+
25
+ self.in_channels = in_channels
26
+ self.out_channels = out_channels
27
+
28
+ self.n_feats = args.n_feats if n_feats is None else n_feats
29
+ self.kernel_size = args.kernel_size if kernel_size is None else kernel_size
30
+ self.n_resblocks = args.n_resblocks if n_resblocks is None else n_resblocks
31
+
32
+ self.mean_shift = mean_shift
33
+ self.rgb_range = args.rgb_range
34
+ self.mean = self.rgb_range / 2
35
+
36
+ modules = []
37
+ modules.append(
38
+ common.default_conv(self.in_channels, self.n_feats, self.kernel_size)
39
+ )
40
+ for _ in range(self.n_resblocks // 3):
41
+ modules.append(common.ResBlock(self.n_feats, self.kernel_size))
42
+ modules.append(
43
+ LambdaLayer(
44
+ dim=self.n_feats, dim_out=self.n_feats, r=23, dim_k=16, heads=4, dim_u=1
45
+ )
46
+ )
47
+ for _ in range(self.n_resblocks // 3):
48
+ modules.append(common.ResBlock(self.n_feats, self.kernel_size))
49
+ modules.append(
50
+ LambdaLayer(
51
+ dim=self.n_feats, dim_out=self.n_feats, r=7, dim_k=16, heads=4, dim_u=4
52
+ )
53
+ )
54
+ for _ in range(self.n_resblocks // 3):
55
+ modules.append(common.ResBlock(self.n_feats, self.kernel_size))
56
+ modules.append(
57
+ common.default_conv(self.n_feats, self.n_feats, self.kernel_size)
58
+ )
59
+ modules.append(common.default_conv(self.n_feats, self.out_channels, 1))
60
+
61
+ self.body = nn.Sequential(*modules)
62
+
63
+ def forward(self, input):
64
+ if self.mean_shift:
65
+ input = input - self.mean
66
+
67
+ output = self.body(input)
68
+
69
+ if self.mean_shift:
70
+ output = output + self.mean
71
+
72
+ return output
deblur/src/model/MSResNet.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from . import common
5
+ from .ResNet import ResNet
6
+
7
+
8
+ def build_model(args):
9
+ return MSResNet(args)
10
+
11
+ class conv_end(nn.Module):
12
+ def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2):
13
+ super(conv_end, self).__init__()
14
+
15
+ modules = [
16
+ common.default_conv(in_channels, out_channels, kernel_size),
17
+ nn.PixelShuffle(ratio)
18
+ ]
19
+
20
+ self.uppath = nn.Sequential(*modules)
21
+
22
+ def forward(self, x):
23
+ return self.uppath(x)
24
+
25
+ class MSResNet(nn.Module):
26
+ def __init__(self, args):
27
+ super(MSResNet, self).__init__()
28
+
29
+ self.rgb_range = args.rgb_range
30
+ self.mean = self.rgb_range / 2
31
+
32
+ self.n_resblocks = args.n_resblocks
33
+ self.n_feats = args.n_feats
34
+ self.kernel_size = args.kernel_size
35
+
36
+ self.n_scales = args.n_scales
37
+
38
+ self.body_models = nn.ModuleList([
39
+ ResNet(args, 3, 3, mean_shift=False),
40
+ ])
41
+ for _ in range(1, self.n_scales):
42
+ self.body_models.insert(0, ResNet(args, 6, 3, mean_shift=False))
43
+
44
+ self.conv_end_models = nn.ModuleList([None])
45
+ for _ in range(1, self.n_scales):
46
+ self.conv_end_models += [conv_end(3, 12)]
47
+
48
+ def forward(self, input_pyramid):
49
+
50
+ scales = range(self.n_scales-1, -1, -1) # 0: fine, 2: coarse
51
+
52
+ for s in scales:
53
+ input_pyramid[s] = input_pyramid[s] - self.mean
54
+
55
+ output_pyramid = [None] * self.n_scales
56
+
57
+ input_s = input_pyramid[-1]
58
+ for s in scales: # [2, 1, 0]
59
+ output_pyramid[s] = self.body_models[s](input_s)
60
+ if s > 0:
61
+ up_feat = self.conv_end_models[s](output_pyramid[s])
62
+ input_s = torch.cat((input_pyramid[s-1], up_feat), 1)
63
+
64
+ for s in scales:
65
+ output_pyramid[s] = output_pyramid[s] + self.mean
66
+
67
+ return output_pyramid
deblur/src/model/MSResNetLambda.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from . import common
5
+ from .ResNet import ResNet
6
+ from lambda_network import LambdaLayer
7
+
8
+ def build_model(args):
9
+ return MSResNet(args)
10
+
11
+ class conv_end(nn.Module):
12
+ def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2):
13
+ super(conv_end, self).__init__()
14
+
15
+ modules = [
16
+ common.default_conv(in_channels, out_channels, kernel_size),
17
+ nn.PixelShuffle(ratio)
18
+ ]
19
+
20
+ self.uppath = nn.Sequential(*modules)
21
+
22
+ def forward(self, x):
23
+ return self.uppath(x)
24
+
25
+ class MSResNet(nn.Module):
26
+ def __init__(self, args):
27
+ super(MSResNet, self).__init__()
28
+
29
+ self.rgb_range = args.rgb_range
30
+ self.mean = self.rgb_range / 2
31
+
32
+ self.n_resblocks = args.n_resblocks
33
+ self.n_feats = args.n_feats
34
+ self.kernel_size = args.kernel_size
35
+
36
+ self.n_scales = args.n_scales
37
+
38
+ self.body_models = nn.ModuleList([
39
+ ResNet(args, 3, 3, mean_shift=False),
40
+ ])
41
+ self.lambda_models = nn.ModuleList([
42
+ LambdaLayer(
43
+ dim = 32, # channels going in
44
+ dim_out = 32, # channels out
45
+ n = 64 * 64, # number of input pixels (64 x 64 image)
46
+ dim_k = 16, # key dimension
47
+ heads = 4, # number of heads, for multi-query
48
+ dim_u = 1 # 'intra-depth' dimension
49
+ )
50
+ ])
51
+ for _ in range(1, self.n_scales):
52
+ self.body_models.insert(0, ResNet(args, 6, 3, mean_shift=False))
53
+
54
+ self.conv_end_models = nn.ModuleList([None])
55
+ for _ in range(1, self.n_scales):
56
+ self.conv_end_models += [conv_end(3, 12)]
57
+
58
+ def forward(self, input_pyramid):
59
+
60
+ scales = range(self.n_scales-1, -1, -1) # 0: fine, 2: coarse
61
+
62
+ for s in scales:
63
+ input_pyramid[s] = input_pyramid[s] - self.mean
64
+
65
+ output_pyramid = [None] * self.n_scales
66
+
67
+ input_s = input_pyramid[-1]
68
+ for s in scales: # [2, 1, 0]
69
+ output_pyramid[s] = self.body_models[s](input_s)
70
+ if s > 0:
71
+ up_feat = self.conv_end_models[s](output_pyramid[s])
72
+ input_s = torch.cat((input_pyramid[s-1], up_feat), 1)
73
+
74
+ for s in scales:
75
+ output_pyramid[s] = output_pyramid[s] + self.mean
76
+
77
+ return output_pyramid
deblur/src/model/RaftNet.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ import torch
4
+ from . import common
5
+
6
+ from lambda_networks import LambdaLayer
7
+
8
+
9
+ def build_model(args):
10
+ return ResNet(args)
11
+
12
+
13
+ class ConvGRU(nn.Module):
14
+ def __init__(self, hidden_dim=128, input_dim=192+128):
15
+ super(ConvGRU, self).__init__()
16
+ self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
17
+ self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
18
+ self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
19
+
20
+ def forward(self, h, x):
21
+ hx = torch.cat([h, x], dim=1)
22
+
23
+ z = torch.sigmoid(self.convz(hx))
24
+ r = torch.sigmoid(self.convr(hx))
25
+ q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
26
+
27
+ # h = (1-z) * h + z * q
28
+ # return h
29
+ return (1-z) * h + z * q
30
+
31
+ class SepConvGRU(nn.Module):
32
+ def __init__(self, hidden_dim=128, input_dim=192+128):
33
+ super(SepConvGRU, self).__init__()
34
+ self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
35
+ self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
36
+ self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37
+
38
+ self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
39
+ self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
40
+ self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41
+
42
+
43
+ def forward(self, h, x):
44
+ # horizontal
45
+ hx = torch.cat([h, x], dim=1)
46
+ z = torch.sigmoid(self.convz1(hx))
47
+ r = torch.sigmoid(self.convr1(hx))
48
+ q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
49
+ h = (1-z) * h + z * q
50
+
51
+ # vertical
52
+ hx = torch.cat([h, x], dim=1)
53
+ z = torch.sigmoid(self.convz2(hx))
54
+ r = torch.sigmoid(self.convr2(hx))
55
+ q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
56
+ h = (1-z) * h + z * q
57
+
58
+ return h
59
+
60
+
61
+ class ResNet(nn.Module):
62
+ def __init__(
63
+ self,
64
+ args,
65
+ in_channels=3,
66
+ out_channels=3,
67
+ n_feats=None,
68
+ kernel_size=None,
69
+ n_resblocks=None,
70
+ mean_shift=True,
71
+ ):
72
+ super(ResNet, self).__init__()
73
+
74
+ self.in_channels = in_channels
75
+ self.out_channels = out_channels
76
+
77
+ self.n_feats = args.n_feats if n_feats is None else n_feats
78
+ self.kernel_size = args.kernel_size if kernel_size is None else kernel_size
79
+ self.n_resblocks = args.n_resblocks if n_resblocks is None else n_resblocks
80
+
81
+ self.mean_shift = mean_shift
82
+ self.rgb_range = args.rgb_range
83
+ self.mean = self.rgb_range / 2
84
+
85
+ modules = []
86
+ m_head=[common.default_conv(self.in_channels, self.n_feats, self.kernel_size)]
87
+ for i in range(3):
88
+ m_head.append(common.ResBlock(self.n_feats, self.kernel_size))
89
+ for _ in range(self.n_resblocks // 2):
90
+ modules.append(common.ResBlock(self.n_feats, self.kernel_size))
91
+ modules.append(
92
+ LambdaLayer(
93
+ dim=self.n_feats, dim_out=self.n_feats, r=23, dim_k=16, heads=4, dim_u=4
94
+ )
95
+ )
96
+ for _ in range(self.n_resblocks // 2):
97
+ modules.append(common.ResBlock(self.n_feats, self.kernel_size))
98
+ m_tail=[]
99
+
100
+ for i in range(3):
101
+ m_tail.append(common.ResBlock(self.n_feats, self.kernel_size))
102
+
103
+ m_tail.append(
104
+ common.default_conv(self.n_feats, self.out_channels, self.kernel_size)
105
+ )
106
+ self.head=nn.Sequential(*m_head)
107
+ self.body = nn.Sequential(*modules)
108
+ self.tail=nn.Sequential(*m_tail)
109
+ self.gru=SepConvGRU(hidden_dim=self.n_feats,input_dim=self.n_feats)
110
+
111
+ def forward(self, input):
112
+ if self.mean_shift:
113
+ input = input - self.mean
114
+
115
+ output = self.body(input)
116
+
117
+ if self.mean_shift:
118
+ output = output + self.mean
119
+
120
+ return output
deblur/src/model/RecLamResNet.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from . import common
5
+ from .LamResNet import ResNet
6
+
7
+
8
+ def build_model(args):
9
+ return RecLamResNet(args)
10
+
11
+
12
+ class conv_end(nn.Module):
13
+ def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2):
14
+ super(conv_end, self).__init__()
15
+
16
+ modules = [
17
+ common.default_conv(in_channels, out_channels, kernel_size),
18
+ nn.PixelShuffle(ratio),
19
+ ]
20
+
21
+ self.uppath = nn.Sequential(*modules)
22
+
23
+ def forward(self, x):
24
+ return self.uppath(x)
25
+
26
+
27
+ class RecLamResNet(nn.Module):
28
+ def __init__(self, args):
29
+ super(RecLamResNet, self).__init__()
30
+
31
+ self.rgb_range = args.rgb_range
32
+ self.mean = self.rgb_range / 2
33
+ self.is_detach=args.detach
34
+
35
+ self.n_resblocks = args.n_resblocks
36
+ self.n_feats = args.n_feats
37
+ self.kernel_size = args.kernel_size
38
+
39
+ self.n_scales = args.n_scales
40
+
41
+ self.body_model = ResNet(args, 3, 3, mean_shift=False)
42
+
43
+ def forward(self, input_lst):
44
+ # we use a reversed list for better compact
45
+ input_lst[0] = input_lst[0] - self.mean
46
+ output_lst = [None] * self.n_scales
47
+ last_output = input_lst[0]
48
+ for i in range(self.n_scales):
49
+ if self.is_detach:
50
+ last_output=last_output.detach()
51
+ output = self.body_model(last_output) + last_output
52
+ output_lst[self.n_scales-i-1] = output + self.mean
53
+ last_output = output
54
+ return output_lst
deblur/src/model/ResNet.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from . import common
4
+
5
+ def build_model(args):
6
+ return ResNet(args)
7
+
8
+ class ResNet(nn.Module):
9
+ def __init__(self, args, in_channels=3, out_channels=3, n_feats=None, kernel_size=None, n_resblocks=None, mean_shift=True):
10
+ super(ResNet, self).__init__()
11
+
12
+ self.in_channels = in_channels
13
+ self.out_channels = out_channels
14
+
15
+ self.n_feats = args.n_feats if n_feats is None else n_feats
16
+ self.kernel_size = args.kernel_size if kernel_size is None else kernel_size
17
+ self.n_resblocks = args.n_resblocks if n_resblocks is None else n_resblocks
18
+
19
+ self.mean_shift = mean_shift
20
+ self.rgb_range = args.rgb_range
21
+ self.mean = self.rgb_range / 2
22
+
23
+ modules = []
24
+ modules.append(common.default_conv(self.in_channels, self.n_feats, self.kernel_size))
25
+ for _ in range(self.n_resblocks):
26
+ modules.append(common.ResBlock(self.n_feats, self.kernel_size))
27
+ modules.append(common.default_conv(self.n_feats, self.out_channels, self.kernel_size))
28
+
29
+ self.body = nn.Sequential(*modules)
30
+
31
+ def forward(self, input):
32
+ if self.mean_shift:
33
+ input = input - self.mean
34
+
35
+ output = self.body(input)
36
+
37
+ if self.mean_shift:
38
+ output = output + self.mean
39
+
40
+ return output
41
+
deblur/src/model/__init__.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from importlib import import_module
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
8
+
9
+ import torch.distributed as dist
10
+ from torch.nn.utils import parameters_to_vector, vector_to_parameters
11
+
12
+ from .discriminator import Discriminator
13
+
14
+ from utils import interact
15
+
16
+ class Model(nn.Module):
17
+ def __init__(self, args):
18
+ super(Model, self).__init__()
19
+
20
+ self.args = args
21
+ self.device = args.device
22
+ self.n_GPUs = args.n_GPUs
23
+ self.save_dir = os.path.join(args.save_dir, 'models')
24
+ os.makedirs(self.save_dir, exist_ok=True)
25
+
26
+ module = import_module('model.' + args.model)
27
+
28
+ self.model = nn.ModuleDict()
29
+ self.model.G = module.build_model(args)
30
+ if self.args.loss.lower().find('adv') >= 0:
31
+ self.model.D = Discriminator(self.args)
32
+ else:
33
+ self.model.D = None
34
+
35
+ self.to(args.device, dtype=args.dtype, non_blocking=True)
36
+ self.load(args.load_epoch, path=args.pretrained)
37
+
38
+ def parallelize(self):
39
+ if self.args.device_type == 'cuda':
40
+ if self.args.distributed:
41
+ Parallel = DistributedDataParallel
42
+ parallel_args = {
43
+ "device_ids": [self.args.rank],
44
+ "output_device": self.args.rank,
45
+ }
46
+ else:
47
+ Parallel = DataParallel
48
+ parallel_args = {
49
+ 'device_ids': list(range(self.n_GPUs)),
50
+ 'output_device': self.args.rank # always 0
51
+ }
52
+
53
+ for model_key in self.model:
54
+ if self.model[model_key] is not None:
55
+ self.model[model_key] = Parallel(self.model[model_key], **parallel_args)
56
+
57
+ def forward(self, input):
58
+ return self.model.G(input)
59
+
60
+ def _save_path(self, epoch):
61
+ model_path = os.path.join(self.save_dir, 'model-{:d}.pt'.format(epoch))
62
+ return model_path
63
+
64
+ def state_dict(self):
65
+ state_dict = {}
66
+ for model_key in self.model:
67
+ if self.model[model_key] is not None:
68
+ parallelized = isinstance(self.model[model_key], (DataParallel, DistributedDataParallel))
69
+ if parallelized:
70
+ state_dict[model_key] = self.model[model_key].module.state_dict()
71
+ else:
72
+ state_dict[model_key] = self.model[model_key].state_dict()
73
+
74
+ return state_dict
75
+
76
+ def load_state_dict(self, state_dict, strict=True):
77
+ for model_key in self.model:
78
+ parallelized = isinstance(self.model[model_key], (DataParallel, DistributedDataParallel))
79
+ if model_key in state_dict:
80
+ if parallelized:
81
+ self.model[model_key].module.load_state_dict(state_dict[model_key], strict)
82
+ else:
83
+ self.model[model_key].load_state_dict(state_dict[model_key], strict)
84
+
85
+ def save(self, epoch):
86
+ torch.save(self.state_dict(), self._save_path(epoch))
87
+
88
+ def load(self, epoch=None, path=None):
89
+ if path:
90
+ model_name = path
91
+ elif isinstance(epoch, int):
92
+ if epoch < 0:
93
+ epoch = self.get_last_epoch()
94
+ if epoch == 0: # epoch 0
95
+ # make sure model parameters are synchronized at initial
96
+ # for multi-node training (not in current implementation)
97
+ # self.synchronize()
98
+
99
+ return # leave model as initialized
100
+
101
+ model_name = self._save_path(epoch)
102
+ else:
103
+ raise Exception('no epoch number or model path specified!')
104
+
105
+ print('Loading model from {}'.format(model_name))
106
+ state_dict = torch.load(model_name, map_location=self.args.device)
107
+ self.load_state_dict(state_dict)
108
+
109
+ return
110
+
111
+ def synchronize(self):
112
+ if self.args.distributed:
113
+ # synchronize model parameters across nodes
114
+ vector = parameters_to_vector(self.parameters())
115
+
116
+ dist.broadcast(vector, 0) # broadcast parameters to other processes
117
+ if self.args.rank != 0:
118
+ vector_to_parameters(vector, self.parameters())
119
+
120
+ del vector
121
+
122
+ return
123
+
124
+ def get_last_epoch(self):
125
+ model_list = sorted(os.listdir(self.save_dir))
126
+ if len(model_list) == 0:
127
+ epoch = 0
128
+ else:
129
+ epoch = int(re.findall('\\d+', model_list[-1])[0]) # model example name model-100.pt
130
+
131
+ return epoch
132
+
133
+ def print(self):
134
+ print(self.model)
135
+
136
+ return
deblur/src/model/common.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ def default_conv(in_channels, out_channels, kernel_size, bias=True, groups=1):
7
+ return nn.Conv2d(
8
+ in_channels, out_channels, kernel_size,
9
+ padding=(kernel_size // 2), bias=bias, groups=groups)
10
+
11
+ def default_norm(n_feats):
12
+ return nn.BatchNorm2d(n_feats)
13
+
14
+ def default_act():
15
+ return nn.ReLU(True)
16
+
17
+ def empty_h(x, n_feats):
18
+ '''
19
+ create an empty hidden state
20
+
21
+ input
22
+ x: B x T x 3 x H x W
23
+
24
+ output
25
+ h: B x C x H/4 x W/4
26
+ '''
27
+ b = x.size(0)
28
+ h, w = x.size()[-2:]
29
+ return x.new_zeros((b, n_feats, h//4, w//4))
30
+
31
+ class Normalization(nn.Conv2d):
32
+ """Normalize input tensor value with convolutional layer"""
33
+ def __init__(self, mean=(0, 0, 0), std=(1, 1, 1)):
34
+ super(Normalization, self).__init__(3, 3, kernel_size=1)
35
+ tensor_mean = torch.Tensor(mean)
36
+ tensor_inv_std = torch.Tensor(std).reciprocal()
37
+
38
+ self.weight.data = torch.eye(3).mul(tensor_inv_std).view(3, 3, 1, 1)
39
+ self.bias.data = torch.Tensor(-tensor_mean.mul(tensor_inv_std))
40
+
41
+ for params in self.parameters():
42
+ params.requires_grad = False
43
+
44
+ class BasicBlock(nn.Sequential):
45
+ """Convolution layer + Activation layer"""
46
+ def __init__(
47
+ self, in_channels, out_channels, kernel_size, bias=True,
48
+ conv=default_conv, norm=False, act=default_act):
49
+
50
+ modules = []
51
+ modules.append(
52
+ conv(in_channels, out_channels, kernel_size, bias=bias))
53
+ if norm: modules.append(norm(out_channels))
54
+ if act: modules.append(act())
55
+
56
+ super(BasicBlock, self).__init__(*modules)
57
+
58
+ class ResBlock(nn.Module):
59
+ def __init__(
60
+ self, n_feats, kernel_size, bias=True,
61
+ conv=default_conv, norm=False, act=default_act):
62
+
63
+ super(ResBlock, self).__init__()
64
+
65
+ modules = []
66
+ for i in range(2):
67
+ modules.append(conv(n_feats, n_feats, kernel_size, bias=bias))
68
+ if norm: modules.append(norm(n_feats))
69
+ if act and i == 0: modules.append(act())
70
+
71
+ self.body = nn.Sequential(*modules)
72
+
73
+ def forward(self, x):
74
+ res = self.body(x)
75
+ res += x
76
+
77
+ return res
78
+
79
+ class ResBlock_mobile(nn.Module):
80
+ def __init__(
81
+ self, n_feats, kernel_size, bias=True,
82
+ conv=default_conv, norm=False, act=default_act, dropout=False):
83
+
84
+ super(ResBlock_mobile, self).__init__()
85
+
86
+ modules = []
87
+ for i in range(2):
88
+ modules.append(conv(n_feats, n_feats, kernel_size, bias=False, groups=n_feats))
89
+ modules.append(conv(n_feats, n_feats, 1, bias=False))
90
+ if dropout and i == 0: modules.append(nn.Dropout2d(dropout))
91
+ if norm: modules.append(norm(n_feats))
92
+ if act and i == 0: modules.append(act())
93
+
94
+ self.body = nn.Sequential(*modules)
95
+
96
+ def forward(self, x):
97
+ res = self.body(x)
98
+ res += x
99
+
100
+ return res
101
+
102
+ class Upsampler(nn.Sequential):
103
+ def __init__(
104
+ self, scale, n_feats, bias=True,
105
+ conv=default_conv, norm=False, act=False):
106
+
107
+ modules = []
108
+ if (scale & (scale - 1)) == 0: # Is scale = 2^n?
109
+ for _ in range(int(math.log(scale, 2))):
110
+ modules.append(conv(n_feats, 4 * n_feats, 3, bias))
111
+ modules.append(nn.PixelShuffle(2))
112
+ if norm: modules.append(norm(n_feats))
113
+ if act: modules.append(act())
114
+ elif scale == 3:
115
+ modules.append(conv(n_feats, 9 * n_feats, 3, bias))
116
+ modules.append(nn.PixelShuffle(3))
117
+ if norm: modules.append(norm(n_feats))
118
+ if act: modules.append(act())
119
+ else:
120
+ raise NotImplementedError
121
+
122
+ super(Upsampler, self).__init__(*modules)
123
+
124
+ # Only support 1 / 2
125
+ class PixelSort(nn.Module):
126
+ """The inverse operation of PixelShuffle
127
+ Reduces the spatial resolution, increasing the number of channels.
128
+ Currently, scale 0.5 is supported only.
129
+ Later, torch.nn.functional.pixel_sort may be implemented.
130
+ Reference:
131
+ http://pytorch.org/docs/0.3.0/_modules/torch/nn/modules/pixelshuffle.html#PixelShuffle
132
+ http://pytorch.org/docs/0.3.0/_modules/torch/nn/functional.html#pixel_shuffle
133
+ """
134
+ def __init__(self, upscale_factor=0.5):
135
+ super(PixelSort, self).__init__()
136
+ self.upscale_factor = upscale_factor
137
+
138
+ def forward(self, x):
139
+ b, c, h, w = x.size()
140
+ x = x.view(b, c, 2, 2, h // 2, w // 2)
141
+ x = x.permute(0, 1, 5, 3, 2, 4).contiguous()
142
+ x = x.view(b, 4 * c, h // 2, w // 2)
143
+
144
+ return x
145
+
146
+ class Downsampler(nn.Sequential):
147
+ def __init__(
148
+ self, scale, n_feats, bias=True,
149
+ conv=default_conv, norm=False, act=False):
150
+
151
+ modules = []
152
+ if scale == 0.5:
153
+ modules.append(PixelSort())
154
+ modules.append(conv(4 * n_feats, n_feats, 3, bias))
155
+ if norm: modules.append(norm(n_feats))
156
+ if act: modules.append(act())
157
+ else:
158
+ raise NotImplementedError
159
+
160
+ super(Downsampler, self).__init__(*modules)
161
+
deblur/src/model/discriminator.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class Discriminator(nn.Module):
4
+ def __init__(self, args):
5
+ super(Discriminator, self).__init__()
6
+
7
+ # self.args = args
8
+ n_feats = args.n_feats
9
+ kernel_size = args.kernel_size
10
+
11
+ def conv(kernel_size, in_channel, n_feats, stride, pad=None):
12
+ if pad is None:
13
+ pad = (kernel_size-1)//2
14
+
15
+ return nn.Conv2d(in_channel, n_feats, kernel_size, stride=stride, padding=pad, bias=False)
16
+
17
+ self.conv_layers = nn.ModuleList([
18
+ conv(kernel_size, 3, n_feats//2, 1), # 256
19
+ conv(kernel_size, n_feats//2, n_feats//2, 2), # 128
20
+ conv(kernel_size, n_feats//2, n_feats, 1),
21
+ conv(kernel_size, n_feats, n_feats, 2), # 64
22
+ conv(kernel_size, n_feats, n_feats*2, 1),
23
+ conv(kernel_size, n_feats*2, n_feats*2, 4), # 16
24
+ conv(kernel_size, n_feats*2, n_feats*4, 1),
25
+ conv(kernel_size, n_feats*4, n_feats*4, 4), # 4
26
+ conv(kernel_size, n_feats*4, n_feats*8, 1),
27
+ conv(4, n_feats*8, n_feats*8, 4, 0), # 1
28
+ ])
29
+
30
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
31
+ self.dense = nn.Conv2d(n_feats*8, 1, 1, bias=False)
32
+
33
+ def forward(self, x):
34
+
35
+ for layer in self.conv_layers:
36
+ x = self.act(layer(x))
37
+
38
+ x = self.dense(x)
39
+
40
+ return x
41
+
deblur/src/model/structure.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .common import ResBlock, default_conv
4
+
5
+ def encoder(in_channels, n_feats):
6
+ """RGB / IR feature encoder
7
+ """
8
+
9
+ # in_channels == 1 or 3 or 4 or ....
10
+ # After 1st conv, B x n_feats x H x W
11
+ # After 2nd conv, B x 2n_feats x H/2 x W/2
12
+ # After 3rd conv, B x 3n_feats x H/4 x W/4
13
+ return nn.Sequential(
14
+ nn.Conv2d(in_channels, 1 * n_feats, 5, stride=1, padding=2),
15
+ nn.Conv2d(1 * n_feats, 2 * n_feats, 5, stride=2, padding=2),
16
+ nn.Conv2d(2 * n_feats, 3 * n_feats, 5, stride=2, padding=2),
17
+ )
18
+
19
+ def decoder(out_channels, n_feats):
20
+ """RGB / IR / Depth decoder
21
+ """
22
+ # After 1st deconv, B x 2n_feats x H/2 x W/2
23
+ # After 2nd deconv, B x n_feats x H x W
24
+ # After 3rd conv, B x out_channels x H x W
25
+ deconv_kargs = {'stride': 2, 'padding': 1, 'output_padding': 1}
26
+
27
+ return nn.Sequential(
28
+ nn.ConvTranspose2d(3 * n_feats, 2 * n_feats, 3, **deconv_kargs),
29
+ nn.ConvTranspose2d(2 * n_feats, 1 * n_feats, 3, **deconv_kargs),
30
+ nn.Conv2d(n_feats, out_channels, 5, stride=1, padding=2),
31
+ )
32
+
33
+ # def ResNet(n_feats, in_channels=None, out_channels=None):
34
+ def ResNet(n_feats, kernel_size, n_blocks, in_channels=None, out_channels=None):
35
+ """sequential ResNet
36
+ """
37
+
38
+ # if in_channels is None:
39
+ # in_channels = n_feats
40
+ # if out_channels is None:
41
+ # out_channels = n_feats
42
+ # # currently not implemented
43
+
44
+ m = []
45
+
46
+ if in_channels is not None:
47
+ m += [default_conv(in_channels, n_feats, kernel_size)]
48
+
49
+ m += [ResBlock(n_feats, 3)] * n_blocks
50
+
51
+ if out_channels is not None:
52
+ m += [default_conv(n_feats, out_channels, kernel_size)]
53
+
54
+
55
+ return nn.Sequential(*m)
56
+
deblur/src/optim/__init__.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import torch.optim.lr_scheduler as lrs
4
+
5
+ import os
6
+ from collections import Counter
7
+
8
+ from model import Model
9
+ from utils import interact, Map
10
+
11
+ class Optimizer(object):
12
+ def __init__(self, args, model):
13
+ self.args = args
14
+
15
+ self.save_dir = os.path.join(self.args.save_dir, 'optim')
16
+ os.makedirs(self.save_dir, exist_ok=True)
17
+
18
+ if isinstance(model, Model):
19
+ model = model.model
20
+
21
+ # set base arguments
22
+ kwargs_optimizer = {
23
+ 'lr': args.lr,
24
+ 'weight_decay': args.weight_decay
25
+ }
26
+
27
+ if args.optimizer == 'SGD':
28
+ optimizer_class = optim.SGD
29
+ kwargs_optimizer['momentum'] = args.momentum
30
+ elif args.optimizer == 'ADAM':
31
+ optimizer_class = optim.Adam
32
+ kwargs_optimizer['betas'] = args.betas
33
+ kwargs_optimizer['eps'] = args.epsilon
34
+ elif args.optimizer == 'RMSPROP':
35
+ optimizer_class = optim.RMSprop
36
+ kwargs_optimizer['eps'] = args.epsilon
37
+
38
+ # scheduler
39
+ if args.scheduler == 'step':
40
+ scheduler_class = lrs.MultiStepLR
41
+ kwargs_scheduler = {
42
+ 'milestones': args.milestones,
43
+ 'gamma': args.gamma,
44
+ }
45
+ elif args.scheduler == 'plateau':
46
+ scheduler_class = lrs.ReduceLROnPlateau
47
+ kwargs_scheduler = {
48
+ 'mode': 'min',
49
+ 'factor': args.gamma,
50
+ 'patience': 10,
51
+ 'verbose': True,
52
+ 'threshold': 0,
53
+ 'threshold_mode': 'abs',
54
+ 'cooldown': 10,
55
+ }
56
+
57
+ self.kwargs_optimizer = kwargs_optimizer
58
+ self.scheduler_class = scheduler_class
59
+ self.kwargs_scheduler = kwargs_scheduler
60
+
61
+ def _get_optimizer(model):
62
+
63
+ class _Optimizer(optimizer_class):
64
+ def __init__(self, model, args, scheduler_class, kwargs_scheduler):
65
+ trainable = filter(lambda x: x.requires_grad, model.parameters())
66
+ super(_Optimizer, self).__init__(trainable, **kwargs_optimizer)
67
+
68
+ self.args = args
69
+
70
+ self._register_scheduler(scheduler_class, kwargs_scheduler)
71
+
72
+ def _register_scheduler(self, scheduler_class, kwargs_scheduler):
73
+ self.scheduler = scheduler_class(self, **kwargs_scheduler)
74
+
75
+ def schedule(self, metrics=None):
76
+ if isinstance(self, lrs.ReduceLROnPlateau):
77
+ self.scheduler.step(metrics)
78
+ else:
79
+ self.scheduler.step()
80
+
81
+ def get_last_epoch(self):
82
+ return self.scheduler.last_epoch
83
+
84
+ def get_lr(self):
85
+ return self.param_groups[0]['lr']
86
+
87
+ def get_last_lr(self):
88
+ return self.scheduler.get_last_lr()[0]
89
+
90
+ def state_dict(self):
91
+ state_dict = super(_Optimizer, self).state_dict() # {'state': ..., 'param_groups': ...}
92
+ state_dict['scheduler'] = self.scheduler.state_dict()
93
+
94
+ return state_dict
95
+
96
+ def load_state_dict(self, state_dict, epoch=None):
97
+ # optimizer
98
+ super(_Optimizer, self).load_state_dict(state_dict) # load 'state' and 'param_groups' only
99
+ # scheduler
100
+ self.scheduler.load_state_dict(state_dict['scheduler']) # should work for plateau or simple resuming
101
+
102
+ reschedule = False
103
+ if isinstance(self.scheduler, lrs.MultiStepLR):
104
+ if self.args.milestones != list(self.scheduler.milestones) or self.args.gamma != self.scheduler.gamma:
105
+ reschedule = True
106
+
107
+ if reschedule:
108
+ if epoch is None:
109
+ if self.scheduler.last_epoch > 1:
110
+ epoch = self.scheduler.last_epoch
111
+ else:
112
+ epoch = self.args.start_epoch - 1
113
+
114
+ # if False:
115
+ # # option 1. new scheduler
116
+ # for i, group in enumerate(self.param_groups):
117
+ # self.param_groups[i]['lr'] = group['initial_lr'] # reset optimizer learning rate to initial
118
+ # # self.scheduler = None
119
+ # self._register_scheduler(scheduler_class, kwargs_scheduler)
120
+
121
+ # self.zero_grad()
122
+ # self.step()
123
+ # for _ in range(epoch):
124
+ # self.scheduler.step()
125
+ # self._step_count -= 1
126
+
127
+ # else:
128
+ # option 2. modify existing scheduler
129
+ self.scheduler.milestones = Counter(self.args.milestones)
130
+ self.scheduler.gamma = self.args.gamma
131
+ for i, group in enumerate(self.param_groups):
132
+ self.param_groups[i]['lr'] = group['initial_lr'] # reset optimizer learning rate to initial
133
+ multiplier = 1
134
+ for milestone in self.scheduler.milestones:
135
+ if epoch >= milestone:
136
+ multiplier *= self.scheduler.gamma
137
+
138
+ self.param_groups[i]['lr'] *= multiplier
139
+
140
+ return _Optimizer(model, args, scheduler_class, kwargs_scheduler)
141
+
142
+ self.G = _get_optimizer(model.G)
143
+ if model.D is not None:
144
+ self.D = _get_optimizer(model.D)
145
+ else:
146
+ self.D = None
147
+
148
+ self.load(args.load_epoch)
149
+
150
+ def zero_grad(self):
151
+ self.G.zero_grad()
152
+
153
+ def step(self):
154
+ self.G.step()
155
+
156
+ def schedule(self, metrics=None):
157
+ self.G.schedule(metrics)
158
+ if self.D is not None:
159
+ self.D.schedule(metrics)
160
+
161
+ def get_last_epoch(self):
162
+ return self.G.get_last_epoch()
163
+
164
+ def get_lr(self):
165
+ return self.G.get_lr()
166
+
167
+ def get_last_lr(self):
168
+ return self.G.get_last_lr()
169
+
170
+ def state_dict(self):
171
+ state_dict = Map()
172
+ state_dict.G = self.G.state_dict()
173
+ if self.D is not None:
174
+ state_dict.D = self.D.state_dict()
175
+
176
+ return state_dict.toDict()
177
+
178
+ def load_state_dict(self, state_dict, epoch=None):
179
+ state_dict = Map(**state_dict)
180
+ self.G.load_state_dict(state_dict.G, epoch)
181
+ if self.D is not None:
182
+ self.D.load_state_dict(state_dict.D, epoch)
183
+
184
+ def _save_path(self, epoch=None):
185
+ epoch = epoch if epoch is not None else self.get_last_epoch()
186
+ save_path = os.path.join(self.save_dir, 'optim-{:d}.pt'.format(epoch))
187
+
188
+ return save_path
189
+
190
+ def save(self, epoch=None):
191
+ if epoch is None:
192
+ epoch = self.G.scheduler.last_epoch
193
+ torch.save(self.state_dict(), self._save_path(epoch))
194
+
195
+ def load(self, epoch):
196
+ if epoch > 0:
197
+ print('Loading optimizer from {}'.format(self._save_path(epoch)))
198
+ self.load_state_dict(torch.load(self._save_path(epoch), map_location=self.args.device), epoch=epoch)
199
+
200
+ elif epoch == 0:
201
+ pass
202
+ else:
203
+ raise NotImplementedError
204
+
205
+ return
206
+
deblur/src/optim/warm_multi_step_lr.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from bisect import bisect_right
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+
5
+ # MultiStep learning rate scheduler with warm restart
6
+ class WarmMultiStepLR(_LRScheduler):
7
+ def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, scale=1):
8
+ if not list(milestones) == sorted(milestones):
9
+ raise ValueError(
10
+ 'Milestones should be a list of increasing integers. Got {}',
11
+ milestones
12
+ )
13
+
14
+ self.milestones = milestones
15
+ self.gamma = gamma
16
+ self.scale = scale
17
+
18
+ self.warmup_epochs = 5
19
+ self.gradual = (self.scale - 1) / self.warmup_epochs
20
+ super(WarmMultiStepLR, self).__init__(optimizer, last_epoch)
21
+
22
+ def get_lr(self):
23
+ if self.last_epoch < self.warmup_epochs:
24
+ return [
25
+ base_lr * (1 + self.last_epoch * self.gradual) / self.scale
26
+ for base_lr in self.base_lrs
27
+ ]
28
+ else:
29
+ return [
30
+ base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch)
31
+ for base_lr in self.base_lrs
32
+ ]
deblur/src/option.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """optionional argument parsing"""
2
+ # pylint: disable=C0103, C0301
3
+ import argparse
4
+ import datetime
5
+ import os
6
+ import re
7
+ import shutil
8
+ import time
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ import torch.backends.cudnn as cudnn
13
+
14
+ from utils import interact
15
+ from utils import str2bool, int2str
16
+
17
+ import template
18
+
19
+ # Training settings
20
+ parser = argparse.ArgumentParser(description='Dynamic Scene Deblurring')
21
+
22
+ # Device specifications
23
+ group_device = parser.add_argument_group('Device specs')
24
+ group_device.add_argument('--seed', type=int, default=-1, help='random seed')
25
+ group_device.add_argument('--num_workers', type=int, default=7, help='the number of dataloader workers')
26
+ group_device.add_argument('--device_type', type=str, choices=('cpu', 'cuda'), default='cuda', help='device to run models')
27
+ group_device.add_argument('--device_index', type=int, default=0, help='device id to run models')
28
+ group_device.add_argument('--n_GPUs', type=int, default=1, help='the number of GPUs for training')
29
+ group_device.add_argument('--distributed', type=str2bool, default=False, help='use DistributedDataParallel instead of DataParallel for better speed')
30
+ group_device.add_argument('--launched', type=str2bool, default=False, help='identify if main.py was executed from launch.py. Do not set this to be true using main.py.')
31
+
32
+ group_device.add_argument('--master_addr', type=str, default='127.0.0.1', help='master address for distributed')
33
+ group_device.add_argument('--master_port', type=int2str, default='8023', help='master port for distributed')
34
+ group_device.add_argument('--dist_backend', type=str, default='nccl', help='distributed backend')
35
+ group_device.add_argument('--init_method', type=str, default='env://', help='distributed init method URL to discover peers')
36
+ group_device.add_argument('--rank', type=int, default=0, help='rank of the distributed process (gpu id). 0 is the master process.')
37
+ group_device.add_argument('--world_size', type=int, default=1, help='world_size for distributed training (number of GPUs)')
38
+
39
+ # Data
40
+ group_data = parser.add_argument_group('Data specs')
41
+ group_data.add_argument('--data_root', type=str, default='/data/ssd/public/czli/deblur', help='dataset root location')
42
+ group_data.add_argument('--dataset', type=str, default=None, help='training/validation/test dataset name, has priority if not None')
43
+ group_data.add_argument('--data_train', type=str, default='GOPRO_Large', help='training dataset name')
44
+ group_data.add_argument('--data_val', type=str, default=None, help='validation dataset name')
45
+ group_data.add_argument('--data_test', type=str, default='GOPRO_Large', help='test dataset name')
46
+ group_data.add_argument('--blur_key', type=str, default='blur_gamma', choices=('blur', 'blur_gamma'), help='blur type from camera response function for GOPRO_Large dataset')
47
+ group_data.add_argument('--rgb_range', type=int, default=255, help='RGB pixel value ranging from 0')
48
+
49
+ # Model
50
+ group_model = parser.add_argument_group('Model specs')
51
+ group_model.add_argument('--model', type=str, default='RecLamResNet', help='model architecture')
52
+ group_model.add_argument('--pretrained', type=str, default='', help='pretrained model location')
53
+ group_model.add_argument('--n_scales', type=int, default=5, help='multi-scale deblurring level')
54
+ group_model.add_argument('--detach', type=str2bool, default=False, help='detach between recurrence')
55
+ group_model.add_argument('--gaussian_pyramid', type=str2bool, default=True, help='gaussian pyramid input/target')
56
+ group_model.add_argument('--n_resblocks', type=int, default=19, help='number of residual blocks per scale')
57
+ group_model.add_argument('--n_feats', type=int, default=64, help='number of feature maps')
58
+ group_model.add_argument('--kernel_size', type=int, default=5, help='size of conv kernel')
59
+ group_model.add_argument('--downsample', type=str, choices=('Gaussian', 'bicubic', 'stride'), default='Gaussian', help='input pyramid generation method')
60
+
61
+ group_model.add_argument('--precision', type=str, default='single', choices=('single', 'half'), help='FP precision for test(single | half)')
62
+
63
+ # amp
64
+ group_amp = parser.add_argument_group('AMP specs')
65
+ group_amp.add_argument('--amp', type=str2bool, default=False, help='use automatic mixed precision training')
66
+ group_amp.add_argument('--init_scale', type=float, default=1024., help='initial loss scale')
67
+
68
+ # Training
69
+ group_train = parser.add_argument_group('Training specs')
70
+ group_train.add_argument('--patch_size', type=int, default=256, help='training patch size')
71
+ group_train.add_argument('--batch_size', type=int, default=16, help='input batch size for training')
72
+ group_train.add_argument('--split_batch', type=int, default=1, help='split a minibatch into smaller chunks')
73
+ group_train.add_argument('--augment', type=str2bool, default=True, help='train with data augmentation')
74
+
75
+ # Testing
76
+ group_test = parser.add_argument_group('Testing specs')
77
+ group_test.add_argument('--validate_every', type=int, default=10, help='do validation at every N epochs')
78
+ group_test.add_argument('--test_every', type=int, default=10, help='do test at every N epochs')
79
+ # group_test.add_argument('--chop', type=str2bool, default=False, help='memory-efficient forward')
80
+ # group_test.add_argument('--self_ensemble', type=str2bool, default=False, help='self-ensembled testing')
81
+
82
+ # Action
83
+ group_action = parser.add_argument_group('Source behavior')
84
+ group_action.add_argument('--do_train', type=str2bool, default=True, help='do train the model')
85
+ group_action.add_argument('--do_validate', type=str2bool, default=True, help='do validate the model')
86
+ group_action.add_argument('--do_test', type=str2bool, default=True, help='do test the model')
87
+ group_action.add_argument('--demo', type=str2bool, default=False, help='demo')
88
+ group_action.add_argument('--demo_input_dir', type=str, default='', help='demo input directory')
89
+ group_action.add_argument('--demo_output_dir', type=str, default='', help='demo output directory')
90
+
91
+ # Optimization
92
+ group_optim = parser.add_argument_group('Optimization specs')
93
+ group_optim.add_argument('--lr', type=float, default=1e-4, help='learning rate')
94
+ group_optim.add_argument('--milestones', type=int, nargs='+', default=[500, 750, 900], help='learning rate decay per N epochs')
95
+ group_optim.add_argument('--scheduler', default='step', choices=('step', 'plateau'), help='learning rate scheduler type')
96
+ group_optim.add_argument('--gamma', type=float, default=0.5, help='learning rate decay factor for step decay')
97
+ group_optim.add_argument('--optimizer', default='ADAM', choices=('SGD', 'ADAM', 'RMSprop'), help='optimizer to use (SGD | ADAM | RMSProp)')
98
+ group_optim.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
99
+ group_optim.add_argument('--betas', type=float, nargs=2, default=(0.9, 0.999), help='ADAM betas')
100
+ group_optim.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon')
101
+ group_optim.add_argument('--weight_decay', type=float, default=0, help='weight decay')
102
+ group_optim.add_argument('--clip', type=float, default=0, help='weight decay')
103
+
104
+ # Loss
105
+ group_loss = parser.add_argument_group('Loss specs')
106
+ group_loss.add_argument('--loss', type=str, default='1*MSE', help='loss function configuration')
107
+ group_loss.add_argument('--metric', type=str, default='PSNR,SSIM', help='metric function configuration. ex) None | PSNR | SSIM | PSNR,SSIM')
108
+ group_loss.add_argument('--gamma', type=float, default=0.6, help='gamma decay')
109
+
110
+
111
+ # Logging
112
+ group_log = parser.add_argument_group('Logging specs')
113
+ group_log.add_argument('--save_dir', type=str, default='', help='subdirectory to save experiment logs')
114
+ # group_log.add_argument('--load_dir', type=str, default='', help='subdirectory to load experiment logs')
115
+ group_log.add_argument('--start_epoch', type=int, default=-1, help='(re)starting epoch number')
116
+ group_log.add_argument('--end_epoch', type=int, default=1000, help='ending epoch number')
117
+ group_log.add_argument('--load_epoch', type=int, default=-1, help='epoch number to load model (start_epoch-1 for training, start_epoch for testing)')
118
+ group_log.add_argument('--save_every', type=int, default=10, help='save model/optimizer at every N epochs')
119
+ group_log.add_argument('--save_results', type=str, default='part', choices=('none', 'part', 'all'), help='save none/part/all of result images')
120
+
121
+ # Debugging
122
+ group_debug = parser.add_argument_group('Debug specs')
123
+ group_debug.add_argument('--stay', type=str2bool, default=False, help='stay at interactive console after trainer initialization')
124
+
125
+ parser.add_argument('--template', type=str, default='', help='argument template option')
126
+
127
+ args = parser.parse_args()
128
+ template.set_template(args)
129
+
130
+ args.data_root = os.path.expanduser(args.data_root) # recognize home directory
131
+ now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
132
+ if args.save_dir == '':
133
+ args.save_dir = now
134
+ args.save_dir = os.path.join('../experiment', args.save_dir)
135
+ os.makedirs(args.save_dir, exist_ok=True)
136
+
137
+ if args.start_epoch < 0: # start from scratch or continue from the last epoch
138
+ # check if there are any models saved before
139
+ model_dir = os.path.join(args.save_dir, 'models')
140
+ model_prefix = 'model-'
141
+ if os.path.exists(model_dir):
142
+ model_list = [name for name in os.listdir(model_dir) if name.startswith(model_prefix)]
143
+ last_epoch = 0
144
+ for name in model_list:
145
+ epochNumber = int(re.findall('\\d+', name)[0]) # model example name model-100.pt
146
+ if last_epoch < epochNumber:
147
+ last_epoch = epochNumber
148
+
149
+ args.start_epoch = last_epoch + 1
150
+ else:
151
+ # train from scratch
152
+ args.start_epoch = 1
153
+ elif args.start_epoch == 0:
154
+ # remove existing directory and start over
155
+ if args.rank == 0: # maybe local rank
156
+ shutil.rmtree(args.save_dir, ignore_errors=True)
157
+ os.makedirs(args.save_dir, exist_ok=True)
158
+ args.start_epoch = 1
159
+
160
+ if args.load_epoch < 0: # load_epoch == start_epoch when doing a post-training test for a specific epoch
161
+ args.load_epoch = args.start_epoch - 1
162
+
163
+ if args.pretrained:
164
+ if args.start_epoch <= 1:
165
+ args.pretrained = os.path.join('../experiment', args.pretrained)
166
+ else:
167
+ print('starting from epoch {}! ignoring pretrained model path..'.format(args.start_epoch))
168
+ args.pretrained = ''
169
+
170
+ if args.model == 'MSResNet':
171
+ args.gaussian_pyramid = True
172
+
173
+ argname = os.path.join(args.save_dir, 'args.pt')
174
+ argname_txt = os.path.join(args.save_dir, 'args.txt')
175
+ if args.start_epoch > 1:
176
+ # load previous arguments and keep the necessary ones same
177
+
178
+ if os.path.exists(argname):
179
+ args_old = torch.load(argname)
180
+
181
+ load_list = [] # list of arguments that are fixed
182
+ # training
183
+ load_list += ['patch_size']
184
+ load_list += ['batch_size']
185
+ # data format
186
+ load_list += ['rgb_range']
187
+ load_list += ['blur_key']
188
+ # model architecture
189
+ load_list += ['n_scales']
190
+ load_list += ['n_resblocks']
191
+ load_list += ['n_feats']
192
+
193
+ for arg_part in load_list:
194
+ vars(args)[arg_part] = vars(args_old)[arg_part]
195
+
196
+ if args.dataset is not None:
197
+ args.data_train = args.dataset
198
+ args.data_val = args.dataset if args.dataset != 'GOPRO_Large' else None
199
+ args.data_test = args.dataset
200
+
201
+ if args.data_val is None:
202
+ args.do_validate = False
203
+
204
+ if args.demo_input_dir:
205
+ args.demo = True
206
+
207
+ if args.demo:
208
+ assert os.path.basename(args.save_dir) != now, 'You should specify pretrained directory by setting --save_dir SAVE_DIR'
209
+
210
+ args.data_train = ''
211
+ args.data_val = ''
212
+ args.data_test = ''
213
+
214
+ args.do_train = False
215
+ args.do_validate = False
216
+ args.do_test = False
217
+
218
+ assert len(args.demo_input_dir) > 0, 'Please specify demo_input_dir!'
219
+ args.demo_input_dir = os.path.expanduser(args.demo_input_dir)
220
+ if args.demo_output_dir:
221
+ args.demo_output_dir = os.path.expanduser(args.demo_output_dir)
222
+
223
+ args.save_results = 'all'
224
+
225
+ if args.amp:
226
+ args.precision = 'single' # model parameters should stay in fp32
227
+
228
+ if args.seed < 0:
229
+ args.seed = int(time.time())
230
+
231
+ # save arguments
232
+ if args.rank == 0:
233
+ torch.save(args, argname)
234
+ with open(argname_txt, 'a') as file:
235
+ file.write('execution at {}\n'.format(now))
236
+
237
+ for key in args.__dict__:
238
+ file.write(key + ': ' + str(args.__dict__[key]) + '\n')
239
+
240
+ file.write('\n')
241
+
242
+ # device and type
243
+ if args.device_type == 'cuda' and not torch.cuda.is_available():
244
+ raise Exception("GPU not available!")
245
+
246
+ if not args.distributed:
247
+ args.rank = 0
248
+
249
+ def setup(args):
250
+ cudnn.benchmark = True
251
+
252
+ if args.distributed:
253
+ os.environ['MASTER_ADDR'] = args.master_addr
254
+ os.environ['MASTER_PORT'] = args.master_port
255
+
256
+ args.device_index = args.rank
257
+ args.world_size = args.n_GPUs # consider single-node training
258
+
259
+ # initialize the process group
260
+ dist.init_process_group(args.dist_backend, init_method=args.init_method, rank=args.rank, world_size=args.world_size)
261
+
262
+ args.device = torch.device(args.device_type, args.device_index)
263
+ args.dtype = torch.float32
264
+ args.dtype_eval = torch.float32 if args.precision == 'single' else torch.float16
265
+
266
+ # set seed for processes (distributed: different seed for each process)
267
+ # model parameters are synchronized explicitly at initial
268
+ torch.manual_seed(args.seed)
269
+ if args.device_type == 'cuda':
270
+ torch.cuda.set_device(args.device)
271
+ if args.rank == 0:
272
+ torch.cuda.manual_seed_all(args.seed)
273
+
274
+ return args
275
+
276
+ def cleanup(args):
277
+ if args.distributed:
278
+ dist.destroy_process_group()
deblur/src/prepare.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ workplace=`pwd -P`
3
+
4
+ if [ ! -d "/data/ssd/public/czli/deblur/GOPRO_Large" ]; then
5
+ workplace=`pwd`
6
+ echo "Copying dataset to ssd"
7
+ cd /research/dept7/liuhy/deblur/dataset
8
+ mkdir -p /data/ssd/public/czli/deblur/GOPRO_Large
9
+ cp GOPRO_Large.zip /data/ssd/public/czli/deblur/GOPRO_Large
10
+ cd /data/ssd/public/czli/deblur/GOPRO_Large
11
+ echo "Dumping zip in data path:" `pwd`
12
+ for f in *.zip; do unzip "$f"; done
13
+ fi
14
+
15
+ cd "$workplace"
16
+ echo "Workplace:" `pwd`
deblur/src/template.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ def set_template(args):
2
+ if args.template.find('gopro') >= 0:
3
+ args.dataset = 'GOPRO_Large'
4
+ args.milestones = [500, 750, 900]
5
+ args.end_epoch = 1000
6
+ elif args.template.find('reds') >= 0:
7
+ args.dataset = 'REDS'
8
+ args.milestones = [100, 150, 180]
9
+ args.end_epoch = 200
deblur/src/train-rec-1.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ bash ./prepare.sh
4
+ python launch.py --n_GPUs 2 main.py --batch_size 16 --amp true --save_dir LamRes_L1 --n_scales 1 --loss 1*L1
deblur/src/train.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+
4
+ import torch
5
+
6
+ import data.common
7
+ from utils import interact, MultiSaver
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ import torchvision
10
+
11
+ import torch.cuda.amp as amp
12
+
13
+ class Trainer():
14
+
15
+ def __init__(self, args, model, criterion, optimizer, loaders):
16
+ print('===> Initializing trainer')
17
+ self.args = args
18
+ self.mode = 'train' # 'val', 'test'
19
+ self.epoch = args.start_epoch
20
+ self.save_dir = args.save_dir
21
+
22
+ self.model = model
23
+ self.criterion = criterion
24
+ self.optimizer = optimizer
25
+ self.loaders = loaders
26
+
27
+
28
+ self.do_train = args.do_train
29
+ self.do_validate = args.do_validate
30
+ self.do_test = args.do_test
31
+
32
+ self.device = args.device
33
+ self.dtype = args.dtype
34
+ self.dtype_eval = torch.float32 if args.precision == 'single' else torch.float16
35
+ self.recurrence=args.n_scales
36
+ if self.args.demo and self.args.demo_output_dir:
37
+ self.result_dir = self.args.demo_output_dir
38
+ else:
39
+ self.result_dir = os.path.join(self.save_dir, 'result')
40
+ os.makedirs(self.result_dir, exist_ok=True)
41
+ print('results are saved in {}'.format(self.result_dir))
42
+
43
+ self.imsaver = MultiSaver(self.result_dir)
44
+
45
+ self.is_slave = self.args.launched and self.args.rank != 0
46
+
47
+ self.scaler = amp.GradScaler(
48
+ init_scale=self.args.init_scale,
49
+ enabled=self.args.amp
50
+ )
51
+ if not self.is_slave:
52
+ self.writter=SummaryWriter(f"runs/{args.save_dir}")
53
+
54
+
55
+ def save(self, epoch=None):
56
+ epoch = self.epoch if epoch is None else epoch
57
+ if epoch % self.args.save_every == 0:
58
+ if self.mode == 'train':
59
+ self.model.save(epoch)
60
+ self.optimizer.save(epoch)
61
+ self.criterion.save()
62
+
63
+ return
64
+
65
+ def load(self, epoch=None, pretrained=None):
66
+ if epoch is None:
67
+ epoch = self.args.load_epoch
68
+ self.epoch = epoch
69
+ self.model.load(epoch, pretrained)
70
+ self.optimizer.load(epoch)
71
+ self.criterion.load(epoch)
72
+
73
+ return
74
+
75
+ def train(self, epoch):
76
+ self.mode = 'train'
77
+ self.epoch = epoch
78
+
79
+ self.model.train()
80
+ self.model.to(dtype=self.dtype)
81
+
82
+ self.criterion.train()
83
+ self.criterion.epoch = epoch
84
+
85
+ if not self.is_slave:
86
+ print('[Epoch {} / lr {:.2e}]'.format(
87
+ epoch, self.optimizer.get_lr()
88
+ ))
89
+ total=len(self.loaders[self.mode])
90
+ acc=0.0
91
+ if self.args.distributed:
92
+ self.loaders[self.mode].sampler.set_epoch(epoch)
93
+ if self.is_slave:
94
+ tq = self.loaders[self.mode]
95
+ else:
96
+ tq = tqdm(self.loaders[self.mode], ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')
97
+ buffer=[0.0]*self.recurrence
98
+ torch.set_grad_enabled(True)
99
+ for idx, batch in enumerate(tq):
100
+ self.optimizer.zero_grad()
101
+
102
+ input, target = data.common.to(
103
+ batch[0], batch[1], device=self.device, dtype=self.dtype)
104
+
105
+ with amp.autocast(self.args.amp):
106
+ output = self.model(input)
107
+ loss = self.criterion(output, target)
108
+
109
+ for i in range(self.recurrence):
110
+ buffer[i]+=self.criterion.buffer[i]
111
+
112
+ self.scaler.scale(loss).backward()
113
+ if self.args.clip>0:
114
+ self.scaler.unscale_(self.optimizer.G)
115
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
116
+ self.scaler.step(self.optimizer.G)
117
+ self.scaler.update()
118
+
119
+ if isinstance(tq, tqdm):
120
+ tq.set_description(self.criterion.get_loss_desc())
121
+ if not self.is_slave:
122
+ rgb_range=self.args.rgb_range
123
+ for i in range(len(output)):
124
+ grid=torchvision.utils.make_grid(output[i])
125
+ self.writter.add_image(f"Output{i}",grid/rgb_range,epoch)
126
+ self.writter.add_scalar(f"Loss{i}",buffer[i],epoch)
127
+ self.writter.add_image("Input",torchvision.utils.make_grid(input[0])/rgb_range,epoch)
128
+ self.writter.add_image("Target",torchvision.utils.make_grid(target[0])/rgb_range,epoch)
129
+ self.criterion.normalize()
130
+ if isinstance(tq, tqdm):
131
+ tq.set_description(self.criterion.get_loss_desc())
132
+ tq.display(pos=-1) # overwrite with synchronized loss
133
+
134
+ self.criterion.step()
135
+ self.optimizer.schedule(self.criterion.get_last_loss())
136
+
137
+ if self.args.rank == 0:
138
+ self.save(epoch)
139
+
140
+ return
141
+
142
+ def evaluate(self, epoch, mode='val'):
143
+ self.mode = mode
144
+ self.epoch = epoch
145
+
146
+ self.model.eval()
147
+ self.model.to(dtype=self.dtype_eval)
148
+
149
+ if mode == 'val':
150
+ self.criterion.validate()
151
+ elif mode == 'test':
152
+ self.criterion.test()
153
+ self.criterion.epoch = epoch
154
+
155
+ self.imsaver.join_background()
156
+
157
+ if self.is_slave:
158
+ tq = self.loaders[self.mode]
159
+ else:
160
+ tq = tqdm(self.loaders[self.mode], ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')
161
+
162
+ compute_loss = True
163
+ torch.set_grad_enabled(False)
164
+ for idx, batch in enumerate(tq):
165
+ input, target = data.common.to(
166
+ batch[0], batch[1], device=self.device, dtype=self.dtype_eval)
167
+ with amp.autocast(self.args.amp):
168
+ output = self.model(input)
169
+
170
+ # if self.args.rgb_range==1:
171
+ # output=output*255
172
+ # target=target*255
173
+
174
+ if mode == 'demo': # remove padded part
175
+ pad_width = batch[2]
176
+ output[0], _ = data.common.pad(output[0], pad_width=pad_width, negative=True)
177
+
178
+ if isinstance(batch[1], torch.BoolTensor):
179
+ compute_loss = False
180
+
181
+ if compute_loss:
182
+ self.criterion(output, target)
183
+ if isinstance(tq, tqdm):
184
+ tq.set_description(self.criterion.get_loss_desc())
185
+
186
+ if self.args.save_results != 'none':
187
+ if isinstance(output, (list, tuple)):
188
+ result = output[-1] # select last output in a pyramid
189
+ elif isinstance(output, torch.Tensor):
190
+ result = output
191
+
192
+ names = batch[-1]
193
+
194
+ if self.args.save_results == 'part' and compute_loss: # save all when GT not available
195
+ indices = batch[-2]
196
+ save_ids = [save_id for save_id, idx in enumerate(indices) if idx % 10 == 0]
197
+
198
+ result = result[save_ids]
199
+ names = [names[save_id] for save_id in save_ids]
200
+
201
+ self.imsaver.save_image(result, names)
202
+
203
+ if compute_loss:
204
+ self.criterion.normalize()
205
+ if isinstance(tq, tqdm):
206
+ tq.set_description(self.criterion.get_loss_desc())
207
+ tq.display(pos=-1) # overwrite with synchronized loss
208
+
209
+ self.criterion.step()
210
+ if self.args.rank == 0:
211
+ self.save()
212
+
213
+ self.imsaver.end_background()
214
+
215
+ def validate(self, epoch):
216
+ self.evaluate(epoch, 'val')
217
+ return
218
+
219
+ def test(self, epoch):
220
+ self.evaluate(epoch, 'test')
221
+ return
222
+
223
+ def fill_evaluation(self, epoch, mode=None, force=False):
224
+ if epoch <= 0:
225
+ return
226
+
227
+ if mode is not None:
228
+ self.mode = mode
229
+
230
+ do_eval = force
231
+ if not force:
232
+ loss_missing = epoch not in self.criterion.loss_stat[self.mode]['Total'] # should it switch to all loss types?
233
+
234
+ metric_missing = False
235
+ for metric_type in self.criterion.metric:
236
+ if epoch not in self.criterion.metric_stat[mode][metric_type]:
237
+ metric_missing = True
238
+
239
+ do_eval = loss_missing or metric_missing
240
+
241
+ if do_eval:
242
+ try:
243
+ self.load(epoch)
244
+ self.evaluate(epoch, self.mode)
245
+ except:
246
+ # print('saved model/optimizer at epoch {} not found!'.format(epoch))
247
+ pass
248
+
249
+ return
deblur/src/train.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ bash ./prepare.sh
4
+ python launch.py --n_GPUs 2 main.py --batch_size 16 --amp true --save_dir RecLamRes_MSE
deblur/src/train01.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ bash ./prepare.sh
4
+ CUDA_VISIBLE_DEVICES=0,1 python launch.py --n_GPUs 2 main.py --batch_size 16 --amp true --save_dir RecLamRes_MSE
deblur/src/train23.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ bash ./prepare.sh
4
+ CUDA_VISIBLE_DEVICES=2,3 python launch.py --n_GPUs 2 main.py --batch_size 8 --amp true --save_dir RecLamRes_MSE_detach --detach true --master_port 8025
deblur/src/utils.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import readline
2
+ import rlcompleter
3
+ readline.parse_and_bind("tab: complete")
4
+ import code
5
+ import pdb
6
+
7
+ import time
8
+ import argparse
9
+ import os
10
+ import imageio
11
+ import torch
12
+ import torch.multiprocessing as mp
13
+
14
+ # debugging tools
15
+ def interact(local=None):
16
+ """interactive console with autocomplete function. Useful for debugging.
17
+ interact(locals())
18
+ """
19
+ if local is None:
20
+ local=dict(globals(), **locals())
21
+
22
+ readline.set_completer(rlcompleter.Completer(local).complete)
23
+ code.interact(local=local)
24
+
25
+ def set_trace(local=None):
26
+ """debugging with pdb
27
+ """
28
+ if local is None:
29
+ local=dict(globals(), **locals())
30
+
31
+ pdb.Pdb.complete = rlcompleter.Completer(local).complete
32
+ pdb.set_trace()
33
+
34
+ # timer
35
+ class Timer():
36
+ """Brought from https://github.com/thstkdgus35/EDSR-PyTorch
37
+ """
38
+ def __init__(self):
39
+ self.acc = 0
40
+ self.tic()
41
+
42
+ def tic(self):
43
+ self.t0 = time.time()
44
+
45
+ def toc(self):
46
+ return time.time() - self.t0
47
+
48
+ def hold(self):
49
+ self.acc += self.toc()
50
+
51
+ def release(self):
52
+ ret = self.acc
53
+ self.acc = 0
54
+
55
+ return ret
56
+
57
+ def reset(self):
58
+ self.acc = 0
59
+
60
+
61
+ # argument parser type casting functions
62
+ def str2bool(val):
63
+ """enable default constant true arguments"""
64
+ # https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
65
+ if isinstance(val, bool):
66
+ return val
67
+ elif val.lower() == 'true':
68
+ return True
69
+ elif val.lower() == 'false':
70
+ return False
71
+ else:
72
+ raise argparse.ArgumentTypeError('Boolean value expected')
73
+
74
+ def int2str(val):
75
+ """convert int to str for environment variable related arguments"""
76
+ if isinstance(val, int):
77
+ return str(val)
78
+ elif isinstance(val, str):
79
+ return val
80
+ else:
81
+ raise argparse.ArgumentTypeError('number value expected')
82
+
83
+
84
+ # image saver using multiprocessing queue
85
+ class MultiSaver():
86
+ def __init__(self, result_dir=None):
87
+ self.queue = None
88
+ self.process = None
89
+ self.result_dir = result_dir
90
+
91
+ def begin_background(self):
92
+ self.queue = mp.Queue()
93
+
94
+ def t(queue):
95
+ while True:
96
+ if queue.empty():
97
+ continue
98
+ img, name = queue.get()
99
+ if name:
100
+ try:
101
+ basename, ext = os.path.splitext(name)
102
+ if ext != '.png':
103
+ name = '{}.png'.format(basename)
104
+ imageio.imwrite(name, img)
105
+ except Exception as e:
106
+ print(e)
107
+ else:
108
+ return
109
+
110
+ worker = lambda: mp.Process(target=t, args=(self.queue,), daemon=False)
111
+ cpu_count = min(8, mp.cpu_count() - 1)
112
+ self.process = [worker() for _ in range(cpu_count)]
113
+ for p in self.process:
114
+ p.start()
115
+
116
+ def end_background(self):
117
+ if self.queue is None:
118
+ return
119
+
120
+ for _ in self.process:
121
+ self.queue.put((None, None))
122
+
123
+ def join_background(self):
124
+ if self.queue is None:
125
+ return
126
+
127
+ while not self.queue.empty():
128
+ time.sleep(0.5)
129
+
130
+ for p in self.process:
131
+ p.join()
132
+
133
+ self.queue = None
134
+
135
+ def save_image(self, output, save_names, result_dir=None):
136
+ result_dir = result_dir if self.result_dir is None else self.result_dir
137
+ if result_dir is None:
138
+ raise Exception('no result dir specified!')
139
+
140
+ if self.queue is None:
141
+ try:
142
+ self.begin_background()
143
+ except Exception as e:
144
+ print(e)
145
+ return
146
+
147
+ # assume NCHW format
148
+ if output.ndim == 2:
149
+ output = output.expand([1, 1] + list(output.shape))
150
+ elif output.ndim == 3:
151
+ output = output.expand([1] + list(output.shape))
152
+
153
+ for output_img, save_name in zip(output, save_names):
154
+ # assume image range [0, 255]
155
+ output_img = output_img.add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
156
+
157
+ save_name = os.path.join(result_dir, save_name)
158
+ save_dir = os.path.dirname(save_name)
159
+ os.makedirs(save_dir, exist_ok=True)
160
+
161
+ self.queue.put((output_img, save_name))
162
+
163
+ return
164
+
165
+ class Map(dict):
166
+ """
167
+ https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
168
+ Example:
169
+ m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer'])
170
+ """
171
+ def __init__(self, *args, **kwargs):
172
+ super(Map, self).__init__(*args, **kwargs)
173
+ for arg in args:
174
+ if isinstance(arg, dict):
175
+ for k, v in arg.items():
176
+ self[k] = v
177
+
178
+ if kwargs:
179
+ for k, v in kwargs.items():
180
+ self[k] = v
181
+
182
+ def __getattr__(self, attr):
183
+ return self.get(attr)
184
+
185
+ def __setattr__(self, key, value):
186
+ self.__setitem__(key, value)
187
+
188
+ def __setitem__(self, key, value):
189
+ super(Map, self).__setitem__(key, value)
190
+ self.__dict__.update({key: value})
191
+
192
+ def __delattr__(self, item):
193
+ self.__delitem__(item)
194
+
195
+ def __delitem__(self, key):
196
+ super(Map, self).__delitem__(key)
197
+ del self.__dict__[key]
198
+
199
+ def toDict(self):
200
+ return self.__dict__