File size: 3,452 Bytes
9fd672f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
import os
import pdb
import logging
import argparse
from pathlib import Path
import torch, platform
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
from GPT_SoVITS.utils.wandb_logger import WandbLoggerWithConfig

logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
torch.set_float32_matmul_precision("high")

def my_model_ckpt(
    config,
    if_save_latest,
    if_save_every_weights,
    half_weights_save_dir,
    exp_name,
    **kwargs,
):
    if if_save_latest:
        kwargs["save_last"] = True
    callbacks = []
    callbacks.append(
        ModelCheckpoint(
            **kwargs,
            filename=exp_name + "_{epoch}-{step}",
        )
    )
    return callbacks[0]

def main(args):
    config = load_yaml_config(args.config_file)
    
    output_dir = Path(config["output_dir"])
    output_dir.mkdir(parents=True, exist_ok=True)
    
    ckpt_dir = output_dir / "ckpt"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    
    seed_everything(config["train"]["seed"], workers=True)
    
    # Initialize wandb logger
    wandb_logger = WandbLoggerWithConfig(config=config)
    
    ckpt_callback = my_model_ckpt(
        config=config,
        if_save_latest=config["train"]["if_save_latest"],
        if_save_every_weights=config["train"]["if_save_every_weights"],
        half_weights_save_dir=config["train"]["half_weights_save_dir"],
        exp_name=config["train"]["exp_name"],
        save_top_k=-1,
        monitor="loss",
        mode="min",
        save_on_train_epoch_end=True,
        every_n_epochs=config["train"]["save_every_n_epoch"],
        dirpath=ckpt_dir,
    )
    
    # Create data module
    data_module = Text2SemanticDataModule(
        config=config,
        train_semantic_path=config.get("train_semantic_path", ""),
        train_phoneme_path=config.get("train_phoneme_path", "")
    )
    
    # Initialize model with correct parameters
    model = Text2SemanticLightningModule(
        config=config,
        output_dir=output_dir,
        is_train=True
    )
    
    # Watch the model in wandb
    wandb_logger.watch_model(model)
    
    trainer = Trainer(
        max_epochs=config["train"]["epochs"],
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=-1 if torch.cuda.is_available() else 1,
        benchmark=False,
        fast_dev_run=False,
        strategy=DDPStrategy(
            process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
        ) if torch.cuda.is_available() else "auto",
        precision=config["train"]["precision"],
        logger=wandb_logger,
        callbacks=[ckpt_callback],
        use_distributed_sampler=False,
    )
    
    trainer.fit(model, data_module)
    wandb.finish()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c",
        "--config_file",
        type=str,
        default="configs/s1.yaml",
        help="path of config file",
    )
    args = parser.parse_args()
    main(args)