|
|
|
import os |
|
import pdb |
|
import logging |
|
import argparse |
|
from pathlib import Path |
|
import torch, platform |
|
from pytorch_lightning import seed_everything |
|
from pytorch_lightning import Trainer |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
from pytorch_lightning.strategies import DDPStrategy |
|
from AR.data.data_module import Text2SemanticDataModule |
|
from AR.models.t2s_lightning_module import Text2SemanticLightningModule |
|
from AR.utils.io import load_yaml_config |
|
from GPT_SoVITS.utils.wandb_logger import WandbLoggerWithConfig |
|
|
|
logging.getLogger("numba").setLevel(logging.WARNING) |
|
logging.getLogger("matplotlib").setLevel(logging.WARNING) |
|
torch.set_float32_matmul_precision("high") |
|
|
|
def my_model_ckpt( |
|
config, |
|
if_save_latest, |
|
if_save_every_weights, |
|
half_weights_save_dir, |
|
exp_name, |
|
**kwargs, |
|
): |
|
if if_save_latest: |
|
kwargs["save_last"] = True |
|
callbacks = [] |
|
callbacks.append( |
|
ModelCheckpoint( |
|
**kwargs, |
|
filename=exp_name + "_{epoch}-{step}", |
|
) |
|
) |
|
return callbacks[0] |
|
|
|
def main(args): |
|
config = load_yaml_config(args.config_file) |
|
|
|
output_dir = Path(config["output_dir"]) |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
ckpt_dir = output_dir / "ckpt" |
|
ckpt_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
seed_everything(config["train"]["seed"], workers=True) |
|
|
|
|
|
wandb_logger = WandbLoggerWithConfig(config=config) |
|
|
|
ckpt_callback = my_model_ckpt( |
|
config=config, |
|
if_save_latest=config["train"]["if_save_latest"], |
|
if_save_every_weights=config["train"]["if_save_every_weights"], |
|
half_weights_save_dir=config["train"]["half_weights_save_dir"], |
|
exp_name=config["train"]["exp_name"], |
|
save_top_k=-1, |
|
monitor="loss", |
|
mode="min", |
|
save_on_train_epoch_end=True, |
|
every_n_epochs=config["train"]["save_every_n_epoch"], |
|
dirpath=ckpt_dir, |
|
) |
|
|
|
|
|
data_module = Text2SemanticDataModule( |
|
config=config, |
|
train_semantic_path=config.get("train_semantic_path", ""), |
|
train_phoneme_path=config.get("train_phoneme_path", "") |
|
) |
|
|
|
|
|
model = Text2SemanticLightningModule( |
|
config=config, |
|
output_dir=output_dir, |
|
is_train=True |
|
) |
|
|
|
|
|
wandb_logger.watch_model(model) |
|
|
|
trainer = Trainer( |
|
max_epochs=config["train"]["epochs"], |
|
accelerator="gpu" if torch.cuda.is_available() else "cpu", |
|
devices=-1 if torch.cuda.is_available() else 1, |
|
benchmark=False, |
|
fast_dev_run=False, |
|
strategy=DDPStrategy( |
|
process_group_backend="nccl" if platform.system() != "Windows" else "gloo" |
|
) if torch.cuda.is_available() else "auto", |
|
precision=config["train"]["precision"], |
|
logger=wandb_logger, |
|
callbacks=[ckpt_callback], |
|
use_distributed_sampler=False, |
|
) |
|
|
|
trainer.fit(model, data_module) |
|
wandb.finish() |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-c", |
|
"--config_file", |
|
type=str, |
|
default="configs/s1.yaml", |
|
help="path of config file", |
|
) |
|
args = parser.parse_args() |
|
main(args) |
|
|