mrfakename's picture
Add source code
4ee33aa
raw
history blame
11.2 kB
import os
from pathlib import Path
import torch.nn.functional as F
from omegaconf import OmegaConf
import torch
import torchaudio
from tqdm.auto import tqdm
from dataset import DiffusionCollater, DiffusionDataset
from ldm.util import instantiate_from_config
from ttts.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
from ttts.utils.utils import clean_checkpoints, plot_spectrogram_to_numpy, summarize
from accelerate import Accelerator
from vocos import Vocos
from ttts.AA_diffusion.cldm.cldm import denormalize_tacotron_mel
from torch.utils.data import DataLoader
from torch.optim import AdamW
from datetime import datetime
from ttts.utils.infer_utils import load_model
# import utils
from torch.utils.tensorboard import SummaryWriter
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def get_grad_norm(model):
total_norm = 0
for name,p in model.named_parameters():
try:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
except:
print(name)
total_norm = total_norm ** (1. / 2)
return total_norm
def cycle(dl):
while True:
for data in dl:
yield data
def create_model(config_path):
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
print(f'Loaded model config from [{config_path}]')
return model
def get_state_dict(d):
return d.get('state_dict', d)
def load_state_dict(ckpt_path, location='cpu'):
_, extension = os.path.splitext(ckpt_path)
if extension.lower() == ".safetensors":
import safetensors.torch
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
else:
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
state_dict = get_state_dict(state_dict)
print(f'Loaded state_dict from [{ckpt_path}]')
return state_dict
class Trainer(object):
def __init__(
self,
cfg_path = 'ttts/AA_diffusion/config.yaml',
):
super().__init__()
self.cfg = OmegaConf.load(cfg_path)
self.accelerator = Accelerator()
# model
self.vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
self.gpt = load_model('gpt',self.cfg['dataset']['gpt_path'],'ttts/gpt/config.json','cuda')
self.model = create_model(cfg_path)
self.mel_length_compression = 4
print("model params:", count_parameters(self.model))
# sampling and training hyperparameters
self.save_and_sample_every = self.cfg['train']['save_and_sample_every']
self.gradient_accumulate_every = self.cfg['train']['gradient_accumulate_every']
self.train_num_steps = self.cfg['train']['train_num_steps']
# dataset and dataloader
self.dataset = DiffusionDataset(self.cfg)
dl = DataLoader(self.dataset, **self.cfg['dataloader'], collate_fn=DiffusionCollater())
dl = self.accelerator.prepare(dl)
self.dl = cycle(dl)
# optimizer
self.opt = AdamW(self.model.parameters(), lr = self.cfg['train']['train_lr'], betas = self.cfg['train']['adam_betas'])
# for logging results in a folder periodically
if self.accelerator.is_main_process:
# eval_ds = TestDataset(self.cfg['data']['val_files'], self.cfg, self.vocos)
# self.eval_dl = DataLoader(eval_ds, batch_size = 1, shuffle = False, num_workers = self.cfg['train']['num_workers'])
# self.eval_dl = iter(cycle(self.eval_dl))
now = datetime.now()
self.logs_folder = Path(self.cfg['train']['logs_folder']+'/'+now.strftime("%Y-%m-%d-%H-%M-%S"))
self.logs_folder.mkdir(exist_ok = True, parents=True)
# step counter state
self.step = 0
# prepare model, dataloader, optimizer with accelerator
self.model, self.opt = self.accelerator.prepare(self.model, self.opt)
@property
def device(self):
return self.accelerator.device
def save(self, milestone):
if not self.accelerator.is_local_main_process:
return
data = {
'step': self.step,
'model': self.accelerator.get_state_dict(self.model),
}
torch.save(data, str(self.logs_folder / f'model-{milestone}.pt'))
def load(self, model_path):
accelerator = self.accelerator
device = accelerator.device
data = torch.load(model_path, map_location=device)
self.step = data['step']
saved_state_dict = data['model']
model = self.accelerator.unwrap_model(self.model)
# del saved_state_dict['cond_stage_model.visual.positional_embedding']
# del saved_state_dict['cond_stage_model.visual.conv1.weight']
model.load_state_dict(saved_state_dict)
def train(self):
# print(1)
accelerator = self.accelerator
device = accelerator.device
if accelerator.is_main_process:
writer = SummaryWriter(log_dir=self.logs_folder)
writer_eval = SummaryWriter(log_dir=os.path.join(self.logs_folder, "eval"))
with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar:
while self.step < self.train_num_steps:
# with torch.autograd.detect_anomaly():
for _ in range(self.gradient_accumulate_every):
data = next(self.dl)
data = {k: v.to(self.device) for k, v in data.items()}
with torch.no_grad():
latent = self.gpt(data['padded_mel_refer'], data['padded_text'],
torch.tensor([data['padded_text'].shape[-1]], device=device), data['padded_mel_code'],
torch.tensor([data['padded_mel_code'].shape[-1]*self.mel_length_compression], device=device),
return_latent=True, clip_inputs=False).transpose(1,2)
latent = F.interpolate(latent, size=data['padded_mel'].shape[-1], mode='nearest')
data_ = dict(jpg=data['padded_mel'], txt=data['padded_mel_refer'], hint=latent)
with self.accelerator.autocast():
loss = accelerator.unwrap_model(self.model).training_step(data_)
model = accelerator.unwrap_model(self.model)
unused_params =[]
# unused_params.extend(list(model.refer_model.out.parameters()))
unused_params.extend(list(model.cond_stage_model.visual.proj))
# unused_params.extend(list(model.refer_model.output_blocks.parameters()))
# unused_params.extend(list(model.refer_model.output_blocks.parameters()))
unused_params.extend(list(model.unconditioned_embedding))
unused_params.extend(list(model.unconditioned_cat_embedding))
extraneous_addition = 0
for p in unused_params:
extraneous_addition = extraneous_addition + p.mean()
loss = loss + 0*extraneous_addition
loss = loss / self.gradient_accumulate_every
self.accelerator.backward(loss)
grad_norm = get_grad_norm(self.model)
accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
pbar.set_description(f'loss: {loss:.4f}')
accelerator.wait_for_everyone()
if (self.step+1)%self.gradient_accumulate_every==0:
self.opt.step()
self.opt.zero_grad()
accelerator.wait_for_everyone()
############################logging#############################################
if accelerator.is_main_process and self.step % 100 == 0:
scalar_dict = {"loss/diff": loss, "loss/grad": grad_norm}
summarize(
writer=writer,
global_step=self.step,
scalars=scalar_dict
)
if accelerator.is_main_process:
if self.step % self.save_and_sample_every == 0:
data = data
data = {k: v.to(self.device) for k, v in data.items()}
with torch.no_grad():
latent = self.gpt(data['padded_mel_refer'], data['padded_text'],
torch.tensor([data['padded_text'].shape[-1]], device=device), data['padded_mel_code'],
torch.tensor([data['padded_mel_code'].shape[-1]*self.mel_length_compression], device=device),
return_latent=True, clip_inputs=False).transpose(1,2)
latent = F.interpolate(latent, size=data['padded_mel'].shape[-1], mode='nearest')
data_ = dict(jpg=data['padded_mel'], txt=data['padded_mel_refer'], hint=latent)
with torch.no_grad():
model = accelerator.unwrap_model(self.model)
model.eval()
milestone = self.step // self.save_and_sample_every
log = model.log_images(data_)
mel = log['samples'].detach().cpu()
mel = denormalize_tacotron_mel(mel)
model.train()
gen = self.vocos.decode(mel)
torchaudio.save(str(self.logs_folder / f'sample-{milestone}.wav'), gen, 24000)
audio_dict = {}
audio_dict.update({
f"gen/audio": gen,
})
image_dict = {
f"gt/mel": plot_spectrogram_to_numpy(data['padded_mel'][0, :, :].detach().unsqueeze(-1).cpu()),
f"gen/mel": plot_spectrogram_to_numpy(mel[0, :, :].detach().unsqueeze(-1).cpu()),
}
summarize(
writer=writer_eval,
global_step=self.step,
audios=audio_dict,
images=image_dict,
audio_sampling_rate=24000
)
keep_ckpts = self.cfg['train']['keep_ckpts']
if keep_ckpts > 0:
clean_checkpoints(path_to_models=self.logs_folder, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
self.save(milestone)
self.step += 1
pbar.update(1)
accelerator.print('training complete')
# example
if __name__ == '__main__':
trainer = Trainer()
# trainer.load('/home/hyc/tortoise_plus_zh/ttts/AA_diffusion/logs/2023-12-30-18-46-48/model-121.pt')
trainer.train()