Spaces:
Sleeping
Sleeping
import os | |
from pathlib import Path | |
import torch.nn.functional as F | |
from omegaconf import OmegaConf | |
import torch | |
import torchaudio | |
from tqdm.auto import tqdm | |
from dataset import DiffusionCollater, DiffusionDataset | |
from ldm.util import instantiate_from_config | |
from ttts.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule | |
from ttts.utils.utils import clean_checkpoints, plot_spectrogram_to_numpy, summarize | |
from accelerate import Accelerator | |
from vocos import Vocos | |
from ttts.AA_diffusion.cldm.cldm import denormalize_tacotron_mel | |
from torch.utils.data import DataLoader | |
from torch.optim import AdamW | |
from datetime import datetime | |
from ttts.utils.infer_utils import load_model | |
# import utils | |
from torch.utils.tensorboard import SummaryWriter | |
def count_parameters(model): | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
def get_grad_norm(model): | |
total_norm = 0 | |
for name,p in model.named_parameters(): | |
try: | |
param_norm = p.grad.data.norm(2) | |
total_norm += param_norm.item() ** 2 | |
except: | |
print(name) | |
total_norm = total_norm ** (1. / 2) | |
return total_norm | |
def cycle(dl): | |
while True: | |
for data in dl: | |
yield data | |
def create_model(config_path): | |
config = OmegaConf.load(config_path) | |
model = instantiate_from_config(config.model).cpu() | |
print(f'Loaded model config from [{config_path}]') | |
return model | |
def get_state_dict(d): | |
return d.get('state_dict', d) | |
def load_state_dict(ckpt_path, location='cpu'): | |
_, extension = os.path.splitext(ckpt_path) | |
if extension.lower() == ".safetensors": | |
import safetensors.torch | |
state_dict = safetensors.torch.load_file(ckpt_path, device=location) | |
else: | |
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) | |
state_dict = get_state_dict(state_dict) | |
print(f'Loaded state_dict from [{ckpt_path}]') | |
return state_dict | |
class Trainer(object): | |
def __init__( | |
self, | |
cfg_path = 'ttts/AA_diffusion/config.yaml', | |
): | |
super().__init__() | |
self.cfg = OmegaConf.load(cfg_path) | |
self.accelerator = Accelerator() | |
# model | |
self.vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") | |
self.gpt = load_model('gpt',self.cfg['dataset']['gpt_path'],'ttts/gpt/config.json','cuda') | |
self.model = create_model(cfg_path) | |
self.mel_length_compression = 4 | |
print("model params:", count_parameters(self.model)) | |
# sampling and training hyperparameters | |
self.save_and_sample_every = self.cfg['train']['save_and_sample_every'] | |
self.gradient_accumulate_every = self.cfg['train']['gradient_accumulate_every'] | |
self.train_num_steps = self.cfg['train']['train_num_steps'] | |
# dataset and dataloader | |
self.dataset = DiffusionDataset(self.cfg) | |
dl = DataLoader(self.dataset, **self.cfg['dataloader'], collate_fn=DiffusionCollater()) | |
dl = self.accelerator.prepare(dl) | |
self.dl = cycle(dl) | |
# optimizer | |
self.opt = AdamW(self.model.parameters(), lr = self.cfg['train']['train_lr'], betas = self.cfg['train']['adam_betas']) | |
# for logging results in a folder periodically | |
if self.accelerator.is_main_process: | |
# eval_ds = TestDataset(self.cfg['data']['val_files'], self.cfg, self.vocos) | |
# self.eval_dl = DataLoader(eval_ds, batch_size = 1, shuffle = False, num_workers = self.cfg['train']['num_workers']) | |
# self.eval_dl = iter(cycle(self.eval_dl)) | |
now = datetime.now() | |
self.logs_folder = Path(self.cfg['train']['logs_folder']+'/'+now.strftime("%Y-%m-%d-%H-%M-%S")) | |
self.logs_folder.mkdir(exist_ok = True, parents=True) | |
# step counter state | |
self.step = 0 | |
# prepare model, dataloader, optimizer with accelerator | |
self.model, self.opt = self.accelerator.prepare(self.model, self.opt) | |
def device(self): | |
return self.accelerator.device | |
def save(self, milestone): | |
if not self.accelerator.is_local_main_process: | |
return | |
data = { | |
'step': self.step, | |
'model': self.accelerator.get_state_dict(self.model), | |
} | |
torch.save(data, str(self.logs_folder / f'model-{milestone}.pt')) | |
def load(self, model_path): | |
accelerator = self.accelerator | |
device = accelerator.device | |
data = torch.load(model_path, map_location=device) | |
self.step = data['step'] | |
saved_state_dict = data['model'] | |
model = self.accelerator.unwrap_model(self.model) | |
# del saved_state_dict['cond_stage_model.visual.positional_embedding'] | |
# del saved_state_dict['cond_stage_model.visual.conv1.weight'] | |
model.load_state_dict(saved_state_dict) | |
def train(self): | |
# print(1) | |
accelerator = self.accelerator | |
device = accelerator.device | |
if accelerator.is_main_process: | |
writer = SummaryWriter(log_dir=self.logs_folder) | |
writer_eval = SummaryWriter(log_dir=os.path.join(self.logs_folder, "eval")) | |
with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: | |
while self.step < self.train_num_steps: | |
# with torch.autograd.detect_anomaly(): | |
for _ in range(self.gradient_accumulate_every): | |
data = next(self.dl) | |
data = {k: v.to(self.device) for k, v in data.items()} | |
with torch.no_grad(): | |
latent = self.gpt(data['padded_mel_refer'], data['padded_text'], | |
torch.tensor([data['padded_text'].shape[-1]], device=device), data['padded_mel_code'], | |
torch.tensor([data['padded_mel_code'].shape[-1]*self.mel_length_compression], device=device), | |
return_latent=True, clip_inputs=False).transpose(1,2) | |
latent = F.interpolate(latent, size=data['padded_mel'].shape[-1], mode='nearest') | |
data_ = dict(jpg=data['padded_mel'], txt=data['padded_mel_refer'], hint=latent) | |
with self.accelerator.autocast(): | |
loss = accelerator.unwrap_model(self.model).training_step(data_) | |
model = accelerator.unwrap_model(self.model) | |
unused_params =[] | |
# unused_params.extend(list(model.refer_model.out.parameters())) | |
unused_params.extend(list(model.cond_stage_model.visual.proj)) | |
# unused_params.extend(list(model.refer_model.output_blocks.parameters())) | |
# unused_params.extend(list(model.refer_model.output_blocks.parameters())) | |
unused_params.extend(list(model.unconditioned_embedding)) | |
unused_params.extend(list(model.unconditioned_cat_embedding)) | |
extraneous_addition = 0 | |
for p in unused_params: | |
extraneous_addition = extraneous_addition + p.mean() | |
loss = loss + 0*extraneous_addition | |
loss = loss / self.gradient_accumulate_every | |
self.accelerator.backward(loss) | |
grad_norm = get_grad_norm(self.model) | |
accelerator.clip_grad_norm_(self.model.parameters(), 1.0) | |
pbar.set_description(f'loss: {loss:.4f}') | |
accelerator.wait_for_everyone() | |
if (self.step+1)%self.gradient_accumulate_every==0: | |
self.opt.step() | |
self.opt.zero_grad() | |
accelerator.wait_for_everyone() | |
############################logging############################################# | |
if accelerator.is_main_process and self.step % 100 == 0: | |
scalar_dict = {"loss/diff": loss, "loss/grad": grad_norm} | |
summarize( | |
writer=writer, | |
global_step=self.step, | |
scalars=scalar_dict | |
) | |
if accelerator.is_main_process: | |
if self.step % self.save_and_sample_every == 0: | |
data = data | |
data = {k: v.to(self.device) for k, v in data.items()} | |
with torch.no_grad(): | |
latent = self.gpt(data['padded_mel_refer'], data['padded_text'], | |
torch.tensor([data['padded_text'].shape[-1]], device=device), data['padded_mel_code'], | |
torch.tensor([data['padded_mel_code'].shape[-1]*self.mel_length_compression], device=device), | |
return_latent=True, clip_inputs=False).transpose(1,2) | |
latent = F.interpolate(latent, size=data['padded_mel'].shape[-1], mode='nearest') | |
data_ = dict(jpg=data['padded_mel'], txt=data['padded_mel_refer'], hint=latent) | |
with torch.no_grad(): | |
model = accelerator.unwrap_model(self.model) | |
model.eval() | |
milestone = self.step // self.save_and_sample_every | |
log = model.log_images(data_) | |
mel = log['samples'].detach().cpu() | |
mel = denormalize_tacotron_mel(mel) | |
model.train() | |
gen = self.vocos.decode(mel) | |
torchaudio.save(str(self.logs_folder / f'sample-{milestone}.wav'), gen, 24000) | |
audio_dict = {} | |
audio_dict.update({ | |
f"gen/audio": gen, | |
}) | |
image_dict = { | |
f"gt/mel": plot_spectrogram_to_numpy(data['padded_mel'][0, :, :].detach().unsqueeze(-1).cpu()), | |
f"gen/mel": plot_spectrogram_to_numpy(mel[0, :, :].detach().unsqueeze(-1).cpu()), | |
} | |
summarize( | |
writer=writer_eval, | |
global_step=self.step, | |
audios=audio_dict, | |
images=image_dict, | |
audio_sampling_rate=24000 | |
) | |
keep_ckpts = self.cfg['train']['keep_ckpts'] | |
if keep_ckpts > 0: | |
clean_checkpoints(path_to_models=self.logs_folder, n_ckpts_to_keep=keep_ckpts, sort_by_time=True) | |
self.save(milestone) | |
self.step += 1 | |
pbar.update(1) | |
accelerator.print('training complete') | |
# example | |
if __name__ == '__main__': | |
trainer = Trainer() | |
# trainer.load('/home/hyc/tortoise_plus_zh/ttts/AA_diffusion/logs/2023-12-30-18-46-48/model-121.pt') | |
trainer.train() | |