File size: 14,657 Bytes
373af33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
# -*- coding: utf-8 -*-
#
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
# Max Planck Institute for Biological Cybernetics. All rights reserved.
#
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
# on this computer program. You can only use this computer program if you have closed a license agreement
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
# Contact: [email protected]
#
#
# If you use this code in a research publication please consider citing the following:
#
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
#
#
# Code Developed by:
# Nima Ghorbani <https://nghorbani.github.io/>
#
# 2020.12.12

# from pytorch_lightning import Trainer

import glob
import os
import os.path as osp
from datetime import datetime as dt
from pytorch_lightning.plugins import DDPPlugin

import numpy as np
import pytorch_lightning as pl
import torch
from human_body_prior.body_model.body_model import BodyModel
from human_body_prior.data.dataloader import VPoserDS
from human_body_prior.data.prepare_data import dataset_exists
from human_body_prior.data.prepare_data import prepare_vposer_datasets
from human_body_prior.models.vposer_model import VPoser
from human_body_prior.tools.angle_continuous_repres import geodesic_loss_R
from human_body_prior.tools.configurations import load_config, dump_config
from human_body_prior.tools.omni_tools import copy2cpu as c2c
from human_body_prior.tools.omni_tools import get_support_data_dir
from human_body_prior.tools.omni_tools import log2file
from human_body_prior.tools.omni_tools import make_deterministic
from human_body_prior.tools.omni_tools import makepath
from human_body_prior.tools.rotation_tools import aa2matrot
from human_body_prior.visualizations.training_visualization import vposer_trainer_renderer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.core import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities import rank_zero_only
from torch import optim as optim_module
from torch.optim import lr_scheduler as lr_sched_module
from torch.utils.data import DataLoader


