# -*- coding: utf-8 -*- # # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), # acting on behalf of its Max Planck Institute for Intelligent Systems and the # Max Planck Institute for Biological Cybernetics. All rights reserved. # # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights # on this computer program. You can only use this computer program if you have closed a license agreement # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and liable to prosecution. # Contact: ps-license@tuebingen.mpg.de # # # If you use this code in a research publication please consider citing the following: # # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image # # # Code Developed by: # Nima Ghorbani # # 2020.12.12 # from pytorch_lightning import Trainer import glob import os import os.path as osp from datetime import datetime as dt from pytorch_lightning.plugins import DDPPlugin import numpy as np import pytorch_lightning as pl import torch from human_body_prior.body_model.body_model import BodyModel from human_body_prior.data.dataloader import VPoserDS from human_body_prior.data.prepare_data import dataset_exists from human_body_prior.data.prepare_data import prepare_vposer_datasets from human_body_prior.models.vposer_model import VPoser from human_body_prior.tools.angle_continuous_repres import geodesic_loss_R from human_body_prior.tools.configurations import load_config, dump_config from human_body_prior.tools.omni_tools import copy2cpu as c2c from human_body_prior.tools.omni_tools import get_support_data_dir from human_body_prior.tools.omni_tools import log2file from human_body_prior.tools.omni_tools import make_deterministic from human_body_prior.tools.omni_tools import makepath from human_body_prior.tools.rotation_tools import aa2matrot from human_body_prior.visualizations.training_visualization import vposer_trainer_renderer from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.core import LightningModule from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities import rank_zero_only from torch import optim as optim_module from torch.optim import lr_scheduler as lr_sched_module from torch.utils.data import DataLoader class VPoserTrainer(LightningModule): """ It includes all data loading and train / val logic., and it is used for both training and testing models. """ def __init__(self, _config): super(VPoserTrainer, self).__init__() _support_data_dir = get_support_data_dir() vp_ps = load_config(**_config) make_deterministic(vp_ps.general.rnd_seed) self.expr_id = vp_ps.general.expr_id self.dataset_id = vp_ps.general.dataset_id self.work_dir = vp_ps.logging.work_dir = makepath(vp_ps.general.work_basedir, self.expr_id) self.dataset_dir = vp_ps.logging.dataset_dir = osp.join(vp_ps.general.dataset_basedir, vp_ps.general.dataset_id) self._log_prefix = '[{}]'.format(self.expr_id) self.text_logger = log2file(prefix=self._log_prefix) self.seq_len = vp_ps.data_parms.num_timeseq_frames self.vp_model = VPoser(vp_ps) with torch.no_grad(): self.bm_train = BodyModel(vp_ps.body_model.bm_fname) if vp_ps.logging.render_during_training: self.renderer = vposer_trainer_renderer(self.bm_train, vp_ps.logging.num_bodies_to_display) else: self.renderer = None self.example_input_array = {'pose_body':torch.ones(vp_ps.train_parms.batch_size, 63),} self.vp_ps = vp_ps def forward(self, pose_body): return self.vp_model(pose_body) def _get_data(self, split_name): assert split_name in ('train', 'vald', 'test') split_name = split_name.replace('vald', 'vald') assert dataset_exists(self.dataset_dir), FileNotFoundError('Dataset does not exist dataset_dir = {}'.format(self.dataset_dir)) dataset = VPoserDS(osp.join(self.dataset_dir, split_name), data_fields = ['pose_body']) assert len(dataset) != 0, ValueError('Dataset has nothing in it!') return DataLoader(dataset, batch_size=self.vp_ps.train_parms.batch_size, shuffle=True if split_name == 'train' else False, num_workers=self.vp_ps.data_parms.num_workers, pin_memory=True) @rank_zero_only def on_train_start(self): if self.global_rank != 0: return self.train_starttime = dt.now().replace(microsecond=0) ######## make a backup of vposer git_repo_dir = os.path.abspath(__file__).split('/') git_repo_dir = '/'.join(git_repo_dir[:git_repo_dir.index('human_body_prior') + 1]) starttime = dt.strftime(self.train_starttime, '%Y_%m_%d_%H_%M_%S') archive_path = makepath(self.work_dir, 'code', 'vposer_{}.tar.gz'.format(starttime), isfile=True) cmd = 'cd %s && git ls-files -z | xargs -0 tar -czf %s' % (git_repo_dir, archive_path) os.system(cmd) ######## self.text_logger('Created a git archive backup at {}'.format(archive_path)) dump_config(self.vp_ps, osp.join(self.work_dir, '{}.yaml'.format(self.expr_id))) def train_dataloader(self): return self._get_data('train') def val_dataloader(self): return self._get_data('vald') def configure_optimizers(self): params_count = lambda params: sum(p.numel() for p in params if p.requires_grad) gen_params = [a[1] for a in self.vp_model.named_parameters() if a[1].requires_grad] gen_optimizer_class = getattr(optim_module, self.vp_ps.train_parms.gen_optimizer.type) gen_optimizer = gen_optimizer_class(gen_params, **self.vp_ps.train_parms.gen_optimizer.args) self.text_logger('Total Trainable Parameters Count in vp_model is %2.2f M.' % (params_count(gen_params) * 1e-6)) lr_sched_class = getattr(lr_sched_module, self.vp_ps.train_parms.lr_scheduler.type) gen_lr_scheduler = lr_sched_class(gen_optimizer, **self.vp_ps.train_parms.lr_scheduler.args) schedulers = [ { 'scheduler': gen_lr_scheduler, 'monitor': 'val_loss', 'interval': 'epoch', 'frequency': 1 }, ] return [gen_optimizer], schedulers def _compute_loss(self, dorig, drec): l1_loss = torch.nn.L1Loss(reduction='mean') geodesic_loss = geodesic_loss_R(reduction='mean') bs, latentD = drec['poZ_body_mean'].shape device = drec['poZ_body_mean'].device loss_kl_wt = self.vp_ps.train_parms.loss_weights.loss_kl_wt loss_rec_wt = self.vp_ps.train_parms.loss_weights.loss_rec_wt loss_matrot_wt = self.vp_ps.train_parms.loss_weights.loss_matrot_wt loss_jtr_wt = self.vp_ps.train_parms.loss_weights.loss_jtr_wt # q_z = torch.distributions.normal.Normal(drec['mean'], drec['std']) q_z = drec['q_z'] # dorig['fullpose'] = torch.cat([dorig['root_orient'], dorig['pose_body']], dim=-1) # Reconstruction loss - L1 on the output mesh with torch.no_grad(): bm_orig = self.bm_train(pose_body=dorig['pose_body']) bm_rec = self.bm_train(pose_body=drec['pose_body'].contiguous().view(bs, -1)) v2v = l1_loss(bm_rec.v, bm_orig.v) # KL loss p_z = torch.distributions.normal.Normal( loc=torch.zeros((bs, latentD), device=device, requires_grad=False), scale=torch.ones((bs, latentD), device=device, requires_grad=False)) weighted_loss_dict = { 'loss_kl':loss_kl_wt * torch.mean(torch.sum(torch.distributions.kl.kl_divergence(q_z, p_z), dim=[1])), 'loss_mesh_rec': loss_rec_wt * v2v } if (self.current_epoch < self.vp_ps.train_parms.keep_extra_loss_terms_until_epoch): # breakpoint() weighted_loss_dict['matrot'] = loss_matrot_wt * geodesic_loss(drec['pose_body_matrot'].view(-1,3,3), aa2matrot(dorig['pose_body'].view(-1, 3))) weighted_loss_dict['jtr'] = loss_jtr_wt * l1_loss(bm_rec.Jtr, bm_orig.Jtr) weighted_loss_dict['loss_total'] = torch.stack(list(weighted_loss_dict.values())).sum() with torch.no_grad(): unweighted_loss_dict = {'v2v': torch.sqrt(torch.pow(bm_rec.v-bm_orig.v, 2).sum(-1)).mean()} unweighted_loss_dict['loss_total'] = torch.cat( list({k: v.view(-1) for k, v in unweighted_loss_dict.items()}.values()), dim=-1).sum().view(1) return {'weighted_loss': weighted_loss_dict, 'unweighted_loss': unweighted_loss_dict} def training_step(self, batch, batch_idx, optimizer_idx=None): drec = self(batch['pose_body'].view(-1, 63)) loss = self._compute_loss(batch, drec) train_loss = loss['weighted_loss']['loss_total'] tensorboard_logs = {'train_loss': train_loss} progress_bar = {k: c2c(v) for k, v in loss['weighted_loss'].items()} return {'loss': train_loss, 'progress_bar':progress_bar, 'log': tensorboard_logs} def validation_step(self, batch, batch_idx): drec = self(batch['pose_body'].view(-1, 63)) loss = self._compute_loss(batch, drec) val_loss = loss['unweighted_loss']['loss_total'] if self.renderer is not None and self.global_rank == 0 and batch_idx % 500==0 and np.random.rand()>0.5: out_fname = makepath(self.work_dir, 'renders/vald_rec_E{:03d}_It{:04d}_val_loss_{:.2f}.png'.format(self.current_epoch, batch_idx, val_loss.item()), isfile=True) self.renderer([batch, drec], out_fname = out_fname) dgen = self.vp_model.sample_poses(self.vp_ps.logging.num_bodies_to_display) out_fname = makepath(self.work_dir, 'renders/vald_gen_E{:03d}_I{:04d}.png'.format(self.current_epoch, batch_idx), isfile=True) self.renderer([dgen], out_fname = out_fname) progress_bar = {'v2v': val_loss} return {'val_loss': c2c(val_loss), 'progress_bar': progress_bar, 'log': progress_bar} def validation_epoch_end(self, outputs): metrics = {'val_loss': np.nanmean(np.concatenate([v['val_loss'] for v in outputs])) } if self.global_rank == 0: self.text_logger('Epoch {}: {}'.format(self.current_epoch, ', '.join('{}:{:.2f}'.format(k, v) for k, v in metrics.items()))) self.text_logger('lr is {}'.format([pg['lr'] for opt in self.trainer.optimizers for pg in opt.param_groups])) metrics = {k: torch.as_tensor(v) for k, v in metrics.items()} return {'val_loss': metrics['val_loss'], 'log': metrics} @rank_zero_only def on_train_end(self): self.train_endtime = dt.now().replace(microsecond=0) endtime = dt.strftime(self.train_endtime, '%Y_%m_%d_%H_%M_%S') elapsedtime = self.train_endtime - self.train_starttime self.vp_ps.logging.best_model_fname = self.trainer.checkpoint_callback.best_model_path self.text_logger('Epoch {} - Finished training at {} after {}'.format(self.current_epoch, endtime, elapsedtime)) self.text_logger('best_model_fname: {}'.format(self.vp_ps.logging.best_model_fname)) dump_config(self.vp_ps, osp.join(self.work_dir, '{}_{}.yaml'.format(self.expr_id, self.dataset_id))) self.hparams = self.vp_ps.toDict() @rank_zero_only def prepare_data(self): '''' Similar to standard AMASS dataset preparation pipeline: Donwload npz file, corresponding to body data from https://amass.is.tue.mpg.de/ and place them under amass_dir ''' self.text_logger = log2file(makepath(self.work_dir, '{}.log'.format(self.expr_id), isfile=True), prefix=self._log_prefix) prepare_vposer_datasets(self.dataset_dir, self.vp_ps.data_parms.amass_splits, self.vp_ps.data_parms.amass_dir, logger=self.text_logger) def create_expr_message(ps): expr_msg = '[{}] batch_size = {}.'.format(ps.general.expr_id, ps.train_parms.batch_size) return expr_msg def train_vposer_once(_config): resume_training_if_possible = True model = VPoserTrainer(_config) model.vp_ps.logging.expr_msg = create_expr_message(model.vp_ps) # model.text_logger(model.vp_ps.logging.expr_msg.replace(". ", '.\n')) dump_config(model.vp_ps, osp.join(model.work_dir, '{}.yaml'.format(model.expr_id))) logger = TensorBoardLogger(model.work_dir, name='tensorboard') lr_monitor = LearningRateMonitor() snapshots_dir = osp.join(model.work_dir, 'snapshots') checkpoint_callback = ModelCheckpoint( dirpath=makepath(snapshots_dir, isfile=True), filename="%s_{epoch:02d}_{val_loss:.2f}" % model.expr_id, save_top_k=1, verbose=True, monitor='val_loss', mode='min', ) early_stop_callback = EarlyStopping(**model.vp_ps.train_parms.early_stopping) resume_from_checkpoint = None if resume_training_if_possible: available_ckpts = sorted(glob.glob(osp.join(snapshots_dir, '*.ckpt')), key=os.path.getmtime) if len(available_ckpts)>0: resume_from_checkpoint = available_ckpts[-1] model.text_logger('Resuming the training from {}'.format(resume_from_checkpoint)) trainer = pl.Trainer(gpus=1, weights_summary='top', distributed_backend = 'ddp', # replace_sampler_ddp=False, # accumulate_grad_batches=4, # profiler=False, # overfit_batches=0.05, # fast_dev_run = True, # limit_train_batches=0.02, # limit_val_batches=0.02, # num_sanity_val_steps=2, plugins=[DDPPlugin(find_unused_parameters=False)], callbacks=[lr_monitor, early_stop_callback, checkpoint_callback], max_epochs=model.vp_ps.train_parms.num_epochs, logger=logger, resume_from_checkpoint=resume_from_checkpoint ) trainer.fit(model)