import torch def to_numpy(tensor): if torch.is_tensor(tensor): return tensor.cpu().numpy() elif type(tensor).__module__ != 'numpy': raise ValueError("Cannot convert {} to numpy array".format( type(tensor))) return tensor def to_torch(ndarray): if type(ndarray).__module__ == 'numpy': return torch.from_numpy(ndarray) elif not torch.is_tensor(ndarray): raise ValueError("Cannot convert {} to torch tensor".format( type(ndarray))) return ndarray def cleanexit(): import sys import os try: sys.exit(0) except SystemExit: os._exit(0) def freeze_joints(x, joints_to_freeze): # Freezes selected joint *rotations* as they appear in the first frame # x [bs, [root+n_joints], joint_dim(6), seqlen] frozen = x.detach().clone() frozen[:, joints_to_freeze, :, :] = frozen[:, joints_to_freeze, :, :1] return frozen