Leitifel's picture
Upload 165 files
899324d verified
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: [email protected]
import time
from typing import Optional
from torch import Tensor
import smplx
from .base import Datastruct, dataclass, Transform
from .rots2rfeats import Rots2Rfeats, Globalvelandy
from .rots2joints import Rots2Joints, SMPLH
from .joints2jfeats import Joints2Jfeats
class SMPLTransform(Transform):
def __init__(self, batch_size=16, rots2rfeats: Rots2Rfeats = None,
rots2joints: Rots2Joints = None,
joints2jfeats: Joints2Jfeats = None,
**kwargs):
if rots2rfeats == None:
rots2rfeats = Globalvelandy(path='./data_loaders/amass/transforms/rots2rfeats/globalvelandy/rot6d/babel-amass/separate_pairs',
normalization=True,
pose_rep='rot6d',
canonicalize=True,
offset=True,
name='Globalvelandy')
if rots2joints == None:
rots2joints = SMPLH(path='./body_models/smplh',
jointstype='smplnh',
input_pose_rep='matrix',
batch_size=batch_size,
gender='male',
name='SMPLH')
if joints2jfeats == None:
joints2jfeats = None # FIXME : prob not it use
self.rots2rfeats = rots2rfeats
self.rots2joints = rots2joints
self.joints2jfeats = joints2jfeats
def Datastruct(self, **kwargs):
return SMPLDatastruct(_rots2rfeats=self.rots2rfeats,
_rots2joints=self.rots2joints,
_joints2jfeats=self.joints2jfeats,
transforms=self,
**kwargs)
def __repr__(self):
return "SMPLTransform()"
class SlimSMPLTransform(Transform):
def __init__(self, batch_size=16, rots2rfeats: Rots2Rfeats = None,
rots2joints: Rots2Joints = None,
**kwargs):
if rots2rfeats == None:
rots2rfeats = Globalvelandy(path='./data_loaders/amass/transforms/rots2rfeats/globalvelandy/rot6d/babel-amass/separate_pairs',
normalization=True,
pose_rep='rot6d',
canonicalize=kwargs.get("canonicalize", True),
offset=True,
name='Globalvelandy')
if rots2joints == None:
rots2joints = SMPLH(path='./body_models/smplh',
jointstype='smplnh',
input_pose_rep='matrix',
batch_size=batch_size,
gender='male',
name='SMPLH')
self.rots2rfeats = rots2rfeats
self.rots2joints = rots2joints
def SlimDatastruct(self, **kwargs):
return SlimSMPLDatastruct(_rots2rfeats=self.rots2rfeats,
_rots2joints=self.rots2joints,
transforms=self,
**kwargs)
def __repr__(self):
return "SlimSMPLTransform()"
class RotIdentityTransform(Transform):
def __init__(self, **kwargs):
return
def Datastruct(self, **kwargs):
return RotTransDatastruct(**kwargs)
def __repr__(self):
return "RotIdentityTransform()"
@dataclass
class RotTransDatastruct(Datastruct):
rots: Tensor
trans: Tensor
transforms: RotIdentityTransform = RotIdentityTransform()
def __post_init__(self):
self.datakeys = ["rots", "trans"]
def __len__(self):
return len(self.rots)
@dataclass
class SMPLDatastruct(Datastruct):
transforms: SMPLTransform
_rots2rfeats: Rots2Rfeats
_rots2joints: Rots2Joints
_joints2jfeats: Joints2Jfeats
features: Optional[Tensor] = None
rots_: Optional[RotTransDatastruct] = None
rfeats_: Optional[Tensor] = None
joints_: Optional[Tensor] = None
jfeats_: Optional[Tensor] = None
vertices_: Optional[Tensor] = None
def __post_init__(self):
self.datakeys = ['features', 'rots_', 'rfeats_',
'joints_', 'jfeats_', 'vertices_']
# starting point
if self.features is not None and self.rfeats_ is None:
self.rfeats_ = self.features
@property
def rots(self):
# Cached value
if self.rots_ is not None:
return self.rots_
# self.rfeats_ should be defined
assert self.rfeats_ is not None
self._rots2rfeats.to(self.rfeats.device)
self.rots_ = self._rots2rfeats.inverse(self.rfeats)
return self.rots_
@property
def rfeats(self):
# Cached value
if self.rfeats_ is not None:
return self.rfeats_
# self.rots_ should be defined
assert self.rots_ is not None
self._rots2rfeats.to(self.rots.device)
self.rfeats_ = self._rots2rfeats(self.rots)
return self.rfeats_
@property
def joints(self):
# Cached value
if self.joints_ is not None:
return self.joints_
self._rots2joints.to(self.rots.device)
self.joints_ = self._rots2joints(self.rots)
return self.joints_
@property
def jfeats(self):
# Cached value
if self.jfeats_ is not None:
return self.jfeats_
self._joints2jfeats.to(self.joints.device)
self.jfeats_ = self._joints2jfeats(self.joints)
return self.jfeats_
@property
def vertices(self):
# Cached value
if self.vertices_ is not None:
return self.vertices_
self._rots2joints.to(self.rots.device)
self.vertices_ = self._rots2joints(self.rots, jointstype="vertices")
return self.vertices_
def __len__(self):
return len(self.rfeats)
@dataclass
class SlimSMPLDatastruct(Datastruct):
transforms: SlimSMPLTransform
_rots2rfeats: Rots2Rfeats
_rots2joints: Rots2Joints
features: Optional[Tensor] = None
rots_: Optional[RotTransDatastruct] = None
rfeats_: Optional[Tensor] = None
joints_: Optional[Tensor] = None
vertices_: Optional[Tensor] = None
def __post_init__(self):
self.datakeys = ['features', 'rots_', 'joints_', 'rfeats_']
# starting point
if self.features is not None and self.rfeats_ is None:
self.rfeats_ = self.features
@property
def rots(self):
# Cached value
if self.rots_ is not None:
return self.rots_
# self.rfeats_ should be defined
assert self.rfeats_ is not None
self._rots2rfeats.to(self.rfeats.device)
self.rots_ = self._rots2rfeats.inverse(self.rfeats)
return self.rots_
@property
def rfeats(self):
# Cached value
if self.rfeats_ is not None:
return self.rfeats_
# self.rots_ should be defined
assert self.rots_ is not None
self._rots2rfeats.to(self.rots.device)
self.rfeats_ = self._rots2rfeats(self.rots)
return self.rfeats_
import time
@property
def joints(self):
# Cached value
if self.joints_ is not None:
return self.joints_
self._rots2joints.to(self.rots.device)
# t0 = time.time()
self.joints_ = self._rots2joints(self.rots)
# t1 = time.time()
# print(f'rots2joints :: {t1-t0}')
return self.joints_
# @property
# def jfeats(self):
# # Cached value
# if self.jfeats_ is not None:
# return self.jfeats_
#
# self._joints2jfeats.to(self.joints.device)
# self.jfeats_ = self._joints2jfeats(self.joints)
# return self.jfeats_
#
@property
def vertices(self):
# Cached value
if self.vertices_ is not None:
return self.vertices_
self._rots2joints.to(self.rots.device)
self.vertices_ = self._rots2joints(self.rots, jointstype="vertices")
return self.vertices_
@property
def faces(self):
return self._rots2joints.faces
def __len__(self):
return len(self.rfeats)
def get_body_model(model_type, gender, batch_size, device='cpu', ext='pkl'):
'''
type: smpl, smplx smplh and others. Refer to smplx tutorial
gender: male, female, neutral
batch_size: an positive integar
'''
mtype = model_type.upper()
if gender != 'neutral':
if not isinstance(gender, str):
gender = str(gender.astype(str)).upper()
else:
gender = gender.upper()
else:
gender = gender.upper()
ext = 'npz'
body_model_path = f'body_models/{model_type}/{mtype}_{gender}.{ext}'
body_model = smplx.create(body_model_path, model_type=type,
gender=gender, ext=ext,
use_pca=False,
num_pca_comps=12,
create_global_orient=True,
create_body_pose=True,
create_betas=True,
create_left_hand_pose=True,
create_right_hand_pose=True,
create_expression=True,
create_jaw_pose=True,
create_leye_pose=True,
create_reye_pose=True,
create_transl=True,
batch_size=batch_size)
if device == 'cuda':
return body_model.cuda()
else:
return body_model