Spaces:
Running
on
Zero
Running
on
Zero
### | |
# Author: Kai Li | |
# Date: 2022-05-26 18:09:54 | |
# Email: [email protected] | |
# LastEditTime: 2024-01-24 00:00:28 | |
### | |
import gc | |
from omegaconf import OmegaConf | |
import torch | |
import pytorch_lightning as pl | |
from torch.optim.lr_scheduler import ReduceLROnPlateau | |
from collections.abc import MutableMapping | |
from omegaconf import ListConfig | |
def flatten_dict(d, parent_key="", sep="_"): | |
"""Flattens a dictionary into a single-level dictionary while preserving | |
parent keys. Taken from | |
`SO <https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys>`_ | |
Args: | |
d (MutableMapping): Dictionary to be flattened. | |
parent_key (str): String to use as a prefix to all subsequent keys. | |
sep (str): String to use as a separator between two key levels. | |
Returns: | |
dict: Single-level dictionary, flattened. | |
""" | |
items = [] | |
for k, v in d.items(): | |
new_key = parent_key + sep + k if parent_key else k | |
if isinstance(v, MutableMapping): | |
items.extend(flatten_dict(v, new_key, sep=sep).items()) | |
else: | |
items.append((new_key, v)) | |
return dict(items) | |
class AudioLightningModule(pl.LightningModule): | |
def __init__( | |
self, | |
model=None, | |
discriminator=None, | |
optimizer=None, | |
loss_func=None, | |
metrics=None, | |
scheduler=None, | |
): | |
super().__init__() | |
self.audio_model = model | |
self.discriminator = discriminator | |
self.optimizer = list(optimizer) | |
self.loss_func = loss_func | |
self.metrics = metrics | |
self.scheduler = list(scheduler) | |
# Save lightning"s AttributeDict under self.hparams | |
self.default_monitor = "val_loss" | |
# self.print(self.audio_model) | |
self.validation_step_outputs = [] | |
self.test_step_outputs = [] | |
self.automatic_optimization = False | |
def forward(self, wav): | |
"""Applies forward pass of the model. | |
Returns: | |
:class:`torch.Tensor` | |
""" | |
return self.audio_model(wav) | |
def training_step(self, batch, batch_nb): | |
ori_data, codec_data = batch | |
optimizer_g, optimizer_d = self.optimizers() | |
# multiple schedulers | |
scheduler_g, scheduler_d = self.lr_schedulers() | |
# train discriminator | |
optimizer_g.zero_grad() | |
output = self(codec_data) | |
optimizer_d.zero_grad() | |
est_outputs, _ = self.discriminator(output.detach(), sample_rate=44100) | |
target_outputs, _ = self.discriminator(ori_data, sample_rate=44100) | |
loss_d = self.loss_func["d"](target_outputs, est_outputs) | |
self.manual_backward(loss_d) | |
self.clip_gradients(optimizer_d, gradient_clip_val=5, gradient_clip_algorithm="norm") | |
optimizer_d.step() | |
# train generator | |
est_outputs, est_feature_maps = self.discriminator(output, sample_rate=44100) | |
_, targets_feature_maps = self.discriminator(ori_data, sample_rate=44100) | |
loss_g = self.loss_func["g"](est_outputs, est_feature_maps, targets_feature_maps, output, ori_data) | |
self.manual_backward(loss_g) | |
self.clip_gradients(optimizer_g, gradient_clip_val=5, gradient_clip_algorithm="norm") | |
optimizer_g.step() | |
# print(loss) | |
if self.trainer.is_last_batch: | |
scheduler_g.step() | |
scheduler_d.step() | |
self.log( | |
"train_loss_d", | |
loss_d, | |
on_epoch=True, | |
prog_bar=True, | |
sync_dist=True, | |
logger=True, | |
) | |
self.log( | |
"train_loss_g", | |
loss_g, | |
on_epoch=True, | |
prog_bar=True, | |
sync_dist=True, | |
logger=True, | |
) | |
def validation_step(self, batch, batch_nb): | |
# cal val loss | |
ori_data, codec_data = batch | |
# print(mixtures.shape) | |
est_sources = self(codec_data) | |
loss = self.metrics(est_sources, ori_data) | |
self.log( | |
"val_loss", | |
loss, | |
on_epoch=True, | |
prog_bar=True, | |
sync_dist=True, | |
logger=True, | |
) | |
self.validation_step_outputs.append(loss) | |
return {"val_loss": loss} | |
def on_validation_epoch_end(self): | |
# val | |
avg_loss = torch.stack(self.validation_step_outputs).mean() | |
val_loss = torch.mean(self.all_gather(avg_loss)) | |
self.log( | |
"lr", | |
self.optimizer[0].param_groups[0]["lr"], | |
on_epoch=True, | |
prog_bar=True, | |
sync_dist=True, | |
) | |
self.logger.experiment.log( | |
{"learning_rate": self.optimizer[0].param_groups[0]["lr"], "epoch": self.current_epoch} | |
) | |
self.logger.experiment.log( | |
{"val_pit_sisnr": -val_loss, "epoch": self.current_epoch} | |
) | |
self.validation_step_outputs.clear() # free memory | |
torch.cuda.empty_cache() | |
def test_step(self, batch, batch_nb): | |
mixtures, targets = batch | |
est_sources = self(mixtures) | |
loss = self.metrics(est_sources, targets) | |
self.log( | |
"test_loss", | |
loss, | |
on_epoch=True, | |
prog_bar=True, | |
sync_dist=True, | |
logger=True, | |
) | |
self.test_step_outputs.append(loss) | |
return {"test_loss": loss} | |
def on_test_epoch_end(self): | |
# val | |
avg_loss = torch.stack(self.test_step_outputs).mean() | |
test_loss = torch.mean(self.all_gather(avg_loss)) | |
self.log( | |
"lr", | |
self.optimizer.param_groups[0]["lr"], | |
on_epoch=True, | |
prog_bar=True, | |
sync_dist=True, | |
) | |
self.logger.experiment.log( | |
{"learning_rate": self.optimizer.param_groups[0]["lr"], "epoch": self.current_epoch} | |
) | |
self.logger.experiment.log( | |
{"test_pit_sisnr": -test_loss, "epoch": self.current_epoch} | |
) | |
self.test_step_outputs.clear() | |
def configure_optimizers(self): | |
"""Initialize optimizers, batch-wise and epoch-wise schedulers.""" | |
if self.scheduler is None: | |
return self.optimizer | |
if not isinstance(self.scheduler, (list, tuple)): | |
self.scheduler = [self.scheduler] # support multiple schedulers | |
if not isinstance(self.optimizer, (list, tuple)): | |
self.optimizer = [self.optimizer] # support multiple schedulers | |
epoch_schedulers = [] | |
for sched in self.scheduler: | |
if not isinstance(sched, dict): | |
if isinstance(sched, ReduceLROnPlateau): | |
sched = {"scheduler": sched, "monitor": self.default_monitor} | |
epoch_schedulers.append(sched) | |
else: | |
sched.setdefault("monitor", self.default_monitor) | |
sched.setdefault("frequency", 1) | |
# Backward compat | |
if sched["interval"] == "batch": | |
sched["interval"] = "step" | |
assert sched["interval"] in [ | |
"epoch", | |
"step", | |
], "Scheduler interval should be either step or epoch" | |
epoch_schedulers.append(sched) | |
return self.optimizer, epoch_schedulers | |
def config_to_hparams(dic): | |
"""Sanitizes the config dict to be handled correctly by torch | |
SummaryWriter. It flatten the config dict, converts ``None`` to | |
``"None"`` and any list and tuple into torch.Tensors. | |
Args: | |
dic (dict): Dictionary to be transformed. | |
Returns: | |
dict: Transformed dictionary. | |
""" | |
dic = flatten_dict(dic) | |
for k, v in dic.items(): | |
if v is None: | |
dic[k] = str(v) | |
elif isinstance(v, (list, tuple)): | |
dic[k] = torch.tensor(v) | |
return dic | |