File size: 4,033 Bytes
32b2aaa 1df74c6 32b2aaa 1df74c6 32b2aaa 627d3d7 32b2aaa 627d3d7 32b2aaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import logging
from dataclasses import dataclass
from typing import Union
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.utils.parametrizations import weight_norm
from ...common import Normalizer
logger = logging.getLogger(__name__)
@dataclass
class IRMAEOutput:
latent: Tensor # latent vector
decoded: Union[Tensor, None] # decoder output, include extra dim
class ResBlock(nn.Sequential):
def __init__(self, channels, dilations=[1, 2, 4, 8]):
wn = weight_norm
super().__init__(
nn.GroupNorm(32, channels),
nn.GELU(),
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[0])),
nn.GroupNorm(32, channels),
nn.GELU(),
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[1])),
nn.GroupNorm(32, channels),
nn.GELU(),
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[2])),
nn.GroupNorm(32, channels),
nn.GELU(),
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[3])),
)
def forward(self, x: Tensor):
return x + super().forward(x)
class IRMAE(nn.Module):
def __init__(
self,
input_dim,
output_dim,
latent_dim,
hidden_dim=1024,
num_irms=4,
):
"""
Args:
input_dim: input dimension
output_dim: output dimension
latent_dim: latent dimension
hidden_dim: hidden layer dimension
num_irm_matrics: number of implicit rank minimization matrices
norm: normalization layer
"""
self.input_dim = input_dim
super().__init__()
self.encoder = nn.Sequential(
nn.Conv1d(input_dim, hidden_dim, 3, padding="same"),
*[ResBlock(hidden_dim) for _ in range(4)],
# Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf)
*[
nn.Conv1d(
hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False
)
for i in range(num_irms)
],
nn.Tanh(),
)
self.decoder = nn.Sequential(
nn.Conv1d(latent_dim, hidden_dim, 3, padding="same"),
*[ResBlock(hidden_dim) for _ in range(4)],
nn.Conv1d(hidden_dim, output_dim, 1),
)
self.head = nn.Sequential(
nn.Conv1d(output_dim, hidden_dim, 3, padding="same"),
nn.GELU(),
nn.Conv1d(hidden_dim, input_dim, 1),
)
self.estimator = Normalizer()
def encode(self, x):
"""
Args:
x: (b c t) tensor
"""
z = self.encoder(x) # (b c t)
_ = self.estimator(z) # Estimate the glboal mean and std of z
self.stats = {}
self.stats["z_mean"] = z.mean().item()
self.stats["z_std"] = z.std().item()
z_float = z.float()
self.stats["z_abs_68"] = z_float.abs().quantile(0.6827).item()
self.stats["z_abs_95"] = z_float.abs().quantile(0.9545).item()
self.stats["z_abs_99"] = z_float.abs().quantile(0.9973).item()
return z
def decode(self, z):
"""
Args:
z: (b c t) tensor
"""
return self.decoder(z)
def forward(self, x, skip_decoding=False):
"""
Args:
x: (b c t) tensor
skip_decoding: if True, skip the decoding step
"""
z = self.encode(x) # q(z|x)
if skip_decoding:
# This speeds up the training in cfm only mode
decoded = None
else:
decoded = self.decode(z) # p(x|z)
predicted = self.head(decoded)
self.losses = dict(mse=F.mse_loss(predicted, x))
return IRMAEOutput(latent=z, decoded=decoded)
|