Spaces:
Runtime error
Runtime error
File size: 2,110 Bytes
d4ab5ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
"""
File copied from
https://github.com/nicola-decao/diffmask/blob/master/diffmask/models/distributions.py
"""
import torch
import torch.distributions as distr
import torch.nn.functional as F
from torch import Tensor
class BinaryConcrete(distr.relaxed_bernoulli.RelaxedBernoulli):
def __init__(self, temperature: Tensor, logits: Tensor):
super().__init__(temperature=temperature, logits=logits)
self.device = self.temperature.device
def cdf(self, value: Tensor) -> Tensor:
return torch.sigmoid(
(torch.log(value) - torch.log(1.0 - value)) * self.temperature - self.logits
)
def log_prob(self, value: Tensor) -> Tensor:
return torch.where(
(value > 0) & (value < 1),
super().log_prob(value),
torch.full_like(value, -float("inf")),
)
def log_expected_L0(self, value: Tensor) -> Tensor:
return -F.softplus(
(torch.log(value) - torch.log(1 - value)) * self.temperature - self.logits
)
class Streched(distr.TransformedDistribution):
def __init__(self, base_dist, l: float = -0.1, r: float = 1.1):
super().__init__(base_dist, distr.AffineTransform(loc=l, scale=r - l))
def log_expected_L0(self) -> Tensor:
value = torch.tensor(0.0, device=self.base_dist.device)
for transform in self.transforms[::-1]:
value = transform.inv(value)
if self._validate_args:
self.base_dist._validate_sample(value)
value = self.base_dist.log_expected_L0(value)
value = self._monotonize_cdf(value)
return value
def expected_L0(self) -> Tensor:
return self.log_expected_L0().exp()
class RectifiedStreched(Streched):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.no_grad()
def sample(self, sample_shape: torch.Size = torch.Size([])) -> Tensor:
return self.rsample(sample_shape)
def rsample(self, sample_shape: torch.Size = torch.Size([])) -> Tensor:
x = super().rsample(sample_shape)
return x.clamp(0, 1)
|