|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Encoding Data Parallel""" |
|
import functools |
|
import threading |
|
|
|
import torch |
|
import torch.cuda.comm as comm |
|
from torch.autograd import Function, Variable |
|
from torch.nn.parallel._functions import Broadcast, ReduceAddCoalesced |
|
from torch.nn.parallel.data_parallel import DataParallel |
|
from torch.nn.parallel.parallel_apply import get_a_var |
|
|
|
torch_ver = torch.__version__[:3] |
|
|
|
__all__ = [ |
|
"allreduce", |
|
"DataParallelModel", |
|
"DataParallelCriterion", |
|
"patch_replication_callback", |
|
] |
|
|
|
|
|
def allreduce(*inputs): |
|
"""Cross GPU all reduce autograd operation for calculate mean and |
|
variance in SyncBN. |
|
""" |
|
return AllReduce.apply(*inputs) |
|
|
|
|
|
class AllReduce(Function): |
|
@staticmethod |
|
def forward(ctx, num_inputs, *inputs): |
|
ctx.num_inputs = num_inputs |
|
ctx.target_gpus = [ |
|
inputs[i].get_device() for i in range(0, len(inputs), num_inputs) |
|
] |
|
inputs = [inputs[i : i + num_inputs] for i in range(0, len(inputs), num_inputs)] |
|
|
|
inputs = sorted(inputs, key=lambda i: i[0].get_device()) |
|
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) |
|
outputs = comm.broadcast_coalesced(results, ctx.target_gpus) |
|
return tuple([t for tensors in outputs for t in tensors]) |
|
|
|
@staticmethod |
|
def backward(ctx, *inputs): |
|
inputs = [i.data for i in inputs] |
|
inputs = [ |
|
inputs[i : i + ctx.num_inputs] |
|
for i in range(0, len(inputs), ctx.num_inputs) |
|
] |
|
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) |
|
outputs = comm.broadcast_coalesced(results, ctx.target_gpus) |
|
return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) |
|
|
|
|
|
class Reduce(Function): |
|
@staticmethod |
|
def forward(ctx, *inputs): |
|
ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] |
|
inputs = sorted(inputs, key=lambda i: i.get_device()) |
|
return comm.reduce_add(inputs) |
|
|
|
@staticmethod |
|
def backward(ctx, gradOutput): |
|
return Broadcast.apply(ctx.target_gpus, gradOutput) |
|
|
|
|
|
class DataParallelModel(DataParallel): |
|
"""Implements data parallelism at the module level. |
|
|
|
This container parallelizes the application of the given module by |
|
splitting the input across the specified devices by chunking in the |
|
batch dimension. |
|
In the forward pass, the module is replicated on each device, |
|
and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. |
|
Note that the outputs are not gathered, please use compatible |
|
:class:`encoding.parallel.DataParallelCriterion`. |
|
|
|
The batch size should be larger than the number of GPUs used. It should |
|
also be an integer multiple of the number of GPUs so that each chunk is |
|
the same size (so that each GPU processes the same number of samples). |
|
|
|
Args: |
|
module: module to be parallelized |
|
device_ids: CUDA devices (default: all devices) |
|
|
|
Reference: |
|
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, |
|
Amit Agrawal. “Context Encoding for Semantic Segmentation. |
|
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* |
|
|
|
Example:: |
|
|
|
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) |
|
>>> y = net(x) |
|
""" |
|
|
|
def gather(self, outputs, output_device): |
|
return outputs |
|
|
|
def replicate(self, module, device_ids): |
|
modules = super(DataParallelModel, self).replicate(module, device_ids) |
|
execute_replication_callbacks(modules) |
|
return modules |
|
|
|
|
|
class DataParallelCriterion(DataParallel): |
|
""" |
|
Calculate loss in multiple-GPUs, which balance the memory usage for |
|
Semantic Segmentation. |
|
|
|
The targets are splitted across the specified devices by chunking in |
|
the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. |
|
|
|
Reference: |
|
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, |
|
Amit Agrawal. “Context Encoding for Semantic Segmentation. |
|
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* |
|
|
|
Example:: |
|
|
|
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) |
|
>>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) |
|
>>> y = net(x) |
|
>>> loss = criterion(y, target) |
|
""" |
|
|
|
def forward(self, inputs, *targets, **kwargs): |
|
|
|
|
|
if not self.device_ids: |
|
return self.module(inputs, *targets, **kwargs) |
|
targets, kwargs = self.scatter(targets, kwargs, self.device_ids) |
|
if len(self.device_ids) == 1: |
|
return self.module(inputs, *targets[0], **kwargs[0]) |
|
replicas = self.replicate(self.module, self.device_ids[: len(inputs)]) |
|
outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) |
|
return Reduce.apply(*outputs) / len(outputs) |
|
|
|
|
|
|
|
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): |
|
assert len(modules) == len(inputs) |
|
assert len(targets) == len(inputs) |
|
if kwargs_tup: |
|
assert len(modules) == len(kwargs_tup) |
|
else: |
|
kwargs_tup = ({},) * len(modules) |
|
if devices is not None: |
|
assert len(modules) == len(devices) |
|
else: |
|
devices = [None] * len(modules) |
|
|
|
lock = threading.Lock() |
|
results = {} |
|
if torch_ver != "0.3": |
|
grad_enabled = torch.is_grad_enabled() |
|
|
|
def _worker(i, module, input, target, kwargs, device=None): |
|
if torch_ver != "0.3": |
|
torch.set_grad_enabled(grad_enabled) |
|
if device is None: |
|
device = get_a_var(input).get_device() |
|
try: |
|
with torch.cuda.device(device): |
|
|
|
if not isinstance(input, (list, tuple)): |
|
input = (input,) |
|
if type(input) != type(target): |
|
if isinstance(target, tuple): |
|
input = tuple(input) |
|
elif isinstance(target, list): |
|
input = list(input) |
|
else: |
|
raise Exception("Types problem") |
|
|
|
output = module(*(input + target), **kwargs) |
|
with lock: |
|
results[i] = output |
|
except Exception as e: |
|
with lock: |
|
results[i] = e |
|
|
|
if len(modules) > 1: |
|
threads = [ |
|
threading.Thread( |
|
target=_worker, |
|
args=(i, module, input, target, kwargs, device), |
|
) |
|
for i, (module, input, target, kwargs, device) in enumerate( |
|
zip(modules, inputs, targets, kwargs_tup, devices) |
|
) |
|
] |
|
|
|
for thread in threads: |
|
thread.start() |
|
for thread in threads: |
|
thread.join() |
|
else: |
|
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) |
|
|
|
outputs = [] |
|
for i in range(len(inputs)): |
|
output = results[i] |
|
if isinstance(output, Exception): |
|
raise output |
|
outputs.append(output) |
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
|
|
class CallbackContext(object): |
|
pass |
|
|
|
|
|
def execute_replication_callbacks(modules): |
|
""" |
|
Execute an replication callback `__data_parallel_replicate__` on each module created |
|
by original replication. |
|
|
|
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` |
|
|
|
Note that, as all modules are isomorphism, we assign each sub-module with a context |
|
(shared among multiple copies of this module on different devices). |
|
Through this context, different copies can share some information. |
|
|
|
We guarantee that the callback on the master copy (the first copy) will be called ahead |
|
of calling the callback of any slave copies. |
|
""" |
|
master_copy = modules[0] |
|
nr_modules = len(list(master_copy.modules())) |
|
ctxs = [CallbackContext() for _ in range(nr_modules)] |
|
|
|
for i, module in enumerate(modules): |
|
for j, m in enumerate(module.modules()): |
|
if hasattr(m, "__data_parallel_replicate__"): |
|
m.__data_parallel_replicate__(ctxs[j], i) |
|
|
|
|
|
def patch_replication_callback(data_parallel): |
|
""" |
|
Monkey-patch an existing `DataParallel` object. Add the replication callback. |
|
Useful when you have customized `DataParallel` implementation. |
|
|
|
Examples: |
|
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) |
|
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) |
|
> patch_replication_callback(sync_bn) |
|
# this is equivalent to |
|
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) |
|
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) |
|
""" |
|
|
|
assert isinstance(data_parallel, DataParallel) |
|
|
|
old_replicate = data_parallel.replicate |
|
|
|
@functools.wraps(old_replicate) |
|
def new_replicate(module, device_ids): |
|
modules = old_replicate(module, device_ids) |
|
execute_replication_callbacks(modules) |
|
return modules |
|
|
|
data_parallel.replicate = new_replicate |
|
|