Spaces:
Sleeping
Sleeping
# Adapted from Open-Sora-Plan | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan | |
# -------------------------------------------------------- | |
import torch | |
import torch.nn as nn | |
from .block import Block | |
class GroupNorm(Block): | |
def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True) | |
def forward(self, x): | |
return self.norm(x) | |
def Normalize(in_channels, num_groups=32): | |
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) | |
class ActNorm(nn.Module): | |
def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False): | |
assert affine | |
super().__init__() | |
self.logdet = logdet | |
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) | |
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) | |
self.allow_reverse_init = allow_reverse_init | |
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) | |
def initialize(self, input): | |
with torch.no_grad(): | |
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) | |
mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) | |
std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) | |
self.loc.data.copy_(-mean) | |
self.scale.data.copy_(1 / (std + 1e-6)) | |
def forward(self, input, reverse=False): | |
if reverse: | |
return self.reverse(input) | |
if len(input.shape) == 2: | |
input = input[:, :, None, None] | |
squeeze = True | |
else: | |
squeeze = False | |
_, _, height, width = input.shape | |
if self.training and self.initialized.item() == 0: | |
self.initialize(input) | |
self.initialized.fill_(1) | |
h = self.scale * (input + self.loc) | |
if squeeze: | |
h = h.squeeze(-1).squeeze(-1) | |
if self.logdet: | |
log_abs = torch.log(torch.abs(self.scale)) | |
logdet = height * width * torch.sum(log_abs) | |
logdet = logdet * torch.ones(input.shape[0]).to(input) | |
return h, logdet | |
return h | |
def reverse(self, output): | |
if self.training and self.initialized.item() == 0: | |
if not self.allow_reverse_init: | |
raise RuntimeError( | |
"Initializing ActNorm in reverse direction is " | |
"disabled by default. Use allow_reverse_init=True to enable." | |
) | |
else: | |
self.initialize(output) | |
self.initialized.fill_(1) | |
if len(output.shape) == 2: | |
output = output[:, :, None, None] | |
squeeze = True | |
else: | |
squeeze = False | |
h = output / self.scale - self.loc | |
if squeeze: | |
h = h.squeeze(-1).squeeze(-1) | |
return h | |