HMR2.0 / hmr2 /models /smpl_wrapper.py
brjathu
Adding HF files
29a229f
raw
history blame
No virus
1.88 kB
import torch
import numpy as np
import pickle
from typing import Optional
import smplx
from smplx.lbs import vertices2joints
from smplx.utils import SMPLOutput
class SMPL(smplx.SMPLLayer):
def __init__(self, *args, joint_regressor_extra: Optional[str] = None, update_hips: bool = False, **kwargs):
"""
Extension of the official SMPL implementation to support more joints.
Args:
Same as SMPLLayer.
joint_regressor_extra (str): Path to extra joint regressor.
"""
super(SMPL, self).__init__(*args, **kwargs)
smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4,
7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]
if joint_regressor_extra is not None:
self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
self.register_buffer('joint_map', torch.tensor(smpl_to_openpose, dtype=torch.long))
self.update_hips = update_hips
def forward(self, *args, **kwargs) -> SMPLOutput:
"""
Run forward pass. Same as SMPL and also append an extra set of joints if joint_regressor_extra is specified.
"""
smpl_output = super(SMPL, self).forward(*args, **kwargs)
joints = smpl_output.joints[:, self.joint_map, :]
if self.update_hips:
joints[:,[9,12]] = joints[:,[9,12]] + \
0.25*(joints[:,[9,12]]-joints[:,[12,9]]) + \
0.5*(joints[:,[8]] - 0.5*(joints[:,[9,12]] + joints[:,[12,9]]))
if hasattr(self, 'joint_regressor_extra'):
extra_joints = vertices2joints(self.joint_regressor_extra, smpl_output.vertices)
joints = torch.cat([joints, extra_joints], dim=1)
smpl_output.joints = joints
return smpl_output