Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# File : batchnorm.py | |
# Author : Jiayuan Mao | |
# Email : [email protected] | |
# Date : 27/01/2018 | |
# | |
# This file is part of Synchronized-BatchNorm-PyTorch. | |
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch | |
# Distributed under MIT License. | |
import collections | |
import torch | |
import torch.nn.functional as F | |
from torch.nn.modules.batchnorm import _BatchNorm | |
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast | |
from .comm import SyncMaster | |
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] | |
def _sum_ft(tensor): | |
"""sum over the first and last dimention""" | |
return tensor.sum(dim=0).sum(dim=-1) | |
def _unsqueeze_ft(tensor): | |
"""add new dementions at the front and the tail""" | |
return tensor.unsqueeze(0).unsqueeze(-1) | |
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) | |
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) | |
# _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size']) | |
class _SynchronizedBatchNorm(_BatchNorm): | |
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): | |
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) | |
self._sync_master = SyncMaster(self._data_parallel_master) | |
self._is_parallel = False | |
self._parallel_id = None | |
self._slave_pipe = None | |
def forward(self, input, gain=None, bias=None): | |
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. | |
if not (self._is_parallel and self.training): | |
out = F.batch_norm( | |
input, self.running_mean, self.running_var, self.weight, self.bias, | |
self.training, self.momentum, self.eps) | |
if gain is not None: | |
out = out + gain | |
if bias is not None: | |
out = out + bias | |
return out | |
# Resize the input to (B, C, -1). | |
input_shape = input.size() | |
# print(input_shape) | |
input = input.view(input.size(0), input.size(1), -1) | |
# Compute the sum and square-sum. | |
sum_size = input.size(0) * input.size(2) | |
input_sum = _sum_ft(input) | |
input_ssum = _sum_ft(input ** 2) | |
# Reduce-and-broadcast the statistics. | |
# print('it begins') | |
if self._parallel_id == 0: | |
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) | |
else: | |
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) | |
# if self._parallel_id == 0: | |
# # print('here') | |
# sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) | |
# else: | |
# # print('there') | |
# sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) | |
# print('how2') | |
# num = sum_size | |
# print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu()))) | |
# Fix the graph | |
# sum = (sum.detach() - input_sum.detach()) + input_sum | |
# ssum = (ssum.detach() - input_ssum.detach()) + input_ssum | |
# mean = sum / num | |
# var = ssum / num - mean ** 2 | |
# # var = (ssum - mean * sum) / num | |
# inv_std = torch.rsqrt(var + self.eps) | |
# Compute the output. | |
if gain is not None: | |
# print('gaining') | |
# scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1) | |
# shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1) | |
# output = input * scale - shift | |
output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1) | |
elif self.affine: | |
# MJY:: Fuse the multiplication for speed. | |
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) | |
else: | |
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) | |
# Reshape it. | |
return output.view(input_shape) | |
def __data_parallel_replicate__(self, ctx, copy_id): | |
self._is_parallel = True | |
self._parallel_id = copy_id | |
# parallel_id == 0 means master device. | |
if self._parallel_id == 0: | |
ctx.sync_master = self._sync_master | |
else: | |
self._slave_pipe = ctx.sync_master.register_slave(copy_id) | |
def _data_parallel_master(self, intermediates): | |
"""Reduce the sum and square-sum, compute the statistics, and broadcast it.""" | |
# Always using same "device order" makes the ReduceAdd operation faster. | |
# Thanks to:: Tete Xiao (http://tetexiao.com/) | |
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) | |
to_reduce = [i[1][:2] for i in intermediates] | |
to_reduce = [j for i in to_reduce for j in i] # flatten | |
target_gpus = [i[1].sum.get_device() for i in intermediates] | |
sum_size = sum([i[1].sum_size for i in intermediates]) | |
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) | |
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) | |
broadcasted = Broadcast.apply(target_gpus, mean, inv_std) | |
# print('a') | |
# print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size) | |
# broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device)) | |
# print('b') | |
outputs = [] | |
for i, rec in enumerate(intermediates): | |
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) | |
# outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3]))) | |
return outputs | |
def _compute_mean_std(self, sum_, ssum, size): | |
"""Compute the mean and standard-deviation with sum and square-sum. This method | |
also maintains the moving average on the master device.""" | |
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' | |
mean = sum_ / size | |
sumvar = ssum - sum_ * mean | |
unbias_var = sumvar / (size - 1) | |
bias_var = sumvar / size | |
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data | |
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data | |
return mean, torch.rsqrt(bias_var + self.eps) | |
# return mean, bias_var.clamp(self.eps) ** -0.5 | |
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): | |
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a | |
mini-batch. | |
.. math:: | |
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta | |
This module differs from the built-in PyTorch BatchNorm1d as the mean and | |
standard-deviation are reduced across all devices during training. | |
For example, when one uses `nn.DataParallel` to wrap the network during | |
training, PyTorch's implementation normalize the tensor on each device using | |
the statistics only on that device, which accelerated the computation and | |
is also easy to implement, but the statistics might be inaccurate. | |
Instead, in this synchronized version, the statistics will be computed | |
over all training samples distributed on multiple devices. | |
Note that, for one-GPU or CPU-only case, this module behaves exactly same | |
as the built-in PyTorch implementation. | |
The mean and standard-deviation are calculated per-dimension over | |
the mini-batches and gamma and beta are learnable parameter vectors | |
of size C (where C is the input size). | |
During training, this layer keeps a running estimate of its computed mean | |
and variance. The running sum is kept with a default momentum of 0.1. | |
During evaluation, this running mean/variance is used for normalization. | |
Because the BatchNorm is done over the `C` dimension, computing statistics | |
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm | |
Args: | |
num_features: num_features from an expected input of size | |
`batch_size x num_features [x width]` | |
eps: a value added to the denominator for numerical stability. | |
Default: 1e-5 | |
momentum: the value used for the running_mean and running_var | |
computation. Default: 0.1 | |
affine: a boolean value that when set to ``True``, gives the layer learnable | |
affine parameters. Default: ``True`` | |
Shape: | |
- Input: :math:`(N, C)` or :math:`(N, C, L)` | |
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) | |
Examples: | |
>>> # With Learnable Parameters | |
>>> m = SynchronizedBatchNorm1d(100) | |
>>> # Without Learnable Parameters | |
>>> m = SynchronizedBatchNorm1d(100, affine=False) | |
>>> input = torch.autograd.Variable(torch.randn(20, 100)) | |
>>> output = m(input) | |
""" | |
def _check_input_dim(self, input): | |
if input.dim() != 2 and input.dim() != 3: | |
raise ValueError('expected 2D or 3D input (got {}D input)' | |
.format(input.dim())) | |
super(SynchronizedBatchNorm1d, self)._check_input_dim(input) | |
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): | |
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch | |
of 3d inputs | |
.. math:: | |
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta | |
This module differs from the built-in PyTorch BatchNorm2d as the mean and | |
standard-deviation are reduced across all devices during training. | |
For example, when one uses `nn.DataParallel` to wrap the network during | |
training, PyTorch's implementation normalize the tensor on each device using | |
the statistics only on that device, which accelerated the computation and | |
is also easy to implement, but the statistics might be inaccurate. | |
Instead, in this synchronized version, the statistics will be computed | |
over all training samples distributed on multiple devices. | |
Note that, for one-GPU or CPU-only case, this module behaves exactly same | |
as the built-in PyTorch implementation. | |
The mean and standard-deviation are calculated per-dimension over | |
the mini-batches and gamma and beta are learnable parameter vectors | |
of size C (where C is the input size). | |
During training, this layer keeps a running estimate of its computed mean | |
and variance. The running sum is kept with a default momentum of 0.1. | |
During evaluation, this running mean/variance is used for normalization. | |
Because the BatchNorm is done over the `C` dimension, computing statistics | |
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm | |
Args: | |
num_features: num_features from an expected input of | |
size batch_size x num_features x height x width | |
eps: a value added to the denominator for numerical stability. | |
Default: 1e-5 | |
momentum: the value used for the running_mean and running_var | |
computation. Default: 0.1 | |
affine: a boolean value that when set to ``True``, gives the layer learnable | |
affine parameters. Default: ``True`` | |
Shape: | |
- Input: :math:`(N, C, H, W)` | |
- Output: :math:`(N, C, H, W)` (same shape as input) | |
Examples: | |
>>> # With Learnable Parameters | |
>>> m = SynchronizedBatchNorm2d(100) | |
>>> # Without Learnable Parameters | |
>>> m = SynchronizedBatchNorm2d(100, affine=False) | |
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) | |
>>> output = m(input) | |
""" | |
def _check_input_dim(self, input): | |
if input.dim() != 4: | |
raise ValueError('expected 4D input (got {}D input)' | |
.format(input.dim())) | |
super(SynchronizedBatchNorm2d, self)._check_input_dim(input) | |
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): | |
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch | |
of 4d inputs | |
.. math:: | |
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta | |
This module differs from the built-in PyTorch BatchNorm3d as the mean and | |
standard-deviation are reduced across all devices during training. | |
For example, when one uses `nn.DataParallel` to wrap the network during | |
training, PyTorch's implementation normalize the tensor on each device using | |
the statistics only on that device, which accelerated the computation and | |
is also easy to implement, but the statistics might be inaccurate. | |
Instead, in this synchronized version, the statistics will be computed | |
over all training samples distributed on multiple devices. | |
Note that, for one-GPU or CPU-only case, this module behaves exactly same | |
as the built-in PyTorch implementation. | |
The mean and standard-deviation are calculated per-dimension over | |
the mini-batches and gamma and beta are learnable parameter vectors | |
of size C (where C is the input size). | |
During training, this layer keeps a running estimate of its computed mean | |
and variance. The running sum is kept with a default momentum of 0.1. | |
During evaluation, this running mean/variance is used for normalization. | |
Because the BatchNorm is done over the `C` dimension, computing statistics | |
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm | |
or Spatio-temporal BatchNorm | |
Args: | |
num_features: num_features from an expected input of | |
size batch_size x num_features x depth x height x width | |
eps: a value added to the denominator for numerical stability. | |
Default: 1e-5 | |
momentum: the value used for the running_mean and running_var | |
computation. Default: 0.1 | |
affine: a boolean value that when set to ``True``, gives the layer learnable | |
affine parameters. Default: ``True`` | |
Shape: | |
- Input: :math:`(N, C, D, H, W)` | |
- Output: :math:`(N, C, D, H, W)` (same shape as input) | |
Examples: | |
>>> # With Learnable Parameters | |
>>> m = SynchronizedBatchNorm3d(100) | |
>>> # Without Learnable Parameters | |
>>> m = SynchronizedBatchNorm3d(100, affine=False) | |
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) | |
>>> output = m(input) | |
""" | |
def _check_input_dim(self, input): | |
if input.dim() != 5: | |
raise ValueError('expected 5D input (got {}D input)' | |
.format(input.dim())) | |
super(SynchronizedBatchNorm3d, self)._check_input_dim(input) |