|
"""optionional argument parsing"""
|
|
|
|
import argparse
|
|
import datetime
|
|
import os
|
|
import re
|
|
import shutil
|
|
import time
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.backends.cudnn as cudnn
|
|
|
|
from utils import interact
|
|
from utils import str2bool, int2str
|
|
|
|
import template
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Dynamic Scene Deblurring')
|
|
|
|
|
|
group_device = parser.add_argument_group('Device specs')
|
|
group_device.add_argument('--seed', type=int, default=-1, help='random seed')
|
|
group_device.add_argument('--num_workers', type=int, default=7, help='the number of dataloader workers')
|
|
group_device.add_argument('--device_type', type=str, choices=('cpu', 'cuda'), default='cuda', help='device to run models')
|
|
group_device.add_argument('--device_index', type=int, default=0, help='device id to run models')
|
|
group_device.add_argument('--n_GPUs', type=int, default=1, help='the number of GPUs for training')
|
|
group_device.add_argument('--distributed', type=str2bool, default=False, help='use DistributedDataParallel instead of DataParallel for better speed')
|
|
group_device.add_argument('--launched', type=str2bool, default=False, help='identify if main.py was executed from launch.py. Do not set this to be true using main.py.')
|
|
|
|
group_device.add_argument('--master_addr', type=str, default='127.0.0.1', help='master address for distributed')
|
|
group_device.add_argument('--master_port', type=int2str, default='8023', help='master port for distributed')
|
|
group_device.add_argument('--dist_backend', type=str, default='nccl', help='distributed backend')
|
|
group_device.add_argument('--init_method', type=str, default='env://', help='distributed init method URL to discover peers')
|
|
group_device.add_argument('--rank', type=int, default=0, help='rank of the distributed process (gpu id). 0 is the master process.')
|
|
group_device.add_argument('--world_size', type=int, default=1, help='world_size for distributed training (number of GPUs)')
|
|
|
|
|
|
group_data = parser.add_argument_group('Data specs')
|
|
group_data.add_argument('--data_root', type=str, default='/data/ssd/public/czli/deblur', help='dataset root location')
|
|
group_data.add_argument('--dataset', type=str, default=None, help='training/validation/test dataset name, has priority if not None')
|
|
group_data.add_argument('--data_train', type=str, default='GOPRO_Large', help='training dataset name')
|
|
group_data.add_argument('--data_val', type=str, default=None, help='validation dataset name')
|
|
group_data.add_argument('--data_test', type=str, default='GOPRO_Large', help='test dataset name')
|
|
group_data.add_argument('--blur_key', type=str, default='blur_gamma', choices=('blur', 'blur_gamma'), help='blur type from camera response function for GOPRO_Large dataset')
|
|
group_data.add_argument('--rgb_range', type=int, default=255, help='RGB pixel value ranging from 0')
|
|
|
|
|
|
group_model = parser.add_argument_group('Model specs')
|
|
group_model.add_argument('--model', type=str, default='RecLamResNet', help='model architecture')
|
|
group_model.add_argument('--pretrained', type=str, default='', help='pretrained model location')
|
|
group_model.add_argument('--n_scales', type=int, default=5, help='multi-scale deblurring level')
|
|
group_model.add_argument('--detach', type=str2bool, default=False, help='detach between recurrence')
|
|
group_model.add_argument('--gaussian_pyramid', type=str2bool, default=True, help='gaussian pyramid input/target')
|
|
group_model.add_argument('--n_resblocks', type=int, default=19, help='number of residual blocks per scale')
|
|
group_model.add_argument('--n_feats', type=int, default=64, help='number of feature maps')
|
|
group_model.add_argument('--kernel_size', type=int, default=5, help='size of conv kernel')
|
|
group_model.add_argument('--downsample', type=str, choices=('Gaussian', 'bicubic', 'stride'), default='Gaussian', help='input pyramid generation method')
|
|
|
|
group_model.add_argument('--precision', type=str, default='single', choices=('single', 'half'), help='FP precision for test(single | half)')
|
|
|
|
|
|
group_amp = parser.add_argument_group('AMP specs')
|
|
group_amp.add_argument('--amp', type=str2bool, default=False, help='use automatic mixed precision training')
|
|
group_amp.add_argument('--init_scale', type=float, default=1024., help='initial loss scale')
|
|
|
|
|
|
group_train = parser.add_argument_group('Training specs')
|
|
group_train.add_argument('--patch_size', type=int, default=256, help='training patch size')
|
|
group_train.add_argument('--batch_size', type=int, default=16, help='input batch size for training')
|
|
group_train.add_argument('--split_batch', type=int, default=1, help='split a minibatch into smaller chunks')
|
|
group_train.add_argument('--augment', type=str2bool, default=True, help='train with data augmentation')
|
|
|
|
|
|
group_test = parser.add_argument_group('Testing specs')
|
|
group_test.add_argument('--validate_every', type=int, default=10, help='do validation at every N epochs')
|
|
group_test.add_argument('--test_every', type=int, default=10, help='do test at every N epochs')
|
|
|
|
|
|
|
|
|
|
group_action = parser.add_argument_group('Source behavior')
|
|
group_action.add_argument('--do_train', type=str2bool, default=True, help='do train the model')
|
|
group_action.add_argument('--do_validate', type=str2bool, default=True, help='do validate the model')
|
|
group_action.add_argument('--do_test', type=str2bool, default=True, help='do test the model')
|
|
group_action.add_argument('--demo', type=str2bool, default=False, help='demo')
|
|
group_action.add_argument('--demo_input_dir', type=str, default='', help='demo input directory')
|
|
group_action.add_argument('--demo_output_dir', type=str, default='', help='demo output directory')
|
|
|
|
|
|
group_optim = parser.add_argument_group('Optimization specs')
|
|
group_optim.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
|
group_optim.add_argument('--milestones', type=int, nargs='+', default=[500, 750, 900], help='learning rate decay per N epochs')
|
|
group_optim.add_argument('--scheduler', default='step', choices=('step', 'plateau'), help='learning rate scheduler type')
|
|
group_optim.add_argument('--gamma', type=float, default=0.5, help='learning rate decay factor for step decay')
|
|
group_optim.add_argument('--optimizer', default='ADAM', choices=('SGD', 'ADAM', 'RMSprop'), help='optimizer to use (SGD | ADAM | RMSProp)')
|
|
group_optim.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
|
|
group_optim.add_argument('--betas', type=float, nargs=2, default=(0.9, 0.999), help='ADAM betas')
|
|
group_optim.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon')
|
|
group_optim.add_argument('--weight_decay', type=float, default=0, help='weight decay')
|
|
group_optim.add_argument('--clip', type=float, default=0, help='weight decay')
|
|
|
|
|
|
group_loss = parser.add_argument_group('Loss specs')
|
|
group_loss.add_argument('--loss', type=str, default='1*MSE', help='loss function configuration')
|
|
group_loss.add_argument('--metric', type=str, default='PSNR,SSIM', help='metric function configuration. ex) None | PSNR | SSIM | PSNR,SSIM')
|
|
group_loss.add_argument('--gamma', type=float, default=0.6, help='gamma decay')
|
|
|
|
|
|
|
|
group_log = parser.add_argument_group('Logging specs')
|
|
group_log.add_argument('--save_dir', type=str, default='', help='subdirectory to save experiment logs')
|
|
|
|
group_log.add_argument('--start_epoch', type=int, default=-1, help='(re)starting epoch number')
|
|
group_log.add_argument('--end_epoch', type=int, default=1000, help='ending epoch number')
|
|
group_log.add_argument('--load_epoch', type=int, default=-1, help='epoch number to load model (start_epoch-1 for training, start_epoch for testing)')
|
|
group_log.add_argument('--save_every', type=int, default=10, help='save model/optimizer at every N epochs')
|
|
group_log.add_argument('--save_results', type=str, default='part', choices=('none', 'part', 'all'), help='save none/part/all of result images')
|
|
|
|
|
|
group_debug = parser.add_argument_group('Debug specs')
|
|
group_debug.add_argument('--stay', type=str2bool, default=False, help='stay at interactive console after trainer initialization')
|
|
|
|
parser.add_argument('--template', type=str, default='', help='argument template option')
|
|
|
|
args = parser.parse_args()
|
|
template.set_template(args)
|
|
|
|
args.data_root = os.path.expanduser(args.data_root)
|
|
now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
|
if args.save_dir == '':
|
|
args.save_dir = now
|
|
args.save_dir = os.path.join('../experiment', args.save_dir)
|
|
os.makedirs(args.save_dir, exist_ok=True)
|
|
|
|
if args.start_epoch < 0:
|
|
|
|
model_dir = os.path.join(args.save_dir, 'models')
|
|
model_prefix = 'model-'
|
|
if os.path.exists(model_dir):
|
|
model_list = [name for name in os.listdir(model_dir) if name.startswith(model_prefix)]
|
|
last_epoch = 0
|
|
for name in model_list:
|
|
epochNumber = int(re.findall('\\d+', name)[0])
|
|
if last_epoch < epochNumber:
|
|
last_epoch = epochNumber
|
|
|
|
args.start_epoch = last_epoch + 1
|
|
else:
|
|
|
|
args.start_epoch = 1
|
|
elif args.start_epoch == 0:
|
|
|
|
if args.rank == 0:
|
|
shutil.rmtree(args.save_dir, ignore_errors=True)
|
|
os.makedirs(args.save_dir, exist_ok=True)
|
|
args.start_epoch = 1
|
|
|
|
if args.load_epoch < 0:
|
|
args.load_epoch = args.start_epoch - 1
|
|
|
|
if args.pretrained:
|
|
if args.start_epoch <= 1:
|
|
args.pretrained = os.path.join('../experiment', args.pretrained)
|
|
else:
|
|
print('starting from epoch {}! ignoring pretrained model path..'.format(args.start_epoch))
|
|
args.pretrained = ''
|
|
|
|
if args.model == 'MSResNet':
|
|
args.gaussian_pyramid = True
|
|
|
|
argname = os.path.join(args.save_dir, 'args.pt')
|
|
argname_txt = os.path.join(args.save_dir, 'args.txt')
|
|
if args.start_epoch > 1:
|
|
|
|
|
|
if os.path.exists(argname):
|
|
args_old = torch.load(argname)
|
|
|
|
load_list = []
|
|
|
|
load_list += ['patch_size']
|
|
load_list += ['batch_size']
|
|
|
|
load_list += ['rgb_range']
|
|
load_list += ['blur_key']
|
|
|
|
load_list += ['n_scales']
|
|
load_list += ['n_resblocks']
|
|
load_list += ['n_feats']
|
|
|
|
for arg_part in load_list:
|
|
vars(args)[arg_part] = vars(args_old)[arg_part]
|
|
|
|
if args.dataset is not None:
|
|
args.data_train = args.dataset
|
|
args.data_val = args.dataset if args.dataset != 'GOPRO_Large' else None
|
|
args.data_test = args.dataset
|
|
|
|
if args.data_val is None:
|
|
args.do_validate = False
|
|
|
|
if args.demo_input_dir:
|
|
args.demo = True
|
|
|
|
if args.demo:
|
|
assert os.path.basename(args.save_dir) != now, 'You should specify pretrained directory by setting --save_dir SAVE_DIR'
|
|
|
|
args.data_train = ''
|
|
args.data_val = ''
|
|
args.data_test = ''
|
|
|
|
args.do_train = False
|
|
args.do_validate = False
|
|
args.do_test = False
|
|
|
|
assert len(args.demo_input_dir) > 0, 'Please specify demo_input_dir!'
|
|
args.demo_input_dir = os.path.expanduser(args.demo_input_dir)
|
|
if args.demo_output_dir:
|
|
args.demo_output_dir = os.path.expanduser(args.demo_output_dir)
|
|
|
|
args.save_results = 'all'
|
|
|
|
if args.amp:
|
|
args.precision = 'single'
|
|
|
|
if args.seed < 0:
|
|
args.seed = int(time.time())
|
|
|
|
|
|
if args.rank == 0:
|
|
torch.save(args, argname)
|
|
with open(argname_txt, 'a') as file:
|
|
file.write('execution at {}\n'.format(now))
|
|
|
|
for key in args.__dict__:
|
|
file.write(key + ': ' + str(args.__dict__[key]) + '\n')
|
|
|
|
file.write('\n')
|
|
|
|
|
|
if args.device_type == 'cuda' and not torch.cuda.is_available():
|
|
raise Exception("GPU not available!")
|
|
|
|
if not args.distributed:
|
|
args.rank = 0
|
|
|
|
def setup(args):
|
|
cudnn.benchmark = True
|
|
|
|
if args.distributed:
|
|
os.environ['MASTER_ADDR'] = args.master_addr
|
|
os.environ['MASTER_PORT'] = args.master_port
|
|
|
|
args.device_index = args.rank
|
|
args.world_size = args.n_GPUs
|
|
|
|
|
|
dist.init_process_group(args.dist_backend, init_method=args.init_method, rank=args.rank, world_size=args.world_size)
|
|
|
|
args.device = torch.device(args.device_type, args.device_index)
|
|
args.dtype = torch.float32
|
|
args.dtype_eval = torch.float32 if args.precision == 'single' else torch.float16
|
|
|
|
|
|
|
|
torch.manual_seed(args.seed)
|
|
if args.device_type == 'cuda':
|
|
torch.cuda.set_device(args.device)
|
|
if args.rank == 0:
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
return args
|
|
|
|
def cleanup(args):
|
|
if args.distributed:
|
|
dist.destroy_process_group()
|
|
|