Spaces:
Runtime error
Runtime error
""" | |
This is a base lightning module that can be used to train a model. | |
The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. | |
""" | |
import inspect | |
from abc import ABC | |
from typing import Any, Dict | |
import torch | |
from lightning import LightningModule | |
from lightning.pytorch.utilities import grad_norm | |
from matcha import utils | |
from matcha.utils.utils import plot_tensor | |
log = utils.get_pylogger(__name__) | |
class BaseLightningClass(LightningModule, ABC): | |
def update_data_statistics(self, data_statistics): | |
if data_statistics is None: | |
data_statistics = { | |
"mel_mean": 0.0, | |
"mel_std": 1.0, | |
} | |
self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) | |
self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) | |
def configure_optimizers(self) -> Any: | |
optimizer = self.hparams.optimizer(params=self.parameters()) | |
if self.hparams.scheduler not in (None, {}): | |
scheduler_args = {} | |
# Manage last epoch for exponential schedulers | |
if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: | |
if hasattr(self, "ckpt_loaded_epoch"): | |
current_epoch = self.ckpt_loaded_epoch - 1 | |
else: | |
current_epoch = -1 | |
scheduler_args.update({"optimizer": optimizer}) | |
scheduler = self.hparams.scheduler.scheduler(**scheduler_args) | |
scheduler.last_epoch = current_epoch | |
return { | |
"optimizer": optimizer, | |
"lr_scheduler": { | |
"scheduler": scheduler, | |
"interval": self.hparams.scheduler.lightning_args.interval, | |
"frequency": self.hparams.scheduler.lightning_args.frequency, | |
"name": "learning_rate", | |
}, | |
} | |
return {"optimizer": optimizer} | |
def get_losses(self, batch): | |
x, x_lengths = batch["x"], batch["x_lengths"] | |
y, y_lengths = batch["y"], batch["y_lengths"] | |
spks = batch["spks"] | |
dur_loss, prior_loss, diff_loss = self( | |
x=x, | |
x_lengths=x_lengths, | |
y=y, | |
y_lengths=y_lengths, | |
spks=spks, | |
out_size=self.out_size, | |
) | |
return { | |
"dur_loss": dur_loss, | |
"prior_loss": prior_loss, | |
"diff_loss": diff_loss, | |
} | |
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | |
self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init | |
def training_step(self, batch: Any, batch_idx: int): | |
loss_dict = self.get_losses(batch) | |
self.log( | |
"step", | |
float(self.global_step), | |
on_step=True, | |
prog_bar=True, | |
logger=True, | |
sync_dist=True, | |
) | |
self.log( | |
"sub_loss/train_dur_loss", | |
loss_dict["dur_loss"], | |
on_step=True, | |
on_epoch=True, | |
logger=True, | |
sync_dist=True, | |
) | |
self.log( | |
"sub_loss/train_prior_loss", | |
loss_dict["prior_loss"], | |
on_step=True, | |
on_epoch=True, | |
logger=True, | |
sync_dist=True, | |
) | |
self.log( | |
"sub_loss/train_diff_loss", | |
loss_dict["diff_loss"], | |
on_step=True, | |
on_epoch=True, | |
logger=True, | |
sync_dist=True, | |
) | |
total_loss = sum(loss_dict.values()) | |
self.log( | |
"loss/train", | |
total_loss, | |
on_step=True, | |
on_epoch=True, | |
logger=True, | |
prog_bar=True, | |
sync_dist=True, | |
) | |
return {"loss": total_loss, "log": loss_dict} | |
def validation_step(self, batch: Any, batch_idx: int): | |
loss_dict = self.get_losses(batch) | |
self.log( | |
"sub_loss/val_dur_loss", | |
loss_dict["dur_loss"], | |
on_step=True, | |
on_epoch=True, | |
logger=True, | |
sync_dist=True, | |
) | |
self.log( | |
"sub_loss/val_prior_loss", | |
loss_dict["prior_loss"], | |
on_step=True, | |
on_epoch=True, | |
logger=True, | |
sync_dist=True, | |
) | |
self.log( | |
"sub_loss/val_diff_loss", | |
loss_dict["diff_loss"], | |
on_step=True, | |
on_epoch=True, | |
logger=True, | |
sync_dist=True, | |
) | |
total_loss = sum(loss_dict.values()) | |
self.log( | |
"loss/val", | |
total_loss, | |
on_step=True, | |
on_epoch=True, | |
logger=True, | |
prog_bar=True, | |
sync_dist=True, | |
) | |
return total_loss | |
def on_validation_end(self) -> None: | |
if self.trainer.is_global_zero: | |
one_batch = next(iter(self.trainer.val_dataloaders)) | |
if self.current_epoch == 0: | |
log.debug("Plotting original samples") | |
for i in range(2): | |
y = one_batch["y"][i].unsqueeze(0).to(self.device) | |
self.logger.experiment.add_image( | |
f"original/{i}", | |
plot_tensor(y.squeeze().cpu()), | |
self.current_epoch, | |
dataformats="HWC", | |
) | |
log.debug("Synthesising...") | |
for i in range(2): | |
x = one_batch["x"][i].unsqueeze(0).to(self.device) | |
x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) | |
spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None | |
output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks) | |
y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] | |
attn = output["attn"] | |
self.logger.experiment.add_image( | |
f"generated_enc/{i}", | |
plot_tensor(y_enc.squeeze().cpu()), | |
self.current_epoch, | |
dataformats="HWC", | |
) | |
self.logger.experiment.add_image( | |
f"generated_dec/{i}", | |
plot_tensor(y_dec.squeeze().cpu()), | |
self.current_epoch, | |
dataformats="HWC", | |
) | |
self.logger.experiment.add_image( | |
f"alignment/{i}", | |
plot_tensor(attn.squeeze().cpu()), | |
self.current_epoch, | |
dataformats="HWC", | |
) | |
def on_before_optimizer_step(self, optimizer): | |
self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) | |