# Copyright (c) OpenMMLab. All rights reserved. import warnings import mmcv import numpy as np import torch import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook, get_dist_info) from mmcv.utils import digit_version from mmpose.core import DistEvalHook, EvalHook, build_optimizers from mmpose.core.distributed_wrapper import DistributedDataParallelWrapper from mmpose.datasets import build_dataloader, build_dataset from mmpose.utils import get_root_logger try: from mmcv.runner import Fp16OptimizerHook except ImportError: warnings.warn( 'Fp16OptimizerHook from mmpose will be deprecated from ' 'v0.15.0. Please install mmcv>=1.1.4', DeprecationWarning) from mmpose.core import Fp16OptimizerHook def init_random_seed(seed=None, device='cuda'): """Initialize random seed. If the seed is not set, the seed will be automatically randomized, and then broadcast to all processes to prevent some potential bugs. Args: seed (int, Optional): The seed. Default to None. device (str): The device where the seed will be put on. Default to 'cuda'. Returns: int: Seed to be used. """ if seed is not None: return seed # Make sure all ranks share the same random seed to prevent # some potential bugs. Please refer to # https://github.com/open-mmlab/mmdetection/issues/6339 rank, world_size = get_dist_info() seed = np.random.randint(2**31) if world_size == 1: return seed if rank == 0: random_num = torch.tensor(seed, dtype=torch.int32, device=device) else: random_num = torch.tensor(0, dtype=torch.int32, device=device) dist.broadcast(random_num, src=0) return random_num.item() def train_model(model, dataset, cfg, distributed=False, validate=False, timestamp=None, meta=None): """Train model entry function. Args: model (nn.Module): The model to be trained. dataset (Dataset): Train dataset. cfg (dict): The config dict for training. distributed (bool): Whether to use distributed training. Default: False. validate (bool): Whether to do evaluation. Default: False. timestamp (str | None): Local time for runner. Default: None. meta (dict | None): Meta dict to record some important information. Default: None """ logger = get_root_logger(cfg.log_level) # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] # step 1: give default values and override (if exist) from cfg.data loader_cfg = { **dict( seed=cfg.get('seed'), drop_last=False, dist=distributed, num_gpus=len(cfg.gpu_ids)), **({} if torch.__version__ != 'parrots' else dict( prefetch_num=2, pin_memory=False, )), **dict((k, cfg.data[k]) for k in [ 'samples_per_gpu', 'workers_per_gpu', 'shuffle', 'seed', 'drop_last', 'prefetch_num', 'pin_memory', 'persistent_workers', ] if k in cfg.data) } # step 2: cfg.data.train_dataloader has highest priority train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {})) data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] # determine whether use adversarial training precess or not use_adverserial_train = cfg.get('use_adversarial_train', False) # put model on gpus if distributed: find_unused_parameters = cfg.get('find_unused_parameters', False) # Sets the `find_unused_parameters` parameter in # torch.nn.parallel.DistributedDataParallel if use_adverserial_train: # Use DistributedDataParallelWrapper for adversarial training model = DistributedDataParallelWrapper( model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: if digit_version(mmcv.__version__) >= digit_version( '1.4.4') or torch.cuda.is_available(): model = MMDataParallel(model, device_ids=cfg.gpu_ids) else: warnings.warn( 'We recommend to use MMCV >= 1.4.4 for CPU training. ' 'See https://github.com/open-mmlab/mmpose/pull/1157 for ' 'details.') # build runner optimizer = build_optimizers(model, cfg.optimizer) runner = EpochBasedRunner( model, optimizer=optimizer, work_dir=cfg.work_dir, logger=logger, meta=meta) # an ugly workaround to make .log and .log.json filenames the same runner.timestamp = timestamp if use_adverserial_train: # The optimizer step process is included in the train_step function # of the model, so the runner should NOT include optimizer hook. optimizer_config = None else: # fp16 setting fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: optimizer_config = Fp16OptimizerHook( **cfg.optimizer_config, **fp16_cfg, distributed=distributed) elif distributed and 'type' not in cfg.optimizer_config: optimizer_config = OptimizerHook(**cfg.optimizer_config) else: optimizer_config = cfg.optimizer_config # register hooks runner.register_training_hooks(cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config, cfg.get('momentum_config', None)) if distributed: runner.register_hook(DistSamplerSeedHook()) # register eval hooks if validate: eval_cfg = cfg.get('evaluation', {}) val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) dataloader_setting = dict( samples_per_gpu=1, workers_per_gpu=cfg.data.get('workers_per_gpu', 1), # cfg.gpus will be ignored if distributed num_gpus=len(cfg.gpu_ids), dist=distributed, drop_last=False, shuffle=False) dataloader_setting = dict(dataloader_setting, **cfg.data.get('val_dataloader', {})) val_dataloader = build_dataloader(val_dataset, **dataloader_setting) eval_hook = DistEvalHook if distributed else EvalHook runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: runner.load_checkpoint(cfg.load_from) runner.run(data_loaders, cfg.workflow, cfg.total_epochs)