Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
import mmcv | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel | |
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook, | |
get_dist_info) | |
from mmcv.utils import digit_version | |
from mmpose.core import DistEvalHook, EvalHook, build_optimizers | |
from mmpose.core.distributed_wrapper import DistributedDataParallelWrapper | |
from mmpose.datasets import build_dataloader, build_dataset | |
from mmpose.utils import get_root_logger | |
try: | |
from mmcv.runner import Fp16OptimizerHook | |
except ImportError: | |
warnings.warn( | |
'Fp16OptimizerHook from mmpose will be deprecated from ' | |
'v0.15.0. Please install mmcv>=1.1.4', DeprecationWarning) | |
from mmpose.core import Fp16OptimizerHook | |
def init_random_seed(seed=None, device='cuda'): | |
"""Initialize random seed. | |
If the seed is not set, the seed will be automatically randomized, | |
and then broadcast to all processes to prevent some potential bugs. | |
Args: | |
seed (int, Optional): The seed. Default to None. | |
device (str): The device where the seed will be put on. | |
Default to 'cuda'. | |
Returns: | |
int: Seed to be used. | |
""" | |
if seed is not None: | |
return seed | |
# Make sure all ranks share the same random seed to prevent | |
# some potential bugs. Please refer to | |
# https://github.com/open-mmlab/mmdetection/issues/6339 | |
rank, world_size = get_dist_info() | |
seed = np.random.randint(2**31) | |
if world_size == 1: | |
return seed | |
if rank == 0: | |
random_num = torch.tensor(seed, dtype=torch.int32, device=device) | |
else: | |
random_num = torch.tensor(0, dtype=torch.int32, device=device) | |
dist.broadcast(random_num, src=0) | |
return random_num.item() | |
def train_model(model, | |
dataset, | |
cfg, | |
distributed=False, | |
validate=False, | |
timestamp=None, | |
meta=None): | |
"""Train model entry function. | |
Args: | |
model (nn.Module): The model to be trained. | |
dataset (Dataset): Train dataset. | |
cfg (dict): The config dict for training. | |
distributed (bool): Whether to use distributed training. | |
Default: False. | |
validate (bool): Whether to do evaluation. Default: False. | |
timestamp (str | None): Local time for runner. Default: None. | |
meta (dict | None): Meta dict to record some important information. | |
Default: None | |
""" | |
logger = get_root_logger(cfg.log_level) | |
# prepare data loaders | |
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] | |
# step 1: give default values and override (if exist) from cfg.data | |
loader_cfg = { | |
**dict( | |
seed=cfg.get('seed'), | |
drop_last=False, | |
dist=distributed, | |
num_gpus=len(cfg.gpu_ids)), | |
**({} if torch.__version__ != 'parrots' else dict( | |
prefetch_num=2, | |
pin_memory=False, | |
)), | |
**dict((k, cfg.data[k]) for k in [ | |
'samples_per_gpu', | |
'workers_per_gpu', | |
'shuffle', | |
'seed', | |
'drop_last', | |
'prefetch_num', | |
'pin_memory', | |
'persistent_workers', | |
] if k in cfg.data) | |
} | |
# step 2: cfg.data.train_dataloader has highest priority | |
train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {})) | |
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] | |
# determine whether use adversarial training precess or not | |
use_adverserial_train = cfg.get('use_adversarial_train', False) | |
# put model on gpus | |
if distributed: | |
find_unused_parameters = cfg.get('find_unused_parameters', False) | |
# Sets the `find_unused_parameters` parameter in | |
# torch.nn.parallel.DistributedDataParallel | |
if use_adverserial_train: | |
# Use DistributedDataParallelWrapper for adversarial training | |
model = DistributedDataParallelWrapper( | |
model, | |
device_ids=[torch.cuda.current_device()], | |
broadcast_buffers=False, | |
find_unused_parameters=find_unused_parameters) | |
else: | |
model = MMDistributedDataParallel( | |
model.cuda(), | |
device_ids=[torch.cuda.current_device()], | |
broadcast_buffers=False, | |
find_unused_parameters=find_unused_parameters) | |
else: | |
if digit_version(mmcv.__version__) >= digit_version( | |
'1.4.4') or torch.cuda.is_available(): | |
model = MMDataParallel(model, device_ids=cfg.gpu_ids) | |
else: | |
warnings.warn( | |
'We recommend to use MMCV >= 1.4.4 for CPU training. ' | |
'See https://github.com/open-mmlab/mmpose/pull/1157 for ' | |
'details.') | |
# build runner | |
optimizer = build_optimizers(model, cfg.optimizer) | |
runner = EpochBasedRunner( | |
model, | |
optimizer=optimizer, | |
work_dir=cfg.work_dir, | |
logger=logger, | |
meta=meta) | |
# an ugly workaround to make .log and .log.json filenames the same | |
runner.timestamp = timestamp | |
if use_adverserial_train: | |
# The optimizer step process is included in the train_step function | |
# of the model, so the runner should NOT include optimizer hook. | |
optimizer_config = None | |
else: | |
# fp16 setting | |
fp16_cfg = cfg.get('fp16', None) | |
if fp16_cfg is not None: | |
optimizer_config = Fp16OptimizerHook( | |
**cfg.optimizer_config, **fp16_cfg, distributed=distributed) | |
elif distributed and 'type' not in cfg.optimizer_config: | |
optimizer_config = OptimizerHook(**cfg.optimizer_config) | |
else: | |
optimizer_config = cfg.optimizer_config | |
# register hooks | |
runner.register_training_hooks(cfg.lr_config, optimizer_config, | |
cfg.checkpoint_config, cfg.log_config, | |
cfg.get('momentum_config', None)) | |
if distributed: | |
runner.register_hook(DistSamplerSeedHook()) | |
# register eval hooks | |
if validate: | |
eval_cfg = cfg.get('evaluation', {}) | |
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) | |
dataloader_setting = dict( | |
samples_per_gpu=1, | |
workers_per_gpu=cfg.data.get('workers_per_gpu', 1), | |
# cfg.gpus will be ignored if distributed | |
num_gpus=len(cfg.gpu_ids), | |
dist=distributed, | |
drop_last=False, | |
shuffle=False) | |
dataloader_setting = dict(dataloader_setting, | |
**cfg.data.get('val_dataloader', {})) | |
val_dataloader = build_dataloader(val_dataset, **dataloader_setting) | |
eval_hook = DistEvalHook if distributed else EvalHook | |
runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) | |
if cfg.resume_from: | |
runner.resume(cfg.resume_from) | |
elif cfg.load_from: | |
runner.load_checkpoint(cfg.load_from) | |
runner.run(data_loaders, cfg.workflow, cfg.total_epochs) | |