Spaces:
Runtime error
Runtime error
Upload utils/sync_batchnorm/batchnorm.py
Browse files
utils/sync_batchnorm/batchnorm.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import collections
|
12 |
+
import contextlib
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
18 |
+
|
19 |
+
try:
|
20 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
21 |
+
except ImportError:
|
22 |
+
ReduceAddCoalesced = Broadcast = None
|
23 |
+
|
24 |
+
try:
|
25 |
+
from jactorch.parallel.comm import SyncMaster
|
26 |
+
from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
|
27 |
+
except ImportError:
|
28 |
+
from .comm import SyncMaster
|
29 |
+
from .replicate import DataParallelWithCallback
|
30 |
+
|
31 |
+
__all__ = [
|
32 |
+
'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
|
33 |
+
'patch_sync_batchnorm', 'convert_model'
|
34 |
+
]
|
35 |
+
|
36 |
+
|
37 |
+
def _sum_ft(tensor):
|
38 |
+
"""sum over the first and last dimention"""
|
39 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
40 |
+
|
41 |
+
|
42 |
+
def _unsqueeze_ft(tensor):
|
43 |
+
"""add new dimensions at the front and the tail"""
|
44 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
45 |
+
|
46 |
+
|
47 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
48 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
49 |
+
|
50 |
+
|
51 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
52 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
53 |
+
assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
|
54 |
+
|
55 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
56 |
+
|
57 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
58 |
+
|
59 |
+
self._is_parallel = False
|
60 |
+
self._parallel_id = None
|
61 |
+
self._slave_pipe = None
|
62 |
+
|
63 |
+
def forward(self, input):
|
64 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
65 |
+
if not (self._is_parallel and self.training):
|
66 |
+
return F.batch_norm(
|
67 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
68 |
+
self.training, self.momentum, self.eps)
|
69 |
+
|
70 |
+
# Resize the input to (B, C, -1).
|
71 |
+
input_shape = input.size()
|
72 |
+
input = input.view(input.size(0), self.num_features, -1)
|
73 |
+
|
74 |
+
# Compute the sum and square-sum.
|
75 |
+
sum_size = input.size(0) * input.size(2)
|
76 |
+
input_sum = _sum_ft(input)
|
77 |
+
input_ssum = _sum_ft(input ** 2)
|
78 |
+
|
79 |
+
# Reduce-and-broadcast the statistics.
|
80 |
+
if self._parallel_id == 0:
|
81 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
82 |
+
else:
|
83 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
84 |
+
|
85 |
+
# Compute the output.
|
86 |
+
if self.affine:
|
87 |
+
# MJY:: Fuse the multiplication for speed.
|
88 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
89 |
+
else:
|
90 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
91 |
+
|
92 |
+
# Reshape it.
|
93 |
+
return output.view(input_shape)
|
94 |
+
|
95 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
96 |
+
self._is_parallel = True
|
97 |
+
self._parallel_id = copy_id
|
98 |
+
|
99 |
+
# parallel_id == 0 means master device.
|
100 |
+
if self._parallel_id == 0:
|
101 |
+
ctx.sync_master = self._sync_master
|
102 |
+
else:
|
103 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
104 |
+
|
105 |
+
def _data_parallel_master(self, intermediates):
|
106 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
107 |
+
|
108 |
+
# Always using same "device order" makes the ReduceAdd operation faster.
|
109 |
+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
110 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
111 |
+
|
112 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
113 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
114 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
115 |
+
|
116 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
117 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
118 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
119 |
+
|
120 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
121 |
+
|
122 |
+
outputs = []
|
123 |
+
for i, rec in enumerate(intermediates):
|
124 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
125 |
+
|
126 |
+
return outputs
|
127 |
+
|
128 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
129 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
130 |
+
also maintains the moving average on the master device."""
|
131 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
132 |
+
mean = sum_ / size
|
133 |
+
sumvar = ssum - sum_ * mean
|
134 |
+
unbias_var = sumvar / (size - 1)
|
135 |
+
bias_var = sumvar / size
|
136 |
+
|
137 |
+
if hasattr(torch, 'no_grad'):
|
138 |
+
with torch.no_grad():
|
139 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
140 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
141 |
+
else:
|
142 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
143 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
144 |
+
|
145 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
146 |
+
|
147 |
+
|
148 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
149 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
150 |
+
mini-batch.
|
151 |
+
|
152 |
+
.. math::
|
153 |
+
|
154 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
155 |
+
|
156 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
157 |
+
standard-deviation are reduced across all devices during training.
|
158 |
+
|
159 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
160 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
161 |
+
the statistics only on that device, which accelerated the computation and
|
162 |
+
is also easy to implement, but the statistics might be inaccurate.
|
163 |
+
Instead, in this synchronized version, the statistics will be computed
|
164 |
+
over all training samples distributed on multiple devices.
|
165 |
+
|
166 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
167 |
+
as the built-in PyTorch implementation.
|
168 |
+
|
169 |
+
The mean and standard-deviation are calculated per-dimension over
|
170 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
171 |
+
of size C (where C is the input size).
|
172 |
+
|
173 |
+
During training, this layer keeps a running estimate of its computed mean
|
174 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
175 |
+
|
176 |
+
During evaluation, this running mean/variance is used for normalization.
|
177 |
+
|
178 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
179 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
180 |
+
|
181 |
+
Args:
|
182 |
+
num_features: num_features from an expected input of size
|
183 |
+
`batch_size x num_features [x width]`
|
184 |
+
eps: a value added to the denominator for numerical stability.
|
185 |
+
Default: 1e-5
|
186 |
+
momentum: the value used for the running_mean and running_var
|
187 |
+
computation. Default: 0.1
|
188 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
189 |
+
affine parameters. Default: ``True``
|
190 |
+
|
191 |
+
Shape::
|
192 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
193 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
194 |
+
|
195 |
+
Examples:
|
196 |
+
>>> # With Learnable Parameters
|
197 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
198 |
+
>>> # Without Learnable Parameters
|
199 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
200 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
201 |
+
>>> output = m(input)
|
202 |
+
"""
|
203 |
+
|
204 |
+
def _check_input_dim(self, input):
|
205 |
+
if input.dim() != 2 and input.dim() != 3:
|
206 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
207 |
+
.format(input.dim()))
|
208 |
+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
209 |
+
|
210 |
+
|
211 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
212 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
213 |
+
of 3d inputs
|
214 |
+
|
215 |
+
.. math::
|
216 |
+
|
217 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
218 |
+
|
219 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
220 |
+
standard-deviation are reduced across all devices during training.
|
221 |
+
|
222 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
223 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
224 |
+
the statistics only on that device, which accelerated the computation and
|
225 |
+
is also easy to implement, but the statistics might be inaccurate.
|
226 |
+
Instead, in this synchronized version, the statistics will be computed
|
227 |
+
over all training samples distributed on multiple devices.
|
228 |
+
|
229 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
230 |
+
as the built-in PyTorch implementation.
|
231 |
+
|
232 |
+
The mean and standard-deviation are calculated per-dimension over
|
233 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
234 |
+
of size C (where C is the input size).
|
235 |
+
|
236 |
+
During training, this layer keeps a running estimate of its computed mean
|
237 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
238 |
+
|
239 |
+
During evaluation, this running mean/variance is used for normalization.
|
240 |
+
|
241 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
242 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
243 |
+
|
244 |
+
Args:
|
245 |
+
num_features: num_features from an expected input of
|
246 |
+
size batch_size x num_features x height x width
|
247 |
+
eps: a value added to the denominator for numerical stability.
|
248 |
+
Default: 1e-5
|
249 |
+
momentum: the value used for the running_mean and running_var
|
250 |
+
computation. Default: 0.1
|
251 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
252 |
+
affine parameters. Default: ``True``
|
253 |
+
|
254 |
+
Shape::
|
255 |
+
- Input: :math:`(N, C, H, W)`
|
256 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
257 |
+
|
258 |
+
Examples:
|
259 |
+
>>> # With Learnable Parameters
|
260 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
261 |
+
>>> # Without Learnable Parameters
|
262 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
263 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
264 |
+
>>> output = m(input)
|
265 |
+
"""
|
266 |
+
|
267 |
+
def _check_input_dim(self, input):
|
268 |
+
if input.dim() != 4:
|
269 |
+
raise ValueError('expected 4D input (got {}D input)'
|
270 |
+
.format(input.dim()))
|
271 |
+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
272 |
+
|
273 |
+
|
274 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
275 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
276 |
+
of 4d inputs
|
277 |
+
|
278 |
+
.. math::
|
279 |
+
|
280 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
281 |
+
|
282 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
283 |
+
standard-deviation are reduced across all devices during training.
|
284 |
+
|
285 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
286 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
287 |
+
the statistics only on that device, which accelerated the computation and
|
288 |
+
is also easy to implement, but the statistics might be inaccurate.
|
289 |
+
Instead, in this synchronized version, the statistics will be computed
|
290 |
+
over all training samples distributed on multiple devices.
|
291 |
+
|
292 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
293 |
+
as the built-in PyTorch implementation.
|
294 |
+
|
295 |
+
The mean and standard-deviation are calculated per-dimension over
|
296 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
297 |
+
of size C (where C is the input size).
|
298 |
+
|
299 |
+
During training, this layer keeps a running estimate of its computed mean
|
300 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
301 |
+
|
302 |
+
During evaluation, this running mean/variance is used for normalization.
|
303 |
+
|
304 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
305 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
306 |
+
or Spatio-temporal BatchNorm
|
307 |
+
|
308 |
+
Args:
|
309 |
+
num_features: num_features from an expected input of
|
310 |
+
size batch_size x num_features x depth x height x width
|
311 |
+
eps: a value added to the denominator for numerical stability.
|
312 |
+
Default: 1e-5
|
313 |
+
momentum: the value used for the running_mean and running_var
|
314 |
+
computation. Default: 0.1
|
315 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
316 |
+
affine parameters. Default: ``True``
|
317 |
+
|
318 |
+
Shape::
|
319 |
+
- Input: :math:`(N, C, D, H, W)`
|
320 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
321 |
+
|
322 |
+
Examples:
|
323 |
+
>>> # With Learnable Parameters
|
324 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
325 |
+
>>> # Without Learnable Parameters
|
326 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
327 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
328 |
+
>>> output = m(input)
|
329 |
+
"""
|
330 |
+
|
331 |
+
def _check_input_dim(self, input):
|
332 |
+
if input.dim() != 5:
|
333 |
+
raise ValueError('expected 5D input (got {}D input)'
|
334 |
+
.format(input.dim()))
|
335 |
+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
336 |
+
|
337 |
+
|
338 |
+
@contextlib.contextmanager
|
339 |
+
def patch_sync_batchnorm():
|
340 |
+
import torch.nn as nn
|
341 |
+
|
342 |
+
backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
|
343 |
+
|
344 |
+
nn.BatchNorm1d = SynchronizedBatchNorm1d
|
345 |
+
nn.BatchNorm2d = SynchronizedBatchNorm2d
|
346 |
+
nn.BatchNorm3d = SynchronizedBatchNorm3d
|
347 |
+
|
348 |
+
yield
|
349 |
+
|
350 |
+
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
|
351 |
+
|
352 |
+
|
353 |
+
def convert_model(module):
|
354 |
+
"""Traverse the input module and its child recursively
|
355 |
+
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
|
356 |
+
to SynchronizedBatchNorm*N*d
|
357 |
+
|
358 |
+
Args:
|
359 |
+
module: the input module needs to be convert to SyncBN model
|
360 |
+
|
361 |
+
Examples:
|
362 |
+
>>> import torch.nn as nn
|
363 |
+
>>> import torchvision
|
364 |
+
>>> # m is a standard pytorch model
|
365 |
+
>>> m = torchvision.models.resnet18(True)
|
366 |
+
>>> m = nn.DataParallel(m)
|
367 |
+
>>> # after convert, m is using SyncBN
|
368 |
+
>>> m = convert_model(m)
|
369 |
+
"""
|
370 |
+
if isinstance(module, torch.nn.DataParallel):
|
371 |
+
mod = module.module
|
372 |
+
mod = convert_model(mod)
|
373 |
+
mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
|
374 |
+
return mod
|
375 |
+
|
376 |
+
mod = module
|
377 |
+
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
|
378 |
+
torch.nn.modules.batchnorm.BatchNorm2d,
|
379 |
+
torch.nn.modules.batchnorm.BatchNorm3d],
|
380 |
+
[SynchronizedBatchNorm1d,
|
381 |
+
SynchronizedBatchNorm2d,
|
382 |
+
SynchronizedBatchNorm3d]):
|
383 |
+
if isinstance(module, pth_module):
|
384 |
+
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
|
385 |
+
mod.running_mean = module.running_mean
|
386 |
+
mod.running_var = module.running_var
|
387 |
+
if module.affine:
|
388 |
+
mod.weight.data = module.weight.data.clone().detach()
|
389 |
+
mod.bias.data = module.bias.data.clone().detach()
|
390 |
+
|
391 |
+
for name, child in module.named_children():
|
392 |
+
mod.add_module(name, convert_model(child))
|
393 |
+
|
394 |
+
return mod
|