class VPoserTrainer(LightningModule):
    """

    It includes all data loading and train / val logic., and it is used for both training and testing models.
    """

    def __init__(self, _config):
        super(VPoserTrainer, self).__init__()

        _support_data_dir = get_support_data_dir()

        vp_ps = load_config(**_config)

        make_deterministic(vp_ps.general.rnd_seed)

        self.expr_id = vp_ps.general.expr_id
        self.dataset_id = vp_ps.general.dataset_id

        self.work_dir = vp_ps.logging.work_dir = makepath(vp_ps.general.work_basedir, self.expr_id)
        self.dataset_dir = vp_ps.logging.dataset_dir = osp.join(vp_ps.general.dataset_basedir, vp_ps.general.dataset_id)

        self._log_prefix = '[{}]'.format(self.expr_id)
        self.text_logger = log2file(prefix=self._log_prefix)

        self.seq_len = vp_ps.data_parms.num_timeseq_frames

        self.vp_model = VPoser(vp_ps)

        with torch.no_grad():

            self.bm_train = BodyModel(vp_ps.body_model.bm_fname)

        if vp_ps.logging.render_during_training:
            self.renderer = vposer_trainer_renderer(self.bm_train, vp_ps.logging.num_bodies_to_display)
        else:
            self.renderer = None

        self.example_input_array = {'pose_body':torch.ones(vp_ps.train_parms.batch_size, 63),}
        self.vp_ps = vp_ps

    def forward(self, pose_body):

        return self.vp_model(pose_body)

    def _get_data(self, split_name):

        assert split_name in ('train', 'vald', 'test')

        split_name = split_name.replace('vald', 'vald')

        assert dataset_exists(self.dataset_dir), FileNotFoundError('Dataset does not exist dataset_dir = {}'.format(self.dataset_dir))
        dataset = VPoserDS(osp.join(self.dataset_dir, split_name), data_fields = ['pose_body'])

        assert len(dataset) != 0, ValueError('Dataset has nothing in it!')

        return DataLoader(dataset,
                          batch_size=self.vp_ps.train_parms.batch_size,
                          shuffle=True if split_name == 'train' else False,
                          num_workers=self.vp_ps.data_parms.num_workers,
                          pin_memory=True)

    @rank_zero_only
    def on_train_start(self):
        if self.global_rank != 0: return
        self.train_starttime = dt.now().replace(microsecond=0)

        ######## make a backup of vposer
        git_repo_dir = os.path.abspath(__file__).split('/')
        git_repo_dir = '/'.join(git_repo_dir[:git_repo_dir.index('human_body_prior') + 1])
        starttime = dt.strftime(self.train_starttime, '%Y_%m_%d_%H_%M_%S')
        archive_path = makepath(self.work_dir, 'code', 'vposer_{}.tar.gz'.format(starttime), isfile=True)
        cmd = 'cd %s && git ls-files -z | xargs -0 tar -czf %s' % (git_repo_dir, archive_path)
        os.system(cmd)
        ########
        self.text_logger('Created a git archive backup at {}'.format(archive_path))
        dump_config(self.vp_ps, osp.join(self.work_dir, '{}.yaml'.format(self.expr_id)))

    def train_dataloader(self):
        return self._get_data('train')

    def val_dataloader(self):
        return self._get_data('vald')

    def configure_optimizers(self):
        params_count = lambda params: sum(p.numel() for p in params if p.requires_grad)

        gen_params = [a[1] for a in self.vp_model.named_parameters() if a[1].requires_grad]
        gen_optimizer_class = getattr(optim_module, self.vp_ps.train_parms.gen_optimizer.type)
        gen_optimizer = gen_optimizer_class(gen_params, **self.vp_ps.train_parms.gen_optimizer.args)

        self.text_logger('Total Trainable Parameters Count in vp_model is %2.2f M.' % (params_count(gen_params) * 1e-6))

        lr_sched_class = getattr(lr_sched_module, self.vp_ps.train_parms.lr_scheduler.type)

        gen_lr_scheduler = lr_sched_class(gen_optimizer, **self.vp_ps.train_parms.lr_scheduler.args)

        schedulers = [
            {
                'scheduler': gen_lr_scheduler,
                'monitor': 'val_loss',
                'interval': 'epoch',
                'frequency': 1
            },
        ]
        return [gen_optimizer], schedulers

    def _compute_loss(self, dorig, drec):
        l1_loss = torch.nn.L1Loss(reduction='mean')
        geodesic_loss = geodesic_loss_R(reduction='mean')

        bs, latentD = drec['poZ_body_mean'].shape
        device = drec['poZ_body_mean'].device

        loss_kl_wt = self.vp_ps.train_parms.loss_weights.loss_kl_wt
        loss_rec_wt = self.vp_ps.train_parms.loss_weights.loss_rec_wt
        loss_matrot_wt = self.vp_ps.train_parms.loss_weights.loss_matrot_wt
        loss_jtr_wt = self.vp_ps.train_parms.loss_weights.loss_jtr_wt

        # q_z = torch.distributions.normal.Normal(drec['mean'], drec['std'])
        q_z = drec['q_z']
        # dorig['fullpose'] = torch.cat([dorig['root_orient'], dorig['pose_body']], dim=-1)

        # Reconstruction loss - L1 on the output mesh
        with torch.no_grad():
            bm_orig = self.bm_train(pose_body=dorig['pose_body'])

        bm_rec = self.bm_train(pose_body=drec['pose_body'].contiguous().view(bs, -1))

        v2v = l1_loss(bm_rec.v, bm_orig.v)

        # KL loss
        p_z = torch.distributions.normal.Normal(
            loc=torch.zeros((bs, latentD), device=device, requires_grad=False),
            scale=torch.ones((bs, latentD), device=device, requires_grad=False))
        weighted_loss_dict = {
            'loss_kl':loss_kl_wt * torch.mean(torch.sum(torch.distributions.kl.kl_divergence(q_z, p_z), dim=[1])),
            'loss_mesh_rec': loss_rec_wt * v2v
        }

        if (self.current_epoch < self.vp_ps.train_parms.keep_extra_loss_terms_until_epoch):
            # breakpoint()
            weighted_loss_dict['matrot'] = loss_matrot_wt * geodesic_loss(drec['pose_body_matrot'].view(-1,3,3), aa2matrot(dorig['pose_body'].view(-1, 3)))
            weighted_loss_dict['jtr'] = loss_jtr_wt * l1_loss(bm_rec.Jtr, bm_orig.Jtr)

        weighted_loss_dict['loss_total'] = torch.stack(list(weighted_loss_dict.values())).sum()

        with torch.no_grad():
            unweighted_loss_dict = {'v2v': torch.sqrt(torch.pow(bm_rec.v-bm_orig.v, 2).sum(-1)).mean()}
            unweighted_loss_dict['loss_total'] = torch.cat(
                list({k: v.view(-1) for k, v in unweighted_loss_dict.items()}.values()), dim=-1).sum().view(1)

        return {'weighted_loss': weighted_loss_dict, 'unweighted_loss': unweighted_loss_dict}

    def training_step(self, batch, batch_idx, optimizer_idx=None):

        drec = self(batch['pose_body'].view(-1, 63))

        loss = self._compute_loss(batch, drec)

        train_loss = loss['weighted_loss']['loss_total']

        tensorboard_logs = {'train_loss': train_loss}
        progress_bar = {k: c2c(v) for k, v in loss['weighted_loss'].items()}
        return {'loss': train_loss, 'progress_bar':progress_bar,  'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):

        drec = self(batch['pose_body'].view(-1, 63))

        loss = self._compute_loss(batch, drec)
        val_loss = loss['unweighted_loss']['loss_total']

        if self.renderer is not None and self.global_rank == 0 and batch_idx % 500==0 and np.random.rand()>0.5:
            out_fname = makepath(self.work_dir, 'renders/vald_rec_E{:03d}_It{:04d}_val_loss_{:.2f}.png'.format(self.current_epoch, batch_idx, val_loss.item()), isfile=True)
            self.renderer([batch, drec], out_fname = out_fname)
            dgen = self.vp_model.sample_poses(self.vp_ps.logging.num_bodies_to_display)
            out_fname = makepath(self.work_dir, 'renders/vald_gen_E{:03d}_I{:04d}.png'.format(self.current_epoch, batch_idx), isfile=True)
            self.renderer([dgen], out_fname = out_fname)


        progress_bar = {'v2v': val_loss}
        return {'val_loss': c2c(val_loss), 'progress_bar': progress_bar, 'log': progress_bar}

    def validation_epoch_end(self, outputs):
        metrics = {'val_loss': np.nanmean(np.concatenate([v['val_loss'] for v in outputs])) }

        if self.global_rank == 0:

            self.text_logger('Epoch {}: {}'.format(self.current_epoch, ', '.join('{}:{:.2f}'.format(k, v) for k, v in metrics.items())))
            self.text_logger('lr is {}'.format([pg['lr'] for opt in self.trainer.optimizers for pg in opt.param_groups]))

        metrics = {k: torch.as_tensor(v) for k, v in metrics.items()}

        return {'val_loss': metrics['val_loss'], 'log': metrics}


    @rank_zero_only
    def on_train_end(self):

        self.train_endtime = dt.now().replace(microsecond=0)
        endtime = dt.strftime(self.train_endtime, '%Y_%m_%d_%H_%M_%S')
        elapsedtime = self.train_endtime - self.train_starttime
        self.vp_ps.logging.best_model_fname = self.trainer.checkpoint_callback.best_model_path

        self.text_logger('Epoch {} - Finished training at {} after {}'.format(self.current_epoch, endtime, elapsedtime))
        self.text_logger('best_model_fname: {}'.format(self.vp_ps.logging.best_model_fname))

        dump_config(self.vp_ps, osp.join(self.work_dir, '{}_{}.yaml'.format(self.expr_id, self.dataset_id)))
        self.hparams = self.vp_ps.toDict()

    @rank_zero_only
    def prepare_data(self):
        '''' Similar to standard AMASS dataset preparation pipeline:
        Donwload npz file, corresponding to body data from https://amass.is.tue.mpg.de/ and place them under amass_dir
        '''
        self.text_logger = log2file(makepath(self.work_dir, '{}.log'.format(self.expr_id), isfile=True), prefix=self._log_prefix)

        prepare_vposer_datasets(self.dataset_dir, self.vp_ps.data_parms.amass_splits, self.vp_ps.data_parms.amass_dir, logger=self.text_logger)


def create_expr_message(ps):
    expr_msg = '[{}] batch_size = {}.'.format(ps.general.expr_id, ps.train_parms.batch_size)

    return expr_msg


def train_vposer_once(_config):

    resume_training_if_possible = True

    model = VPoserTrainer(_config)
    model.vp_ps.logging.expr_msg = create_expr_message(model.vp_ps)
    # model.text_logger(model.vp_ps.logging.expr_msg.replace(". ", '.\n'))
    dump_config(model.vp_ps, osp.join(model.work_dir, '{}.yaml'.format(model.expr_id)))

    logger = TensorBoardLogger(model.work_dir, name='tensorboard')
    lr_monitor = LearningRateMonitor()

    snapshots_dir = osp.join(model.work_dir, 'snapshots')
    checkpoint_callback = ModelCheckpoint(
        dirpath=makepath(snapshots_dir, isfile=True),
        filename="%s_{epoch:02d}_{val_loss:.2f}" % model.expr_id,
        save_top_k=1,
        verbose=True,
        monitor='val_loss',
        mode='min',
    )
    early_stop_callback = EarlyStopping(**model.vp_ps.train_parms.early_stopping)

    resume_from_checkpoint = None
    if resume_training_if_possible:
        available_ckpts = sorted(glob.glob(osp.join(snapshots_dir, '*.ckpt')), key=os.path.getmtime)
        if len(available_ckpts)>0:
            resume_from_checkpoint = available_ckpts[-1]
            model.text_logger('Resuming the training from {}'.format(resume_from_checkpoint))

    trainer = pl.Trainer(gpus=1,
                         weights_summary='top',
                         distributed_backend = 'ddp',
                         # replace_sampler_ddp=False,
                         # accumulate_grad_batches=4,
                         # profiler=False,
                         # overfit_batches=0.05,
                         # fast_dev_run = True,
                         # limit_train_batches=0.02,
                         # limit_val_batches=0.02,
                         # num_sanity_val_steps=2,
                         plugins=[DDPPlugin(find_unused_parameters=False)],

                         callbacks=[lr_monitor, early_stop_callback, checkpoint_callback],

                         max_epochs=model.vp_ps.train_parms.num_epochs,
                         logger=logger,
                         resume_from_checkpoint=resume_from_checkpoint
                         )

    trainer.fit(model)