Spaces:
Running
on
Zero
Running
on
Zero
from abc import abstractmethod | |
from typing import Any, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ....modules.distributions.distributions import DiagonalGaussianDistribution | |
from .base import AbstractRegularizer | |
class DiagonalGaussianRegularizer(AbstractRegularizer): | |
def __init__(self, sample: bool = True): | |
super().__init__() | |
self.sample = sample | |
def get_trainable_parameters(self) -> Any: | |
yield from () | |
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: | |
log = dict() | |
posterior = DiagonalGaussianDistribution(z) | |
if self.sample: | |
z = posterior.sample() | |
else: | |
z = posterior.mode() | |
kl_loss = posterior.kl() | |
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] | |
log["kl_loss"] = kl_loss | |
return z, log | |