|
import os |
|
import argparse |
|
from shutil import copyfile |
|
import torch.distributed as dist |
|
import torch |
|
import importlib |
|
import datetime |
|
from utils.dist_utils import ( |
|
get_world_size, |
|
) |
|
from omegaconf import OmegaConf |
|
from utils.utils import seed_all |
|
parser = argparse.ArgumentParser(description='VFI') |
|
parser.add_argument('-c', '--config', type=str) |
|
parser.add_argument('-p', '--port', default='23455', type=str) |
|
parser.add_argument('--local_rank', default='0') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
def main_worker(rank, config): |
|
if 'local_rank' not in config: |
|
config['local_rank'] = config['global_rank'] = rank |
|
if torch.cuda.is_available(): |
|
print(f'Rank {rank} is available') |
|
config['device'] = f"cuda:{rank}" |
|
if config['distributed']: |
|
dist.init_process_group(backend='nccl', |
|
timeout=datetime.timedelta(seconds=5400)) |
|
else: |
|
config['device'] = 'cpu' |
|
|
|
cfg_name = os.path.basename(args.config).split('.')[0] |
|
config['exp_name'] = cfg_name + '_' + config['exp_name'] |
|
config['save_dir'] = os.path.join(config['save_dir'], config['exp_name']) |
|
|
|
if (not config['distributed']) or rank == 0: |
|
os.makedirs(config['save_dir'], exist_ok=True) |
|
os.makedirs(f'{config["save_dir"]}/ckpts', exist_ok=True) |
|
config_path = os.path.join(config['save_dir'], |
|
args.config.split('/')[-1]) |
|
if not os.path.isfile(config_path): |
|
copyfile(args.config, config_path) |
|
print('[**] create folder {}'.format(config['save_dir'])) |
|
|
|
trainer_name = config.get('trainer_type', 'base_trainer') |
|
print(f'using GPU {rank} for training') |
|
if rank == 0: |
|
print(trainer_name) |
|
trainer_pack = importlib.import_module('trainers.' + trainer_name) |
|
trainer = trainer_pack.Trainer(config) |
|
|
|
trainer.train() |
|
|
|
|
|
if __name__ == "__main__": |
|
torch.backends.cudnn.benchmark = True |
|
cfg = OmegaConf.load(args.config) |
|
seed_all(cfg.seed) |
|
rank = int(args.local_rank) |
|
torch.cuda.set_device(torch.device(f'cuda:{rank}')) |
|
|
|
cfg['world_size'] = get_world_size() |
|
cfg['local_rank'] = rank |
|
if rank == 0: |
|
print('world_size: ', cfg['world_size']) |
|
main_worker(rank, cfg) |
|
|
|
|