Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# Copyright 2019 Tomoki Hayashi | |
# MIT License (https://opensource.org/licenses/MIT) | |
"""Train Parallel WaveGAN.""" | |
import argparse | |
import logging | |
import os | |
import sys | |
from collections import defaultdict | |
import matplotlib | |
import numpy as np | |
import soundfile as sf | |
import torch | |
import yaml | |
from tensorboardX import SummaryWriter | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
import parallel_wavegan | |
import parallel_wavegan.models | |
import parallel_wavegan.optimizers | |
from parallel_wavegan.datasets import AudioMelDataset | |
from parallel_wavegan.datasets import AudioMelSCPDataset | |
from parallel_wavegan.layers import PQMF | |
from parallel_wavegan.losses import DiscriminatorAdversarialLoss | |
from parallel_wavegan.losses import FeatureMatchLoss | |
from parallel_wavegan.losses import GeneratorAdversarialLoss | |
from parallel_wavegan.losses import MelSpectrogramLoss | |
from parallel_wavegan.losses import MultiResolutionSTFTLoss | |
from parallel_wavegan.utils import read_hdf5 | |
# set to avoid matplotlib error in CLI environment | |
matplotlib.use("Agg") | |
class Trainer(object): | |
"""Customized trainer module for Parallel WaveGAN training.""" | |
def __init__( | |
self, | |
steps, | |
epochs, | |
data_loader, | |
sampler, | |
model, | |
criterion, | |
optimizer, | |
scheduler, | |
config, | |
device=torch.device("cpu"), | |
): | |
"""Initialize trainer. | |
Args: | |
steps (int): Initial global steps. | |
epochs (int): Initial global epochs. | |
data_loader (dict): Dict of data loaders. It must contrain "train" and "dev" loaders. | |
model (dict): Dict of models. It must contrain "generator" and "discriminator" models. | |
criterion (dict): Dict of criterions. It must contrain "stft" and "mse" criterions. | |
optimizer (dict): Dict of optimizers. It must contrain "generator" and "discriminator" optimizers. | |
scheduler (dict): Dict of schedulers. It must contrain "generator" and "discriminator" schedulers. | |
config (dict): Config dict loaded from yaml format configuration file. | |
device (torch.deive): Pytorch device instance. | |
""" | |
self.steps = steps | |
self.epochs = epochs | |
self.data_loader = data_loader | |
self.sampler = sampler | |
self.model = model | |
self.criterion = criterion | |
self.optimizer = optimizer | |
self.scheduler = scheduler | |
self.config = config | |
self.device = device | |
self.writer = SummaryWriter(config["outdir"]) | |
self.finish_train = False | |
self.total_train_loss = defaultdict(float) | |
self.total_eval_loss = defaultdict(float) | |
def run(self): | |
"""Run training.""" | |
self.tqdm = tqdm( | |
initial=self.steps, total=self.config["train_max_steps"], desc="[train]" | |
) | |
while True: | |
# train one epoch | |
self._train_epoch() | |
# check whether training is finished | |
if self.finish_train: | |
break | |
self.tqdm.close() | |
logging.info("Finished training.") | |
def save_checkpoint(self, checkpoint_path): | |
"""Save checkpoint. | |
Args: | |
checkpoint_path (str): Checkpoint path to be saved. | |
""" | |
state_dict = { | |
"optimizer": { | |
"generator": self.optimizer["generator"].state_dict(), | |
"discriminator": self.optimizer["discriminator"].state_dict(), | |
}, | |
"scheduler": { | |
"generator": self.scheduler["generator"].state_dict(), | |
"discriminator": self.scheduler["discriminator"].state_dict(), | |
}, | |
"steps": self.steps, | |
"epochs": self.epochs, | |
} | |
if self.config["distributed"]: | |
state_dict["model"] = { | |
"generator": self.model["generator"].module.state_dict(), | |
"discriminator": self.model["discriminator"].module.state_dict(), | |
} | |
else: | |
state_dict["model"] = { | |
"generator": self.model["generator"].state_dict(), | |
"discriminator": self.model["discriminator"].state_dict(), | |
} | |
if not os.path.exists(os.path.dirname(checkpoint_path)): | |
os.makedirs(os.path.dirname(checkpoint_path)) | |
torch.save(state_dict, checkpoint_path) | |
def load_checkpoint(self, checkpoint_path, load_only_params=False): | |
"""Load checkpoint. | |
Args: | |
checkpoint_path (str): Checkpoint path to be loaded. | |
load_only_params (bool): Whether to load only model parameters. | |
""" | |
state_dict = torch.load(checkpoint_path, map_location="cpu") | |
if self.config["distributed"]: | |
self.model["generator"].module.load_state_dict( | |
state_dict["model"]["generator"] | |
) | |
self.model["discriminator"].module.load_state_dict( | |
state_dict["model"]["discriminator"] | |
) | |
else: | |
self.model["generator"].load_state_dict(state_dict["model"]["generator"]) | |
self.model["discriminator"].load_state_dict( | |
state_dict["model"]["discriminator"] | |
) | |
if not load_only_params: | |
self.steps = state_dict["steps"] | |
self.epochs = state_dict["epochs"] | |
self.optimizer["generator"].load_state_dict( | |
state_dict["optimizer"]["generator"] | |
) | |
self.optimizer["discriminator"].load_state_dict( | |
state_dict["optimizer"]["discriminator"] | |
) | |
self.scheduler["generator"].load_state_dict( | |
state_dict["scheduler"]["generator"] | |
) | |
self.scheduler["discriminator"].load_state_dict( | |
state_dict["scheduler"]["discriminator"] | |
) | |
def _train_step(self, batch): | |
"""Train model one step.""" | |
# parse batch | |
x, y = batch | |
x = tuple([x_.to(self.device) for x_ in x]) | |
y = y.to(self.device) | |
####################### | |
# Generator # | |
####################### | |
if self.steps > self.config.get("generator_train_start_steps", 0): | |
y_ = self.model["generator"](*x) | |
# reconstruct the signal from multi-band signal | |
if self.config["generator_params"]["out_channels"] > 1: | |
y_mb_ = y_ | |
y_ = self.criterion["pqmf"].synthesis(y_mb_) | |
# initialize | |
gen_loss = 0.0 | |
# multi-resolution sfft loss | |
if self.config["use_stft_loss"]: | |
sc_loss, mag_loss = self.criterion["stft"](y_, y) | |
gen_loss += sc_loss + mag_loss | |
self.total_train_loss[ | |
"train/spectral_convergence_loss" | |
] += sc_loss.item() | |
self.total_train_loss[ | |
"train/log_stft_magnitude_loss" | |
] += mag_loss.item() | |
# subband multi-resolution stft loss | |
if self.config["use_subband_stft_loss"]: | |
gen_loss *= 0.5 # for balancing with subband stft loss | |
y_mb = self.criterion["pqmf"].analysis(y) | |
sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb) | |
gen_loss += 0.5 * (sub_sc_loss + sub_mag_loss) | |
self.total_train_loss[ | |
"train/sub_spectral_convergence_loss" | |
] += sub_sc_loss.item() | |
self.total_train_loss[ | |
"train/sub_log_stft_magnitude_loss" | |
] += sub_mag_loss.item() | |
# mel spectrogram loss | |
if self.config["use_mel_loss"]: | |
mel_loss = self.criterion["mel"](y_, y) | |
gen_loss += mel_loss | |
self.total_train_loss["train/mel_loss"] += mel_loss.item() | |
# weighting aux loss | |
gen_loss *= self.config.get("lambda_aux", 1.0) | |
# adversarial loss | |
if self.steps > self.config["discriminator_train_start_steps"]: | |
p_ = self.model["discriminator"](y_) | |
adv_loss = self.criterion["gen_adv"](p_) | |
self.total_train_loss["train/adversarial_loss"] += adv_loss.item() | |
# feature matching loss | |
if self.config["use_feat_match_loss"]: | |
# no need to track gradients | |
with torch.no_grad(): | |
p = self.model["discriminator"](y) | |
fm_loss = self.criterion["feat_match"](p_, p) | |
self.total_train_loss[ | |
"train/feature_matching_loss" | |
] += fm_loss.item() | |
adv_loss += self.config["lambda_feat_match"] * fm_loss | |
# add adversarial loss to generator loss | |
gen_loss += self.config["lambda_adv"] * adv_loss | |
self.total_train_loss["train/generator_loss"] += gen_loss.item() | |
# update generator | |
self.optimizer["generator"].zero_grad() | |
gen_loss.backward() | |
if self.config["generator_grad_norm"] > 0: | |
torch.nn.utils.clip_grad_norm_( | |
self.model["generator"].parameters(), | |
self.config["generator_grad_norm"], | |
) | |
self.optimizer["generator"].step() | |
self.scheduler["generator"].step() | |
####################### | |
# Discriminator # | |
####################### | |
if self.steps > self.config["discriminator_train_start_steps"]: | |
# re-compute y_ which leads better quality | |
with torch.no_grad(): | |
y_ = self.model["generator"](*x) | |
if self.config["generator_params"]["out_channels"] > 1: | |
y_ = self.criterion["pqmf"].synthesis(y_) | |
# discriminator loss | |
p = self.model["discriminator"](y) | |
p_ = self.model["discriminator"](y_.detach()) | |
real_loss, fake_loss = self.criterion["dis_adv"](p_, p) | |
dis_loss = real_loss + fake_loss | |
self.total_train_loss["train/real_loss"] += real_loss.item() | |
self.total_train_loss["train/fake_loss"] += fake_loss.item() | |
self.total_train_loss["train/discriminator_loss"] += dis_loss.item() | |
# update discriminator | |
self.optimizer["discriminator"].zero_grad() | |
dis_loss.backward() | |
if self.config["discriminator_grad_norm"] > 0: | |
torch.nn.utils.clip_grad_norm_( | |
self.model["discriminator"].parameters(), | |
self.config["discriminator_grad_norm"], | |
) | |
self.optimizer["discriminator"].step() | |
self.scheduler["discriminator"].step() | |
# update counts | |
self.steps += 1 | |
self.tqdm.update(1) | |
self._check_train_finish() | |
def _train_epoch(self): | |
"""Train model one epoch.""" | |
for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1): | |
# train one step | |
self._train_step(batch) | |
# check interval | |
if self.config["rank"] == 0: | |
self._check_log_interval() | |
self._check_eval_interval() | |
self._check_save_interval() | |
# check whether training is finished | |
if self.finish_train: | |
return | |
# update | |
self.epochs += 1 | |
self.train_steps_per_epoch = train_steps_per_epoch | |
logging.info( | |
f"(Steps: {self.steps}) Finished {self.epochs} epoch training " | |
f"({self.train_steps_per_epoch} steps per epoch)." | |
) | |
# needed for shuffle in distributed training | |
if self.config["distributed"]: | |
self.sampler["train"].set_epoch(self.epochs) | |
def _eval_step(self, batch): | |
"""Evaluate model one step.""" | |
# parse batch | |
x, y = batch | |
x = tuple([x_.to(self.device) for x_ in x]) | |
y = y.to(self.device) | |
####################### | |
# Generator # | |
####################### | |
y_ = self.model["generator"](*x) | |
if self.config["generator_params"]["out_channels"] > 1: | |
y_mb_ = y_ | |
y_ = self.criterion["pqmf"].synthesis(y_mb_) | |
# initialize | |
aux_loss = 0.0 | |
# multi-resolution stft loss | |
if self.config["use_stft_loss"]: | |
sc_loss, mag_loss = self.criterion["stft"](y_, y) | |
aux_loss += sc_loss + mag_loss | |
self.total_eval_loss["eval/spectral_convergence_loss"] += sc_loss.item() | |
self.total_eval_loss["eval/log_stft_magnitude_loss"] += mag_loss.item() | |
# subband multi-resolution stft loss | |
if self.config.get("use_subband_stft_loss", False): | |
aux_loss *= 0.5 # for balancing with subband stft loss | |
y_mb = self.criterion["pqmf"].analysis(y) | |
sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb) | |
self.total_eval_loss[ | |
"eval/sub_spectral_convergence_loss" | |
] += sub_sc_loss.item() | |
self.total_eval_loss[ | |
"eval/sub_log_stft_magnitude_loss" | |
] += sub_mag_loss.item() | |
aux_loss += 0.5 * (sub_sc_loss + sub_mag_loss) | |
# mel spectrogram loss | |
if self.config["use_mel_loss"]: | |
mel_loss = self.criterion["mel"](y_, y) | |
aux_loss += mel_loss | |
self.total_eval_loss["eval/mel_loss"] += mel_loss.item() | |
# weighting stft loss | |
aux_loss *= self.config.get("lambda_aux", 1.0) | |
# adversarial loss | |
p_ = self.model["discriminator"](y_) | |
adv_loss = self.criterion["gen_adv"](p_) | |
gen_loss = aux_loss + self.config["lambda_adv"] * adv_loss | |
# feature matching loss | |
if self.config["use_feat_match_loss"]: | |
p = self.model["discriminator"](y) | |
fm_loss = self.criterion["feat_match"](p_, p) | |
self.total_eval_loss["eval/feature_matching_loss"] += fm_loss.item() | |
gen_loss += ( | |
self.config["lambda_adv"] * self.config["lambda_feat_match"] * fm_loss | |
) | |
####################### | |
# Discriminator # | |
####################### | |
p = self.model["discriminator"](y) | |
p_ = self.model["discriminator"](y_) | |
# discriminator loss | |
real_loss, fake_loss = self.criterion["dis_adv"](p_, p) | |
dis_loss = real_loss + fake_loss | |
# add to total eval loss | |
self.total_eval_loss["eval/adversarial_loss"] += adv_loss.item() | |
self.total_eval_loss["eval/generator_loss"] += gen_loss.item() | |
self.total_eval_loss["eval/real_loss"] += real_loss.item() | |
self.total_eval_loss["eval/fake_loss"] += fake_loss.item() | |
self.total_eval_loss["eval/discriminator_loss"] += dis_loss.item() | |
def _eval_epoch(self): | |
"""Evaluate model one epoch.""" | |
logging.info(f"(Steps: {self.steps}) Start evaluation.") | |
# change mode | |
for key in self.model.keys(): | |
self.model[key].eval() | |
# calculate loss for each batch | |
for eval_steps_per_epoch, batch in enumerate( | |
tqdm(self.data_loader["dev"], desc="[eval]"), 1 | |
): | |
# eval one step | |
self._eval_step(batch) | |
# save intermediate result | |
if eval_steps_per_epoch == 1: | |
self._genearete_and_save_intermediate_result(batch) | |
logging.info( | |
f"(Steps: {self.steps}) Finished evaluation " | |
f"({eval_steps_per_epoch} steps per epoch)." | |
) | |
# average loss | |
for key in self.total_eval_loss.keys(): | |
self.total_eval_loss[key] /= eval_steps_per_epoch | |
logging.info( | |
f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}." | |
) | |
# record | |
self._write_to_tensorboard(self.total_eval_loss) | |
# reset | |
self.total_eval_loss = defaultdict(float) | |
# restore mode | |
for key in self.model.keys(): | |
self.model[key].train() | |
def _genearete_and_save_intermediate_result(self, batch): | |
"""Generate and save intermediate result.""" | |
# delayed import to avoid error related backend error | |
import matplotlib.pyplot as plt | |
# generate | |
x_batch, y_batch = batch | |
x_batch = tuple([x.to(self.device) for x in x_batch]) | |
y_batch = y_batch.to(self.device) | |
y_batch_ = self.model["generator"](*x_batch) | |
if self.config["generator_params"]["out_channels"] > 1: | |
y_batch_ = self.criterion["pqmf"].synthesis(y_batch_) | |
# check directory | |
dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps") | |
if not os.path.exists(dirname): | |
os.makedirs(dirname) | |
for idx, (y, y_) in enumerate(zip(y_batch, y_batch_), 1): | |
# convert to ndarray | |
y, y_ = y.view(-1).cpu().numpy(), y_.view(-1).cpu().numpy() | |
# plot figure and save it | |
figname = os.path.join(dirname, f"{idx}.png") | |
plt.subplot(2, 1, 1) | |
plt.plot(y) | |
plt.title("groundtruth speech") | |
plt.subplot(2, 1, 2) | |
plt.plot(y_) | |
plt.title(f"generated speech @ {self.steps} steps") | |
plt.tight_layout() | |
plt.savefig(figname) | |
plt.close() | |
# save as wavfile | |
y = np.clip(y, -1, 1) | |
y_ = np.clip(y_, -1, 1) | |
sf.write( | |
figname.replace(".png", "_ref.wav"), | |
y, | |
self.config["sampling_rate"], | |
"PCM_16", | |
) | |
sf.write( | |
figname.replace(".png", "_gen.wav"), | |
y_, | |
self.config["sampling_rate"], | |
"PCM_16", | |
) | |
if idx >= self.config["num_save_intermediate_results"]: | |
break | |
def _write_to_tensorboard(self, loss): | |
"""Write to tensorboard.""" | |
for key, value in loss.items(): | |
self.writer.add_scalar(key, value, self.steps) | |
def _check_save_interval(self): | |
if self.steps % self.config["save_interval_steps"] == 0: | |
self.save_checkpoint( | |
os.path.join(self.config["outdir"], f"checkpoint-{self.steps}steps.pkl") | |
) | |
logging.info(f"Successfully saved checkpoint @ {self.steps} steps.") | |
def _check_eval_interval(self): | |
if self.steps % self.config["eval_interval_steps"] == 0: | |
self._eval_epoch() | |
def _check_log_interval(self): | |
if self.steps % self.config["log_interval_steps"] == 0: | |
for key in self.total_train_loss.keys(): | |
self.total_train_loss[key] /= self.config["log_interval_steps"] | |
logging.info( | |
f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}." | |
) | |
self._write_to_tensorboard(self.total_train_loss) | |
# reset | |
self.total_train_loss = defaultdict(float) | |
def _check_train_finish(self): | |
if self.steps >= self.config["train_max_steps"]: | |
self.finish_train = True | |
class Collater(object): | |
"""Customized collater for Pytorch DataLoader in training.""" | |
def __init__( | |
self, | |
batch_max_steps=20480, | |
hop_size=256, | |
aux_context_window=2, | |
use_noise_input=False, | |
): | |
"""Initialize customized collater for PyTorch DataLoader. | |
Args: | |
batch_max_steps (int): The maximum length of input signal in batch. | |
hop_size (int): Hop size of auxiliary features. | |
aux_context_window (int): Context window size for auxiliary feature conv. | |
use_noise_input (bool): Whether to use noise input. | |
""" | |
if batch_max_steps % hop_size != 0: | |
batch_max_steps += -(batch_max_steps % hop_size) | |
assert batch_max_steps % hop_size == 0 | |
self.batch_max_steps = batch_max_steps | |
self.batch_max_frames = batch_max_steps // hop_size | |
self.hop_size = hop_size | |
self.aux_context_window = aux_context_window | |
self.use_noise_input = use_noise_input | |
# set useful values in random cutting | |
self.start_offset = aux_context_window | |
self.end_offset = -(self.batch_max_frames + aux_context_window) | |
self.mel_threshold = self.batch_max_frames + 2 * aux_context_window | |
def __call__(self, batch): | |
"""Convert into batch tensors. | |
Args: | |
batch (list): list of tuple of the pair of audio and features. | |
Returns: | |
Tensor: Gaussian noise batch (B, 1, T). | |
Tensor: Auxiliary feature batch (B, C, T'), where | |
T = (T' - 2 * aux_context_window) * hop_size. | |
Tensor: Target signal batch (B, 1, T). | |
""" | |
# check length | |
batch = [ | |
self._adjust_length(*b) for b in batch if len(b[1]) > self.mel_threshold | |
] | |
xs, cs = [b[0] for b in batch], [b[1] for b in batch] | |
# make batch with random cut | |
c_lengths = [len(c) for c in cs] | |
start_frames = np.array( | |
[ | |
np.random.randint(self.start_offset, cl + self.end_offset) | |
for cl in c_lengths | |
] | |
) | |
x_starts = start_frames * self.hop_size | |
x_ends = x_starts + self.batch_max_steps | |
c_starts = start_frames - self.aux_context_window | |
c_ends = start_frames + self.batch_max_frames + self.aux_context_window | |
y_batch = [x[start:end] for x, start, end in zip(xs, x_starts, x_ends)] | |
c_batch = [c[start:end] for c, start, end in zip(cs, c_starts, c_ends)] | |
# convert each batch to tensor, asuume that each item in batch has the same length | |
y_batch = torch.tensor(y_batch, dtype=torch.float).unsqueeze(1) # (B, 1, T) | |
c_batch = torch.tensor(c_batch, dtype=torch.float).transpose(2, 1) # (B, C, T') | |
# make input noise signal batch tensor | |
if self.use_noise_input: | |
z_batch = torch.randn(y_batch.size()) # (B, 1, T) | |
return (z_batch, c_batch), y_batch | |
else: | |
return (c_batch,), y_batch | |
def _adjust_length(self, x, c): | |
"""Adjust the audio and feature lengths. | |
Note: | |
Basically we assume that the length of x and c are adjusted | |
through preprocessing stage, but if we use other library processed | |
features, this process will be needed. | |
""" | |
if len(x) < len(c) * self.hop_size: | |
x = np.pad(x, (0, len(c) * self.hop_size - len(x)), mode="edge") | |
# check the legnth is valid | |
assert len(x) == len(c) * self.hop_size | |
return x, c | |
def main(): | |
"""Run training process.""" | |
parser = argparse.ArgumentParser( | |
description="Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)." | |
) | |
parser.add_argument( | |
"--train-wav-scp", | |
default=None, | |
type=str, | |
help="kaldi-style wav.scp file for training. " | |
"you need to specify either train-*-scp or train-dumpdir.", | |
) | |
parser.add_argument( | |
"--train-feats-scp", | |
default=None, | |
type=str, | |
help="kaldi-style feats.scp file for training. " | |
"you need to specify either train-*-scp or train-dumpdir.", | |
) | |
parser.add_argument( | |
"--train-segments", | |
default=None, | |
type=str, | |
help="kaldi-style segments file for training.", | |
) | |
parser.add_argument( | |
"--train-dumpdir", | |
default=None, | |
type=str, | |
help="directory including training data. " | |
"you need to specify either train-*-scp or train-dumpdir.", | |
) | |
parser.add_argument( | |
"--dev-wav-scp", | |
default=None, | |
type=str, | |
help="kaldi-style wav.scp file for validation. " | |
"you need to specify either dev-*-scp or dev-dumpdir.", | |
) | |
parser.add_argument( | |
"--dev-feats-scp", | |
default=None, | |
type=str, | |
help="kaldi-style feats.scp file for vaidation. " | |
"you need to specify either dev-*-scp or dev-dumpdir.", | |
) | |
parser.add_argument( | |
"--dev-segments", | |
default=None, | |
type=str, | |
help="kaldi-style segments file for validation.", | |
) | |
parser.add_argument( | |
"--dev-dumpdir", | |
default=None, | |
type=str, | |
help="directory including development data. " | |
"you need to specify either dev-*-scp or dev-dumpdir.", | |
) | |
parser.add_argument( | |
"--outdir", | |
type=str, | |
required=True, | |
help="directory to save checkpoints.", | |
) | |
parser.add_argument( | |
"--config", | |
type=str, | |
required=True, | |
help="yaml format configuration file.", | |
) | |
parser.add_argument( | |
"--pretrain", | |
default="", | |
type=str, | |
nargs="?", | |
help='checkpoint file path to load pretrained params. (default="")', | |
) | |
parser.add_argument( | |
"--resume", | |
default="", | |
type=str, | |
nargs="?", | |
help='checkpoint file path to resume training. (default="")', | |
) | |
parser.add_argument( | |
"--verbose", | |
type=int, | |
default=1, | |
help="logging level. higher is more logging. (default=1)", | |
) | |
parser.add_argument( | |
"--rank", | |
"--local_rank", | |
default=0, | |
type=int, | |
help="rank for distributed training. no need to explictly specify.", | |
) | |
args = parser.parse_args() | |
args.distributed = False | |
if not torch.cuda.is_available(): | |
device = torch.device("cpu") | |
else: | |
device = torch.device("cuda") | |
# effective when using fixed size inputs | |
# see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 | |
torch.backends.cudnn.benchmark = True | |
torch.cuda.set_device(args.rank) | |
# setup for distributed training | |
# see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed | |
if "WORLD_SIZE" in os.environ: | |
args.world_size = int(os.environ["WORLD_SIZE"]) | |
args.distributed = args.world_size > 1 | |
if args.distributed: | |
torch.distributed.init_process_group(backend="nccl", init_method="env://") | |
# suppress logging for distributed training | |
if args.rank != 0: | |
sys.stdout = open(os.devnull, "w") | |
# set logger | |
if args.verbose > 1: | |
logging.basicConfig( | |
level=logging.DEBUG, | |
stream=sys.stdout, | |
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
) | |
elif args.verbose > 0: | |
logging.basicConfig( | |
level=logging.INFO, | |
stream=sys.stdout, | |
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
) | |
else: | |
logging.basicConfig( | |
level=logging.WARN, | |
stream=sys.stdout, | |
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
) | |
logging.warning("Skip DEBUG/INFO messages") | |
# check directory existence | |
if not os.path.exists(args.outdir): | |
os.makedirs(args.outdir) | |
# check arguments | |
if (args.train_feats_scp is not None and args.train_dumpdir is not None) or ( | |
args.train_feats_scp is None and args.train_dumpdir is None | |
): | |
raise ValueError("Please specify either --train-dumpdir or --train-*-scp.") | |
if (args.dev_feats_scp is not None and args.dev_dumpdir is not None) or ( | |
args.dev_feats_scp is None and args.dev_dumpdir is None | |
): | |
raise ValueError("Please specify either --dev-dumpdir or --dev-*-scp.") | |
# load and save config | |
with open(args.config) as f: | |
config = yaml.load(f, Loader=yaml.Loader) | |
config.update(vars(args)) | |
config["version"] = parallel_wavegan.__version__ # add version info | |
with open(os.path.join(args.outdir, "config.yml"), "w") as f: | |
yaml.dump(config, f, Dumper=yaml.Dumper) | |
for key, value in config.items(): | |
logging.info(f"{key} = {value}") | |
# get dataset | |
if config["remove_short_samples"]: | |
mel_length_threshold = config["batch_max_steps"] // config[ | |
"hop_size" | |
] + 2 * config["generator_params"].get("aux_context_window", 0) | |
else: | |
mel_length_threshold = None | |
if args.train_wav_scp is None or args.dev_wav_scp is None: | |
if config["format"] == "hdf5": | |
audio_query, mel_query = "*.h5", "*.h5" | |
audio_load_fn = lambda x: read_hdf5(x, "wave") # NOQA | |
mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA | |
elif config["format"] == "npy": | |
audio_query, mel_query = "*-wave.npy", "*-feats.npy" | |
audio_load_fn = np.load | |
mel_load_fn = np.load | |
else: | |
raise ValueError("support only hdf5 or npy format.") | |
if args.train_dumpdir is not None: | |
train_dataset = AudioMelDataset( | |
root_dir=args.train_dumpdir, | |
audio_query=audio_query, | |
mel_query=mel_query, | |
audio_load_fn=audio_load_fn, | |
mel_load_fn=mel_load_fn, | |
mel_length_threshold=mel_length_threshold, | |
allow_cache=config.get("allow_cache", False), # keep compatibility | |
) | |
else: | |
train_dataset = AudioMelSCPDataset( | |
wav_scp=args.train_wav_scp, | |
feats_scp=args.train_feats_scp, | |
segments=args.train_segments, | |
mel_length_threshold=mel_length_threshold, | |
allow_cache=config.get("allow_cache", False), # keep compatibility | |
) | |
logging.info(f"The number of training files = {len(train_dataset)}.") | |
if args.dev_dumpdir is not None: | |
dev_dataset = AudioMelDataset( | |
root_dir=args.dev_dumpdir, | |
audio_query=audio_query, | |
mel_query=mel_query, | |
audio_load_fn=audio_load_fn, | |
mel_load_fn=mel_load_fn, | |
mel_length_threshold=mel_length_threshold, | |
allow_cache=config.get("allow_cache", False), # keep compatibility | |
) | |
else: | |
dev_dataset = AudioMelSCPDataset( | |
wav_scp=args.dev_wav_scp, | |
feats_scp=args.dev_feats_scp, | |
segments=args.dev_segments, | |
mel_length_threshold=mel_length_threshold, | |
allow_cache=config.get("allow_cache", False), # keep compatibility | |
) | |
logging.info(f"The number of development files = {len(dev_dataset)}.") | |
dataset = { | |
"train": train_dataset, | |
"dev": dev_dataset, | |
} | |
# get data loader | |
collater = Collater( | |
batch_max_steps=config["batch_max_steps"], | |
hop_size=config["hop_size"], | |
# keep compatibility | |
aux_context_window=config["generator_params"].get("aux_context_window", 0), | |
# keep compatibility | |
use_noise_input=config.get("generator_type", "ParallelWaveGANGenerator") | |
in ["ParallelWaveGANGenerator"], | |
) | |
sampler = {"train": None, "dev": None} | |
if args.distributed: | |
# setup sampler for distributed training | |
from torch.utils.data.distributed import DistributedSampler | |
sampler["train"] = DistributedSampler( | |
dataset=dataset["train"], | |
num_replicas=args.world_size, | |
rank=args.rank, | |
shuffle=True, | |
) | |
sampler["dev"] = DistributedSampler( | |
dataset=dataset["dev"], | |
num_replicas=args.world_size, | |
rank=args.rank, | |
shuffle=False, | |
) | |
data_loader = { | |
"train": DataLoader( | |
dataset=dataset["train"], | |
shuffle=False if args.distributed else True, | |
collate_fn=collater, | |
batch_size=config["batch_size"], | |
num_workers=config["num_workers"], | |
sampler=sampler["train"], | |
pin_memory=config["pin_memory"], | |
), | |
"dev": DataLoader( | |
dataset=dataset["dev"], | |
shuffle=False if args.distributed else True, | |
collate_fn=collater, | |
batch_size=config["batch_size"], | |
num_workers=config["num_workers"], | |
sampler=sampler["dev"], | |
pin_memory=config["pin_memory"], | |
), | |
} | |
# define models | |
generator_class = getattr( | |
parallel_wavegan.models, | |
# keep compatibility | |
config.get("generator_type", "ParallelWaveGANGenerator"), | |
) | |
discriminator_class = getattr( | |
parallel_wavegan.models, | |
# keep compatibility | |
config.get("discriminator_type", "ParallelWaveGANDiscriminator"), | |
) | |
model = { | |
"generator": generator_class( | |
**config["generator_params"], | |
).to(device), | |
"discriminator": discriminator_class( | |
**config["discriminator_params"], | |
).to(device), | |
} | |
# define criterions | |
criterion = { | |
"gen_adv": GeneratorAdversarialLoss( | |
# keep compatibility | |
**config.get("generator_adv_loss_params", {}) | |
).to(device), | |
"dis_adv": DiscriminatorAdversarialLoss( | |
# keep compatibility | |
**config.get("discriminator_adv_loss_params", {}) | |
).to(device), | |
} | |
if config.get("use_stft_loss", True): # keep compatibility | |
config["use_stft_loss"] = True | |
criterion["stft"] = MultiResolutionSTFTLoss( | |
**config["stft_loss_params"], | |
).to(device) | |
if config.get("use_subband_stft_loss", False): # keep compatibility | |
assert config["generator_params"]["out_channels"] > 1 | |
criterion["sub_stft"] = MultiResolutionSTFTLoss( | |
**config["subband_stft_loss_params"], | |
).to(device) | |
else: | |
config["use_subband_stft_loss"] = False | |
if config.get("use_feat_match_loss", False): # keep compatibility | |
criterion["feat_match"] = FeatureMatchLoss( | |
# keep compatibility | |
**config.get("feat_match_loss_params", {}), | |
).to(device) | |
else: | |
config["use_feat_match_loss"] = False | |
if config.get("use_mel_loss", False): # keep compatibility | |
if config.get("mel_loss_params", None) is None: | |
criterion["mel"] = MelSpectrogramLoss( | |
fs=config["sampling_rate"], | |
fft_size=config["fft_size"], | |
hop_size=config["hop_size"], | |
win_length=config["win_length"], | |
window=config["window"], | |
num_mels=config["num_mels"], | |
fmin=config["fmin"], | |
fmax=config["fmax"], | |
).to(device) | |
else: | |
criterion["mel"] = MelSpectrogramLoss( | |
**config["mel_loss_params"], | |
).to(device) | |
else: | |
config["use_mel_loss"] = False | |
# define special module for subband processing | |
if config["generator_params"]["out_channels"] > 1: | |
criterion["pqmf"] = PQMF( | |
subbands=config["generator_params"]["out_channels"], | |
# keep compatibility | |
**config.get("pqmf_params", {}), | |
).to(device) | |
# define optimizers and schedulers | |
generator_optimizer_class = getattr( | |
parallel_wavegan.optimizers, | |
# keep compatibility | |
config.get("generator_optimizer_type", "RAdam"), | |
) | |
discriminator_optimizer_class = getattr( | |
parallel_wavegan.optimizers, | |
# keep compatibility | |
config.get("discriminator_optimizer_type", "RAdam"), | |
) | |
optimizer = { | |
"generator": generator_optimizer_class( | |
model["generator"].parameters(), | |
**config["generator_optimizer_params"], | |
), | |
"discriminator": discriminator_optimizer_class( | |
model["discriminator"].parameters(), | |
**config["discriminator_optimizer_params"], | |
), | |
} | |
generator_scheduler_class = getattr( | |
torch.optim.lr_scheduler, | |
# keep compatibility | |
config.get("generator_scheduler_type", "StepLR"), | |
) | |
discriminator_scheduler_class = getattr( | |
torch.optim.lr_scheduler, | |
# keep compatibility | |
config.get("discriminator_scheduler_type", "StepLR"), | |
) | |
scheduler = { | |
"generator": generator_scheduler_class( | |
optimizer=optimizer["generator"], | |
**config["generator_scheduler_params"], | |
), | |
"discriminator": discriminator_scheduler_class( | |
optimizer=optimizer["discriminator"], | |
**config["discriminator_scheduler_params"], | |
), | |
} | |
if args.distributed: | |
# wrap model for distributed training | |
try: | |
from apex.parallel import DistributedDataParallel | |
except ImportError: | |
raise ImportError( | |
"apex is not installed. please check https://github.com/NVIDIA/apex." | |
) | |
model["generator"] = DistributedDataParallel(model["generator"]) | |
model["discriminator"] = DistributedDataParallel(model["discriminator"]) | |
# show settings | |
logging.info(model["generator"]) | |
logging.info(model["discriminator"]) | |
logging.info(optimizer["generator"]) | |
logging.info(optimizer["discriminator"]) | |
logging.info(scheduler["generator"]) | |
logging.info(scheduler["discriminator"]) | |
for criterion_ in criterion.values(): | |
logging.info(criterion_) | |
# define trainer | |
trainer = Trainer( | |
steps=0, | |
epochs=0, | |
data_loader=data_loader, | |
sampler=sampler, | |
model=model, | |
criterion=criterion, | |
optimizer=optimizer, | |
scheduler=scheduler, | |
config=config, | |
device=device, | |
) | |
# load pretrained parameters from checkpoint | |
if len(args.pretrain) != 0: | |
trainer.load_checkpoint(args.pretrain, load_only_params=True) | |
logging.info(f"Successfully load parameters from {args.pretrain}.") | |
# resume from checkpoint | |
if len(args.resume) != 0: | |
trainer.load_checkpoint(args.resume) | |
logging.info(f"Successfully resumed from {args.resume}.") | |
# run training loop | |
try: | |
trainer.run() | |
finally: | |
trainer.save_checkpoint( | |
os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl") | |
) | |
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.") | |
if __name__ == "__main__": | |
main() | |