File size: 2,138 Bytes
8d015d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7a7cfe
8d015d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
import os

import yaml
import torch
import torch.multiprocessing as mp

import utils
from utils.experiment import *
from trainer import Trainer


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg')
    parser.add_argument('--load-root', default='data')
    parser.add_argument('--save-root', default='save')
    parser.add_argument('--name', '-n', default=None)
    parser.add_argument('--tag', default=None)
    parser.add_argument('--cudnn', action='store_true')
    parser.add_argument('--port-offset', '-p', type=int, default=0)
    parser.add_argument('--wandb-upload', '-w', action='store_true')
    args = parser.parse_args()

    return args


def make_cfg(args):
    with open(args.cfg, 'r') as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)

    def translate_cfg_(d):
        for k, v in d.items():
            if isinstance(v, dict):
                translate_cfg_(v)
            elif isinstance(v, str):
                d[k] = v.replace('$load_root$', args.load_root)
    translate_cfg_(cfg)
    
    if args.name is None:
        exp_name = os.path.basename(args.cfg).split('.')[0].replace('_benchmark', '').replace('_demo', '')
    else:
        exp_name = args.name
    if args.tag is not None:
        exp_name += '_' + args.tag

    env = dict()
    env['exp_name'] = exp_name + '_' + cfg['exp_name']
    env['save_dir'] = os.path.join(args.save_root, env['exp_name'])
    #env['tot_gpus'] = torch.cuda.device_count()
    env['cudnn'] = args.cudnn
    env['port'] = str(29600 + args.port_offset)
    env['wandb_upload'] = args.wandb_upload
    cfg['env'] = env

    return cfg


def main():
    args = parse_args()

    cfgs = make_cfg(args)

    init_experiment(cfgs)
    init_distributed_mode(cfgs)
    print('here')
    init_deterministic(cfgs['seed'])

    trainer = Trainer(cfgs)

    if cfgs['mode'] == 'train':
        trainer.train()
    elif cfgs['mode'] == 'validate':
        trainer.validate()
    elif cfgs['mode'] == 'test':
        trainer.test()
    elif cfgs['mode'] == 'demo':
        trainer.demo()



if __name__ == '__main__':
    main()