sound / GPT_SoVITS /s1_train.py
Alyosha11's picture
Add files using upload-large-folder tool
9fd672f verified
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
import os
import pdb
import logging
import argparse
from pathlib import Path
import torch, platform
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
from GPT_SoVITS.utils.wandb_logger import WandbLoggerWithConfig
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
torch.set_float32_matmul_precision("high")
def my_model_ckpt(
config,
if_save_latest,
if_save_every_weights,
half_weights_save_dir,
exp_name,
**kwargs,
):
if if_save_latest:
kwargs["save_last"] = True
callbacks = []
callbacks.append(
ModelCheckpoint(
**kwargs,
filename=exp_name + "_{epoch}-{step}",
)
)
return callbacks[0]
def main(args):
config = load_yaml_config(args.config_file)
output_dir = Path(config["output_dir"])
output_dir.mkdir(parents=True, exist_ok=True)
ckpt_dir = output_dir / "ckpt"
ckpt_dir.mkdir(parents=True, exist_ok=True)
seed_everything(config["train"]["seed"], workers=True)
# Initialize wandb logger
wandb_logger = WandbLoggerWithConfig(config=config)
ckpt_callback = my_model_ckpt(
config=config,
if_save_latest=config["train"]["if_save_latest"],
if_save_every_weights=config["train"]["if_save_every_weights"],
half_weights_save_dir=config["train"]["half_weights_save_dir"],
exp_name=config["train"]["exp_name"],
save_top_k=-1,
monitor="loss",
mode="min",
save_on_train_epoch_end=True,
every_n_epochs=config["train"]["save_every_n_epoch"],
dirpath=ckpt_dir,
)
# Create data module
data_module = Text2SemanticDataModule(
config=config,
train_semantic_path=config.get("train_semantic_path", ""),
train_phoneme_path=config.get("train_phoneme_path", "")
)
# Initialize model with correct parameters
model = Text2SemanticLightningModule(
config=config,
output_dir=output_dir,
is_train=True
)
# Watch the model in wandb
wandb_logger.watch_model(model)
trainer = Trainer(
max_epochs=config["train"]["epochs"],
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=-1 if torch.cuda.is_available() else 1,
benchmark=False,
fast_dev_run=False,
strategy=DDPStrategy(
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
) if torch.cuda.is_available() else "auto",
precision=config["train"]["precision"],
logger=wandb_logger,
callbacks=[ckpt_callback],
use_distributed_sampler=False,
)
trainer.fit(model, data_module)
wandb.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config_file",
type=str,
default="configs/s1.yaml",
help="path of config file",
)
args = parser.parse_args()
main(args)