Spaces:
Runtime error
Runtime error
from typing import Iterable | |
import torch | |
from torch import nn | |
class NoiseRegularizer(nn.Module): | |
def forward(self, noises: Iterable[torch.Tensor]): | |
loss = 0 | |
for noise in noises: | |
size = noise.shape[2] | |
while True: | |
loss = ( | |
loss | |
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) | |
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) | |
) | |
if size <= 8: | |
break | |
noise = noise.reshape([1, 1, size // 2, 2, size // 2, 2]) | |
noise = noise.mean([3, 5]) | |
size //= 2 | |
return loss | |
def normalize(noises: Iterable[torch.Tensor]): | |
for noise in noises: | |
mean = noise.mean() | |
std = noise.std() | |
noise.data.add_(-mean).div_(std) | |