ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
raw
history blame
2.31 kB
import os
import argparse
from shutil import copyfile
import torch.distributed as dist
import torch
import importlib
import datetime
from utils.dist_utils import (
get_world_size,
)
from omegaconf import OmegaConf
from utils.utils import seed_all
parser = argparse.ArgumentParser(description='VFI')
parser.add_argument('-c', '--config', type=str)
parser.add_argument('-p', '--port', default='23455', type=str)
parser.add_argument('--local_rank', default='0')
args = parser.parse_args()
def main_worker(rank, config):
if 'local_rank' not in config:
config['local_rank'] = config['global_rank'] = rank
if torch.cuda.is_available():
print(f'Rank {rank} is available')
config['device'] = f"cuda:{rank}"
if config['distributed']:
dist.init_process_group(backend='nccl',
timeout=datetime.timedelta(seconds=5400))
else:
config['device'] = 'cpu'
cfg_name = os.path.basename(args.config).split('.')[0]
config['exp_name'] = cfg_name + '_' + config['exp_name']
config['save_dir'] = os.path.join(config['save_dir'], config['exp_name'])
if (not config['distributed']) or rank == 0:
os.makedirs(config['save_dir'], exist_ok=True)
os.makedirs(f'{config["save_dir"]}/ckpts', exist_ok=True)
config_path = os.path.join(config['save_dir'],
args.config.split('/')[-1])
if not os.path.isfile(config_path):
copyfile(args.config, config_path)
print('[**] create folder {}'.format(config['save_dir']))
trainer_name = config.get('trainer_type', 'base_trainer')
print(f'using GPU {rank} for training')
if rank == 0:
print(trainer_name)
trainer_pack = importlib.import_module('trainers.' + trainer_name)
trainer = trainer_pack.Trainer(config)
trainer.train()
if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
cfg = OmegaConf.load(args.config)
seed_all(cfg.seed)
rank = int(args.local_rank)
torch.cuda.set_device(torch.device(f'cuda:{rank}'))
# setting distributed cfgurations
cfg['world_size'] = get_world_size()
cfg['local_rank'] = rank
if rank == 0:
print('world_size: ', cfg['world_size'])
main_worker(rank, cfg)