File size: 4,703 Bytes
32b2aaa 1df74c6 32b2aaa 1df74c6 32b2aaa 1df74c6 32b2aaa 1df74c6 32b2aaa 1df74c6 32b2aaa 1df74c6 32b2aaa 1df74c6 32b2aaa 1df74c6 32b2aaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import logging
from enum import Enum
from typing import Union
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import Tensor, nn
from .cfm import CFM
from .irmae import IRMAE, IRMAEOutput
logger = logging.getLogger(__name__)
def freeze_(module):
for p in module.parameters():
p.requires_grad_(False)
class LCFM(nn.Module):
class Mode(Enum):
AE = "ae"
CFM = "cfm"
def __init__(self, ae: IRMAE, cfm: CFM, z_scale: float = 1.0):
super().__init__()
self.ae = ae
self.cfm = cfm
self.z_scale = z_scale
self._mode = None
self._eval_tau = 0.5
@property
def mode(self):
return self._mode
def set_mode_(self, mode):
mode = self.Mode(mode)
self._mode = mode
if mode == mode.AE:
freeze_(self.cfm)
logger.info("Freeze cfm")
elif mode == mode.CFM:
freeze_(self.ae)
logger.info("Freeze ae (encoder and decoder)")
else:
raise ValueError(f"Unknown training mode: {mode}")
def get_running_train_loop(self):
try:
# Lazy import
from ...utils.train_loop import TrainLoop
return TrainLoop.get_running_loop()
except ImportError:
return None
@property
def global_step(self):
loop = self.get_running_train_loop()
if loop is None:
return None
return loop.global_step
@torch.no_grad()
def _visualize(self, x, y, y_):
loop = self.get_running_train_loop()
if loop is None:
return
plt.subplot(221)
plt.imshow(
y[0].detach().cpu().numpy(),
aspect="auto",
origin="lower",
interpolation="none",
)
plt.title("GT")
plt.subplot(222)
y_ = y_[:, : y.shape[1]]
plt.imshow(
y_[0].detach().cpu().numpy(),
aspect="auto",
origin="lower",
interpolation="none",
)
plt.title("Posterior")
plt.subplot(223)
z_ = self.cfm(x)
y__ = self.ae.decode(z_)
y__ = y__[:, : y.shape[1]]
plt.imshow(
y__[0].detach().cpu().numpy(),
aspect="auto",
origin="lower",
interpolation="none",
)
plt.title("C-Prior")
del y__
plt.subplot(224)
z_ = torch.randn_like(z_)
y__ = self.ae.decode(z_)
y__ = y__[:, : y.shape[1]]
plt.imshow(
y__[0].detach().cpu().numpy(),
aspect="auto",
origin="lower",
interpolation="none",
)
plt.title("Prior")
del z_, y__
path = loop.make_current_step_viz_path("recon", ".png")
path.parent.mkdir(exist_ok=True, parents=True)
plt.tight_layout()
plt.savefig(path, dpi=500)
plt.close()
def _scale(self, z: Tensor):
return z * self.z_scale
def _unscale(self, z: Tensor):
return z / self.z_scale
def eval_tau_(self, tau):
self._eval_tau = tau
def forward(self, x, y: Union[Tensor, None] = None, ψ0: Union[Tensor, None] = None):
"""
Args:
x: (b d t), condition mel
y: (b d t), target mel
ψ0: (b d t), starting mel
"""
if self.mode == self.Mode.CFM:
self.ae.eval() # Always set to eval when training cfm
if ψ0 is not None:
ψ0 = self._scale(self.ae.encode(ψ0))
if self.training:
tau = torch.rand_like(ψ0[:, :1, :1])
else:
tau = self._eval_tau
ψ0 = tau * torch.randn_like(ψ0) + (1 - tau) * ψ0
if y is None:
if self.mode == self.Mode.AE:
with torch.no_grad():
training = self.ae.training
self.ae.eval()
z = self.ae.encode(x)
self.ae.train(training)
else:
z = self._unscale(self.cfm(x, ψ0=ψ0))
h = self.ae.decode(z)
else:
ae_output: IRMAEOutput = self.ae(
y, skip_decoding=self.mode == self.Mode.CFM
)
if self.mode == self.Mode.CFM:
_ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0)
h = ae_output.decoded
if (
h is not None
and self.global_step is not None
and self.global_step % 100 == 0
):
self._visualize(x[:1], y[:1], h[:1])
return h
|