|
import torch |
|
import torch.nn as nn |
|
from torch.nn.functional import conv2d |
|
|
|
|
|
class Whitening2d(nn.Module): |
|
def __init__(self, num_features, momentum=0.01, track_running_stats=True, eps=0): |
|
super(Whitening2d, self).__init__() |
|
self.num_features = num_features |
|
self.momentum = momentum |
|
self.track_running_stats = track_running_stats |
|
self.eps = eps |
|
|
|
if self.track_running_stats: |
|
self.register_buffer( |
|
"running_mean", torch.zeros([1, self.num_features, 1, 1]) |
|
) |
|
self.register_buffer("running_variance", torch.eye(self.num_features)) |
|
|
|
def forward(self, x): |
|
x = x.unsqueeze(2).unsqueeze(3) |
|
m = x.mean(0).view(self.num_features, -1).mean(-1).view(1, -1, 1, 1) |
|
if not self.training and self.track_running_stats: |
|
m = self.running_mean |
|
xn = x - m |
|
|
|
T = xn.permute(1, 0, 2, 3).contiguous().view(self.num_features, -1) |
|
f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1) |
|
|
|
eye = torch.eye(self.num_features).type(f_cov.type()) |
|
|
|
if not self.training and self.track_running_stats: |
|
f_cov = self.running_variance |
|
|
|
f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye |
|
|
|
inv_sqrt = torch.linalg.solve_triangular( |
|
torch.linalg.cholesky(f_cov_shrinked), |
|
eye, |
|
upper=False |
|
) |
|
|
|
inv_sqrt = inv_sqrt.contiguous().view( |
|
self.num_features, self.num_features, 1, 1 |
|
) |
|
|
|
decorrelated = conv2d(xn, inv_sqrt) |
|
|
|
if self.training and self.track_running_stats: |
|
self.running_mean = torch.add( |
|
self.momentum * m.detach(), |
|
(1 - self.momentum) * self.running_mean, |
|
out=self.running_mean, |
|
) |
|
self.running_variance = torch.add( |
|
self.momentum * f_cov.detach(), |
|
(1 - self.momentum) * self.running_variance, |
|
out=self.running_variance, |
|
) |
|
|
|
return decorrelated.squeeze(2).squeeze(2) |
|
|
|
def extra_repr(self): |
|
return "features={}, eps={}, momentum={}".format( |
|
self.num_features, self.eps, self.momentum |
|
) |
|
|