Spaces:
Runtime error
Runtime error
#! /usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# File : batchnorm_reimpl.py | |
# Author : acgtyrant | |
# Date : 11/01/2018 | |
# | |
# This file is part of Synchronized-BatchNorm-PyTorch. | |
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch | |
# Distributed under MIT License. | |
import torch | |
import torch.nn as nn | |
import torch.nn.init as init | |
__all__ = ['BatchNormReimpl'] | |
class BatchNorm2dReimpl(nn.Module): | |
""" | |
A re-implementation of batch normalization, used for testing the numerical | |
stability. | |
Author: acgtyrant | |
See also: | |
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 | |
""" | |
def __init__(self, num_features, eps=1e-5, momentum=0.1): | |
super().__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.momentum = momentum | |
self.weight = nn.Parameter(torch.empty(num_features)) | |
self.bias = nn.Parameter(torch.empty(num_features)) | |
self.register_buffer('running_mean', torch.zeros(num_features)) | |
self.register_buffer('running_var', torch.ones(num_features)) | |
self.reset_parameters() | |
def reset_running_stats(self): | |
self.running_mean.zero_() | |
self.running_var.fill_(1) | |
def reset_parameters(self): | |
self.reset_running_stats() | |
init.uniform_(self.weight) | |
init.zeros_(self.bias) | |
def forward(self, input_): | |
batchsize, channels, height, width = input_.size() | |
numel = batchsize * height * width | |
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) | |
sum_ = input_.sum(1) | |
sum_of_square = input_.pow(2).sum(1) | |
mean = sum_ / numel | |
sumvar = sum_of_square - sum_ * mean | |
self.running_mean = ( | |
(1 - self.momentum) * self.running_mean | |
+ self.momentum * mean.detach() | |
) | |
unbias_var = sumvar / (numel - 1) | |
self.running_var = ( | |
(1 - self.momentum) * self.running_var | |
+ self.momentum * unbias_var.detach() | |
) | |
bias_var = sumvar / numel | |
inv_std = 1 / (bias_var + self.eps).pow(0.5) | |
output = ( | |
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * | |
self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) | |
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() | |