from __future__ import print_function |
import argparse |
import datetime |
import logging |
logging.getLogger('matplotlib').setLevel(logging.WARNING) |
from copy import deepcopy |
import os |
import torch |
import torch.distributed as dist |
import deepspeed |
from hyperpyyaml import load_hyperpyyaml |
from torch.distributed.elastic.multiprocessing.errors import record |
from cosyvoice.utils.executor import Executor |
from cosyvoice.utils.train_utils import ( |
init_distributed, |
init_dataset_and_dataloader, |
init_optimizer_and_scheduler, |
init_summarywriter, save_model, |
wrap_cuda_model, check_modify_and_save_config) |
def get_args(): |
parser = argparse.ArgumentParser(description='training your network') |
parser.add_argument('--train_engine', |
default='torch_ddp', |
choices=['torch_ddp', 'deepspeed'], |
help='Engine for paralleled training') |
parser.add_argument('--model', required=True, help='model which will be trained') |
parser.add_argument('--config', required=True, help='config file') |
parser.add_argument('--train_data', required=True, help='train data file') |
parser.add_argument('--cv_data', required=True, help='cv data file') |
parser.add_argument('--checkpoint', help='checkpoint model') |
parser.add_argument('--model_dir', required=True, help='save model dir') |
parser.add_argument('--tensorboard_dir', |
default='tensorboard', |
help='tensorboard log dir') |
parser.add_argument('--ddp.dist_backend', |
dest='dist_backend', |
default='nccl', |
choices=['nccl', 'gloo'], |
help='distributed backend') |
parser.add_argument('--num_workers', |
default=0, |
type=int, |
help='num of subprocess workers for reading') |
parser.add_argument('--prefetch', |
default=100, |
type=int, |
help='prefetch number') |
parser.add_argument('--pin_memory', |
action='store_true', |
default=False, |
help='Use pinned memory buffers used for reading') |
parser.add_argument('--use_amp', |
action='store_true', |
default=False, |
help='Use automatic mixed precision training') |
parser.add_argument('--deepspeed.save_states', |
dest='save_states', |
default='model_only', |
choices=['model_only', 'model+optimizer'], |
help='save model/optimizer states') |
parser.add_argument('--timeout', |
default=60, |
type=int, |
help='timeout (in seconds) of cosyvoice_join.') |
parser = deepspeed.add_config_arguments(parser) |
args = parser.parse_args() |
return args |
@record |
def main(): |
args = get_args() |
logging.basicConfig(level=logging.DEBUG, |
format='%(asctime)s %(levelname)s %(message)s') |
gan = True if args.model == 'hifigan' else False |
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model} |
if gan is True: |
override_dict.pop('hift') |
with open(args.config, 'r') as f: |
configs = load_hyperpyyaml(f, overrides=override_dict) |
if gan is True: |
configs['train_conf'] = configs['train_conf_gan'] |
configs['train_conf'].update(vars(args)) |
init_distributed(args) |
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ |
init_dataset_and_dataloader(args, configs, gan) |
configs = check_modify_and_save_config(args, configs) |
writer = init_summarywriter(args) |
model = configs[args.model] |
start_step, start_epoch = 0, -1 |
if args.checkpoint is not None: |
if os.path.exists(args.checkpoint): |
state_dict = torch.load(args.checkpoint, map_location='cpu') |
model.load_state_dict(state_dict, strict=False) |
if 'step' in state_dict: |
start_step = state_dict['step'] |
if 'epoch' in state_dict: |
start_epoch = state_dict['epoch'] |
else: |
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint)) |
model = wrap_cuda_model(args, model) |
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan) |
scheduler.set_step(start_step) |
if scheduler_d is not None: |
scheduler_d.set_step(start_step) |
info_dict = deepcopy(configs['train_conf']) |
info_dict['step'] = start_step |
info_dict['epoch'] = start_epoch |
save_model(model, 'init', info_dict) |
executor = Executor(gan=gan) |
executor.step = start_step |
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None |
print('start step {} start epoch {}'.format(start_step, start_epoch)) |
for epoch in range(start_epoch + 1, info_dict['max_epoch']): |
executor.epoch = epoch |
train_dataset.set_epoch(epoch) |
dist.barrier() |
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) |
if gan is True: |
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, |
writer, info_dict, scaler, group_join) |
else: |
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join) |
dist.destroy_process_group(group_join) |
if __name__ == '__main__': |
main() |