|
""" |
|
This is a base lightning module that can be used to train a model. |
|
The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. |
|
""" |
|
import inspect |
|
from abc import ABC |
|
from typing import Any, Dict |
|
|
|
import torch |
|
from lightning import LightningModule |
|
from lightning.pytorch.utilities import grad_norm |
|
|
|
from matcha import utils |
|
from matcha.utils.utils import plot_tensor |
|
|
|
log = utils.get_pylogger(__name__) |
|
|
|
|
|
class BaseLightningClass(LightningModule, ABC): |
|
def update_data_statistics(self, data_statistics): |
|
if data_statistics is None: |
|
data_statistics = { |
|
"mel_mean": 0.0, |
|
"mel_std": 1.0, |
|
} |
|
|
|
self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) |
|
self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) |
|
|
|
def configure_optimizers(self) -> Any: |
|
optimizer = self.hparams.optimizer(params=self.parameters()) |
|
if self.hparams.scheduler not in (None, {}): |
|
scheduler_args = {} |
|
|
|
if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: |
|
if hasattr(self, "ckpt_loaded_epoch"): |
|
current_epoch = self.ckpt_loaded_epoch - 1 |
|
else: |
|
current_epoch = -1 |
|
|
|
scheduler_args.update({"optimizer": optimizer}) |
|
scheduler = self.hparams.scheduler.scheduler(**scheduler_args) |
|
scheduler.last_epoch = current_epoch |
|
return { |
|
"optimizer": optimizer, |
|
"lr_scheduler": { |
|
"scheduler": scheduler, |
|
"interval": self.hparams.scheduler.lightning_args.interval, |
|
"frequency": self.hparams.scheduler.lightning_args.frequency, |
|
"name": "learning_rate", |
|
}, |
|
} |
|
|
|
return {"optimizer": optimizer} |
|
|
|
def get_losses(self, batch): |
|
x, x_lengths = batch["x"], batch["x_lengths"] |
|
y, y_lengths = batch["y"], batch["y_lengths"] |
|
spks = batch["spks"] |
|
|
|
dur_loss, prior_loss, diff_loss = self( |
|
x=x, |
|
x_lengths=x_lengths, |
|
y=y, |
|
y_lengths=y_lengths, |
|
spks=spks, |
|
out_size=self.out_size, |
|
) |
|
return { |
|
"dur_loss": dur_loss, |
|
"prior_loss": prior_loss, |
|
"diff_loss": diff_loss, |
|
} |
|
|
|
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: |
|
self.ckpt_loaded_epoch = checkpoint["epoch"] |
|
|
|
def training_step(self, batch: Any, batch_idx: int): |
|
loss_dict = self.get_losses(batch) |
|
self.log( |
|
"step", |
|
float(self.global_step), |
|
on_step=True, |
|
prog_bar=True, |
|
logger=True, |
|
sync_dist=True, |
|
) |
|
|
|
self.log( |
|
"sub_loss/train_dur_loss", |
|
loss_dict["dur_loss"], |
|
on_step=True, |
|
on_epoch=True, |
|
logger=True, |
|
sync_dist=True, |
|
) |
|
self.log( |
|
"sub_loss/train_prior_loss", |
|
loss_dict["prior_loss"], |
|
on_step=True, |
|
on_epoch=True, |
|
logger=True, |
|
sync_dist=True, |
|
) |
|
self.log( |
|
"sub_loss/train_diff_loss", |
|
loss_dict["diff_loss"], |
|
on_step=True, |
|
on_epoch=True, |
|
logger=True, |
|
sync_dist=True, |
|
) |
|
|
|
total_loss = sum(loss_dict.values()) |
|
self.log( |
|
"loss/train", |
|
total_loss, |
|
on_step=True, |
|
on_epoch=True, |
|
logger=True, |
|
prog_bar=True, |
|
sync_dist=True, |
|
) |
|
|
|
return {"loss": total_loss, "log": loss_dict} |
|
|
|
def validation_step(self, batch: Any, batch_idx: int): |
|
loss_dict = self.get_losses(batch) |
|
self.log( |
|
"sub_loss/val_dur_loss", |
|
loss_dict["dur_loss"], |
|
on_step=True, |
|
on_epoch=True, |
|
logger=True, |
|
sync_dist=True, |
|
) |
|
self.log( |
|
"sub_loss/val_prior_loss", |
|
loss_dict["prior_loss"], |
|
on_step=True, |
|
on_epoch=True, |
|
logger=True, |
|
sync_dist=True, |
|
) |
|
self.log( |
|
"sub_loss/val_diff_loss", |
|
loss_dict["diff_loss"], |
|
on_step=True, |
|
on_epoch=True, |
|
logger=True, |
|
sync_dist=True, |
|
) |
|
|
|
total_loss = sum(loss_dict.values()) |
|
self.log( |
|
"loss/val", |
|
total_loss, |
|
on_step=True, |
|
on_epoch=True, |
|
logger=True, |
|
prog_bar=True, |
|
sync_dist=True, |
|
) |
|
|
|
return total_loss |
|
|
|
def on_validation_end(self) -> None: |
|
if self.trainer.is_global_zero: |
|
one_batch = next(iter(self.trainer.val_dataloaders)) |
|
if self.current_epoch == 0: |
|
log.debug("Plotting original samples") |
|
for i in range(2): |
|
y = one_batch["y"][i].unsqueeze(0).to(self.device) |
|
self.logger.experiment.add_image( |
|
f"original/{i}", |
|
plot_tensor(y.squeeze().cpu()), |
|
self.current_epoch, |
|
dataformats="HWC", |
|
) |
|
|
|
log.debug("Synthesising...") |
|
for i in range(2): |
|
x = one_batch["x"][i].unsqueeze(0).to(self.device) |
|
x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) |
|
spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None |
|
output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks) |
|
y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] |
|
attn = output["attn"] |
|
self.logger.experiment.add_image( |
|
f"generated_enc/{i}", |
|
plot_tensor(y_enc.squeeze().cpu()), |
|
self.current_epoch, |
|
dataformats="HWC", |
|
) |
|
self.logger.experiment.add_image( |
|
f"generated_dec/{i}", |
|
plot_tensor(y_dec.squeeze().cpu()), |
|
self.current_epoch, |
|
dataformats="HWC", |
|
) |
|
self.logger.experiment.add_image( |
|
f"alignment/{i}", |
|
plot_tensor(attn.squeeze().cpu()), |
|
self.current_epoch, |
|
dataformats="HWC", |
|
) |
|
|
|
def on_before_optimizer_step(self, optimizer): |
|
self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) |
|
|