Spaces:
Runtime error
Runtime error
Upload utils/sync_batchnorm/replicate.py
Browse files
utils/sync_batchnorm/replicate.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : replicate.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : [email protected]
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
import functools
|
| 12 |
+
|
| 13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
'CallbackContext',
|
| 17 |
+
'execute_replication_callbacks',
|
| 18 |
+
'DataParallelWithCallback',
|
| 19 |
+
'patch_replication_callback'
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CallbackContext(object):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def execute_replication_callbacks(modules):
|
| 28 |
+
"""
|
| 29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
| 30 |
+
|
| 31 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
| 32 |
+
|
| 33 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
| 34 |
+
(shared among multiple copies of this module on different devices).
|
| 35 |
+
Through this context, different copies can share some information.
|
| 36 |
+
|
| 37 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
| 38 |
+
of any slave copies.
|
| 39 |
+
"""
|
| 40 |
+
master_copy = modules[0]
|
| 41 |
+
nr_modules = len(list(master_copy.modules()))
|
| 42 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
| 43 |
+
|
| 44 |
+
for i, module in enumerate(modules):
|
| 45 |
+
for j, m in enumerate(module.modules()):
|
| 46 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
| 47 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class DataParallelWithCallback(DataParallel):
|
| 51 |
+
"""
|
| 52 |
+
Data Parallel with a replication callback.
|
| 53 |
+
|
| 54 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
| 55 |
+
original `replicate` function.
|
| 56 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
| 57 |
+
|
| 58 |
+
Examples:
|
| 59 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
| 60 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
| 61 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def replicate(self, module, device_ids):
|
| 65 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
| 66 |
+
execute_replication_callbacks(modules)
|
| 67 |
+
return modules
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def patch_replication_callback(data_parallel):
|
| 71 |
+
"""
|
| 72 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
| 73 |
+
Useful when you have customized `DataParallel` implementation.
|
| 74 |
+
|
| 75 |
+
Examples:
|
| 76 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
| 77 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
| 78 |
+
> patch_replication_callback(sync_bn)
|
| 79 |
+
# this is equivalent to
|
| 80 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
| 81 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
assert isinstance(data_parallel, DataParallel)
|
| 85 |
+
|
| 86 |
+
old_replicate = data_parallel.replicate
|
| 87 |
+
|
| 88 |
+
@functools.wraps(old_replicate)
|
| 89 |
+
def new_replicate(module, device_ids):
|
| 90 |
+
modules = old_replicate(module, device_ids)
|
| 91 |
+
execute_replication_callbacks(modules)
|
| 92 |
+
return modules
|
| 93 |
+
|
| 94 |
+
data_parallel.replicate = new_replicate
|