|
|
|
import torch |
|
import torch.nn as nn |
|
from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel |
|
from mmcv.parallel.scatter_gather import scatter_kwargs |
|
from torch.cuda._utils import _get_device_index |
|
|
|
|
|
@MODULE_WRAPPERS.register_module() |
|
class DistributedDataParallelWrapper(nn.Module): |
|
"""A DistributedDataParallel wrapper for models in 3D mesh estimation task. |
|
|
|
In some pieplines, there is a need to wrap different modules in |
|
the models with separate DistributedDataParallel. Otherwise, it will cause |
|
errors for GAN training. |
|
More specific, the GAN model, usually has two sub-modules: |
|
generator and discriminator. If we wrap both of them in one |
|
standard DistributedDataParallel, it will cause errors during training, |
|
because when we update the parameters of the generator (or discriminator), |
|
the parameters of the discriminator (or generator) is not updated, which is |
|
not allowed for DistributedDataParallel. |
|
So we design this wrapper to separately wrap DistributedDataParallel |
|
for generator and discriminator. |
|
In this wrapper, we perform two operations: |
|
1. Wrap the modules in the models with separate MMDistributedDataParallel. |
|
Note that only modules with parameters will be wrapped. |
|
2. Do scatter operation for 'forward', 'train_step' and 'val_step'. |
|
Note that the arguments of this wrapper is the same as those in |
|
`torch.nn.parallel.distributed.DistributedDataParallel`. |
|
Args: |
|
module (nn.Module): Module that needs to be wrapped. |
|
device_ids (list[int | `torch.device`]): Same as that in |
|
`torch.nn.parallel.distributed.DistributedDataParallel`. |
|
dim (int, optional): Same as that in the official scatter function in |
|
pytorch. Defaults to 0. |
|
broadcast_buffers (bool): Same as that in |
|
`torch.nn.parallel.distributed.DistributedDataParallel`. |
|
Defaults to False. |
|
find_unused_parameters (bool, optional): Same as that in |
|
`torch.nn.parallel.distributed.DistributedDataParallel`. |
|
Traverse the autograd graph of all tensors contained in returned |
|
value of the wrapped module’s forward function. Defaults to False. |
|
kwargs (dict): Other arguments used in |
|
`torch.nn.parallel.distributed.DistributedDataParallel`. |
|
""" |
|
|
|
def __init__(self, |
|
module, |
|
device_ids, |
|
dim=0, |
|
broadcast_buffers=False, |
|
find_unused_parameters=False, |
|
**kwargs): |
|
super().__init__() |
|
assert len(device_ids) == 1, ( |
|
'Currently, DistributedDataParallelWrapper only supports one' |
|
'single CUDA device for each process.' |
|
f'The length of device_ids must be 1, but got {len(device_ids)}.') |
|
self.module = module |
|
self.dim = dim |
|
self.to_ddp(device_ids=device_ids, |
|
dim=dim, |
|
broadcast_buffers=broadcast_buffers, |
|
find_unused_parameters=find_unused_parameters, |
|
**kwargs) |
|
self.output_device = _get_device_index(device_ids[0], True) |
|
|
|
def to_ddp(self, device_ids, dim, broadcast_buffers, |
|
find_unused_parameters, **kwargs): |
|
"""Wrap models with separate MMDistributedDataParallel. |
|
|
|
It only wraps the modules with parameters. |
|
""" |
|
for name, module in self.module._modules.items(): |
|
if next(module.parameters(), None) is None: |
|
module = module.cuda() |
|
elif all(not p.requires_grad for p in module.parameters()): |
|
module = module.cuda() |
|
else: |
|
module = MMDistributedDataParallel( |
|
module.cuda(), |
|
device_ids=device_ids, |
|
dim=dim, |
|
broadcast_buffers=broadcast_buffers, |
|
find_unused_parameters=find_unused_parameters, |
|
**kwargs) |
|
self.module._modules[name] = module |
|
|
|
def scatter(self, inputs, kwargs, device_ids): |
|
"""Scatter function. |
|
|
|
Args: |
|
inputs (Tensor): Input Tensor. |
|
kwargs (dict): Args for |
|
``mmcv.parallel.scatter_gather.scatter_kwargs``. |
|
device_ids (int): Device id. |
|
""" |
|
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) |
|
|
|
def forward(self, *inputs, **kwargs): |
|
"""Forward function. |
|
|
|
Args: |
|
inputs (tuple): Input data. |
|
kwargs (dict): Args for |
|
``mmcv.parallel.scatter_gather.scatter_kwargs``. |
|
""" |
|
inputs, kwargs = self.scatter(inputs, kwargs, |
|
[torch.cuda.current_device()]) |
|
return self.module(*inputs[0], **kwargs[0]) |
|
|
|
def train_step(self, *inputs, **kwargs): |
|
"""Train step function. |
|
|
|
Args: |
|
inputs (Tensor): Input Tensor. |
|
kwargs (dict): Args for |
|
``mmcv.parallel.scatter_gather.scatter_kwargs``. |
|
""" |
|
inputs, kwargs = self.scatter(inputs, kwargs, |
|
[torch.cuda.current_device()]) |
|
output = self.module.train_step(*inputs[0], **kwargs[0]) |
|
return output |
|
|
|
def val_step(self, *inputs, **kwargs): |
|
"""Validation step function. |
|
|
|
Args: |
|
inputs (tuple): Input data. |
|
kwargs (dict): Args for ``scatter_kwargs``. |
|
""" |
|
inputs, kwargs = self.scatter(inputs, kwargs, |
|
[torch.cuda.current_device()]) |
|
output = self.module.val_step(*inputs[0], **kwargs[0]) |
|
return output |
|
|