Upload train_vocoder.py with huggingface_hub
Browse files- train_vocoder.py +77 -0
train_vocoder.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
|
4 |
+
from trainer import Trainer, TrainerArgs
|
5 |
+
|
6 |
+
from TTS.config import load_config, register_config
|
7 |
+
from TTS.utils.audio import AudioProcessor
|
8 |
+
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
9 |
+
from TTS.vocoder.models import setup_model
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class TrainVocoderArgs(TrainerArgs):
|
14 |
+
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
"""Run `tts` model training directly by a `config.json` file."""
|
19 |
+
# init trainer args
|
20 |
+
train_args = TrainVocoderArgs()
|
21 |
+
parser = train_args.init_argparse(arg_prefix="")
|
22 |
+
|
23 |
+
# override trainer args from comman-line args
|
24 |
+
args, config_overrides = parser.parse_known_args()
|
25 |
+
train_args.parse_args(args)
|
26 |
+
|
27 |
+
# load config.json and register
|
28 |
+
if args.config_path or args.continue_path:
|
29 |
+
if args.config_path:
|
30 |
+
# init from a file
|
31 |
+
config = load_config(args.config_path)
|
32 |
+
if len(config_overrides) > 0:
|
33 |
+
config.parse_known_args(config_overrides, relaxed_parser=True)
|
34 |
+
elif args.continue_path:
|
35 |
+
# continue from a prev experiment
|
36 |
+
config = load_config(os.path.join(args.continue_path, "config.json"))
|
37 |
+
if len(config_overrides) > 0:
|
38 |
+
config.parse_known_args(config_overrides, relaxed_parser=True)
|
39 |
+
else:
|
40 |
+
# init from console args
|
41 |
+
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
|
42 |
+
|
43 |
+
config_base = BaseTrainingConfig()
|
44 |
+
config_base.parse_known_args(config_overrides)
|
45 |
+
config = register_config(config_base.model)()
|
46 |
+
|
47 |
+
# load training samples
|
48 |
+
if "feature_path" in config and config.feature_path:
|
49 |
+
# load pre-computed features
|
50 |
+
print(f" > Loading features from: {config.feature_path}")
|
51 |
+
eval_samples, train_samples = load_wav_feat_data(config.data_path, config.feature_path, config.eval_split_size)
|
52 |
+
else:
|
53 |
+
# load data raw wav files
|
54 |
+
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
|
55 |
+
|
56 |
+
# setup audio processor
|
57 |
+
ap = AudioProcessor(**config.audio)
|
58 |
+
|
59 |
+
# init the model from config
|
60 |
+
model = setup_model(config)
|
61 |
+
|
62 |
+
# init the trainer and 🚀
|
63 |
+
trainer = Trainer(
|
64 |
+
train_args,
|
65 |
+
config,
|
66 |
+
config.output_path,
|
67 |
+
model=model,
|
68 |
+
train_samples=train_samples,
|
69 |
+
eval_samples=eval_samples,
|
70 |
+
training_assets={"audio_processor": ap},
|
71 |
+
parse_command_line_args=False,
|
72 |
+
)
|
73 |
+
trainer.fit()
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
main()
|