josedolot commited on
Commit
fd6307b
·
1 Parent(s): 59410c2

Upload utils/sync_batchnorm/batchnorm.py

Browse files
Files changed (1) hide show
  1. utils/sync_batchnorm/batchnorm.py +394 -0
utils/sync_batchnorm/batchnorm.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+ import contextlib
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+ from torch.nn.modules.batchnorm import _BatchNorm
18
+
19
+ try:
20
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
21
+ except ImportError:
22
+ ReduceAddCoalesced = Broadcast = None
23
+
24
+ try:
25
+ from jactorch.parallel.comm import SyncMaster
26
+ from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
27
+ except ImportError:
28
+ from .comm import SyncMaster
29
+ from .replicate import DataParallelWithCallback
30
+
31
+ __all__ = [
32
+ 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
33
+ 'patch_sync_batchnorm', 'convert_model'
34
+ ]
35
+
36
+
37
+ def _sum_ft(tensor):
38
+ """sum over the first and last dimention"""
39
+ return tensor.sum(dim=0).sum(dim=-1)
40
+
41
+
42
+ def _unsqueeze_ft(tensor):
43
+ """add new dimensions at the front and the tail"""
44
+ return tensor.unsqueeze(0).unsqueeze(-1)
45
+
46
+
47
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
48
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
49
+
50
+
51
+ class _SynchronizedBatchNorm(_BatchNorm):
52
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
53
+ assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
54
+
55
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
56
+
57
+ self._sync_master = SyncMaster(self._data_parallel_master)
58
+
59
+ self._is_parallel = False
60
+ self._parallel_id = None
61
+ self._slave_pipe = None
62
+
63
+ def forward(self, input):
64
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
65
+ if not (self._is_parallel and self.training):
66
+ return F.batch_norm(
67
+ input, self.running_mean, self.running_var, self.weight, self.bias,
68
+ self.training, self.momentum, self.eps)
69
+
70
+ # Resize the input to (B, C, -1).
71
+ input_shape = input.size()
72
+ input = input.view(input.size(0), self.num_features, -1)
73
+
74
+ # Compute the sum and square-sum.
75
+ sum_size = input.size(0) * input.size(2)
76
+ input_sum = _sum_ft(input)
77
+ input_ssum = _sum_ft(input ** 2)
78
+
79
+ # Reduce-and-broadcast the statistics.
80
+ if self._parallel_id == 0:
81
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
82
+ else:
83
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
84
+
85
+ # Compute the output.
86
+ if self.affine:
87
+ # MJY:: Fuse the multiplication for speed.
88
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
89
+ else:
90
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
91
+
92
+ # Reshape it.
93
+ return output.view(input_shape)
94
+
95
+ def __data_parallel_replicate__(self, ctx, copy_id):
96
+ self._is_parallel = True
97
+ self._parallel_id = copy_id
98
+
99
+ # parallel_id == 0 means master device.
100
+ if self._parallel_id == 0:
101
+ ctx.sync_master = self._sync_master
102
+ else:
103
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
104
+
105
+ def _data_parallel_master(self, intermediates):
106
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
107
+
108
+ # Always using same "device order" makes the ReduceAdd operation faster.
109
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
110
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
111
+
112
+ to_reduce = [i[1][:2] for i in intermediates]
113
+ to_reduce = [j for i in to_reduce for j in i] # flatten
114
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
115
+
116
+ sum_size = sum([i[1].sum_size for i in intermediates])
117
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
118
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
119
+
120
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
121
+
122
+ outputs = []
123
+ for i, rec in enumerate(intermediates):
124
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
125
+
126
+ return outputs
127
+
128
+ def _compute_mean_std(self, sum_, ssum, size):
129
+ """Compute the mean and standard-deviation with sum and square-sum. This method
130
+ also maintains the moving average on the master device."""
131
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
132
+ mean = sum_ / size
133
+ sumvar = ssum - sum_ * mean
134
+ unbias_var = sumvar / (size - 1)
135
+ bias_var = sumvar / size
136
+
137
+ if hasattr(torch, 'no_grad'):
138
+ with torch.no_grad():
139
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
140
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
141
+ else:
142
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
143
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
144
+
145
+ return mean, bias_var.clamp(self.eps) ** -0.5
146
+
147
+
148
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
149
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
150
+ mini-batch.
151
+
152
+ .. math::
153
+
154
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
155
+
156
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
157
+ standard-deviation are reduced across all devices during training.
158
+
159
+ For example, when one uses `nn.DataParallel` to wrap the network during
160
+ training, PyTorch's implementation normalize the tensor on each device using
161
+ the statistics only on that device, which accelerated the computation and
162
+ is also easy to implement, but the statistics might be inaccurate.
163
+ Instead, in this synchronized version, the statistics will be computed
164
+ over all training samples distributed on multiple devices.
165
+
166
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
167
+ as the built-in PyTorch implementation.
168
+
169
+ The mean and standard-deviation are calculated per-dimension over
170
+ the mini-batches and gamma and beta are learnable parameter vectors
171
+ of size C (where C is the input size).
172
+
173
+ During training, this layer keeps a running estimate of its computed mean
174
+ and variance. The running sum is kept with a default momentum of 0.1.
175
+
176
+ During evaluation, this running mean/variance is used for normalization.
177
+
178
+ Because the BatchNorm is done over the `C` dimension, computing statistics
179
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
180
+
181
+ Args:
182
+ num_features: num_features from an expected input of size
183
+ `batch_size x num_features [x width]`
184
+ eps: a value added to the denominator for numerical stability.
185
+ Default: 1e-5
186
+ momentum: the value used for the running_mean and running_var
187
+ computation. Default: 0.1
188
+ affine: a boolean value that when set to ``True``, gives the layer learnable
189
+ affine parameters. Default: ``True``
190
+
191
+ Shape::
192
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
193
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
194
+
195
+ Examples:
196
+ >>> # With Learnable Parameters
197
+ >>> m = SynchronizedBatchNorm1d(100)
198
+ >>> # Without Learnable Parameters
199
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
200
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
201
+ >>> output = m(input)
202
+ """
203
+
204
+ def _check_input_dim(self, input):
205
+ if input.dim() != 2 and input.dim() != 3:
206
+ raise ValueError('expected 2D or 3D input (got {}D input)'
207
+ .format(input.dim()))
208
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
209
+
210
+
211
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
212
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
213
+ of 3d inputs
214
+
215
+ .. math::
216
+
217
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
218
+
219
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
220
+ standard-deviation are reduced across all devices during training.
221
+
222
+ For example, when one uses `nn.DataParallel` to wrap the network during
223
+ training, PyTorch's implementation normalize the tensor on each device using
224
+ the statistics only on that device, which accelerated the computation and
225
+ is also easy to implement, but the statistics might be inaccurate.
226
+ Instead, in this synchronized version, the statistics will be computed
227
+ over all training samples distributed on multiple devices.
228
+
229
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
230
+ as the built-in PyTorch implementation.
231
+
232
+ The mean and standard-deviation are calculated per-dimension over
233
+ the mini-batches and gamma and beta are learnable parameter vectors
234
+ of size C (where C is the input size).
235
+
236
+ During training, this layer keeps a running estimate of its computed mean
237
+ and variance. The running sum is kept with a default momentum of 0.1.
238
+
239
+ During evaluation, this running mean/variance is used for normalization.
240
+
241
+ Because the BatchNorm is done over the `C` dimension, computing statistics
242
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
243
+
244
+ Args:
245
+ num_features: num_features from an expected input of
246
+ size batch_size x num_features x height x width
247
+ eps: a value added to the denominator for numerical stability.
248
+ Default: 1e-5
249
+ momentum: the value used for the running_mean and running_var
250
+ computation. Default: 0.1
251
+ affine: a boolean value that when set to ``True``, gives the layer learnable
252
+ affine parameters. Default: ``True``
253
+
254
+ Shape::
255
+ - Input: :math:`(N, C, H, W)`
256
+ - Output: :math:`(N, C, H, W)` (same shape as input)
257
+
258
+ Examples:
259
+ >>> # With Learnable Parameters
260
+ >>> m = SynchronizedBatchNorm2d(100)
261
+ >>> # Without Learnable Parameters
262
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
263
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
264
+ >>> output = m(input)
265
+ """
266
+
267
+ def _check_input_dim(self, input):
268
+ if input.dim() != 4:
269
+ raise ValueError('expected 4D input (got {}D input)'
270
+ .format(input.dim()))
271
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
272
+
273
+
274
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
275
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
276
+ of 4d inputs
277
+
278
+ .. math::
279
+
280
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
281
+
282
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
283
+ standard-deviation are reduced across all devices during training.
284
+
285
+ For example, when one uses `nn.DataParallel` to wrap the network during
286
+ training, PyTorch's implementation normalize the tensor on each device using
287
+ the statistics only on that device, which accelerated the computation and
288
+ is also easy to implement, but the statistics might be inaccurate.
289
+ Instead, in this synchronized version, the statistics will be computed
290
+ over all training samples distributed on multiple devices.
291
+
292
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
293
+ as the built-in PyTorch implementation.
294
+
295
+ The mean and standard-deviation are calculated per-dimension over
296
+ the mini-batches and gamma and beta are learnable parameter vectors
297
+ of size C (where C is the input size).
298
+
299
+ During training, this layer keeps a running estimate of its computed mean
300
+ and variance. The running sum is kept with a default momentum of 0.1.
301
+
302
+ During evaluation, this running mean/variance is used for normalization.
303
+
304
+ Because the BatchNorm is done over the `C` dimension, computing statistics
305
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
306
+ or Spatio-temporal BatchNorm
307
+
308
+ Args:
309
+ num_features: num_features from an expected input of
310
+ size batch_size x num_features x depth x height x width
311
+ eps: a value added to the denominator for numerical stability.
312
+ Default: 1e-5
313
+ momentum: the value used for the running_mean and running_var
314
+ computation. Default: 0.1
315
+ affine: a boolean value that when set to ``True``, gives the layer learnable
316
+ affine parameters. Default: ``True``
317
+
318
+ Shape::
319
+ - Input: :math:`(N, C, D, H, W)`
320
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
321
+
322
+ Examples:
323
+ >>> # With Learnable Parameters
324
+ >>> m = SynchronizedBatchNorm3d(100)
325
+ >>> # Without Learnable Parameters
326
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
327
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
328
+ >>> output = m(input)
329
+ """
330
+
331
+ def _check_input_dim(self, input):
332
+ if input.dim() != 5:
333
+ raise ValueError('expected 5D input (got {}D input)'
334
+ .format(input.dim()))
335
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
336
+
337
+
338
+ @contextlib.contextmanager
339
+ def patch_sync_batchnorm():
340
+ import torch.nn as nn
341
+
342
+ backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
343
+
344
+ nn.BatchNorm1d = SynchronizedBatchNorm1d
345
+ nn.BatchNorm2d = SynchronizedBatchNorm2d
346
+ nn.BatchNorm3d = SynchronizedBatchNorm3d
347
+
348
+ yield
349
+
350
+ nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
351
+
352
+
353
+ def convert_model(module):
354
+ """Traverse the input module and its child recursively
355
+ and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
356
+ to SynchronizedBatchNorm*N*d
357
+
358
+ Args:
359
+ module: the input module needs to be convert to SyncBN model
360
+
361
+ Examples:
362
+ >>> import torch.nn as nn
363
+ >>> import torchvision
364
+ >>> # m is a standard pytorch model
365
+ >>> m = torchvision.models.resnet18(True)
366
+ >>> m = nn.DataParallel(m)
367
+ >>> # after convert, m is using SyncBN
368
+ >>> m = convert_model(m)
369
+ """
370
+ if isinstance(module, torch.nn.DataParallel):
371
+ mod = module.module
372
+ mod = convert_model(mod)
373
+ mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
374
+ return mod
375
+
376
+ mod = module
377
+ for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
378
+ torch.nn.modules.batchnorm.BatchNorm2d,
379
+ torch.nn.modules.batchnorm.BatchNorm3d],
380
+ [SynchronizedBatchNorm1d,
381
+ SynchronizedBatchNorm2d,
382
+ SynchronizedBatchNorm3d]):
383
+ if isinstance(module, pth_module):
384
+ mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
385
+ mod.running_mean = module.running_mean
386
+ mod.running_var = module.running_var
387
+ if module.affine:
388
+ mod.weight.data = module.weight.data.clone().detach()
389
+ mod.bias.data = module.bias.data.clone().detach()
390
+
391
+ for name, child in module.named_children():
392
+ mod.add_module(name, convert_model(child))
393
+
394
+ return mod