Spaces:
Running
Running
Delete s1_train.py
Browse files- s1_train.py +0 -171
s1_train.py
DELETED
@@ -1,171 +0,0 @@
|
|
1 |
-
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
|
2 |
-
import os
|
3 |
-
import pdb
|
4 |
-
|
5 |
-
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
6 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
7 |
-
import argparse
|
8 |
-
import logging
|
9 |
-
from pathlib import Path
|
10 |
-
|
11 |
-
import torch, platform
|
12 |
-
from pytorch_lightning import seed_everything
|
13 |
-
from pytorch_lightning import Trainer
|
14 |
-
from pytorch_lightning.callbacks import ModelCheckpoint
|
15 |
-
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
16 |
-
from pytorch_lightning.strategies import DDPStrategy
|
17 |
-
from AR.data.data_module import Text2SemanticDataModule
|
18 |
-
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
19 |
-
from AR.utils.io import load_yaml_config
|
20 |
-
|
21 |
-
logging.getLogger("numba").setLevel(logging.WARNING)
|
22 |
-
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
23 |
-
torch.set_float32_matmul_precision("high")
|
24 |
-
from AR.utils import get_newest_ckpt
|
25 |
-
|
26 |
-
from collections import OrderedDict
|
27 |
-
|
28 |
-
|
29 |
-
class my_model_ckpt(ModelCheckpoint):
|
30 |
-
def __init__(
|
31 |
-
self,
|
32 |
-
config,
|
33 |
-
if_save_latest,
|
34 |
-
if_save_every_weights,
|
35 |
-
half_weights_save_dir,
|
36 |
-
exp_name,
|
37 |
-
**kwargs
|
38 |
-
):
|
39 |
-
super().__init__(**kwargs)
|
40 |
-
self.if_save_latest = if_save_latest
|
41 |
-
self.if_save_every_weights = if_save_every_weights
|
42 |
-
self.half_weights_save_dir = half_weights_save_dir
|
43 |
-
self.exp_name = exp_name
|
44 |
-
self.config = config
|
45 |
-
|
46 |
-
def on_train_epoch_end(self, trainer, pl_module):
|
47 |
-
if not self._should_skip_saving_checkpoint(
|
48 |
-
trainer
|
49 |
-
) and self._should_save_on_train_epoch_end(trainer):
|
50 |
-
monitor_candidates = self._monitor_candidates(trainer)
|
51 |
-
if (
|
52 |
-
self._every_n_epochs >= 1
|
53 |
-
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
|
54 |
-
):
|
55 |
-
if (
|
56 |
-
self.if_save_latest == True
|
57 |
-
): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
58 |
-
to_clean = list(os.listdir(self.dirpath))
|
59 |
-
self._save_topk_checkpoint(trainer, monitor_candidates)
|
60 |
-
if self.if_save_latest == True:
|
61 |
-
for name in to_clean:
|
62 |
-
try:
|
63 |
-
os.remove("%s/%s" % (self.dirpath, name))
|
64 |
-
except:
|
65 |
-
pass
|
66 |
-
if self.if_save_every_weights == True:
|
67 |
-
to_save_od = OrderedDict()
|
68 |
-
to_save_od["weight"] = OrderedDict()
|
69 |
-
dictt = trainer.strategy._lightning_module.state_dict()
|
70 |
-
for key in dictt:
|
71 |
-
to_save_od["weight"][key] = dictt[key].half()
|
72 |
-
to_save_od["config"] = self.config
|
73 |
-
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
|
74 |
-
torch.save(
|
75 |
-
to_save_od,
|
76 |
-
"%s/%s-e%s.ckpt"
|
77 |
-
% (
|
78 |
-
self.half_weights_save_dir,
|
79 |
-
self.exp_name,
|
80 |
-
trainer.current_epoch + 1,
|
81 |
-
),
|
82 |
-
)
|
83 |
-
self._save_last_checkpoint(trainer, monitor_candidates)
|
84 |
-
|
85 |
-
|
86 |
-
def main(args):
|
87 |
-
config = load_yaml_config(args.config_file)
|
88 |
-
|
89 |
-
output_dir = Path(config["output_dir"])
|
90 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
91 |
-
|
92 |
-
ckpt_dir = output_dir / "ckpt"
|
93 |
-
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
94 |
-
|
95 |
-
seed_everything(config["train"]["seed"], workers=True)
|
96 |
-
ckpt_callback: ModelCheckpoint = my_model_ckpt(
|
97 |
-
config=config,
|
98 |
-
if_save_latest=config["train"]["if_save_latest"],
|
99 |
-
if_save_every_weights=config["train"]["if_save_every_weights"],
|
100 |
-
half_weights_save_dir=config["train"]["half_weights_save_dir"],
|
101 |
-
exp_name=config["train"]["exp_name"],
|
102 |
-
save_top_k=-1,
|
103 |
-
monitor="top_3_acc",
|
104 |
-
mode="max",
|
105 |
-
save_on_train_epoch_end=True,
|
106 |
-
every_n_epochs=config["train"]["save_every_n_epoch"],
|
107 |
-
dirpath=ckpt_dir,
|
108 |
-
)
|
109 |
-
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
|
110 |
-
trainer: Trainer = Trainer(
|
111 |
-
max_epochs=config["train"]["epochs"],
|
112 |
-
accelerator="gpu",
|
113 |
-
# val_check_interval=9999999999999999999999,###不要验证
|
114 |
-
# check_val_every_n_epoch=None,
|
115 |
-
limit_val_batches=0,
|
116 |
-
devices=-1,
|
117 |
-
benchmark=False,
|
118 |
-
fast_dev_run=False,
|
119 |
-
strategy=DDPStrategy(
|
120 |
-
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
|
121 |
-
),
|
122 |
-
precision=config["train"]["precision"],
|
123 |
-
logger=logger,
|
124 |
-
num_sanity_val_steps=0,
|
125 |
-
callbacks=[ckpt_callback],
|
126 |
-
)
|
127 |
-
|
128 |
-
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
|
129 |
-
config, output_dir
|
130 |
-
)
|
131 |
-
|
132 |
-
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
133 |
-
config,
|
134 |
-
train_semantic_path=config["train_semantic_path"],
|
135 |
-
train_phoneme_path=config["train_phoneme_path"],
|
136 |
-
# dev_semantic_path=args.dev_semantic_path,
|
137 |
-
# dev_phoneme_path=args.dev_phoneme_path
|
138 |
-
)
|
139 |
-
|
140 |
-
try:
|
141 |
-
# 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
|
142 |
-
newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
|
143 |
-
ckpt_path = ckpt_dir / newest_ckpt_name
|
144 |
-
except Exception:
|
145 |
-
ckpt_path = None
|
146 |
-
print("ckpt_path:", ckpt_path)
|
147 |
-
trainer.fit(model, data_module, ckpt_path=ckpt_path)
|
148 |
-
|
149 |
-
|
150 |
-
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
|
151 |
-
if __name__ == "__main__":
|
152 |
-
parser = argparse.ArgumentParser()
|
153 |
-
parser.add_argument(
|
154 |
-
"-c",
|
155 |
-
"--config_file",
|
156 |
-
type=str,
|
157 |
-
default="configs/s1longer.yaml",
|
158 |
-
help="path of config file",
|
159 |
-
)
|
160 |
-
# args for dataset
|
161 |
-
# parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
|
162 |
-
# parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')
|
163 |
-
|
164 |
-
# parser.add_argument('--dev_semantic_path', type=str, default='dump_mix/semantic_dev.tsv')
|
165 |
-
# parser.add_argument('--dev_phoneme_path', type=str, default='dump_mix/phoneme_dev.npy')
|
166 |
-
# parser.add_argument('--output_dir',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/logs_s1',help='directory to save the results')
|
167 |
-
# parser.add_argument('--output_dir',type=str,default='/liujing04/gpt_logs/s1/xuangou_ft',help='directory to save the results')
|
168 |
-
|
169 |
-
args = parser.parse_args()
|
170 |
-
logging.info(str(args))
|
171 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|