Upload folder using huggingface_hub
Browse files- deblur/experiment/.gitignore +1 -0
- deblur/src/data/__init__.py +79 -0
- deblur/src/data/common.py +165 -0
- deblur/src/data/dataset.py +154 -0
- deblur/src/data/demo.py +22 -0
- deblur/src/data/gopro_large.py +23 -0
- deblur/src/data/reds.py +28 -0
- deblur/src/data/sampler.py +115 -0
- deblur/src/lambda_networks/__init__.py +3 -0
- deblur/src/lambda_networks/lambda_networks.py +80 -0
- deblur/src/lambda_networks/rlambda_networks.py +93 -0
- deblur/src/launch.py +55 -0
- deblur/src/loss/__init__.py +500 -0
- deblur/src/loss/adversarial.py +52 -0
- deblur/src/loss/metric.py +112 -0
- deblur/src/main.py +67 -0
- deblur/src/model/LamResNet.py +72 -0
- deblur/src/model/MSResNet.py +67 -0
- deblur/src/model/MSResNetLambda.py +77 -0
- deblur/src/model/RaftNet.py +120 -0
- deblur/src/model/RecLamResNet.py +54 -0
- deblur/src/model/ResNet.py +41 -0
- deblur/src/model/__init__.py +136 -0
- deblur/src/model/common.py +161 -0
- deblur/src/model/discriminator.py +41 -0
- deblur/src/model/structure.py +56 -0
- deblur/src/optim/__init__.py +206 -0
- deblur/src/optim/warm_multi_step_lr.py +32 -0
- deblur/src/option.py +278 -0
- deblur/src/prepare.sh +16 -0
- deblur/src/template.py +9 -0
- deblur/src/train-rec-1.sh +4 -0
- deblur/src/train.py +249 -0
- deblur/src/train.sh +4 -0
- deblur/src/train01.sh +4 -0
- deblur/src/train23.sh +4 -0
- deblur/src/utils.py +200 -0
deblur/experiment/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
!.gitignore
|
deblur/src/data/__init__.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Generic dataset loader"""
|
2 |
+
|
3 |
+
from importlib import import_module
|
4 |
+
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from torch.utils.data import SequentialSampler, RandomSampler
|
7 |
+
from torch.utils.data.distributed import DistributedSampler
|
8 |
+
from .sampler import DistributedEvalSampler
|
9 |
+
|
10 |
+
class Data():
|
11 |
+
def __init__(self, args):
|
12 |
+
|
13 |
+
self.modes = ['train', 'val', 'test', 'demo']
|
14 |
+
|
15 |
+
self.action = {
|
16 |
+
'train': args.do_train,
|
17 |
+
'val': args.do_validate,
|
18 |
+
'test': args.do_test,
|
19 |
+
'demo': args.demo
|
20 |
+
}
|
21 |
+
|
22 |
+
self.dataset_name = {
|
23 |
+
'train': args.data_train,
|
24 |
+
'val': args.data_val,
|
25 |
+
'test': args.data_test,
|
26 |
+
'demo': 'Demo'
|
27 |
+
}
|
28 |
+
|
29 |
+
self.args = args
|
30 |
+
|
31 |
+
def _get_data_loader(mode='train'):
|
32 |
+
dataset_name = self.dataset_name[mode]
|
33 |
+
dataset = import_module('data.' + dataset_name.lower())
|
34 |
+
dataset = getattr(dataset, dataset_name)(args, mode)
|
35 |
+
|
36 |
+
if mode == 'train':
|
37 |
+
if args.distributed:
|
38 |
+
batch_size = int(args.batch_size / args.n_GPUs) # batch size per GPU (single-node training)
|
39 |
+
sampler = DistributedSampler(dataset, shuffle=True, num_replicas=args.world_size, rank=args.rank)
|
40 |
+
num_workers = int((args.num_workers + args.n_GPUs - 1) / args.n_GPUs) # num_workers per GPU (single-node training)
|
41 |
+
else:
|
42 |
+
batch_size = args.batch_size
|
43 |
+
sampler = RandomSampler(dataset, replacement=False)
|
44 |
+
num_workers = args.num_workers
|
45 |
+
drop_last = True
|
46 |
+
|
47 |
+
elif mode in ('val', 'test', 'demo'):
|
48 |
+
if args.distributed:
|
49 |
+
batch_size = 1 # 1 image per GPU
|
50 |
+
sampler = DistributedEvalSampler(dataset, shuffle=False, num_replicas=args.world_size, rank=args.rank)
|
51 |
+
num_workers = int((args.num_workers + args.n_GPUs - 1) / args.n_GPUs) # num_workers per GPU (single-node training)
|
52 |
+
else:
|
53 |
+
batch_size = args.n_GPUs # 1 image per GPU
|
54 |
+
sampler = SequentialSampler(dataset)
|
55 |
+
num_workers = args.num_workers
|
56 |
+
drop_last = False
|
57 |
+
|
58 |
+
loader = DataLoader(
|
59 |
+
dataset=dataset,
|
60 |
+
batch_size=batch_size,
|
61 |
+
shuffle=False,
|
62 |
+
sampler=sampler,
|
63 |
+
num_workers=num_workers,
|
64 |
+
pin_memory=True,
|
65 |
+
drop_last=drop_last,
|
66 |
+
)
|
67 |
+
|
68 |
+
return loader
|
69 |
+
|
70 |
+
self.loaders = {}
|
71 |
+
for mode in self.modes:
|
72 |
+
if self.action[mode]:
|
73 |
+
self.loaders[mode] = _get_data_loader(mode)
|
74 |
+
print('===> Loading {} dataset: {}'.format(mode, self.dataset_name[mode]))
|
75 |
+
else:
|
76 |
+
self.loaders[mode] = None
|
77 |
+
|
78 |
+
def get_loader(self):
|
79 |
+
return self.loaders
|
deblur/src/data/common.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
from skimage.color import rgb2hsv, hsv2rgb
|
4 |
+
# from skimage.transform import pyramid_gaussian
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
def _apply(func, x):
|
9 |
+
|
10 |
+
if isinstance(x, (list, tuple)):
|
11 |
+
return [_apply(func, x_i) for x_i in x]
|
12 |
+
elif isinstance(x, dict):
|
13 |
+
y = {}
|
14 |
+
for key, value in x.items():
|
15 |
+
y[key] = _apply(func, value)
|
16 |
+
return y
|
17 |
+
else:
|
18 |
+
return func(x)
|
19 |
+
|
20 |
+
def crop(*args, ps=256): # patch_size
|
21 |
+
# args = [input, target]
|
22 |
+
def _get_shape(*args):
|
23 |
+
if isinstance(args[0], (list, tuple)):
|
24 |
+
return _get_shape(args[0][0])
|
25 |
+
elif isinstance(args[0], dict):
|
26 |
+
return _get_shape(list(args[0].values())[0])
|
27 |
+
else:
|
28 |
+
return args[0].shape
|
29 |
+
|
30 |
+
h, w, _ = _get_shape(args)
|
31 |
+
|
32 |
+
py = random.randrange(0, h-ps+1)
|
33 |
+
px = random.randrange(0, w-ps+1)
|
34 |
+
|
35 |
+
def _crop(img):
|
36 |
+
if img.ndim == 2:
|
37 |
+
return img[py:py+ps, px:px+ps, np.newaxis]
|
38 |
+
else:
|
39 |
+
return img[py:py+ps, px:px+ps, :]
|
40 |
+
|
41 |
+
return _apply(_crop, args)
|
42 |
+
|
43 |
+
def add_noise(*args, sigma_sigma=2, rgb_range=255):
|
44 |
+
|
45 |
+
if len(args) == 1: # usually there is only a single input
|
46 |
+
args = args[0]
|
47 |
+
|
48 |
+
sigma = np.random.normal() * sigma_sigma * rgb_range/255
|
49 |
+
|
50 |
+
def _add_noise(img):
|
51 |
+
noise = np.random.randn(*img.shape).astype(np.float32) * sigma
|
52 |
+
return (img + noise).clip(0, rgb_range)
|
53 |
+
|
54 |
+
return _apply(_add_noise, args)
|
55 |
+
|
56 |
+
def augment(*args, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=255):
|
57 |
+
"""augmentation consistent to input and target"""
|
58 |
+
|
59 |
+
choices = (False, True)
|
60 |
+
|
61 |
+
hflip = hflip and random.choice(choices)
|
62 |
+
vflip = rot and random.choice(choices)
|
63 |
+
rot90 = rot and random.choice(choices)
|
64 |
+
# shuffle = shuffle
|
65 |
+
|
66 |
+
if shuffle:
|
67 |
+
rgb_order = list(range(3))
|
68 |
+
random.shuffle(rgb_order)
|
69 |
+
if rgb_order == list(range(3)):
|
70 |
+
shuffle = False
|
71 |
+
|
72 |
+
if change_saturation:
|
73 |
+
amp_factor = np.random.uniform(0.5, 1.5)
|
74 |
+
|
75 |
+
def _augment(img):
|
76 |
+
if hflip: img = img[:, ::-1, :]
|
77 |
+
if vflip: img = img[::-1, :, :]
|
78 |
+
if rot90: img = img.transpose(1, 0, 2)
|
79 |
+
if shuffle and img.ndim > 2:
|
80 |
+
if img.shape[-1] == 3: # RGB image only
|
81 |
+
img = img[..., rgb_order]
|
82 |
+
|
83 |
+
if change_saturation:
|
84 |
+
hsv_img = rgb2hsv(img)
|
85 |
+
hsv_img[..., 1] *= amp_factor
|
86 |
+
|
87 |
+
img = hsv2rgb(hsv_img).clip(0, 1) * rgb_range
|
88 |
+
|
89 |
+
return img.astype(np.float32)
|
90 |
+
|
91 |
+
return _apply(_augment, args)
|
92 |
+
|
93 |
+
def pad(img, divisor=4, pad_width=None, negative=False):
|
94 |
+
|
95 |
+
def _pad_numpy(img, divisor=4, pad_width=None, negative=False):
|
96 |
+
if pad_width is None:
|
97 |
+
(h, w, _) = img.shape
|
98 |
+
pad_h = -h % divisor
|
99 |
+
pad_w = -w % divisor
|
100 |
+
pad_width = ((0, pad_h), (0, pad_w), (0, 0))
|
101 |
+
|
102 |
+
img = np.pad(img, pad_width, mode='edge')
|
103 |
+
|
104 |
+
return img, pad_width
|
105 |
+
|
106 |
+
def _pad_tensor(img, divisor=4, pad_width=None, negative=False):
|
107 |
+
|
108 |
+
n, c, h, w = img.shape
|
109 |
+
if pad_width is None:
|
110 |
+
pad_h = -h % divisor
|
111 |
+
pad_w = -w % divisor
|
112 |
+
pad_width = (0, pad_w, 0, pad_h)
|
113 |
+
else:
|
114 |
+
try:
|
115 |
+
pad_h = pad_width[0][1]
|
116 |
+
pad_w = pad_width[1][1]
|
117 |
+
if isinstance(pad_h, torch.Tensor):
|
118 |
+
pad_h = pad_h.item()
|
119 |
+
if isinstance(pad_w, torch.Tensor):
|
120 |
+
pad_w = pad_w.item()
|
121 |
+
|
122 |
+
pad_width = (0, pad_w, 0, pad_h)
|
123 |
+
except:
|
124 |
+
pass
|
125 |
+
|
126 |
+
if negative:
|
127 |
+
pad_width = [-val for val in pad_width]
|
128 |
+
|
129 |
+
img = torch.nn.functional.pad(img, pad_width, 'reflect')
|
130 |
+
|
131 |
+
return img, pad_width
|
132 |
+
|
133 |
+
if isinstance(img, np.ndarray):
|
134 |
+
return _pad_numpy(img, divisor, pad_width, negative)
|
135 |
+
else: # torch.Tensor
|
136 |
+
return _pad_tensor(img, divisor, pad_width, negative)
|
137 |
+
|
138 |
+
def generate_pyramid(*args, n_scales):
|
139 |
+
|
140 |
+
def _generate_pyramid(img):
|
141 |
+
if img.dtype != np.float32:
|
142 |
+
img = img.astype(np.float32)
|
143 |
+
# pyramid = list(pyramid_gaussian(img, n_scales-1, multichannel=True))
|
144 |
+
# bypass pyramid, deliver the image as is
|
145 |
+
pyramid = [img]
|
146 |
+
|
147 |
+
return pyramid
|
148 |
+
|
149 |
+
return _apply(_generate_pyramid, args)
|
150 |
+
|
151 |
+
def np2tensor(*args, rgb_range=255):
|
152 |
+
def _np2tensor(x):
|
153 |
+
np_transpose = np.ascontiguousarray(x.transpose(2, 0, 1))
|
154 |
+
tensor = torch.from_numpy(np_transpose).float()
|
155 |
+
tensor.mul_(rgb_range / 255)
|
156 |
+
return tensor
|
157 |
+
|
158 |
+
return _apply(_np2tensor, args)
|
159 |
+
|
160 |
+
def to(*args, device=None, dtype=torch.float):
|
161 |
+
|
162 |
+
def _to(x):
|
163 |
+
return x.to(device=device, dtype=dtype, non_blocking=True, copy=False)
|
164 |
+
|
165 |
+
return _apply(_to, args)
|
deblur/src/data/dataset.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import imageio
|
4 |
+
import numpy as np
|
5 |
+
import torch.utils.data as data
|
6 |
+
|
7 |
+
from data import common
|
8 |
+
|
9 |
+
from utils import interact
|
10 |
+
|
11 |
+
class Dataset(data.Dataset):
|
12 |
+
"""Basic dataloader class
|
13 |
+
"""
|
14 |
+
def __init__(self, args, mode='train'):
|
15 |
+
super(Dataset, self).__init__()
|
16 |
+
self.args = args
|
17 |
+
self.mode = mode
|
18 |
+
|
19 |
+
self.modes = ()
|
20 |
+
self.set_modes()
|
21 |
+
self._check_mode()
|
22 |
+
|
23 |
+
self.set_keys()
|
24 |
+
|
25 |
+
if self.mode == 'train':
|
26 |
+
dataset = args.data_train
|
27 |
+
elif self.mode == 'val':
|
28 |
+
dataset = args.data_val
|
29 |
+
elif self.mode == 'test':
|
30 |
+
dataset = args.data_test
|
31 |
+
elif self.mode == 'demo':
|
32 |
+
pass
|
33 |
+
else:
|
34 |
+
raise NotImplementedError('not implemented for this mode: {}!'.format(self.mode))
|
35 |
+
|
36 |
+
if self.mode == 'demo':
|
37 |
+
self.subset_root = args.demo_input_dir
|
38 |
+
else:
|
39 |
+
self.subset_root = os.path.join(args.data_root, dataset, self.mode)
|
40 |
+
|
41 |
+
self.blur_list = []
|
42 |
+
self.sharp_list = []
|
43 |
+
|
44 |
+
self._scan()
|
45 |
+
|
46 |
+
def set_modes(self):
|
47 |
+
self.modes = ('train', 'val', 'test', 'demo')
|
48 |
+
|
49 |
+
def _check_mode(self):
|
50 |
+
"""Should be called in the child class __init__() after super
|
51 |
+
"""
|
52 |
+
if self.mode not in self.modes:
|
53 |
+
raise NotImplementedError('mode error: not for {}'.format(self.mode))
|
54 |
+
|
55 |
+
return
|
56 |
+
|
57 |
+
def set_keys(self):
|
58 |
+
self.blur_key = 'blur' # to be overwritten by child class
|
59 |
+
self.sharp_key = 'sharp' # to be overwritten by child class
|
60 |
+
|
61 |
+
self.non_blur_keys = []
|
62 |
+
self.non_sharp_keys = []
|
63 |
+
|
64 |
+
return
|
65 |
+
|
66 |
+
def _scan(self, root=None):
|
67 |
+
"""Should be called in the child class __init__() after super
|
68 |
+
"""
|
69 |
+
if root is None:
|
70 |
+
root = self.subset_root
|
71 |
+
|
72 |
+
if self.blur_key in self.non_blur_keys:
|
73 |
+
self.non_blur_keys.remove(self.blur_key)
|
74 |
+
if self.sharp_key in self.non_sharp_keys:
|
75 |
+
self.non_sharp_keys.remove(self.sharp_key)
|
76 |
+
|
77 |
+
def _key_check(path, true_key, false_keys):
|
78 |
+
path = os.path.join(path, '')
|
79 |
+
if path.find(true_key) >= 0:
|
80 |
+
for false_key in false_keys:
|
81 |
+
if path.find(false_key) >= 0:
|
82 |
+
return False
|
83 |
+
|
84 |
+
return True
|
85 |
+
else:
|
86 |
+
return False
|
87 |
+
|
88 |
+
def _get_list_by_key(root, true_key, false_keys):
|
89 |
+
data_list = []
|
90 |
+
for sub, dirs, files in os.walk(root):
|
91 |
+
if not dirs:
|
92 |
+
file_list = [os.path.join(sub, f) for f in files]
|
93 |
+
if _key_check(sub, true_key, false_keys):
|
94 |
+
data_list += file_list
|
95 |
+
|
96 |
+
data_list.sort()
|
97 |
+
|
98 |
+
return data_list
|
99 |
+
|
100 |
+
def _rectify_keys():
|
101 |
+
self.blur_key = os.path.join(self.blur_key, '')
|
102 |
+
self.non_blur_keys = [os.path.join(non_blur_key, '') for non_blur_key in self.non_blur_keys]
|
103 |
+
self.sharp_key = os.path.join(self.sharp_key, '')
|
104 |
+
self.non_sharp_keys = [os.path.join(non_sharp_key, '') for non_sharp_key in self.non_sharp_keys]
|
105 |
+
|
106 |
+
_rectify_keys()
|
107 |
+
|
108 |
+
self.blur_list = _get_list_by_key(root, self.blur_key, self.non_blur_keys)
|
109 |
+
self.sharp_list = _get_list_by_key(root, self.sharp_key, self.non_sharp_keys)
|
110 |
+
|
111 |
+
if len(self.sharp_list) > 0:
|
112 |
+
assert(len(self.blur_list) == len(self.sharp_list))
|
113 |
+
|
114 |
+
return
|
115 |
+
|
116 |
+
def __getitem__(self, idx):
|
117 |
+
|
118 |
+
blur = imageio.imread(self.blur_list[idx], pilmode='RGB')
|
119 |
+
if len(self.sharp_list) > 0:
|
120 |
+
sharp = imageio.imread(self.sharp_list[idx], pilmode='RGB')
|
121 |
+
imgs = [blur, sharp]
|
122 |
+
else:
|
123 |
+
imgs = [blur]
|
124 |
+
|
125 |
+
pad_width = 0 # dummy value
|
126 |
+
if self.mode == 'train':
|
127 |
+
imgs = common.crop(*imgs, ps=self.args.patch_size)
|
128 |
+
if self.args.augment:
|
129 |
+
imgs = common.augment(*imgs, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=self.args.rgb_range)
|
130 |
+
imgs[0] = common.add_noise(imgs[0], sigma_sigma=2, rgb_range=self.args.rgb_range)
|
131 |
+
elif self.mode == 'demo':
|
132 |
+
imgs[0], pad_width = common.pad(imgs[0], divisor=2**(self.args.n_scales-1)) # pad in case of non-divisible size
|
133 |
+
else:
|
134 |
+
pass # deliver test image as is.
|
135 |
+
|
136 |
+
if self.args.gaussian_pyramid:
|
137 |
+
imgs = common.generate_pyramid(*imgs, n_scales=self.args.n_scales)
|
138 |
+
|
139 |
+
imgs = common.np2tensor(*imgs, rgb_range=self.args.rgb_range)
|
140 |
+
relpath = os.path.relpath(self.blur_list[idx], self.subset_root)
|
141 |
+
|
142 |
+
blur = imgs[0]
|
143 |
+
sharp = imgs[1] if len(imgs) > 1 else False
|
144 |
+
|
145 |
+
return blur, sharp, pad_width, idx, relpath
|
146 |
+
|
147 |
+
def __len__(self):
|
148 |
+
return len(self.blur_list)
|
149 |
+
# return 32
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
|
deblur/src/data/demo.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from data.dataset import Dataset
|
2 |
+
|
3 |
+
from utils import interact
|
4 |
+
|
5 |
+
class Demo(Dataset):
|
6 |
+
"""Demo train, test subset class
|
7 |
+
"""
|
8 |
+
def __init__(self, args, mode='demo'):
|
9 |
+
super(Demo, self).__init__(args, mode)
|
10 |
+
|
11 |
+
def set_modes(self):
|
12 |
+
self.modes = ('demo')
|
13 |
+
|
14 |
+
def set_keys(self):
|
15 |
+
super(Demo, self).set_keys()
|
16 |
+
self.blur_key = '' # all the files
|
17 |
+
self.non_sharp_keys = [''] # no files
|
18 |
+
|
19 |
+
def __getitem__(self, idx):
|
20 |
+
blur, sharp, pad_width, idx, relpath = super(Demo, self).__getitem__(idx)
|
21 |
+
|
22 |
+
return blur, sharp, pad_width, idx, relpath
|
deblur/src/data/gopro_large.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from data.dataset import Dataset
|
2 |
+
|
3 |
+
from utils import interact
|
4 |
+
|
5 |
+
class GOPRO_Large(Dataset):
|
6 |
+
"""GOPRO_Large train, test subset class
|
7 |
+
"""
|
8 |
+
def __init__(self, args, mode='train'):
|
9 |
+
super(GOPRO_Large, self).__init__(args, mode)
|
10 |
+
|
11 |
+
def set_modes(self):
|
12 |
+
self.modes = ('train', 'test')
|
13 |
+
|
14 |
+
def set_keys(self):
|
15 |
+
super(GOPRO_Large, self).set_keys()
|
16 |
+
self.blur_key = 'blur_gamma'
|
17 |
+
# self.sharp_key = 'sharp'
|
18 |
+
|
19 |
+
def __getitem__(self, idx):
|
20 |
+
blur, sharp, pad_width, idx, relpath = super(GOPRO_Large, self).__getitem__(idx)
|
21 |
+
relpath = relpath.replace('{}/'.format(self.blur_key), '')
|
22 |
+
|
23 |
+
return blur, sharp, pad_width, idx, relpath
|
deblur/src/data/reds.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from data.dataset import Dataset
|
2 |
+
|
3 |
+
from utils import interact
|
4 |
+
|
5 |
+
class REDS(Dataset):
|
6 |
+
"""REDS train, val, test subset class
|
7 |
+
"""
|
8 |
+
def __init__(self, args, mode='train'):
|
9 |
+
super(REDS, self).__init__(args, mode)
|
10 |
+
|
11 |
+
def set_modes(self):
|
12 |
+
self.modes = ('train', 'val', 'test')
|
13 |
+
|
14 |
+
def set_keys(self):
|
15 |
+
super(REDS, self).set_keys()
|
16 |
+
# self.blur_key = 'blur'
|
17 |
+
# self.sharp_key = 'sharp'
|
18 |
+
|
19 |
+
self.non_blur_keys = ['blur', 'blur_comp', 'blur_bicubic']
|
20 |
+
self.non_blur_keys.remove(self.blur_key)
|
21 |
+
self.non_sharp_keys = ['sharp_bicubic', 'sharp']
|
22 |
+
self.non_sharp_keys.remove(self.sharp_key)
|
23 |
+
|
24 |
+
def __getitem__(self, idx):
|
25 |
+
blur, sharp, pad_width, idx, relpath = super(REDS, self).__getitem__(idx)
|
26 |
+
relpath = relpath.replace('{}/{}/'.format(self.mode, self.blur_key), '')
|
27 |
+
|
28 |
+
return blur, sharp, pad_width, idx, relpath
|
deblur/src/data/sampler.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Sampler
|
4 |
+
import torch.distributed as dist
|
5 |
+
|
6 |
+
|
7 |
+
class DistributedEvalSampler(Sampler):
|
8 |
+
r"""
|
9 |
+
DistributedEvalSampler is different from DistributedSampler.
|
10 |
+
It does NOT add extra samples to make it evenly divisible.
|
11 |
+
DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever.
|
12 |
+
See this issue for details: https://github.com/pytorch/pytorch/issues/22584
|
13 |
+
shuffle is disabled by default
|
14 |
+
|
15 |
+
DistributedEvalSampler is for evaluation purpose where synchronization does not happen every epoch.
|
16 |
+
Synchronization should be done outside the dataloader loop.
|
17 |
+
|
18 |
+
Sampler that restricts data loading to a subset of the dataset.
|
19 |
+
|
20 |
+
It is especially useful in conjunction with
|
21 |
+
:class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
|
22 |
+
process can pass a :class`~torch.utils.data.DistributedSampler` instance as a
|
23 |
+
:class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
|
24 |
+
original dataset that is exclusive to it.
|
25 |
+
|
26 |
+
.. note::
|
27 |
+
Dataset is assumed to be of constant size.
|
28 |
+
|
29 |
+
Arguments:
|
30 |
+
dataset: Dataset used for sampling.
|
31 |
+
num_replicas (int, optional): Number of processes participating in
|
32 |
+
distributed training. By default, :attr:`rank` is retrieved from the
|
33 |
+
current distributed group.
|
34 |
+
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
|
35 |
+
By default, :attr:`rank` is retrieved from the current distributed
|
36 |
+
group.
|
37 |
+
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
|
38 |
+
indices.
|
39 |
+
seed (int, optional): random seed used to shuffle the sampler if
|
40 |
+
:attr:`shuffle=True`. This number should be identical across all
|
41 |
+
processes in the distributed group. Default: ``0``.
|
42 |
+
|
43 |
+
.. warning::
|
44 |
+
In distributed mode, calling the :meth`set_epoch(epoch) <set_epoch>` method at
|
45 |
+
the beginning of each epoch **before** creating the :class:`DataLoader` iterator
|
46 |
+
is necessary to make shuffling work properly across multiple epochs. Otherwise,
|
47 |
+
the same ordering will be always used.
|
48 |
+
|
49 |
+
Example::
|
50 |
+
|
51 |
+
>>> sampler = DistributedSampler(dataset) if is_distributed else None
|
52 |
+
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
|
53 |
+
... sampler=sampler)
|
54 |
+
>>> for epoch in range(start_epoch, n_epochs):
|
55 |
+
... if is_distributed:
|
56 |
+
... sampler.set_epoch(epoch)
|
57 |
+
... train(loader)
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, seed=0):
|
61 |
+
if num_replicas is None:
|
62 |
+
if not dist.is_available():
|
63 |
+
raise RuntimeError("Requires distributed package to be available")
|
64 |
+
num_replicas = dist.get_world_size()
|
65 |
+
if rank is None:
|
66 |
+
if not dist.is_available():
|
67 |
+
raise RuntimeError("Requires distributed package to be available")
|
68 |
+
rank = dist.get_rank()
|
69 |
+
self.dataset = dataset
|
70 |
+
self.num_replicas = num_replicas
|
71 |
+
self.rank = rank
|
72 |
+
self.epoch = 0
|
73 |
+
# self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
74 |
+
# self.total_size = self.num_samples * self.num_replicas
|
75 |
+
self.total_size = len(self.dataset) # true value without extra samples
|
76 |
+
indices = list(range(self.total_size))
|
77 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
78 |
+
self.num_samples = len(indices) # true value without extra samples
|
79 |
+
|
80 |
+
self.shuffle = shuffle
|
81 |
+
self.seed = seed
|
82 |
+
|
83 |
+
def __iter__(self):
|
84 |
+
if self.shuffle:
|
85 |
+
# deterministically shuffle based on epoch and seed
|
86 |
+
g = torch.Generator()
|
87 |
+
g.manual_seed(self.seed + self.epoch)
|
88 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
89 |
+
else:
|
90 |
+
indices = list(range(len(self.dataset)))
|
91 |
+
|
92 |
+
|
93 |
+
# # add extra samples to make it evenly divisible
|
94 |
+
# indices += indices[:(self.total_size - len(indices))]
|
95 |
+
# assert len(indices) == self.total_size
|
96 |
+
|
97 |
+
# subsample
|
98 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
99 |
+
assert len(indices) == self.num_samples
|
100 |
+
|
101 |
+
return iter(indices)
|
102 |
+
|
103 |
+
def __len__(self):
|
104 |
+
return self.num_samples
|
105 |
+
|
106 |
+
def set_epoch(self, epoch):
|
107 |
+
r"""
|
108 |
+
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
109 |
+
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
110 |
+
sampler will yield the same ordering.
|
111 |
+
|
112 |
+
Arguments:
|
113 |
+
epoch (int): _epoch number.
|
114 |
+
"""
|
115 |
+
self.epoch = epoch
|
deblur/src/lambda_networks/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from lambda_networks.lambda_networks import LambdaLayer
|
2 |
+
from lambda_networks.rlambda_networks import RLambdaLayer
|
3 |
+
λLayer = LambdaLayer
|
deblur/src/lambda_networks/lambda_networks.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, einsum
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
# helpers functions
|
7 |
+
|
8 |
+
def exists(val):
|
9 |
+
return val is not None
|
10 |
+
|
11 |
+
def default(val, d):
|
12 |
+
return val if exists(val) else d
|
13 |
+
|
14 |
+
# lambda layer
|
15 |
+
|
16 |
+
class LambdaLayer(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim,
|
20 |
+
*,
|
21 |
+
dim_k,
|
22 |
+
n = None,
|
23 |
+
r = None,
|
24 |
+
heads = 4,
|
25 |
+
dim_out = None,
|
26 |
+
dim_u = 1):
|
27 |
+
super().__init__()
|
28 |
+
dim_out = default(dim_out, dim)
|
29 |
+
self.u = dim_u # intra-depth dimension
|
30 |
+
self.heads = heads
|
31 |
+
|
32 |
+
assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
|
33 |
+
dim_v = dim_out // heads
|
34 |
+
|
35 |
+
self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
|
36 |
+
self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
|
37 |
+
self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
|
38 |
+
|
39 |
+
self.norm_q = nn.BatchNorm2d(dim_k * heads)
|
40 |
+
self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
|
41 |
+
|
42 |
+
self.local_contexts = exists(r)
|
43 |
+
if exists(r):
|
44 |
+
assert (r % 2) == 1, 'Receptive kernel size should be odd'
|
45 |
+
self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
|
46 |
+
else:
|
47 |
+
assert exists(n), 'You must specify the total sequence length (h x w)'
|
48 |
+
self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
|
49 |
+
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
|
53 |
+
|
54 |
+
q = self.to_q(x)
|
55 |
+
k = self.to_k(x)
|
56 |
+
v = self.to_v(x)
|
57 |
+
|
58 |
+
q = self.norm_q(q)
|
59 |
+
v = self.norm_v(v)
|
60 |
+
|
61 |
+
q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
|
62 |
+
k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
|
63 |
+
v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
|
64 |
+
|
65 |
+
k = k.softmax(dim=-1)
|
66 |
+
|
67 |
+
λc = einsum('b u k m, b u v m -> b k v', k, v)
|
68 |
+
Yc = einsum('b h k n, b k v -> b h v n', q, λc)
|
69 |
+
|
70 |
+
if self.local_contexts:
|
71 |
+
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
|
72 |
+
λp = self.pos_conv(v)
|
73 |
+
Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3))
|
74 |
+
else:
|
75 |
+
λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
|
76 |
+
Yp = einsum('b h k n, b n k v -> b h v n', q, λp)
|
77 |
+
|
78 |
+
Y = Yc + Yp
|
79 |
+
out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
|
80 |
+
return out
|
deblur/src/lambda_networks/rlambda_networks.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, einsum
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
|
7 |
+
# helpers functions
|
8 |
+
|
9 |
+
def exists(val):
|
10 |
+
return val is not None
|
11 |
+
|
12 |
+
|
13 |
+
def default(val, d):
|
14 |
+
return val if exists(val) else d
|
15 |
+
|
16 |
+
|
17 |
+
# lambda layer
|
18 |
+
|
19 |
+
class RLambdaLayer(nn.Module):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
dim,
|
23 |
+
*,
|
24 |
+
dim_k,
|
25 |
+
n=None,
|
26 |
+
r=None,
|
27 |
+
heads=4,
|
28 |
+
dim_out=None,
|
29 |
+
dim_u=1,
|
30 |
+
recurrence=None
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
dim_out = default(dim_out, dim)
|
34 |
+
self.u = dim_u # intra-depth dimension
|
35 |
+
self.heads = heads
|
36 |
+
|
37 |
+
assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
|
38 |
+
dim_v = dim_out // heads
|
39 |
+
|
40 |
+
self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias=False)
|
41 |
+
self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias=False)
|
42 |
+
self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias=False)
|
43 |
+
|
44 |
+
self.norm_q = nn.BatchNorm2d(dim_k * heads)
|
45 |
+
self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
|
46 |
+
|
47 |
+
self.local_contexts = exists(r)
|
48 |
+
self.recurrence = recurrence
|
49 |
+
if exists(r):
|
50 |
+
assert (r % 2) == 1, 'Receptive kernel size should be odd'
|
51 |
+
self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding=(0, r // 2, r // 2))
|
52 |
+
else:
|
53 |
+
assert exists(n), 'You must specify the total sequence length (h x w)'
|
54 |
+
self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
|
55 |
+
|
56 |
+
def apply_lambda(self, lambda_c, lambda_p, x):
|
57 |
+
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
|
58 |
+
q = self.to_q(x)
|
59 |
+
q = self.norm_q(q)
|
60 |
+
q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h=h)
|
61 |
+
Yc = einsum('b h k n, b k v -> b h v n', q, lambda_c)
|
62 |
+
if self.local_contexts:
|
63 |
+
Yp = einsum('b h k n, b k v n -> b h v n', q, lambda_p.flatten(3))
|
64 |
+
else:
|
65 |
+
Yp = einsum('b h k n, b n k v -> b h v n', q, lambda_p)
|
66 |
+
Y = Yc + Yp
|
67 |
+
out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh=hh, ww=ww)
|
68 |
+
return out
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
|
72 |
+
|
73 |
+
k = self.to_k(x)
|
74 |
+
v = self.to_v(x)
|
75 |
+
|
76 |
+
v = self.norm_v(v)
|
77 |
+
|
78 |
+
k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u=u)
|
79 |
+
v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u=u)
|
80 |
+
|
81 |
+
k = k.softmax(dim=-1)
|
82 |
+
|
83 |
+
λc = einsum('b u k m, b u v m -> b k v', k, v)
|
84 |
+
|
85 |
+
if self.local_contexts:
|
86 |
+
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh=hh, ww=ww)
|
87 |
+
λp = self.pos_conv(v)
|
88 |
+
else:
|
89 |
+
λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
|
90 |
+
out = x
|
91 |
+
for i in range(self.recurrence):
|
92 |
+
out = self.apply_lambda(λc, λp, out)
|
93 |
+
return out
|
deblur/src/launch.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" distributed launcher adopted from torch.distributed.launch
|
2 |
+
usage example: https://github.com/facebookresearch/maskrcnn-benchmark
|
3 |
+
This enables using multiprocessing for each spawned process (as they are treated as main processes)
|
4 |
+
"""
|
5 |
+
import sys
|
6 |
+
import subprocess
|
7 |
+
from argparse import ArgumentParser, REMAINDER
|
8 |
+
|
9 |
+
from utils import str2bool, int2str
|
10 |
+
|
11 |
+
def parse_args():
|
12 |
+
parser = ArgumentParser(description="PyTorch distributed training launch "
|
13 |
+
"helper utilty that will spawn up "
|
14 |
+
"multiple distributed processes")
|
15 |
+
|
16 |
+
|
17 |
+
parser.add_argument('--n_GPUs', type=int, default=1, help='the number of GPUs for training')
|
18 |
+
|
19 |
+
# positional
|
20 |
+
parser.add_argument("training_script", type=str,
|
21 |
+
help="The full path to the single GPU training "
|
22 |
+
"program/script to be launched in parallel, "
|
23 |
+
"followed by all the arguments for the "
|
24 |
+
"training script")
|
25 |
+
|
26 |
+
# rest from the training program
|
27 |
+
parser.add_argument('training_script_args', nargs=REMAINDER)
|
28 |
+
return parser.parse_args()
|
29 |
+
|
30 |
+
def main():
|
31 |
+
args = parse_args()
|
32 |
+
|
33 |
+
processes = []
|
34 |
+
for rank in range(0, args.n_GPUs):
|
35 |
+
cmd = [sys.executable]
|
36 |
+
|
37 |
+
cmd.append(args.training_script)
|
38 |
+
cmd.extend(args.training_script_args)
|
39 |
+
|
40 |
+
cmd += ['--distributed', 'True']
|
41 |
+
cmd += ['--launched', 'True']
|
42 |
+
cmd += ['--n_GPUs', str(args.n_GPUs)]
|
43 |
+
cmd += ['--rank', str(rank)]
|
44 |
+
|
45 |
+
process = subprocess.Popen(cmd)
|
46 |
+
processes.append(process)
|
47 |
+
|
48 |
+
for process in processes:
|
49 |
+
process.wait()
|
50 |
+
if process.returncode != 0:
|
51 |
+
raise subprocess.CalledProcessError(returncode=process.returncode,
|
52 |
+
cmd=cmd)
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
main()
|
deblur/src/loss/__init__.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from importlib import import_module
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import torch.distributed as dist
|
7 |
+
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
plt.switch_backend('agg') # https://github.com/matplotlib/matplotlib/issues/3466
|
10 |
+
|
11 |
+
from .metric import PSNR, SSIM
|
12 |
+
|
13 |
+
from utils import interact
|
14 |
+
|
15 |
+
def sequence_loss(sr, hr, loss_func, gamma=0.8, max_val=None):
|
16 |
+
""" Loss function defined over sequence of flow predictions """
|
17 |
+
|
18 |
+
n_recurrence = len(sr)
|
19 |
+
total_loss = 0.0
|
20 |
+
buffer=[0.0]*n_recurrence
|
21 |
+
# exlude invalid pixels and extremely large diplacements
|
22 |
+
for i in range(n_recurrence):
|
23 |
+
i_weight = gamma**(n_recurrence - i - 1)
|
24 |
+
i_loss = loss_func(sr[i],hr)
|
25 |
+
buffer[i]=i_loss.item()
|
26 |
+
# total_loss += i_weight * (valid[:, None] * i_loss).mean()
|
27 |
+
total_loss += i_weight * (i_loss)
|
28 |
+
return total_loss,buffer
|
29 |
+
|
30 |
+
|
31 |
+
class Loss(torch.nn.modules.loss._Loss):
|
32 |
+
def __init__(self, args, epoch=None, model=None, optimizer=None):
|
33 |
+
"""
|
34 |
+
input:
|
35 |
+
args.loss use '+' to sum over different loss functions
|
36 |
+
use '*' to specify the loss weight
|
37 |
+
|
38 |
+
example:
|
39 |
+
1*MSE+0.5*VGG54
|
40 |
+
loss = sum of MSE and VGG54(weight=0.5)
|
41 |
+
|
42 |
+
args.measure similar to args.loss, but without weight
|
43 |
+
|
44 |
+
example:
|
45 |
+
MSE+PSNR
|
46 |
+
measure MSE and PSNR, independently
|
47 |
+
"""
|
48 |
+
super(Loss, self).__init__()
|
49 |
+
|
50 |
+
self.args = args
|
51 |
+
|
52 |
+
self.rgb_range = args.rgb_range
|
53 |
+
self.gamma=args.gamma
|
54 |
+
self.device_type = args.device_type
|
55 |
+
self.synchronized = False
|
56 |
+
|
57 |
+
self.epoch = args.start_epoch if epoch is None else epoch
|
58 |
+
self.save_dir = args.save_dir
|
59 |
+
self.save_name = os.path.join(self.save_dir, 'loss.pt')
|
60 |
+
|
61 |
+
# self.training = True
|
62 |
+
self.validating = False
|
63 |
+
self.testing = False
|
64 |
+
self.mode = 'train'
|
65 |
+
self.modes = ('train', 'val', 'test')
|
66 |
+
|
67 |
+
# Loss
|
68 |
+
self.loss = nn.ModuleDict()
|
69 |
+
self.loss_types = []
|
70 |
+
self.weight = {}
|
71 |
+
self.buffer=[0.0]*args.n_scales
|
72 |
+
|
73 |
+
self.loss_stat = {mode:{} for mode in self.modes}
|
74 |
+
# loss_stat[mode][loss_type][epoch] = loss_value
|
75 |
+
# loss_stat[mode]['Total'][epoch] = loss_total
|
76 |
+
|
77 |
+
for weighted_loss in args.loss.split('+'):
|
78 |
+
w, l = weighted_loss.split('*')
|
79 |
+
l = l.upper()
|
80 |
+
if l in ('ABS', 'L1'):
|
81 |
+
loss_type = 'L1'
|
82 |
+
func = nn.L1Loss()
|
83 |
+
elif l in ('MSE', 'L2'):
|
84 |
+
loss_type = 'L2'
|
85 |
+
func = nn.MSELoss()
|
86 |
+
elif l in ('ADV', 'GAN'):
|
87 |
+
loss_type = 'ADV'
|
88 |
+
m = import_module('loss.adversarial')
|
89 |
+
func = getattr(m, 'Adversarial')(args, model, optimizer)
|
90 |
+
else:
|
91 |
+
loss_type = l
|
92 |
+
m = import_module*'loss.{}'.format(l.lower())
|
93 |
+
func = getattr(m, l)(args)
|
94 |
+
|
95 |
+
self.loss_types += [loss_type]
|
96 |
+
self.loss[loss_type] = func
|
97 |
+
self.weight[loss_type] = float(w)
|
98 |
+
|
99 |
+
print('Loss function: {}'.format(args.loss))
|
100 |
+
|
101 |
+
# Metrics
|
102 |
+
self.do_measure = args.metric.lower() != 'none'
|
103 |
+
|
104 |
+
self.metric = nn.ModuleDict()
|
105 |
+
self.metric_types = []
|
106 |
+
self.metric_stat = {mode:{} for mode in self.modes}
|
107 |
+
# metric_stat[mode][metric_type][epoch] = metric_value
|
108 |
+
|
109 |
+
if self.do_measure:
|
110 |
+
for metric_type in args.metric.split(','):
|
111 |
+
metric_type = metric_type.upper()
|
112 |
+
if metric_type == 'PSNR':
|
113 |
+
metric_func = PSNR()
|
114 |
+
elif metric_type == 'SSIM':
|
115 |
+
metric_func = SSIM(args.device_type) # single precision
|
116 |
+
else:
|
117 |
+
raise NotImplementedError
|
118 |
+
|
119 |
+
self.metric_types += [metric_type]
|
120 |
+
self.metric[metric_type] = metric_func
|
121 |
+
|
122 |
+
print('Metrics: {}'.format(args.metric))
|
123 |
+
|
124 |
+
if args.start_epoch != 1:
|
125 |
+
self.load(args.start_epoch - 1)
|
126 |
+
|
127 |
+
for mode in self.modes:
|
128 |
+
for loss_type in self.loss:
|
129 |
+
if loss_type not in self.loss_stat[mode]:
|
130 |
+
self.loss_stat[mode][loss_type] = {} # initialize loss
|
131 |
+
|
132 |
+
if 'Total' not in self.loss_stat[mode]:
|
133 |
+
self.loss_stat[mode]['Total'] = {}
|
134 |
+
|
135 |
+
if self.do_measure:
|
136 |
+
for metric_type in self.metric:
|
137 |
+
if metric_type not in self.metric_stat[mode]:
|
138 |
+
self.metric_stat[mode][metric_type] = {}
|
139 |
+
|
140 |
+
self.count = 0
|
141 |
+
self.count_m = 0
|
142 |
+
|
143 |
+
self.to(args.device, dtype=args.dtype)
|
144 |
+
|
145 |
+
def train(self, mode=True):
|
146 |
+
super(Loss, self).train(mode)
|
147 |
+
if mode:
|
148 |
+
self.validating = False
|
149 |
+
self.testing = False
|
150 |
+
self.mode = 'train'
|
151 |
+
else: # default test mode
|
152 |
+
self.validating = False
|
153 |
+
self.testing = True
|
154 |
+
self.mode = 'test'
|
155 |
+
|
156 |
+
def validate(self):
|
157 |
+
super(Loss, self).eval()
|
158 |
+
# self.training = False
|
159 |
+
self.validating = True
|
160 |
+
self.testing = False
|
161 |
+
self.mode = 'val'
|
162 |
+
|
163 |
+
def test(self):
|
164 |
+
super(Loss, self).eval()
|
165 |
+
# self.training = False
|
166 |
+
self.validating = False
|
167 |
+
self.testing = True
|
168 |
+
self.mode = 'test'
|
169 |
+
|
170 |
+
def forward(self, input, target):
|
171 |
+
self.synchronized = False
|
172 |
+
|
173 |
+
loss = 0
|
174 |
+
weights=[0.32,0.08,0.02,0.01,0.005]
|
175 |
+
if len(input)>len(weights):
|
176 |
+
for i in range(len(input)-len(weights)):
|
177 |
+
weights.append(weights[-1]*0.5)
|
178 |
+
if len(input)==1:
|
179 |
+
weights=[1.0]
|
180 |
+
weights=weights[::-1]
|
181 |
+
def _ms_forward(input, target, func):
|
182 |
+
if isinstance(input, (list, tuple)): # loss for list output
|
183 |
+
_loss,buffer_lst=sequence_loss(input,target[0],func,gamma=self.gamma)
|
184 |
+
# first=func(input[0],target[0])
|
185 |
+
# _loss = first*weights[0]
|
186 |
+
# self.buffer=[first.item()]
|
187 |
+
# for i in range(1,len(input)):
|
188 |
+
# tmp=func(input[i],target[0])
|
189 |
+
# self.buffer.append(tmp.item())
|
190 |
+
# _loss+=tmp*weights[i]
|
191 |
+
self.buffer=buffer_lst
|
192 |
+
return _loss
|
193 |
+
elif isinstance(input, dict): # loss for dict output
|
194 |
+
_loss = []
|
195 |
+
for key in input:
|
196 |
+
_loss += [func(input[key], target[key])]
|
197 |
+
return sum(_loss)
|
198 |
+
else: # loss for tensor output
|
199 |
+
return func(input, target)
|
200 |
+
|
201 |
+
# initialize
|
202 |
+
if self.count == 0:
|
203 |
+
for loss_type in self.loss_types:
|
204 |
+
self.loss_stat[self.mode][loss_type][self.epoch] = 0
|
205 |
+
self.loss_stat[self.mode]['Total'][self.epoch] = 0
|
206 |
+
|
207 |
+
if isinstance(input, list):
|
208 |
+
count = input[0].shape[0]
|
209 |
+
else: # Tensor
|
210 |
+
count = input.shape[0] # batch size
|
211 |
+
|
212 |
+
isnan = False
|
213 |
+
for loss_type in self.loss_types:
|
214 |
+
|
215 |
+
if loss_type == 'ADV':
|
216 |
+
_loss = self.loss[loss_type](input[0], target[0], self.training) * self.weight[loss_type]
|
217 |
+
else:
|
218 |
+
_loss = _ms_forward(input, target, self.loss[loss_type]) * self.weight[loss_type]
|
219 |
+
|
220 |
+
if torch.isnan(_loss):
|
221 |
+
isnan = True # skip recording (will also be skipped at backprop)
|
222 |
+
else:
|
223 |
+
self.loss_stat[self.mode][loss_type][self.epoch] += _loss.item() * count
|
224 |
+
self.loss_stat[self.mode]['Total'][self.epoch] += _loss.item() * count
|
225 |
+
|
226 |
+
loss += _loss
|
227 |
+
|
228 |
+
if not isnan:
|
229 |
+
self.count += count
|
230 |
+
|
231 |
+
if not self.training and self.do_measure:
|
232 |
+
self.measure(input, target)
|
233 |
+
|
234 |
+
return loss
|
235 |
+
|
236 |
+
def measure(self, input, target):
|
237 |
+
if isinstance(input, (list, tuple)):
|
238 |
+
self.measure(input[0], target[0])
|
239 |
+
return
|
240 |
+
elif isinstance(input, dict):
|
241 |
+
first_key = list(input.keys())[0]
|
242 |
+
self.measure(input[first_key], target[first_key])
|
243 |
+
return
|
244 |
+
else:
|
245 |
+
pass
|
246 |
+
|
247 |
+
if self.count_m == 0:
|
248 |
+
for metric_type in self.metric_stat[self.mode]:
|
249 |
+
self.metric_stat[self.mode][metric_type][self.epoch] = 0
|
250 |
+
|
251 |
+
if isinstance(input, list):
|
252 |
+
count = input[0].shape[0]
|
253 |
+
else: # Tensor
|
254 |
+
count = input.shape[0] # batch size
|
255 |
+
|
256 |
+
for metric_type in self.metric_stat[self.mode]:
|
257 |
+
|
258 |
+
input = input.clamp(0, self.rgb_range) # not in_place
|
259 |
+
if self.rgb_range==1:
|
260 |
+
input*=255
|
261 |
+
target*=255
|
262 |
+
input.round_()
|
263 |
+
target.round_()
|
264 |
+
if self.rgb_range == 255:
|
265 |
+
target*=255
|
266 |
+
input.round_()
|
267 |
+
|
268 |
+
_metric = self.metric[metric_type](input, target)
|
269 |
+
self.metric_stat[self.mode][metric_type][self.epoch] += _metric.item() * count
|
270 |
+
|
271 |
+
self.count_m += count
|
272 |
+
|
273 |
+
return
|
274 |
+
|
275 |
+
def normalize(self):
|
276 |
+
if self.args.distributed:
|
277 |
+
dist.barrier()
|
278 |
+
if not self.synchronized:
|
279 |
+
self.all_reduce()
|
280 |
+
|
281 |
+
if self.count > 0:
|
282 |
+
for loss_type in self.loss_stat[self.mode]: # including 'Total'
|
283 |
+
self.loss_stat[self.mode][loss_type][self.epoch] /= self.count
|
284 |
+
self.count = 0
|
285 |
+
|
286 |
+
if self.count_m > 0:
|
287 |
+
for metric_type in self.metric_stat[self.mode]:
|
288 |
+
self.metric_stat[self.mode][metric_type][self.epoch] /= self.count_m
|
289 |
+
self.count_m = 0
|
290 |
+
|
291 |
+
return
|
292 |
+
|
293 |
+
def all_reduce(self, epoch=None):
|
294 |
+
# synchronize loss for distributed GPU processes
|
295 |
+
|
296 |
+
if epoch is None:
|
297 |
+
epoch = self.epoch
|
298 |
+
|
299 |
+
def _reduce_value(value, ReduceOp=dist.ReduceOp.SUM):
|
300 |
+
value_tensor = torch.Tensor([value]).to(self.args.device, self.args.dtype, non_blocking=True)
|
301 |
+
dist.all_reduce(value_tensor, ReduceOp, async_op=False)
|
302 |
+
value = value_tensor.item()
|
303 |
+
del value_tensor
|
304 |
+
|
305 |
+
return value
|
306 |
+
|
307 |
+
dist.barrier()
|
308 |
+
if self.count > 0: # I assume this should be true
|
309 |
+
self.count = _reduce_value(self.count, dist.ReduceOp.SUM)
|
310 |
+
|
311 |
+
for loss_type in self.loss_stat[self.mode]:
|
312 |
+
self.loss_stat[self.mode][loss_type][epoch] = _reduce_value(
|
313 |
+
self.loss_stat[self.mode][loss_type][epoch],
|
314 |
+
dist.ReduceOp.SUM
|
315 |
+
)
|
316 |
+
|
317 |
+
if self.count_m > 0:
|
318 |
+
self.count_m = _reduce_value(self.count_m, dist.ReduceOp.SUM)
|
319 |
+
|
320 |
+
for metric_type in self.metric_stat[self.mode]:
|
321 |
+
self.metric_stat[self.mode][metric_type][epoch] = _reduce_value(
|
322 |
+
self.metric_stat[self.mode][metric_type][epoch],
|
323 |
+
dist.ReduceOp.SUM
|
324 |
+
)
|
325 |
+
|
326 |
+
self.synchronized = True
|
327 |
+
|
328 |
+
return
|
329 |
+
|
330 |
+
def print_metrics(self):
|
331 |
+
|
332 |
+
print(self.get_metric_desc())
|
333 |
+
return
|
334 |
+
|
335 |
+
def get_last_loss(self):
|
336 |
+
return self.loss_stat[self.mode]['Total'][self.epoch]
|
337 |
+
|
338 |
+
def get_loss_desc(self):
|
339 |
+
|
340 |
+
if self.mode == 'train':
|
341 |
+
desc_prefix = 'Train'
|
342 |
+
elif self.mode == 'val':
|
343 |
+
desc_prefix = 'Validation'
|
344 |
+
else:
|
345 |
+
desc_prefix = 'Test'
|
346 |
+
|
347 |
+
loss = self.loss_stat[self.mode]['Total'][self.epoch]
|
348 |
+
if self.count > 0:
|
349 |
+
loss /= self.count
|
350 |
+
desc = '{} Loss: {:.1f}'.format(desc_prefix, loss)
|
351 |
+
|
352 |
+
if self.mode in ('val', 'test'):
|
353 |
+
metric_desc = self.get_metric_desc()
|
354 |
+
desc = '{}{}'.format(desc, metric_desc)
|
355 |
+
|
356 |
+
return desc
|
357 |
+
|
358 |
+
def get_metric_desc(self):
|
359 |
+
desc = ''
|
360 |
+
for metric_type in self.metric_stat[self.mode]:
|
361 |
+
measured = self.metric_stat[self.mode][metric_type][self.epoch]
|
362 |
+
if self.count_m > 0:
|
363 |
+
measured /= self.count_m
|
364 |
+
|
365 |
+
if metric_type == 'PSNR':
|
366 |
+
desc += ' {}: {:2.2f}'.format(metric_type, measured)
|
367 |
+
elif metric_type == 'SSIM':
|
368 |
+
desc += ' {}: {:1.4f}'.format(metric_type, measured)
|
369 |
+
else:
|
370 |
+
desc += ' {}: {:2.4f}'.format(metric_type, measured)
|
371 |
+
|
372 |
+
return desc
|
373 |
+
|
374 |
+
def step(self, plot_name=None):
|
375 |
+
self.normalize()
|
376 |
+
self.plot(plot_name)
|
377 |
+
if not self.training and self.do_measure:
|
378 |
+
# self.print_metrics()
|
379 |
+
self.plot_metric()
|
380 |
+
# self.epoch += 1
|
381 |
+
|
382 |
+
return
|
383 |
+
|
384 |
+
def save(self):
|
385 |
+
|
386 |
+
state = {
|
387 |
+
'loss_stat': self.loss_stat,
|
388 |
+
'metric_stat': self.metric_stat,
|
389 |
+
}
|
390 |
+
torch.save(state, self.save_name)
|
391 |
+
|
392 |
+
return
|
393 |
+
|
394 |
+
def load(self, epoch=None):
|
395 |
+
|
396 |
+
print('Loading loss record from {}'.format(self.save_name))
|
397 |
+
if os.path.exists(self.save_name):
|
398 |
+
state = torch.load(self.save_name, map_location=self.args.device)
|
399 |
+
|
400 |
+
self.loss_stat = state['loss_stat']
|
401 |
+
if 'metric_stat' in state:
|
402 |
+
self.metric_stat = state['metric_stat']
|
403 |
+
else:
|
404 |
+
pass
|
405 |
+
else:
|
406 |
+
print('no loss record found for {}!'.format(self.save_name))
|
407 |
+
|
408 |
+
if epoch is not None:
|
409 |
+
self.epoch = epoch
|
410 |
+
|
411 |
+
return
|
412 |
+
|
413 |
+
def plot(self, plot_name=None, metric=False):
|
414 |
+
|
415 |
+
self.plot_loss(plot_name)
|
416 |
+
|
417 |
+
if metric:
|
418 |
+
self.plot_metric(plot_name)
|
419 |
+
# else:
|
420 |
+
# self.plot_loss(plot_name)
|
421 |
+
|
422 |
+
return
|
423 |
+
|
424 |
+
|
425 |
+
def plot_loss(self, plot_name=None):
|
426 |
+
if plot_name is None:
|
427 |
+
plot_name = os.path.join(self.save_dir, "{}_loss.pdf".format(self.mode))
|
428 |
+
|
429 |
+
title = "{} loss".format(self.mode)
|
430 |
+
|
431 |
+
fig = plt.figure()
|
432 |
+
plt.title(title)
|
433 |
+
plt.xlabel('epochs')
|
434 |
+
plt.ylabel('loss')
|
435 |
+
plt.grid(True, linestyle=':')
|
436 |
+
|
437 |
+
for loss_type, loss_record in self.loss_stat[self.mode].items(): # including Total
|
438 |
+
axis = sorted([epoch for epoch in loss_record.keys() if epoch <= self.epoch])
|
439 |
+
value = [self.loss_stat[self.mode][loss_type][epoch] for epoch in axis]
|
440 |
+
label = loss_type
|
441 |
+
|
442 |
+
plt.plot(axis, value, label=label)
|
443 |
+
|
444 |
+
plt.xlim(0, self.epoch)
|
445 |
+
plt.legend()
|
446 |
+
plt.savefig(plot_name)
|
447 |
+
plt.close(fig)
|
448 |
+
|
449 |
+
return
|
450 |
+
|
451 |
+
def plot_metric(self, plot_name=None):
|
452 |
+
# assume there are only max 2 metrics
|
453 |
+
if plot_name is None:
|
454 |
+
plot_name = os.path.join(self.save_dir, "{}_metric.pdf".format(self.mode))
|
455 |
+
|
456 |
+
title = "{} metrics".format(self.mode)
|
457 |
+
|
458 |
+
fig, ax1 = plt.subplots()
|
459 |
+
plt.title(title)
|
460 |
+
plt.grid(True, linestyle=':')
|
461 |
+
ax1.set_xlabel('epochs')
|
462 |
+
|
463 |
+
plots = None
|
464 |
+
for metric_type, metric_record in self.metric_stat[self.mode].items():
|
465 |
+
axis = sorted([epoch for epoch in metric_record.keys() if epoch <= self.epoch])
|
466 |
+
value = [metric_record[epoch] for epoch in axis]
|
467 |
+
label = metric_type
|
468 |
+
|
469 |
+
if metric_type == 'PSNR':
|
470 |
+
ax = ax1
|
471 |
+
color='C0'
|
472 |
+
elif metric_type == 'SSIM':
|
473 |
+
ax2 = ax1.twinx()
|
474 |
+
ax = ax2
|
475 |
+
color='C1'
|
476 |
+
|
477 |
+
ax.set_ylabel(metric_type)
|
478 |
+
if plots is None:
|
479 |
+
plots = ax.plot(axis, value, label=label, color=color)
|
480 |
+
else:
|
481 |
+
plots += ax.plot(axis, value, label=label, color=color)
|
482 |
+
|
483 |
+
labels = [plot.get_label() for plot in plots]
|
484 |
+
plt.legend(plots, labels)
|
485 |
+
plt.xlim(0, self.epoch)
|
486 |
+
plt.savefig(plot_name)
|
487 |
+
plt.close(fig)
|
488 |
+
|
489 |
+
return
|
490 |
+
|
491 |
+
def sort(self):
|
492 |
+
# sort the loss/metric record
|
493 |
+
for mode in self.modes:
|
494 |
+
for loss_type, loss_epochs in self.loss_stat[mode].items():
|
495 |
+
self.loss_stat[mode][loss_type] = {epoch: loss_epochs[epoch] for epoch in sorted(loss_epochs)}
|
496 |
+
|
497 |
+
for metric_type, metric_epochs in self.metric_stat[mode].items():
|
498 |
+
self.metric_stat[mode][metric_type] = {epoch: metric_epochs[epoch] for epoch in sorted(metric_epochs)}
|
499 |
+
|
500 |
+
return self
|
deblur/src/loss/adversarial.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from utils import interact
|
5 |
+
|
6 |
+
import torch.cuda.amp as amp
|
7 |
+
|
8 |
+
class Adversarial(nn.modules.loss._Loss):
|
9 |
+
# pure loss function without saving & loading option
|
10 |
+
# but trains deiscriminator
|
11 |
+
def __init__(self, args, model, optimizer):
|
12 |
+
super(Adversarial, self).__init__()
|
13 |
+
self.args = args
|
14 |
+
self.model = model.model
|
15 |
+
self.optimizer = optimizer
|
16 |
+
self.scaler = amp.GradScaler(
|
17 |
+
init_scale=self.args.init_scale,
|
18 |
+
enabled=self.args.amp
|
19 |
+
)
|
20 |
+
|
21 |
+
self.gan_k = 1
|
22 |
+
|
23 |
+
self.BCELoss = nn.BCEWithLogitsLoss()
|
24 |
+
|
25 |
+
def forward(self, fake, real, training=False):
|
26 |
+
if training:
|
27 |
+
# update discriminator
|
28 |
+
fake_detach = fake.detach()
|
29 |
+
for _ in range(self.gan_k):
|
30 |
+
self.optimizer.D.zero_grad()
|
31 |
+
# d: B x 1 tensor
|
32 |
+
with amp.autocast(self.args.amp):
|
33 |
+
d_fake = self.model.D(fake_detach)
|
34 |
+
d_real = self.model.D(real)
|
35 |
+
|
36 |
+
label_fake = torch.zeros_like(d_fake)
|
37 |
+
label_real = torch.ones_like(d_real)
|
38 |
+
|
39 |
+
loss_d = self.BCELoss(d_fake, label_fake) + self.BCELoss(d_real, label_real)
|
40 |
+
|
41 |
+
self.scaler.scale(loss_d).backward(retain_graph=False)
|
42 |
+
self.scaler.step(self.optimizer.D)
|
43 |
+
self.scaler.update()
|
44 |
+
else:
|
45 |
+
d_real = self.model.D(real)
|
46 |
+
label_real = torch.ones_like(d_real)
|
47 |
+
|
48 |
+
# update generator (outside here)
|
49 |
+
d_fake_bp = self.model.D(fake)
|
50 |
+
loss_g = self.BCELoss(d_fake_bp, label_real)
|
51 |
+
|
52 |
+
return loss_g
|
deblur/src/loss/metric.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from skimage.metrics import peak_signal_noise_ratio, structural_similarity
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
def _expand(img):
|
7 |
+
if img.ndim < 4:
|
8 |
+
img = img.expand([1] * (4-img.ndim) + list(img.shape))
|
9 |
+
|
10 |
+
return img
|
11 |
+
|
12 |
+
class PSNR(nn.Module):
|
13 |
+
def __init__(self):
|
14 |
+
super(PSNR, self).__init__()
|
15 |
+
|
16 |
+
def forward(self, im1, im2, data_range=None):
|
17 |
+
# tensor input, constant output
|
18 |
+
|
19 |
+
if data_range is None:
|
20 |
+
data_range = 255 if im1.max() > 1 else 1
|
21 |
+
|
22 |
+
se = (im1-im2)**2
|
23 |
+
se = _expand(se)
|
24 |
+
|
25 |
+
mse = se.mean(dim=list(range(1, se.ndim)))
|
26 |
+
psnr = 10 * (data_range**2/mse).log10().mean()
|
27 |
+
|
28 |
+
return psnr
|
29 |
+
|
30 |
+
class SSIM(nn.Module):
|
31 |
+
def __init__(self, device_type='cpu', dtype=torch.float32):
|
32 |
+
super(SSIM, self).__init__()
|
33 |
+
|
34 |
+
self.device_type = device_type
|
35 |
+
self.dtype = dtype # SSIM in half precision could be inaccurate
|
36 |
+
|
37 |
+
def _get_ssim_weight():
|
38 |
+
truncate = 3.5
|
39 |
+
sigma = 1.5
|
40 |
+
r = int(truncate * sigma + 0.5) # radius as in ndimage
|
41 |
+
win_size = 2 * r + 1
|
42 |
+
nch = 3
|
43 |
+
|
44 |
+
weight = torch.Tensor([-(x - win_size//2)**2/float(2*sigma**2) for x in range(win_size)]).exp().unsqueeze(1)
|
45 |
+
weight = weight.mm(weight.t())
|
46 |
+
weight /= weight.sum()
|
47 |
+
weight = weight.repeat(nch, 1, 1, 1)
|
48 |
+
|
49 |
+
return weight
|
50 |
+
|
51 |
+
self.weight = _get_ssim_weight().to(self.device_type, dtype=self.dtype, non_blocking=True)
|
52 |
+
|
53 |
+
def forward(self, im1, im2, data_range=None):
|
54 |
+
"""Implementation adopted from skimage.metrics.structural_similarity
|
55 |
+
Default arguments set to multichannel=True, gaussian_weight=True, use_sample_covariance=False
|
56 |
+
"""
|
57 |
+
|
58 |
+
im1 = im1.to(self.device_type, dtype=self.dtype, non_blocking=True)
|
59 |
+
im2 = im2.to(self.device_type, dtype=self.dtype, non_blocking=True)
|
60 |
+
|
61 |
+
K1 = 0.01
|
62 |
+
K2 = 0.03
|
63 |
+
sigma = 1.5
|
64 |
+
|
65 |
+
truncate = 3.5
|
66 |
+
r = int(truncate * sigma + 0.5) # radius as in ndimage
|
67 |
+
win_size = 2 * r + 1
|
68 |
+
|
69 |
+
im1 = _expand(im1)
|
70 |
+
im2 = _expand(im2)
|
71 |
+
|
72 |
+
nch = im1.shape[1]
|
73 |
+
|
74 |
+
if im1.shape[2] < win_size or im1.shape[3] < win_size:
|
75 |
+
raise ValueError(
|
76 |
+
"win_size exceeds image extent. If the input is a multichannel "
|
77 |
+
"(color) image, set multichannel=True.")
|
78 |
+
|
79 |
+
if data_range is None:
|
80 |
+
data_range = 255 if im1.max() > 1 else 1
|
81 |
+
|
82 |
+
def filter_func(img): # no padding
|
83 |
+
return nn.functional.conv2d(img, self.weight, groups=nch).to(self.dtype)
|
84 |
+
# return torch.conv2d(img, self.weight, groups=nch).to(self.dtype)
|
85 |
+
|
86 |
+
# compute (weighted) means
|
87 |
+
ux = filter_func(im1)
|
88 |
+
uy = filter_func(im2)
|
89 |
+
|
90 |
+
# compute (weighted) variances and covariances
|
91 |
+
uxx = filter_func(im1 * im1)
|
92 |
+
uyy = filter_func(im2 * im2)
|
93 |
+
uxy = filter_func(im1 * im2)
|
94 |
+
vx = (uxx - ux * ux)
|
95 |
+
vy = (uyy - uy * uy)
|
96 |
+
vxy = (uxy - ux * uy)
|
97 |
+
|
98 |
+
R = data_range
|
99 |
+
C1 = (K1 * R) ** 2
|
100 |
+
C2 = (K2 * R) ** 2
|
101 |
+
|
102 |
+
A1, A2, B1, B2 = ((2 * ux * uy + C1,
|
103 |
+
2 * vxy + C2,
|
104 |
+
ux ** 2 + uy ** 2 + C1,
|
105 |
+
vx + vy + C2))
|
106 |
+
D = B1 * B2
|
107 |
+
S = (A1 * A2) / D
|
108 |
+
|
109 |
+
# compute (weighted) mean of ssim
|
110 |
+
mssim = S.mean()
|
111 |
+
|
112 |
+
return mssim
|
deblur/src/main.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""main file that does everything"""
|
2 |
+
from utils import interact
|
3 |
+
|
4 |
+
from option import args, setup, cleanup
|
5 |
+
from data import Data
|
6 |
+
from model import Model
|
7 |
+
from loss import Loss
|
8 |
+
from optim import Optimizer
|
9 |
+
from train import Trainer
|
10 |
+
|
11 |
+
def main_worker(rank, args):
|
12 |
+
args.rank = rank
|
13 |
+
args = setup(args)
|
14 |
+
|
15 |
+
loaders = Data(args).get_loader()
|
16 |
+
model = Model(args)
|
17 |
+
model.parallelize()
|
18 |
+
optimizer = Optimizer(args, model)
|
19 |
+
|
20 |
+
criterion = Loss(args, model=model, optimizer=optimizer)
|
21 |
+
|
22 |
+
trainer = Trainer(args, model, criterion, optimizer, loaders)
|
23 |
+
|
24 |
+
if args.stay:
|
25 |
+
interact(local=locals())
|
26 |
+
exit()
|
27 |
+
|
28 |
+
if args.demo:
|
29 |
+
trainer.evaluate(epoch=args.start_epoch, mode='demo')
|
30 |
+
exit()
|
31 |
+
|
32 |
+
for epoch in range(1, args.start_epoch):
|
33 |
+
if args.do_validate:
|
34 |
+
if epoch % args.validate_every == 0:
|
35 |
+
trainer.fill_evaluation(epoch, 'val')
|
36 |
+
if args.do_test:
|
37 |
+
if epoch % args.test_every == 0:
|
38 |
+
trainer.fill_evaluation(epoch, 'test')
|
39 |
+
|
40 |
+
for epoch in range(args.start_epoch, args.end_epoch+1):
|
41 |
+
if args.do_train:
|
42 |
+
trainer.train(epoch)
|
43 |
+
|
44 |
+
if args.do_validate:
|
45 |
+
if epoch % args.validate_every == 0:
|
46 |
+
if trainer.epoch != epoch:
|
47 |
+
trainer.load(epoch)
|
48 |
+
trainer.validate(epoch)
|
49 |
+
|
50 |
+
if args.do_test:
|
51 |
+
if epoch % args.test_every == 0:
|
52 |
+
if trainer.epoch != epoch:
|
53 |
+
trainer.load(epoch)
|
54 |
+
trainer.test(epoch)
|
55 |
+
|
56 |
+
if args.rank == 0 or not args.launched:
|
57 |
+
print('')
|
58 |
+
|
59 |
+
trainer.imsaver.join_background()
|
60 |
+
|
61 |
+
cleanup(args)
|
62 |
+
|
63 |
+
def main():
|
64 |
+
main_worker(args.rank, args)
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
main()
|
deblur/src/model/LamResNet.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from . import common
|
4 |
+
|
5 |
+
from lambda_networks import LambdaLayer
|
6 |
+
|
7 |
+
|
8 |
+
def build_model(args):
|
9 |
+
return ResNet(args)
|
10 |
+
|
11 |
+
|
12 |
+
class ResNet(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
args,
|
16 |
+
in_channels=3,
|
17 |
+
out_channels=3,
|
18 |
+
n_feats=None,
|
19 |
+
kernel_size=None,
|
20 |
+
n_resblocks=None,
|
21 |
+
mean_shift=True,
|
22 |
+
):
|
23 |
+
super(ResNet, self).__init__()
|
24 |
+
|
25 |
+
self.in_channels = in_channels
|
26 |
+
self.out_channels = out_channels
|
27 |
+
|
28 |
+
self.n_feats = args.n_feats if n_feats is None else n_feats
|
29 |
+
self.kernel_size = args.kernel_size if kernel_size is None else kernel_size
|
30 |
+
self.n_resblocks = args.n_resblocks if n_resblocks is None else n_resblocks
|
31 |
+
|
32 |
+
self.mean_shift = mean_shift
|
33 |
+
self.rgb_range = args.rgb_range
|
34 |
+
self.mean = self.rgb_range / 2
|
35 |
+
|
36 |
+
modules = []
|
37 |
+
modules.append(
|
38 |
+
common.default_conv(self.in_channels, self.n_feats, self.kernel_size)
|
39 |
+
)
|
40 |
+
for _ in range(self.n_resblocks // 3):
|
41 |
+
modules.append(common.ResBlock(self.n_feats, self.kernel_size))
|
42 |
+
modules.append(
|
43 |
+
LambdaLayer(
|
44 |
+
dim=self.n_feats, dim_out=self.n_feats, r=23, dim_k=16, heads=4, dim_u=1
|
45 |
+
)
|
46 |
+
)
|
47 |
+
for _ in range(self.n_resblocks // 3):
|
48 |
+
modules.append(common.ResBlock(self.n_feats, self.kernel_size))
|
49 |
+
modules.append(
|
50 |
+
LambdaLayer(
|
51 |
+
dim=self.n_feats, dim_out=self.n_feats, r=7, dim_k=16, heads=4, dim_u=4
|
52 |
+
)
|
53 |
+
)
|
54 |
+
for _ in range(self.n_resblocks // 3):
|
55 |
+
modules.append(common.ResBlock(self.n_feats, self.kernel_size))
|
56 |
+
modules.append(
|
57 |
+
common.default_conv(self.n_feats, self.n_feats, self.kernel_size)
|
58 |
+
)
|
59 |
+
modules.append(common.default_conv(self.n_feats, self.out_channels, 1))
|
60 |
+
|
61 |
+
self.body = nn.Sequential(*modules)
|
62 |
+
|
63 |
+
def forward(self, input):
|
64 |
+
if self.mean_shift:
|
65 |
+
input = input - self.mean
|
66 |
+
|
67 |
+
output = self.body(input)
|
68 |
+
|
69 |
+
if self.mean_shift:
|
70 |
+
output = output + self.mean
|
71 |
+
|
72 |
+
return output
|
deblur/src/model/MSResNet.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from . import common
|
5 |
+
from .ResNet import ResNet
|
6 |
+
|
7 |
+
|
8 |
+
def build_model(args):
|
9 |
+
return MSResNet(args)
|
10 |
+
|
11 |
+
class conv_end(nn.Module):
|
12 |
+
def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2):
|
13 |
+
super(conv_end, self).__init__()
|
14 |
+
|
15 |
+
modules = [
|
16 |
+
common.default_conv(in_channels, out_channels, kernel_size),
|
17 |
+
nn.PixelShuffle(ratio)
|
18 |
+
]
|
19 |
+
|
20 |
+
self.uppath = nn.Sequential(*modules)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
return self.uppath(x)
|
24 |
+
|
25 |
+
class MSResNet(nn.Module):
|
26 |
+
def __init__(self, args):
|
27 |
+
super(MSResNet, self).__init__()
|
28 |
+
|
29 |
+
self.rgb_range = args.rgb_range
|
30 |
+
self.mean = self.rgb_range / 2
|
31 |
+
|
32 |
+
self.n_resblocks = args.n_resblocks
|
33 |
+
self.n_feats = args.n_feats
|
34 |
+
self.kernel_size = args.kernel_size
|
35 |
+
|
36 |
+
self.n_scales = args.n_scales
|
37 |
+
|
38 |
+
self.body_models = nn.ModuleList([
|
39 |
+
ResNet(args, 3, 3, mean_shift=False),
|
40 |
+
])
|
41 |
+
for _ in range(1, self.n_scales):
|
42 |
+
self.body_models.insert(0, ResNet(args, 6, 3, mean_shift=False))
|
43 |
+
|
44 |
+
self.conv_end_models = nn.ModuleList([None])
|
45 |
+
for _ in range(1, self.n_scales):
|
46 |
+
self.conv_end_models += [conv_end(3, 12)]
|
47 |
+
|
48 |
+
def forward(self, input_pyramid):
|
49 |
+
|
50 |
+
scales = range(self.n_scales-1, -1, -1) # 0: fine, 2: coarse
|
51 |
+
|
52 |
+
for s in scales:
|
53 |
+
input_pyramid[s] = input_pyramid[s] - self.mean
|
54 |
+
|
55 |
+
output_pyramid = [None] * self.n_scales
|
56 |
+
|
57 |
+
input_s = input_pyramid[-1]
|
58 |
+
for s in scales: # [2, 1, 0]
|
59 |
+
output_pyramid[s] = self.body_models[s](input_s)
|
60 |
+
if s > 0:
|
61 |
+
up_feat = self.conv_end_models[s](output_pyramid[s])
|
62 |
+
input_s = torch.cat((input_pyramid[s-1], up_feat), 1)
|
63 |
+
|
64 |
+
for s in scales:
|
65 |
+
output_pyramid[s] = output_pyramid[s] + self.mean
|
66 |
+
|
67 |
+
return output_pyramid
|
deblur/src/model/MSResNetLambda.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from . import common
|
5 |
+
from .ResNet import ResNet
|
6 |
+
from lambda_network import LambdaLayer
|
7 |
+
|
8 |
+
def build_model(args):
|
9 |
+
return MSResNet(args)
|
10 |
+
|
11 |
+
class conv_end(nn.Module):
|
12 |
+
def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2):
|
13 |
+
super(conv_end, self).__init__()
|
14 |
+
|
15 |
+
modules = [
|
16 |
+
common.default_conv(in_channels, out_channels, kernel_size),
|
17 |
+
nn.PixelShuffle(ratio)
|
18 |
+
]
|
19 |
+
|
20 |
+
self.uppath = nn.Sequential(*modules)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
return self.uppath(x)
|
24 |
+
|
25 |
+
class MSResNet(nn.Module):
|
26 |
+
def __init__(self, args):
|
27 |
+
super(MSResNet, self).__init__()
|
28 |
+
|
29 |
+
self.rgb_range = args.rgb_range
|
30 |
+
self.mean = self.rgb_range / 2
|
31 |
+
|
32 |
+
self.n_resblocks = args.n_resblocks
|
33 |
+
self.n_feats = args.n_feats
|
34 |
+
self.kernel_size = args.kernel_size
|
35 |
+
|
36 |
+
self.n_scales = args.n_scales
|
37 |
+
|
38 |
+
self.body_models = nn.ModuleList([
|
39 |
+
ResNet(args, 3, 3, mean_shift=False),
|
40 |
+
])
|
41 |
+
self.lambda_models = nn.ModuleList([
|
42 |
+
LambdaLayer(
|
43 |
+
dim = 32, # channels going in
|
44 |
+
dim_out = 32, # channels out
|
45 |
+
n = 64 * 64, # number of input pixels (64 x 64 image)
|
46 |
+
dim_k = 16, # key dimension
|
47 |
+
heads = 4, # number of heads, for multi-query
|
48 |
+
dim_u = 1 # 'intra-depth' dimension
|
49 |
+
)
|
50 |
+
])
|
51 |
+
for _ in range(1, self.n_scales):
|
52 |
+
self.body_models.insert(0, ResNet(args, 6, 3, mean_shift=False))
|
53 |
+
|
54 |
+
self.conv_end_models = nn.ModuleList([None])
|
55 |
+
for _ in range(1, self.n_scales):
|
56 |
+
self.conv_end_models += [conv_end(3, 12)]
|
57 |
+
|
58 |
+
def forward(self, input_pyramid):
|
59 |
+
|
60 |
+
scales = range(self.n_scales-1, -1, -1) # 0: fine, 2: coarse
|
61 |
+
|
62 |
+
for s in scales:
|
63 |
+
input_pyramid[s] = input_pyramid[s] - self.mean
|
64 |
+
|
65 |
+
output_pyramid = [None] * self.n_scales
|
66 |
+
|
67 |
+
input_s = input_pyramid[-1]
|
68 |
+
for s in scales: # [2, 1, 0]
|
69 |
+
output_pyramid[s] = self.body_models[s](input_s)
|
70 |
+
if s > 0:
|
71 |
+
up_feat = self.conv_end_models[s](output_pyramid[s])
|
72 |
+
input_s = torch.cat((input_pyramid[s-1], up_feat), 1)
|
73 |
+
|
74 |
+
for s in scales:
|
75 |
+
output_pyramid[s] = output_pyramid[s] + self.mean
|
76 |
+
|
77 |
+
return output_pyramid
|
deblur/src/model/RaftNet.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from . import common
|
5 |
+
|
6 |
+
from lambda_networks import LambdaLayer
|
7 |
+
|
8 |
+
|
9 |
+
def build_model(args):
|
10 |
+
return ResNet(args)
|
11 |
+
|
12 |
+
|
13 |
+
class ConvGRU(nn.Module):
|
14 |
+
def __init__(self, hidden_dim=128, input_dim=192+128):
|
15 |
+
super(ConvGRU, self).__init__()
|
16 |
+
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
17 |
+
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
18 |
+
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
19 |
+
|
20 |
+
def forward(self, h, x):
|
21 |
+
hx = torch.cat([h, x], dim=1)
|
22 |
+
|
23 |
+
z = torch.sigmoid(self.convz(hx))
|
24 |
+
r = torch.sigmoid(self.convr(hx))
|
25 |
+
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
|
26 |
+
|
27 |
+
# h = (1-z) * h + z * q
|
28 |
+
# return h
|
29 |
+
return (1-z) * h + z * q
|
30 |
+
|
31 |
+
class SepConvGRU(nn.Module):
|
32 |
+
def __init__(self, hidden_dim=128, input_dim=192+128):
|
33 |
+
super(SepConvGRU, self).__init__()
|
34 |
+
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
35 |
+
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
36 |
+
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
37 |
+
|
38 |
+
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
39 |
+
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
40 |
+
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
41 |
+
|
42 |
+
|
43 |
+
def forward(self, h, x):
|
44 |
+
# horizontal
|
45 |
+
hx = torch.cat([h, x], dim=1)
|
46 |
+
z = torch.sigmoid(self.convz1(hx))
|
47 |
+
r = torch.sigmoid(self.convr1(hx))
|
48 |
+
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
|
49 |
+
h = (1-z) * h + z * q
|
50 |
+
|
51 |
+
# vertical
|
52 |
+
hx = torch.cat([h, x], dim=1)
|
53 |
+
z = torch.sigmoid(self.convz2(hx))
|
54 |
+
r = torch.sigmoid(self.convr2(hx))
|
55 |
+
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
|
56 |
+
h = (1-z) * h + z * q
|
57 |
+
|
58 |
+
return h
|
59 |
+
|
60 |
+
|
61 |
+
class ResNet(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
args,
|
65 |
+
in_channels=3,
|
66 |
+
out_channels=3,
|
67 |
+
n_feats=None,
|
68 |
+
kernel_size=None,
|
69 |
+
n_resblocks=None,
|
70 |
+
mean_shift=True,
|
71 |
+
):
|
72 |
+
super(ResNet, self).__init__()
|
73 |
+
|
74 |
+
self.in_channels = in_channels
|
75 |
+
self.out_channels = out_channels
|
76 |
+
|
77 |
+
self.n_feats = args.n_feats if n_feats is None else n_feats
|
78 |
+
self.kernel_size = args.kernel_size if kernel_size is None else kernel_size
|
79 |
+
self.n_resblocks = args.n_resblocks if n_resblocks is None else n_resblocks
|
80 |
+
|
81 |
+
self.mean_shift = mean_shift
|
82 |
+
self.rgb_range = args.rgb_range
|
83 |
+
self.mean = self.rgb_range / 2
|
84 |
+
|
85 |
+
modules = []
|
86 |
+
m_head=[common.default_conv(self.in_channels, self.n_feats, self.kernel_size)]
|
87 |
+
for i in range(3):
|
88 |
+
m_head.append(common.ResBlock(self.n_feats, self.kernel_size))
|
89 |
+
for _ in range(self.n_resblocks // 2):
|
90 |
+
modules.append(common.ResBlock(self.n_feats, self.kernel_size))
|
91 |
+
modules.append(
|
92 |
+
LambdaLayer(
|
93 |
+
dim=self.n_feats, dim_out=self.n_feats, r=23, dim_k=16, heads=4, dim_u=4
|
94 |
+
)
|
95 |
+
)
|
96 |
+
for _ in range(self.n_resblocks // 2):
|
97 |
+
modules.append(common.ResBlock(self.n_feats, self.kernel_size))
|
98 |
+
m_tail=[]
|
99 |
+
|
100 |
+
for i in range(3):
|
101 |
+
m_tail.append(common.ResBlock(self.n_feats, self.kernel_size))
|
102 |
+
|
103 |
+
m_tail.append(
|
104 |
+
common.default_conv(self.n_feats, self.out_channels, self.kernel_size)
|
105 |
+
)
|
106 |
+
self.head=nn.Sequential(*m_head)
|
107 |
+
self.body = nn.Sequential(*modules)
|
108 |
+
self.tail=nn.Sequential(*m_tail)
|
109 |
+
self.gru=SepConvGRU(hidden_dim=self.n_feats,input_dim=self.n_feats)
|
110 |
+
|
111 |
+
def forward(self, input):
|
112 |
+
if self.mean_shift:
|
113 |
+
input = input - self.mean
|
114 |
+
|
115 |
+
output = self.body(input)
|
116 |
+
|
117 |
+
if self.mean_shift:
|
118 |
+
output = output + self.mean
|
119 |
+
|
120 |
+
return output
|
deblur/src/model/RecLamResNet.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from . import common
|
5 |
+
from .LamResNet import ResNet
|
6 |
+
|
7 |
+
|
8 |
+
def build_model(args):
|
9 |
+
return RecLamResNet(args)
|
10 |
+
|
11 |
+
|
12 |
+
class conv_end(nn.Module):
|
13 |
+
def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2):
|
14 |
+
super(conv_end, self).__init__()
|
15 |
+
|
16 |
+
modules = [
|
17 |
+
common.default_conv(in_channels, out_channels, kernel_size),
|
18 |
+
nn.PixelShuffle(ratio),
|
19 |
+
]
|
20 |
+
|
21 |
+
self.uppath = nn.Sequential(*modules)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
return self.uppath(x)
|
25 |
+
|
26 |
+
|
27 |
+
class RecLamResNet(nn.Module):
|
28 |
+
def __init__(self, args):
|
29 |
+
super(RecLamResNet, self).__init__()
|
30 |
+
|
31 |
+
self.rgb_range = args.rgb_range
|
32 |
+
self.mean = self.rgb_range / 2
|
33 |
+
self.is_detach=args.detach
|
34 |
+
|
35 |
+
self.n_resblocks = args.n_resblocks
|
36 |
+
self.n_feats = args.n_feats
|
37 |
+
self.kernel_size = args.kernel_size
|
38 |
+
|
39 |
+
self.n_scales = args.n_scales
|
40 |
+
|
41 |
+
self.body_model = ResNet(args, 3, 3, mean_shift=False)
|
42 |
+
|
43 |
+
def forward(self, input_lst):
|
44 |
+
# we use a reversed list for better compact
|
45 |
+
input_lst[0] = input_lst[0] - self.mean
|
46 |
+
output_lst = [None] * self.n_scales
|
47 |
+
last_output = input_lst[0]
|
48 |
+
for i in range(self.n_scales):
|
49 |
+
if self.is_detach:
|
50 |
+
last_output=last_output.detach()
|
51 |
+
output = self.body_model(last_output) + last_output
|
52 |
+
output_lst[self.n_scales-i-1] = output + self.mean
|
53 |
+
last_output = output
|
54 |
+
return output_lst
|
deblur/src/model/ResNet.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from . import common
|
4 |
+
|
5 |
+
def build_model(args):
|
6 |
+
return ResNet(args)
|
7 |
+
|
8 |
+
class ResNet(nn.Module):
|
9 |
+
def __init__(self, args, in_channels=3, out_channels=3, n_feats=None, kernel_size=None, n_resblocks=None, mean_shift=True):
|
10 |
+
super(ResNet, self).__init__()
|
11 |
+
|
12 |
+
self.in_channels = in_channels
|
13 |
+
self.out_channels = out_channels
|
14 |
+
|
15 |
+
self.n_feats = args.n_feats if n_feats is None else n_feats
|
16 |
+
self.kernel_size = args.kernel_size if kernel_size is None else kernel_size
|
17 |
+
self.n_resblocks = args.n_resblocks if n_resblocks is None else n_resblocks
|
18 |
+
|
19 |
+
self.mean_shift = mean_shift
|
20 |
+
self.rgb_range = args.rgb_range
|
21 |
+
self.mean = self.rgb_range / 2
|
22 |
+
|
23 |
+
modules = []
|
24 |
+
modules.append(common.default_conv(self.in_channels, self.n_feats, self.kernel_size))
|
25 |
+
for _ in range(self.n_resblocks):
|
26 |
+
modules.append(common.ResBlock(self.n_feats, self.kernel_size))
|
27 |
+
modules.append(common.default_conv(self.n_feats, self.out_channels, self.kernel_size))
|
28 |
+
|
29 |
+
self.body = nn.Sequential(*modules)
|
30 |
+
|
31 |
+
def forward(self, input):
|
32 |
+
if self.mean_shift:
|
33 |
+
input = input - self.mean
|
34 |
+
|
35 |
+
output = self.body(input)
|
36 |
+
|
37 |
+
if self.mean_shift:
|
38 |
+
output = output + self.mean
|
39 |
+
|
40 |
+
return output
|
41 |
+
|
deblur/src/model/__init__.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from importlib import import_module
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
8 |
+
|
9 |
+
import torch.distributed as dist
|
10 |
+
from torch.nn.utils import parameters_to_vector, vector_to_parameters
|
11 |
+
|
12 |
+
from .discriminator import Discriminator
|
13 |
+
|
14 |
+
from utils import interact
|
15 |
+
|
16 |
+
class Model(nn.Module):
|
17 |
+
def __init__(self, args):
|
18 |
+
super(Model, self).__init__()
|
19 |
+
|
20 |
+
self.args = args
|
21 |
+
self.device = args.device
|
22 |
+
self.n_GPUs = args.n_GPUs
|
23 |
+
self.save_dir = os.path.join(args.save_dir, 'models')
|
24 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
25 |
+
|
26 |
+
module = import_module('model.' + args.model)
|
27 |
+
|
28 |
+
self.model = nn.ModuleDict()
|
29 |
+
self.model.G = module.build_model(args)
|
30 |
+
if self.args.loss.lower().find('adv') >= 0:
|
31 |
+
self.model.D = Discriminator(self.args)
|
32 |
+
else:
|
33 |
+
self.model.D = None
|
34 |
+
|
35 |
+
self.to(args.device, dtype=args.dtype, non_blocking=True)
|
36 |
+
self.load(args.load_epoch, path=args.pretrained)
|
37 |
+
|
38 |
+
def parallelize(self):
|
39 |
+
if self.args.device_type == 'cuda':
|
40 |
+
if self.args.distributed:
|
41 |
+
Parallel = DistributedDataParallel
|
42 |
+
parallel_args = {
|
43 |
+
"device_ids": [self.args.rank],
|
44 |
+
"output_device": self.args.rank,
|
45 |
+
}
|
46 |
+
else:
|
47 |
+
Parallel = DataParallel
|
48 |
+
parallel_args = {
|
49 |
+
'device_ids': list(range(self.n_GPUs)),
|
50 |
+
'output_device': self.args.rank # always 0
|
51 |
+
}
|
52 |
+
|
53 |
+
for model_key in self.model:
|
54 |
+
if self.model[model_key] is not None:
|
55 |
+
self.model[model_key] = Parallel(self.model[model_key], **parallel_args)
|
56 |
+
|
57 |
+
def forward(self, input):
|
58 |
+
return self.model.G(input)
|
59 |
+
|
60 |
+
def _save_path(self, epoch):
|
61 |
+
model_path = os.path.join(self.save_dir, 'model-{:d}.pt'.format(epoch))
|
62 |
+
return model_path
|
63 |
+
|
64 |
+
def state_dict(self):
|
65 |
+
state_dict = {}
|
66 |
+
for model_key in self.model:
|
67 |
+
if self.model[model_key] is not None:
|
68 |
+
parallelized = isinstance(self.model[model_key], (DataParallel, DistributedDataParallel))
|
69 |
+
if parallelized:
|
70 |
+
state_dict[model_key] = self.model[model_key].module.state_dict()
|
71 |
+
else:
|
72 |
+
state_dict[model_key] = self.model[model_key].state_dict()
|
73 |
+
|
74 |
+
return state_dict
|
75 |
+
|
76 |
+
def load_state_dict(self, state_dict, strict=True):
|
77 |
+
for model_key in self.model:
|
78 |
+
parallelized = isinstance(self.model[model_key], (DataParallel, DistributedDataParallel))
|
79 |
+
if model_key in state_dict:
|
80 |
+
if parallelized:
|
81 |
+
self.model[model_key].module.load_state_dict(state_dict[model_key], strict)
|
82 |
+
else:
|
83 |
+
self.model[model_key].load_state_dict(state_dict[model_key], strict)
|
84 |
+
|
85 |
+
def save(self, epoch):
|
86 |
+
torch.save(self.state_dict(), self._save_path(epoch))
|
87 |
+
|
88 |
+
def load(self, epoch=None, path=None):
|
89 |
+
if path:
|
90 |
+
model_name = path
|
91 |
+
elif isinstance(epoch, int):
|
92 |
+
if epoch < 0:
|
93 |
+
epoch = self.get_last_epoch()
|
94 |
+
if epoch == 0: # epoch 0
|
95 |
+
# make sure model parameters are synchronized at initial
|
96 |
+
# for multi-node training (not in current implementation)
|
97 |
+
# self.synchronize()
|
98 |
+
|
99 |
+
return # leave model as initialized
|
100 |
+
|
101 |
+
model_name = self._save_path(epoch)
|
102 |
+
else:
|
103 |
+
raise Exception('no epoch number or model path specified!')
|
104 |
+
|
105 |
+
print('Loading model from {}'.format(model_name))
|
106 |
+
state_dict = torch.load(model_name, map_location=self.args.device)
|
107 |
+
self.load_state_dict(state_dict)
|
108 |
+
|
109 |
+
return
|
110 |
+
|
111 |
+
def synchronize(self):
|
112 |
+
if self.args.distributed:
|
113 |
+
# synchronize model parameters across nodes
|
114 |
+
vector = parameters_to_vector(self.parameters())
|
115 |
+
|
116 |
+
dist.broadcast(vector, 0) # broadcast parameters to other processes
|
117 |
+
if self.args.rank != 0:
|
118 |
+
vector_to_parameters(vector, self.parameters())
|
119 |
+
|
120 |
+
del vector
|
121 |
+
|
122 |
+
return
|
123 |
+
|
124 |
+
def get_last_epoch(self):
|
125 |
+
model_list = sorted(os.listdir(self.save_dir))
|
126 |
+
if len(model_list) == 0:
|
127 |
+
epoch = 0
|
128 |
+
else:
|
129 |
+
epoch = int(re.findall('\\d+', model_list[-1])[0]) # model example name model-100.pt
|
130 |
+
|
131 |
+
return epoch
|
132 |
+
|
133 |
+
def print(self):
|
134 |
+
print(self.model)
|
135 |
+
|
136 |
+
return
|
deblur/src/model/common.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
def default_conv(in_channels, out_channels, kernel_size, bias=True, groups=1):
|
7 |
+
return nn.Conv2d(
|
8 |
+
in_channels, out_channels, kernel_size,
|
9 |
+
padding=(kernel_size // 2), bias=bias, groups=groups)
|
10 |
+
|
11 |
+
def default_norm(n_feats):
|
12 |
+
return nn.BatchNorm2d(n_feats)
|
13 |
+
|
14 |
+
def default_act():
|
15 |
+
return nn.ReLU(True)
|
16 |
+
|
17 |
+
def empty_h(x, n_feats):
|
18 |
+
'''
|
19 |
+
create an empty hidden state
|
20 |
+
|
21 |
+
input
|
22 |
+
x: B x T x 3 x H x W
|
23 |
+
|
24 |
+
output
|
25 |
+
h: B x C x H/4 x W/4
|
26 |
+
'''
|
27 |
+
b = x.size(0)
|
28 |
+
h, w = x.size()[-2:]
|
29 |
+
return x.new_zeros((b, n_feats, h//4, w//4))
|
30 |
+
|
31 |
+
class Normalization(nn.Conv2d):
|
32 |
+
"""Normalize input tensor value with convolutional layer"""
|
33 |
+
def __init__(self, mean=(0, 0, 0), std=(1, 1, 1)):
|
34 |
+
super(Normalization, self).__init__(3, 3, kernel_size=1)
|
35 |
+
tensor_mean = torch.Tensor(mean)
|
36 |
+
tensor_inv_std = torch.Tensor(std).reciprocal()
|
37 |
+
|
38 |
+
self.weight.data = torch.eye(3).mul(tensor_inv_std).view(3, 3, 1, 1)
|
39 |
+
self.bias.data = torch.Tensor(-tensor_mean.mul(tensor_inv_std))
|
40 |
+
|
41 |
+
for params in self.parameters():
|
42 |
+
params.requires_grad = False
|
43 |
+
|
44 |
+
class BasicBlock(nn.Sequential):
|
45 |
+
"""Convolution layer + Activation layer"""
|
46 |
+
def __init__(
|
47 |
+
self, in_channels, out_channels, kernel_size, bias=True,
|
48 |
+
conv=default_conv, norm=False, act=default_act):
|
49 |
+
|
50 |
+
modules = []
|
51 |
+
modules.append(
|
52 |
+
conv(in_channels, out_channels, kernel_size, bias=bias))
|
53 |
+
if norm: modules.append(norm(out_channels))
|
54 |
+
if act: modules.append(act())
|
55 |
+
|
56 |
+
super(BasicBlock, self).__init__(*modules)
|
57 |
+
|
58 |
+
class ResBlock(nn.Module):
|
59 |
+
def __init__(
|
60 |
+
self, n_feats, kernel_size, bias=True,
|
61 |
+
conv=default_conv, norm=False, act=default_act):
|
62 |
+
|
63 |
+
super(ResBlock, self).__init__()
|
64 |
+
|
65 |
+
modules = []
|
66 |
+
for i in range(2):
|
67 |
+
modules.append(conv(n_feats, n_feats, kernel_size, bias=bias))
|
68 |
+
if norm: modules.append(norm(n_feats))
|
69 |
+
if act and i == 0: modules.append(act())
|
70 |
+
|
71 |
+
self.body = nn.Sequential(*modules)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
res = self.body(x)
|
75 |
+
res += x
|
76 |
+
|
77 |
+
return res
|
78 |
+
|
79 |
+
class ResBlock_mobile(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self, n_feats, kernel_size, bias=True,
|
82 |
+
conv=default_conv, norm=False, act=default_act, dropout=False):
|
83 |
+
|
84 |
+
super(ResBlock_mobile, self).__init__()
|
85 |
+
|
86 |
+
modules = []
|
87 |
+
for i in range(2):
|
88 |
+
modules.append(conv(n_feats, n_feats, kernel_size, bias=False, groups=n_feats))
|
89 |
+
modules.append(conv(n_feats, n_feats, 1, bias=False))
|
90 |
+
if dropout and i == 0: modules.append(nn.Dropout2d(dropout))
|
91 |
+
if norm: modules.append(norm(n_feats))
|
92 |
+
if act and i == 0: modules.append(act())
|
93 |
+
|
94 |
+
self.body = nn.Sequential(*modules)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
res = self.body(x)
|
98 |
+
res += x
|
99 |
+
|
100 |
+
return res
|
101 |
+
|
102 |
+
class Upsampler(nn.Sequential):
|
103 |
+
def __init__(
|
104 |
+
self, scale, n_feats, bias=True,
|
105 |
+
conv=default_conv, norm=False, act=False):
|
106 |
+
|
107 |
+
modules = []
|
108 |
+
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
|
109 |
+
for _ in range(int(math.log(scale, 2))):
|
110 |
+
modules.append(conv(n_feats, 4 * n_feats, 3, bias))
|
111 |
+
modules.append(nn.PixelShuffle(2))
|
112 |
+
if norm: modules.append(norm(n_feats))
|
113 |
+
if act: modules.append(act())
|
114 |
+
elif scale == 3:
|
115 |
+
modules.append(conv(n_feats, 9 * n_feats, 3, bias))
|
116 |
+
modules.append(nn.PixelShuffle(3))
|
117 |
+
if norm: modules.append(norm(n_feats))
|
118 |
+
if act: modules.append(act())
|
119 |
+
else:
|
120 |
+
raise NotImplementedError
|
121 |
+
|
122 |
+
super(Upsampler, self).__init__(*modules)
|
123 |
+
|
124 |
+
# Only support 1 / 2
|
125 |
+
class PixelSort(nn.Module):
|
126 |
+
"""The inverse operation of PixelShuffle
|
127 |
+
Reduces the spatial resolution, increasing the number of channels.
|
128 |
+
Currently, scale 0.5 is supported only.
|
129 |
+
Later, torch.nn.functional.pixel_sort may be implemented.
|
130 |
+
Reference:
|
131 |
+
http://pytorch.org/docs/0.3.0/_modules/torch/nn/modules/pixelshuffle.html#PixelShuffle
|
132 |
+
http://pytorch.org/docs/0.3.0/_modules/torch/nn/functional.html#pixel_shuffle
|
133 |
+
"""
|
134 |
+
def __init__(self, upscale_factor=0.5):
|
135 |
+
super(PixelSort, self).__init__()
|
136 |
+
self.upscale_factor = upscale_factor
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
b, c, h, w = x.size()
|
140 |
+
x = x.view(b, c, 2, 2, h // 2, w // 2)
|
141 |
+
x = x.permute(0, 1, 5, 3, 2, 4).contiguous()
|
142 |
+
x = x.view(b, 4 * c, h // 2, w // 2)
|
143 |
+
|
144 |
+
return x
|
145 |
+
|
146 |
+
class Downsampler(nn.Sequential):
|
147 |
+
def __init__(
|
148 |
+
self, scale, n_feats, bias=True,
|
149 |
+
conv=default_conv, norm=False, act=False):
|
150 |
+
|
151 |
+
modules = []
|
152 |
+
if scale == 0.5:
|
153 |
+
modules.append(PixelSort())
|
154 |
+
modules.append(conv(4 * n_feats, n_feats, 3, bias))
|
155 |
+
if norm: modules.append(norm(n_feats))
|
156 |
+
if act: modules.append(act())
|
157 |
+
else:
|
158 |
+
raise NotImplementedError
|
159 |
+
|
160 |
+
super(Downsampler, self).__init__(*modules)
|
161 |
+
|
deblur/src/model/discriminator.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
class Discriminator(nn.Module):
|
4 |
+
def __init__(self, args):
|
5 |
+
super(Discriminator, self).__init__()
|
6 |
+
|
7 |
+
# self.args = args
|
8 |
+
n_feats = args.n_feats
|
9 |
+
kernel_size = args.kernel_size
|
10 |
+
|
11 |
+
def conv(kernel_size, in_channel, n_feats, stride, pad=None):
|
12 |
+
if pad is None:
|
13 |
+
pad = (kernel_size-1)//2
|
14 |
+
|
15 |
+
return nn.Conv2d(in_channel, n_feats, kernel_size, stride=stride, padding=pad, bias=False)
|
16 |
+
|
17 |
+
self.conv_layers = nn.ModuleList([
|
18 |
+
conv(kernel_size, 3, n_feats//2, 1), # 256
|
19 |
+
conv(kernel_size, n_feats//2, n_feats//2, 2), # 128
|
20 |
+
conv(kernel_size, n_feats//2, n_feats, 1),
|
21 |
+
conv(kernel_size, n_feats, n_feats, 2), # 64
|
22 |
+
conv(kernel_size, n_feats, n_feats*2, 1),
|
23 |
+
conv(kernel_size, n_feats*2, n_feats*2, 4), # 16
|
24 |
+
conv(kernel_size, n_feats*2, n_feats*4, 1),
|
25 |
+
conv(kernel_size, n_feats*4, n_feats*4, 4), # 4
|
26 |
+
conv(kernel_size, n_feats*4, n_feats*8, 1),
|
27 |
+
conv(4, n_feats*8, n_feats*8, 4, 0), # 1
|
28 |
+
])
|
29 |
+
|
30 |
+
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
31 |
+
self.dense = nn.Conv2d(n_feats*8, 1, 1, bias=False)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
|
35 |
+
for layer in self.conv_layers:
|
36 |
+
x = self.act(layer(x))
|
37 |
+
|
38 |
+
x = self.dense(x)
|
39 |
+
|
40 |
+
return x
|
41 |
+
|
deblur/src/model/structure.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from .common import ResBlock, default_conv
|
4 |
+
|
5 |
+
def encoder(in_channels, n_feats):
|
6 |
+
"""RGB / IR feature encoder
|
7 |
+
"""
|
8 |
+
|
9 |
+
# in_channels == 1 or 3 or 4 or ....
|
10 |
+
# After 1st conv, B x n_feats x H x W
|
11 |
+
# After 2nd conv, B x 2n_feats x H/2 x W/2
|
12 |
+
# After 3rd conv, B x 3n_feats x H/4 x W/4
|
13 |
+
return nn.Sequential(
|
14 |
+
nn.Conv2d(in_channels, 1 * n_feats, 5, stride=1, padding=2),
|
15 |
+
nn.Conv2d(1 * n_feats, 2 * n_feats, 5, stride=2, padding=2),
|
16 |
+
nn.Conv2d(2 * n_feats, 3 * n_feats, 5, stride=2, padding=2),
|
17 |
+
)
|
18 |
+
|
19 |
+
def decoder(out_channels, n_feats):
|
20 |
+
"""RGB / IR / Depth decoder
|
21 |
+
"""
|
22 |
+
# After 1st deconv, B x 2n_feats x H/2 x W/2
|
23 |
+
# After 2nd deconv, B x n_feats x H x W
|
24 |
+
# After 3rd conv, B x out_channels x H x W
|
25 |
+
deconv_kargs = {'stride': 2, 'padding': 1, 'output_padding': 1}
|
26 |
+
|
27 |
+
return nn.Sequential(
|
28 |
+
nn.ConvTranspose2d(3 * n_feats, 2 * n_feats, 3, **deconv_kargs),
|
29 |
+
nn.ConvTranspose2d(2 * n_feats, 1 * n_feats, 3, **deconv_kargs),
|
30 |
+
nn.Conv2d(n_feats, out_channels, 5, stride=1, padding=2),
|
31 |
+
)
|
32 |
+
|
33 |
+
# def ResNet(n_feats, in_channels=None, out_channels=None):
|
34 |
+
def ResNet(n_feats, kernel_size, n_blocks, in_channels=None, out_channels=None):
|
35 |
+
"""sequential ResNet
|
36 |
+
"""
|
37 |
+
|
38 |
+
# if in_channels is None:
|
39 |
+
# in_channels = n_feats
|
40 |
+
# if out_channels is None:
|
41 |
+
# out_channels = n_feats
|
42 |
+
# # currently not implemented
|
43 |
+
|
44 |
+
m = []
|
45 |
+
|
46 |
+
if in_channels is not None:
|
47 |
+
m += [default_conv(in_channels, n_feats, kernel_size)]
|
48 |
+
|
49 |
+
m += [ResBlock(n_feats, 3)] * n_blocks
|
50 |
+
|
51 |
+
if out_channels is not None:
|
52 |
+
m += [default_conv(n_feats, out_channels, kernel_size)]
|
53 |
+
|
54 |
+
|
55 |
+
return nn.Sequential(*m)
|
56 |
+
|
deblur/src/optim/__init__.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.optim as optim
|
3 |
+
import torch.optim.lr_scheduler as lrs
|
4 |
+
|
5 |
+
import os
|
6 |
+
from collections import Counter
|
7 |
+
|
8 |
+
from model import Model
|
9 |
+
from utils import interact, Map
|
10 |
+
|
11 |
+
class Optimizer(object):
|
12 |
+
def __init__(self, args, model):
|
13 |
+
self.args = args
|
14 |
+
|
15 |
+
self.save_dir = os.path.join(self.args.save_dir, 'optim')
|
16 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
17 |
+
|
18 |
+
if isinstance(model, Model):
|
19 |
+
model = model.model
|
20 |
+
|
21 |
+
# set base arguments
|
22 |
+
kwargs_optimizer = {
|
23 |
+
'lr': args.lr,
|
24 |
+
'weight_decay': args.weight_decay
|
25 |
+
}
|
26 |
+
|
27 |
+
if args.optimizer == 'SGD':
|
28 |
+
optimizer_class = optim.SGD
|
29 |
+
kwargs_optimizer['momentum'] = args.momentum
|
30 |
+
elif args.optimizer == 'ADAM':
|
31 |
+
optimizer_class = optim.Adam
|
32 |
+
kwargs_optimizer['betas'] = args.betas
|
33 |
+
kwargs_optimizer['eps'] = args.epsilon
|
34 |
+
elif args.optimizer == 'RMSPROP':
|
35 |
+
optimizer_class = optim.RMSprop
|
36 |
+
kwargs_optimizer['eps'] = args.epsilon
|
37 |
+
|
38 |
+
# scheduler
|
39 |
+
if args.scheduler == 'step':
|
40 |
+
scheduler_class = lrs.MultiStepLR
|
41 |
+
kwargs_scheduler = {
|
42 |
+
'milestones': args.milestones,
|
43 |
+
'gamma': args.gamma,
|
44 |
+
}
|
45 |
+
elif args.scheduler == 'plateau':
|
46 |
+
scheduler_class = lrs.ReduceLROnPlateau
|
47 |
+
kwargs_scheduler = {
|
48 |
+
'mode': 'min',
|
49 |
+
'factor': args.gamma,
|
50 |
+
'patience': 10,
|
51 |
+
'verbose': True,
|
52 |
+
'threshold': 0,
|
53 |
+
'threshold_mode': 'abs',
|
54 |
+
'cooldown': 10,
|
55 |
+
}
|
56 |
+
|
57 |
+
self.kwargs_optimizer = kwargs_optimizer
|
58 |
+
self.scheduler_class = scheduler_class
|
59 |
+
self.kwargs_scheduler = kwargs_scheduler
|
60 |
+
|
61 |
+
def _get_optimizer(model):
|
62 |
+
|
63 |
+
class _Optimizer(optimizer_class):
|
64 |
+
def __init__(self, model, args, scheduler_class, kwargs_scheduler):
|
65 |
+
trainable = filter(lambda x: x.requires_grad, model.parameters())
|
66 |
+
super(_Optimizer, self).__init__(trainable, **kwargs_optimizer)
|
67 |
+
|
68 |
+
self.args = args
|
69 |
+
|
70 |
+
self._register_scheduler(scheduler_class, kwargs_scheduler)
|
71 |
+
|
72 |
+
def _register_scheduler(self, scheduler_class, kwargs_scheduler):
|
73 |
+
self.scheduler = scheduler_class(self, **kwargs_scheduler)
|
74 |
+
|
75 |
+
def schedule(self, metrics=None):
|
76 |
+
if isinstance(self, lrs.ReduceLROnPlateau):
|
77 |
+
self.scheduler.step(metrics)
|
78 |
+
else:
|
79 |
+
self.scheduler.step()
|
80 |
+
|
81 |
+
def get_last_epoch(self):
|
82 |
+
return self.scheduler.last_epoch
|
83 |
+
|
84 |
+
def get_lr(self):
|
85 |
+
return self.param_groups[0]['lr']
|
86 |
+
|
87 |
+
def get_last_lr(self):
|
88 |
+
return self.scheduler.get_last_lr()[0]
|
89 |
+
|
90 |
+
def state_dict(self):
|
91 |
+
state_dict = super(_Optimizer, self).state_dict() # {'state': ..., 'param_groups': ...}
|
92 |
+
state_dict['scheduler'] = self.scheduler.state_dict()
|
93 |
+
|
94 |
+
return state_dict
|
95 |
+
|
96 |
+
def load_state_dict(self, state_dict, epoch=None):
|
97 |
+
# optimizer
|
98 |
+
super(_Optimizer, self).load_state_dict(state_dict) # load 'state' and 'param_groups' only
|
99 |
+
# scheduler
|
100 |
+
self.scheduler.load_state_dict(state_dict['scheduler']) # should work for plateau or simple resuming
|
101 |
+
|
102 |
+
reschedule = False
|
103 |
+
if isinstance(self.scheduler, lrs.MultiStepLR):
|
104 |
+
if self.args.milestones != list(self.scheduler.milestones) or self.args.gamma != self.scheduler.gamma:
|
105 |
+
reschedule = True
|
106 |
+
|
107 |
+
if reschedule:
|
108 |
+
if epoch is None:
|
109 |
+
if self.scheduler.last_epoch > 1:
|
110 |
+
epoch = self.scheduler.last_epoch
|
111 |
+
else:
|
112 |
+
epoch = self.args.start_epoch - 1
|
113 |
+
|
114 |
+
# if False:
|
115 |
+
# # option 1. new scheduler
|
116 |
+
# for i, group in enumerate(self.param_groups):
|
117 |
+
# self.param_groups[i]['lr'] = group['initial_lr'] # reset optimizer learning rate to initial
|
118 |
+
# # self.scheduler = None
|
119 |
+
# self._register_scheduler(scheduler_class, kwargs_scheduler)
|
120 |
+
|
121 |
+
# self.zero_grad()
|
122 |
+
# self.step()
|
123 |
+
# for _ in range(epoch):
|
124 |
+
# self.scheduler.step()
|
125 |
+
# self._step_count -= 1
|
126 |
+
|
127 |
+
# else:
|
128 |
+
# option 2. modify existing scheduler
|
129 |
+
self.scheduler.milestones = Counter(self.args.milestones)
|
130 |
+
self.scheduler.gamma = self.args.gamma
|
131 |
+
for i, group in enumerate(self.param_groups):
|
132 |
+
self.param_groups[i]['lr'] = group['initial_lr'] # reset optimizer learning rate to initial
|
133 |
+
multiplier = 1
|
134 |
+
for milestone in self.scheduler.milestones:
|
135 |
+
if epoch >= milestone:
|
136 |
+
multiplier *= self.scheduler.gamma
|
137 |
+
|
138 |
+
self.param_groups[i]['lr'] *= multiplier
|
139 |
+
|
140 |
+
return _Optimizer(model, args, scheduler_class, kwargs_scheduler)
|
141 |
+
|
142 |
+
self.G = _get_optimizer(model.G)
|
143 |
+
if model.D is not None:
|
144 |
+
self.D = _get_optimizer(model.D)
|
145 |
+
else:
|
146 |
+
self.D = None
|
147 |
+
|
148 |
+
self.load(args.load_epoch)
|
149 |
+
|
150 |
+
def zero_grad(self):
|
151 |
+
self.G.zero_grad()
|
152 |
+
|
153 |
+
def step(self):
|
154 |
+
self.G.step()
|
155 |
+
|
156 |
+
def schedule(self, metrics=None):
|
157 |
+
self.G.schedule(metrics)
|
158 |
+
if self.D is not None:
|
159 |
+
self.D.schedule(metrics)
|
160 |
+
|
161 |
+
def get_last_epoch(self):
|
162 |
+
return self.G.get_last_epoch()
|
163 |
+
|
164 |
+
def get_lr(self):
|
165 |
+
return self.G.get_lr()
|
166 |
+
|
167 |
+
def get_last_lr(self):
|
168 |
+
return self.G.get_last_lr()
|
169 |
+
|
170 |
+
def state_dict(self):
|
171 |
+
state_dict = Map()
|
172 |
+
state_dict.G = self.G.state_dict()
|
173 |
+
if self.D is not None:
|
174 |
+
state_dict.D = self.D.state_dict()
|
175 |
+
|
176 |
+
return state_dict.toDict()
|
177 |
+
|
178 |
+
def load_state_dict(self, state_dict, epoch=None):
|
179 |
+
state_dict = Map(**state_dict)
|
180 |
+
self.G.load_state_dict(state_dict.G, epoch)
|
181 |
+
if self.D is not None:
|
182 |
+
self.D.load_state_dict(state_dict.D, epoch)
|
183 |
+
|
184 |
+
def _save_path(self, epoch=None):
|
185 |
+
epoch = epoch if epoch is not None else self.get_last_epoch()
|
186 |
+
save_path = os.path.join(self.save_dir, 'optim-{:d}.pt'.format(epoch))
|
187 |
+
|
188 |
+
return save_path
|
189 |
+
|
190 |
+
def save(self, epoch=None):
|
191 |
+
if epoch is None:
|
192 |
+
epoch = self.G.scheduler.last_epoch
|
193 |
+
torch.save(self.state_dict(), self._save_path(epoch))
|
194 |
+
|
195 |
+
def load(self, epoch):
|
196 |
+
if epoch > 0:
|
197 |
+
print('Loading optimizer from {}'.format(self._save_path(epoch)))
|
198 |
+
self.load_state_dict(torch.load(self._save_path(epoch), map_location=self.args.device), epoch=epoch)
|
199 |
+
|
200 |
+
elif epoch == 0:
|
201 |
+
pass
|
202 |
+
else:
|
203 |
+
raise NotImplementedError
|
204 |
+
|
205 |
+
return
|
206 |
+
|
deblur/src/optim/warm_multi_step_lr.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from bisect import bisect_right
|
3 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
4 |
+
|
5 |
+
# MultiStep learning rate scheduler with warm restart
|
6 |
+
class WarmMultiStepLR(_LRScheduler):
|
7 |
+
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, scale=1):
|
8 |
+
if not list(milestones) == sorted(milestones):
|
9 |
+
raise ValueError(
|
10 |
+
'Milestones should be a list of increasing integers. Got {}',
|
11 |
+
milestones
|
12 |
+
)
|
13 |
+
|
14 |
+
self.milestones = milestones
|
15 |
+
self.gamma = gamma
|
16 |
+
self.scale = scale
|
17 |
+
|
18 |
+
self.warmup_epochs = 5
|
19 |
+
self.gradual = (self.scale - 1) / self.warmup_epochs
|
20 |
+
super(WarmMultiStepLR, self).__init__(optimizer, last_epoch)
|
21 |
+
|
22 |
+
def get_lr(self):
|
23 |
+
if self.last_epoch < self.warmup_epochs:
|
24 |
+
return [
|
25 |
+
base_lr * (1 + self.last_epoch * self.gradual) / self.scale
|
26 |
+
for base_lr in self.base_lrs
|
27 |
+
]
|
28 |
+
else:
|
29 |
+
return [
|
30 |
+
base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch)
|
31 |
+
for base_lr in self.base_lrs
|
32 |
+
]
|
deblur/src/option.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""optionional argument parsing"""
|
2 |
+
# pylint: disable=C0103, C0301
|
3 |
+
import argparse
|
4 |
+
import datetime
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import shutil
|
8 |
+
import time
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.distributed as dist
|
12 |
+
import torch.backends.cudnn as cudnn
|
13 |
+
|
14 |
+
from utils import interact
|
15 |
+
from utils import str2bool, int2str
|
16 |
+
|
17 |
+
import template
|
18 |
+
|
19 |
+
# Training settings
|
20 |
+
parser = argparse.ArgumentParser(description='Dynamic Scene Deblurring')
|
21 |
+
|
22 |
+
# Device specifications
|
23 |
+
group_device = parser.add_argument_group('Device specs')
|
24 |
+
group_device.add_argument('--seed', type=int, default=-1, help='random seed')
|
25 |
+
group_device.add_argument('--num_workers', type=int, default=7, help='the number of dataloader workers')
|
26 |
+
group_device.add_argument('--device_type', type=str, choices=('cpu', 'cuda'), default='cuda', help='device to run models')
|
27 |
+
group_device.add_argument('--device_index', type=int, default=0, help='device id to run models')
|
28 |
+
group_device.add_argument('--n_GPUs', type=int, default=1, help='the number of GPUs for training')
|
29 |
+
group_device.add_argument('--distributed', type=str2bool, default=False, help='use DistributedDataParallel instead of DataParallel for better speed')
|
30 |
+
group_device.add_argument('--launched', type=str2bool, default=False, help='identify if main.py was executed from launch.py. Do not set this to be true using main.py.')
|
31 |
+
|
32 |
+
group_device.add_argument('--master_addr', type=str, default='127.0.0.1', help='master address for distributed')
|
33 |
+
group_device.add_argument('--master_port', type=int2str, default='8023', help='master port for distributed')
|
34 |
+
group_device.add_argument('--dist_backend', type=str, default='nccl', help='distributed backend')
|
35 |
+
group_device.add_argument('--init_method', type=str, default='env://', help='distributed init method URL to discover peers')
|
36 |
+
group_device.add_argument('--rank', type=int, default=0, help='rank of the distributed process (gpu id). 0 is the master process.')
|
37 |
+
group_device.add_argument('--world_size', type=int, default=1, help='world_size for distributed training (number of GPUs)')
|
38 |
+
|
39 |
+
# Data
|
40 |
+
group_data = parser.add_argument_group('Data specs')
|
41 |
+
group_data.add_argument('--data_root', type=str, default='/data/ssd/public/czli/deblur', help='dataset root location')
|
42 |
+
group_data.add_argument('--dataset', type=str, default=None, help='training/validation/test dataset name, has priority if not None')
|
43 |
+
group_data.add_argument('--data_train', type=str, default='GOPRO_Large', help='training dataset name')
|
44 |
+
group_data.add_argument('--data_val', type=str, default=None, help='validation dataset name')
|
45 |
+
group_data.add_argument('--data_test', type=str, default='GOPRO_Large', help='test dataset name')
|
46 |
+
group_data.add_argument('--blur_key', type=str, default='blur_gamma', choices=('blur', 'blur_gamma'), help='blur type from camera response function for GOPRO_Large dataset')
|
47 |
+
group_data.add_argument('--rgb_range', type=int, default=255, help='RGB pixel value ranging from 0')
|
48 |
+
|
49 |
+
# Model
|
50 |
+
group_model = parser.add_argument_group('Model specs')
|
51 |
+
group_model.add_argument('--model', type=str, default='RecLamResNet', help='model architecture')
|
52 |
+
group_model.add_argument('--pretrained', type=str, default='', help='pretrained model location')
|
53 |
+
group_model.add_argument('--n_scales', type=int, default=5, help='multi-scale deblurring level')
|
54 |
+
group_model.add_argument('--detach', type=str2bool, default=False, help='detach between recurrence')
|
55 |
+
group_model.add_argument('--gaussian_pyramid', type=str2bool, default=True, help='gaussian pyramid input/target')
|
56 |
+
group_model.add_argument('--n_resblocks', type=int, default=19, help='number of residual blocks per scale')
|
57 |
+
group_model.add_argument('--n_feats', type=int, default=64, help='number of feature maps')
|
58 |
+
group_model.add_argument('--kernel_size', type=int, default=5, help='size of conv kernel')
|
59 |
+
group_model.add_argument('--downsample', type=str, choices=('Gaussian', 'bicubic', 'stride'), default='Gaussian', help='input pyramid generation method')
|
60 |
+
|
61 |
+
group_model.add_argument('--precision', type=str, default='single', choices=('single', 'half'), help='FP precision for test(single | half)')
|
62 |
+
|
63 |
+
# amp
|
64 |
+
group_amp = parser.add_argument_group('AMP specs')
|
65 |
+
group_amp.add_argument('--amp', type=str2bool, default=False, help='use automatic mixed precision training')
|
66 |
+
group_amp.add_argument('--init_scale', type=float, default=1024., help='initial loss scale')
|
67 |
+
|
68 |
+
# Training
|
69 |
+
group_train = parser.add_argument_group('Training specs')
|
70 |
+
group_train.add_argument('--patch_size', type=int, default=256, help='training patch size')
|
71 |
+
group_train.add_argument('--batch_size', type=int, default=16, help='input batch size for training')
|
72 |
+
group_train.add_argument('--split_batch', type=int, default=1, help='split a minibatch into smaller chunks')
|
73 |
+
group_train.add_argument('--augment', type=str2bool, default=True, help='train with data augmentation')
|
74 |
+
|
75 |
+
# Testing
|
76 |
+
group_test = parser.add_argument_group('Testing specs')
|
77 |
+
group_test.add_argument('--validate_every', type=int, default=10, help='do validation at every N epochs')
|
78 |
+
group_test.add_argument('--test_every', type=int, default=10, help='do test at every N epochs')
|
79 |
+
# group_test.add_argument('--chop', type=str2bool, default=False, help='memory-efficient forward')
|
80 |
+
# group_test.add_argument('--self_ensemble', type=str2bool, default=False, help='self-ensembled testing')
|
81 |
+
|
82 |
+
# Action
|
83 |
+
group_action = parser.add_argument_group('Source behavior')
|
84 |
+
group_action.add_argument('--do_train', type=str2bool, default=True, help='do train the model')
|
85 |
+
group_action.add_argument('--do_validate', type=str2bool, default=True, help='do validate the model')
|
86 |
+
group_action.add_argument('--do_test', type=str2bool, default=True, help='do test the model')
|
87 |
+
group_action.add_argument('--demo', type=str2bool, default=False, help='demo')
|
88 |
+
group_action.add_argument('--demo_input_dir', type=str, default='', help='demo input directory')
|
89 |
+
group_action.add_argument('--demo_output_dir', type=str, default='', help='demo output directory')
|
90 |
+
|
91 |
+
# Optimization
|
92 |
+
group_optim = parser.add_argument_group('Optimization specs')
|
93 |
+
group_optim.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
94 |
+
group_optim.add_argument('--milestones', type=int, nargs='+', default=[500, 750, 900], help='learning rate decay per N epochs')
|
95 |
+
group_optim.add_argument('--scheduler', default='step', choices=('step', 'plateau'), help='learning rate scheduler type')
|
96 |
+
group_optim.add_argument('--gamma', type=float, default=0.5, help='learning rate decay factor for step decay')
|
97 |
+
group_optim.add_argument('--optimizer', default='ADAM', choices=('SGD', 'ADAM', 'RMSprop'), help='optimizer to use (SGD | ADAM | RMSProp)')
|
98 |
+
group_optim.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
|
99 |
+
group_optim.add_argument('--betas', type=float, nargs=2, default=(0.9, 0.999), help='ADAM betas')
|
100 |
+
group_optim.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon')
|
101 |
+
group_optim.add_argument('--weight_decay', type=float, default=0, help='weight decay')
|
102 |
+
group_optim.add_argument('--clip', type=float, default=0, help='weight decay')
|
103 |
+
|
104 |
+
# Loss
|
105 |
+
group_loss = parser.add_argument_group('Loss specs')
|
106 |
+
group_loss.add_argument('--loss', type=str, default='1*MSE', help='loss function configuration')
|
107 |
+
group_loss.add_argument('--metric', type=str, default='PSNR,SSIM', help='metric function configuration. ex) None | PSNR | SSIM | PSNR,SSIM')
|
108 |
+
group_loss.add_argument('--gamma', type=float, default=0.6, help='gamma decay')
|
109 |
+
|
110 |
+
|
111 |
+
# Logging
|
112 |
+
group_log = parser.add_argument_group('Logging specs')
|
113 |
+
group_log.add_argument('--save_dir', type=str, default='', help='subdirectory to save experiment logs')
|
114 |
+
# group_log.add_argument('--load_dir', type=str, default='', help='subdirectory to load experiment logs')
|
115 |
+
group_log.add_argument('--start_epoch', type=int, default=-1, help='(re)starting epoch number')
|
116 |
+
group_log.add_argument('--end_epoch', type=int, default=1000, help='ending epoch number')
|
117 |
+
group_log.add_argument('--load_epoch', type=int, default=-1, help='epoch number to load model (start_epoch-1 for training, start_epoch for testing)')
|
118 |
+
group_log.add_argument('--save_every', type=int, default=10, help='save model/optimizer at every N epochs')
|
119 |
+
group_log.add_argument('--save_results', type=str, default='part', choices=('none', 'part', 'all'), help='save none/part/all of result images')
|
120 |
+
|
121 |
+
# Debugging
|
122 |
+
group_debug = parser.add_argument_group('Debug specs')
|
123 |
+
group_debug.add_argument('--stay', type=str2bool, default=False, help='stay at interactive console after trainer initialization')
|
124 |
+
|
125 |
+
parser.add_argument('--template', type=str, default='', help='argument template option')
|
126 |
+
|
127 |
+
args = parser.parse_args()
|
128 |
+
template.set_template(args)
|
129 |
+
|
130 |
+
args.data_root = os.path.expanduser(args.data_root) # recognize home directory
|
131 |
+
now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
132 |
+
if args.save_dir == '':
|
133 |
+
args.save_dir = now
|
134 |
+
args.save_dir = os.path.join('../experiment', args.save_dir)
|
135 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
136 |
+
|
137 |
+
if args.start_epoch < 0: # start from scratch or continue from the last epoch
|
138 |
+
# check if there are any models saved before
|
139 |
+
model_dir = os.path.join(args.save_dir, 'models')
|
140 |
+
model_prefix = 'model-'
|
141 |
+
if os.path.exists(model_dir):
|
142 |
+
model_list = [name for name in os.listdir(model_dir) if name.startswith(model_prefix)]
|
143 |
+
last_epoch = 0
|
144 |
+
for name in model_list:
|
145 |
+
epochNumber = int(re.findall('\\d+', name)[0]) # model example name model-100.pt
|
146 |
+
if last_epoch < epochNumber:
|
147 |
+
last_epoch = epochNumber
|
148 |
+
|
149 |
+
args.start_epoch = last_epoch + 1
|
150 |
+
else:
|
151 |
+
# train from scratch
|
152 |
+
args.start_epoch = 1
|
153 |
+
elif args.start_epoch == 0:
|
154 |
+
# remove existing directory and start over
|
155 |
+
if args.rank == 0: # maybe local rank
|
156 |
+
shutil.rmtree(args.save_dir, ignore_errors=True)
|
157 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
158 |
+
args.start_epoch = 1
|
159 |
+
|
160 |
+
if args.load_epoch < 0: # load_epoch == start_epoch when doing a post-training test for a specific epoch
|
161 |
+
args.load_epoch = args.start_epoch - 1
|
162 |
+
|
163 |
+
if args.pretrained:
|
164 |
+
if args.start_epoch <= 1:
|
165 |
+
args.pretrained = os.path.join('../experiment', args.pretrained)
|
166 |
+
else:
|
167 |
+
print('starting from epoch {}! ignoring pretrained model path..'.format(args.start_epoch))
|
168 |
+
args.pretrained = ''
|
169 |
+
|
170 |
+
if args.model == 'MSResNet':
|
171 |
+
args.gaussian_pyramid = True
|
172 |
+
|
173 |
+
argname = os.path.join(args.save_dir, 'args.pt')
|
174 |
+
argname_txt = os.path.join(args.save_dir, 'args.txt')
|
175 |
+
if args.start_epoch > 1:
|
176 |
+
# load previous arguments and keep the necessary ones same
|
177 |
+
|
178 |
+
if os.path.exists(argname):
|
179 |
+
args_old = torch.load(argname)
|
180 |
+
|
181 |
+
load_list = [] # list of arguments that are fixed
|
182 |
+
# training
|
183 |
+
load_list += ['patch_size']
|
184 |
+
load_list += ['batch_size']
|
185 |
+
# data format
|
186 |
+
load_list += ['rgb_range']
|
187 |
+
load_list += ['blur_key']
|
188 |
+
# model architecture
|
189 |
+
load_list += ['n_scales']
|
190 |
+
load_list += ['n_resblocks']
|
191 |
+
load_list += ['n_feats']
|
192 |
+
|
193 |
+
for arg_part in load_list:
|
194 |
+
vars(args)[arg_part] = vars(args_old)[arg_part]
|
195 |
+
|
196 |
+
if args.dataset is not None:
|
197 |
+
args.data_train = args.dataset
|
198 |
+
args.data_val = args.dataset if args.dataset != 'GOPRO_Large' else None
|
199 |
+
args.data_test = args.dataset
|
200 |
+
|
201 |
+
if args.data_val is None:
|
202 |
+
args.do_validate = False
|
203 |
+
|
204 |
+
if args.demo_input_dir:
|
205 |
+
args.demo = True
|
206 |
+
|
207 |
+
if args.demo:
|
208 |
+
assert os.path.basename(args.save_dir) != now, 'You should specify pretrained directory by setting --save_dir SAVE_DIR'
|
209 |
+
|
210 |
+
args.data_train = ''
|
211 |
+
args.data_val = ''
|
212 |
+
args.data_test = ''
|
213 |
+
|
214 |
+
args.do_train = False
|
215 |
+
args.do_validate = False
|
216 |
+
args.do_test = False
|
217 |
+
|
218 |
+
assert len(args.demo_input_dir) > 0, 'Please specify demo_input_dir!'
|
219 |
+
args.demo_input_dir = os.path.expanduser(args.demo_input_dir)
|
220 |
+
if args.demo_output_dir:
|
221 |
+
args.demo_output_dir = os.path.expanduser(args.demo_output_dir)
|
222 |
+
|
223 |
+
args.save_results = 'all'
|
224 |
+
|
225 |
+
if args.amp:
|
226 |
+
args.precision = 'single' # model parameters should stay in fp32
|
227 |
+
|
228 |
+
if args.seed < 0:
|
229 |
+
args.seed = int(time.time())
|
230 |
+
|
231 |
+
# save arguments
|
232 |
+
if args.rank == 0:
|
233 |
+
torch.save(args, argname)
|
234 |
+
with open(argname_txt, 'a') as file:
|
235 |
+
file.write('execution at {}\n'.format(now))
|
236 |
+
|
237 |
+
for key in args.__dict__:
|
238 |
+
file.write(key + ': ' + str(args.__dict__[key]) + '\n')
|
239 |
+
|
240 |
+
file.write('\n')
|
241 |
+
|
242 |
+
# device and type
|
243 |
+
if args.device_type == 'cuda' and not torch.cuda.is_available():
|
244 |
+
raise Exception("GPU not available!")
|
245 |
+
|
246 |
+
if not args.distributed:
|
247 |
+
args.rank = 0
|
248 |
+
|
249 |
+
def setup(args):
|
250 |
+
cudnn.benchmark = True
|
251 |
+
|
252 |
+
if args.distributed:
|
253 |
+
os.environ['MASTER_ADDR'] = args.master_addr
|
254 |
+
os.environ['MASTER_PORT'] = args.master_port
|
255 |
+
|
256 |
+
args.device_index = args.rank
|
257 |
+
args.world_size = args.n_GPUs # consider single-node training
|
258 |
+
|
259 |
+
# initialize the process group
|
260 |
+
dist.init_process_group(args.dist_backend, init_method=args.init_method, rank=args.rank, world_size=args.world_size)
|
261 |
+
|
262 |
+
args.device = torch.device(args.device_type, args.device_index)
|
263 |
+
args.dtype = torch.float32
|
264 |
+
args.dtype_eval = torch.float32 if args.precision == 'single' else torch.float16
|
265 |
+
|
266 |
+
# set seed for processes (distributed: different seed for each process)
|
267 |
+
# model parameters are synchronized explicitly at initial
|
268 |
+
torch.manual_seed(args.seed)
|
269 |
+
if args.device_type == 'cuda':
|
270 |
+
torch.cuda.set_device(args.device)
|
271 |
+
if args.rank == 0:
|
272 |
+
torch.cuda.manual_seed_all(args.seed)
|
273 |
+
|
274 |
+
return args
|
275 |
+
|
276 |
+
def cleanup(args):
|
277 |
+
if args.distributed:
|
278 |
+
dist.destroy_process_group()
|
deblur/src/prepare.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
workplace=`pwd -P`
|
3 |
+
|
4 |
+
if [ ! -d "/data/ssd/public/czli/deblur/GOPRO_Large" ]; then
|
5 |
+
workplace=`pwd`
|
6 |
+
echo "Copying dataset to ssd"
|
7 |
+
cd /research/dept7/liuhy/deblur/dataset
|
8 |
+
mkdir -p /data/ssd/public/czli/deblur/GOPRO_Large
|
9 |
+
cp GOPRO_Large.zip /data/ssd/public/czli/deblur/GOPRO_Large
|
10 |
+
cd /data/ssd/public/czli/deblur/GOPRO_Large
|
11 |
+
echo "Dumping zip in data path:" `pwd`
|
12 |
+
for f in *.zip; do unzip "$f"; done
|
13 |
+
fi
|
14 |
+
|
15 |
+
cd "$workplace"
|
16 |
+
echo "Workplace:" `pwd`
|
deblur/src/template.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def set_template(args):
|
2 |
+
if args.template.find('gopro') >= 0:
|
3 |
+
args.dataset = 'GOPRO_Large'
|
4 |
+
args.milestones = [500, 750, 900]
|
5 |
+
args.end_epoch = 1000
|
6 |
+
elif args.template.find('reds') >= 0:
|
7 |
+
args.dataset = 'REDS'
|
8 |
+
args.milestones = [100, 150, 180]
|
9 |
+
args.end_epoch = 200
|
deblur/src/train-rec-1.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
bash ./prepare.sh
|
4 |
+
python launch.py --n_GPUs 2 main.py --batch_size 16 --amp true --save_dir LamRes_L1 --n_scales 1 --loss 1*L1
|
deblur/src/train.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import data.common
|
7 |
+
from utils import interact, MultiSaver
|
8 |
+
from torch.utils.tensorboard import SummaryWriter
|
9 |
+
import torchvision
|
10 |
+
|
11 |
+
import torch.cuda.amp as amp
|
12 |
+
|
13 |
+
class Trainer():
|
14 |
+
|
15 |
+
def __init__(self, args, model, criterion, optimizer, loaders):
|
16 |
+
print('===> Initializing trainer')
|
17 |
+
self.args = args
|
18 |
+
self.mode = 'train' # 'val', 'test'
|
19 |
+
self.epoch = args.start_epoch
|
20 |
+
self.save_dir = args.save_dir
|
21 |
+
|
22 |
+
self.model = model
|
23 |
+
self.criterion = criterion
|
24 |
+
self.optimizer = optimizer
|
25 |
+
self.loaders = loaders
|
26 |
+
|
27 |
+
|
28 |
+
self.do_train = args.do_train
|
29 |
+
self.do_validate = args.do_validate
|
30 |
+
self.do_test = args.do_test
|
31 |
+
|
32 |
+
self.device = args.device
|
33 |
+
self.dtype = args.dtype
|
34 |
+
self.dtype_eval = torch.float32 if args.precision == 'single' else torch.float16
|
35 |
+
self.recurrence=args.n_scales
|
36 |
+
if self.args.demo and self.args.demo_output_dir:
|
37 |
+
self.result_dir = self.args.demo_output_dir
|
38 |
+
else:
|
39 |
+
self.result_dir = os.path.join(self.save_dir, 'result')
|
40 |
+
os.makedirs(self.result_dir, exist_ok=True)
|
41 |
+
print('results are saved in {}'.format(self.result_dir))
|
42 |
+
|
43 |
+
self.imsaver = MultiSaver(self.result_dir)
|
44 |
+
|
45 |
+
self.is_slave = self.args.launched and self.args.rank != 0
|
46 |
+
|
47 |
+
self.scaler = amp.GradScaler(
|
48 |
+
init_scale=self.args.init_scale,
|
49 |
+
enabled=self.args.amp
|
50 |
+
)
|
51 |
+
if not self.is_slave:
|
52 |
+
self.writter=SummaryWriter(f"runs/{args.save_dir}")
|
53 |
+
|
54 |
+
|
55 |
+
def save(self, epoch=None):
|
56 |
+
epoch = self.epoch if epoch is None else epoch
|
57 |
+
if epoch % self.args.save_every == 0:
|
58 |
+
if self.mode == 'train':
|
59 |
+
self.model.save(epoch)
|
60 |
+
self.optimizer.save(epoch)
|
61 |
+
self.criterion.save()
|
62 |
+
|
63 |
+
return
|
64 |
+
|
65 |
+
def load(self, epoch=None, pretrained=None):
|
66 |
+
if epoch is None:
|
67 |
+
epoch = self.args.load_epoch
|
68 |
+
self.epoch = epoch
|
69 |
+
self.model.load(epoch, pretrained)
|
70 |
+
self.optimizer.load(epoch)
|
71 |
+
self.criterion.load(epoch)
|
72 |
+
|
73 |
+
return
|
74 |
+
|
75 |
+
def train(self, epoch):
|
76 |
+
self.mode = 'train'
|
77 |
+
self.epoch = epoch
|
78 |
+
|
79 |
+
self.model.train()
|
80 |
+
self.model.to(dtype=self.dtype)
|
81 |
+
|
82 |
+
self.criterion.train()
|
83 |
+
self.criterion.epoch = epoch
|
84 |
+
|
85 |
+
if not self.is_slave:
|
86 |
+
print('[Epoch {} / lr {:.2e}]'.format(
|
87 |
+
epoch, self.optimizer.get_lr()
|
88 |
+
))
|
89 |
+
total=len(self.loaders[self.mode])
|
90 |
+
acc=0.0
|
91 |
+
if self.args.distributed:
|
92 |
+
self.loaders[self.mode].sampler.set_epoch(epoch)
|
93 |
+
if self.is_slave:
|
94 |
+
tq = self.loaders[self.mode]
|
95 |
+
else:
|
96 |
+
tq = tqdm(self.loaders[self.mode], ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')
|
97 |
+
buffer=[0.0]*self.recurrence
|
98 |
+
torch.set_grad_enabled(True)
|
99 |
+
for idx, batch in enumerate(tq):
|
100 |
+
self.optimizer.zero_grad()
|
101 |
+
|
102 |
+
input, target = data.common.to(
|
103 |
+
batch[0], batch[1], device=self.device, dtype=self.dtype)
|
104 |
+
|
105 |
+
with amp.autocast(self.args.amp):
|
106 |
+
output = self.model(input)
|
107 |
+
loss = self.criterion(output, target)
|
108 |
+
|
109 |
+
for i in range(self.recurrence):
|
110 |
+
buffer[i]+=self.criterion.buffer[i]
|
111 |
+
|
112 |
+
self.scaler.scale(loss).backward()
|
113 |
+
if self.args.clip>0:
|
114 |
+
self.scaler.unscale_(self.optimizer.G)
|
115 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
|
116 |
+
self.scaler.step(self.optimizer.G)
|
117 |
+
self.scaler.update()
|
118 |
+
|
119 |
+
if isinstance(tq, tqdm):
|
120 |
+
tq.set_description(self.criterion.get_loss_desc())
|
121 |
+
if not self.is_slave:
|
122 |
+
rgb_range=self.args.rgb_range
|
123 |
+
for i in range(len(output)):
|
124 |
+
grid=torchvision.utils.make_grid(output[i])
|
125 |
+
self.writter.add_image(f"Output{i}",grid/rgb_range,epoch)
|
126 |
+
self.writter.add_scalar(f"Loss{i}",buffer[i],epoch)
|
127 |
+
self.writter.add_image("Input",torchvision.utils.make_grid(input[0])/rgb_range,epoch)
|
128 |
+
self.writter.add_image("Target",torchvision.utils.make_grid(target[0])/rgb_range,epoch)
|
129 |
+
self.criterion.normalize()
|
130 |
+
if isinstance(tq, tqdm):
|
131 |
+
tq.set_description(self.criterion.get_loss_desc())
|
132 |
+
tq.display(pos=-1) # overwrite with synchronized loss
|
133 |
+
|
134 |
+
self.criterion.step()
|
135 |
+
self.optimizer.schedule(self.criterion.get_last_loss())
|
136 |
+
|
137 |
+
if self.args.rank == 0:
|
138 |
+
self.save(epoch)
|
139 |
+
|
140 |
+
return
|
141 |
+
|
142 |
+
def evaluate(self, epoch, mode='val'):
|
143 |
+
self.mode = mode
|
144 |
+
self.epoch = epoch
|
145 |
+
|
146 |
+
self.model.eval()
|
147 |
+
self.model.to(dtype=self.dtype_eval)
|
148 |
+
|
149 |
+
if mode == 'val':
|
150 |
+
self.criterion.validate()
|
151 |
+
elif mode == 'test':
|
152 |
+
self.criterion.test()
|
153 |
+
self.criterion.epoch = epoch
|
154 |
+
|
155 |
+
self.imsaver.join_background()
|
156 |
+
|
157 |
+
if self.is_slave:
|
158 |
+
tq = self.loaders[self.mode]
|
159 |
+
else:
|
160 |
+
tq = tqdm(self.loaders[self.mode], ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')
|
161 |
+
|
162 |
+
compute_loss = True
|
163 |
+
torch.set_grad_enabled(False)
|
164 |
+
for idx, batch in enumerate(tq):
|
165 |
+
input, target = data.common.to(
|
166 |
+
batch[0], batch[1], device=self.device, dtype=self.dtype_eval)
|
167 |
+
with amp.autocast(self.args.amp):
|
168 |
+
output = self.model(input)
|
169 |
+
|
170 |
+
# if self.args.rgb_range==1:
|
171 |
+
# output=output*255
|
172 |
+
# target=target*255
|
173 |
+
|
174 |
+
if mode == 'demo': # remove padded part
|
175 |
+
pad_width = batch[2]
|
176 |
+
output[0], _ = data.common.pad(output[0], pad_width=pad_width, negative=True)
|
177 |
+
|
178 |
+
if isinstance(batch[1], torch.BoolTensor):
|
179 |
+
compute_loss = False
|
180 |
+
|
181 |
+
if compute_loss:
|
182 |
+
self.criterion(output, target)
|
183 |
+
if isinstance(tq, tqdm):
|
184 |
+
tq.set_description(self.criterion.get_loss_desc())
|
185 |
+
|
186 |
+
if self.args.save_results != 'none':
|
187 |
+
if isinstance(output, (list, tuple)):
|
188 |
+
result = output[-1] # select last output in a pyramid
|
189 |
+
elif isinstance(output, torch.Tensor):
|
190 |
+
result = output
|
191 |
+
|
192 |
+
names = batch[-1]
|
193 |
+
|
194 |
+
if self.args.save_results == 'part' and compute_loss: # save all when GT not available
|
195 |
+
indices = batch[-2]
|
196 |
+
save_ids = [save_id for save_id, idx in enumerate(indices) if idx % 10 == 0]
|
197 |
+
|
198 |
+
result = result[save_ids]
|
199 |
+
names = [names[save_id] for save_id in save_ids]
|
200 |
+
|
201 |
+
self.imsaver.save_image(result, names)
|
202 |
+
|
203 |
+
if compute_loss:
|
204 |
+
self.criterion.normalize()
|
205 |
+
if isinstance(tq, tqdm):
|
206 |
+
tq.set_description(self.criterion.get_loss_desc())
|
207 |
+
tq.display(pos=-1) # overwrite with synchronized loss
|
208 |
+
|
209 |
+
self.criterion.step()
|
210 |
+
if self.args.rank == 0:
|
211 |
+
self.save()
|
212 |
+
|
213 |
+
self.imsaver.end_background()
|
214 |
+
|
215 |
+
def validate(self, epoch):
|
216 |
+
self.evaluate(epoch, 'val')
|
217 |
+
return
|
218 |
+
|
219 |
+
def test(self, epoch):
|
220 |
+
self.evaluate(epoch, 'test')
|
221 |
+
return
|
222 |
+
|
223 |
+
def fill_evaluation(self, epoch, mode=None, force=False):
|
224 |
+
if epoch <= 0:
|
225 |
+
return
|
226 |
+
|
227 |
+
if mode is not None:
|
228 |
+
self.mode = mode
|
229 |
+
|
230 |
+
do_eval = force
|
231 |
+
if not force:
|
232 |
+
loss_missing = epoch not in self.criterion.loss_stat[self.mode]['Total'] # should it switch to all loss types?
|
233 |
+
|
234 |
+
metric_missing = False
|
235 |
+
for metric_type in self.criterion.metric:
|
236 |
+
if epoch not in self.criterion.metric_stat[mode][metric_type]:
|
237 |
+
metric_missing = True
|
238 |
+
|
239 |
+
do_eval = loss_missing or metric_missing
|
240 |
+
|
241 |
+
if do_eval:
|
242 |
+
try:
|
243 |
+
self.load(epoch)
|
244 |
+
self.evaluate(epoch, self.mode)
|
245 |
+
except:
|
246 |
+
# print('saved model/optimizer at epoch {} not found!'.format(epoch))
|
247 |
+
pass
|
248 |
+
|
249 |
+
return
|
deblur/src/train.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
bash ./prepare.sh
|
4 |
+
python launch.py --n_GPUs 2 main.py --batch_size 16 --amp true --save_dir RecLamRes_MSE
|
deblur/src/train01.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
bash ./prepare.sh
|
4 |
+
CUDA_VISIBLE_DEVICES=0,1 python launch.py --n_GPUs 2 main.py --batch_size 16 --amp true --save_dir RecLamRes_MSE
|
deblur/src/train23.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
bash ./prepare.sh
|
4 |
+
CUDA_VISIBLE_DEVICES=2,3 python launch.py --n_GPUs 2 main.py --batch_size 8 --amp true --save_dir RecLamRes_MSE_detach --detach true --master_port 8025
|
deblur/src/utils.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import readline
|
2 |
+
import rlcompleter
|
3 |
+
readline.parse_and_bind("tab: complete")
|
4 |
+
import code
|
5 |
+
import pdb
|
6 |
+
|
7 |
+
import time
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import imageio
|
11 |
+
import torch
|
12 |
+
import torch.multiprocessing as mp
|
13 |
+
|
14 |
+
# debugging tools
|
15 |
+
def interact(local=None):
|
16 |
+
"""interactive console with autocomplete function. Useful for debugging.
|
17 |
+
interact(locals())
|
18 |
+
"""
|
19 |
+
if local is None:
|
20 |
+
local=dict(globals(), **locals())
|
21 |
+
|
22 |
+
readline.set_completer(rlcompleter.Completer(local).complete)
|
23 |
+
code.interact(local=local)
|
24 |
+
|
25 |
+
def set_trace(local=None):
|
26 |
+
"""debugging with pdb
|
27 |
+
"""
|
28 |
+
if local is None:
|
29 |
+
local=dict(globals(), **locals())
|
30 |
+
|
31 |
+
pdb.Pdb.complete = rlcompleter.Completer(local).complete
|
32 |
+
pdb.set_trace()
|
33 |
+
|
34 |
+
# timer
|
35 |
+
class Timer():
|
36 |
+
"""Brought from https://github.com/thstkdgus35/EDSR-PyTorch
|
37 |
+
"""
|
38 |
+
def __init__(self):
|
39 |
+
self.acc = 0
|
40 |
+
self.tic()
|
41 |
+
|
42 |
+
def tic(self):
|
43 |
+
self.t0 = time.time()
|
44 |
+
|
45 |
+
def toc(self):
|
46 |
+
return time.time() - self.t0
|
47 |
+
|
48 |
+
def hold(self):
|
49 |
+
self.acc += self.toc()
|
50 |
+
|
51 |
+
def release(self):
|
52 |
+
ret = self.acc
|
53 |
+
self.acc = 0
|
54 |
+
|
55 |
+
return ret
|
56 |
+
|
57 |
+
def reset(self):
|
58 |
+
self.acc = 0
|
59 |
+
|
60 |
+
|
61 |
+
# argument parser type casting functions
|
62 |
+
def str2bool(val):
|
63 |
+
"""enable default constant true arguments"""
|
64 |
+
# https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
|
65 |
+
if isinstance(val, bool):
|
66 |
+
return val
|
67 |
+
elif val.lower() == 'true':
|
68 |
+
return True
|
69 |
+
elif val.lower() == 'false':
|
70 |
+
return False
|
71 |
+
else:
|
72 |
+
raise argparse.ArgumentTypeError('Boolean value expected')
|
73 |
+
|
74 |
+
def int2str(val):
|
75 |
+
"""convert int to str for environment variable related arguments"""
|
76 |
+
if isinstance(val, int):
|
77 |
+
return str(val)
|
78 |
+
elif isinstance(val, str):
|
79 |
+
return val
|
80 |
+
else:
|
81 |
+
raise argparse.ArgumentTypeError('number value expected')
|
82 |
+
|
83 |
+
|
84 |
+
# image saver using multiprocessing queue
|
85 |
+
class MultiSaver():
|
86 |
+
def __init__(self, result_dir=None):
|
87 |
+
self.queue = None
|
88 |
+
self.process = None
|
89 |
+
self.result_dir = result_dir
|
90 |
+
|
91 |
+
def begin_background(self):
|
92 |
+
self.queue = mp.Queue()
|
93 |
+
|
94 |
+
def t(queue):
|
95 |
+
while True:
|
96 |
+
if queue.empty():
|
97 |
+
continue
|
98 |
+
img, name = queue.get()
|
99 |
+
if name:
|
100 |
+
try:
|
101 |
+
basename, ext = os.path.splitext(name)
|
102 |
+
if ext != '.png':
|
103 |
+
name = '{}.png'.format(basename)
|
104 |
+
imageio.imwrite(name, img)
|
105 |
+
except Exception as e:
|
106 |
+
print(e)
|
107 |
+
else:
|
108 |
+
return
|
109 |
+
|
110 |
+
worker = lambda: mp.Process(target=t, args=(self.queue,), daemon=False)
|
111 |
+
cpu_count = min(8, mp.cpu_count() - 1)
|
112 |
+
self.process = [worker() for _ in range(cpu_count)]
|
113 |
+
for p in self.process:
|
114 |
+
p.start()
|
115 |
+
|
116 |
+
def end_background(self):
|
117 |
+
if self.queue is None:
|
118 |
+
return
|
119 |
+
|
120 |
+
for _ in self.process:
|
121 |
+
self.queue.put((None, None))
|
122 |
+
|
123 |
+
def join_background(self):
|
124 |
+
if self.queue is None:
|
125 |
+
return
|
126 |
+
|
127 |
+
while not self.queue.empty():
|
128 |
+
time.sleep(0.5)
|
129 |
+
|
130 |
+
for p in self.process:
|
131 |
+
p.join()
|
132 |
+
|
133 |
+
self.queue = None
|
134 |
+
|
135 |
+
def save_image(self, output, save_names, result_dir=None):
|
136 |
+
result_dir = result_dir if self.result_dir is None else self.result_dir
|
137 |
+
if result_dir is None:
|
138 |
+
raise Exception('no result dir specified!')
|
139 |
+
|
140 |
+
if self.queue is None:
|
141 |
+
try:
|
142 |
+
self.begin_background()
|
143 |
+
except Exception as e:
|
144 |
+
print(e)
|
145 |
+
return
|
146 |
+
|
147 |
+
# assume NCHW format
|
148 |
+
if output.ndim == 2:
|
149 |
+
output = output.expand([1, 1] + list(output.shape))
|
150 |
+
elif output.ndim == 3:
|
151 |
+
output = output.expand([1] + list(output.shape))
|
152 |
+
|
153 |
+
for output_img, save_name in zip(output, save_names):
|
154 |
+
# assume image range [0, 255]
|
155 |
+
output_img = output_img.add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
156 |
+
|
157 |
+
save_name = os.path.join(result_dir, save_name)
|
158 |
+
save_dir = os.path.dirname(save_name)
|
159 |
+
os.makedirs(save_dir, exist_ok=True)
|
160 |
+
|
161 |
+
self.queue.put((output_img, save_name))
|
162 |
+
|
163 |
+
return
|
164 |
+
|
165 |
+
class Map(dict):
|
166 |
+
"""
|
167 |
+
https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
|
168 |
+
Example:
|
169 |
+
m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer'])
|
170 |
+
"""
|
171 |
+
def __init__(self, *args, **kwargs):
|
172 |
+
super(Map, self).__init__(*args, **kwargs)
|
173 |
+
for arg in args:
|
174 |
+
if isinstance(arg, dict):
|
175 |
+
for k, v in arg.items():
|
176 |
+
self[k] = v
|
177 |
+
|
178 |
+
if kwargs:
|
179 |
+
for k, v in kwargs.items():
|
180 |
+
self[k] = v
|
181 |
+
|
182 |
+
def __getattr__(self, attr):
|
183 |
+
return self.get(attr)
|
184 |
+
|
185 |
+
def __setattr__(self, key, value):
|
186 |
+
self.__setitem__(key, value)
|
187 |
+
|
188 |
+
def __setitem__(self, key, value):
|
189 |
+
super(Map, self).__setitem__(key, value)
|
190 |
+
self.__dict__.update({key: value})
|
191 |
+
|
192 |
+
def __delattr__(self, item):
|
193 |
+
self.__delitem__(item)
|
194 |
+
|
195 |
+
def __delitem__(self, key):
|
196 |
+
super(Map, self).__delitem__(key)
|
197 |
+
del self.__dict__[key]
|
198 |
+
|
199 |
+
def toDict(self):
|
200 |
+
return self.__dict__
|