import functools import os import numpy as np import blobfile as bf import torch from torch.optim import AdamW from diffusion import logger from diffusion.fp16_util import MixedPrecisionTrainer from diffusion.resample import LossAwareSampler, UniformSampler from tqdm import tqdm from diffusion.resample import create_named_schedule_sampler import sys [sys.path.append(i) for i in ['../process', '../../ubisoft-laforge-ZeroEGGS-main', '../mydiffusion_zeggs']] from generate.generate import WavEncoder from process_zeggs_bvh import pose2bvh # For ImageNet experiments, this was a good default value. # We found that the lg_loss_scale quickly climbed to # 20-21 within the first ~1K steps of training. INITIAL_LOG_LOSS_SCALE = 20.0 class TrainLoop: def __init__(self, args, model, diffusion, device, data=None): self.args = args self.data = data self.model = model self.diffusion = diffusion self.cond_mode = model.cond_mode self.batch_size = args.batch_size self.microbatch = args.batch_size # deprecating this option self.lr = args.lr self.log_interval = args.log_interval # self.save_interval = args.save_interval # self.resume_checkpoint = args.resume_checkpoint self.use_fp16 = False # deprecating this option self.fp16_scale_growth = 1e-3 # deprecating this option self.weight_decay = args.weight_decay self.lr_anneal_steps = args.lr_anneal_steps self.step = 0 self.resume_step = 0 self.global_batch = self.batch_size # * dist.get_world_size() # self.num_steps = args.num_steps self.num_epochs = 40000 self.n_seed = 8 self.sync_cuda = torch.cuda.is_available() # self._load_and_sync_parameters() self.mp_trainer = MixedPrecisionTrainer( model=self.model, use_fp16=self.use_fp16, fp16_scale_growth=self.fp16_scale_growth, ) self.save_dir = args.save_dir self.device = device if args.audio_feat == "wav encoder": self.WavEncoder = WavEncoder().to(self.device) self.opt = AdamW([ {'params': self.mp_trainer.master_params, 'lr':self.lr, 'weight_decay':self.weight_decay}, {'params': self.WavEncoder.parameters(), 'lr':self.lr} ]) elif args.audio_feat == "mfcc" or args.audio_feat == 'wavlm': self.opt = AdamW([ {'params': self.mp_trainer.master_params, 'lr':self.lr, 'weight_decay':self.weight_decay} ]) # if self.resume_step: # self._load_optimizer_state() # Model was resumed, either due to a restart or a checkpoint # being specified at the command line. self.schedule_sampler_type = 'uniform' self.schedule_sampler = create_named_schedule_sampler(self.schedule_sampler_type, diffusion) self.eval_wrapper, self.eval_data, self.eval_gt_data = None, None, None # if args.dataset in ['kit', 'humanml'] and args.eval_during_training: # mm_num_samples = 0 # mm is super slow hence we won't run it during training # mm_num_repeats = 0 # mm is super slow hence we won't run it during training # gen_loader = get_dataset_loader(name=args.dataset, batch_size=args.eval_batch_size, num_frames=None, # split=args.eval_split, # hml_mode='eval') # # self.eval_gt_data = get_dataset_loader(name=args.dataset, batch_size=args.eval_batch_size, num_frames=None, # split=args.eval_split, # hml_mode='gt') # self.eval_wrapper = EvaluatorMDMWrapper(args.dataset, self.device) # self.eval_data = { # 'test': lambda: eval_humanml.get_mdm_loader( # model, diffusion, args.eval_batch_size, # gen_loader, mm_num_samples, mm_num_repeats, gen_loader.dataset.opt.max_motion_length, # args.eval_num_samples, scale=1., # ) # } self.use_ddp = False self.ddp_model = self.model self.mask_train = (torch.zeros([self.batch_size, 1, 1, args.n_poses]) < 1).to(self.device) self.mask_test = (torch.zeros([1, 1, 1, args.n_poses]) < 1).to(self.device) # self.tmp_audio = torch.from_numpy(np.load('tmp_audio.npy')).unsqueeze(0).to(self.device) # self.tmp_mfcc = torch.from_numpy(np.load('10_kieks_0_9_16.npz')['mfcc'][:args.n_poses]).to(torch.float32).unsqueeze(0).to(self.device) self.mask_local_train = torch.ones(self.batch_size, args.n_poses).bool().to(self.device) self.mask_local_test = torch.ones(1, args.n_poses).bool().to(self.device) # def _load_and_sync_parameters(self): # resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint # # if resume_checkpoint: # self.resume_step = parse_resume_step_from_filename(resume_checkpoint) # logger.log(f"loading model from checkpoint: {resume_checkpoint}...") # self.model.load_state_dict( # dist_util.load_state_dict( # resume_checkpoint, map_location=self.device # ) # ) # def _load_optimizer_state(self): # main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint # opt_checkpoint = bf.join( # bf.dirname(main_checkpoint), f"opt{self.resume_step:09}.pt" # ) # if bf.exists(opt_checkpoint): # logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") # state_dict = dist_util.load_state_dict( # opt_checkpoint, map_location=self.device # ) # self.opt.load_state_dict(state_dict) def run_loop(self): for epoch in range(self.num_epochs): # print(f'Starting epoch {epoch}') # for _ in tqdm(range(10)): # 4 steps, batch size, chmod 777 for batch in tqdm(self.data): if not (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps): break cond_ = {'y':{}} # cond_['y']['text'] = ['A person turns left with medium speed.', 'A human goes slowly about 1.5 meters forward.'] # motion = torch.rand(2, 135, 1, 80).to(self.device) # pose_seq, _, style, audio, mfcc, wavlm = batch # (batch, 240, 135), (batch, 30), (batch, 64000) # pose_seq, _, style, _, _, wavlm = batch pose_seq, style, wavlm = batch motion = pose_seq.permute(0, 2, 1).unsqueeze(2).to(self.device) cond_['y']['seed'] = motion[..., 0:self.n_seed] # motion = motion[..., self.n_seed:] cond_['y']['style'] = style.to(self.device) cond_['y']['mask_local'] = self.mask_local_train if self.args.audio_feat == 'wav encoder': # cond_['y']['audio'] = torch.rand(240, 2, 32).to(self.device) cond_['y']['audio'] = self.WavEncoder(audio.to(self.device)).permute(1, 0, 2) # (batch, 240, 32) elif self.args.audio_feat == 'mfcc': # cond_['y']['audio'] = torch.rand(80, 2, 13).to(self.device) cond_['y']['audio'] = mfcc.to(torch.float32).to(self.device).permute(1, 0, 2) # [self.n_seed:, ...] # (batch, 80, 13) elif self.args.audio_feat == 'wavlm': cond_['y']['audio'] = wavlm.to(torch.float32).to(self.device) cond_['y']['mask'] = self.mask_train # [..., self.n_seed:] self.run_step(motion, cond_) if self.step % self.log_interval == 0: for k,v in logger.get_current().name2val.items(): if k == 'loss': print('step[{}]: loss[{:0.5f}]'.format(self.step+self.resume_step, v)) # if self.step % 10000 == 0: # sample_fn = self.diffusion.p_sample_loop # # model_kwargs_ = {'y': {}} # model_kwargs_['y']['mask'] = self.mask_test # [..., self.n_seed:] # model_kwargs_['y']['seed'] = torch.zeros([1, 1141, 1, self.n_seed]).to(self.device) # model_kwargs_['y']['style'] = torch.zeros([1, 6]).to(self.device) # model_kwargs_['y']['mask_local'] = self.mask_local_test # if self.args.audio_feat == 'wav encoder': # model_kwargs_['y']['audio'] = self.WavEncoder(self.tmp_audio).permute(1, 0, 2) # # model_kwargs_['y']['audio'] = torch.rand(240, 1, 32).to(self.device) # elif self.args.audio_feat == 'mfcc': # model_kwargs_['y']['audio'] = self.tmp_mfcc.permute(1, 0, 2) # [self.n_seed:, ...] # # model_kwargs_['y']['audio'] = torch.rand(80, 1, 13).to(self.device) # elif self.args.audio_feat == 'wavlm': # model_kwargs_['y']['audio'] = torch.randn(1, 1, 1024).to(self.device) # # sample = sample_fn( # self.model, # (1, 1141, 1, self.args.n_poses), # - self.n_seed # clip_denoised=False, # model_kwargs=model_kwargs_, # skip_timesteps=0, # 0 is the default value - i.e. don't skip any step # init_image=None, # progress=True, # dump_steps=None, # noise=None, # const_noise=False, # ) # (1, 135, 1, 240) # # sampled_seq = sample.squeeze(0).permute(1, 2, 0) # data_mean_ = np.load("../../ubisoft-laforge-ZeroEGGS-main/Data/processed_v1/processed/mean.npz")['mean'] # data_std_ = np.load("../../ubisoft-laforge-ZeroEGGS-main/Data/processed_v1/processed/std.npz")['std'] # # data_mean = np.array(data_mean_).squeeze() # data_std = np.array(data_std_).squeeze() # std = np.clip(data_std, a_min=0.01, a_max=None) # out_poses = np.multiply(np.array(sampled_seq[0].detach().cpu()), std) + data_mean # # pipeline_path = '../../../My/process/resource/data_pipe_20_rotation.sav' # save_path = 'inference_zeggs_mymodel3_wavlm' # prefix = str(datetime.now().strftime('%Y%m%d_%H%M%S')) # if not os.path.exists(save_path): # os.mkdir(save_path) # # make_bvh_GENEA2020_BT(save_path, prefix, out_poses, smoothing=False, pipeline_path=pipeline_path) # # pose2bvh(out_poses, os.path.join(save_path, prefix + '.bvh'), length=self.args.n_poses) if self.step % 50000 == 0: self.save() # self.model.eval() # self.evaluate() # self.model.train() # Run for a finite amount of time in integration tests. if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: return self.step += 1 if not (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps): break # Save the last checkpoint if it wasn't already saved. # if (self.step - 1) % 50000 != 0: # self.save() # self.evaluate() def run_step(self, batch, cond): self.forward_backward(batch, cond) # torch.Size([64, 251, 1, 196]) cond['y'].keys() dict_keys(['mask', 'lengths', 'text', 'tokens']) self.mp_trainer.optimize(self.opt) self._anneal_lr() self.log_step() def forward_backward(self, batch, cond): self.mp_trainer.zero_grad() for i in range(0, batch.shape[0], self.microbatch): # Eliminates the microbatch feature assert i == 0 assert self.microbatch == self.batch_size micro = batch micro_cond = cond last_batch = (i + self.microbatch) >= batch.shape[0] t, weights = self.schedule_sampler.sample(micro.shape[0], self.device) compute_losses = functools.partial( self.diffusion.training_losses, self.ddp_model, micro, # [bs, ch, image_size, image_size] # x_start, (2, 135, 1, 240) t, # [bs](int) sampled timesteps model_kwargs=micro_cond, dataset='kit' ) if last_batch or not self.use_ddp: losses = compute_losses() else: with self.ddp_model.no_sync(): losses = compute_losses() if isinstance(self.schedule_sampler, LossAwareSampler): self.schedule_sampler.update_with_local_losses( t, losses["loss"].detach() ) loss = (losses["loss"] * weights).mean() log_loss_dict( self.diffusion, t, {k: v * weights for k, v in losses.items()} ) self.mp_trainer.backward(loss) def _anneal_lr(self): if not self.lr_anneal_steps: return frac_done = (self.step + self.resume_step) / self.lr_anneal_steps lr = self.lr * (1 - frac_done) for param_group in self.opt.param_groups: param_group["lr"] = lr def log_step(self): logger.logkv("step", self.step + self.resume_step) logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) def ckpt_file_name(self): return f"model{(self.step+self.resume_step):09d}.pt" def save(self): def save_checkpoint(params): state_dict = self.mp_trainer.master_params_to_state_dict(params) # Do not save CLIP weights clip_weights = [e for e in state_dict.keys() if e.startswith('clip_model.')] for e in clip_weights: del state_dict[e] logger.log(f"saving model...") filename = self.ckpt_file_name() with bf.BlobFile(bf.join(self.save_dir, filename), "wb") as f: torch.save(state_dict, f) save_checkpoint(self.mp_trainer.master_params) with bf.BlobFile( bf.join(self.save_dir, f"opt{(self.step+self.resume_step):09d}.pt"), "wb", ) as f: torch.save(self.opt.state_dict(), f) def parse_resume_step_from_filename(filename): """ Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the checkpoint's number of steps. """ split = filename.split("model") if len(split) < 2: return 0 split1 = split[-1].split(".")[0] try: return int(split1) except ValueError: return 0 def get_blob_logdir(): # You can change this to be a separate path to save checkpoints to # a blobstore or some external drive. return logger.get_dir() def find_resume_checkpoint(): # On your infrastructure, you may want to override this to automatically # discover the latest checkpoint on your blob storage, etc. return None def log_loss_dict(diffusion, ts, losses): for key, values in losses.items(): logger.logkv_mean(key, values.mean().item()) # Log the quantiles (four quartiles, in particular). for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): quartile = int(4 * sub_t / diffusion.num_timesteps) logger.logkv_mean(f"{key}_q{quartile}", sub_loss)