HaMeR / mmpose /apis /train.py
geopavlakos's picture
Initial commit
d7a991a
raw
history blame
7.37 kB
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
get_dist_info)
from mmcv.utils import digit_version
from mmpose.core import DistEvalHook, EvalHook, build_optimizers
from mmpose.core.distributed_wrapper import DistributedDataParallelWrapper
from mmpose.datasets import build_dataloader, build_dataset
from mmpose.utils import get_root_logger
try:
from mmcv.runner import Fp16OptimizerHook
except ImportError:
warnings.warn(
'Fp16OptimizerHook from mmpose will be deprecated from '
'v0.15.0. Please install mmcv>=1.1.4', DeprecationWarning)
from mmpose.core import Fp16OptimizerHook
def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to prevent some potential bugs.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
def train_model(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
"""Train model entry function.
Args:
model (nn.Module): The model to be trained.
dataset (Dataset): Train dataset.
cfg (dict): The config dict for training.
distributed (bool): Whether to use distributed training.
Default: False.
validate (bool): Whether to do evaluation. Default: False.
timestamp (str | None): Local time for runner. Default: None.
meta (dict | None): Meta dict to record some important information.
Default: None
"""
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
# step 1: give default values and override (if exist) from cfg.data
loader_cfg = {
**dict(
seed=cfg.get('seed'),
drop_last=False,
dist=distributed,
num_gpus=len(cfg.gpu_ids)),
**({} if torch.__version__ != 'parrots' else dict(
prefetch_num=2,
pin_memory=False,
)),
**dict((k, cfg.data[k]) for k in [
'samples_per_gpu',
'workers_per_gpu',
'shuffle',
'seed',
'drop_last',
'prefetch_num',
'pin_memory',
'persistent_workers',
] if k in cfg.data)
}
# step 2: cfg.data.train_dataloader has highest priority
train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {}))
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
# determine whether use adversarial training precess or not
use_adverserial_train = cfg.get('use_adversarial_train', False)
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
if use_adverserial_train:
# Use DistributedDataParallelWrapper for adversarial training
model = DistributedDataParallelWrapper(
model,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
if digit_version(mmcv.__version__) >= digit_version(
'1.4.4') or torch.cuda.is_available():
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
else:
warnings.warn(
'We recommend to use MMCV >= 1.4.4 for CPU training. '
'See https://github.com/open-mmlab/mmpose/pull/1157 for '
'details.')
# build runner
optimizer = build_optimizers(model, cfg.optimizer)
runner = EpochBasedRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)
# an ugly workaround to make .log and .log.json filenames the same
runner.timestamp = timestamp
if use_adverserial_train:
# The optimizer step process is included in the train_step function
# of the model, so the runner should NOT include optimizer hook.
optimizer_config = None
else:
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
if distributed:
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
eval_cfg = cfg.get('evaluation', {})
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
dataloader_setting = dict(
samples_per_gpu=1,
workers_per_gpu=cfg.data.get('workers_per_gpu', 1),
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
drop_last=False,
shuffle=False)
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('val_dataloader', {}))
val_dataloader = build_dataloader(val_dataset, **dataloader_setting)
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)