diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..283fbfa5f5b59c93d9e4e77879c65a409bbe2afc 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.mp4 filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index 98a49b56b246188ff059169ea200a675b3cb450a..213bdb447883f54623dee74c08504c2b37a2e1a3 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,94 @@
----
-title: HaWoR
-emoji: 👁
-colorFrom: green
-colorTo: indigo
-sdk: gradio
-sdk_version: 5.9.1
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
+# HaWoR: World-Space Hand Motion Reconstruction from Egocentric Videos
+
+[Jinglei Zhang]()
1 [Jiankang Deng](https://jiankangdeng.github.io/)
2 [Chao Ma](https://scholar.google.com/citations?user=syoPhv8AAAAJ&hl=en)
1 [Rolandos Alexandros Potamias](https://rolpotamias.github.io)
2
+
+
1Shanghai Jiao Tong University, China
+
2Imperial College London, UK
+
+
+
+
+
+This is the official implementation of **[HaWoR](https://hawor-project.github.io/)**, a hand reconstruction model in the world coordinates:
+
+![teaser](assets/teaser.png)
+
+## Installation
+
+### Installation
+```
+git clone --recursive https://github.com/ThunderVVV/HaWoR.git
+cd HaWoR
+```
+
+The code has been tested with PyTorch 1.13 and CUDA 11.7. It is suggested to use an anaconda environment to install the the required dependencies:
+```bash
+conda create --name hawor python=3.10
+conda activate hawor
+
+pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
+# Install requirements
+pip install -r requirements.txt
+pip install pytorch-lightning==2.2.4 --no-deps
+pip install lightning-utilities torchmetrics==1.4.0
+```
+
+### Install masked DROID-SLAM:
+
+```
+cd thirdparty/DROID-SLAM
+python setup.py install
+```
+
+Download DROID-SLAM official weights [droid.pth](https://drive.google.com/file/d/1PpqVt1H4maBa_GbPJp4NwxRsd9jk-elh/view?usp=sharing), put it under `./weights/external/`.
+
+### Install Metric3D
+
+Download Metric3D official weights [metric_depth_vit_large_800k.pth](https://drive.google.com/file/d/1eT2gG-kwsVzNy5nJrbm4KC-9DbNKyLnr/view?usp=drive_link), put it under `thirdparty/Metric3D/weights`.
+
+### Download the model weights
+
+```bash
+wget https://huggingface.co/spaces/rolpotamias/WiLoR/resolve/main/pretrained_models/detector.pt -P ./weights/external/
+wget https://huggingface.co/ThunderVVV/HaWoR/resolve/main/hawor/checkpoints/hawor.ckpt -P ./weights/hawor/checkpoints/
+wget https://huggingface.co/ThunderVVV/HaWoR/resolve/main/hawor/checkpoints/infiller.pt -P ./weights/hawor/checkpoints/
+wget https://huggingface.co/ThunderVVV/HaWoR/resolve/main/hawor/model_config.yaml -P ./weights/hawor/
+```
+It is also required to download MANO model from [MANO website](https://mano.is.tue.mpg.de).
+Create an account by clicking Sign Up and download the models (mano_v*_*.zip). Unzip and put the hand model to the `_DATA/data/mano/MANO_RIGHT.pkl` and `_DATA/data_left/mano_left/MANO_LEFT.pkl`.
+
+Note that MANO model falls under the [MANO license](https://mano.is.tue.mpg.de/license.html).
+## Demo
+
+For visualizaiton in world view, run with:
+```bash
+python demo.py --video_path ./example/video_0.mp4 --vis_mode world
+```
+
+For visualizaiton in camera view, run with:
+```bash
+python demo.py --video_path ./example/video_0.mp4 --vis_mode cam
+```
+
+## Training
+The training code will be released soon.
+
+## Acknowledgements
+Parts of the code are taken or adapted from the following repos:
+- [HaMeR](https://github.com/geopavlakos/hamer/)
+- [WiLoR](https://github.com/rolpotamias/WiLoR)
+- [SLAHMR](https://github.com/vye16/slahmr)
+- [TRAM](https://github.com/yufu-wang/tram)
+- [CMIB](https://github.com/jihoonerd/Conditional-Motion-In-Betweening)
+
+
+## License
+HaWoR models fall under the [CC-BY-NC--ND License](./license.txt). This repository depends also on [MANO Model](https://mano.is.tue.mpg.de/license.html), which are fall under their own licenses. By using this repository, you must also comply with the terms of these external licenses.
+## Citing
+If you find HaWoR useful for your research, please consider citing our paper:
+
+```bibtex
+
+```
diff --git a/_DATA/data/mano/.gitkeep b/_DATA/data/mano/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/_DATA/data/mano/MANO_RIGHT.pkl b/_DATA/data/mano/MANO_RIGHT.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..8e7ac7faf64ad51096ec1da626ea13757ed7f665
--- /dev/null
+++ b/_DATA/data/mano/MANO_RIGHT.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:45d60aa3b27ef9107a7afd4e00808f307fd91111e1cfa35afd5c4a62de264767
+size 3821356
diff --git a/_DATA/data/mano_mean_params.npz b/_DATA/data/mano_mean_params.npz
new file mode 100644
index 0000000000000000000000000000000000000000..dc294b01fb78a9cd6636c87a69b59cf82d28d15b
--- /dev/null
+++ b/_DATA/data/mano_mean_params.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:efc0ec58e4a5cef78f3abfb4e8f91623b8950be9eff8b8e0dbb0d036ebc63988
+size 1178
diff --git a/_DATA/data_left/mano_left/.gitkeep b/_DATA/data_left/mano_left/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/_DATA/data_left/mano_left/MANO_LEFT.pkl b/_DATA/data_left/mano_left/MANO_LEFT.pkl
new file mode 100755
index 0000000000000000000000000000000000000000..32cdc533e2c01ed4995db2dc1302520d7d374c5a
--- /dev/null
+++ b/_DATA/data_left/mano_left/MANO_LEFT.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4022f7083f2ca7c78b2b3d595abbab52debd32b09d372b16923a801f0ea6a30
+size 3821391
diff --git a/assets/teaser.png b/assets/teaser.png
new file mode 100644
index 0000000000000000000000000000000000000000..1f3e4905bb6fc0d2d4f09966e379f78ba92c39b4
--- /dev/null
+++ b/assets/teaser.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6b33d76e9a10f215f0777612dd32ac73a5ce3b0e8735813968e7048ecd1ed3a1
+size 1118621
diff --git a/demo.py b/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a437853015f84e74a059289f217b5bc12a00ed5
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,113 @@
+import argparse
+import sys
+import os
+
+import torch
+sys.path.insert(0, os.path.dirname(__file__))
+import numpy as np
+import joblib
+from scripts.scripts_test_video.detect_track_video import detect_track_video
+from scripts.scripts_test_video.hawor_video import hawor_motion_estimation, hawor_infiller
+from scripts.scripts_test_video.hawor_slam import hawor_slam
+from hawor.utils.process import get_mano_faces, run_mano, run_mano_left
+from lib.eval_utils.custom_utils import load_slam_cam
+from lib.vis.run_vis2 import run_vis2_on_video, run_vis2_on_video_cam
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--img_focal", type=float)
+ parser.add_argument("--video_path", type=str, default='example/video_0.mp4')
+ parser.add_argument("--input_type", type=str, default='file')
+ parser.add_argument("--checkpoint", type=str, default='./weights/hawor/checkpoints/hawor.ckpt')
+ parser.add_argument("--infiller_weight", type=str, default='./weights/hawor/checkpoints/infiller.pt')
+ parser.add_argument("--vis_mode", type=str, default='world', help='cam | world')
+ args = parser.parse_args()
+
+ start_idx, end_idx, seq_folder, imgfiles = detect_track_video(args)
+
+ frame_chunks_all, img_focal = hawor_motion_estimation(args, start_idx, end_idx, seq_folder)
+
+ hawor_slam(args, start_idx, end_idx)
+ slam_path = os.path.join(seq_folder, f"SLAM/hawor_slam_w_scale_{start_idx}_{end_idx}.npz")
+ R_w2c_sla_all, t_w2c_sla_all, R_c2w_sla_all, t_c2w_sla_all = load_slam_cam(slam_path)
+
+ pred_trans, pred_rot, pred_hand_pose, pred_betas, pred_valid = hawor_infiller(args, start_idx, end_idx, frame_chunks_all)
+
+ # vis sequence for this video
+ hand2idx = {
+ "right": 1,
+ "left": 0
+ }
+ vis_start = 0
+ vis_end = pred_trans.shape[1] - 1
+
+ # get faces
+ faces = get_mano_faces()
+ faces_new = np.array([[92, 38, 234],
+ [234, 38, 239],
+ [38, 122, 239],
+ [239, 122, 279],
+ [122, 118, 279],
+ [279, 118, 215],
+ [118, 117, 215],
+ [215, 117, 214],
+ [117, 119, 214],
+ [214, 119, 121],
+ [119, 120, 121],
+ [121, 120, 78],
+ [120, 108, 78],
+ [78, 108, 79]])
+ faces_right = np.concatenate([faces, faces_new], axis=0)
+
+ # get right hand vertices
+ hand = 'right'
+ hand_idx = hand2idx[hand]
+ pred_glob_r = run_mano(pred_trans[hand_idx:hand_idx+1, vis_start:vis_end], pred_rot[hand_idx:hand_idx+1, vis_start:vis_end], pred_hand_pose[hand_idx:hand_idx+1, vis_start:vis_end], betas=pred_betas[hand_idx:hand_idx+1, vis_start:vis_end])
+ right_verts = pred_glob_r['vertices'][0]
+ right_dict = {
+ 'vertices': right_verts.unsqueeze(0),
+ 'faces': faces_right,
+ }
+
+ # get left hand vertices
+ faces_left = faces_right[:,[0,2,1]]
+ hand = 'left'
+ hand_idx = hand2idx[hand]
+ pred_glob_l = run_mano_left(pred_trans[hand_idx:hand_idx+1, vis_start:vis_end], pred_rot[hand_idx:hand_idx+1, vis_start:vis_end], pred_hand_pose[hand_idx:hand_idx+1, vis_start:vis_end], betas=pred_betas[hand_idx:hand_idx+1, vis_start:vis_end])
+ left_verts = pred_glob_l['vertices'][0]
+ left_dict = {
+ 'vertices': left_verts.unsqueeze(0),
+ 'faces': faces_left,
+ }
+
+ R_x = torch.tensor([[1, 0, 0],
+ [0, -1, 0],
+ [0, 0, -1]]).float()
+ R_c2w_sla_all = torch.einsum('ij,njk->nik', R_x, R_c2w_sla_all)
+ t_c2w_sla_all = torch.einsum('ij,nj->ni', R_x, t_c2w_sla_all)
+ R_w2c_sla_all = R_c2w_sla_all.transpose(-1, -2)
+ t_w2c_sla_all = -torch.einsum("bij,bj->bi", R_w2c_sla_all, t_c2w_sla_all)
+ left_dict['vertices'] = torch.einsum('ij,btnj->btni', R_x, left_dict['vertices'].cpu())
+ right_dict['vertices'] = torch.einsum('ij,btnj->btni', R_x, right_dict['vertices'].cpu())
+
+ # Here we use aitviewer(https://github.com/eth-ait/aitviewer) for simple visualization.
+ if args.vis_mode == 'world':
+ output_pth = os.path.join(seq_folder, f"vis_{vis_start}_{vis_end}")
+ if not os.path.exists(output_pth):
+ os.makedirs(output_pth)
+ image_names = imgfiles[vis_start:vis_end]
+ print(f"vis {vis_start} to {vis_end}")
+ run_vis2_on_video(left_dict, right_dict, output_pth, img_focal, image_names, R_c2w=R_c2w_sla_all[vis_start:vis_end], t_c2w=t_c2w_sla_all[vis_start:vis_end])
+ elif args.vis_mode == 'cam':
+ output_pth = os.path.join(seq_folder, f"vis_{vis_start}_{vis_end}")
+ if not os.path.exists(output_pth):
+ os.makedirs(output_pth)
+ image_names = imgfiles[vis_start:vis_end]
+ print(f"vis {vis_start} to {vis_end}")
+ run_vis2_on_video_cam(left_dict, right_dict, output_pth, img_focal, image_names, R_w2c=R_w2c_sla_all[vis_start:vis_end], t_w2c=t_w2c_sla_all[vis_start:vis_end])
+
+ print("finish")
+
+
+
diff --git a/example/video_0.mp4 b/example/video_0.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..69082314e091591656cfff9993853abc208a9568
--- /dev/null
+++ b/example/video_0.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:13ff124a68e4b48190e0c3f0ce9f38db59c5e3bb8a093b3c7fc9c67276be2062
+size 6515891
diff --git a/hawor/configs/__init__.py b/hawor/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..569bf1675619d2765c8135482bfdd9bebb032482
--- /dev/null
+++ b/hawor/configs/__init__.py
@@ -0,0 +1,120 @@
+import os
+from typing import Dict
+from yacs.config import CfgNode as CN
+
+CACHE_DIR_HAWOR = "./_DATA"
+
+def to_lower(x: Dict) -> Dict:
+ """
+ Convert all dictionary keys to lowercase
+ Args:
+ x (dict): Input dictionary
+ Returns:
+ dict: Output dictionary with all keys converted to lowercase
+ """
+ return {k.lower(): v for k, v in x.items()}
+
+_C = CN(new_allowed=True)
+
+_C.GENERAL = CN(new_allowed=True)
+_C.GENERAL.RESUME = True
+_C.GENERAL.TIME_TO_RUN = 3300
+_C.GENERAL.VAL_STEPS = 100
+_C.GENERAL.LOG_STEPS = 100
+_C.GENERAL.CHECKPOINT_STEPS = 20000
+_C.GENERAL.CHECKPOINT_DIR = "checkpoints"
+_C.GENERAL.SUMMARY_DIR = "tensorboard"
+_C.GENERAL.NUM_GPUS = 1
+_C.GENERAL.NUM_WORKERS = 4
+_C.GENERAL.MIXED_PRECISION = True
+_C.GENERAL.ALLOW_CUDA = True
+_C.GENERAL.PIN_MEMORY = False
+_C.GENERAL.DISTRIBUTED = False
+_C.GENERAL.LOCAL_RANK = 0
+_C.GENERAL.USE_SYNCBN = False
+_C.GENERAL.WORLD_SIZE = 1
+
+_C.TRAIN = CN(new_allowed=True)
+_C.TRAIN.NUM_EPOCHS = 100
+_C.TRAIN.BATCH_SIZE = 32
+_C.TRAIN.SHUFFLE = True
+_C.TRAIN.WARMUP = False
+_C.TRAIN.NORMALIZE_PER_IMAGE = False
+_C.TRAIN.CLIP_GRAD = False
+_C.TRAIN.CLIP_GRAD_VALUE = 1.0
+_C.LOSS_WEIGHTS = CN(new_allowed=True)
+
+_C.DATASETS = CN(new_allowed=True)
+
+_C.MODEL = CN(new_allowed=True)
+_C.MODEL.IMAGE_SIZE = 224
+
+_C.EXTRA = CN(new_allowed=True)
+_C.EXTRA.FOCAL_LENGTH = 5000
+
+_C.DATASETS.CONFIG = CN(new_allowed=True)
+_C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
+_C.DATASETS.CONFIG.ROT_FACTOR = 30
+_C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
+_C.DATASETS.CONFIG.COLOR_SCALE = 0.2
+_C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
+_C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
+_C.DATASETS.CONFIG.DO_FLIP = False
+_C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
+_C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
+
+def default_config() -> CN:
+ """
+ Get a yacs CfgNode object with the default config values.
+ """
+ # Return a clone so that the defaults will not be altered
+ # This is for the "local variable" use pattern
+ return _C.clone()
+
+def dataset_config() -> CN:
+ """
+ Get dataset config file
+ Returns:
+ CfgNode: Dataset config as a yacs CfgNode object.
+ """
+ cfg = CN(new_allowed=True)
+ config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets_tar.yaml')
+ cfg.merge_from_file(config_file)
+ cfg.freeze()
+ return cfg
+
+def dataset_eval_config() -> CN:
+ cfg = CN(new_allowed=True)
+ config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets_eval.yaml')
+ cfg.merge_from_file(config_file)
+ cfg.freeze()
+ return cfg
+
+def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN:
+ """
+ Read a config file and optionally merge it with the default config file.
+ Args:
+ config_file (str): Path to config file.
+ merge (bool): Whether to merge with the default config or not.
+ Returns:
+ CfgNode: Config as a yacs CfgNode object.
+ """
+ if merge:
+ cfg = default_config()
+ else:
+ cfg = CN(new_allowed=True)
+ cfg.merge_from_file(config_file)
+
+ if update_cachedir:
+ def update_path(path: str) -> str:
+ if os.path.basename(CACHE_DIR_HAWOR) in path:
+ return path
+ if os.path.isabs(path):
+ return path
+ return os.path.join(CACHE_DIR_HAWOR, path)
+
+ cfg.MANO.MODEL_PATH = update_path(cfg.MANO.MODEL_PATH)
+ cfg.MANO.MEAN_PARAMS = update_path(cfg.MANO.MEAN_PARAMS)
+
+ cfg.freeze()
+ return cfg
diff --git a/hawor/configs/__pycache__/__init__.cpython-310.pyc b/hawor/configs/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..61f3434ffd66806e5e80fa5eab81122a1bbb807f
Binary files /dev/null and b/hawor/configs/__pycache__/__init__.cpython-310.pyc differ
diff --git a/hawor/utils/__pycache__/geometry.cpython-310.pyc b/hawor/utils/__pycache__/geometry.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f059c3f3ff802d128d320ced5558abee3c1afdf
Binary files /dev/null and b/hawor/utils/__pycache__/geometry.cpython-310.pyc differ
diff --git a/hawor/utils/__pycache__/process.cpython-310.pyc b/hawor/utils/__pycache__/process.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9da2e4829b6ba7e66fc498b011525820bb4f7e25
Binary files /dev/null and b/hawor/utils/__pycache__/process.cpython-310.pyc differ
diff --git a/hawor/utils/__pycache__/pylogger.cpython-310.pyc b/hawor/utils/__pycache__/pylogger.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7ac89c1cd866f31319e154116338138699ab077
Binary files /dev/null and b/hawor/utils/__pycache__/pylogger.cpython-310.pyc differ
diff --git a/hawor/utils/__pycache__/render_openpose.cpython-310.pyc b/hawor/utils/__pycache__/render_openpose.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f151b9daaf3d603d87e22ee034475d5dbe7acf2
Binary files /dev/null and b/hawor/utils/__pycache__/render_openpose.cpython-310.pyc differ
diff --git a/hawor/utils/__pycache__/rotation.cpython-310.pyc b/hawor/utils/__pycache__/rotation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07c890f30559e00cf99c04ffaa821aedf48a0570
Binary files /dev/null and b/hawor/utils/__pycache__/rotation.cpython-310.pyc differ
diff --git a/hawor/utils/geometry.py b/hawor/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..1effe3d9aa66386aa3dd114f07646c6a96a5035e
--- /dev/null
+++ b/hawor/utils/geometry.py
@@ -0,0 +1,102 @@
+from typing import Optional
+import torch
+from torch.nn import functional as F
+
+def aa_to_rotmat(theta: torch.Tensor):
+ """
+ Convert axis-angle representation to rotation matrix.
+ Works by first converting it to a quaternion.
+ Args:
+ theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
+ Returns:
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
+ """
+ norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
+ angle = torch.unsqueeze(norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
+ return quat_to_rotmat(quat)
+
+def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
+ """
+ Convert quaternion representation to rotation matrix.
+ Args:
+ quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
+ Returns:
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
+ """
+ norm_quat = quat
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
+
+ B = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w*x, w*y, w*z
+ xy, xz, yz = x*y, x*z, y*z
+
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
+ return rotMat
+
+
+def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
+ """
+ Convert 6D rotation representation to 3x3 rotation matrix.
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
+ Args:
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
+ Returns:
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
+ """
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
+ a1 = x[:, :, 0]
+ a2 = x[:, :, 1]
+ b1 = F.normalize(a1)
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
+ b3 = torch.linalg.cross(b1, b2)
+ return torch.stack((b1, b2, b3), dim=-1)
+
+def perspective_projection(points: torch.Tensor,
+ translation: torch.Tensor,
+ focal_length: torch.Tensor,
+ camera_center: Optional[torch.Tensor] = None,
+ rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Computes the perspective projection of a set of 3D points.
+ Args:
+ points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
+ translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
+ focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
+ camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
+ rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
+ Returns:
+ torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
+ """
+ batch_size = points.shape[0]
+ if rotation is None:
+ rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
+ if camera_center is None:
+ camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
+ # Populate intrinsic camera matrix K.
+ K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
+ K[:,0,0] = focal_length[:,0]
+ K[:,1,1] = focal_length[:,1]
+ K[:,2,2] = 1.
+ K[:,:-1, -1] = camera_center
+
+ # Transform points
+ points = torch.einsum('bij,bkj->bki', rotation, points)
+ points = points + translation.unsqueeze(1)
+
+ # Apply perspective distortion
+ projected_points = points / points[:,:,-1].unsqueeze(-1)
+
+ # Apply camera intrinsics
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
+
+ return projected_points[:, :, :-1]
\ No newline at end of file
diff --git a/hawor/utils/process.py b/hawor/utils/process.py
new file mode 100644
index 0000000000000000000000000000000000000000..89bfc33c069b78bd20e72844b3b42e06becf0b84
--- /dev/null
+++ b/hawor/utils/process.py
@@ -0,0 +1,198 @@
+import torch
+from lib.models.mano_wrapper import MANO
+from hawor.utils.geometry import aa_to_rotmat
+import numpy as np
+import sys
+import os
+
+def block_print():
+ sys.stdout = open(os.devnull, 'w')
+
+def enable_print():
+ sys.stdout = sys.__stdout__
+
+def get_mano_faces():
+ block_print()
+ MANO_cfg = {
+ 'DATA_DIR': '_DATA/data/',
+ 'MODEL_PATH': '_DATA/data/mano',
+ 'GENDER': 'neutral',
+ 'NUM_HAND_JOINTS': 15,
+ 'CREATE_BODY_POSE': False
+ }
+ mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
+ mano = MANO(**mano_cfg)
+ enable_print()
+ return mano.faces
+
+
+def run_mano(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True):
+ """
+ Forward pass of the SMPL model and populates pred_data accordingly with
+ joints3d, verts3d, points3d.
+
+ trans : B x T x 3
+ root_orient : B x T x 3
+ body_pose : B x T x J*3
+ betas : (optional) B x D
+ """
+ block_print()
+ MANO_cfg = {
+ 'DATA_DIR': '_DATA/data/',
+ 'MODEL_PATH': '_DATA/data/mano',
+ 'GENDER': 'neutral',
+ 'NUM_HAND_JOINTS': 15,
+ 'CREATE_BODY_POSE': False
+ }
+ mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
+ mano = MANO(**mano_cfg)
+ if use_cuda:
+ mano = mano.cuda()
+
+ B, T, _ = root_orient.shape
+ NUM_JOINTS = 15
+ mano_params = {
+ 'global_orient': root_orient.reshape(B*T, -1),
+ 'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3),
+ 'betas': betas.reshape(B*T, -1),
+ }
+ rotmat_mano_params = mano_params
+ rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3)
+ rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3)
+ rotmat_mano_params['transl'] = trans.reshape(B*T, 3)
+
+ if use_cuda:
+ mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False)
+ else:
+ mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False)
+
+ faces_right = mano.faces
+ faces_new = np.array([[92, 38, 234],
+ [234, 38, 239],
+ [38, 122, 239],
+ [239, 122, 279],
+ [122, 118, 279],
+ [279, 118, 215],
+ [118, 117, 215],
+ [215, 117, 214],
+ [117, 119, 214],
+ [214, 119, 121],
+ [119, 120, 121],
+ [121, 120, 78],
+ [120, 108, 78],
+ [78, 108, 79]])
+ faces_right = np.concatenate([faces_right, faces_new], axis=0)
+ faces_n = len(faces_right)
+ faces_left = faces_right[:,[0,2,1]]
+
+ outputs = {
+ "joints": mano_output.joints.reshape(B, T, -1, 3),
+ "vertices": mano_output.vertices.reshape(B, T, -1, 3),
+ }
+
+ if not is_right is None:
+ # outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0]
+ # outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0]
+ is_right = (is_right[:, :, 0].cpu().numpy() > 0)
+ faces_result = np.zeros((B, T, faces_n, 3))
+ faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0)
+ faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0)
+ faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded)
+ outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32))
+
+
+ enable_print()
+ return outputs
+
+def run_mano_left(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True, fix_shapedirs=True):
+ """
+ Forward pass of the SMPL model and populates pred_data accordingly with
+ joints3d, verts3d, points3d.
+
+ trans : B x T x 3
+ root_orient : B x T x 3
+ body_pose : B x T x J*3
+ betas : (optional) B x D
+ """
+ block_print()
+ MANO_cfg = {
+ 'DATA_DIR': '_DATA/data_left/',
+ 'MODEL_PATH': '_DATA/data_left/mano_left',
+ 'GENDER': 'neutral',
+ 'NUM_HAND_JOINTS': 15,
+ 'CREATE_BODY_POSE': False,
+ 'is_rhand': False
+ }
+ mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
+ mano = MANO(**mano_cfg)
+ if use_cuda:
+ mano = mano.cuda()
+
+ # fix MANO shapedirs of the left hand bug (https://github.com/vchoutas/smplx/issues/48)
+ if fix_shapedirs:
+ mano.shapedirs[:, 0, :] *= -1
+
+ B, T, _ = root_orient.shape
+ NUM_JOINTS = 15
+ mano_params = {
+ 'global_orient': root_orient.reshape(B*T, -1),
+ 'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3),
+ 'betas': betas.reshape(B*T, -1),
+ }
+ rotmat_mano_params = mano_params
+ rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3)
+ rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3)
+ rotmat_mano_params['transl'] = trans.reshape(B*T, 3)
+
+ if use_cuda:
+ mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False)
+ else:
+ mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False)
+
+ faces_right = mano.faces
+ faces_new = np.array([[92, 38, 234],
+ [234, 38, 239],
+ [38, 122, 239],
+ [239, 122, 279],
+ [122, 118, 279],
+ [279, 118, 215],
+ [118, 117, 215],
+ [215, 117, 214],
+ [117, 119, 214],
+ [214, 119, 121],
+ [119, 120, 121],
+ [121, 120, 78],
+ [120, 108, 78],
+ [78, 108, 79]])
+ faces_right = np.concatenate([faces_right, faces_new], axis=0)
+ faces_n = len(faces_right)
+ faces_left = faces_right[:,[0,2,1]]
+
+ outputs = {
+ "joints": mano_output.joints.reshape(B, T, -1, 3),
+ "vertices": mano_output.vertices.reshape(B, T, -1, 3),
+ }
+
+ if not is_right is None:
+ # outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0]
+ # outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0]
+ is_right = (is_right[:, :, 0].cpu().numpy() > 0)
+ faces_result = np.zeros((B, T, faces_n, 3))
+ faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0)
+ faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0)
+ faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded)
+ outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32))
+
+
+ enable_print()
+ return outputs
+
+def run_mano_twohands(init_trans, init_rot, init_hand_pose, is_right, init_betas, use_cuda=True, fix_shapedirs=True):
+ outputs_left = run_mano_left(init_trans[0:1], init_rot[0:1], init_hand_pose[0:1], None, init_betas[0:1], use_cuda=use_cuda, fix_shapedirs=fix_shapedirs)
+ outputs_right = run_mano(init_trans[1:2], init_rot[1:2], init_hand_pose[1:2], None, init_betas[1:2], use_cuda=use_cuda)
+ outputs_two = {
+ "vertices": torch.cat((outputs_left["vertices"], outputs_right["vertices"]), dim=0),
+ "joints": torch.cat((outputs_left["joints"], outputs_right["joints"]), dim=0)
+
+ }
+ return outputs_two
\ No newline at end of file
diff --git a/hawor/utils/pylogger.py b/hawor/utils/pylogger.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ffa71893ec20acde65e44d899334a38d8d1333
--- /dev/null
+++ b/hawor/utils/pylogger.py
@@ -0,0 +1,17 @@
+import logging
+
+from pytorch_lightning.utilities import rank_zero_only
+
+
+def get_pylogger(name=__name__) -> logging.Logger:
+ """Initializes multi-GPU-friendly python command line logger."""
+
+ logger = logging.getLogger(name)
+
+ # this ensures all logging levels get marked with the rank zero decorator
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
+ logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
+ for level in logging_levels:
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
+
+ return logger
diff --git a/hawor/utils/render_openpose.py b/hawor/utils/render_openpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ffcb5c6b52cdec2058f0f3d3b2ec5b705d5b2a9
--- /dev/null
+++ b/hawor/utils/render_openpose.py
@@ -0,0 +1,225 @@
+"""
+Render OpenPose keypoints.
+Code was ported to Python from the official C++ implementation https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/utilities/keypoint.cpp
+"""
+import cv2
+import math
+import numpy as np
+from typing import List, Tuple
+
+def get_keypoints_rectangle(keypoints: np.array, threshold: float) -> Tuple[float, float, float]:
+ """
+ Compute rectangle enclosing keypoints above the threshold.
+ Args:
+ keypoints (np.array): Keypoint array of shape (N, 3).
+ threshold (float): Confidence visualization threshold.
+ Returns:
+ Tuple[float, float, float]: Rectangle width, height and area.
+ """
+ valid_ind = keypoints[:, -1] > threshold
+ if valid_ind.sum() > 0:
+ valid_keypoints = keypoints[valid_ind][:, :-1]
+ max_x = valid_keypoints[:,0].max()
+ max_y = valid_keypoints[:,1].max()
+ min_x = valid_keypoints[:,0].min()
+ min_y = valid_keypoints[:,1].min()
+ width = max_x - min_x
+ height = max_y - min_y
+ area = width * height
+ return width, height, area
+ else:
+ return 0,0,0
+
+def render_keypoints(img: np.array,
+ keypoints: np.array,
+ pairs: List,
+ colors: List,
+ thickness_circle_ratio: float,
+ thickness_line_ratio_wrt_circle: float,
+ pose_scales: List,
+ threshold: float = 0.1,
+ alpha: float = 1.0) -> np.array:
+ """
+ Render keypoints on input image.
+ Args:
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
+ keypoints (np.array): Keypoint array of shape (N, 3).
+ pairs (List): List of keypoint pairs per limb.
+ colors: (List): List of colors per keypoint.
+ thickness_circle_ratio (float): Circle thickness ratio.
+ thickness_line_ratio_wrt_circle (float): Line thickness ratio wrt the circle.
+ pose_scales (List): List of pose scales.
+ threshold (float): Only visualize keypoints with confidence above the threshold.
+ Returns:
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
+ """
+ img_orig = img.copy()
+ width, height = img.shape[1], img.shape[2]
+ area = width * height
+
+ lineType = 8
+ shift = 0
+ numberColors = len(colors)
+ thresholdRectangle = 0.1
+
+ person_width, person_height, person_area = get_keypoints_rectangle(keypoints, thresholdRectangle)
+ if person_area > 0:
+ ratioAreas = min(1, max(person_width / width, person_height / height))
+ thicknessRatio = np.maximum(np.round(math.sqrt(area) * thickness_circle_ratio * ratioAreas), 2)
+ thicknessCircle = np.maximum(1, thicknessRatio if ratioAreas > 0.05 else -np.ones_like(thicknessRatio))
+ thicknessLine = np.maximum(1, np.round(thicknessRatio * thickness_line_ratio_wrt_circle))
+ radius = thicknessRatio / 2
+
+ img = np.ascontiguousarray(img.copy())
+ for i, pair in enumerate(pairs):
+ index1, index2 = pair
+ if keypoints[index1, -1] > threshold and keypoints[index2, -1] > threshold:
+ thicknessLineScaled = int(round(min(thicknessLine[index1], thicknessLine[index2]) * pose_scales[0]))
+ colorIndex = index2
+ color = colors[colorIndex % numberColors]
+ keypoint1 = keypoints[index1, :-1].astype(int)
+ keypoint2 = keypoints[index2, :-1].astype(int)
+ cv2.line(img, tuple(keypoint1.tolist()), tuple(keypoint2.tolist()), tuple(color.tolist()), thicknessLineScaled, lineType, shift)
+ for part in range(len(keypoints)):
+ faceIndex = part
+ if keypoints[faceIndex, -1] > threshold:
+ radiusScaled = int(round(radius[faceIndex] * pose_scales[0]))
+ thicknessCircleScaled = int(round(thicknessCircle[faceIndex] * pose_scales[0]))
+ colorIndex = part
+ color = colors[colorIndex % numberColors]
+ center = keypoints[faceIndex, :-1].astype(int)
+ cv2.circle(img, tuple(center.tolist()), radiusScaled, tuple(color.tolist()), thicknessCircleScaled, lineType, shift)
+ return img
+
+def render_hand_keypoints(img, right_hand_keypoints, threshold=0.1, use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0):
+ if use_confidence and map_fn is not None:
+ #thicknessCircleRatioLeft = 1./50 * map_fn(left_hand_keypoints[:, -1])
+ thicknessCircleRatioRight = 1./50 * map_fn(right_hand_keypoints[:, -1])
+ else:
+ #thicknessCircleRatioLeft = 1./50 * np.ones(left_hand_keypoints.shape[0])
+ thicknessCircleRatioRight = 1./50 * np.ones(right_hand_keypoints.shape[0])
+ thicknessLineRatioWRTCircle = 0.75
+ pairs = [0,1, 1,2, 2,3, 3,4, 0,5, 5,6, 6,7, 7,8, 0,9, 9,10, 10,11, 11,12, 0,13, 13,14, 14,15, 15,16, 0,17, 17,18, 18,19, 19,20]
+ pairs = np.array(pairs).reshape(-1,2)
+
+ colors = [100., 100., 100.,
+ 100., 0., 0.,
+ 150., 0., 0.,
+ 200., 0., 0.,
+ 255., 0., 0.,
+ 100., 100., 0.,
+ 150., 150., 0.,
+ 200., 200., 0.,
+ 255., 255., 0.,
+ 0., 100., 50.,
+ 0., 150., 75.,
+ 0., 200., 100.,
+ 0., 255., 125.,
+ 0., 50., 100.,
+ 0., 75., 150.,
+ 0., 100., 200.,
+ 0., 125., 255.,
+ 100., 0., 100.,
+ 150., 0., 150.,
+ 200., 0., 200.,
+ 255., 0., 255.]
+ colors = np.array(colors).reshape(-1,3)
+ #colors = np.zeros_like(colors)
+ poseScales = [1]
+ #img = render_keypoints(img, left_hand_keypoints, pairs, colors, thicknessCircleRatioLeft, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha)
+ img = render_keypoints(img, right_hand_keypoints, pairs, colors, thicknessCircleRatioRight, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha)
+ #img = render_keypoints(img, right_hand_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1)
+ return img
+
+def render_hand_landmarks(img, right_hand_keypoints, threshold=0.1, use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0):
+ if use_confidence and map_fn is not None:
+ #thicknessCircleRatioLeft = 1./50 * map_fn(left_hand_keypoints[:, -1])
+ thicknessCircleRatioRight = 1./50 * map_fn(right_hand_keypoints[:, -1])
+ else:
+ #thicknessCircleRatioLeft = 1./50 * np.ones(left_hand_keypoints.shape[0])
+ thicknessCircleRatioRight = 1./50 * np.ones(right_hand_keypoints.shape[0])
+ thicknessLineRatioWRTCircle = 0.75
+ pairs = []
+ pairs = np.array(pairs).reshape(-1,2)
+
+ colors = [255, 0, 0]
+ colors = np.array(colors).reshape(-1,3)
+ #colors = np.zeros_like(colors)
+ poseScales = [1]
+ #img = render_keypoints(img, left_hand_keypoints, pairs, colors, thicknessCircleRatioLeft, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha)
+ img = render_keypoints(img, right_hand_keypoints, pairs, colors, thicknessCircleRatioRight * 0.1, thicknessLineRatioWRTCircle * 0.1, poseScales, threshold, alpha=alpha)
+ #img = render_keypoints(img, right_hand_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1)
+ return img
+
+def render_body_keypoints(img: np.array,
+ body_keypoints: np.array) -> np.array:
+ """
+ Render OpenPose body keypoints on input image.
+ Args:
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
+ body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
+ Returns:
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
+ """
+
+ thickness_circle_ratio = 1./75. * np.ones(body_keypoints.shape[0])
+ thickness_line_ratio_wrt_circle = 0.75
+ pairs = []
+ pairs = [1,8,1,2,1,5,2,3,3,4,5,6,6,7,8,9,9,10,10,11,8,12,12,13,13,14,1,0,0,15,15,17,0,16,16,18,14,19,19,20,14,21,11,22,22,23,11,24]
+ pairs = np.array(pairs).reshape(-1,2)
+ colors = [255., 0., 85.,
+ 255., 0., 0.,
+ 255., 85., 0.,
+ 255., 170., 0.,
+ 255., 255., 0.,
+ 170., 255., 0.,
+ 85., 255., 0.,
+ 0., 255., 0.,
+ 255., 0., 0.,
+ 0., 255., 85.,
+ 0., 255., 170.,
+ 0., 255., 255.,
+ 0., 170., 255.,
+ 0., 85., 255.,
+ 0., 0., 255.,
+ 255., 0., 170.,
+ 170., 0., 255.,
+ 255., 0., 255.,
+ 85., 0., 255.,
+ 0., 0., 255.,
+ 0., 0., 255.,
+ 0., 0., 255.,
+ 0., 255., 255.,
+ 0., 255., 255.,
+ 0., 255., 255.]
+ colors = np.array(colors).reshape(-1,3)
+ pose_scales = [1]
+ return render_keypoints(img, body_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1)
+
+def render_openpose(img: np.array,
+ hand_keypoints: np.array) -> np.array:
+ """
+ Render keypoints in the OpenPose format on input image.
+ Args:
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
+ body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
+ Returns:
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
+ """
+ #img = render_body_keypoints(img, body_keypoints)
+ img = render_hand_keypoints(img, hand_keypoints)
+ return img
+
+def render_openpose_landmarks(img: np.array,
+ hand_keypoints: np.array) -> np.array:
+ """
+ Render keypoints in the OpenPose format on input image.
+ Args:
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
+ body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
+ Returns:
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
+ """
+ #img = render_body_keypoints(img, body_keypoints)
+ img = render_hand_landmarks(img, hand_keypoints)
+ return img
diff --git a/hawor/utils/rotation.py b/hawor/utils/rotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7bf9dd99d4ec03802055c03177be5a1a634eac0
--- /dev/null
+++ b/hawor/utils/rotation.py
@@ -0,0 +1,293 @@
+import torch
+import numpy as np
+from torch.nn import functional as F
+
+
+def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
+ """
+ Taken from https://github.com/mkocabas/VIBE/blob/master/lib/utils/geometry.py
+ Calculates the rotation matrices for a batch of rotation vectors
+ - param rot_vecs: torch.tensor (N, 3) array of N axis-angle vectors
+ - returns R: torch.tensor (N, 3, 3) rotation matrices
+ """
+ batch_size = rot_vecs.shape[0]
+ device = rot_vecs.device
+
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
+ rot_dir = rot_vecs / angle
+
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
+
+ # Bx1 arrays
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
+
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(
+ (batch_size, 3, 3)
+ )
+
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
+ return rot_mat
+
+
+def quaternion_mul(q0, q1):
+ """
+ EXPECTS WXYZ
+ :param q0 (*, 4)
+ :param q1 (*, 4)
+ """
+ r0, r1 = q0[..., :1], q1[..., :1]
+ v0, v1 = q0[..., 1:], q1[..., 1:]
+ r = r0 * r1 - (v0 * v1).sum(dim=-1, keepdim=True)
+ v = r0 * v1 + r1 * v0 + torch.linalg.cross(v0, v1)
+ return torch.cat([r, v], dim=-1)
+
+
+def quaternion_inverse(q, eps=1e-8):
+ """
+ EXPECTS WXYZ
+ :param q (*, 4)
+ """
+ conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1)
+ mag = torch.square(q).sum(dim=-1, keepdim=True) + eps
+ return conj / mag
+
+
+def quaternion_slerp(t, q0, q1, eps=1e-8):
+ """
+ :param t (*, 1) must be between 0 and 1
+ :param q0 (*, 4)
+ :param q1 (*, 4)
+ """
+ dims = q0.shape[:-1]
+ t = t.view(*dims, 1)
+
+ q0 = F.normalize(q0, p=2, dim=-1)
+ q1 = F.normalize(q1, p=2, dim=-1)
+ dot = (q0 * q1).sum(dim=-1, keepdim=True)
+
+ # make sure we give the shortest rotation path (< 180d)
+ neg = dot < 0
+ q1 = torch.where(neg, -q1, q1)
+ dot = torch.where(neg, -dot, dot)
+ angle = torch.acos(dot)
+
+ # if angle is too small, just do linear interpolation
+ collin = torch.abs(dot) > 1 - eps
+ fac = 1 / torch.sin(angle)
+ w0 = torch.where(collin, 1 - t, torch.sin((1 - t) * angle) * fac)
+ w1 = torch.where(collin, t, torch.sin(t * angle) * fac)
+ slerp = q0 * w0 + q1 * w1
+ return slerp
+
+
+def rotation_matrix_to_angle_axis(rotation_matrix):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert rotation matrix to Rodrigues vector
+ """
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
+ aa = quaternion_to_angle_axis(quaternion)
+ aa[torch.isnan(aa)] = 0.0
+ return aa
+
+
+def quaternion_to_angle_axis(quaternion):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert quaternion vector to angle axis of rotation.
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
+
+ :param quaternion (*, 4) expects WXYZ
+ :returns angle_axis (*, 3)
+ """
+ # unpack input and compute conversion
+ q1 = quaternion[..., 1]
+ q2 = quaternion[..., 2]
+ q3 = quaternion[..., 3]
+ sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3
+
+ sin_theta = torch.sqrt(sin_squared_theta)
+ cos_theta = quaternion[..., 0]
+ two_theta = 2.0 * torch.where(
+ cos_theta < 0.0,
+ torch.atan2(-sin_theta, -cos_theta),
+ torch.atan2(sin_theta, cos_theta),
+ )
+
+ k_pos = two_theta / sin_theta
+ k_neg = 2.0 * torch.ones_like(sin_theta)
+ k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
+
+ angle_axis = torch.zeros_like(quaternion)[..., :3]
+ angle_axis[..., 0] += q1 * k
+ angle_axis[..., 1] += q2 * k
+ angle_axis[..., 2] += q3 * k
+ return angle_axis
+
+
+def angle_axis_to_rotation_matrix(angle_axis):
+ """
+ :param angle_axis (*, 3)
+ return (*, 3, 3)
+ """
+ quat = angle_axis_to_quaternion(angle_axis)
+ return quaternion_to_rotation_matrix(quat)
+
+
+def quaternion_to_rotation_matrix(quaternion):
+ """
+ Convert a quaternion to a rotation matrix.
+ Taken from https://github.com/kornia/kornia, based on
+ https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101
+ https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247
+ :param quaternion (N, 4) expects WXYZ order
+ returns rotation matrix (N, 3, 3)
+ """
+ # normalize the input quaternion
+ quaternion_norm = F.normalize(quaternion, p=2, dim=-1, eps=1e-12)
+ *dims, _ = quaternion_norm.shape
+
+ # unpack the normalized quaternion components
+ w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1)
+
+ # compute the actual conversion
+ tx = 2.0 * x
+ ty = 2.0 * y
+ tz = 2.0 * z
+ twx = tx * w
+ twy = ty * w
+ twz = tz * w
+ txx = tx * x
+ txy = ty * x
+ txz = tz * x
+ tyy = ty * y
+ tyz = tz * y
+ tzz = tz * z
+ one = torch.tensor(1.0)
+
+ matrix = torch.stack(
+ (
+ one - (tyy + tzz),
+ txy - twz,
+ txz + twy,
+ txy + twz,
+ one - (txx + tzz),
+ tyz - twx,
+ txz - twy,
+ tyz + twx,
+ one - (txx + tyy),
+ ),
+ dim=-1,
+ ).view(*dims, 3, 3)
+ return matrix
+
+
+def angle_axis_to_quaternion(angle_axis):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+ Convert angle axis to quaternion in WXYZ order
+ :param angle_axis (*, 3)
+ :returns quaternion (*, 4) WXYZ order
+ """
+ theta_sq = torch.sum(angle_axis**2, dim=-1, keepdim=True) # (*, 1)
+ # need to handle the zero rotation case
+ valid = theta_sq > 0
+ theta = torch.sqrt(theta_sq)
+ half_theta = 0.5 * theta
+ ones = torch.ones_like(half_theta)
+ # fill zero with the limit of sin ax / x -> a
+ k = torch.where(valid, torch.sin(half_theta) / theta, 0.5 * ones)
+ w = torch.where(valid, torch.cos(half_theta), ones)
+ quat = torch.cat([w, k * angle_axis], dim=-1)
+ return quat
+
+
+def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+ Convert rotation matrix to 4d quaternion vector
+ This algorithm is based on algorithm described in
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
+
+ :param rotation_matrix (N, 3, 3)
+ """
+ *dims, m, n = rotation_matrix.shape
+ rmat_t = torch.transpose(rotation_matrix.reshape(-1, m, n), -1, -2)
+
+ mask_d2 = rmat_t[:, 2, 2] < eps
+
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
+
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
+ q0 = torch.stack(
+ [
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
+ t0,
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
+ ],
+ -1,
+ )
+ t0_rep = t0.repeat(4, 1).t()
+
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
+ q1 = torch.stack(
+ [
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+ t1,
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
+ ],
+ -1,
+ )
+ t1_rep = t1.repeat(4, 1).t()
+
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
+ q2 = torch.stack(
+ [
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
+ t2,
+ ],
+ -1,
+ )
+ t2_rep = t2.repeat(4, 1).t()
+
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
+ q3 = torch.stack(
+ [
+ t3,
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
+ ],
+ -1,
+ )
+ t3_rep = t3.repeat(4, 1).t()
+
+ mask_c0 = mask_d2 * mask_d0_d1
+ mask_c1 = mask_d2 * ~mask_d0_d1
+ mask_c2 = ~mask_d2 * mask_d0_nd1
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
+
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
+ q /= torch.sqrt(
+ t0_rep * mask_c0
+ + t1_rep * mask_c1
+ + t2_rep * mask_c2 # noqa
+ + t3_rep * mask_c3
+ ) # noqa
+ q *= 0.5
+ return q.reshape(*dims, 4)
diff --git a/imgui.ini b/imgui.ini
new file mode 100644
index 0000000000000000000000000000000000000000..3cf22d85b75378eea2a4e7e1df56f60199fd9708
--- /dev/null
+++ b/imgui.ini
@@ -0,0 +1,15 @@
+[Window][Debug##Default]
+Pos=60,60
+Size=400,400
+Collapsed=0
+
+[Window][Editor]
+Pos=50,50
+Size=250,700
+Collapsed=0
+
+[Window][Playback]
+Pos=50,800
+Size=400,175
+Collapsed=1
+
diff --git a/infiller/hand_utils/geometry.py b/infiller/hand_utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3d65066ce383c744407f8d800a7342cd324142b
--- /dev/null
+++ b/infiller/hand_utils/geometry.py
@@ -0,0 +1,412 @@
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+
+def perspective_projection(points, rotation, translation,
+ focal_length, camera_center, distortion=None):
+ """
+ This function computes the perspective projection of a set of points.
+ Input:
+ points (bs, N, 3): 3D points
+ rotation (bs, 3, 3): Camera rotation
+ translation (bs, 3): Camera translation
+ focal_length (bs,) or scalar: Focal length
+ camera_center (bs, 2): Camera center
+ """
+ batch_size = points.shape[0]
+
+ # Extrinsic
+ if rotation is not None:
+ points = torch.einsum('bij,bkj->bki', rotation, points)
+
+ if translation is not None:
+ points = points + translation.unsqueeze(1)
+
+ if distortion is not None:
+ kc = distortion
+ points = points[:,:,:2] / points[:,:,2:]
+
+ r2 = points[:,:,0]**2 + points[:,:,1]**2
+ dx = (2 * kc[:,[2]] * points[:,:,0] * points[:,:,1]
+ + kc[:,[3]] * (r2 + 2*points[:,:,0]**2))
+
+ dy = (2 * kc[:,[3]] * points[:,:,0] * points[:,:,1]
+ + kc[:,[2]] * (r2 + 2*points[:,:,1]**2))
+
+ x = (1 + kc[:,[0]]*r2 + kc[:,[1]]*r2.pow(2) + kc[:,[4]]*r2.pow(3)) * points[:,:,0] + dx
+ y = (1 + kc[:,[0]]*r2 + kc[:,[1]]*r2.pow(2) + kc[:,[4]]*r2.pow(3)) * points[:,:,1] + dy
+
+ points = torch.stack([x, y, torch.ones_like(x)], dim=-1)
+
+ # Intrinsic
+ K = torch.zeros([batch_size, 3, 3], device=points.device)
+ K[:,0,0] = focal_length
+ K[:,1,1] = focal_length
+ K[:,2,2] = 1.
+ K[:,:-1, -1] = camera_center
+
+ # Apply camera intrinsicsrf
+ points = points / points[:,:,-1].unsqueeze(-1)
+ projected_points = torch.einsum('bij,bkj->bki', K, points)
+ projected_points = projected_points[:, :, :-1]
+
+ return projected_points
+
+
+def avg_rot(rot):
+ # input [B,...,3,3] --> output [...,3,3]
+ rot = rot.mean(dim=0)
+ U, _, V = torch.svd(rot)
+ rot = U @ V.transpose(-1, -2)
+ return rot
+
+
+def rot9d_to_rotmat(x):
+ """Convert 9D rotation representation to 3x3 rotation matrix.
+ Based on Levinson et al., "An Analysis of SVD for Deep Rotation Estimation"
+ Input:
+ (B,9) or (B,J*9) Batch of 9D rotation (interpreted as 3x3 est rotmat)
+ Output:
+ (B,3,3) or (B*J,3,3) Batch of corresponding rotation matrices
+ """
+ x = x.view(-1,3,3)
+ u, _, vh = torch.linalg.svd(x)
+
+ sig = torch.eye(3).expand(len(x), 3, 3).clone()
+ sig = sig.to(x.device)
+ sig[:, -1, -1] = (u @ vh).det()
+
+ R = u @ sig @ vh
+
+ return R
+
+
+"""
+Deprecated in favor of: rotation_conversions.py
+
+Useful geometric operations, e.g. differentiable Rodrigues formula
+Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
+"""
+def batch_rodrigues(theta):
+ """Convert axis-angle representation to rotation matrix.
+ Args:
+ theta: size = [B, 3]
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
+ angle = torch.unsqueeze(l1norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
+ return quat_to_rotmat(quat)
+
+def quat_to_rotmat(quat):
+ """Convert quaternion coefficients to rotation matrix.
+ Args:
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ norm_quat = quat
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
+
+ B = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w*x, w*y, w*z
+ xy, xz, yz = x*y, x*z, y*z
+
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
+ return rotMat
+
+def rot6d_to_rotmat(x):
+ """Convert 6D rotation representation to 3x3 rotation matrix.
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
+ Input:
+ (B,6) Batch of 6-D rotation representations
+ Output:
+ (B,3,3) Batch of corresponding rotation matrices
+ """
+ x = x.view(-1,3,2)
+ a1 = x[:, :, 0]
+ a2 = x[:, :, 1]
+ b1 = F.normalize(a1)
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
+ b3 = torch.cross(b1, b2)
+ return torch.stack((b1, b2, b3), dim=-1)
+
+def rot6d_to_rotmat_hmr2(x: torch.Tensor) -> torch.Tensor:
+ """
+ Convert 6D rotation representation to 3x3 rotation matrix.
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
+ Args:
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
+ Returns:
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
+ """
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
+ a1 = x[:, :, 0]
+ a2 = x[:, :, 1]
+ b1 = F.normalize(a1)
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
+ b3 = torch.cross(b1, b2)
+ return torch.stack((b1, b2, b3), dim=-1)
+
+def rotmat_to_rot6d(rotmat):
+ """ Inverse function of the above.
+ Input:
+ (B,3,3) Batch of corresponding rotation matrices
+ Output:
+ (B,6) Batch of 6-D rotation representations
+ """
+ # rot6d = rotmat[:, :, :2]
+ rot6d = rotmat[...,:2]
+ rot6d = rot6d.reshape(rot6d.size(0), -1)
+ return rot6d
+
+
+def rotation_matrix_to_angle_axis(rotation_matrix):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert 3x4 rotation matrix to Rodrigues vector
+
+ Args:
+ rotation_matrix (Tensor): rotation matrix.
+
+ Returns:
+ Tensor: Rodrigues vector transformation.
+
+ Shape:
+ - Input: :math:`(N, 3, 4)`
+ - Output: :math:`(N, 3)`
+
+ Example:
+ >>> input = torch.rand(2, 3, 4) # Nx4x4
+ >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
+ """
+ if rotation_matrix.shape[1:] == (3,3):
+ rot_mat = rotation_matrix.reshape(-1, 3, 3)
+ hom = torch.tensor([0, 0, 1], dtype=torch.float32,
+ device=rotation_matrix.device).reshape(1, 3, 1).expand(rot_mat.shape[0], -1, -1)
+ rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
+
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
+ aa = quaternion_to_angle_axis(quaternion)
+ aa[torch.isnan(aa)] = 0.0
+ return aa
+
+
+def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert quaternion vector to angle axis of rotation.
+
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
+
+ Args:
+ quaternion (torch.Tensor): tensor with quaternions.
+
+ Return:
+ torch.Tensor: tensor with angle axis of rotation.
+
+ Shape:
+ - Input: :math:`(*, 4)` where `*` means, any number of dimensions
+ - Output: :math:`(*, 3)`
+
+ Example:
+ >>> quaternion = torch.rand(2, 4) # Nx4
+ >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
+ """
+ if not torch.is_tensor(quaternion):
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
+ type(quaternion)))
+
+ if not quaternion.shape[-1] == 4:
+ raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
+ .format(quaternion.shape))
+ # unpack input and compute conversion
+ q1: torch.Tensor = quaternion[..., 1]
+ q2: torch.Tensor = quaternion[..., 2]
+ q3: torch.Tensor = quaternion[..., 3]
+ sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
+
+ sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
+ cos_theta: torch.Tensor = quaternion[..., 0]
+ two_theta: torch.Tensor = 2.0 * torch.where(
+ cos_theta < 0.0,
+ torch.atan2(-sin_theta, -cos_theta),
+ torch.atan2(sin_theta, cos_theta))
+
+ k_pos: torch.Tensor = two_theta / sin_theta
+ k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
+ k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
+
+ angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
+ angle_axis[..., 0] += q1 * k
+ angle_axis[..., 1] += q2 * k
+ angle_axis[..., 2] += q3 * k
+ return angle_axis
+
+
+def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert 3x4 rotation matrix to 4d quaternion vector
+
+ This algorithm is based on algorithm described in
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
+
+ Args:
+ rotation_matrix (Tensor): the rotation matrix to convert.
+
+ Return:
+ Tensor: the rotation in quaternion
+
+ Shape:
+ - Input: :math:`(N, 3, 4)`
+ - Output: :math:`(N, 4)`
+
+ Example:
+ >>> input = torch.rand(4, 3, 4) # Nx3x4
+ >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
+ """
+ if not torch.is_tensor(rotation_matrix):
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
+ type(rotation_matrix)))
+
+ if len(rotation_matrix.shape) > 3:
+ raise ValueError(
+ "Input size must be a three dimensional tensor. Got {}".format(
+ rotation_matrix.shape))
+ if not rotation_matrix.shape[-2:] == (3, 4):
+ raise ValueError(
+ "Input size must be a N x 3 x 4 tensor. Got {}".format(
+ rotation_matrix.shape))
+
+ rmat_t = torch.transpose(rotation_matrix, 1, 2)
+
+ mask_d2 = rmat_t[:, 2, 2] < eps
+
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
+
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
+ q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
+ t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
+ t0_rep = t0.repeat(4, 1).t()
+
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
+ q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+ t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
+ t1_rep = t1.repeat(4, 1).t()
+
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
+ q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
+ t2_rep = t2.repeat(4, 1).t()
+
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
+ q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
+ t3_rep = t3.repeat(4, 1).t()
+
+ mask_c0 = mask_d2 * mask_d0_d1
+ mask_c1 = mask_d2 * ~mask_d0_d1
+ mask_c2 = ~mask_d2 * mask_d0_nd1
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
+
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
+ q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
+ t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
+ q *= 0.5
+ return q
+
+
+def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000., img_size=224.):
+ """
+ This function is borrowed from https://github.com/nkolot/SPIN/utils/geometry.py
+
+ Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
+ Input:
+ S: (25, 3) 3D joint locations
+ joints: (25, 3) 2D joint locations and confidence
+ Returns:
+ (3,) camera translation vector
+ """
+
+ num_joints = S.shape[0]
+ # focal length
+ f = np.array([focal_length,focal_length])
+ # optical center
+ center = np.array([img_size/2., img_size/2.])
+
+ # transformations
+ Z = np.reshape(np.tile(S[:,2],(2,1)).T,-1)
+ XY = np.reshape(S[:,0:2],-1)
+ O = np.tile(center,num_joints)
+ F = np.tile(f,num_joints)
+ weight2 = np.reshape(np.tile(np.sqrt(joints_conf),(2,1)).T,-1)
+
+ # least squares
+ Q = np.array([F*np.tile(np.array([1,0]),num_joints), F*np.tile(np.array([0,1]),num_joints), O-np.reshape(joints_2d,-1)]).T
+ c = (np.reshape(joints_2d,-1)-O)*Z - F*XY
+
+ # weighted least squares
+ W = np.diagflat(weight2)
+ Q = np.dot(W,Q)
+ c = np.dot(W,c)
+
+ # square matrix
+ A = np.dot(Q.T,Q)
+ b = np.dot(Q.T,c)
+
+ # solution
+ trans = np.linalg.solve(A, b)
+
+ return trans
+
+
+def estimate_translation(S, joints_2d, focal_length=5000., img_size=224.):
+ """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
+ Input:
+ S: (B, 49, 3) 3D joint locations
+ joints: (B, 49, 3) 2D joint locations and confidence
+ Returns:
+ (B, 3) camera translation vectors
+ """
+
+ device = S.device
+ # Use only joints 25:49 (GT joints)
+ S = S[:, -24:, :3].cpu().numpy()
+ joints_2d = joints_2d[:, -24:, :].cpu().numpy()
+
+ joints_conf = joints_2d[:, :, -1]
+ joints_2d = joints_2d[:, :, :-1]
+ trans = np.zeros((S.shape[0], 3), dtype=np.float32)
+ # Find the translation for each example in the batch
+ for i in range(S.shape[0]):
+ S_i = S[i]
+ joints_i = joints_2d[i]
+ conf_i = joints_conf[i]
+ trans[i] = estimate_translation_np(S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size)
+ return torch.from_numpy(trans).to(device)
+
+
diff --git a/infiller/hand_utils/geometry_utils.py b/infiller/hand_utils/geometry_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1effe3d9aa66386aa3dd114f07646c6a96a5035e
--- /dev/null
+++ b/infiller/hand_utils/geometry_utils.py
@@ -0,0 +1,102 @@
+from typing import Optional
+import torch
+from torch.nn import functional as F
+
+def aa_to_rotmat(theta: torch.Tensor):
+ """
+ Convert axis-angle representation to rotation matrix.
+ Works by first converting it to a quaternion.
+ Args:
+ theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
+ Returns:
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
+ """
+ norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
+ angle = torch.unsqueeze(norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
+ return quat_to_rotmat(quat)
+
+def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
+ """
+ Convert quaternion representation to rotation matrix.
+ Args:
+ quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
+ Returns:
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
+ """
+ norm_quat = quat
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
+
+ B = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w*x, w*y, w*z
+ xy, xz, yz = x*y, x*z, y*z
+
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
+ return rotMat
+
+
+def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
+ """
+ Convert 6D rotation representation to 3x3 rotation matrix.
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
+ Args:
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
+ Returns:
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
+ """
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
+ a1 = x[:, :, 0]
+ a2 = x[:, :, 1]
+ b1 = F.normalize(a1)
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
+ b3 = torch.linalg.cross(b1, b2)
+ return torch.stack((b1, b2, b3), dim=-1)
+
+def perspective_projection(points: torch.Tensor,
+ translation: torch.Tensor,
+ focal_length: torch.Tensor,
+ camera_center: Optional[torch.Tensor] = None,
+ rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Computes the perspective projection of a set of 3D points.
+ Args:
+ points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
+ translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
+ focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
+ camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
+ rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
+ Returns:
+ torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
+ """
+ batch_size = points.shape[0]
+ if rotation is None:
+ rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
+ if camera_center is None:
+ camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
+ # Populate intrinsic camera matrix K.
+ K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
+ K[:,0,0] = focal_length[:,0]
+ K[:,1,1] = focal_length[:,1]
+ K[:,2,2] = 1.
+ K[:,:-1, -1] = camera_center
+
+ # Transform points
+ points = torch.einsum('bij,bkj->bki', rotation, points)
+ points = points + translation.unsqueeze(1)
+
+ # Apply perspective distortion
+ projected_points = points / points[:,:,-1].unsqueeze(-1)
+
+ # Apply camera intrinsics
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
+
+ return projected_points[:, :, :-1]
\ No newline at end of file
diff --git a/infiller/hand_utils/mano_wrapper.py b/infiller/hand_utils/mano_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..4801c26e3410035db3d09f8dac70c637dd7b00f5
--- /dev/null
+++ b/infiller/hand_utils/mano_wrapper.py
@@ -0,0 +1,52 @@
+import torch
+import numpy as np
+import pickle
+from typing import Optional
+import smplx
+from smplx.lbs import vertices2joints
+from smplx.utils import MANOOutput, to_tensor
+from smplx.vertex_ids import vertex_ids
+
+
+class MANO(smplx.MANOLayer):
+ def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs):
+ """
+ Extension of the official MANO implementation to support more joints.
+ Args:
+ Same as MANOLayer.
+ joint_regressor_extra (str): Path to extra joint regressor.
+ """
+ super(MANO, self).__init__(*args, **kwargs)
+ mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]
+
+ #2, 3, 5, 4, 1
+ if joint_regressor_extra is not None:
+ self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
+ self.register_buffer('extra_joints_idxs', to_tensor(list(vertex_ids['mano'].values()), dtype=torch.long))
+ self.register_buffer('joint_map', torch.tensor(mano_to_openpose, dtype=torch.long))
+
+ def forward(self, *args, **kwargs) -> MANOOutput:
+ """
+ Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified.
+ """
+ mano_output = super(MANO, self).forward(*args, **kwargs)
+ extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs)
+ joints = torch.cat([mano_output.joints, extra_joints], dim=1)
+ joints = joints[:, self.joint_map, :]
+ if hasattr(self, 'joint_regressor_extra'):
+ extra_joints = vertices2joints(self.joint_regressor_extra, mano_output.vertices)
+ joints = torch.cat([joints, extra_joints], dim=1)
+ mano_output.joints = joints
+ return mano_output
+
+ def query(self, hmr_output):
+ batch_size = hmr_output['pred_rotmat'].shape[0]
+ pred_rotmat = hmr_output['pred_rotmat'].reshape(batch_size, -1, 3, 3)
+ pred_shape = hmr_output['pred_shape'].reshape(batch_size, 10)
+
+ mano_output = self(global_orient=pred_rotmat[:, [0]],
+ hand_pose = pred_rotmat[:, 1:],
+ betas = pred_shape,
+ pose2rot=False)
+
+ return mano_output
\ No newline at end of file
diff --git a/infiller/hand_utils/process.py b/infiller/hand_utils/process.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb40b94c964837733bc0403adf802f4f9ba50654
--- /dev/null
+++ b/infiller/hand_utils/process.py
@@ -0,0 +1,171 @@
+import torch
+from hand_utils.mano_wrapper import MANO
+from hand_utils.geometry_utils import aa_to_rotmat
+import numpy as np
+
+def run_mano(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True):
+ """
+ Forward pass of the SMPL model and populates pred_data accordingly with
+ joints3d, verts3d, points3d.
+
+ trans : B x T x 3
+ root_orient : B x T x 3
+ body_pose : B x T x J*3
+ betas : (optional) B x D
+ """
+ MANO_cfg = {
+ 'DATA_DIR': '_DATA/data/',
+ 'MODEL_PATH': '_DATA/data/mano',
+ 'GENDER': 'neutral',
+ 'NUM_HAND_JOINTS': 15,
+ 'CREATE_BODY_POSE': False
+ }
+ mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
+ mano = MANO(**mano_cfg)
+ if use_cuda:
+ mano = mano.cuda()
+
+ B, T, _ = root_orient.shape
+ NUM_JOINTS = 15
+ mano_params = {
+ 'global_orient': root_orient.reshape(B*T, -1),
+ 'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3),
+ 'betas': betas.reshape(B*T, -1),
+ }
+ rotmat_mano_params = mano_params
+ rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3)
+ rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3)
+ rotmat_mano_params['transl'] = trans.reshape(B*T, 3)
+
+ if use_cuda:
+ mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False)
+ else:
+ mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False)
+
+ faces_right = mano.faces
+ faces_new = np.array([[92, 38, 234],
+ [234, 38, 239],
+ [38, 122, 239],
+ [239, 122, 279],
+ [122, 118, 279],
+ [279, 118, 215],
+ [118, 117, 215],
+ [215, 117, 214],
+ [117, 119, 214],
+ [214, 119, 121],
+ [119, 120, 121],
+ [121, 120, 78],
+ [120, 108, 78],
+ [78, 108, 79]])
+ faces_right = np.concatenate([faces_right, faces_new], axis=0)
+ faces_n = len(faces_right)
+ faces_left = faces_right[:,[0,2,1]]
+
+ outputs = {
+ "joints": mano_output.joints.reshape(B, T, -1, 3),
+ "vertices": mano_output.vertices.reshape(B, T, -1, 3),
+ }
+
+ if not is_right is None:
+ # outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0]
+ # outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0]
+ is_right = (is_right[:, :, 0].cpu().numpy() > 0)
+ faces_result = np.zeros((B, T, faces_n, 3))
+ faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0)
+ faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0)
+ faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded)
+ outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32))
+
+
+ return outputs
+
+def run_mano_left(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True, fix_shapedirs=True):
+ """
+ Forward pass of the SMPL model and populates pred_data accordingly with
+ joints3d, verts3d, points3d.
+
+ trans : B x T x 3
+ root_orient : B x T x 3
+ body_pose : B x T x J*3
+ betas : (optional) B x D
+ """
+ MANO_cfg = {
+ 'DATA_DIR': '_DATA/data_left/',
+ 'MODEL_PATH': '_DATA/data_left/mano_left',
+ 'GENDER': 'neutral',
+ 'NUM_HAND_JOINTS': 15,
+ 'CREATE_BODY_POSE': False,
+ 'is_rhand': False
+ }
+ mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
+ mano = MANO(**mano_cfg)
+ if use_cuda:
+ mano = mano.cuda()
+
+ # fix MANO shapedirs of the left hand bug (https://github.com/vchoutas/smplx/issues/48)
+ if fix_shapedirs:
+ mano.shapedirs[:, 0, :] *= -1
+
+ B, T, _ = root_orient.shape
+ NUM_JOINTS = 15
+ mano_params = {
+ 'global_orient': root_orient.reshape(B*T, -1),
+ 'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3),
+ 'betas': betas.reshape(B*T, -1),
+ }
+ rotmat_mano_params = mano_params
+ rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3)
+ rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3)
+ rotmat_mano_params['transl'] = trans.reshape(B*T, 3)
+
+ if use_cuda:
+ mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False)
+ else:
+ mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False)
+
+ faces_right = mano.faces
+ faces_new = np.array([[92, 38, 234],
+ [234, 38, 239],
+ [38, 122, 239],
+ [239, 122, 279],
+ [122, 118, 279],
+ [279, 118, 215],
+ [118, 117, 215],
+ [215, 117, 214],
+ [117, 119, 214],
+ [214, 119, 121],
+ [119, 120, 121],
+ [121, 120, 78],
+ [120, 108, 78],
+ [78, 108, 79]])
+ faces_right = np.concatenate([faces_right, faces_new], axis=0)
+ faces_n = len(faces_right)
+ faces_left = faces_right[:,[0,2,1]]
+
+ outputs = {
+ "joints": mano_output.joints.reshape(B, T, -1, 3),
+ "vertices": mano_output.vertices.reshape(B, T, -1, 3),
+ }
+
+ if not is_right is None:
+ # outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0]
+ # outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0]
+ is_right = (is_right[:, :, 0].cpu().numpy() > 0)
+ faces_result = np.zeros((B, T, faces_n, 3))
+ faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0)
+ faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0)
+ faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded)
+ outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32))
+
+
+ return outputs
+
+def run_mano_twohands(init_trans, init_rot, init_hand_pose, is_right, init_betas, use_cuda=True, fix_shapedirs=True):
+ outputs_left = run_mano_left(init_trans[0:1], init_rot[0:1], init_hand_pose[0:1], None, init_betas[0:1], use_cuda=use_cuda, fix_shapedirs=fix_shapedirs)
+ outputs_right = run_mano(init_trans[1:2], init_rot[1:2], init_hand_pose[1:2], None, init_betas[1:2], use_cuda=use_cuda)
+ outputs_two = {
+ "vertices": torch.cat((outputs_left["vertices"], outputs_right["vertices"]), dim=0),
+ "joints": torch.cat((outputs_left["joints"], outputs_right["joints"]), dim=0)
+
+ }
+ return outputs_two
\ No newline at end of file
diff --git a/infiller/hand_utils/rotation.py b/infiller/hand_utils/rotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7bf9dd99d4ec03802055c03177be5a1a634eac0
--- /dev/null
+++ b/infiller/hand_utils/rotation.py
@@ -0,0 +1,293 @@
+import torch
+import numpy as np
+from torch.nn import functional as F
+
+
+def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
+ """
+ Taken from https://github.com/mkocabas/VIBE/blob/master/lib/utils/geometry.py
+ Calculates the rotation matrices for a batch of rotation vectors
+ - param rot_vecs: torch.tensor (N, 3) array of N axis-angle vectors
+ - returns R: torch.tensor (N, 3, 3) rotation matrices
+ """
+ batch_size = rot_vecs.shape[0]
+ device = rot_vecs.device
+
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
+ rot_dir = rot_vecs / angle
+
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
+
+ # Bx1 arrays
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
+
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(
+ (batch_size, 3, 3)
+ )
+
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
+ return rot_mat
+
+
+def quaternion_mul(q0, q1):
+ """
+ EXPECTS WXYZ
+ :param q0 (*, 4)
+ :param q1 (*, 4)
+ """
+ r0, r1 = q0[..., :1], q1[..., :1]
+ v0, v1 = q0[..., 1:], q1[..., 1:]
+ r = r0 * r1 - (v0 * v1).sum(dim=-1, keepdim=True)
+ v = r0 * v1 + r1 * v0 + torch.linalg.cross(v0, v1)
+ return torch.cat([r, v], dim=-1)
+
+
+def quaternion_inverse(q, eps=1e-8):
+ """
+ EXPECTS WXYZ
+ :param q (*, 4)
+ """
+ conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1)
+ mag = torch.square(q).sum(dim=-1, keepdim=True) + eps
+ return conj / mag
+
+
+def quaternion_slerp(t, q0, q1, eps=1e-8):
+ """
+ :param t (*, 1) must be between 0 and 1
+ :param q0 (*, 4)
+ :param q1 (*, 4)
+ """
+ dims = q0.shape[:-1]
+ t = t.view(*dims, 1)
+
+ q0 = F.normalize(q0, p=2, dim=-1)
+ q1 = F.normalize(q1, p=2, dim=-1)
+ dot = (q0 * q1).sum(dim=-1, keepdim=True)
+
+ # make sure we give the shortest rotation path (< 180d)
+ neg = dot < 0
+ q1 = torch.where(neg, -q1, q1)
+ dot = torch.where(neg, -dot, dot)
+ angle = torch.acos(dot)
+
+ # if angle is too small, just do linear interpolation
+ collin = torch.abs(dot) > 1 - eps
+ fac = 1 / torch.sin(angle)
+ w0 = torch.where(collin, 1 - t, torch.sin((1 - t) * angle) * fac)
+ w1 = torch.where(collin, t, torch.sin(t * angle) * fac)
+ slerp = q0 * w0 + q1 * w1
+ return slerp
+
+
+def rotation_matrix_to_angle_axis(rotation_matrix):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert rotation matrix to Rodrigues vector
+ """
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
+ aa = quaternion_to_angle_axis(quaternion)
+ aa[torch.isnan(aa)] = 0.0
+ return aa
+
+
+def quaternion_to_angle_axis(quaternion):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert quaternion vector to angle axis of rotation.
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
+
+ :param quaternion (*, 4) expects WXYZ
+ :returns angle_axis (*, 3)
+ """
+ # unpack input and compute conversion
+ q1 = quaternion[..., 1]
+ q2 = quaternion[..., 2]
+ q3 = quaternion[..., 3]
+ sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3
+
+ sin_theta = torch.sqrt(sin_squared_theta)
+ cos_theta = quaternion[..., 0]
+ two_theta = 2.0 * torch.where(
+ cos_theta < 0.0,
+ torch.atan2(-sin_theta, -cos_theta),
+ torch.atan2(sin_theta, cos_theta),
+ )
+
+ k_pos = two_theta / sin_theta
+ k_neg = 2.0 * torch.ones_like(sin_theta)
+ k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
+
+ angle_axis = torch.zeros_like(quaternion)[..., :3]
+ angle_axis[..., 0] += q1 * k
+ angle_axis[..., 1] += q2 * k
+ angle_axis[..., 2] += q3 * k
+ return angle_axis
+
+
+def angle_axis_to_rotation_matrix(angle_axis):
+ """
+ :param angle_axis (*, 3)
+ return (*, 3, 3)
+ """
+ quat = angle_axis_to_quaternion(angle_axis)
+ return quaternion_to_rotation_matrix(quat)
+
+
+def quaternion_to_rotation_matrix(quaternion):
+ """
+ Convert a quaternion to a rotation matrix.
+ Taken from https://github.com/kornia/kornia, based on
+ https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101
+ https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247
+ :param quaternion (N, 4) expects WXYZ order
+ returns rotation matrix (N, 3, 3)
+ """
+ # normalize the input quaternion
+ quaternion_norm = F.normalize(quaternion, p=2, dim=-1, eps=1e-12)
+ *dims, _ = quaternion_norm.shape
+
+ # unpack the normalized quaternion components
+ w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1)
+
+ # compute the actual conversion
+ tx = 2.0 * x
+ ty = 2.0 * y
+ tz = 2.0 * z
+ twx = tx * w
+ twy = ty * w
+ twz = tz * w
+ txx = tx * x
+ txy = ty * x
+ txz = tz * x
+ tyy = ty * y
+ tyz = tz * y
+ tzz = tz * z
+ one = torch.tensor(1.0)
+
+ matrix = torch.stack(
+ (
+ one - (tyy + tzz),
+ txy - twz,
+ txz + twy,
+ txy + twz,
+ one - (txx + tzz),
+ tyz - twx,
+ txz - twy,
+ tyz + twx,
+ one - (txx + tyy),
+ ),
+ dim=-1,
+ ).view(*dims, 3, 3)
+ return matrix
+
+
+def angle_axis_to_quaternion(angle_axis):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+ Convert angle axis to quaternion in WXYZ order
+ :param angle_axis (*, 3)
+ :returns quaternion (*, 4) WXYZ order
+ """
+ theta_sq = torch.sum(angle_axis**2, dim=-1, keepdim=True) # (*, 1)
+ # need to handle the zero rotation case
+ valid = theta_sq > 0
+ theta = torch.sqrt(theta_sq)
+ half_theta = 0.5 * theta
+ ones = torch.ones_like(half_theta)
+ # fill zero with the limit of sin ax / x -> a
+ k = torch.where(valid, torch.sin(half_theta) / theta, 0.5 * ones)
+ w = torch.where(valid, torch.cos(half_theta), ones)
+ quat = torch.cat([w, k * angle_axis], dim=-1)
+ return quat
+
+
+def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+ Convert rotation matrix to 4d quaternion vector
+ This algorithm is based on algorithm described in
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
+
+ :param rotation_matrix (N, 3, 3)
+ """
+ *dims, m, n = rotation_matrix.shape
+ rmat_t = torch.transpose(rotation_matrix.reshape(-1, m, n), -1, -2)
+
+ mask_d2 = rmat_t[:, 2, 2] < eps
+
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
+
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
+ q0 = torch.stack(
+ [
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
+ t0,
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
+ ],
+ -1,
+ )
+ t0_rep = t0.repeat(4, 1).t()
+
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
+ q1 = torch.stack(
+ [
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+ t1,
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
+ ],
+ -1,
+ )
+ t1_rep = t1.repeat(4, 1).t()
+
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
+ q2 = torch.stack(
+ [
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
+ t2,
+ ],
+ -1,
+ )
+ t2_rep = t2.repeat(4, 1).t()
+
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
+ q3 = torch.stack(
+ [
+ t3,
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
+ ],
+ -1,
+ )
+ t3_rep = t3.repeat(4, 1).t()
+
+ mask_c0 = mask_d2 * mask_d0_d1
+ mask_c1 = mask_d2 * ~mask_d0_d1
+ mask_c2 = ~mask_d2 * mask_d0_nd1
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
+
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
+ q /= torch.sqrt(
+ t0_rep * mask_c0
+ + t1_rep * mask_c1
+ + t2_rep * mask_c2 # noqa
+ + t3_rep * mask_c3
+ ) # noqa
+ q *= 0.5
+ return q.reshape(*dims, 4)
diff --git a/infiller/lib/misc/sampler.py b/infiller/lib/misc/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..16d44f6682886c70f30749aae1ff890a2b1557af
--- /dev/null
+++ b/infiller/lib/misc/sampler.py
@@ -0,0 +1,79 @@
+import argparse
+import os
+from pathlib import Path
+
+import imageio
+import numpy as np
+import torch
+import torch.nn as nn
+from PIL import Image
+from sklearn.preprocessing import LabelEncoder
+
+from cmib.data.lafan1_dataset import LAFAN1Dataset
+from cmib.data.utils import write_json
+from cmib.lafan1.utils import quat_ik
+from cmib.model.network import TransformerModel
+from cmib.model.preprocess import (lerp_input_repr, replace_constant,
+ slerp_input_repr, vectorize_representation)
+from cmib.model.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets, joint_names,
+ sk_parents)
+from cmib.vis.pose import plot_pose_with_stop
+
+
+def test(opt, device):
+
+ save_dir = Path(os.path.join('runs', 'train', opt.exp_name))
+ wdir = save_dir / 'weights'
+ weights = os.listdir(wdir)
+ weights_paths = [wdir / weight for weight in weights]
+ latest_weight = max(weights_paths , key = os.path.getctime)
+ ckpt = torch.load(latest_weight, map_location=device)
+ print(f"Loaded weight: {latest_weight}")
+
+ # Load Skeleton
+ skeleton_mocap = Skeleton(offsets=sk_offsets, parents=sk_parents, device=device)
+ skeleton_mocap.remove_joints(sk_joints_to_remove)
+
+ # Load LAFAN Dataset
+ Path(opt.processed_data_dir).mkdir(parents=True, exist_ok=True)
+ lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=opt.processed_data_dir, train=False, device=device)
+ total_data = lafan_dataset.data['global_pos'].shape[0]
+
+ # Replace with noise to In-betweening Frames
+ from_idx, target_idx = ckpt['from_idx'], ckpt['target_idx'] # default: 9-40, max: 48
+ horizon = ckpt['horizon']
+ print(f"HORIZON: {horizon}")
+
+ test_idx = []
+ for i in range(total_data):
+ test_idx.append(i)
+
+ # Compare Input data, Prediction, GT
+ save_path = os.path.join(opt.save_path, 'sampler')
+ for i in range(len(test_idx)):
+ Path(save_path).mkdir(parents=True, exist_ok=True)
+
+ start_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx]
+ target_pose = lafan_dataset.data['global_pos'][test_idx[i], target_idx]
+ gt_stopover_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx]
+
+ gt_img_path = os.path.join(save_path)
+ plot_pose_with_stop(start_pose, target_pose, target_pose, gt_stopover_pose, i, skeleton_mocap, save_dir=gt_img_path, prefix='gt')
+ print(f"ID {test_idx[i]}: completed.")
+
+def parse_opt():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--project', default='runs/train', help='project/name')
+ parser.add_argument('--exp_name', default='slerp_40', help='experiment name')
+ parser.add_argument('--data_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH', help='BVH dataset path')
+ parser.add_argument('--skeleton_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH/walk1_subject1.bvh', help='path to reference skeleton')
+ parser.add_argument('--processed_data_dir', type=str, default='processed_data_original/', help='path to save pickled processed data')
+ parser.add_argument('--save_path', type=str, default='runs/test', help='path to save model')
+ parser.add_argument('--motion_type', type=str, default='jumps', help='motion type')
+ opt = parser.parse_args()
+ return opt
+
+if __name__ == "__main__":
+ opt = parse_opt()
+ device = torch.device("cpu")
+ test(opt, device)
diff --git a/infiller/lib/model/__pycache__/network.cpython-310.pyc b/infiller/lib/model/__pycache__/network.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b25cc9023b522dabca5d407fb3d10e7bd2a0c7a9
Binary files /dev/null and b/infiller/lib/model/__pycache__/network.cpython-310.pyc differ
diff --git a/infiller/lib/model/network.py b/infiller/lib/model/network.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f812e0b1ffd1d876d6466eb4f4e974e6f57bd01
--- /dev/null
+++ b/infiller/lib/model/network.py
@@ -0,0 +1,276 @@
+import math
+import numpy as np
+import torch
+from torch import nn, Tensor
+from torch.nn import TransformerEncoder, TransformerEncoderLayer
+# from cmib.model.positional_encoding import PositionalEmbedding
+
+class SinPositionalEncoding(nn.Module):
+ def __init__(self, d_model, dropout=0.1, max_len=100):
+ super(SinPositionalEncoding, self).__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ pe = torch.zeros(max_len, d_model)
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0).transpose(0, 1)
+
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ # not used in the final model
+ x = x + self.pe[:x.shape[0], :]
+ return self.dropout(x)
+
+
+class MultiHeadedAttention(nn.Module):
+ def __init__(self, n_head, d_model, d_head, dropout=0.1,
+ pre_lnorm=True, bias=False):
+ """
+ Multi-headed attention with relative positional encoding and
+ memory mechanism.
+
+ Args:
+ n_head (int): Number of heads.
+ d_model (int): Input dimension.
+ d_head (int): Head dimension.
+ dropout (float, optional): Dropout value. Defaults to 0.1.
+ pre_lnorm (bool, optional):
+ Apply layer norm before rest of calculation. Defaults to True.
+ In original Transformer paper (pre_lnorm=False):
+ LayerNorm(x + Sublayer(x))
+ In tensor2tensor implementation (pre_lnorm=True):
+ x + Sublayer(LayerNorm(x))
+ bias (bool, optional):
+ Add bias to q, k, v and output projections. Defaults to False.
+
+ """
+ super(MultiHeadedAttention, self).__init__()
+
+ self.n_head = n_head
+ self.d_model = d_model
+ self.d_head = d_head
+ self.dropout = dropout
+ self.pre_lnorm = pre_lnorm
+ self.bias = bias
+ self.atten_scale = 1 / math.sqrt(self.d_model)
+
+ self.q_linear = nn.Linear(d_model, n_head * d_head, bias=bias)
+ self.k_linear = nn.Linear(d_model, n_head * d_head, bias=bias)
+ self.v_linear = nn.Linear(d_model, n_head * d_head, bias=bias)
+ self.out_linear = nn.Linear(n_head * d_head, d_model, bias=bias)
+
+ self.droput_layer = nn.Dropout(dropout)
+ self.atten_dropout_layer = nn.Dropout(dropout)
+
+ self.layer_norm = nn.LayerNorm(d_model)
+
+ def forward(self, hidden, memory=None, mask=None,
+ extra_atten_score=None):
+ """
+ Args:
+ hidden (Tensor): Input embedding or hidden state of previous layer.
+ Shape: (batch, seq, dim)
+ pos_emb (Tensor): Relative positional embedding lookup table.
+ Shape: (batch, (seq+mem_len)*2-1, d_head)
+ pos_emb[:, seq+mem_len]
+
+ memory (Tensor): Memory tensor of previous layer.
+ Shape: (batch, mem_len, dim)
+ mask (BoolTensor, optional): Attention mask.
+ Set item value to True if you DO NOT want keep certain
+ attention score, otherwise False. Defaults to None.
+ Shape: (seq, seq+mem_len).
+ """
+ combined = hidden
+ # if memory is None:
+ # combined = hidden
+ # mem_len = 0
+ # else:
+ # combined = torch.cat([memory, hidden], dim=1)
+ # mem_len = memory.shape[1]
+
+ if self.pre_lnorm:
+ hidden = self.layer_norm(hidden)
+ combined = self.layer_norm(combined)
+
+ # shape: (batch, q/k/v_len, dim)
+ q = self.q_linear(hidden)
+ k = self.k_linear(combined)
+ v = self.v_linear(combined)
+
+ # reshape to (batch, q/k/v_len, n_head, d_head)
+ q = q.reshape(q.shape[0], q.shape[1], self.n_head, self.d_head)
+ k = k.reshape(k.shape[0], k.shape[1], self.n_head, self.d_head)
+ v = v.reshape(v.shape[0], v.shape[1], self.n_head, self.d_head)
+
+ # transpose to (batch, n_head, q/k/v_len, d_head)
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ # add n_head dimension for relative positional embedding lookup table
+ # (batch, n_head, k/v_len*2-1, d_head)
+ # pos_emb = pos_emb[:, None]
+
+ # (batch, n_head, q_len, k_len)
+ atten_score = torch.matmul(q, k.transpose(-1, -2))
+
+ # qpos = torch.matmul(q, pos_emb.transpose(-1, -2))
+ # DEBUG
+ # ones = torch.zeros(q.shape)
+ # ones[:, :, :, 0] = 1.0
+ # qpos = torch.matmul(ones, pos_emb.transpose(-1, -2))
+ # atten_score = atten_score + self.skew(qpos, mem_len)
+ atten_score = atten_score * self.atten_scale
+
+ # if extra_atten_score is not None:
+ # atten_score = atten_score + extra_atten_score
+
+ if mask is not None:
+ # print(atten_score.shape)
+ # print(mask.shape)
+ # apply attention mask
+ atten_score = atten_score.masked_fill(mask, float("-inf"))
+ atten_score = atten_score.softmax(dim=-1)
+ atten_score = self.atten_dropout_layer(atten_score)
+
+ # (batch, n_head, q_len, d_head)
+ atten_vec = torch.matmul(atten_score, v)
+ # (batch, q_len, n_head*d_head)
+ atten_vec = atten_vec.transpose(1, 2).flatten(start_dim=-2)
+
+ # linear projection
+ output = self.droput_layer(self.out_linear(atten_vec))
+
+ if self.pre_lnorm:
+ return hidden + output
+ else:
+ return self.layer_norm(hidden + output)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, d_model, d_inner, dropout=0.1, pre_lnorm=True):
+ """
+ Positionwise feed-forward network.
+
+ Args:
+ d_model(int): Dimension of the input and output.
+ d_inner (int): Dimension of the middle layer(bottleneck).
+ dropout (float, optional): Dropout value. Defaults to 0.1.
+ pre_lnorm (bool, optional):
+ Apply layer norm before rest of calculation. Defaults to True.
+ In original Transformer paper (pre_lnorm=False):
+ LayerNorm(x + Sublayer(x))
+ In tensor2tensor implementation (pre_lnorm=True):
+ x + Sublayer(LayerNorm(x))
+ """
+ super(FeedForward, self).__init__()
+ self.d_model = d_model
+ self.d_inner = d_inner
+ self.dropout = dropout
+ self.pre_lnorm = pre_lnorm
+
+ self.layer_norm = nn.LayerNorm(d_model)
+ self.network = nn.Sequential(
+ nn.Linear(d_model, d_inner),
+ nn.ReLU(),
+ nn.Dropout(dropout),
+ nn.Linear(d_inner, d_model),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x):
+ if self.pre_lnorm:
+ return x + self.network(self.layer_norm(x))
+ else:
+ return self.layer_norm(x + self.network(x))
+class TransformerModel(nn.Module):
+ def __init__(
+ self,
+ seq_len: int,
+ input_dim: int,
+ d_model: int,
+ nhead: int,
+ d_hid: int,
+ nlayers: int,
+ dropout: float = 0.5,
+ out_dim=91,
+ masked_attention_stage=False,
+ ):
+ super().__init__()
+ self.model_type = "Transformer"
+ self.seq_len = seq_len
+ self.d_model = d_model
+ self.nhead = nhead
+ self.d_hid = d_hid
+ self.nlayers = nlayers
+ self.pos_embedding = SinPositionalEncoding(d_model=d_model, dropout=0.1, max_len=seq_len)
+ if masked_attention_stage:
+ self.input_layer = nn.Linear(input_dim+1, d_model)
+ # visible to invisible attention
+ self.att_layers = nn.ModuleList()
+ self.pff_layers = nn.ModuleList()
+ self.pre_lnorm = True
+ self.layer_norm = nn.LayerNorm(d_model)
+ for i in range(self.nlayers):
+ self.att_layers.append(
+ MultiHeadedAttention(
+ self.nhead, self.d_model,
+ self.d_model // self.nhead, dropout=dropout,
+ pre_lnorm=True,
+ bias=False
+ )
+ )
+
+ self.pff_layers.append(
+ FeedForward(
+ self.d_model, d_hid,
+ dropout=dropout,
+ pre_lnorm=True
+ )
+ )
+ else:
+ self.att_layers = None
+ self.input_layer = nn.Linear(input_dim, d_model)
+ encoder_layers = TransformerEncoderLayer(
+ d_model, nhead, d_hid, dropout, activation="gelu"
+ )
+ self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
+ self.decoder = nn.Linear(d_model, out_dim)
+
+ self.init_weights()
+
+ def init_weights(self) -> None:
+ initrange = 0.1
+ self.decoder.bias.data.zero_()
+ self.decoder.weight.data.uniform_(-initrange, initrange)
+
+ def forward(self, src: Tensor, src_mask: Tensor, data_mask=None, atten_mask=None) -> Tensor:
+ """
+ Args:
+ src: Tensor, shape [seq_len, batch_size, embedding_dim]
+ src_mask: Tensor, shape [seq_len, seq_len]
+
+ Returns:
+ output Tensor of shape [seq_len, batch_size, embedding_dim]
+ """
+ if not data_mask is None:
+ src = torch.cat([src, data_mask.expand(*src.shape[:-1], data_mask.shape[-1])], dim=-1)
+ src = self.input_layer(src)
+ output = self.pos_embedding(src)
+ # output = src
+ if self.att_layers:
+ assert not atten_mask is None
+ output = output.permute(1, 0, 2)
+ for i in range(self.nlayers):
+ output = self.att_layers[i](output, mask=atten_mask)
+ output = self.pff_layers[i](output)
+ if self.pre_lnorm:
+ output = self.layer_norm(output)
+ output = output.permute(1, 0, 2)
+ output = self.transformer_encoder(output)
+ output = self.decoder(output)
+ return output
diff --git a/infiller/lib/model/positional_encoding.py b/infiller/lib/model/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..a40cfeea570bc9f52af1240dc8529baf4b38d7a0
--- /dev/null
+++ b/infiller/lib/model/positional_encoding.py
@@ -0,0 +1,42 @@
+import torch
+from torch import nn, Tensor
+import math
+
+
+class PositionalEmbedding(nn.Module):
+ def __init__(self, seq_len: int = 32, d_model: int = 96):
+ super().__init__()
+ self.pos_emb = nn.Embedding(seq_len + 1, d_model)
+
+ def forward(self, inputs):
+ positions = (
+ torch.arange(inputs.size(0), device=inputs.device)
+ .expand(inputs.size(1), inputs.size(0))
+ .contiguous()
+ + 1
+ )
+ outputs = inputs + self.pos_emb(positions).permute(1, 0, 2)
+ return outputs
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ position = torch.arange(max_len).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
+ )
+ pe = torch.zeros(max_len, 1, d_model)
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """
+ Args:
+ x: Tensor, shape [seq_len, batch_size, embedding_dim]
+ """
+ x = x + self.pe[: x.size(0)]
+ return self.dropout(x)
diff --git a/infiller/lib/model/preprocess.py b/infiller/lib/model/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..44cdae00f464fbf55177edf83a1dea72a5c9090a
--- /dev/null
+++ b/infiller/lib/model/preprocess.py
@@ -0,0 +1,189 @@
+import torch
+
+
+def replace_constant(minibatch_pose_input, mask_start_frame):
+
+ seq_len = minibatch_pose_input.size(1)
+ interpolated = (
+ torch.ones_like(minibatch_pose_input, device=minibatch_pose_input.device) * 0.1
+ )
+
+ if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
+ interpolate_start = minibatch_pose_input[:, 0, :]
+ interpolate_end = minibatch_pose_input[:, seq_len - 1, :]
+
+ interpolated[:, 0, :] = interpolate_start
+ interpolated[:, seq_len - 1, :] = interpolate_end
+
+ assert torch.allclose(interpolated[:, 0, :], interpolate_start)
+ assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end)
+
+ else:
+ interpolate_start1 = minibatch_pose_input[:, 0, :]
+ interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :]
+
+ interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :]
+ interpolate_end2 = minibatch_pose_input[:, seq_len - 1, :]
+
+ interpolated[:, 0, :] = interpolate_start1
+ interpolated[:, mask_start_frame, :] = interpolate_end1
+
+ interpolated[:, mask_start_frame, :] = interpolate_start2
+ interpolated[:, seq_len - 1, :] = interpolate_end2
+
+ assert torch.allclose(interpolated[:, 0, :], interpolate_start1)
+ assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1)
+
+ assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2)
+ assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end2)
+ return interpolated
+
+
+def slerp(x, y, a):
+ """
+ Perfroms spherical linear interpolation (SLERP) between x and y, with proportion a
+
+ :param x: quaternion tensor
+ :param y: quaternion tensor
+ :param a: indicator (between 0 and 1) of completion of the interpolation.
+ :return: tensor of interpolation results
+ """
+ device = x.device
+ len = torch.sum(x * y, dim=-1)
+
+ neg = len < 0.0
+ len[neg] = -len[neg]
+ y[neg] = -y[neg]
+
+ a = torch.zeros_like(x[..., 0]) + a
+ amount0 = torch.zeros(a.shape, device=device)
+ amount1 = torch.zeros(a.shape, device=device)
+
+ linear = (1.0 - len) < 0.01
+ omegas = torch.arccos(len[~linear])
+ sinoms = torch.sin(omegas)
+
+ amount0[linear] = 1.0 - a[linear]
+ amount0[~linear] = torch.sin((1.0 - a[~linear]) * omegas) / sinoms
+
+ amount1[linear] = a[linear]
+ amount1[~linear] = torch.sin(a[~linear] * omegas) / sinoms
+ # res = amount0[..., np.newaxis] * x + amount1[..., np.newaxis] * y
+ res = amount0.unsqueeze(3) * x + amount1.unsqueeze(3) * y
+
+ return res
+
+
+def slerp_input_repr(minibatch_pose_input, mask_start_frame):
+ seq_len = minibatch_pose_input.size(1)
+ minibatch_pose_input = minibatch_pose_input.reshape(
+ minibatch_pose_input.size(0), seq_len, -1, 4
+ )
+ interpolated = torch.zeros_like(
+ minibatch_pose_input, device=minibatch_pose_input.device
+ )
+
+ if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
+ interpolate_start = minibatch_pose_input[:, 0:1]
+ interpolate_end = minibatch_pose_input[:, seq_len - 1 :]
+
+ for i in range(seq_len):
+ dt = 1 / (seq_len - 1)
+ interpolated[:, i : i + 1, :] = slerp(
+ interpolate_start, interpolate_end, dt * i
+ )
+
+ assert torch.allclose(interpolated[:, 0:1], interpolate_start)
+ assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end)
+ else:
+ interpolate_start1 = minibatch_pose_input[:, 0:1]
+ interpolate_end1 = minibatch_pose_input[
+ :, mask_start_frame : mask_start_frame + 1
+ ]
+
+ interpolate_start2 = minibatch_pose_input[
+ :, mask_start_frame : mask_start_frame + 1
+ ]
+ interpolate_end2 = minibatch_pose_input[:, seq_len - 1 :]
+
+ for i in range(mask_start_frame + 1):
+ dt = 1 / mask_start_frame
+ interpolated[:, i : i + 1, :] = slerp(
+ interpolate_start1, interpolate_end1, dt * i
+ )
+
+ assert torch.allclose(interpolated[:, 0:1], interpolate_start1)
+ assert torch.allclose(
+ interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_end1
+ )
+
+ for i in range(mask_start_frame, seq_len):
+ dt = 1 / (seq_len - mask_start_frame - 1)
+ interpolated[:, i : i + 1, :] = slerp(
+ interpolate_start2, interpolate_end2, dt * (i - mask_start_frame)
+ )
+
+ assert torch.allclose(
+ interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_start2
+ )
+ assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end2)
+
+ interpolated = torch.nn.functional.normalize(interpolated, p=2.0, dim=3)
+ return interpolated.reshape(minibatch_pose_input.size(0), seq_len, -1)
+
+
+def lerp_input_repr(minibatch_pose_input, mask_start_frame):
+ seq_len = minibatch_pose_input.size(1)
+ interpolated = torch.zeros_like(
+ minibatch_pose_input, device=minibatch_pose_input.device
+ )
+
+ if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
+ interpolate_start = minibatch_pose_input[:, 0, :]
+ interpolate_end = minibatch_pose_input[:, seq_len - 1, :]
+
+ for i in range(seq_len):
+ dt = 1 / (seq_len - 1)
+ interpolated[:, i, :] = torch.lerp(
+ interpolate_start, interpolate_end, dt * i
+ )
+
+ assert torch.allclose(interpolated[:, 0, :], interpolate_start)
+ assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end)
+ else:
+ interpolate_start1 = minibatch_pose_input[:, 0, :]
+ interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :]
+
+ interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :]
+ interpolate_end2 = minibatch_pose_input[:, -1, :]
+
+ for i in range(mask_start_frame + 1):
+ dt = 1 / mask_start_frame
+ interpolated[:, i, :] = torch.lerp(
+ interpolate_start1, interpolate_end1, dt * i
+ )
+
+ assert torch.allclose(interpolated[:, 0, :], interpolate_start1)
+ assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1)
+
+ for i in range(mask_start_frame, seq_len):
+ dt = 1 / (seq_len - mask_start_frame - 1)
+ interpolated[:, i, :] = torch.lerp(
+ interpolate_start2, interpolate_end2, dt * (i - mask_start_frame)
+ )
+
+ assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2)
+ assert torch.allclose(interpolated[:, -1, :], interpolate_end2)
+ return interpolated
+
+
+def vectorize_representation(global_position, global_rotation):
+
+ batch_size = global_position.shape[0]
+ seq_len = global_position.shape[1]
+
+ global_pos_vec = global_position.reshape(batch_size, seq_len, -1).contiguous()
+ global_rot_vec = global_rotation.reshape(batch_size, seq_len, -1).contiguous()
+
+ global_pose_vec_gt = torch.cat([global_pos_vec, global_rot_vec], dim=2)
+ return global_pose_vec_gt
diff --git a/infiller/lib/model/skeleton.py b/infiller/lib/model/skeleton.py
new file mode 100644
index 0000000000000000000000000000000000000000..d69797b4af0672c54c3635b5e4a39d3966e90e9b
--- /dev/null
+++ b/infiller/lib/model/skeleton.py
@@ -0,0 +1,349 @@
+import torch
+import numpy as np
+from cmib.data.quaternion import qmul, qrot
+import torch.nn as nn
+
+amass_offsets = [
+ [0.0, 0.0, 0.0],
+
+ [0.058581, -0.082280, -0.017664],
+ [0.043451, -0.386469, 0.008037],
+ [-0.014790, -0.426874, -0.037428],
+ [0.041054, -0.060286, 0.122042],
+ [0.0, 0.0, 0.0],
+
+ [-0.060310, -0.090513, -0.013543],
+ [-0.043257, -0.383688, -0.004843],
+ [0.019056, -0.420046, -0.034562],
+ [-0.034840, -0.062106, 0.130323],
+ [0.0, 0.0, 0.0],
+
+ [0.004439, 0.124404, -0.038385],
+ [0.004488, 0.137956, 0.026820],
+ [-0.002265, 0.056032, 0.002855],
+ [-0.013390, 0.211636, -0.033468],
+ [0.010113, 0.088937, 0.050410],
+ [0.0, 0.0, 0.0],
+
+ [0.071702, 0.114000, -0.018898],
+ [0.122921, 0.045205, -0.019046],
+ [0.255332, -0.015649, -0.022946],
+ [0.265709, 0.012698, -0.007375],
+ [0.0, 0.0, 0.0],
+
+ [-0.082954, 0.112472, -0.023707],
+ [-0.113228, 0.046853, -0.008472],
+ [-0.260127, -0.014369, -0.031269],
+ [-0.269108, 0.006794, -0.006027],
+ [0.0, 0.0, 0.0]
+]
+
+sk_offsets = [
+ [-42.198200, 91.614723, -40.067841],
+
+ [0.103456, 1.857829, 10.548506],
+ [43.499992, -0.000038, -0.000002],
+ [42.372192, 0.000015, -0.000007],
+ [17.299999, -0.000002, 0.000003],
+ [0.000000, 0.000000, 0.000000],
+
+ [0.103457, 1.857829, -10.548503],
+ [43.500042, -0.000027, 0.000008],
+ [42.372257, -0.000008, 0.000014],
+ [17.299992, -0.000005, 0.000004],
+ [0.000000, 0.000000, 0.000000],
+
+ [6.901968, -2.603733, -0.000001],
+ [12.588099, 0.000002, 0.000000],
+ [12.343206, 0.000000, -0.000001],
+ [25.832886, -0.000004, 0.000003],
+ [11.766620, 0.000005, -0.000001],
+ [0.000000, 0.000000, 0.000000],
+
+ [19.745899, -1.480370, 6.000108],
+ [11.284125, -0.000009, -0.000018],
+ [33.000050, 0.000004, 0.000032],
+ [25.200008, 0.000015, 0.000008],
+ [0.000000, 0.000000, 0.000000],
+
+ [19.746099, -1.480375, -6.000073],
+ [11.284138, -0.000015, -0.000012],
+ [33.000092, 0.000017, 0.000013],
+ [25.199780, 0.000135, 0.000422],
+ [0.000000, 0.000000, 0.000000],
+]
+
+sk_parents = [
+ -1,
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 0,
+ 6,
+ 7,
+ 8,
+ 9,
+ 0,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 13,
+ 17,
+ 18,
+ 19,
+ 20,
+ 13,
+ 22,
+ 23,
+ 24,
+ 25,
+]
+
+sk_joints_to_remove = [5, 10, 16, 21, 26]
+
+joint_names = [
+ "Hips",
+ "LeftUpLeg",
+ "LeftLeg",
+ "LeftFoot",
+ "LeftToe",
+ "RightUpLeg",
+ "RightLeg",
+ "RightFoot",
+ "RightToe",
+ "Spine",
+ "Spine1",
+ "Spine2",
+ "Neck",
+ "Head",
+ "LeftShoulder",
+ "LeftArm",
+ "LeftForeArm",
+ "LeftHand",
+ "RightShoulder",
+ "RightArm",
+ "RightForeArm",
+ "RightHand",
+]
+
+
+class Skeleton:
+ def __init__(
+ self,
+ offsets,
+ parents,
+ joints_left=None,
+ joints_right=None,
+ bone_length=None,
+ device=None,
+ ):
+ assert len(offsets) == len(parents)
+
+ self._offsets = torch.Tensor(offsets).to(device)
+ self._parents = np.array(parents)
+ self._joints_left = joints_left
+ self._joints_right = joints_right
+ self._compute_metadata()
+
+ def num_joints(self):
+ return self._offsets.shape[0]
+
+ def offsets(self):
+ return self._offsets
+
+ def parents(self):
+ return self._parents
+
+ def has_children(self):
+ return self._has_children
+
+ def children(self):
+ return self._children
+
+ def convert_to_global_pos(self, unit_vec_rerp):
+ """
+ Convert the unit offset matrix to global position.
+ First row(root) will have absolute position value in global coordinates.
+ """
+ bone_length = self.get_bone_length_weight()
+ batch_size = unit_vec_rerp.size(0)
+ seq_len = unit_vec_rerp.size(1)
+ unit_vec_table = unit_vec_rerp.reshape(batch_size, seq_len, 22, 3)
+ global_position = torch.zeros_like(unit_vec_table, device=unit_vec_table.device)
+
+ for i, parent in enumerate(self._parents):
+ if parent == -1: # if root
+ global_position[:, :, i] = unit_vec_table[:, :, i]
+
+ else:
+ global_position[:, :, i] = global_position[:, :, parent] + (
+ nn.functional.normalize(unit_vec_table[:, :, i], p=2.0, dim=-1)
+ * bone_length[i]
+ )
+
+ return global_position
+
+ def convert_to_unit_offset_mat(self, global_position):
+ """
+ Convert the global position of the skeleton to a unit offset matrix.
+ First row(root) will have absolute position value in global coordinates.
+ """
+
+ bone_length = self.get_bone_length_weight()
+ unit_offset_mat = torch.zeros_like(
+ global_position, device=global_position.device
+ )
+
+ for i, parent in enumerate(self._parents):
+
+ if parent == -1: # if root
+ unit_offset_mat[:, :, i] = global_position[:, :, i]
+ else:
+ unit_offset_mat[:, :, i] = (
+ global_position[:, :, i] - global_position[:, :, parent]
+ ) / bone_length[i]
+
+ return unit_offset_mat
+
+ def remove_joints(self, joints_to_remove):
+ """
+ Remove the joints specified in 'joints_to_remove', both from the
+ skeleton definition and from the dataset (which is modified in place).
+ The rotations of removed joints are propagated along the kinematic chain.
+ """
+ valid_joints = []
+ for joint in range(len(self._parents)):
+ if joint not in joints_to_remove:
+ valid_joints.append(joint)
+
+ index_offsets = np.zeros(len(self._parents), dtype=int)
+ new_parents = []
+ for i, parent in enumerate(self._parents):
+ if i not in joints_to_remove:
+ new_parents.append(parent - index_offsets[parent])
+ else:
+ index_offsets[i:] += 1
+ self._parents = np.array(new_parents)
+
+ self._offsets = self._offsets[valid_joints]
+ self._compute_metadata()
+
+ def forward_kinematics(self, rotations, root_positions):
+ """
+ Perform forward kinematics using the given trajectory and local rotations.
+ Arguments (where N = batch size, L = sequence length, J = number of joints):
+ -- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint.
+ -- root_positions: (N, L, 3) tensor describing the root joint positions.
+ """
+ assert len(rotations.shape) == 4
+ assert rotations.shape[-1] == 4
+
+ positions_world = []
+ rotations_world = []
+
+ expanded_offsets = self._offsets.expand(
+ rotations.shape[0],
+ rotations.shape[1],
+ self._offsets.shape[0],
+ self._offsets.shape[1],
+ )
+
+ # Parallelize along the batch and time dimensions
+ for i in range(self._offsets.shape[0]):
+ if self._parents[i] == -1:
+ positions_world.append(root_positions)
+ rotations_world.append(rotations[:, :, 0])
+ else:
+ positions_world.append(
+ qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i])
+ + positions_world[self._parents[i]]
+ )
+ if self._has_children[i]:
+ rotations_world.append(
+ qmul(rotations_world[self._parents[i]], rotations[:, :, i])
+ )
+ else:
+ # This joint is a terminal node -> it would be useless to compute the transformation
+ rotations_world.append(None)
+
+ return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2)
+
+ def forward_kinematics_with_rotation(self, rotations, root_positions):
+ """
+ Perform forward kinematics using the given trajectory and local rotations.
+ Arguments (where N = batch size, L = sequence length, J = number of joints):
+ -- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint.
+ -- root_positions: (N, L, 3) tensor describing the root joint positions.
+ """
+ assert len(rotations.shape) == 4
+ assert rotations.shape[-1] == 4
+
+ positions_world = []
+ rotations_world = []
+
+ expanded_offsets = self._offsets.expand(
+ rotations.shape[0],
+ rotations.shape[1],
+ self._offsets.shape[0],
+ self._offsets.shape[1],
+ )
+
+ # Parallelize along the batch and time dimensions
+ for i in range(self._offsets.shape[0]):
+ if self._parents[i] == -1:
+ positions_world.append(root_positions)
+ rotations_world.append(rotations[:, :, 0])
+ else:
+ positions_world.append(
+ qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i])
+ + positions_world[self._parents[i]]
+ )
+ if self._has_children[i]:
+ rotations_world.append(
+ qmul(rotations_world[self._parents[i]], rotations[:, :, i])
+ )
+ else:
+ # This joint is a terminal node -> it would be useless to compute the transformation
+ rotations_world.append(
+ torch.Tensor([1, 0, 0, 0])
+ .expand(rotations.shape[0], rotations.shape[1], 4)
+ .to(rotations.device)
+ )
+
+ return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2), torch.stack(
+ rotations_world, dim=3
+ ).permute(0, 1, 3, 2)
+
+ def get_bone_length_weight(self):
+ bone_length = []
+ for i, parent in enumerate(self._parents):
+ if parent == -1:
+ bone_length.append(1)
+ else:
+ bone_length.append(
+ torch.linalg.norm(self._offsets[i : i + 1], ord="fro").item()
+ )
+ return torch.Tensor(bone_length)
+
+ def joints_left(self):
+ return self._joints_left
+
+ def joints_right(self):
+ return self._joints_right
+
+ def _compute_metadata(self):
+ self._has_children = np.zeros(len(self._parents)).astype(bool)
+ for i, parent in enumerate(self._parents):
+ if parent != -1:
+ self._has_children[parent] = True
+
+ self._children = []
+ for i, parent in enumerate(self._parents):
+ self._children.append([])
+ for i, parent in enumerate(self._parents):
+ if parent != -1:
+ self._children[parent].append(i)
diff --git a/infiller/lib/vis/pose.py b/infiller/lib/vis/pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..767edd1090b93044e4584624e7eb580c0285cb3b
--- /dev/null
+++ b/infiller/lib/vis/pose.py
@@ -0,0 +1,248 @@
+import os
+import pathlib
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+def project_root_position(position_arr: np.array, file_name: str):
+ """
+ Take batch of root arrays and porject it on 2D plane
+
+ N: samples
+ L: trajectory length
+ J: joints
+
+ position_arr: [N,L,J,3]
+ """
+
+ root_joints = position_arr[:, :, 0]
+
+ x_pos = root_joints[:, :, 0]
+ y_pos = root_joints[:, :, 2]
+
+ fig = plt.figure()
+
+ for i in range(x_pos.shape[1]):
+
+ if i == 0:
+ plt.scatter(x_pos[:, i], y_pos[:, i], c="b")
+ elif i == x_pos.shape[1] - 1:
+ plt.scatter(x_pos[:, i], y_pos[:, i], c="r")
+ else:
+ plt.scatter(x_pos[:, i], y_pos[:, i], c="k", marker="*", s=1)
+
+ plt.title(f"Root Position: {file_name}")
+ plt.xlabel("X Axis")
+ plt.ylabel("Y Axis")
+ plt.xlim((-300, 300))
+ plt.ylim((-300, 300))
+ plt.grid()
+ plt.savefig(f"{file_name}.png", dpi=200)
+
+
+def plot_single_pose(
+ pose,
+ frame_idx,
+ skeleton,
+ save_dir,
+ prefix,
+):
+ fig = plt.figure()
+ ax = fig.add_subplot(111, projection="3d")
+
+ parent_idx = skeleton.parents()
+
+ for i, p in enumerate(parent_idx):
+ if i > 0:
+ ax.plot(
+ [pose[i, 0], pose[p, 0]],
+ [pose[i, 2], pose[p, 2]],
+ [pose[i, 1], pose[p, 1]],
+ c="k",
+ )
+
+ x_min = pose[:, 0].min()
+ x_max = pose[:, 0].max()
+
+ y_min = pose[:, 1].min()
+ y_max = pose[:, 1].max()
+
+ z_min = pose[:, 2].min()
+ z_max = pose[:, 2].max()
+
+ ax.set_xlim(x_min, x_max)
+ ax.set_xlabel("$X$ Axis")
+
+ ax.set_ylim(z_min, z_max)
+ ax.set_ylabel("$Y$ Axis")
+
+ ax.set_zlim(y_min, y_max)
+ ax.set_zlabel("$Z$ Axis")
+
+ plt.draw()
+
+ title = f"{prefix}: {frame_idx}"
+ plt.title(title)
+ prefix = prefix
+ pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)
+ plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60)
+ plt.close()
+
+
+def plot_pose(
+ start_pose,
+ inbetween_pose,
+ target_pose,
+ frame_idx,
+ skeleton,
+ save_dir,
+ prefix,
+):
+ fig = plt.figure()
+ ax = fig.add_subplot(111, projection="3d")
+
+ parent_idx = skeleton.parents()
+
+ for i, p in enumerate(parent_idx):
+ if i > 0:
+ ax.plot(
+ [start_pose[i, 0], start_pose[p, 0]],
+ [start_pose[i, 2], start_pose[p, 2]],
+ [start_pose[i, 1], start_pose[p, 1]],
+ c="b",
+ )
+ ax.plot(
+ [inbetween_pose[i, 0], inbetween_pose[p, 0]],
+ [inbetween_pose[i, 2], inbetween_pose[p, 2]],
+ [inbetween_pose[i, 1], inbetween_pose[p, 1]],
+ c="k",
+ )
+ ax.plot(
+ [target_pose[i, 0], target_pose[p, 0]],
+ [target_pose[i, 2], target_pose[p, 2]],
+ [target_pose[i, 1], target_pose[p, 1]],
+ c="r",
+ )
+
+ x_min = np.min(
+ [start_pose[:, 0].min(), inbetween_pose[:, 0].min(), target_pose[:, 0].min()]
+ )
+ x_max = np.max(
+ [start_pose[:, 0].max(), inbetween_pose[:, 0].max(), target_pose[:, 0].max()]
+ )
+
+ y_min = np.min(
+ [start_pose[:, 1].min(), inbetween_pose[:, 1].min(), target_pose[:, 1].min()]
+ )
+ y_max = np.max(
+ [start_pose[:, 1].max(), inbetween_pose[:, 1].max(), target_pose[:, 1].max()]
+ )
+
+ z_min = np.min(
+ [start_pose[:, 2].min(), inbetween_pose[:, 2].min(), target_pose[:, 2].min()]
+ )
+ z_max = np.max(
+ [start_pose[:, 2].max(), inbetween_pose[:, 2].max(), target_pose[:, 2].max()]
+ )
+
+ ax.set_xlim(x_min, x_max)
+ ax.set_xlabel("$X$ Axis")
+
+ ax.set_ylim(z_min, z_max)
+ ax.set_ylabel("$Y$ Axis")
+
+ ax.set_zlim(y_min, y_max)
+ ax.set_zlabel("$Z$ Axis")
+
+ plt.draw()
+
+ title = f"{prefix}: {frame_idx}"
+ plt.title(title)
+ prefix = prefix
+ pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)
+ plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60)
+ plt.close()
+
+
+def plot_pose_with_stop(
+ start_pose,
+ inbetween_pose,
+ target_pose,
+ stopover,
+ frame_idx,
+ skeleton,
+ save_dir,
+ prefix,
+):
+ fig = plt.figure()
+ ax = fig.add_subplot(111, projection="3d")
+
+ parent_idx = skeleton.parents()
+
+ for i, p in enumerate(parent_idx):
+ if i > 0:
+ ax.plot(
+ [start_pose[i, 0], start_pose[p, 0]],
+ [start_pose[i, 2], start_pose[p, 2]],
+ [start_pose[i, 1], start_pose[p, 1]],
+ c="b",
+ )
+ ax.plot(
+ [inbetween_pose[i, 0], inbetween_pose[p, 0]],
+ [inbetween_pose[i, 2], inbetween_pose[p, 2]],
+ [inbetween_pose[i, 1], inbetween_pose[p, 1]],
+ c="k",
+ )
+ ax.plot(
+ [target_pose[i, 0], target_pose[p, 0]],
+ [target_pose[i, 2], target_pose[p, 2]],
+ [target_pose[i, 1], target_pose[p, 1]],
+ c="r",
+ )
+
+ ax.plot(
+ [stopover[i, 0], stopover[p, 0]],
+ [stopover[i, 2], stopover[p, 2]],
+ [stopover[i, 1], stopover[p, 1]],
+ c="indigo",
+ )
+
+ x_min = np.min(
+ [start_pose[:, 0].min(), inbetween_pose[:, 0].min(), target_pose[:, 0].min()]
+ )
+ x_max = np.max(
+ [start_pose[:, 0].max(), inbetween_pose[:, 0].max(), target_pose[:, 0].max()]
+ )
+
+ y_min = np.min(
+ [start_pose[:, 1].min(), inbetween_pose[:, 1].min(), target_pose[:, 1].min()]
+ )
+ y_max = np.max(
+ [start_pose[:, 1].max(), inbetween_pose[:, 1].max(), target_pose[:, 1].max()]
+ )
+
+ z_min = np.min(
+ [start_pose[:, 2].min(), inbetween_pose[:, 2].min(), target_pose[:, 2].min()]
+ )
+ z_max = np.max(
+ [start_pose[:, 2].max(), inbetween_pose[:, 2].max(), target_pose[:, 2].max()]
+ )
+
+ ax.set_xlim(x_min, x_max)
+ ax.set_xlabel("$X$ Axis")
+
+ ax.set_ylim(z_min, z_max)
+ ax.set_ylabel("$Y$ Axis")
+
+ ax.set_zlim(y_min, y_max)
+ ax.set_zlabel("$Z$ Axis")
+
+ plt.draw()
+
+ title = f"{prefix}: {frame_idx}"
+ plt.title(title)
+ prefix = prefix
+ pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)
+ plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60)
+ plt.close()
diff --git a/lib/core/__pycache__/constants.cpython-310.pyc b/lib/core/__pycache__/constants.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3d76e2ff1c3a5d84ff35a861b049b49923bada5
Binary files /dev/null and b/lib/core/__pycache__/constants.cpython-310.pyc differ
diff --git a/lib/core/constants.py b/lib/core/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..818df466acf09a3952ba49075fe8b66902e19d73
--- /dev/null
+++ b/lib/core/constants.py
@@ -0,0 +1,78 @@
+FOCAL_LENGTH = 5000.
+
+# Mean and standard deviation for normalizing input image
+IMG_NORM_MEAN = [0.485, 0.456, 0.406]
+IMG_NORM_STD = [0.229, 0.224, 0.225]
+
+"""
+We create a superset of joints containing the OpenPose joints together with the ones that each dataset provides.
+We keep a superset of 24 joints such that we include all joints from every dataset.
+If a dataset doesn't provide annotations for a specific joint, we simply ignore it.
+The joints used here are the following:
+"""
+JOINT_NAMES = [
+'OP Nose', 'OP Neck', 'OP RShoulder', #0,1,2
+'OP RElbow', 'OP RWrist', 'OP LShoulder', #3,4,5
+'OP LElbow', 'OP LWrist', 'OP MidHip', #6, 7,8
+'OP RHip', 'OP RKnee', 'OP RAnkle', #9,10,11
+'OP LHip', 'OP LKnee', 'OP LAnkle', #12,13,14
+'OP REye', 'OP LEye', 'OP REar', #15,16,17
+'OP LEar', 'OP LBigToe', 'OP LSmallToe', #18,19,20
+'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel', #21, 22, 23, 24 ##Total 25 joints for openpose
+'Right Ankle', 'Right Knee', 'Right Hip', #0,1,2
+'Left Hip', 'Left Knee', 'Left Ankle', #3, 4, 5
+'Right Wrist', 'Right Elbow', 'Right Shoulder', #6
+'Left Shoulder', 'Left Elbow', 'Left Wrist', #9
+'Neck (LSP)', 'Top of Head (LSP)', #12, 13
+'Pelvis (MPII)', 'Thorax (MPII)', #14, 15
+'Spine (H36M)', 'Jaw (H36M)', #16, 17
+'Head (H36M)', 'Nose', 'Left Eye', #18, 19, 20
+'Right Eye', 'Left Ear', 'Right Ear' #21,22,23 (Total 24 joints)
+]
+
+# Dict containing the joints in numerical order
+JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))}
+
+# Map joints to SMPL joints
+JOINT_MAP = {
+'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17,
+'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16,
+'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0,
+'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8,
+'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7,
+'OP REye': 25, 'OP LEye': 26, 'OP REar': 27,
+'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30,
+'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34,
+'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45,
+'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7,
+'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17,
+'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20,
+'Neck (LSP)': 47, 'Top of Head (LSP)': 48,
+'Pelvis (MPII)': 49, 'Thorax (MPII)': 50,
+'Spine (H36M)': 51, 'Jaw (H36M)': 52,
+'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26,
+'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27
+}
+
+# Joint selectors
+# Indices to get the 14 LSP joints from the 17 H36M joints
+H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9]
+H36M_TO_J14 = H36M_TO_J17[:14]
+# Indices to get the 14 LSP joints from the ground truth joints
+J24_TO_J17 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18, 14, 16, 17]
+J24_TO_J14 = J24_TO_J17[:14]
+
+# Permutation of SMPL pose parameters when flipping the shape
+SMPL_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 23, 22]
+SMPL_POSE_FLIP_PERM = []
+for i in SMPL_JOINTS_FLIP_PERM:
+ SMPL_POSE_FLIP_PERM.append(3*i)
+ SMPL_POSE_FLIP_PERM.append(3*i+1)
+ SMPL_POSE_FLIP_PERM.append(3*i+2)
+# Permutation indices for the 24 ground truth joints
+J24_FLIP_PERM = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22]
+# Permutation indices for the full set of 49 joints
+J49_FLIP_PERM = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]\
+ + [25+i for i in J24_FLIP_PERM]
+
+
diff --git a/lib/datasets/__pycache__/track_dataset.cpython-310.pyc b/lib/datasets/__pycache__/track_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..238cee6d17b180d75b563f6a34b95d85a20866bd
Binary files /dev/null and b/lib/datasets/__pycache__/track_dataset.cpython-310.pyc differ
diff --git a/lib/datasets/track_dataset.py b/lib/datasets/track_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..076cf8897f87fd03f65eaea7d0fb9d8001eface0
--- /dev/null
+++ b/lib/datasets/track_dataset.py
@@ -0,0 +1,78 @@
+import torch
+from torch.utils.data import Dataset
+from torchvision.transforms import Normalize, ToTensor, Compose
+import numpy as np
+import cv2
+
+from lib.core import constants
+from lib.utils.imutils import crop, boxes_2_cs
+
+
+class TrackDatasetEval(Dataset):
+ """
+ Track Dataset Class - Load images/crops of the tracked boxes.
+ """
+ def __init__(self, imgfiles, boxes,
+ crop_size=256, dilate=1.0,
+ img_focal=None, img_center=None, normalization=True,
+ item_idx=0, do_flip=False):
+ super(TrackDatasetEval, self).__init__()
+
+ self.imgfiles = imgfiles
+ self.crop_size = crop_size
+ self.normalization = normalization
+ self.normalize_img = Compose([
+ ToTensor(),
+ Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD)
+ ])
+
+ self.boxes = boxes
+ self.box_dilate = dilate
+ self.centers, self.scales = boxes_2_cs(boxes)
+
+ self.img_focal = img_focal
+ self.img_center = img_center
+ self.item_idx = item_idx
+ self.do_flip = do_flip
+
+ def __len__(self):
+ return len(self.imgfiles)
+
+
+ def __getitem__(self, index):
+ item = {}
+ imgfile = self.imgfiles[index]
+ scale = self.scales[index] * self.box_dilate
+ center = self.centers[index]
+
+ img_focal = self.img_focal
+ img_center = self.img_center
+
+ img = cv2.imread(imgfile)[:,:,::-1]
+ if self.do_flip:
+ img = img[:, ::-1, :]
+ img_width = img.shape[1]
+ center[0] = img_width - center[0] - 1
+ img_crop = crop(img, center, scale,
+ [self.crop_size, self.crop_size],
+ rot=0).astype('uint8')
+ # cv2.imwrite('debug_crop.png', img_crop[:,:,::-1])
+
+ if self.normalization:
+ img_crop = self.normalize_img(img_crop)
+ else:
+ img_crop = torch.from_numpy(img_crop)
+ item['img'] = img_crop
+
+ if self.do_flip:
+ # center[0] = img_width - center[0] - 1
+ item['do_flip'] = torch.tensor(1).float()
+ item['img_idx'] = torch.tensor(index).long()
+ item['scale'] = torch.tensor(scale).float()
+ item['center'] = torch.tensor(center).float()
+ item['img_focal'] = torch.tensor(img_focal).float()
+ item['img_center'] = torch.tensor(img_center).float()
+
+
+ return item
+
diff --git a/lib/eval_utils/__pycache__/custom_utils.cpython-310.pyc b/lib/eval_utils/__pycache__/custom_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..434627c371a53973a3f661331ef3b8c1920ee8ac
Binary files /dev/null and b/lib/eval_utils/__pycache__/custom_utils.cpython-310.pyc differ
diff --git a/lib/eval_utils/__pycache__/filling_utils.cpython-310.pyc b/lib/eval_utils/__pycache__/filling_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd71e1bcff744702fd95d00968a8b5955282a967
Binary files /dev/null and b/lib/eval_utils/__pycache__/filling_utils.cpython-310.pyc differ
diff --git a/lib/eval_utils/custom_utils.py b/lib/eval_utils/custom_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fcae184117ef0186ded0e661d0e73060e158529
--- /dev/null
+++ b/lib/eval_utils/custom_utils.py
@@ -0,0 +1,99 @@
+import copy
+import numpy as np
+import torch
+
+from hawor.utils.process import run_mano, run_mano_left
+from hawor.utils.rotation import angle_axis_to_quaternion, rotation_matrix_to_angle_axis
+from scipy.interpolate import interp1d
+
+
+def cam2world_convert(R_c2w_sla, t_c2w_sla, data_out, handedness):
+ init_rot_mat = copy.deepcopy(data_out["init_root_orient"])
+ init_rot_mat = torch.einsum("tij,btjk->btik", R_c2w_sla, init_rot_mat)
+ init_rot = rotation_matrix_to_angle_axis(init_rot_mat)
+ init_rot_quat = angle_axis_to_quaternion(init_rot)
+ # data_out["init_root_orient"] = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
+ # data_out["init_hand_pose"] = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
+ data_out_init_root_orient = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
+ data_out_init_hand_pose = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
+
+ init_trans = data_out["init_trans"] # (B, T, 3)
+ if handedness == "right":
+ outputs = run_mano(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"])
+ elif handedness == "left":
+ outputs = run_mano_left(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"])
+ root_loc = outputs["joints"][..., 0, :].cpu() # (B, T, 3)
+ offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
+ init_trans = (
+ torch.einsum("tij,btj->bti", R_c2w_sla, root_loc)
+ + t_c2w_sla[None, :]
+ + offset
+ )
+
+ data_world = {
+ "init_root_orient": init_rot, # (B, T, 3)
+ "init_hand_pose": data_out_init_hand_pose, # (B, T, 15, 3)
+ "init_trans": init_trans, # (B, T, 3)
+ "init_betas": data_out["init_betas"] # (B, T, 10)
+ }
+
+ return data_world
+
+def quaternion_to_matrix(quaternions):
+ """
+ Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ r, i, j, k = torch.unbind(quaternions, -1)
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+def load_slam_cam(fpath):
+ print(f"Loading cameras from {fpath}...")
+ pred_cam = dict(np.load(fpath, allow_pickle=True))
+ pred_traj = pred_cam['traj']
+ t_c2w_sla = torch.tensor(pred_traj[:, :3]) * pred_cam['scale']
+ pred_camq = torch.tensor(pred_traj[:, 3:])
+ R_c2w_sla = quaternion_to_matrix(pred_camq[:,[3,0,1,2]])
+ R_w2c_sla = R_c2w_sla.transpose(-1, -2)
+ t_w2c_sla = -torch.einsum("bij,bj->bi", R_w2c_sla, t_c2w_sla)
+ return R_w2c_sla, t_w2c_sla, R_c2w_sla, t_c2w_sla
+
+
+def interpolate_bboxes(bboxes):
+ T = bboxes.shape[0]
+
+ zero_indices = np.where(np.all(bboxes == 0, axis=1))[0]
+
+ non_zero_indices = np.where(np.any(bboxes != 0, axis=1))[0]
+
+ if len(zero_indices) == 0:
+ return bboxes
+
+ interpolated_bboxes = bboxes.copy()
+ for i in range(5):
+ interp_func = interp1d(non_zero_indices, bboxes[non_zero_indices, i], kind='linear', fill_value="extrapolate")
+ interpolated_bboxes[zero_indices, i] = interp_func(zero_indices)
+
+ return interpolated_bboxes
\ No newline at end of file
diff --git a/lib/eval_utils/filling_utils.py b/lib/eval_utils/filling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..672d98753d2c2c088d011c2d8b2d00bac16c2257
--- /dev/null
+++ b/lib/eval_utils/filling_utils.py
@@ -0,0 +1,306 @@
+import copy
+import os
+import joblib
+import numpy as np
+from scipy.spatial.transform import Slerp, Rotation
+import torch
+
+from hawor.utils.process import run_mano, run_mano_left
+from hawor.utils.rotation import angle_axis_to_quaternion, angle_axis_to_rotation_matrix, quaternion_to_rotation_matrix, rotation_matrix_to_angle_axis
+from lib.utils.geometry import rotmat_to_rot6d
+from lib.utils.geometry import rot6d_to_rotmat
+
+def slerp_interpolation_aa(pos, valid):
+
+ B, T, N, _ = pos.shape # B: 批次大小, T: 时间步长, N: 关节数, 4: 四元数维度
+ pos_interp = pos.copy() # 创建副本以存储插值结果
+
+ for b in range(B):
+ for n in range(N):
+ quat_b_n = pos[b, :, n, :]
+ valid_b_n = valid[b, :]
+
+ invalid_idxs = np.where(~valid_b_n)[0]
+ valid_idxs = np.where(valid_b_n)[0]
+
+ if len(invalid_idxs) == 0:
+ continue
+
+ if len(valid_idxs) > 1:
+ valid_times = valid_idxs # 有效时间步
+ valid_rots = Rotation.from_rotvec(quat_b_n[valid_idxs]) # 有效四元数
+
+ slerp = Slerp(valid_times, valid_rots)
+
+ for idx in invalid_idxs:
+ if idx < valid_idxs[0]: # 时间步小于第一个有效时间步,进行外推
+ pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[0]] # 复制第一个有效四元数
+ elif idx > valid_idxs[-1]: # 时间步大于最后一个有效时间步,进行外推
+ pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[-1]] # 复制最后一个有效四元数
+ else:
+ interp_rot = slerp([idx])
+ pos_interp[b, idx, n, :] = interp_rot.as_rotvec()[0]
+ # print("#######")
+ # if N > 1:
+ # print(pos[1,0,11])
+ # print(pos_interp[1,0,11])
+
+ return pos_interp
+
+def slerp_interpolation_quat(pos, valid):
+
+ # wxyz to xyzw
+ pos = pos[:, :, :, [1, 2, 3, 0]]
+
+ B, T, N, _ = pos.shape # B: 批次大小, T: 时间步长, N: 关节数, 4: 四元数维度
+ pos_interp = pos.copy() # 创建副本以存储插值结果
+
+ for b in range(B):
+ for n in range(N):
+ quat_b_n = pos[b, :, n, :]
+ valid_b_n = valid[b, :]
+
+ invalid_idxs = np.where(~valid_b_n)[0]
+ valid_idxs = np.where(valid_b_n)[0]
+
+ if len(invalid_idxs) == 0:
+ continue
+
+ if len(valid_idxs) > 1:
+ valid_times = valid_idxs # 有效时间步
+ valid_rots = Rotation.from_quat(quat_b_n[valid_idxs]) # 有效四元数
+
+ slerp = Slerp(valid_times, valid_rots)
+
+ for idx in invalid_idxs:
+ if idx < valid_idxs[0]: # 时间步小于第一个有效时间步,进行外推
+ pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[0]] # 复制第一个有效四元数
+ elif idx > valid_idxs[-1]: # 时间步大于最后一个有效时间步,进行外推
+ pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[-1]] # 复制最后一个有效四元数
+ else:
+ interp_rot = slerp([idx])
+ pos_interp[b, idx, n, :] = interp_rot.as_quat()[0]
+
+ # xyzw to wxyz
+ pos_interp = pos_interp[:, :, :, [3, 0, 1, 2]]
+ return pos_interp
+
+
+def linear_interpolation_nd(pos, valid):
+ B, T = pos.shape[:2] # 取出批次大小B和时间步长T
+ feature_dim = pos.shape[2] # ** 代表的任意维度
+ pos_interp = pos.copy() # 创建一个副本,用来保存插值结果
+
+ for b in range(B):
+ for idx in range(feature_dim): # 针对任意维度
+ pos_b_idx = pos[b, :, idx] # 取出第b批次对应的**维度下的一个时间序列
+ valid_b = valid[b, :] # 当前批次的有效标志
+
+ # 找到无效的索引(False)
+ invalid_idxs = np.where(~valid_b)[0]
+ valid_idxs = np.where(valid_b)[0]
+
+ if len(invalid_idxs) == 0:
+ continue
+
+ # 对无效部分进行线性插值
+ if len(valid_idxs) > 1: # 确保有足够的有效点用于插值
+ pos_b_idx[invalid_idxs] = np.interp(invalid_idxs, valid_idxs, pos_b_idx[valid_idxs])
+ pos_interp[b, :, idx] = pos_b_idx # 保存插值结果
+
+ return pos_interp
+
+def world2canonical_convert(R_c2w_sla, t_c2w_sla, data_out, handedness):
+ init_rot_mat = copy.deepcopy(data_out["init_root_orient"])
+ init_rot_mat = torch.einsum("tij,btjk->btik", R_c2w_sla, init_rot_mat)
+ init_rot = rotation_matrix_to_angle_axis(init_rot_mat)
+ init_rot_quat = angle_axis_to_quaternion(init_rot)
+ # data_out["init_root_orient"] = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
+ # data_out["init_hand_pose"] = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
+ data_out_init_root_orient = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
+ data_out_init_hand_pose = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
+
+ init_trans = data_out["init_trans"] # (B, T, 3)
+ if handedness == "left":
+ outputs = run_mano_left(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"])
+
+ elif handedness == "right":
+ outputs = run_mano(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"])
+ root_loc = outputs["joints"][..., 0, :].cpu() # (B, T, 3)
+ offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
+ init_trans = (
+ torch.einsum("tij,btj->bti", R_c2w_sla, root_loc)
+ + t_c2w_sla[None, :]
+ + offset
+ )
+
+ data_world = {
+ "init_root_orient": init_rot, # (B, T, 3)
+ "init_hand_pose": data_out_init_hand_pose, # (B, T, 15, 3)
+ "init_trans": init_trans, # (B, T, 3)
+ "init_betas": data_out["init_betas"] # (B, T, 10)
+ }
+
+ return data_world
+
+def filling_preprocess(item):
+
+ num_joints = 15
+
+ global_trans = item['trans'] # (2, seq_len, 3)
+ global_rot = item['rot'] #(2, seq_len, 3)
+ hand_pose = item['hand_pose'] # (2, seq_len, 45)
+ betas = item['betas'] # (2, seq_len, 10)
+ valid = item['valid'] # (2, seq_len)
+
+ N, T, _ = global_trans.shape
+ R_canonical2world_left_aa = torch.from_numpy(global_rot[0, 0])
+ R_canonical2world_right_aa = torch.from_numpy(global_rot[1, 0])
+ R_world2canonical_left = angle_axis_to_rotation_matrix(R_canonical2world_left_aa).t()
+ R_world2canonical_right = angle_axis_to_rotation_matrix(R_canonical2world_right_aa).t()
+
+
+ # transform left hand to canonical
+ hand_pose = hand_pose.reshape(N, T, num_joints, 3)
+ data_world_left = {
+ "init_trans": torch.from_numpy(global_trans[0:1]),
+ "init_root_orient": angle_axis_to_rotation_matrix(torch.from_numpy(global_rot[0:1])),
+ "init_hand_pose": angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose[0:1])),
+ "init_betas": torch.from_numpy(betas[0:1]),
+ }
+
+ data_left_init_root_orient = rotation_matrix_to_angle_axis(data_world_left["init_root_orient"])
+ data_left_init_hand_pose = rotation_matrix_to_angle_axis(data_world_left["init_hand_pose"])
+ outputs = run_mano_left(data_world_left["init_trans"], data_left_init_root_orient, data_left_init_hand_pose, betas=data_world_left["init_betas"])
+ init_trans = data_world_left["init_trans"][0, 0] # (3,)
+ root_loc = outputs["joints"][0, 0, 0, :].cpu() # (3,)
+ offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
+ t_world2canonical_left = -torch.einsum("ij,j->i", R_world2canonical_left, root_loc) - offset
+
+ R_world2canonical_left = R_world2canonical_left.repeat(T, 1, 1)
+ t_world2canonical_left = t_world2canonical_left.repeat(T, 1)
+ data_canonical_left = world2canonical_convert(R_world2canonical_left, t_world2canonical_left, data_world_left, "left")
+
+ # transform right hand to canonical
+ data_world_right = {
+ "init_trans": torch.from_numpy(global_trans[1:2]),
+ "init_root_orient": angle_axis_to_rotation_matrix(torch.from_numpy(global_rot[1:2])),
+ "init_hand_pose": angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose[1:2])),
+ "init_betas": torch.from_numpy(betas[1:2]),
+ }
+
+ data_right_init_root_orient = rotation_matrix_to_angle_axis(data_world_right["init_root_orient"])
+ data_right_init_hand_pose = rotation_matrix_to_angle_axis(data_world_right["init_hand_pose"])
+ outputs = run_mano(data_world_right["init_trans"], data_right_init_root_orient, data_right_init_hand_pose, betas=data_world_right["init_betas"])
+ init_trans = data_world_right["init_trans"][0, 0] # (3,)
+ root_loc = outputs["joints"][0, 0, 0, :].cpu() # (3,)
+ offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
+ t_world2canonical_right = -torch.einsum("ij,j->i", R_world2canonical_right, root_loc) - offset
+
+ R_world2canonical_right = R_world2canonical_right.repeat(T, 1, 1)
+ t_world2canonical_right = t_world2canonical_right.repeat(T, 1)
+ data_canonical_right = world2canonical_convert(R_world2canonical_right, t_world2canonical_right, data_world_right, "right")
+
+ # merge left and right canonical data
+ global_rot = torch.cat((data_canonical_left['init_root_orient'], data_canonical_right['init_root_orient']))
+ global_trans = torch.cat((data_canonical_left['init_trans'], data_canonical_right['init_trans'])).numpy()
+
+ # global_rot = angle_axis_to_quaternion(global_rot).numpy().reshape(N, T, 1, 4)
+ global_rot = global_rot.reshape(N, T, 1, 3).numpy()
+
+ hand_pose = hand_pose.reshape(N, T, 15, 3)
+ # hand_pose = angle_axis_to_quaternion(torch.from_numpy(hand_pose)).numpy()
+
+ # lerp and slerp
+ global_trans_lerped = linear_interpolation_nd(global_trans, valid)
+ betas_lerped = linear_interpolation_nd(betas, valid)
+ global_rot_slerped = slerp_interpolation_aa(global_rot, valid)
+ hand_pose_slerped = slerp_interpolation_aa(hand_pose, valid)
+
+
+ # convert to rot6d
+
+ global_rot_slerped_mat = angle_axis_to_rotation_matrix(torch.from_numpy(global_rot_slerped.reshape(N*T, -1)))
+ # global_rot_slerped_mat = quaternion_to_rotation_matrix(torch.from_numpy(global_rot_slerped.reshape(N*T, -1)))
+ global_rot_slerped_rot6d = rotmat_to_rot6d(global_rot_slerped_mat).reshape(N, T, -1).numpy()
+ hand_pose_slerped_mat = angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose_slerped.reshape(N*T*num_joints, -1)))
+ # hand_pose_slerped_mat = quaternion_to_rotation_matrix(torch.from_numpy(hand_pose_slerped.reshape(N*T*num_joints, -1)))
+ hand_pose_slerped_rot6d = rotmat_to_rot6d(hand_pose_slerped_mat).reshape(N, T, -1).numpy()
+
+
+ # concat to (T, concat_dim)
+ global_pose_vec_input = np.concatenate((global_trans_lerped, betas_lerped, global_rot_slerped_rot6d, hand_pose_slerped_rot6d), axis=-1).transpose(1, 0, 2).reshape(T, -1)
+
+ R_canon2w_left = R_world2canonical_left.transpose(-1, -2)
+ t_canon2w_left = -torch.einsum("tij,tj->ti", R_canon2w_left, t_world2canonical_left)
+ R_canon2w_right = R_world2canonical_right.transpose(-1, -2)
+ t_canon2w_right = -torch.einsum("tij,tj->ti", R_canon2w_right, t_world2canonical_right)
+
+ transform_w_canon = {
+ "R_w2canon_left": R_world2canonical_left,
+ "t_w2canon_left": t_world2canonical_left,
+ "R_canon2w_left": R_canon2w_left,
+ "t_canon2w_left": t_canon2w_left,
+
+ "R_w2canon_right": R_world2canonical_right,
+ "t_w2canon_right": t_world2canonical_right,
+ "R_canon2w_right": R_canon2w_right,
+ "t_canon2w_right": t_canon2w_right,
+ }
+
+ return global_pose_vec_input, transform_w_canon
+
+def custom_rot6d_to_rotmat(rot6d):
+ original_shape = rot6d.shape[:-1]
+ rot6d = rot6d.reshape(-1, 6)
+ mat = rot6d_to_rotmat(rot6d)
+ mat = mat.reshape(*original_shape, 3, 3)
+ return mat
+
+def filling_postprocess(output, transform_w_canon):
+ # output = output.numpy()
+ output = output.permute(1, 0, 2) # (2, T, -1)
+ N, T, _ = output.shape
+ canon_trans = output[:, :, :3]
+ betas = output[:, :, 3:13]
+ canon_rot_rot6d = output[:, :, 13:19]
+ hand_pose_rot6d = output[:, :, 19:109].reshape(N, T, 15, 6)
+
+ canon_rot_mat = custom_rot6d_to_rotmat(canon_rot_rot6d)
+ hand_pose_mat = custom_rot6d_to_rotmat(hand_pose_rot6d)
+
+ data_canonical_left = {
+ "init_trans": canon_trans[[0], :, :],
+ "init_root_orient": canon_rot_mat[[0], :, :, :],
+ "init_hand_pose": hand_pose_mat[[0], :, :, :, :],
+ "init_betas": betas[[0], :, :]
+ }
+
+ data_canonical_right = {
+ "init_trans": canon_trans[[1], :, :],
+ "init_root_orient": canon_rot_mat[[1], :, :, :],
+ "init_hand_pose": hand_pose_mat[[1], :, :, :, :],
+ "init_betas": betas[[1], :, :]
+ }
+
+ R_canon2w_left = transform_w_canon['R_canon2w_left']
+ t_canon2w_left = transform_w_canon['t_canon2w_left']
+ R_canon2w_right = transform_w_canon['R_canon2w_right']
+ t_canon2w_right = transform_w_canon['t_canon2w_right']
+
+
+ world_left = world2canonical_convert(R_canon2w_left, t_canon2w_left, data_canonical_left, "left")
+ world_right = world2canonical_convert(R_canon2w_right, t_canon2w_right, data_canonical_right, "right")
+
+ global_rot = torch.cat((world_left['init_root_orient'], world_right['init_root_orient'])).numpy()
+ global_trans = torch.cat((world_left['init_trans'], world_right['init_trans'])).numpy()
+
+ pred_data = {
+ "trans": global_trans, # (2, T, 3)
+ "rot": global_rot, # (2, T, 3)
+ "hand_pose": rotation_matrix_to_angle_axis(hand_pose_mat).flatten(-2).numpy(), # (2, T, 45)
+ "betas": betas.numpy(), # (2, T, 10)
+ }
+
+ return pred_data
+
diff --git a/lib/eval_utils/video_utils.py b/lib/eval_utils/video_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fd29f9c3088d6a7556b0566a1953950e8f639a8
--- /dev/null
+++ b/lib/eval_utils/video_utils.py
@@ -0,0 +1,85 @@
+import cv2
+import os
+import subprocess
+
+def make_video_grid_2x2(out_path, vid_paths, overwrite=False):
+ """
+ 将四个视频以原始分辨率拼接成 2x2 网格。
+
+ :param out_path: 输出视频路径。
+ :param vid_paths: 输入视频路径的列表(长度必须为 4)。
+ :param overwrite: 如果为 True,覆盖已存在的输出文件。
+ """
+ if os.path.isfile(out_path) and not overwrite:
+ print(f"{out_path} already exists, skipping.")
+ return
+
+ if any(not os.path.isfile(v) for v in vid_paths):
+ print("Not all inputs exist!", vid_paths)
+ return
+
+ # 确保视频路径长度为 4
+ if len(vid_paths) != 4:
+ print("Error: Exactly 4 video paths are required!")
+ return
+
+ # 获取视频路径
+ v1, v2, v3, v4 = vid_paths
+
+ # ffmpeg 拼接命令,直接拼接不调整大小
+ cmd = (
+ f"ffmpeg -i {v1} -i {v2} -i {v3} -i {v4} "
+ f"-filter_complex '[0:v][1:v][2:v][3:v]xstack=inputs=4:layout=0_0|w0_0|0_h0|w0_h0[v]' "
+ f"-map '[v]' {out_path} -y"
+ )
+
+ print(cmd)
+ subprocess.call(cmd, shell=True, stdin=subprocess.PIPE)
+
+def create_video_from_images(image_list, output_path, fps=15, target_resolution=(540, 540)):
+ """
+ 将图片列表合成为 MP4 视频。
+
+ :param image_list: 图片路径的列表。
+ :param output_path: 输出视频的文件路径(如 output.mp4)。
+ :param fps: 视频的帧率(默认 15 FPS)。
+ """
+ # if not image_list:
+ # print("图片列表为空!")
+ # return
+
+ # 读取第一张图片以获取宽度和高度
+ first_image = cv2.imread(image_list[0])
+ if first_image is None:
+ print(f"无法读取图片: {image_list[0]}")
+ return
+
+ height, width, _ = first_image.shape
+ if height != width:
+ if height < width:
+ vis_w = target_resolution[0]
+ vis_h = int(target_resolution[0] / width * height)
+ elif height > width:
+ vis_h = target_resolution[0]
+ vis_w = int(target_resolution[0] / height * width)
+ else:
+ vis_h = target_resolution[0]
+ vis_w = target_resolution[0]
+ target_resolution = (vis_w, vis_h)
+
+ # 定义视频编码器和输出参数
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 mp4v 编码器
+ video_writer = cv2.VideoWriter(output_path, fourcc, fps, target_resolution)
+
+ # 遍历图片列表并写入视频
+ for image_path in image_list:
+ frame = cv2.imread(image_path)
+ frame_resized = cv2.resize(frame, target_resolution)
+ if frame is None:
+ print(f"无法读取图片: {image_path}")
+ continue
+ video_writer.write(frame_resized)
+
+ # 释放视频写入器
+ video_writer.release()
+ print(f"视频已保存至: {output_path}")
\ No newline at end of file
diff --git a/lib/models/__pycache__/hawor.cpython-310.pyc b/lib/models/__pycache__/hawor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea55f4816cf240a4b2f3786f8f560f0e794a830b
Binary files /dev/null and b/lib/models/__pycache__/hawor.cpython-310.pyc differ
diff --git a/lib/models/__pycache__/mano_wrapper.cpython-310.pyc b/lib/models/__pycache__/mano_wrapper.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d945c4f2c9b28af8dd870fe910e7575fa184b649
Binary files /dev/null and b/lib/models/__pycache__/mano_wrapper.cpython-310.pyc differ
diff --git a/lib/models/__pycache__/modules.cpython-310.pyc b/lib/models/__pycache__/modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab0189e3d153407e46e5e07fd8d5685639c73364
Binary files /dev/null and b/lib/models/__pycache__/modules.cpython-310.pyc differ
diff --git a/lib/models/backbones/__init__.py b/lib/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8288cfd84555de1720b81129fa8accf76948cb88
--- /dev/null
+++ b/lib/models/backbones/__init__.py
@@ -0,0 +1,8 @@
+from .vit import vit
+
+
+def create_backbone(cfg):
+ if cfg.MODEL.BACKBONE.TYPE == 'vit':
+ return vit(cfg)
+ else:
+ raise NotImplementedError('Backbone type is not implemented')
diff --git a/lib/models/backbones/__pycache__/__init__.cpython-310.pyc b/lib/models/backbones/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5b426f6ccab9cf5c36e1444934e14c0bbcc8eec
Binary files /dev/null and b/lib/models/backbones/__pycache__/__init__.cpython-310.pyc differ
diff --git a/lib/models/backbones/__pycache__/vit.cpython-310.pyc b/lib/models/backbones/__pycache__/vit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..59f170f6599832d8098e7142b3778fd7a83529b2
Binary files /dev/null and b/lib/models/backbones/__pycache__/vit.cpython-310.pyc differ
diff --git a/lib/models/backbones/vit.py b/lib/models/backbones/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..c56c71889cd441294f57ad687d0678d2443d1eed
--- /dev/null
+++ b/lib/models/backbones/vit.py
@@ -0,0 +1,348 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+from functools import partial
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+
+def vit(cfg):
+ return ViT(
+ img_size=(256, 192),
+ patch_size=16,
+ embed_dim=1280,
+ depth=32,
+ num_heads=16,
+ ratio=1,
+ use_checkpoint=False,
+ mlp_ratio=4,
+ qkv_bias=True,
+ drop_path_rate=0.55,
+ )
+
+def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
+ """
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
+ dimension for the original embeddings.
+ Args:
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
+ hw (Tuple): size of input image tokens.
+
+ Returns:
+ Absolute positional embeddings after processing with shape (1, H, W, C)
+ """
+ cls_token = None
+ B, L, C = abs_pos.shape
+ if has_cls_token:
+ cls_token = abs_pos[:, 0:1]
+ abs_pos = abs_pos[:, 1:]
+
+ if ori_h != h or ori_w != w:
+ new_abs_pos = F.interpolate(
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
+ size=(h, w),
+ mode="bicubic",
+ align_corners=False,
+ ).permute(0, 2, 3, 1).reshape(B, -1, C)
+
+ else:
+ new_abs_pos = abs_pos
+
+ if cls_token is not None:
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
+ return new_abs_pos
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self):
+ return 'p={}'.format(self.drop_prob)
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+class Attention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., attn_head_dim=None,):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.dim = dim
+
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm, attn_head_dim=None
+ ):
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
+ )
+
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
+
+ def forward(self, x, **kwargs):
+ B, C, H, W = x.shape
+ x = self.proj(x)
+ Hp, Wp = x.shape[2], x.shape[3]
+
+ x = x.flatten(2).transpose(1, 2)
+ return x, (Hp, Wp)
+
+
+class HybridEmbed(nn.Module):
+ """ CNN Feature Map Embedding
+ Extract feature map from CNN, flatten, project to embedding dim.
+ """
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
+ super().__init__()
+ assert isinstance(backbone, nn.Module)
+ img_size = to_2tuple(img_size)
+ self.img_size = img_size
+ self.backbone = backbone
+ if feature_size is None:
+ with torch.no_grad():
+ training = backbone.training
+ if training:
+ backbone.eval()
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
+ feature_size = o.shape[-2:]
+ feature_dim = o.shape[1]
+ backbone.train(training)
+ else:
+ feature_size = to_2tuple(feature_size)
+ feature_dim = self.backbone.feature_info.channels()[-1]
+ self.num_patches = feature_size[0] * feature_size[1]
+ self.proj = nn.Linear(feature_dim, embed_dim)
+
+ def forward(self, x):
+ x = self.backbone(x)[-1]
+ x = x.flatten(2).transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+
+class ViT(nn.Module):
+
+ def __init__(self,
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
+ frozen_stages=-1, ratio=1, last_norm=True,
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
+ ):
+ # Protect mutable default arguments
+ super(ViT, self).__init__()
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.frozen_stages = frozen_stages
+ self.use_checkpoint = use_checkpoint
+ self.patch_padding = patch_padding
+ self.freeze_attn = freeze_attn
+ self.freeze_ffn = freeze_ffn
+ self.depth = depth
+
+ if hybrid_backbone is not None:
+ self.patch_embed = HybridEmbed(
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
+ else:
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
+ num_patches = self.patch_embed.num_patches
+
+ # since the pretraining model has class token
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ )
+ for i in range(depth)])
+
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
+
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=.02)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ """Freeze parameters."""
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = self.blocks[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ if self.freeze_attn:
+ for i in range(0, self.depth):
+ m = self.blocks[i]
+ m.attn.eval()
+ m.norm1.eval()
+ for param in m.attn.parameters():
+ param.requires_grad = False
+ for param in m.norm1.parameters():
+ param.requires_grad = False
+
+ if self.freeze_ffn:
+ self.pos_embed.requires_grad = False
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+ for i in range(0, self.depth):
+ m = self.blocks[i]
+ m.mlp.eval()
+ m.norm2.eval()
+ for param in m.mlp.parameters():
+ param.requires_grad = False
+ for param in m.norm2.parameters():
+ param.requires_grad = False
+
+ def init_weights(self):
+ """Initialize the weights in backbone.
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ self.apply(_init_weights)
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def forward_features(self, x):
+ B, C, H, W = x.shape
+ x, (Hp, Wp) = self.patch_embed(x)
+
+ if self.pos_embed is not None:
+ # fit for multiple GPU training
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
+
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+
+ x = self.last_norm(x)
+
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
+
+ return xp
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ return x
+
+ def train(self, mode=True):
+ """Convert the model into training mode."""
+ super().train(mode)
+ self._freeze_stages()
diff --git a/lib/models/components/__init__.py b/lib/models/components/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lib/models/components/__pycache__/__init__.cpython-310.pyc b/lib/models/components/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..305210c4ee998e41e6ecfd36a2321e54e24d6578
Binary files /dev/null and b/lib/models/components/__pycache__/__init__.cpython-310.pyc differ
diff --git a/lib/models/components/__pycache__/pose_transformer.cpython-310.pyc b/lib/models/components/__pycache__/pose_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50a53a44f68fd4da61ddefec37603449f76a1dcc
Binary files /dev/null and b/lib/models/components/__pycache__/pose_transformer.cpython-310.pyc differ
diff --git a/lib/models/components/__pycache__/t_cond_mlp.cpython-310.pyc b/lib/models/components/__pycache__/t_cond_mlp.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d244f3e7957c1d3eb97a77ece03adf3cc0e94fe7
Binary files /dev/null and b/lib/models/components/__pycache__/t_cond_mlp.cpython-310.pyc differ
diff --git a/lib/models/components/pose_transformer.py b/lib/models/components/pose_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac04971407cb59637490cc4842f048b9bc4758be
--- /dev/null
+++ b/lib/models/components/pose_transformer.py
@@ -0,0 +1,358 @@
+from inspect import isfunction
+from typing import Callable, Optional
+
+import torch
+from einops import rearrange
+from einops.layers.torch import Rearrange
+from torch import nn
+
+from .t_cond_mlp import (
+ AdaptiveLayerNorm1D,
+ FrequencyEmbedder,
+ normalization_layer,
+)
+# from .vit import Attention, FeedForward
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+class PreNorm(nn.Module):
+ def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
+ super().__init__()
+ self.norm = normalization_layer(norm, dim, norm_cond_dim)
+ self.fn = fn
+
+ def forward(self, x: torch.Tensor, *args, **kwargs):
+ if isinstance(self.norm, AdaptiveLayerNorm1D):
+ return self.fn(self.norm(x, *args), **kwargs)
+ else:
+ return self.fn(self.norm(x), **kwargs)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim, dropout=0.0):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(dim, hidden_dim),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(hidden_dim, dim),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head**-0.5
+
+ self.attend = nn.Softmax(dim=-1)
+ self.dropout = nn.Dropout(dropout)
+
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
+
+ self.to_out = (
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
+ if project_out
+ else nn.Identity()
+ )
+
+ def forward(self, x):
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
+
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+
+ attn = self.attend(dots)
+ attn = self.dropout(attn)
+
+ out = torch.matmul(attn, v)
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.to_out(out)
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head**-0.5
+
+ self.attend = nn.Softmax(dim=-1)
+ self.dropout = nn.Dropout(dropout)
+
+ context_dim = default(context_dim, dim)
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+
+ self.to_out = (
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
+ if project_out
+ else nn.Identity()
+ )
+
+ def forward(self, x, context=None):
+ context = default(context, x)
+ k, v = self.to_kv(context).chunk(2, dim=-1)
+ q = self.to_q(x)
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
+
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+
+ attn = self.attend(dots)
+ attn = self.dropout(attn)
+
+ out = torch.matmul(attn, v)
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.to_out(out)
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ depth: int,
+ heads: int,
+ dim_head: int,
+ mlp_dim: int,
+ dropout: float = 0.0,
+ norm: str = "layer",
+ norm_cond_dim: int = -1,
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
+ ]
+ )
+ )
+
+ def forward(self, x: torch.Tensor, *args):
+ for attn, ff in self.layers:
+ x = attn(x, *args) + x
+ x = ff(x, *args) + x
+ return x
+
+
+class TransformerCrossAttn(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ depth: int,
+ heads: int,
+ dim_head: int,
+ mlp_dim: int,
+ dropout: float = 0.0,
+ norm: str = "layer",
+ norm_cond_dim: int = -1,
+ context_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
+ ca = CrossAttention(
+ dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
+ )
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
+ PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
+ ]
+ )
+ )
+
+ def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
+ if context_list is None:
+ context_list = [context] * len(self.layers)
+ if len(context_list) != len(self.layers):
+ raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
+
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
+ x = self_attn(x, *args) + x
+ x = cross_attn(x, *args, context=context_list[i]) + x
+ x = ff(x, *args) + x
+ return x
+
+
+class DropTokenDropout(nn.Module):
+ def __init__(self, p: float = 0.1):
+ super().__init__()
+ if p < 0 or p > 1:
+ raise ValueError(
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
+ )
+ self.p = p
+
+ def forward(self, x: torch.Tensor):
+ # x: (batch_size, seq_len, dim)
+ if self.training and self.p > 0:
+ zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
+ # TODO: permutation idx for each batch using torch.argsort
+ if zero_mask.any():
+ x = x[:, ~zero_mask, :]
+ return x
+
+
+class ZeroTokenDropout(nn.Module):
+ def __init__(self, p: float = 0.1):
+ super().__init__()
+ if p < 0 or p > 1:
+ raise ValueError(
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
+ )
+ self.p = p
+
+ def forward(self, x: torch.Tensor):
+ # x: (batch_size, seq_len, dim)
+ if self.training and self.p > 0:
+ zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
+ # Zero-out the masked tokens
+ x[zero_mask, :] = 0
+ return x
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(
+ self,
+ num_tokens: int,
+ token_dim: int,
+ dim: int,
+ depth: int,
+ heads: int,
+ mlp_dim: int,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ emb_dropout: float = 0.0,
+ emb_dropout_type: str = "drop",
+ emb_dropout_loc: str = "token",
+ norm: str = "layer",
+ norm_cond_dim: int = -1,
+ token_pe_numfreq: int = -1,
+ ):
+ super().__init__()
+ if token_pe_numfreq > 0:
+ token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
+ self.to_token_embedding = nn.Sequential(
+ Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
+ FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
+ Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
+ nn.Linear(token_dim_new, dim),
+ )
+ else:
+ self.to_token_embedding = nn.Linear(token_dim, dim)
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
+ if emb_dropout_type == "drop":
+ self.dropout = DropTokenDropout(emb_dropout)
+ elif emb_dropout_type == "zero":
+ self.dropout = ZeroTokenDropout(emb_dropout)
+ else:
+ raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
+ self.emb_dropout_loc = emb_dropout_loc
+
+ self.transformer = Transformer(
+ dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
+ )
+
+ def forward(self, inp: torch.Tensor, *args, **kwargs):
+ x = inp
+
+ if self.emb_dropout_loc == "input":
+ x = self.dropout(x)
+ x = self.to_token_embedding(x)
+
+ if self.emb_dropout_loc == "token":
+ x = self.dropout(x)
+ b, n, _ = x.shape
+ x += self.pos_embedding[:, :n]
+
+ if self.emb_dropout_loc == "token_afterpos":
+ x = self.dropout(x)
+ x = self.transformer(x, *args)
+ return x
+
+
+class TransformerDecoder(nn.Module):
+ def __init__(
+ self,
+ num_tokens: int,
+ token_dim: int,
+ dim: int,
+ depth: int,
+ heads: int,
+ mlp_dim: int,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ emb_dropout: float = 0.0,
+ emb_dropout_type: str = 'drop',
+ norm: str = "layer",
+ norm_cond_dim: int = -1,
+ context_dim: Optional[int] = None,
+ skip_token_embedding: bool = False,
+ ):
+ super().__init__()
+ if not skip_token_embedding:
+ self.to_token_embedding = nn.Linear(token_dim, dim)
+ else:
+ self.to_token_embedding = nn.Identity()
+ if token_dim != dim:
+ raise ValueError(
+ f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
+ )
+
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
+ if emb_dropout_type == "drop":
+ self.dropout = DropTokenDropout(emb_dropout)
+ elif emb_dropout_type == "zero":
+ self.dropout = ZeroTokenDropout(emb_dropout)
+ elif emb_dropout_type == "normal":
+ self.dropout = nn.Dropout(emb_dropout)
+
+ self.transformer = TransformerCrossAttn(
+ dim,
+ depth,
+ heads,
+ dim_head,
+ mlp_dim,
+ dropout,
+ norm=norm,
+ norm_cond_dim=norm_cond_dim,
+ context_dim=context_dim,
+ )
+
+ def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
+ x = self.to_token_embedding(inp)
+ b, n, _ = x.shape
+
+ x = self.dropout(x)
+ x += self.pos_embedding[:, :n]
+
+ x = self.transformer(x, *args, context=context, context_list=context_list)
+ return x
+
diff --git a/lib/models/components/t_cond_mlp.py b/lib/models/components/t_cond_mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..44d5a09bf54f67712a69953039b7b5af41c3f029
--- /dev/null
+++ b/lib/models/components/t_cond_mlp.py
@@ -0,0 +1,199 @@
+import copy
+from typing import List, Optional
+
+import torch
+
+
+class AdaptiveLayerNorm1D(torch.nn.Module):
+ def __init__(self, data_dim: int, norm_cond_dim: int):
+ super().__init__()
+ if data_dim <= 0:
+ raise ValueError(f"data_dim must be positive, but got {data_dim}")
+ if norm_cond_dim <= 0:
+ raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
+ self.norm = torch.nn.LayerNorm(
+ data_dim
+ ) # TODO: Check if elementwise_affine=True is correct
+ self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
+ torch.nn.init.zeros_(self.linear.weight)
+ torch.nn.init.zeros_(self.linear.bias)
+
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+ # x: (batch, ..., data_dim)
+ # t: (batch, norm_cond_dim)
+ # return: (batch, data_dim)
+ x = self.norm(x)
+ alpha, beta = self.linear(t).chunk(2, dim=-1)
+
+ # Add singleton dimensions to alpha and beta
+ if x.dim() > 2:
+ alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
+ beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
+
+ return x * (1 + alpha) + beta
+
+
+class SequentialCond(torch.nn.Sequential):
+ def forward(self, input, *args, **kwargs):
+ for module in self:
+ if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
+ # print(f'Passing on args to {module}', [a.shape for a in args])
+ input = module(input, *args, **kwargs)
+ else:
+ # print(f'Skipping passing args to {module}', [a.shape for a in args])
+ input = module(input)
+ return input
+
+
+def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
+ if norm == "batch":
+ return torch.nn.BatchNorm1d(dim)
+ elif norm == "layer":
+ return torch.nn.LayerNorm(dim)
+ elif norm == "ada":
+ assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
+ return AdaptiveLayerNorm1D(dim, norm_cond_dim)
+ elif norm is None:
+ return torch.nn.Identity()
+ else:
+ raise ValueError(f"Unknown norm: {norm}")
+
+
+def linear_norm_activ_dropout(
+ input_dim: int,
+ output_dim: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ bias: bool = True,
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
+ dropout: float = 0.0,
+ norm_cond_dim: int = -1,
+) -> SequentialCond:
+ layers = []
+ layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
+ if norm is not None:
+ layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
+ layers.append(copy.deepcopy(activation))
+ if dropout > 0.0:
+ layers.append(torch.nn.Dropout(dropout))
+ return SequentialCond(*layers)
+
+
+def create_simple_mlp(
+ input_dim: int,
+ hidden_dims: List[int],
+ output_dim: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ bias: bool = True,
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
+ dropout: float = 0.0,
+ norm_cond_dim: int = -1,
+) -> SequentialCond:
+ layers = []
+ prev_dim = input_dim
+ for hidden_dim in hidden_dims:
+ layers.extend(
+ linear_norm_activ_dropout(
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
+ )
+ )
+ prev_dim = hidden_dim
+ layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
+ return SequentialCond(*layers)
+
+
+class ResidualMLPBlock(torch.nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ num_hidden_layers: int,
+ output_dim: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ bias: bool = True,
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
+ dropout: float = 0.0,
+ norm_cond_dim: int = -1,
+ ):
+ super().__init__()
+ if not (input_dim == output_dim == hidden_dim):
+ raise NotImplementedError(
+ f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
+ )
+
+ layers = []
+ prev_dim = input_dim
+ for i in range(num_hidden_layers):
+ layers.append(
+ linear_norm_activ_dropout(
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
+ )
+ )
+ prev_dim = hidden_dim
+ self.model = SequentialCond(*layers)
+ self.skip = torch.nn.Identity()
+
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ return x + self.model(x, *args, **kwargs)
+
+
+class ResidualMLP(torch.nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ num_hidden_layers: int,
+ output_dim: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ bias: bool = True,
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
+ dropout: float = 0.0,
+ num_blocks: int = 1,
+ norm_cond_dim: int = -1,
+ ):
+ super().__init__()
+ self.input_dim = input_dim
+ self.model = SequentialCond(
+ linear_norm_activ_dropout(
+ input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
+ ),
+ *[
+ ResidualMLPBlock(
+ hidden_dim,
+ hidden_dim,
+ num_hidden_layers,
+ hidden_dim,
+ activation,
+ bias,
+ norm,
+ dropout,
+ norm_cond_dim,
+ )
+ for _ in range(num_blocks)
+ ],
+ torch.nn.Linear(hidden_dim, output_dim, bias=bias),
+ )
+
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ return self.model(x, *args, **kwargs)
+
+
+class FrequencyEmbedder(torch.nn.Module):
+ def __init__(self, num_frequencies, max_freq_log2):
+ super().__init__()
+ frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
+ self.register_buffer("frequencies", frequencies)
+
+ def forward(self, x):
+ # x should be of size (N,) or (N, D)
+ N = x.size(0)
+ if x.dim() == 1: # (N,)
+ x = x.unsqueeze(1) # (N, D) where D=1
+ x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
+ scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
+ s = torch.sin(scaled)
+ c = torch.cos(scaled)
+ embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
+ N, -1
+ ) # (N, D * 2 * num_frequencies + D)
+ return embedded
+
diff --git a/lib/models/hawor.py b/lib/models/hawor.py
new file mode 100644
index 0000000000000000000000000000000000000000..34a081cd63d0a3e3be894fda21375683f3e61d51
--- /dev/null
+++ b/lib/models/hawor.py
@@ -0,0 +1,527 @@
+import einops
+import numpy as np
+import torch
+import pytorch_lightning as pl
+from typing import Dict
+from torchvision.utils import make_grid
+
+from tqdm import tqdm
+from yacs.config import CfgNode
+
+from lib.datasets.track_dataset import TrackDatasetEval
+from lib.models.modules import MANOTransformerDecoderHead, temporal_attention
+from hawor.utils.pylogger import get_pylogger
+from hawor.utils.render_openpose import render_openpose
+from lib.utils.geometry import rot6d_to_rotmat_hmr2 as rot6d_to_rotmat
+from lib.utils.geometry import perspective_projection
+from hawor.utils.rotation import angle_axis_to_rotation_matrix
+from torch.utils.data import default_collate
+
+from .backbones import create_backbone
+from .mano_wrapper import MANO
+
+
+log = get_pylogger(__name__)
+idx = 0
+
+class HAWOR(pl.LightningModule):
+
+ def __init__(self, cfg: CfgNode):
+ """
+ Setup HAWOR model
+ Args:
+ cfg (CfgNode): Config file as a yacs CfgNode
+ """
+ super().__init__()
+
+ # Save hyperparameters
+ self.save_hyperparameters(logger=False, ignore=['init_renderer'])
+
+ self.cfg = cfg
+ self.crop_size = cfg.MODEL.IMAGE_SIZE
+ self.seq_len = 16
+ self.pose_num = 16
+ self.pose_dim = 6 # rot6d representation
+ self.box_info_dim = 3
+
+ # Create backbone feature extractor
+ self.backbone = create_backbone(cfg)
+ try:
+ if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None):
+ whole_state_dict = torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict']
+ backbone_state_dict = {}
+ for key in whole_state_dict:
+ if key[:9] == 'backbone.':
+ backbone_state_dict[key[9:]] = whole_state_dict[key]
+ self.backbone.load_state_dict(backbone_state_dict)
+ print(f'Loaded backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}')
+ for param in self.backbone.parameters():
+ param.requires_grad = False
+ else:
+ print('WARNING: init backbone from sratch !!!')
+ except:
+ print('WARNING: init backbone from sratch !!!')
+
+ # Space-time memory
+ if cfg.MODEL.ST_MODULE:
+ hdim = cfg.MODEL.ST_HDIM
+ nlayer = cfg.MODEL.ST_NLAYER
+ self.st_module = temporal_attention(in_dim=1280+3,
+ out_dim=1280,
+ hdim=hdim,
+ nlayer=nlayer,
+ residual=True)
+ print(f'Using Temporal Attention space-time: {nlayer} layers {hdim} dim.')
+ else:
+ self.st_module = None
+
+ # Motion memory
+ if cfg.MODEL.MOTION_MODULE:
+ hdim = cfg.MODEL.MOTION_HDIM
+ nlayer = cfg.MODEL.MOTION_NLAYER
+
+ self.motion_module = temporal_attention(in_dim=self.pose_num * self.pose_dim + self.box_info_dim,
+ out_dim=self.pose_num * self.pose_dim,
+ hdim=hdim,
+ nlayer=nlayer,
+ residual=False)
+ print(f'Using Temporal Attention motion layer: {nlayer} layers {hdim} dim.')
+ else:
+ self.motion_module = None
+
+ # Create MANO head
+ # self.mano_head = build_mano_head(cfg)
+ self.mano_head = MANOTransformerDecoderHead(cfg)
+
+
+ # default open torch compile
+ if cfg.MODEL.BACKBONE.get('TORCH_COMPILE', 0):
+ log.info("Model will use torch.compile")
+ self.backbone = torch.compile(self.backbone)
+ self.mano_head = torch.compile(self.mano_head)
+
+ # Define loss functions
+ # self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1')
+ # self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1')
+ # self.mano_parameter_loss = ParameterLoss()
+
+ # Instantiate MANO model
+ mano_cfg = {k.lower(): v for k,v in dict(cfg.MANO).items()}
+ self.mano = MANO(**mano_cfg)
+
+ # Buffer that shows whetheer we need to initialize ActNorm layers
+ self.register_buffer('initialized', torch.tensor(False))
+
+ # Disable automatic optimization since we use adversarial training
+ self.automatic_optimization = False
+
+ if cfg.MODEL.get('LOAD_WEIGHTS', None):
+ whole_state_dict = torch.load(cfg.MODEL.LOAD_WEIGHTS, map_location='cpu')['state_dict']
+ self.load_state_dict(whole_state_dict, strict=True)
+ print(f"load {cfg.MODEL.LOAD_WEIGHTS}")
+
+ def get_parameters(self):
+ all_params = list(self.mano_head.parameters())
+ if not self.st_module is None:
+ all_params += list(self.st_module.parameters())
+ if not self.motion_module is None:
+ all_params += list(self.motion_module.parameters())
+ all_params += list(self.backbone.parameters())
+ return all_params
+
+ def configure_optimizers(self) -> torch.optim.Optimizer:
+ """
+ Setup model and distriminator Optimizers
+ Returns:
+ Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers
+ """
+ param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}]
+
+ optimizer = torch.optim.AdamW(params=param_groups,
+ # lr=self.cfg.TRAIN.LR,
+ weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
+ return optimizer
+
+ def forward_step(self, batch: Dict, train: bool = False) -> Dict:
+ """
+ Run a forward step of the network
+ Args:
+ batch (Dict): Dictionary containing batch data
+ train (bool): Flag indicating whether it is training or validation mode
+ Returns:
+ Dict: Dictionary containing the regression output
+ """
+
+ image = batch['img'].flatten(0, 1)
+ center = batch['center'].flatten(0, 1)
+ scale = batch['scale'].flatten(0, 1)
+ img_focal = batch['img_focal'].flatten(0, 1)
+ img_center = batch['img_center'].flatten(0, 1)
+ bn = len(image)
+
+ # estimate focal length, and bbox
+ bbox_info = self.bbox_est(center, scale, img_focal, img_center)
+
+ # backbone
+ feature = self.backbone(image[:,:,:,32:-32])
+ feature = feature.float()
+
+ # space-time module
+ if self.st_module is not None:
+ bb = einops.repeat(bbox_info, 'b c -> b c h w', h=16, w=12)
+ feature = torch.cat([feature, bb], dim=1)
+
+ feature = einops.rearrange(feature, '(b t) c h w -> (b h w) t c', t=16)
+ feature = self.st_module(feature)
+ feature = einops.rearrange(feature, '(b h w) t c -> (b t) c h w', h=16, w=12)
+
+ # smpl_head: transformer + smpl
+ # pred_mano_params, pred_cam, pred_mano_params_list = self.mano_head(feature)
+ # pred_shape = pred_mano_params_list['pred_shape']
+ # pred_pose = pred_mano_params_list['pred_pose']
+ pred_pose, pred_shape, pred_cam = self.mano_head(feature)
+ pred_rotmat_0 = rot6d_to_rotmat(pred_pose).reshape(-1, self.pose_num, 3, 3)
+
+ # smpl motion module
+ if self.motion_module is not None:
+ bb = einops.rearrange(bbox_info, '(b t) c -> b t c', t=16)
+ pred_pose = einops.rearrange(pred_pose, '(b t) c -> b t c', t=16)
+ pred_pose = torch.cat([pred_pose, bb], dim=2)
+
+ pred_pose = self.motion_module(pred_pose)
+ pred_pose = einops.rearrange(pred_pose, 'b t c -> (b t) c')
+
+ out = {}
+ if 'do_flip' in batch:
+ pred_cam[..., 1] *= -1
+ center[..., 0] = img_center[..., 0]*2 - center[..., 0] - 1
+ out['pred_cam'] = pred_cam
+ out['pred_pose'] = pred_pose
+ out['pred_shape'] = pred_shape
+ out['pred_rotmat'] = rot6d_to_rotmat(out['pred_pose']).reshape(-1, self.pose_num, 3, 3)
+ out['pred_rotmat_0'] = pred_rotmat_0
+
+ s_out = self.mano.query(out)
+ j3d = s_out.joints
+ j2d = self.project(j3d, out['pred_cam'], center, scale, img_focal, img_center)
+ j2d = j2d / self.crop_size - 0.5 # norm to [-0.5, 0.5]
+
+ trans_full = self.get_trans(out['pred_cam'], center, scale, img_focal, img_center)
+ out['trans_full'] = trans_full
+
+ output = {
+ 'pred_mano_params': {
+ 'global_orient': out['pred_rotmat'][:, :1].clone(),
+ 'hand_pose': out['pred_rotmat'][:, 1:].clone(),
+ 'betas': out['pred_shape'].clone(),
+ },
+ 'pred_keypoints_3d': j3d.clone(),
+ 'pred_keypoints_2d': j2d.clone(),
+ 'out': out,
+ }
+ # print(output)
+ # output['gt_project_j2d'] = self.project(batch['gt_j3d_wo_trans'].clone().flatten(0,1), out['pred_cam'], center, scale, img_focal, img_center)
+ # output['gt_project_j2d'] = output['gt_project_j2d'] / self.crop_size - 0.5
+
+
+ return output
+
+ def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor:
+ """
+ Compute losses given the input batch and the regression output
+ Args:
+ batch (Dict): Dictionary containing batch data
+ output (Dict): Dictionary containing the regression output
+ train (bool): Flag indicating whether it is training or validation mode
+ Returns:
+ torch.Tensor : Total loss for current batch
+ """
+
+ pred_mano_params = output['pred_mano_params']
+ pred_keypoints_2d = output['pred_keypoints_2d']
+ pred_keypoints_3d = output['pred_keypoints_3d']
+
+
+ batch_size = pred_mano_params['hand_pose'].shape[0]
+ device = pred_mano_params['hand_pose'].device
+ dtype = pred_mano_params['hand_pose'].dtype
+
+ # Get annotations
+ gt_keypoints_2d = batch['gt_cam_j2d'].flatten(0, 1)
+ gt_keypoints_2d = torch.cat([gt_keypoints_2d, torch.ones(*gt_keypoints_2d.shape[:-1], 1, device=gt_keypoints_2d.device)], dim=-1)
+ gt_keypoints_3d = batch['gt_j3d_wo_trans'].flatten(0, 1)
+ gt_keypoints_3d = torch.cat([gt_keypoints_3d, torch.ones(*gt_keypoints_3d.shape[:-1], 1, device=gt_keypoints_3d.device)], dim=-1)
+ pose_gt = batch['gt_cam_full_pose'].flatten(0, 1).reshape(-1, 16, 3)
+ rotmat_gt = angle_axis_to_rotation_matrix(pose_gt)
+ gt_mano_params = {
+ 'global_orient': rotmat_gt[:, :1],
+ 'hand_pose': rotmat_gt[:, 1:],
+ 'betas': batch['gt_cam_betas'],
+ }
+
+ # Compute 3D keypoint loss
+ loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
+ loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
+
+ # to avoid nan
+ loss_keypoints_2d = torch.nan_to_num(loss_keypoints_2d)
+
+ # Compute loss on MANO parameters
+ loss_mano_params = {}
+ for k, pred in pred_mano_params.items():
+ gt = gt_mano_params[k].view(batch_size, -1)
+ loss_mano_params[k] = self.mano_parameter_loss(pred.reshape(batch_size, -1), gt.reshape(batch_size, -1))
+
+ loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\
+ self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d+\
+ sum([loss_mano_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_mano_params])
+
+ losses = dict(loss=loss.detach(),
+ loss_keypoints_2d=loss_keypoints_2d.detach() * self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'],
+ loss_keypoints_3d=loss_keypoints_3d.detach() * self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'])
+
+ for k, v in loss_mano_params.items():
+ losses['loss_' + k] = v.detach() * self.cfg.LOSS_WEIGHTS[k.upper()]
+
+ output['losses'] = losses
+
+ return loss
+
+ # Tensoroboard logging should run from first rank only
+ @pl.utilities.rank_zero.rank_zero_only
+ def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, write_to_summary_writer: bool = True, render_log: bool = True) -> None:
+ """
+ Log results to Tensorboard
+ Args:
+ batch (Dict): Dictionary containing batch data
+ output (Dict): Dictionary containing the regression output
+ step_count (int): Global training step count
+ train (bool): Flag indicating whether it is training or validation mode
+ """
+
+ mode = 'train' if train else 'val'
+ batch_size = output['pred_keypoints_2d'].shape[0]
+ images = batch['img'].flatten(0,1)
+ images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1,3,1,1)
+ images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1)
+
+ losses = output['losses']
+ if write_to_summary_writer:
+ summary_writer = self.logger.experiment
+ for loss_name, val in losses.items():
+ summary_writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count)
+
+ if render_log:
+ gt_keypoints_2d = batch['gt_cam_j2d'].flatten(0,1).clone()
+ pred_keypoints_2d = output['pred_keypoints_2d'].clone().detach().reshape(batch_size, -1, 2)
+ gt_project_j2d = pred_keypoints_2d
+ # gt_project_j2d = output['gt_project_j2d'].clone().detach().reshape(batch_size, -1, 2)
+
+ num_images = 4
+ skip=16
+
+ predictions = self.visualize_tensorboard(images[:num_images*skip:skip].cpu().numpy(),
+ pred_keypoints_2d[:num_images*skip:skip].cpu().numpy(),
+ gt_project_j2d[:num_images*skip:skip].cpu().numpy(),
+ gt_keypoints_2d[:num_images*skip:skip].cpu().numpy(),
+ )
+ summary_writer.add_image('%s/predictions' % mode, predictions, step_count)
+
+
+ def forward(self, batch: Dict) -> Dict:
+ """
+ Run a forward step of the network in val mode
+ Args:
+ batch (Dict): Dictionary containing batch data
+ Returns:
+ Dict: Dictionary containing the regression output
+ """
+ return self.forward_step(batch, train=False)
+
+ def training_step(self, joint_batch: Dict, batch_idx: int) -> Dict:
+ """
+ Run a full training step
+ Args:
+ joint_batch (Dict): Dictionary containing image and mocap batch data
+ batch_idx (int): Unused.
+ batch_idx (torch.Tensor): Unused.
+ Returns:
+ Dict: Dictionary containing regression output.
+ """
+ batch = joint_batch['img']
+ optimizer = self.optimizers(use_pl_optimizer=True)
+
+ batch_size = batch['img'].shape[0]
+ output = self.forward_step(batch, train=True)
+ # pred_mano_params = output['pred_mano_params']
+ loss = self.compute_loss(batch, output, train=True)
+
+ # Error if Nan
+ if torch.isnan(loss):
+ raise ValueError('Loss is NaN')
+
+ optimizer.zero_grad()
+ self.manual_backward(loss)
+ # Clip gradient
+ if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0:
+ gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, error_if_nonfinite=True)
+ self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
+ optimizer.step()
+
+ # if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0:
+ if self.global_step > 0 and self.global_step % 100 == 0:
+ self.tensorboard_logging(batch, output, self.global_step, train=True, render_log=self.cfg.TRAIN.get("RENDER_LOG", True))
+
+ self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=batch_size)
+
+ return output
+
+ def inference(self, imgfiles, boxes, img_focal, img_center, device='cuda', do_flip=False):
+ db = TrackDatasetEval(imgfiles, boxes, img_focal=img_focal,
+ img_center=img_center, normalization=True, dilate=1.2, do_flip=do_flip)
+
+ # Results
+ pred_cam = []
+ pred_pose = []
+ pred_shape = []
+ pred_rotmat = []
+ pred_trans = []
+
+ # To-do: efficient implementation with batch
+ items = []
+ for i in tqdm(range(len(db))):
+ item = db[i]
+ items.append(item)
+
+ # padding to 16
+ if i == len(db) - 1 and len(db) % 16 != 0:
+ pad = 16 - len(db) % 16
+ for _ in range(pad):
+ items.append(item)
+
+ if len(items) < 16:
+ continue
+ elif len(items) == 16:
+ batch = default_collate(items)
+ items = []
+ else:
+ raise NotImplementedError
+
+ with torch.no_grad():
+ batch = {k: v.to(device).unsqueeze(0) for k, v in batch.items() if type(v)==torch.Tensor}
+ # for image_i in range(16):
+ # hawor_input_cv2 = vis_tensor_cv2(batch['img'][:, image_i])
+ # cv2.imwrite(f'debug_vis_model.png', hawor_input_cv2)
+ # print("vis")
+ output = self.forward(batch)
+ out = output['out']
+
+ if i == len(db) - 1 and len(db) % 16 != 0:
+ out = {k:v[:len(db) % 16] for k,v in out.items()}
+ else:
+ out = {k:v for k,v in out.items()}
+
+ pred_cam.append(out['pred_cam'].cpu())
+ pred_pose.append(out['pred_pose'].cpu())
+ pred_shape.append(out['pred_shape'].cpu())
+ pred_rotmat.append(out['pred_rotmat'].cpu())
+ pred_trans.append(out['trans_full'].cpu())
+
+
+ results = {'pred_cam': torch.cat(pred_cam),
+ 'pred_pose': torch.cat(pred_pose),
+ 'pred_shape': torch.cat(pred_shape),
+ 'pred_rotmat': torch.cat(pred_rotmat),
+ 'pred_trans': torch.cat(pred_trans),
+ 'img_focal': img_focal,
+ 'img_center': img_center}
+
+ return results
+
+ def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict:
+ """
+ Run a validation step and log to Tensorboard
+ Args:
+ batch (Dict): Dictionary containing batch data
+ batch_idx (int): Unused.
+ Returns:
+ Dict: Dictionary containing regression output.
+ """
+ # batch_size = batch['img'].shape[0]
+ output = self.forward_step(batch, train=False)
+ loss = self.compute_loss(batch, output, train=False)
+ output['loss'] = loss
+ self.tensorboard_logging(batch, output, self.global_step, train=False)
+
+ return output
+
+ def visualize_tensorboard(self, images, pred_keypoints, gt_project_j2d, gt_keypoints):
+ pred_keypoints = 256 * (pred_keypoints + 0.5)
+ gt_keypoints = 256 * (gt_keypoints + 0.5)
+ gt_project_j2d = 256 * (gt_project_j2d + 0.5)
+ pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1)
+ gt_keypoints = np.concatenate((gt_keypoints, np.ones_like(gt_keypoints)[:, :, [0]]), axis=-1)
+ gt_project_j2d = np.concatenate((gt_project_j2d, np.ones_like(gt_project_j2d)[:, :, [0]]), axis=-1)
+ images_np = np.transpose(images, (0,2,3,1))
+ rend_imgs = []
+ for i in range(images_np.shape[0]):
+ pred_keypoints_img = render_openpose(255 * images_np[i].copy(), pred_keypoints[i]) / 255
+ gt_project_j2d_img = render_openpose(255 * images_np[i].copy(), gt_project_j2d[i]) / 255
+ gt_keypoints_img = render_openpose(255*images_np[i].copy(), gt_keypoints[i]) / 255
+ rend_imgs.append(torch.from_numpy(images[i]))
+ rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2,0,1))
+ rend_imgs.append(torch.from_numpy(gt_project_j2d_img).permute(2,0,1))
+ rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2,0,1))
+ rend_imgs = make_grid(rend_imgs, nrow=4, padding=2)
+ return rend_imgs
+
+ def project(self, points, pred_cam, center, scale, img_focal, img_center, return_full=False):
+
+ trans_full = self.get_trans(pred_cam, center, scale, img_focal, img_center)
+
+ # Projection in full frame image coordinate
+ points = points + trans_full
+ points2d_full = perspective_projection(points, rotation=None, translation=None,
+ focal_length=img_focal, camera_center=img_center)
+
+ # Adjust projected points to crop image coordinate
+ # (s.t. 1. we can calculate loss in crop image easily
+ # 2. we can query its pixel in the crop
+ # )
+ b = scale * 200
+ points2d = points2d_full - (center - b[:,None]/2)[:,None,:]
+ points2d = points2d * (self.crop_size / b)[:,None,None]
+
+ if return_full:
+ return points2d_full, points2d
+ else:
+ return points2d
+
+ def get_trans(self, pred_cam, center, scale, img_focal, img_center):
+ b = scale * 200
+ cx, cy = center[:,0], center[:,1] # center of crop
+ s, tx, ty = pred_cam.unbind(-1)
+
+ img_cx, img_cy = img_center[:,0], img_center[:,1] # center of original image
+
+ bs = b*s
+ tx_full = tx + 2*(cx-img_cx)/bs
+ ty_full = ty + 2*(cy-img_cy)/bs
+ tz_full = 2*img_focal/bs
+
+ trans_full = torch.stack([tx_full, ty_full, tz_full], dim=-1)
+ trans_full = trans_full.unsqueeze(1)
+
+ return trans_full
+
+ def bbox_est(self, center, scale, img_focal, img_center):
+ # Original image center
+ img_cx, img_cy = img_center[:,0], img_center[:,1]
+
+ # Implement CLIFF (Li et al.) bbox feature
+ cx, cy, b = center[:, 0], center[:, 1], scale * 200
+ bbox_info = torch.stack([cx - img_cx, cy - img_cy, b], dim=-1)
+ bbox_info[:, :2] = bbox_info[:, :2] / img_focal.unsqueeze(-1) * 2.8
+ bbox_info[:, 2] = (bbox_info[:, 2] - 0.24 * img_focal) / (0.06 * img_focal)
+
+ return bbox_info
diff --git a/lib/models/mano_wrapper.py b/lib/models/mano_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..4801c26e3410035db3d09f8dac70c637dd7b00f5
--- /dev/null
+++ b/lib/models/mano_wrapper.py
@@ -0,0 +1,52 @@
+import torch
+import numpy as np
+import pickle
+from typing import Optional
+import smplx
+from smplx.lbs import vertices2joints
+from smplx.utils import MANOOutput, to_tensor
+from smplx.vertex_ids import vertex_ids
+
+
+class MANO(smplx.MANOLayer):
+ def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs):
+ """
+ Extension of the official MANO implementation to support more joints.
+ Args:
+ Same as MANOLayer.
+ joint_regressor_extra (str): Path to extra joint regressor.
+ """
+ super(MANO, self).__init__(*args, **kwargs)
+ mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]
+
+ #2, 3, 5, 4, 1
+ if joint_regressor_extra is not None:
+ self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
+ self.register_buffer('extra_joints_idxs', to_tensor(list(vertex_ids['mano'].values()), dtype=torch.long))
+ self.register_buffer('joint_map', torch.tensor(mano_to_openpose, dtype=torch.long))
+
+ def forward(self, *args, **kwargs) -> MANOOutput:
+ """
+ Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified.
+ """
+ mano_output = super(MANO, self).forward(*args, **kwargs)
+ extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs)
+ joints = torch.cat([mano_output.joints, extra_joints], dim=1)
+ joints = joints[:, self.joint_map, :]
+ if hasattr(self, 'joint_regressor_extra'):
+ extra_joints = vertices2joints(self.joint_regressor_extra, mano_output.vertices)
+ joints = torch.cat([joints, extra_joints], dim=1)
+ mano_output.joints = joints
+ return mano_output
+
+ def query(self, hmr_output):
+ batch_size = hmr_output['pred_rotmat'].shape[0]
+ pred_rotmat = hmr_output['pred_rotmat'].reshape(batch_size, -1, 3, 3)
+ pred_shape = hmr_output['pred_shape'].reshape(batch_size, 10)
+
+ mano_output = self(global_orient=pred_rotmat[:, [0]],
+ hand_pose = pred_rotmat[:, 1:],
+ betas = pred_shape,
+ pose2rot=False)
+
+ return mano_output
\ No newline at end of file
diff --git a/lib/models/modules.py b/lib/models/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0adbef1ddfb6f15e8cf0df6b99f924f9c409d6f
--- /dev/null
+++ b/lib/models/modules.py
@@ -0,0 +1,133 @@
+import numpy as np
+import einops
+import torch
+import torch.nn as nn
+from .components.pose_transformer import TransformerDecoder
+
+if torch.cuda.is_available():
+ autocast = torch.cuda.amp.autocast
+ # print('Using autocast')
+else:
+ # dummy GradScaler for PyTorch < 1.6 OR no cuda
+ class autocast:
+ def __init__(self, enabled=True):
+ pass
+ def __enter__(self):
+ pass
+ def __exit__(self, *args):
+ pass
+
+class MANOTransformerDecoderHead(nn.Module):
+ """ HMR2 Cross-attention based SMPL Transformer decoder
+ """
+ def __init__(self, cfg):
+ super().__init__()
+ transformer_args = dict(
+ depth = 6, # originally 6
+ heads = 8,
+ mlp_dim = 1024,
+ dim_head = 64,
+ dropout = 0.0,
+ emb_dropout = 0.0,
+ norm = "layer",
+ context_dim = 1280,
+ num_tokens = 1,
+ token_dim = 1,
+ dim = 1024
+ )
+ self.transformer = TransformerDecoder(**transformer_args)
+
+ dim = 1024
+ npose = 16*6
+ self.decpose = nn.Linear(dim, npose)
+ self.decshape = nn.Linear(dim, 10)
+ self.deccam = nn.Linear(dim, 3)
+ nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
+ nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
+ nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
+
+ mean_params = np.load(cfg.MANO.MEAN_PARAMS)
+ init_hand_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
+ init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
+ init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
+ self.register_buffer('init_hand_pose', init_hand_pose)
+ self.register_buffer('init_betas', init_betas)
+ self.register_buffer('init_cam', init_cam)
+
+
+ def forward(self, x, **kwargs):
+
+ batch_size = x.shape[0]
+ # vit pretrained backbone is channel-first. Change to token-first
+ x = einops.rearrange(x, 'b c h w -> b (h w) c')
+
+ init_hand_pose = self.init_hand_pose.expand(batch_size, -1)
+ init_betas = self.init_betas.expand(batch_size, -1)
+ init_cam = self.init_cam.expand(batch_size, -1)
+
+ # Pass through transformer
+ token = torch.zeros(batch_size, 1, 1).to(x.device)
+ token_out = self.transformer(token, context=x)
+ token_out = token_out.squeeze(1) # (B, C)
+
+ # Readout from token_out
+ pred_pose = self.decpose(token_out) + init_hand_pose
+ pred_shape = self.decshape(token_out) + init_betas
+ pred_cam = self.deccam(token_out) + init_cam
+
+ return pred_pose, pred_shape, pred_cam
+
+
+
+class temporal_attention(nn.Module):
+ def __init__(self, in_dim=1280, out_dim=1280, hdim=512, nlayer=6, nhead=4, residual=False):
+ super(temporal_attention, self).__init__()
+ self.hdim = hdim
+ self.out_dim = out_dim
+ self.residual = residual
+ self.l1 = nn.Linear(in_dim, hdim)
+ self.l2 = nn.Linear(hdim, out_dim)
+
+ self.pos_embedding = PositionalEncoding(hdim, dropout=0.1)
+ TranLayer = nn.TransformerEncoderLayer(d_model=hdim, nhead=nhead, dim_feedforward=1024,
+ dropout=0.1, activation='gelu')
+ self.trans = nn.TransformerEncoder(TranLayer, num_layers=nlayer)
+
+ nn.init.xavier_uniform_(self.l1.weight, gain=0.01)
+ nn.init.xavier_uniform_(self.l2.weight, gain=0.01)
+
+ def forward(self, x):
+ x = x.permute(1,0,2) # (b,t,c) -> (t,b,c)
+
+ h = self.l1(x)
+ h = self.pos_embedding(h)
+ h = self.trans(h)
+ h = self.l2(h)
+
+ if self.residual:
+ x = x[..., :self.out_dim] + h
+ else:
+ x = h
+ x = x.permute(1,0,2)
+
+ return x
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(self, d_model, dropout=0.1, max_len=100):
+ super(PositionalEncoding, self).__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ pe = torch.zeros(max_len, d_model)
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0).transpose(0, 1)
+
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ # not used in the final model
+ x = x + self.pe[:x.shape[0], :]
+ return self.dropout(x)
diff --git a/lib/pipeline/__init__.py b/lib/pipeline/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lib/pipeline/__pycache__/__init__.cpython-310.pyc b/lib/pipeline/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b0ca428c1966c571d69b2796cd2c187dc66446f
Binary files /dev/null and b/lib/pipeline/__pycache__/__init__.cpython-310.pyc differ
diff --git a/lib/pipeline/__pycache__/est_scale.cpython-310.pyc b/lib/pipeline/__pycache__/est_scale.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc53b52d4d22b1a1a17175a5f62c04b706cf276a
Binary files /dev/null and b/lib/pipeline/__pycache__/est_scale.cpython-310.pyc differ
diff --git a/lib/pipeline/__pycache__/masked_droid_slam.cpython-310.pyc b/lib/pipeline/__pycache__/masked_droid_slam.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0f93b76370783d2774c468d83be5d415b749115
Binary files /dev/null and b/lib/pipeline/__pycache__/masked_droid_slam.cpython-310.pyc differ
diff --git a/lib/pipeline/__pycache__/tools.cpython-310.pyc b/lib/pipeline/__pycache__/tools.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43b0b6082146b5ba984551e79ef8fe27a66f91f6
Binary files /dev/null and b/lib/pipeline/__pycache__/tools.cpython-310.pyc differ
diff --git a/lib/pipeline/est_scale.py b/lib/pipeline/est_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..780d7a1e2dd4d7b2680e49ba13e3a7feb8cf99c5
--- /dev/null
+++ b/lib/pipeline/est_scale.py
@@ -0,0 +1,195 @@
+import numpy as np
+import cv2
+import torch
+from torchmin import minimize
+
+
+def est_scale_iterative(slam_depth, pred_depth, iters=10, msk=None):
+ """ Simple depth-align by iterative median and thresholding """
+ s = pred_depth / slam_depth
+
+ if msk is None:
+ msk = np.zeros_like(pred_depth)
+ else:
+ msk = cv2.resize(msk, (pred_depth.shape[1], pred_depth.shape[0]))
+
+ robust = (msk<0.5) * (0 4:
+ image = cv2.undistort(image, K, calib[4:])
+
+ h0, w0, _ = image.shape
+ h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0)))
+ w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0)))
+
+ image = cv2.resize(image, (w1, h1))
+ image = image[:h1-h1%8, :w1-w1%8]
+ image = torch.as_tensor(image).permute(2, 0, 1)
+
+ intrinsics = torch.as_tensor([fx, fy, cx, cy])
+ intrinsics[0::2] *= (w1 / w0)
+ intrinsics[1::2] *= (h1 / h0)
+
+ yield t, image[None], intrinsics
+
+
+def run_slam(imagedir, masks, calib=None, depth=None, stride=1,
+ filter_thresh=2.4, disable_vis=True):
+ """ Maksed DROID-SLAM """
+ droid = None
+ depth = None
+ args.filter_thresh = filter_thresh
+ args.disable_vis = disable_vis
+ masks = masks[::stride]
+
+ img_msks, conf_msks = preprocess_masks(imagedir, masks)
+ if calib is None:
+ calib = est_calib(imagedir)
+
+ for (t, image, intrinsics) in tqdm(image_stream(imagedir, calib, stride)):
+
+ if droid is None:
+ args.image_size = [image.shape[2], image.shape[3]]
+ droid = Droid(args)
+
+ img_msk = img_msks[t]
+ conf_msk = conf_msks[t]
+ image = image * (img_msk < 0.5)
+ # cv2.imwrite('debug.png', image[0].permute(1, 2, 0).numpy())
+
+ droid.track(t, image, intrinsics=intrinsics, depth=depth, mask=conf_msk)
+
+ traj = droid.terminate(image_stream(imagedir, calib, stride))
+
+ return droid, traj
+
+def run_droid_slam(imagedir, calib=None, depth=None, stride=1,
+ filter_thresh=2.4, disable_vis=True):
+ """ Maksed DROID-SLAM """
+ droid = None
+ depth = None
+ args.filter_thresh = filter_thresh
+ args.disable_vis = disable_vis
+
+ if calib is None:
+ calib = est_calib(imagedir)
+
+ for (t, image, intrinsics) in tqdm(image_stream(imagedir, calib, stride)):
+
+ if droid is None:
+ args.image_size = [image.shape[2], image.shape[3]]
+ droid = Droid(args)
+
+ droid.track(t, image, intrinsics=intrinsics, depth=depth)
+
+ traj = droid.terminate(image_stream(imagedir, calib, stride))
+
+ return droid, traj
+
+
+def eval_slam(traj_est, cam_t, cam_q, return_traj=True, correct_scale=False, align=True, align_origin=False):
+ """ Evaluation for SLAM """
+ tstamps = np.array([i for i in range(len(traj_est))], dtype=np.float32)
+
+ traj_est = PoseTrajectory3D(
+ positions_xyz=traj_est[:,:3],
+ orientations_quat_wxyz=traj_est[:,3:],
+ timestamps=tstamps)
+
+ traj_ref = PoseTrajectory3D(
+ positions_xyz=cam_t.copy(),
+ orientations_quat_wxyz=cam_q.copy(),
+ timestamps=tstamps)
+
+ traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est)
+ result = main_ape.ape(traj_ref, traj_est, est_name='traj',
+ pose_relation=PoseRelation.translation_part, align=align, align_origin=align_origin,
+ correct_scale=correct_scale)
+
+ stats = result.stats
+
+ if return_traj:
+ return stats, traj_ref, traj_est
+
+ return stats
+
+
+def test_slam(imagedir, img_msks, conf_msks, calib, stride=10, max_frame=50):
+ """ Shorter SLAM step to test reprojection error """
+ args = parser.parse_args([])
+ args.stereo = False
+ args.upsample = False
+ args.disable_vis = True
+ args.frontend_window = 10
+ args.frontend_thresh = 10
+ droid = None
+
+ for (t, image, intrinsics) in image_stream(imagedir, calib, stride, max_frame):
+ if droid is None:
+ args.image_size = [image.shape[2], image.shape[3]]
+ droid = Droid(args)
+
+ img_msk = img_msks[t]
+ conf_msk = conf_msks[t]
+ image = image * (img_msk < 0.5)
+ droid.track(t, image, intrinsics=intrinsics, mask=conf_msk)
+
+ reprojection_error = droid.compute_error()
+ del droid
+
+ return reprojection_error
+
+
+def search_focal_length(img_folder, masks, stride=10, max_frame=50,
+ low=500, high=1500, step=100):
+ """ Search for a good focal length by SLAM reprojection error """
+ masks = masks[::stride]
+ masks = masks[:max_frame]
+ img_msks, conf_msks = preprocess_masks(img_folder, masks)
+
+ # default estimate
+ calib = np.array(est_calib(img_folder))
+ best_focal = calib[0]
+ best_err = test_slam(img_folder, img_msks, conf_msks,
+ stride=stride, calib=calib, max_frame=max_frame)
+
+ # search based on slam reprojection error
+ for focal in range(low, high, step):
+ calib[:2] = focal
+ err = test_slam(img_folder, img_msks, conf_msks,
+ stride=stride, calib=calib, max_frame=max_frame)
+
+ if err < best_err:
+ best_err = err
+ best_focal = focal
+
+ print('Best focal length:', best_focal)
+
+ return best_focal
+
+
+def preprocess_masks(img_folder, masks):
+ """ Resize masks for masked droid """
+ H, W = get_dimention(img_folder)
+ resize_1 = Resize((H, W), antialias=True)
+ resize_2 = Resize((H//8, W//8), antialias=True)
+
+ img_msks = []
+ for i in range(0, len(masks), 500):
+ m = resize_1(masks[i:i+500])
+ img_msks.append(m)
+ img_msks = torch.cat(img_msks)
+
+ conf_msks = []
+ for i in range(0, len(masks), 500):
+ m = resize_2(masks[i:i+500])
+ conf_msks.append(m)
+ conf_msks = torch.cat(conf_msks)
+
+ return img_msks, conf_msks
+
+
+
+
+
diff --git a/lib/pipeline/tools.py b/lib/pipeline/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..1869faf35b8532b26b25fcb1971fa9046e9b3b15
--- /dev/null
+++ b/lib/pipeline/tools.py
@@ -0,0 +1,129 @@
+import cv2
+from tqdm import tqdm
+import numpy as np
+import torch
+
+from ultralytics import YOLO
+
+
+if torch.cuda.is_available():
+ autocast = torch.cuda.amp.autocast
+else:
+ class autocast:
+ def __init__(self, enabled=True):
+ pass
+ def __enter__(self):
+ pass
+ def __exit__(self, *args):
+ pass
+
+
+def detect_track(imgfiles, thresh=0.5):
+
+ hand_det_model = YOLO('./weights/external/detector.pt')
+
+ # Run
+ boxes_ = []
+ tracks = {}
+ for t, imgpath in enumerate(tqdm(imgfiles)):
+ img_cv2 = cv2.imread(imgpath)
+
+ ### --- Detection ---
+ with torch.no_grad():
+ with autocast():
+ results = hand_det_model.track(img_cv2, conf=thresh, persist=True, verbose=False)
+
+ boxes = results[0].boxes.xyxy.cpu().numpy()
+ confs = results[0].boxes.conf.cpu().numpy()
+ handedness = results[0].boxes.cls.cpu().numpy()
+ if not results[0].boxes.id is None:
+ track_id = results[0].boxes.id.cpu().numpy()
+ else:
+ track_id = [-1] * len(boxes)
+
+ boxes = np.hstack([boxes, confs[:, None]])
+ find_right = False
+ find_left = False
+ for idx, box in enumerate(boxes):
+ if track_id[idx] == -1:
+ if handedness[[idx]] > 0:
+ id = int(10000)
+ else:
+ id = int(5000)
+ else:
+ id = track_id[idx]
+ subj = dict()
+ subj['frame'] = t
+ subj['det'] = True
+ subj['det_box'] = boxes[[idx]]
+ subj['det_handedness'] = handedness[[idx]]
+
+
+ if (not find_right and handedness[[idx]] > 0) or (not find_left and handedness[[idx]]==0):
+ if id in tracks:
+ tracks[id].append(subj)
+ else:
+ tracks[id] = [subj]
+
+ if handedness[[idx]] > 0:
+ find_right = True
+ elif handedness[[idx]] == 0:
+ find_left = True
+ tracks = np.array(tracks, dtype=object)
+ boxes_ = np.array(boxes_, dtype=object)
+
+ return boxes_, tracks
+
+
+def parse_chunks(frame, boxes, min_len=16):
+ """ If a track disappear in the middle,
+ we separate it to different segments to estimate the HPS independently.
+ If a segment is less than 16 frames, we get rid of it for now.
+ """
+ frame_chunks = []
+ boxes_chunks = []
+ step = frame[1:] - frame[:-1]
+ step = np.concatenate([[0], step])
+ breaks = np.where(step != 1)[0]
+
+ start = 0
+ for bk in breaks:
+ f_chunk = frame[start:bk]
+ b_chunk = boxes[start:bk]
+ start = bk
+ if len(f_chunk)>=min_len:
+ frame_chunks.append(f_chunk)
+ boxes_chunks.append(b_chunk)
+
+ if bk==breaks[-1]: # last chunk
+ f_chunk = frame[bk:]
+ b_chunk = boxes[bk:]
+ if len(f_chunk)>=min_len:
+ frame_chunks.append(f_chunk)
+ boxes_chunks.append(b_chunk)
+
+ return frame_chunks, boxes_chunks
+
+def parse_chunks_hand_frame(frame):
+ """ If a track disappear in the middle,
+ we separate it to different segments to estimate the HPS independently.
+ If a segment is less than 16 frames, we get rid of it for now.
+ """
+ frame_chunks = []
+ step = frame[1:] - frame[:-1]
+ step = np.concatenate([[0], step])
+ breaks = np.where(step != 1)[0]
+
+ start = 0
+ for bk in breaks:
+ f_chunk = frame[start:bk]
+ start = bk
+ if len(f_chunk) > 0:
+ frame_chunks.append(f_chunk)
+
+ if bk==breaks[-1]: # last chunk
+ f_chunk = frame[bk:]
+ if len(f_chunk) > 0:
+ frame_chunks.append(f_chunk)
+
+ return frame_chunks
diff --git a/lib/utils/__pycache__/geometry.cpython-310.pyc b/lib/utils/__pycache__/geometry.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9c90a145f832d48366222f7963cf1bc1f0cdf04
Binary files /dev/null and b/lib/utils/__pycache__/geometry.cpython-310.pyc differ
diff --git a/lib/utils/__pycache__/imutils.cpython-310.pyc b/lib/utils/__pycache__/imutils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b6a7494b10d0336eecadabbd494530d3b956ca4
Binary files /dev/null and b/lib/utils/__pycache__/imutils.cpython-310.pyc differ
diff --git a/lib/utils/geometry.py b/lib/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3d65066ce383c744407f8d800a7342cd324142b
--- /dev/null
+++ b/lib/utils/geometry.py
@@ -0,0 +1,412 @@
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+
+def perspective_projection(points, rotation, translation,
+ focal_length, camera_center, distortion=None):
+ """
+ This function computes the perspective projection of a set of points.
+ Input:
+ points (bs, N, 3): 3D points
+ rotation (bs, 3, 3): Camera rotation
+ translation (bs, 3): Camera translation
+ focal_length (bs,) or scalar: Focal length
+ camera_center (bs, 2): Camera center
+ """
+ batch_size = points.shape[0]
+
+ # Extrinsic
+ if rotation is not None:
+ points = torch.einsum('bij,bkj->bki', rotation, points)
+
+ if translation is not None:
+ points = points + translation.unsqueeze(1)
+
+ if distortion is not None:
+ kc = distortion
+ points = points[:,:,:2] / points[:,:,2:]
+
+ r2 = points[:,:,0]**2 + points[:,:,1]**2
+ dx = (2 * kc[:,[2]] * points[:,:,0] * points[:,:,1]
+ + kc[:,[3]] * (r2 + 2*points[:,:,0]**2))
+
+ dy = (2 * kc[:,[3]] * points[:,:,0] * points[:,:,1]
+ + kc[:,[2]] * (r2 + 2*points[:,:,1]**2))
+
+ x = (1 + kc[:,[0]]*r2 + kc[:,[1]]*r2.pow(2) + kc[:,[4]]*r2.pow(3)) * points[:,:,0] + dx
+ y = (1 + kc[:,[0]]*r2 + kc[:,[1]]*r2.pow(2) + kc[:,[4]]*r2.pow(3)) * points[:,:,1] + dy
+
+ points = torch.stack([x, y, torch.ones_like(x)], dim=-1)
+
+ # Intrinsic
+ K = torch.zeros([batch_size, 3, 3], device=points.device)
+ K[:,0,0] = focal_length
+ K[:,1,1] = focal_length
+ K[:,2,2] = 1.
+ K[:,:-1, -1] = camera_center
+
+ # Apply camera intrinsicsrf
+ points = points / points[:,:,-1].unsqueeze(-1)
+ projected_points = torch.einsum('bij,bkj->bki', K, points)
+ projected_points = projected_points[:, :, :-1]
+
+ return projected_points
+
+
+def avg_rot(rot):
+ # input [B,...,3,3] --> output [...,3,3]
+ rot = rot.mean(dim=0)
+ U, _, V = torch.svd(rot)
+ rot = U @ V.transpose(-1, -2)
+ return rot
+
+
+def rot9d_to_rotmat(x):
+ """Convert 9D rotation representation to 3x3 rotation matrix.
+ Based on Levinson et al., "An Analysis of SVD for Deep Rotation Estimation"
+ Input:
+ (B,9) or (B,J*9) Batch of 9D rotation (interpreted as 3x3 est rotmat)
+ Output:
+ (B,3,3) or (B*J,3,3) Batch of corresponding rotation matrices
+ """
+ x = x.view(-1,3,3)
+ u, _, vh = torch.linalg.svd(x)
+
+ sig = torch.eye(3).expand(len(x), 3, 3).clone()
+ sig = sig.to(x.device)
+ sig[:, -1, -1] = (u @ vh).det()
+
+ R = u @ sig @ vh
+
+ return R
+
+
+"""
+Deprecated in favor of: rotation_conversions.py
+
+Useful geometric operations, e.g. differentiable Rodrigues formula
+Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
+"""
+def batch_rodrigues(theta):
+ """Convert axis-angle representation to rotation matrix.
+ Args:
+ theta: size = [B, 3]
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
+ angle = torch.unsqueeze(l1norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
+ return quat_to_rotmat(quat)
+
+def quat_to_rotmat(quat):
+ """Convert quaternion coefficients to rotation matrix.
+ Args:
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ norm_quat = quat
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
+
+ B = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w*x, w*y, w*z
+ xy, xz, yz = x*y, x*z, y*z
+
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
+ return rotMat
+
+def rot6d_to_rotmat(x):
+ """Convert 6D rotation representation to 3x3 rotation matrix.
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
+ Input:
+ (B,6) Batch of 6-D rotation representations
+ Output:
+ (B,3,3) Batch of corresponding rotation matrices
+ """
+ x = x.view(-1,3,2)
+ a1 = x[:, :, 0]
+ a2 = x[:, :, 1]
+ b1 = F.normalize(a1)
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
+ b3 = torch.cross(b1, b2)
+ return torch.stack((b1, b2, b3), dim=-1)
+
+def rot6d_to_rotmat_hmr2(x: torch.Tensor) -> torch.Tensor:
+ """
+ Convert 6D rotation representation to 3x3 rotation matrix.
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
+ Args:
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
+ Returns:
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
+ """
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
+ a1 = x[:, :, 0]
+ a2 = x[:, :, 1]
+ b1 = F.normalize(a1)
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
+ b3 = torch.cross(b1, b2)
+ return torch.stack((b1, b2, b3), dim=-1)
+
+def rotmat_to_rot6d(rotmat):
+ """ Inverse function of the above.
+ Input:
+ (B,3,3) Batch of corresponding rotation matrices
+ Output:
+ (B,6) Batch of 6-D rotation representations
+ """
+ # rot6d = rotmat[:, :, :2]
+ rot6d = rotmat[...,:2]
+ rot6d = rot6d.reshape(rot6d.size(0), -1)
+ return rot6d
+
+
+def rotation_matrix_to_angle_axis(rotation_matrix):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert 3x4 rotation matrix to Rodrigues vector
+
+ Args:
+ rotation_matrix (Tensor): rotation matrix.
+
+ Returns:
+ Tensor: Rodrigues vector transformation.
+
+ Shape:
+ - Input: :math:`(N, 3, 4)`
+ - Output: :math:`(N, 3)`
+
+ Example:
+ >>> input = torch.rand(2, 3, 4) # Nx4x4
+ >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
+ """
+ if rotation_matrix.shape[1:] == (3,3):
+ rot_mat = rotation_matrix.reshape(-1, 3, 3)
+ hom = torch.tensor([0, 0, 1], dtype=torch.float32,
+ device=rotation_matrix.device).reshape(1, 3, 1).expand(rot_mat.shape[0], -1, -1)
+ rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
+
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
+ aa = quaternion_to_angle_axis(quaternion)
+ aa[torch.isnan(aa)] = 0.0
+ return aa
+
+
+def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert quaternion vector to angle axis of rotation.
+
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
+
+ Args:
+ quaternion (torch.Tensor): tensor with quaternions.
+
+ Return:
+ torch.Tensor: tensor with angle axis of rotation.
+
+ Shape:
+ - Input: :math:`(*, 4)` where `*` means, any number of dimensions
+ - Output: :math:`(*, 3)`
+
+ Example:
+ >>> quaternion = torch.rand(2, 4) # Nx4
+ >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
+ """
+ if not torch.is_tensor(quaternion):
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
+ type(quaternion)))
+
+ if not quaternion.shape[-1] == 4:
+ raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
+ .format(quaternion.shape))
+ # unpack input and compute conversion
+ q1: torch.Tensor = quaternion[..., 1]
+ q2: torch.Tensor = quaternion[..., 2]
+ q3: torch.Tensor = quaternion[..., 3]
+ sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
+
+ sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
+ cos_theta: torch.Tensor = quaternion[..., 0]
+ two_theta: torch.Tensor = 2.0 * torch.where(
+ cos_theta < 0.0,
+ torch.atan2(-sin_theta, -cos_theta),
+ torch.atan2(sin_theta, cos_theta))
+
+ k_pos: torch.Tensor = two_theta / sin_theta
+ k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
+ k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
+
+ angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
+ angle_axis[..., 0] += q1 * k
+ angle_axis[..., 1] += q2 * k
+ angle_axis[..., 2] += q3 * k
+ return angle_axis
+
+
+def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert 3x4 rotation matrix to 4d quaternion vector
+
+ This algorithm is based on algorithm described in
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
+
+ Args:
+ rotation_matrix (Tensor): the rotation matrix to convert.
+
+ Return:
+ Tensor: the rotation in quaternion
+
+ Shape:
+ - Input: :math:`(N, 3, 4)`
+ - Output: :math:`(N, 4)`
+
+ Example:
+ >>> input = torch.rand(4, 3, 4) # Nx3x4
+ >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
+ """
+ if not torch.is_tensor(rotation_matrix):
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
+ type(rotation_matrix)))
+
+ if len(rotation_matrix.shape) > 3:
+ raise ValueError(
+ "Input size must be a three dimensional tensor. Got {}".format(
+ rotation_matrix.shape))
+ if not rotation_matrix.shape[-2:] == (3, 4):
+ raise ValueError(
+ "Input size must be a N x 3 x 4 tensor. Got {}".format(
+ rotation_matrix.shape))
+
+ rmat_t = torch.transpose(rotation_matrix, 1, 2)
+
+ mask_d2 = rmat_t[:, 2, 2] < eps
+
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
+
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
+ q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
+ t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
+ t0_rep = t0.repeat(4, 1).t()
+
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
+ q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+ t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
+ t1_rep = t1.repeat(4, 1).t()
+
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
+ q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
+ t2_rep = t2.repeat(4, 1).t()
+
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
+ q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
+ t3_rep = t3.repeat(4, 1).t()
+
+ mask_c0 = mask_d2 * mask_d0_d1
+ mask_c1 = mask_d2 * ~mask_d0_d1
+ mask_c2 = ~mask_d2 * mask_d0_nd1
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
+
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
+ q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
+ t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
+ q *= 0.5
+ return q
+
+
+def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000., img_size=224.):
+ """
+ This function is borrowed from https://github.com/nkolot/SPIN/utils/geometry.py
+
+ Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
+ Input:
+ S: (25, 3) 3D joint locations
+ joints: (25, 3) 2D joint locations and confidence
+ Returns:
+ (3,) camera translation vector
+ """
+
+ num_joints = S.shape[0]
+ # focal length
+ f = np.array([focal_length,focal_length])
+ # optical center
+ center = np.array([img_size/2., img_size/2.])
+
+ # transformations
+ Z = np.reshape(np.tile(S[:,2],(2,1)).T,-1)
+ XY = np.reshape(S[:,0:2],-1)
+ O = np.tile(center,num_joints)
+ F = np.tile(f,num_joints)
+ weight2 = np.reshape(np.tile(np.sqrt(joints_conf),(2,1)).T,-1)
+
+ # least squares
+ Q = np.array([F*np.tile(np.array([1,0]),num_joints), F*np.tile(np.array([0,1]),num_joints), O-np.reshape(joints_2d,-1)]).T
+ c = (np.reshape(joints_2d,-1)-O)*Z - F*XY
+
+ # weighted least squares
+ W = np.diagflat(weight2)
+ Q = np.dot(W,Q)
+ c = np.dot(W,c)
+
+ # square matrix
+ A = np.dot(Q.T,Q)
+ b = np.dot(Q.T,c)
+
+ # solution
+ trans = np.linalg.solve(A, b)
+
+ return trans
+
+
+def estimate_translation(S, joints_2d, focal_length=5000., img_size=224.):
+ """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
+ Input:
+ S: (B, 49, 3) 3D joint locations
+ joints: (B, 49, 3) 2D joint locations and confidence
+ Returns:
+ (B, 3) camera translation vectors
+ """
+
+ device = S.device
+ # Use only joints 25:49 (GT joints)
+ S = S[:, -24:, :3].cpu().numpy()
+ joints_2d = joints_2d[:, -24:, :].cpu().numpy()
+
+ joints_conf = joints_2d[:, :, -1]
+ joints_2d = joints_2d[:, :, :-1]
+ trans = np.zeros((S.shape[0], 3), dtype=np.float32)
+ # Find the translation for each example in the batch
+ for i in range(S.shape[0]):
+ S_i = S[i]
+ joints_i = joints_2d[i]
+ conf_i = joints_conf[i]
+ trans[i] = estimate_translation_np(S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size)
+ return torch.from_numpy(trans).to(device)
+
+
diff --git a/lib/utils/imutils.py b/lib/utils/imutils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d38db0bdd4446245f4bfdafb43643d92125f9d8d
--- /dev/null
+++ b/lib/utils/imutils.py
@@ -0,0 +1,286 @@
+"""
+This file contains functions that are used to perform data augmentation.
+"""
+import torch
+import numpy as np
+from skimage.transform import rotate, resize
+import cv2
+from torchvision.transforms import Normalize, ToTensor, Compose
+
+from lib.core import constants
+
+def get_normalization():
+ normalize_img = Compose([ToTensor(),
+ Normalize(mean=constants.IMG_NORM_MEAN,
+ std=constants.IMG_NORM_STD)
+ ])
+ return normalize_img
+
+def get_transform(center, scale, res, rot=0):
+ """Generate transformation matrix."""
+ h = 200 * scale + 1e-6
+ t = np.zeros((3, 3))
+ t[0, 0] = float(res[1]) / h
+ t[1, 1] = float(res[0]) / h
+ t[0, 2] = res[1] * (-float(center[0]) / h + .5)
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
+ t[2, 2] = 1
+ if not rot == 0:
+ rot = -rot # To match direction of rotation from cropping
+ rot_mat = np.zeros((3,3))
+ rot_rad = rot * np.pi / 180
+ sn,cs = np.sin(rot_rad), np.cos(rot_rad)
+ rot_mat[0,:2] = [cs, -sn]
+ rot_mat[1,:2] = [sn, cs]
+ rot_mat[2,2] = 1
+ # Need to rotate around center
+ t_mat = np.eye(3)
+ t_mat[0,2] = -res[1]/2
+ t_mat[1,2] = -res[0]/2
+ t_inv = t_mat.copy()
+ t_inv[:2,2] *= -1
+ t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
+ return t
+
+def transform(pt, center, scale, res, invert=0, rot=0, asint=True):
+ """Transform pixel location to different reference."""
+ t = get_transform(center, scale, res, rot=rot)
+ if invert:
+ t = np.linalg.inv(t)
+ new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T
+ new_pt = np.dot(t, new_pt)
+
+ if asint:
+ return new_pt[:2].astype(int)+1
+ else:
+ return new_pt[:2]+1
+
+def transform_pts(pts, center, scale, res, invert=0, rot=0, asint=True):
+ """Transform pixel location to different reference."""
+ t = get_transform(center, scale, res, rot=rot)
+ if invert:
+ t = np.linalg.inv(t)
+ pts = np.concatenate((pts, np.ones_like(pts)[:, [0]]), axis=-1)
+ new_pt = pts.T
+ new_pt = np.dot(t, new_pt)
+
+ if asint:
+ return new_pt[:2, :].T.astype(int)
+ else:
+ return new_pt[:2, :].T
+
+def crop(img, center, scale, res, rot=0):
+ """Crop image according to the supplied bounding box."""
+ # Upper left point
+ ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
+ # Bottom right point
+ br = np.array(transform([res[0]+1,
+ res[1]+1], center, scale, res, invert=1))-1
+
+ # Padding so that when rotated proper amount of context is included
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
+ if not rot == 0:
+ ul -= pad
+ br += pad
+
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
+ if len(img.shape) > 2:
+ new_shape += [img.shape[2]]
+ new_img = np.zeros(new_shape)
+
+
+ # Range to fill new array
+ new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
+ new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
+ # Range to sample from original image
+ old_x = max(0, ul[0]), min(len(img[0]), br[0])
+ old_y = max(0, ul[1]), min(len(img), br[1])
+ try:
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
+ old_x[0]:old_x[1]]
+ except:
+ print("invlid bbox, fill with 0")
+
+ if not rot == 0:
+ # Remove padding
+ new_img = rotate(new_img, rot)
+ new_img = new_img[pad:-pad, pad:-pad]
+
+ new_img = resize(new_img, res)
+ return new_img
+
+def crop_j2d(j2d, center, scale, res, rot=0):
+ """Crop image according to the supplied bounding box."""
+ # Upper left point
+ # crop_j2d = np.array(transform_pts(j2d, center, scale, res, invert=0))
+ b = scale * 200
+ points2d = j2d - (center - b/2)
+ points2d = points2d * (res[0] / b)
+
+ return points2d
+
+
+def crop_crop(img, center, scale, res, rot=0):
+ """Crop image according to the supplied bounding box."""
+ # Upper left point
+ ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
+ # Bottom right point
+ br = np.array(transform([res[0]+1,
+ res[1]+1], center, scale, res, invert=1))-1
+
+ # Padding so that when rotated proper amount of context is included
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
+ if not rot == 0:
+ ul -= pad
+ br += pad
+
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
+ if len(img.shape) > 2:
+ new_shape += [img.shape[2]]
+ new_img = np.zeros(new_shape)
+
+
+ if new_img.shape[0] > img.shape[0]:
+ p = (new_img.shape[0] - img.shape[0]) / 2
+ p = int(p)
+ new_img = cv2.copyMakeBorder(img, p, p, p, p, cv2.BORDER_REPLICATE)
+
+ # Range to fill new array
+ new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
+ new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
+ # Range to sample from original image
+ old_x = max(0, ul[0]), min(len(img[0]), br[0])
+ old_y = max(0, ul[1]), min(len(img), br[1])
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
+ old_x[0]:old_x[1]]
+
+ if not rot == 0:
+ # Remove padding
+ new_img = rotate(new_img, rot)
+ new_img = new_img[pad:-pad, pad:-pad]
+
+ new_img = resize(new_img, res)
+ return new_img
+
+def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
+ """'Undo' the image cropping/resizing.
+ This function is used when evaluating mask/part segmentation.
+ """
+ res = img.shape[:2]
+ # Upper left point
+ ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
+ # Bottom right point
+ br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1
+ # size of cropped image
+ crop_shape = [br[1] - ul[1], br[0] - ul[0]]
+
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
+ if len(img.shape) > 2:
+ new_shape += [img.shape[2]]
+ new_img = np.zeros(orig_shape, dtype=np.uint8)
+ # Range to fill new array
+ new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
+ new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
+ # Range to sample from original image
+ old_x = max(0, ul[0]), min(orig_shape[1], br[0])
+ old_y = max(0, ul[1]), min(orig_shape[0], br[1])
+ img = resize(img, crop_shape, interp='nearest')
+ new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
+ return new_img
+
+def rot_aa(aa, rot):
+ """Rotate axis angle parameters."""
+ # pose parameters
+ R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
+ [0, 0, 1]])
+ # find the rotation of the body in camera frame
+ per_rdg, _ = cv2.Rodrigues(aa)
+ # apply the global rotation to the global orientation
+ resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
+ aa = (resrot.T)[0]
+ return aa
+
+def flip_img(img):
+ """Flip rgb images or masks.
+ channels come last, e.g. (256,256,3).
+ """
+ img = np.fliplr(img)
+ return img
+
+def flip_kp(kp):
+ """Flip keypoints."""
+ if len(kp) == 24:
+ flipped_parts = constants.J24_FLIP_PERM
+ elif len(kp) == 49:
+ flipped_parts = constants.J49_FLIP_PERM
+ kp = kp[flipped_parts]
+ kp[:,0] = - kp[:,0]
+ return kp
+
+def flip_pose(pose):
+ """Flip pose.
+ The flipping is based on SMPL parameters.
+ """
+ flipped_parts = constants.SMPL_POSE_FLIP_PERM
+ pose = pose[flipped_parts]
+ # we also negate the second and the third dimension of the axis-angle
+ pose[1::3] = -pose[1::3]
+ pose[2::3] = -pose[2::3]
+ return pose
+
+
+def crop_img(img, center, scale, res, val=255):
+ """Crop image according to the supplied bounding box."""
+ # Upper left point
+ ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
+ # Bottom right point
+ br = np.array(transform([res[0]+1,
+ res[1]+1], center, scale, res, invert=1))-1
+
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
+ if len(img.shape) > 2:
+ new_shape += [img.shape[2]]
+ new_img = np.ones(new_shape) * val
+
+ # Range to fill new array
+ new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
+ new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
+ # Range to sample from original image
+ old_x = max(0, ul[0]), min(len(img[0]), br[0])
+ old_y = max(0, ul[1]), min(len(img), br[1])
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
+ old_x[0]:old_x[1]]
+ new_img = resize(new_img, res)
+ return new_img
+
+
+def boxes_2_cs(boxes):
+ x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
+ w, h = x2-x1, y2-y1
+ cx, cy = x1+w/2, y1+h/2
+ size = np.stack([w, h]).max(axis=0)
+
+ centers = np.stack([cx, cy], axis=1)
+ scales = size / 200
+ return centers, scales
+
+
+def box_2_cs(box):
+ x1,y1,x2,y2 = box[:4].int().tolist()
+
+ w, h = x2-x1, y2-y1
+ cx, cy = x1+w/2, y1+h/2
+ size = max(w, h)
+
+ center = [cx, cy]
+ scale = size / 200
+ return center, scale
+
+
+def est_intrinsics(img_shape):
+ h, w, c = img_shape
+ img_center = torch.tensor([w/2., h/2.]).float()
+ img_focal = torch.tensor(np.sqrt(h**2 + w**2)).float()
+ return img_center, img_focal
+
diff --git a/lib/vis/__pycache__/renderer.cpython-310.pyc b/lib/vis/__pycache__/renderer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd3fcdddc6426fe08a0300b4ec618c1c97c4dad3
Binary files /dev/null and b/lib/vis/__pycache__/renderer.cpython-310.pyc differ
diff --git a/lib/vis/__pycache__/run_vis2.cpython-310.pyc b/lib/vis/__pycache__/run_vis2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0f8f39d1ffb22e7a23459d38fc903eb10802f76a
Binary files /dev/null and b/lib/vis/__pycache__/run_vis2.cpython-310.pyc differ
diff --git a/lib/vis/__pycache__/tools.cpython-310.pyc b/lib/vis/__pycache__/tools.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b96ce951765fead7570b9951a63e48b0e9ebb0b
Binary files /dev/null and b/lib/vis/__pycache__/tools.cpython-310.pyc differ
diff --git a/lib/vis/__pycache__/viewer.cpython-310.pyc b/lib/vis/__pycache__/viewer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5679447fbf1a890041998ef990ec13444f1d734
Binary files /dev/null and b/lib/vis/__pycache__/viewer.cpython-310.pyc differ
diff --git a/lib/vis/renderer.py b/lib/vis/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..efe7e6ab2c89cdc5d95d85dfb62ca0c2d643a0c5
--- /dev/null
+++ b/lib/vis/renderer.py
@@ -0,0 +1,356 @@
+# Useful rendering functions from WHAM (some modification)
+
+import cv2
+import torch
+import numpy as np
+
+from pytorch3d.renderer import (
+ PerspectiveCameras,
+ TexturesVertex,
+ PointLights,
+ Materials,
+ RasterizationSettings,
+ MeshRenderer,
+ MeshRasterizer,
+ SoftPhongShader,
+)
+from pytorch3d.structures import Meshes
+from pytorch3d.structures.meshes import join_meshes_as_scene
+from pytorch3d.renderer.cameras import look_at_rotation
+from pytorch3d.renderer.camera_conversions import _cameras_from_opencv_projection
+
+from .tools import get_colors, checkerboard_geometry
+
+
+def overlay_image_onto_background(image, mask, bbox, background):
+ if isinstance(image, torch.Tensor):
+ image = image.detach().cpu().numpy()
+ if isinstance(mask, torch.Tensor):
+ mask = mask.detach().cpu().numpy()
+
+ out_image = background.copy()
+ bbox = bbox[0].int().cpu().numpy().copy()
+ roi_image = out_image[bbox[1]:bbox[3], bbox[0]:bbox[2]]
+
+ roi_image[mask] = image[mask]
+ out_image[bbox[1]:bbox[3], bbox[0]:bbox[2]] = roi_image
+
+ return out_image
+
+
+def update_intrinsics_from_bbox(K_org, bbox):
+ device, dtype = K_org.device, K_org.dtype
+
+ K = torch.zeros((K_org.shape[0], 4, 4)
+ ).to(device=device, dtype=dtype)
+ K[:, :3, :3] = K_org.clone()
+ K[:, 2, 2] = 0
+ K[:, 2, -1] = 1
+ K[:, -1, 2] = 1
+
+ image_sizes = []
+ for idx, bbox in enumerate(bbox):
+ left, upper, right, lower = bbox
+ cx, cy = K[idx, 0, 2], K[idx, 1, 2]
+
+ new_cx = cx - left
+ new_cy = cy - upper
+ new_height = max(lower - upper, 1)
+ new_width = max(right - left, 1)
+ new_cx = new_width - new_cx
+ new_cy = new_height - new_cy
+
+ K[idx, 0, 2] = new_cx
+ K[idx, 1, 2] = new_cy
+ image_sizes.append((int(new_height), int(new_width)))
+
+ return K, image_sizes
+
+
+def perspective_projection(x3d, K, R=None, T=None):
+ if R != None:
+ x3d = torch.matmul(R, x3d.transpose(1, 2)).transpose(1, 2)
+ if T != None:
+ x3d = x3d + T.transpose(1, 2)
+
+ x2d = torch.div(x3d, x3d[..., 2:])
+ x2d = torch.matmul(K, x2d.transpose(-1, -2)).transpose(-1, -2)[..., :2]
+ return x2d
+
+
+def compute_bbox_from_points(X, img_w, img_h, scaleFactor=1.2):
+ left = torch.clamp(X.min(1)[0][:, 0], min=0, max=img_w)
+ right = torch.clamp(X.max(1)[0][:, 0], min=0, max=img_w)
+ top = torch.clamp(X.min(1)[0][:, 1], min=0, max=img_h)
+ bottom = torch.clamp(X.max(1)[0][:, 1], min=0, max=img_h)
+
+ cx = (left + right) / 2
+ cy = (top + bottom) / 2
+ width = (right - left)
+ height = (bottom - top)
+
+ new_left = torch.clamp(cx - width/2 * scaleFactor, min=0, max=img_w-1)
+ new_right = torch.clamp(cx + width/2 * scaleFactor, min=1, max=img_w)
+ new_top = torch.clamp(cy - height / 2 * scaleFactor, min=0, max=img_h-1)
+ new_bottom = torch.clamp(cy + height / 2 * scaleFactor, min=1, max=img_h)
+
+ bbox = torch.stack((new_left.detach(), new_top.detach(),
+ new_right.detach(), new_bottom.detach())).int().float().T
+
+ return bbox
+
+
+class Renderer():
+ def __init__(self, width, height, focal_length, device,
+ bin_size=None, max_faces_per_bin=None):
+
+ self.width = width
+ self.height = height
+ self.focal_length = focal_length
+
+ self.device = device
+
+ self.initialize_camera_params()
+ self.lights = PointLights(device=device, location=[[0.0, 0.0, -10.0]])
+ self.create_renderer(bin_size, max_faces_per_bin)
+
+ def create_renderer(self, bin_size, max_faces_per_bin):
+ self.renderer = MeshRenderer(
+ rasterizer=MeshRasterizer(
+ raster_settings=RasterizationSettings(
+ image_size=self.image_sizes[0],
+ blur_radius=1e-5, bin_size=bin_size,
+ max_faces_per_bin=max_faces_per_bin),
+ ),
+ shader=SoftPhongShader(
+ device=self.device,
+ lights=self.lights,
+ )
+ )
+
+ def initialize_camera_params(self):
+ """Hard coding for camera parameters
+ TODO: Do some soft coding"""
+
+ # Extrinsics
+ self.R = torch.diag(
+ torch.tensor([1, 1, 1])
+ ).float().to(self.device).unsqueeze(0)
+
+ self.T = torch.tensor(
+ [0, 0, 0]
+ ).unsqueeze(0).float().to(self.device)
+
+ # Intrinsics
+ self.K = torch.tensor(
+ [[self.focal_length, 0, self.width/2],
+ [0, self.focal_length, self.height/2],
+ [0, 0, 1]]
+ ).unsqueeze(0).float().to(self.device)
+ self.bboxes = torch.tensor([[0, 0, self.width, self.height]]).float()
+ self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, self.bboxes)
+
+ # self.K_full = self.K # test
+ self.cameras = self.create_camera()
+
+ def create_camera(self, R=None, T=None):
+ if R is not None:
+ self.R = R.clone().view(1, 3, 3).to(self.device)
+ if T is not None:
+ self.T = T.clone().view(1, 3).to(self.device)
+
+ return PerspectiveCameras(
+ device=self.device,
+ R=self.R, #.mT,
+ T=self.T,
+ K=self.K_full,
+ image_size=self.image_sizes,
+ in_ndc=False)
+
+ def create_camera_from_cv(self, R, T, K=None, image_size=None):
+ # R: [1, 3, 3] Tensor
+ # T: [1, 3] Tensor
+ # K: [1, 3, 3] Tensor
+ # image_size: [1, 2] Tensor in HW
+ if K is None:
+ K = self.K
+
+ if image_size is None:
+ image_size = torch.tensor(self.image_sizes)
+
+ cameras = _cameras_from_opencv_projection(R, T, K, image_size)
+ lights = PointLights(device=K.device, location=T)
+
+ return cameras, lights
+
+ def set_ground(self, length, center_x, center_z):
+ device = self.device
+ v, f, vc, fc = map(torch.from_numpy, checkerboard_geometry(length=length, c1=center_x, c2=center_z, up="y"))
+ v, f, vc = v.to(device), f.to(device), vc.to(device)
+ self.ground_geometry = [v, f, vc]
+
+
+ def update_bbox(self, x3d, scale=2.0, mask=None):
+ """ Update bbox of cameras from the given 3d points
+
+ x3d: input 3D keypoints (or vertices), (num_frames, num_points, 3)
+ """
+
+ if x3d.size(-1) != 3:
+ x2d = x3d.unsqueeze(0)
+ else:
+ x2d = perspective_projection(x3d.unsqueeze(0), self.K, self.R, self.T.reshape(1, 3, 1))
+
+ if mask is not None:
+ x2d = x2d[:, ~mask]
+
+ bbox = compute_bbox_from_points(x2d, self.width, self.height, scale)
+ self.bboxes = bbox
+
+ self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox)
+ self.cameras = self.create_camera()
+ self.create_renderer()
+
+ def reset_bbox(self,):
+ bbox = torch.zeros((1, 4)).float().to(self.device)
+ bbox[0, 2] = self.width
+ bbox[0, 3] = self.height
+ self.bboxes = bbox
+
+ self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox)
+ self.cameras = self.create_camera()
+ self.create_renderer()
+
+ def render_mesh(self, vertices, background, colors=[0.8, 0.8, 0.8]):
+ self.update_bbox(vertices[::50], scale=1.2)
+ vertices = vertices.unsqueeze(0)
+
+ if colors[0] > 1: colors = [c / 255. for c in colors]
+ verts_features = torch.tensor(colors).reshape(1, 1, 3).to(device=vertices.device, dtype=vertices.dtype)
+ verts_features = verts_features.repeat(1, vertices.shape[1], 1)
+ textures = TexturesVertex(verts_features=verts_features)
+
+ mesh = Meshes(verts=vertices,
+ faces=self.faces,
+ textures=textures,)
+
+ materials = Materials(
+ device=self.device,
+ specular_color=(colors, ),
+ shininess=0
+ )
+
+ results = torch.flip(
+ self.renderer(mesh, materials=materials, cameras=self.cameras, lights=self.lights),
+ [1, 2]
+ )
+ image = results[0, ..., :3] * 255
+ mask = results[0, ..., -1] > 1e-3
+
+ image = overlay_image_onto_background(image, mask, self.bboxes, background.copy())
+ self.reset_bbox()
+ return image
+
+
+ def render_with_ground(self, verts, faces, colors, cameras, lights):
+ """
+ :param verts (B, V, 3)
+ :param faces (F, 3)
+ :param colors (B, 3)
+ """
+
+ # (B, V, 3), (B, F, 3), (B, V, 3)
+ verts, faces, colors = prep_shared_geometry(verts, faces, colors)
+ # (V, 3), (F, 3), (V, 3)
+ gv, gf, gc = self.ground_geometry
+ verts = list(torch.unbind(verts, dim=0)) + [gv]
+ faces = list(torch.unbind(faces, dim=0)) + [gf]
+ colors = list(torch.unbind(colors, dim=0)) + [gc[..., :3]]
+ mesh = create_meshes(verts, faces, colors)
+
+ materials = Materials(
+ device=self.device,
+ shininess=0
+ )
+
+ results = self.renderer(mesh, cameras=cameras, lights=lights, materials=materials)
+ image = (results[0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
+
+ return image
+
+ def render_multiple(self, verts_list, faces, colors_list, cameras, lights):
+ """
+ :param verts (B, V, 3)
+ :param faces (F, 3)
+ :param colors (B, 3)
+ """
+ # (B, V, 3), (B, F, 3), (B, V, 3)
+ verts_, faces_, colors_ = [], [], []
+ for i, verts in enumerate(verts_list):
+ colors = colors_list[[i]]
+ verts_i, faces_i, colors_i = prep_shared_geometry(verts, faces, colors)
+ if i == 0:
+ verts_ = list(torch.unbind(verts_i, dim=0))
+ faces_ = list(torch.unbind(faces_i, dim=0))
+ colors_ = list(torch.unbind(colors_i, dim=0))
+ else:
+ verts_ += list(torch.unbind(verts_i, dim=0))
+ faces_ += list(torch.unbind(faces_i, dim=0))
+ colors_ += list(torch.unbind(colors_i, dim=0))
+
+ # # (V, 3), (F, 3), (V, 3)
+ # gv, gf, gc = self.ground_geometry
+ # verts_ += [gv]
+ # faces_ += [gf]
+ # colors_ += [gc[..., :3]]
+ mesh = create_meshes(verts_, faces_, colors_)
+
+ materials = Materials(
+ device=self.device,
+ shininess=0
+ )
+ results = self.renderer(mesh, cameras=cameras, lights=lights, materials=materials)
+ image = (results[0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
+ mask = results[0, ..., -1].cpu().numpy() > 0
+ return image, mask
+
+
+def prep_shared_geometry(verts, faces, colors):
+ """
+ :param verts (B, V, 3)
+ :param faces (F, 3)
+ :param colors (B, 4)
+ """
+ B, V, _ = verts.shape
+ F, _ = faces.shape
+ colors = colors.unsqueeze(1).expand(B, V, -1)[..., :3]
+ faces = faces.unsqueeze(0).expand(B, F, -1)
+ return verts, faces, colors
+
+
+def create_meshes(verts, faces, colors):
+ """
+ :param verts (B, V, 3)
+ :param faces (B, F, 3)
+ :param colors (B, V, 3)
+ """
+ textures = TexturesVertex(verts_features=colors)
+ meshes = Meshes(verts=verts, faces=faces, textures=textures)
+ return join_meshes_as_scene(meshes)
+
+
+def get_global_cameras(verts, device, distance=5, position=(-5.0, 5.0, 0.0)):
+ positions = torch.tensor([position]).repeat(len(verts), 1)
+ targets = verts.mean(1)
+
+ directions = targets - positions
+ directions = directions / torch.norm(directions, dim=-1).unsqueeze(-1) * distance
+ positions = targets - directions
+
+ rotation = look_at_rotation(positions, targets, ).mT
+ translation = -(rotation @ positions.unsqueeze(-1)).squeeze(-1)
+
+ lights = PointLights(device=device, location=[position])
+ return rotation, translation, lights
+
+
diff --git a/lib/vis/run_vis2.py b/lib/vis/run_vis2.py
new file mode 100644
index 0000000000000000000000000000000000000000..390a1dfe2da0949d14f2a1283fbf30f280d4295d
--- /dev/null
+++ b/lib/vis/run_vis2.py
@@ -0,0 +1,250 @@
+import os
+import cv2
+import numpy as np
+import torch
+import trimesh
+
+import lib.vis.viewer as viewer_utils
+from lib.vis.wham_tools.tools import checkerboard_geometry
+
+def camera_marker_geometry(radius, height):
+ vertices = np.array(
+ [
+ [-radius, -radius, 0],
+ [radius, -radius, 0],
+ [radius, radius, 0],
+ [-radius, radius, 0],
+ [0, 0, - height],
+ ]
+ )
+
+
+ faces = np.array(
+ [[0, 1, 2], [0, 2, 3], [1, 0, 4], [2, 1, 4], [3, 2, 4], [0, 3, 4],]
+ )
+
+ face_colors = np.array(
+ [
+ [0.5, 0.5, 0.5, 1.0],
+ [0.5, 0.5, 0.5, 1.0],
+ [0.0, 1.0, 0.0, 1.0],
+ [1.0, 0.0, 0.0, 1.0],
+ [0.0, 1.0, 0.0, 1.0],
+ [1.0, 0.0, 0.0, 1.0],
+ ]
+ )
+ return vertices, faces, face_colors
+
+
+def run_vis2_on_video(res_dict, res_dict2, output_pth, focal_length, image_names, R_c2w=None, t_c2w=None):
+
+ img0 = cv2.imread(image_names[0])
+ height, width, _ = img0.shape
+
+ world_mano = {}
+ world_mano['vertices'] = res_dict['vertices']
+ world_mano['faces'] = res_dict['faces']
+
+ world_mano2 = {}
+ world_mano2['vertices'] = res_dict2['vertices']
+ world_mano2['faces'] = res_dict2['faces']
+
+ vis_dict = {}
+ color_idx = 0
+ world_mano['vertices'] = world_mano['vertices']
+ for _id, _verts in enumerate(world_mano['vertices']):
+ verts = _verts.cpu().numpy() # T, N, 3
+ body_faces = world_mano['faces']
+ body_meshes = {
+ "v3d": verts,
+ "f3d": body_faces,
+ "vc": None,
+ "name": f"hand_{_id}",
+ # "color": "pace-green",
+ "color": "director-purple",
+ }
+ vis_dict[f"hand_{_id}"] = body_meshes
+ color_idx += 1
+
+ world_mano2['vertices'] = world_mano2['vertices']
+ for _id, _verts in enumerate(world_mano2['vertices']):
+ verts = _verts.cpu().numpy() # T, N, 3
+ body_faces = world_mano2['faces']
+ body_meshes = {
+ "v3d": verts,
+ "f3d": body_faces,
+ "vc": None,
+ "name": f"hand2_{_id}",
+ # "color": "pace-blue",
+ "color": "director-blue",
+ }
+ vis_dict[f"hand2_{_id}"] = body_meshes
+ color_idx += 1
+
+ v, f, vc, fc = checkerboard_geometry(length=100, c1=0, c2=0, up="z")
+ v[:, 2] -= 2 # z plane
+ gound_meshes = {
+ "v3d": v,
+ "f3d": f,
+ "vc": vc,
+ "name": "ground",
+ "fc": fc,
+ "color": -1,
+ }
+ vis_dict["ground"] = gound_meshes
+
+ num_frames = len(world_mano['vertices'][_id])
+ Rt = np.zeros((num_frames, 3, 4))
+ Rt[:, :3, :3] = R_c2w[:num_frames]
+ Rt[:, :3, 3] = t_c2w[:num_frames]
+
+ verts, faces, face_colors = camera_marker_geometry(0.05, 0.1)
+ verts = np.einsum("tij,nj->tni", Rt[:, :3, :3], verts) + Rt[:, None, :3, 3]
+ camera_meshes = {
+ "v3d": verts,
+ "f3d": faces,
+ "vc": None,
+ "name": "camera",
+ "fc": face_colors,
+ "color": -1,
+ }
+ vis_dict["camera"] = camera_meshes
+
+ side_source = torch.tensor([0.463, -0.478, 2.456])
+ side_target = torch.tensor([0.026, -0.481, -3.184])
+ up = torch.tensor([1.0, 0.0, 0.0])
+ view_camera = lookat_matrix(side_source, side_target, up)
+ viewer_Rt = np.tile(view_camera[:3, :4], (num_frames, 1, 1))
+
+ meshes = viewer_utils.construct_viewer_meshes(
+ vis_dict, draw_edges=False, flat_shading=False
+ )
+
+ vis_h, vis_w = (1000, 1000)
+ K = np.array(
+ [
+ [1000, 0, vis_w / 2],
+ [0, 1000, vis_h / 2],
+ [0, 0, 1]
+ ]
+ )
+
+ data = viewer_utils.ViewerData(viewer_Rt, K, vis_w, vis_h)
+ batch = (meshes, data)
+
+ viewer = viewer_utils.ARCTICViewer(interactive=True, size=(vis_w, vis_h))
+ viewer.render_seq(batch, out_folder=os.path.join(output_pth, 'aitviewer'))
+
+def run_vis2_on_video_cam(res_dict, res_dict2, output_pth, focal_length, image_names, R_w2c=None, t_w2c=None):
+
+ img0 = cv2.imread(image_names[0])
+ height, width, _ = img0.shape
+
+ world_mano = {}
+ world_mano['vertices'] = res_dict['vertices']
+ world_mano['faces'] = res_dict['faces']
+
+ world_mano2 = {}
+ world_mano2['vertices'] = res_dict2['vertices']
+ world_mano2['faces'] = res_dict2['faces']
+
+ vis_dict = {}
+ color_idx = 0
+ world_mano['vertices'] = world_mano['vertices']
+ for _id, _verts in enumerate(world_mano['vertices']):
+ verts = _verts.cpu().numpy() # T, N, 3
+ body_faces = world_mano['faces']
+ body_meshes = {
+ "v3d": verts,
+ "f3d": body_faces,
+ "vc": None,
+ "name": f"hand_{_id}",
+ # "color": "pace-green",
+ "color": "director-purple",
+ }
+ vis_dict[f"hand_{_id}"] = body_meshes
+ color_idx += 1
+
+ world_mano2['vertices'] = world_mano2['vertices']
+ for _id, _verts in enumerate(world_mano2['vertices']):
+ verts = _verts.cpu().numpy() # T, N, 3
+ body_faces = world_mano2['faces']
+ body_meshes = {
+ "v3d": verts,
+ "f3d": body_faces,
+ "vc": None,
+ "name": f"hand2_{_id}",
+ # "color": "pace-blue",
+ "color": "director-blue",
+ }
+ vis_dict[f"hand2_{_id}"] = body_meshes
+ color_idx += 1
+
+ meshes = viewer_utils.construct_viewer_meshes(
+ vis_dict, draw_edges=False, flat_shading=False
+ )
+
+ num_frames = len(world_mano['vertices'][_id])
+ Rt = np.zeros((num_frames, 3, 4))
+ Rt[:, :3, :3] = R_w2c[:num_frames]
+ Rt[:, :3, 3] = t_w2c[:num_frames]
+
+ cols, rows = (width, height)
+ K = np.array(
+ [
+ [focal_length, 0, width / 2],
+ [0, focal_length, height / 2],
+ [0, 0, 1]
+ ]
+ )
+ vis_h = height
+ vis_w = width
+
+ data = viewer_utils.ViewerData(Rt, K, cols, rows, imgnames=image_names)
+ batch = (meshes, data)
+
+ viewer = viewer_utils.ARCTICViewer(interactive=True, size=(vis_w, vis_h))
+ viewer.render_seq(batch, out_folder=os.path.join(output_pth, 'aitviewer'))
+
+def lookat_matrix(source_pos, target_pos, up):
+ """
+ IMPORTANT: USES RIGHT UP BACK XYZ CONVENTION
+ :param source_pos (*, 3)
+ :param target_pos (*, 3)
+ :param up (3,)
+ """
+ *dims, _ = source_pos.shape
+ up = up.reshape(*(1,) * len(dims), 3)
+ up = up / torch.linalg.norm(up, dim=-1, keepdim=True)
+ back = normalize(target_pos - source_pos)
+ right = normalize(torch.linalg.cross(up, back))
+ up = normalize(torch.linalg.cross(back, right))
+ R = torch.stack([right, up, back], dim=-1)
+ return make_4x4_pose(R, source_pos)
+
+def make_4x4_pose(R, t):
+ """
+ :param R (*, 3, 3)
+ :param t (*, 3)
+ return (*, 4, 4)
+ """
+ dims = R.shape[:-2]
+ pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1)
+ bottom = (
+ torch.tensor([0, 0, 0, 1], device=R.device)
+ .reshape(*(1,) * len(dims), 1, 4)
+ .expand(*dims, 1, 4)
+ )
+ return torch.cat([pose_3x4, bottom], dim=-2)
+
+def normalize(x):
+ return x / torch.linalg.norm(x, dim=-1, keepdim=True)
+
+def save_mesh_to_obj(vertices, faces, file_path):
+ # 创建一个 Trimesh 对象
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
+
+ # 导出为 .obj 文件
+ mesh.export(file_path)
+ print(f"Mesh saved to {file_path}")
+
diff --git a/lib/vis/tools.py b/lib/vis/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..da8695e6391633b6143adebfb154566ee6ebf697
--- /dev/null
+++ b/lib/vis/tools.py
@@ -0,0 +1,825 @@
+# Useful visualization functions from SLAHMR and WHAM
+
+import os
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+
+
+def read_image(path, scale=1):
+ im = Image.open(path)
+ if scale == 1:
+ return np.array(im)
+ W, H = im.size
+ w, h = int(scale * W), int(scale * H)
+ return np.array(im.resize((w, h), Image.ANTIALIAS))
+
+
+def transform_torch3d(T_c2w):
+ """
+ :param T_c2w (*, 4, 4)
+ returns (*, 3, 3), (*, 3)
+ """
+ R1 = torch.tensor(
+ [[-1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, 1.0],], device=T_c2w.device,
+ )
+ R2 = torch.tensor(
+ [[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0],], device=T_c2w.device,
+ )
+ cam_R, cam_t = T_c2w[..., :3, :3], T_c2w[..., :3, 3]
+ cam_R = torch.einsum("...ij,jk->...ik", cam_R, R1)
+ cam_t = torch.einsum("ij,...j->...i", R2, cam_t)
+ return cam_R, cam_t
+
+
+def transform_pyrender(T_c2w):
+ """
+ :param T_c2w (*, 4, 4)
+ """
+ T_vis = torch.tensor(
+ [
+ [1.0, 0.0, 0.0, 0.0],
+ [0.0, -1.0, 0.0, 0.0],
+ [0.0, 0.0, -1.0, 0.0],
+ [0.0, 0.0, 0.0, 1.0],
+ ],
+ device=T_c2w.device,
+ )
+ return torch.einsum(
+ "...ij,jk->...ik", torch.einsum("ij,...jk->...ik", T_vis, T_c2w), T_vis
+ )
+
+
+def smpl_to_geometry(verts, faces, vis_mask=None, track_ids=None):
+ """
+ :param verts (B, T, V, 3)
+ :param faces (F, 3)
+ :param vis_mask (optional) (B, T) visibility of each person
+ :param track_ids (optional) (B,)
+ returns list of T verts (B, V, 3), faces (F, 3), colors (B, 3)
+ where B is different depending on the visibility of the people
+ """
+ B, T = verts.shape[:2]
+ device = verts.device
+
+ # (B, 3)
+ colors = (
+ track_to_colors(track_ids)
+ if track_ids is not None
+ else torch.ones(B, 3, device) * 0.5
+ )
+
+ # list T (B, V, 3), T (B, 3), T (F, 3)
+ return filter_visible_meshes(verts, colors, faces, vis_mask)
+
+
+def filter_visible_meshes(verts, colors, faces, vis_mask=None, vis_opacity=False):
+ """
+ :param verts (B, T, V, 3)
+ :param colors (B, 3)
+ :param faces (F, 3)
+ :param vis_mask (optional tensor, default None) (B, T) ternary mask
+ -1 if not in frame
+ 0 if temporarily occluded
+ 1 if visible
+ :param vis_opacity (optional bool, default False)
+ if True, make occluded people alpha=0.5, otherwise alpha=1
+ returns a list of T lists verts (Bi, V, 3), colors (Bi, 4), faces (F, 3)
+ """
+ # import ipdb; ipdb.set_trace()
+ B, T = verts.shape[:2]
+ faces = [faces for t in range(T)]
+ if vis_mask is None:
+ verts = [verts[:, t] for t in range(T)]
+ colors = [colors for t in range(T)]
+ return verts, colors, faces
+
+ # render occluded and visible, but not removed
+ vis_mask = vis_mask >= 0
+ if vis_opacity:
+ alpha = 0.5 * (vis_mask[..., None] + 1)
+ else:
+ alpha = (vis_mask[..., None] >= 0).float()
+ vert_list = [verts[vis_mask[:, t], t] for t in range(T)]
+ colors = [
+ torch.cat([colors[vis_mask[:, t]], alpha[vis_mask[:, t], t]], dim=-1)
+ for t in range(T)
+ ]
+ bounds = get_bboxes(verts, vis_mask)
+ return vert_list, colors, faces, bounds
+
+
+def get_bboxes(verts, vis_mask):
+ """
+ return bb_min, bb_max, and mean for each track (B, 3) over entire trajectory
+ :param verts (B, T, V, 3)
+ :param vis_mask (B, T)
+ """
+ B, T, *_ = verts.shape
+ bb_min, bb_max, mean = [], [], []
+ for b in range(B):
+ v = verts[b, vis_mask[b, :T]] # (Tb, V, 3)
+ bb_min.append(v.amin(dim=(0, 1)))
+ bb_max.append(v.amax(dim=(0, 1)))
+ mean.append(v.mean(dim=(0, 1)))
+ bb_min = torch.stack(bb_min, dim=0)
+ bb_max = torch.stack(bb_max, dim=0)
+ mean = torch.stack(mean, dim=0)
+ # point to a track that's long and close to the camera
+ zs = mean[:, 2]
+ counts = vis_mask[:, :T].sum(dim=-1) # (B,)
+ mask = counts < 0.8 * T
+ zs[mask] = torch.inf
+ sel = torch.argmin(zs)
+ return bb_min.amin(dim=0), bb_max.amax(dim=0), mean[sel]
+
+
+def track_to_colors(track_ids):
+ """
+ :param track_ids (B)
+ """
+ color_map = torch.from_numpy(get_colors()).to(track_ids)
+ return color_map[track_ids] / 255 # (B, 3)
+
+
+def get_colors():
+ # color_file = os.path.abspath(os.path.join(__file__, "../colors_phalp.txt"))
+ color_file = os.path.abspath(os.path.join(__file__, "../colors.txt"))
+ RGB_tuples = np.vstack(
+ [
+ np.loadtxt(color_file, skiprows=0),
+ # np.loadtxt(color_file, skiprows=1),
+ np.random.uniform(0, 255, size=(10000, 3)),
+ [[0, 0, 0]],
+ ]
+ )
+ b = np.where(RGB_tuples == 0)
+ RGB_tuples[b] = 1
+ return RGB_tuples.astype(np.float32)
+
+
+def checkerboard_geometry(
+ length=12.0,
+ color0=[0.8, 0.9, 0.9],
+ color1=[0.6, 0.7, 0.7],
+ tile_width=0.5,
+ alpha=1.0,
+ up="y",
+ c1=0.0,
+ c2=0.0,
+):
+ assert up == "y" or up == "z"
+ color0 = np.array(color0 + [alpha])
+ color1 = np.array(color1 + [alpha])
+ radius = length / 2.0
+ num_rows = num_cols = max(2, int(length / tile_width))
+ vertices = []
+ vert_colors = []
+ faces = []
+ face_colors = []
+ for i in range(num_rows):
+ for j in range(num_cols):
+ u0, v0 = j * tile_width - radius, i * tile_width - radius
+ us = np.array([u0, u0, u0 + tile_width, u0 + tile_width])
+ vs = np.array([v0, v0 + tile_width, v0 + tile_width, v0])
+ zs = np.zeros(4)
+ if up == "y":
+ cur_verts = np.stack([us, zs, vs], axis=-1) # (4, 3)
+ cur_verts[:, 0] += c1
+ cur_verts[:, 2] += c2
+ else:
+ cur_verts = np.stack([us, vs, zs], axis=-1) # (4, 3)
+ cur_verts[:, 0] += c1
+ cur_verts[:, 1] += c2
+
+ cur_faces = np.array(
+ [[0, 1, 3], [1, 2, 3], [0, 3, 1], [1, 3, 2]], dtype=np.int64
+ )
+ cur_faces += 4 * (i * num_cols + j) # the number of previously added verts
+ use_color0 = (i % 2 == 0 and j % 2 == 0) or (i % 2 == 1 and j % 2 == 1)
+ cur_color = color0 if use_color0 else color1
+ cur_colors = np.array([cur_color, cur_color, cur_color, cur_color])
+
+ vertices.append(cur_verts)
+ faces.append(cur_faces)
+ vert_colors.append(cur_colors)
+ face_colors.append(cur_colors)
+
+ vertices = np.concatenate(vertices, axis=0).astype(np.float32)
+ vert_colors = np.concatenate(vert_colors, axis=0).astype(np.float32)
+ faces = np.concatenate(faces, axis=0).astype(np.float32)
+ face_colors = np.concatenate(face_colors, axis=0).astype(np.float32)
+
+ return vertices, faces, vert_colors, face_colors
+
+
+def camera_marker_geometry(radius, height, up):
+ assert up == "y" or up == "z"
+ if up == "y":
+ vertices = np.array(
+ [
+ [-radius, -radius, 0],
+ [radius, -radius, 0],
+ [radius, radius, 0],
+ [-radius, radius, 0],
+ [0, 0, height],
+ ]
+ )
+ else:
+ vertices = np.array(
+ [
+ [-radius, 0, -radius],
+ [radius, 0, -radius],
+ [radius, 0, radius],
+ [-radius, 0, radius],
+ [0, -height, 0],
+ ]
+ )
+
+ faces = np.array(
+ [[0, 3, 1], [1, 3, 2], [0, 1, 4], [1, 2, 4], [2, 3, 4], [3, 0, 4],]
+ )
+
+ face_colors = np.array(
+ [
+ [1.0, 1.0, 1.0, 1.0],
+ [1.0, 1.0, 1.0, 1.0],
+ [0.0, 1.0, 0.0, 1.0],
+ [1.0, 0.0, 0.0, 1.0],
+ [0.0, 1.0, 0.0, 1.0],
+ [1.0, 0.0, 0.0, 1.0],
+ ]
+ )
+ return vertices, faces, face_colors
+
+
+def vis_keypoints(
+ keypts_list,
+ img_size,
+ radius=6,
+ thickness=3,
+ kpt_score_thr=0.3,
+ dataset="TopDownCocoDataset",
+):
+ """
+ Visualize keypoints
+ From ViTPose/mmpose/apis/inference.py
+ """
+ palette = np.array(
+ [
+ [255, 128, 0],
+ [255, 153, 51],
+ [255, 178, 102],
+ [230, 230, 0],
+ [255, 153, 255],
+ [153, 204, 255],
+ [255, 102, 255],
+ [255, 51, 255],
+ [102, 178, 255],
+ [51, 153, 255],
+ [255, 153, 153],
+ [255, 102, 102],
+ [255, 51, 51],
+ [153, 255, 153],
+ [102, 255, 102],
+ [51, 255, 51],
+ [0, 255, 0],
+ [0, 0, 255],
+ [255, 0, 0],
+ [255, 255, 255],
+ ]
+ )
+
+ if dataset in (
+ "TopDownCocoDataset",
+ "BottomUpCocoDataset",
+ "TopDownOCHumanDataset",
+ "AnimalMacaqueDataset",
+ ):
+ # show the results
+ skeleton = [
+ [15, 13],
+ [13, 11],
+ [16, 14],
+ [14, 12],
+ [11, 12],
+ [5, 11],
+ [6, 12],
+ [5, 6],
+ [5, 7],
+ [6, 8],
+ [7, 9],
+ [8, 10],
+ [1, 2],
+ [0, 1],
+ [0, 2],
+ [1, 3],
+ [2, 4],
+ [3, 5],
+ [4, 6],
+ ]
+
+ pose_link_color = palette[
+ [0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]
+ ]
+ pose_kpt_color = palette[
+ [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]
+ ]
+
+ elif dataset == "TopDownCocoWholeBodyDataset":
+ # show the results
+ skeleton = [
+ [15, 13],
+ [13, 11],
+ [16, 14],
+ [14, 12],
+ [11, 12],
+ [5, 11],
+ [6, 12],
+ [5, 6],
+ [5, 7],
+ [6, 8],
+ [7, 9],
+ [8, 10],
+ [1, 2],
+ [0, 1],
+ [0, 2],
+ [1, 3],
+ [2, 4],
+ [3, 5],
+ [4, 6],
+ [15, 17],
+ [15, 18],
+ [15, 19],
+ [16, 20],
+ [16, 21],
+ [16, 22],
+ [91, 92],
+ [92, 93],
+ [93, 94],
+ [94, 95],
+ [91, 96],
+ [96, 97],
+ [97, 98],
+ [98, 99],
+ [91, 100],
+ [100, 101],
+ [101, 102],
+ [102, 103],
+ [91, 104],
+ [104, 105],
+ [105, 106],
+ [106, 107],
+ [91, 108],
+ [108, 109],
+ [109, 110],
+ [110, 111],
+ [112, 113],
+ [113, 114],
+ [114, 115],
+ [115, 116],
+ [112, 117],
+ [117, 118],
+ [118, 119],
+ [119, 120],
+ [112, 121],
+ [121, 122],
+ [122, 123],
+ [123, 124],
+ [112, 125],
+ [125, 126],
+ [126, 127],
+ [127, 128],
+ [112, 129],
+ [129, 130],
+ [130, 131],
+ [131, 132],
+ ]
+
+ pose_link_color = palette[
+ [0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]
+ + [16, 16, 16, 16, 16, 16]
+ + [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
+ + [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
+ ]
+ pose_kpt_color = palette[
+ [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]
+ + [0, 0, 0, 0, 0, 0]
+ + [19] * (68 + 42)
+ ]
+
+ elif dataset == "TopDownAicDataset":
+ skeleton = [
+ [2, 1],
+ [1, 0],
+ [0, 13],
+ [13, 3],
+ [3, 4],
+ [4, 5],
+ [8, 7],
+ [7, 6],
+ [6, 9],
+ [9, 10],
+ [10, 11],
+ [12, 13],
+ [0, 6],
+ [3, 9],
+ ]
+
+ pose_link_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 0, 7, 7]]
+ pose_kpt_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 0, 0]]
+
+ elif dataset == "TopDownMpiiDataset":
+ skeleton = [
+ [0, 1],
+ [1, 2],
+ [2, 6],
+ [6, 3],
+ [3, 4],
+ [4, 5],
+ [6, 7],
+ [7, 8],
+ [8, 9],
+ [8, 12],
+ [12, 11],
+ [11, 10],
+ [8, 13],
+ [13, 14],
+ [14, 15],
+ ]
+
+ pose_link_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 9, 9, 9, 9, 9, 9]]
+ pose_kpt_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 0, 9, 9, 9, 9, 9, 9]]
+
+ elif dataset == "TopDownMpiiTrbDataset":
+ skeleton = [
+ [12, 13],
+ [13, 0],
+ [13, 1],
+ [0, 2],
+ [1, 3],
+ [2, 4],
+ [3, 5],
+ [0, 6],
+ [1, 7],
+ [6, 7],
+ [6, 8],
+ [7, 9],
+ [8, 10],
+ [9, 11],
+ [14, 15],
+ [16, 17],
+ [18, 19],
+ [20, 21],
+ [22, 23],
+ [24, 25],
+ [26, 27],
+ [28, 29],
+ [30, 31],
+ [32, 33],
+ [34, 35],
+ [36, 37],
+ [38, 39],
+ ]
+
+ pose_link_color = palette[[16] * 14 + [19] * 13]
+ pose_kpt_color = palette[[16] * 14 + [0] * 26]
+
+ elif dataset in ("OneHand10KDataset", "FreiHandDataset", "PanopticDataset"):
+ skeleton = [
+ [0, 1],
+ [1, 2],
+ [2, 3],
+ [3, 4],
+ [0, 5],
+ [5, 6],
+ [6, 7],
+ [7, 8],
+ [0, 9],
+ [9, 10],
+ [10, 11],
+ [11, 12],
+ [0, 13],
+ [13, 14],
+ [14, 15],
+ [15, 16],
+ [0, 17],
+ [17, 18],
+ [18, 19],
+ [19, 20],
+ ]
+
+ pose_link_color = palette[
+ [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
+ ]
+ pose_kpt_color = palette[
+ [0, 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
+ ]
+
+ elif dataset == "InterHand2DDataset":
+ skeleton = [
+ [0, 1],
+ [1, 2],
+ [2, 3],
+ [4, 5],
+ [5, 6],
+ [6, 7],
+ [8, 9],
+ [9, 10],
+ [10, 11],
+ [12, 13],
+ [13, 14],
+ [14, 15],
+ [16, 17],
+ [17, 18],
+ [18, 19],
+ [3, 20],
+ [7, 20],
+ [11, 20],
+ [15, 20],
+ [19, 20],
+ ]
+
+ pose_link_color = palette[
+ [0, 0, 0, 4, 4, 4, 8, 8, 8, 12, 12, 12, 16, 16, 16, 0, 4, 8, 12, 16]
+ ]
+ pose_kpt_color = palette[
+ [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16, 0]
+ ]
+
+ elif dataset == "Face300WDataset":
+ # show the results
+ skeleton = []
+
+ pose_link_color = palette[[]]
+ pose_kpt_color = palette[[19] * 68]
+ kpt_score_thr = 0
+
+ elif dataset == "FaceAFLWDataset":
+ # show the results
+ skeleton = []
+
+ pose_link_color = palette[[]]
+ pose_kpt_color = palette[[19] * 19]
+ kpt_score_thr = 0
+
+ elif dataset == "FaceCOFWDataset":
+ # show the results
+ skeleton = []
+
+ pose_link_color = palette[[]]
+ pose_kpt_color = palette[[19] * 29]
+ kpt_score_thr = 0
+
+ elif dataset == "FaceWFLWDataset":
+ # show the results
+ skeleton = []
+
+ pose_link_color = palette[[]]
+ pose_kpt_color = palette[[19] * 98]
+ kpt_score_thr = 0
+
+ elif dataset == "AnimalHorse10Dataset":
+ skeleton = [
+ [0, 1],
+ [1, 12],
+ [12, 16],
+ [16, 21],
+ [21, 17],
+ [17, 11],
+ [11, 10],
+ [10, 8],
+ [8, 9],
+ [9, 12],
+ [2, 3],
+ [3, 4],
+ [5, 6],
+ [6, 7],
+ [13, 14],
+ [14, 15],
+ [18, 19],
+ [19, 20],
+ ]
+
+ pose_link_color = palette[[4] * 10 + [6] * 2 + [6] * 2 + [7] * 2 + [7] * 2]
+ pose_kpt_color = palette[
+ [4, 4, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 4, 7, 7, 7, 4, 4, 7, 7, 7, 4]
+ ]
+
+ elif dataset == "AnimalFlyDataset":
+ skeleton = [
+ [1, 0],
+ [2, 0],
+ [3, 0],
+ [4, 3],
+ [5, 4],
+ [7, 6],
+ [8, 7],
+ [9, 8],
+ [11, 10],
+ [12, 11],
+ [13, 12],
+ [15, 14],
+ [16, 15],
+ [17, 16],
+ [19, 18],
+ [20, 19],
+ [21, 20],
+ [23, 22],
+ [24, 23],
+ [25, 24],
+ [27, 26],
+ [28, 27],
+ [29, 28],
+ [30, 3],
+ [31, 3],
+ ]
+
+ pose_link_color = palette[[0] * 25]
+ pose_kpt_color = palette[[0] * 32]
+
+ elif dataset == "AnimalLocustDataset":
+ skeleton = [
+ [1, 0],
+ [2, 1],
+ [3, 2],
+ [4, 3],
+ [6, 5],
+ [7, 6],
+ [9, 8],
+ [10, 9],
+ [11, 10],
+ [13, 12],
+ [14, 13],
+ [15, 14],
+ [17, 16],
+ [18, 17],
+ [19, 18],
+ [21, 20],
+ [22, 21],
+ [24, 23],
+ [25, 24],
+ [26, 25],
+ [28, 27],
+ [29, 28],
+ [30, 29],
+ [32, 31],
+ [33, 32],
+ [34, 33],
+ ]
+
+ pose_link_color = palette[[0] * 26]
+ pose_kpt_color = palette[[0] * 35]
+
+ elif dataset == "AnimalZebraDataset":
+ skeleton = [[1, 0], [2, 1], [3, 2], [4, 2], [5, 7], [6, 7], [7, 2], [8, 7]]
+
+ pose_link_color = palette[[0] * 8]
+ pose_kpt_color = palette[[0] * 9]
+
+ elif dataset in "AnimalPoseDataset":
+ skeleton = [
+ [0, 1],
+ [0, 2],
+ [1, 3],
+ [0, 4],
+ [1, 4],
+ [4, 5],
+ [5, 7],
+ [6, 7],
+ [5, 8],
+ [8, 12],
+ [12, 16],
+ [5, 9],
+ [9, 13],
+ [13, 17],
+ [6, 10],
+ [10, 14],
+ [14, 18],
+ [6, 11],
+ [11, 15],
+ [15, 19],
+ ]
+
+ pose_link_color = palette[[0] * 20]
+ pose_kpt_color = palette[[0] * 20]
+ else:
+ NotImplementedError()
+
+ img_w, img_h = img_size
+ img = 255 * np.ones((img_h, img_w, 3), dtype=np.uint8)
+ img = imshow_keypoints(
+ img,
+ keypts_list,
+ skeleton,
+ kpt_score_thr,
+ pose_kpt_color,
+ pose_link_color,
+ radius,
+ thickness,
+ )
+ alpha = 255 * (img != 255).any(axis=-1, keepdims=True).astype(np.uint8)
+ return np.concatenate([img, alpha], axis=-1)
+
+
+def imshow_keypoints(
+ img,
+ pose_result,
+ skeleton=None,
+ kpt_score_thr=0.3,
+ pose_kpt_color=None,
+ pose_link_color=None,
+ radius=4,
+ thickness=1,
+ show_keypoint_weight=False,
+):
+ """Draw keypoints and links on an image.
+ From ViTPose/mmpose/core/visualization/image.py
+
+ Args:
+ img (H, W, 3) array
+ pose_result (list[kpts]): The poses to draw. Each element kpts is
+ a set of K keypoints as an Kx3 numpy.ndarray, where each
+ keypoint is represented as x, y, score.
+ kpt_score_thr (float, optional): Minimum score of keypoints
+ to be shown. Default: 0.3.
+ pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
+ the keypoint will not be drawn.
+ pose_link_color (np.array[Mx3]): Color of M links. If None, the
+ links will not be drawn.
+ thickness (int): Thickness of lines.
+ show_keypoint_weight (bool): If True, opacity indicates keypoint score
+ """
+ import math
+ img_h, img_w, _ = img.shape
+ idcs = [0, 16, 15, 18, 17, 5, 2, 6, 3, 7, 4, 12, 9, 13, 10, 14, 11]
+ for kpts in pose_result:
+ kpts = np.array(kpts, copy=False)[idcs]
+
+ # draw each point on image
+ if pose_kpt_color is not None:
+ assert len(pose_kpt_color) == len(kpts)
+ for kid, kpt in enumerate(kpts):
+ x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
+ if kpt_score > kpt_score_thr:
+ color = tuple(int(c) for c in pose_kpt_color[kid])
+ if show_keypoint_weight:
+ img_copy = img.copy()
+ cv2.circle(
+ img_copy, (int(x_coord), int(y_coord)), radius, color, -1
+ )
+ transparency = max(0, min(1, kpt_score))
+ cv2.addWeighted(
+ img_copy, transparency, img, 1 - transparency, 0, dst=img
+ )
+ else:
+ cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
+
+ # draw links
+ if skeleton is not None and pose_link_color is not None:
+ assert len(pose_link_color) == len(skeleton)
+ for sk_id, sk in enumerate(skeleton):
+ pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
+ pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
+ if (
+ pos1[0] > 0
+ and pos1[0] < img_w
+ and pos1[1] > 0
+ and pos1[1] < img_h
+ and pos2[0] > 0
+ and pos2[0] < img_w
+ and pos2[1] > 0
+ and pos2[1] < img_h
+ and kpts[sk[0], 2] > kpt_score_thr
+ and kpts[sk[1], 2] > kpt_score_thr
+ ):
+ color = tuple(int(c) for c in pose_link_color[sk_id])
+ if show_keypoint_weight:
+ img_copy = img.copy()
+ X = (pos1[0], pos2[0])
+ Y = (pos1[1], pos2[1])
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
+ stickwidth = 2
+ polygon = cv2.ellipse2Poly(
+ (int(mX), int(mY)),
+ (int(length / 2), int(stickwidth)),
+ int(angle),
+ 0,
+ 360,
+ 1,
+ )
+ cv2.fillConvexPoly(img_copy, polygon, color)
+ transparency = max(
+ 0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))
+ )
+ cv2.addWeighted(
+ img_copy, transparency, img, 1 - transparency, 0, dst=img
+ )
+ else:
+ cv2.line(img, pos1, pos2, color, thickness=thickness)
+
+ return img
\ No newline at end of file
diff --git a/lib/vis/viewer.py b/lib/vis/viewer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f0f2ba200c3066bda4c6193a9645f39f27e20a9
--- /dev/null
+++ b/lib/vis/viewer.py
@@ -0,0 +1,308 @@
+import os
+import os.path as op
+import re
+from abc import abstractmethod
+
+import matplotlib.cm as cm
+import numpy as np
+from aitviewer.headless import HeadlessRenderer
+from aitviewer.renderables.billboard import Billboard
+from aitviewer.renderables.meshes import Meshes
+from aitviewer.scene.camera import OpenCVCamera
+from aitviewer.scene.material import Material
+from aitviewer.utils.so3 import aa2rot_numpy
+from aitviewer.viewer import Viewer
+from easydict import EasyDict as edict
+from loguru import logger
+from PIL import Image
+from tqdm import tqdm
+import random
+
+OBJ_ID = 100
+SMPLX_ID = 150
+LEFT_ID = 200
+RIGHT_ID = 250
+SEGM_IDS = {"object": OBJ_ID, "smplx": SMPLX_ID, "left": LEFT_ID, "right": RIGHT_ID}
+
+cmap = cm.get_cmap("plasma")
+materials = {
+ "white": Material(color=(1.0, 1.0, 1.0, 1.0), ambient=0.2),
+ "green": Material(color=(0.0, 1.0, 0.0, 1.0), ambient=0.2),
+ "blue": Material(color=(0.0, 0.0, 1.0, 1.0), ambient=0.2),
+ "red": Material(color=(0.969, 0.106, 0.059, 1.0), ambient=0.2),
+ "cyan": Material(color=(0.051, 0.659, 0.051, 1.0), ambient=0.2),
+ "light-blue": Material(color=(0.588, 0.5647, 0.9725, 1.0), ambient=0.2),
+ "cyan-light": Material(color=(0.051, 0.659, 0.051, 1.0), ambient=0.2),
+ "dark-light": Material(color=(0.404, 0.278, 0.278, 1.0), ambient=0.2),
+ "rice": Material(color=(0.922, 0.922, 0.102, 1.0), ambient=0.2),
+ "whac-whac": Material(color=(167/255, 193/255, 203/255, 1.0), ambient=0.2),
+ "whac-wham": Material(color=(165/255, 153/255, 174/255, 1.0), ambient=0.2),
+ "pace-blue": Material(color=(0.584, 0.902, 0.976, 1.0), ambient=0.2),
+ "pace-green": Material(color=(0.631, 1.0, 0.753, 1.0), ambient=0.2),
+ "director-purple": Material(color=(0.804, 0.6, 0.820, 1.0), ambient=0.2),
+ "director-blue": Material(color=(0.207, 0.596, 0.792, 1.0), ambient=0.2),
+ "none": None,
+}
+color_list = list(materials.keys())
+
+def random_material():
+ return Material(color=(random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1), 1), ambient=0.2)
+
+
+class ViewerData(edict):
+ """
+ Interface to standardize viewer data.
+ """
+
+ def __init__(self, Rt, K, cols, rows, imgnames=None):
+ self.imgnames = imgnames
+ self.Rt = Rt
+ self.K = K
+ self.num_frames = Rt.shape[0]
+ self.cols = cols
+ self.rows = rows
+ self.validate_format()
+
+ def validate_format(self):
+ assert len(self.Rt.shape) == 3
+ assert self.Rt.shape[0] == self.num_frames
+ assert self.Rt.shape[1] == 3
+ assert self.Rt.shape[2] == 4
+
+ assert len(self.K.shape) == 2
+ assert self.K.shape[0] == 3
+ assert self.K.shape[1] == 3
+ if self.imgnames is not None:
+ assert self.num_frames == len(self.imgnames)
+ assert self.num_frames > 0
+ im_p = self.imgnames[0]
+ assert op.exists(im_p), f"Image path {im_p} does not exist"
+
+
+class ARCTICViewer:
+ def __init__(
+ self,
+ render_types=["rgb", "depth", "mask"],
+ interactive=True,
+ size=(2024, 2024),
+ ):
+ if not interactive:
+ v = HeadlessRenderer()
+ else:
+ v = Viewer(size=size)
+
+ self.v = v
+ self.interactive = interactive
+ # self.layers = layers
+ self.render_types = render_types
+
+ def view_interactive(self):
+ self.v.run()
+
+ def view_fn_headless(self, num_iter, out_folder):
+ v = self.v
+
+ v._init_scene()
+
+ logger.info("Rendering to video")
+ if "video" in self.render_types:
+ vid_p = op.join(out_folder, "video.mp4")
+ v.save_video(video_dir=vid_p)
+
+ pbar = tqdm(range(num_iter))
+ for fidx in pbar:
+ out_rgb = op.join(out_folder, "images", f"rgb/{fidx:04d}.png")
+ out_mask = op.join(out_folder, "images", f"mask/{fidx:04d}.png")
+ out_depth = op.join(out_folder, "images", f"depth/{fidx:04d}.npy")
+
+ # render RGB, depth, segmentation masks
+ if "rgb" in self.render_types:
+ v.export_frame(out_rgb)
+ if "depth" in self.render_types:
+ os.makedirs(op.dirname(out_depth), exist_ok=True)
+ render_depth(v, out_depth)
+ if "mask" in self.render_types:
+ os.makedirs(op.dirname(out_mask), exist_ok=True)
+ render_mask(v, out_mask)
+ v.scene.next_frame()
+ logger.info(f"Exported to {out_folder}")
+
+ @abstractmethod
+ def load_data(self):
+ pass
+
+ def check_format(self, batch):
+ meshes_all, data = batch
+ assert isinstance(meshes_all, dict)
+ assert len(meshes_all) > 0
+ for mesh in meshes_all.values():
+ assert isinstance(mesh, Meshes)
+ assert isinstance(data, ViewerData)
+
+ def render_seq(self, batch, out_folder="./render_out", floor_y=0):
+ meshes_all, data = batch
+ self.setup_viewer(data, floor_y)
+ for mesh in meshes_all.values():
+ self.v.scene.add(mesh)
+ if self.interactive:
+ self.view_interactive()
+ else:
+ num_iter = data["num_frames"]
+ self.view_fn_headless(num_iter, out_folder)
+
+ def setup_viewer(self, data, floor_y):
+ v = self.v
+ fps = 30
+ if "imgnames" in data:
+ setup_billboard(data, v)
+
+ # camera.show_path()
+ v.run_animations = True # autoplay
+ v.run_animations = False # autoplay
+ v.playback_fps = fps
+ v.scene.fps = fps
+ v.scene.origin.enabled = False
+ v.scene.floor.enabled = False
+ v.auto_set_floor = False
+ # v.scene.camera.position = np.array((0.0, 0.0, 0))
+ self.v = v
+
+
+def dist2vc(dist_ro, dist_lo, dist_o, _cmap, tf_fn=None):
+ if tf_fn is not None:
+ exp_map = tf_fn
+ else:
+ exp_map = small_exp_map
+ dist_ro = exp_map(dist_ro)
+ dist_lo = exp_map(dist_lo)
+ dist_o = exp_map(dist_o)
+
+ vc_ro = _cmap(dist_ro)
+ vc_lo = _cmap(dist_lo)
+ vc_o = _cmap(dist_o)
+ return vc_ro, vc_lo, vc_o
+
+
+def small_exp_map(_dist):
+ dist = np.copy(_dist)
+ # dist = 1.0 - np.clip(dist, 0, 0.1) / 0.1
+ dist = np.exp(-20.0 * dist)
+ return dist
+
+
+def construct_viewer_meshes(data, draw_edges=False, flat_shading=True):
+ rotation_flip = aa2rot_numpy(np.array([1, 0, 0]) * np.pi)
+ meshes = {}
+ for key, val in data.items():
+ if 'single' in key:
+ draw_edges = True
+ else:
+ draw_edges = False
+ if "object" in key:
+ flat_shading = False
+ else:
+ flat_shading = False
+ v3d = val["v3d"]
+ if not isinstance(val["color"], str):
+ val["color"] = color_list[val["color"]]
+ if val["color"] == "random":
+ mesh_material = random_material()
+ else:
+ mesh_material = materials[val["color"]]
+ meshes[key] = Meshes(
+ v3d,
+ val["f3d"],
+ vertex_colors=val["vc"],
+ face_colors=val["fc"] if "fc" in val else None,
+ name=val["name"],
+ flat_shading=flat_shading,
+ draw_edges=draw_edges,
+ material=mesh_material,
+ # rotation=rotation_flip,
+ )
+ return meshes
+
+
+def setup_viewer(
+ v, shared_folder_p, video, images_path, data, flag, seq_name, side_angle
+):
+ fps = 10
+ cols, rows = 224, 224
+ focal = 1000.0
+
+ # setup image paths
+ regex = re.compile(r"(\d*)$")
+
+ def sort_key(x):
+ name = os.path.splitext(x)[0]
+ return int(regex.search(name).group(0))
+
+ # setup billboard
+ images_path = op.join(shared_folder_p, "images")
+ images_paths = [
+ os.path.join(images_path, f)
+ for f in sorted(os.listdir(images_path), key=sort_key)
+ ]
+ assert len(images_paths) > 0
+
+ cam_t = data[f"{flag}.object.cam_t"]
+ num_frames = min(cam_t.shape[0], len(images_paths))
+ cam_t = cam_t[:num_frames]
+ # setup camera
+ K = np.array([[focal, 0, rows / 2.0], [0, focal, cols / 2.0], [0, 0, 1]])
+ Rt = np.zeros((num_frames, 3, 4))
+ Rt[:, :, 3] = cam_t
+ Rt[:, :3, :3] = np.eye(3)
+ Rt[:, 1:3, :3] *= -1.0
+
+ camera = OpenCVCamera(K, Rt, cols, rows, viewer=v)
+ if side_angle is None:
+ billboard = Billboard.from_camera_and_distance(
+ camera, 10.0, cols, rows, images_paths
+ )
+ v.scene.add(billboard)
+ v.scene.add(camera)
+ v.run_animations = True # autoplay
+ v.playback_fps = fps
+ v.scene.fps = fps
+ v.scene.origin.enabled = False
+ v.scene.floor.enabled = False
+ v.auto_set_floor = False
+ v.scene.floor.position[1] = -3
+ v.set_temp_camera(camera)
+ # v.scene.camera.position = np.array((0.0, 0.0, 0))
+ return v
+
+
+def render_depth(v, depth_p):
+ depth = np.array(v.get_depth()).astype(np.float16)
+ np.save(depth_p, depth)
+
+
+def render_mask(v, mask_p):
+ nodes_uid = {node.name: node.uid for node in v.scene.collect_nodes()}
+ my_cmap = {
+ uid: [SEGM_IDS[name], SEGM_IDS[name], SEGM_IDS[name]]
+ for name, uid in nodes_uid.items()
+ if name in SEGM_IDS.keys()
+ }
+ mask = np.array(v.get_mask(color_map=my_cmap)).astype(np.uint8)
+ mask = Image.fromarray(mask)
+ mask.save(mask_p)
+
+
+def setup_billboard(data, v):
+ images_paths = data.imgnames
+ K = data.K
+ Rt = data.Rt
+ rows = data.rows
+ cols = data.cols
+ camera = OpenCVCamera(K, Rt, cols, rows, viewer=v)
+ if images_paths is not None:
+ billboard = Billboard.from_camera_and_distance(
+ camera, 10.0, cols, rows, images_paths
+ )
+ v.scene.add(billboard)
+ v.scene.add(camera)
+ v.scene.camera.load_cam()
+ v.set_temp_camera(camera)
diff --git a/lib/vis/wham_tools/__pycache__/tools.cpython-310.pyc b/lib/vis/wham_tools/__pycache__/tools.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df1f556d13adac6ae0fe0c8b0cdd1b1c48662637
Binary files /dev/null and b/lib/vis/wham_tools/__pycache__/tools.cpython-310.pyc differ
diff --git a/lib/vis/wham_tools/tools.py b/lib/vis/wham_tools/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9e7f053f49a11e12a7f01a76dbca9059c359f89
--- /dev/null
+++ b/lib/vis/wham_tools/tools.py
@@ -0,0 +1,56 @@
+import numpy as np
+
+
+def checkerboard_geometry(
+ length=12.0,
+ color0=[172/255, 172/255, 172/255],
+ color1=[215/255, 215/255, 215/255],
+ tile_width=0.5,
+ alpha=1.0,
+ up="y",
+ c1=0.0,
+ c2=0.0,
+):
+ assert up == "y" or up == "z"
+ color0 = np.array(color0 + [alpha])
+ color1 = np.array(color1 + [alpha])
+ radius = length / 2.0
+ num_rows = num_cols = max(2, int(length / tile_width))
+ vertices = []
+ vert_colors = []
+ faces = []
+ face_colors = []
+ for i in range(num_rows):
+ for j in range(num_cols):
+ u0, v0 = j * tile_width - radius, i * tile_width - radius
+ us = np.array([u0, u0, u0 + tile_width, u0 + tile_width])
+ vs = np.array([v0, v0 + tile_width, v0 + tile_width, v0])
+ zs = np.zeros(4)
+ if up == "y":
+ cur_verts = np.stack([us, zs, vs], axis=-1) # (4, 3)
+ cur_verts[:, 0] += c1
+ cur_verts[:, 2] += c2
+ else:
+ cur_verts = np.stack([us, vs, zs], axis=-1) # (4, 3)
+ cur_verts[:, 0] += c1
+ cur_verts[:, 1] += c2
+
+ cur_faces = np.array(
+ [[0, 1, 3], [1, 2, 3], [0, 3, 1], [1, 3, 2]], dtype=np.int64
+ )
+ cur_faces += 4 * (i * num_cols + j) # the number of previously added verts
+ use_color0 = (i % 2 == 0 and j % 2 == 0) or (i % 2 == 1 and j % 2 == 1)
+ cur_color = color0 if use_color0 else color1
+ cur_colors = np.array([cur_color, cur_color, cur_color, cur_color])
+
+ vertices.append(cur_verts)
+ faces.append(cur_faces)
+ vert_colors.append(cur_colors)
+ face_colors.append(cur_colors)
+
+ vertices = np.concatenate(vertices, axis=0).astype(np.float32)
+ vert_colors = np.concatenate(vert_colors, axis=0).astype(np.float32)
+ faces = np.concatenate(faces, axis=0).astype(np.float32)
+ face_colors = np.concatenate(face_colors, axis=0).astype(np.float32)
+
+ return vertices, faces, vert_colors, face_colors
\ No newline at end of file
diff --git a/license.txt b/license.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3811ab9a1eb86022f72d9c871edb3c05907a158f
--- /dev/null
+++ b/license.txt
@@ -0,0 +1,402 @@
+Attribution-NonCommercial-NoDerivatives 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial-NoDerivatives 4.0
+International Public License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial-NoDerivatives 4.0 International Public
+License ("Public License"). To the extent this Public License may be
+interpreted as a contract, You are granted the Licensed Rights in
+consideration of Your acceptance of these terms and conditions, and the
+Licensor grants You such rights in consideration of benefits the
+Licensor receives from making the Licensed Material available under
+these terms and conditions.
+
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+
+ c. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ d. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ e. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ f. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ g. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ h. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ i. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ j. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ k. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce and reproduce, but not Share, Adapted Material
+ for NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material, You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ For the avoidance of doubt, You do not have permission under
+ this Public License to Share Adapted Material.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only and provided You do not Share Adapted Material;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.†The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index bcb91bb3e0308f3fdadc2c88a746919e194f3426..66530b53768d7236f7f9dc8fae910d2718218e52 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1,37 @@
-git+https://github.com/princeton-vl/DROID-SLAM.git
\ No newline at end of file
+numpy==1.26.4
+opencv-python
+pyrender
+scikit-image
+smplx==0.1.28
+yacs
+mmcv==1.3.9
+timm
+einops
+xtcocotools
+pandas
+hydra-core
+hydra-submitit-launcher
+hydra-colorlog
+pyrootutils
+rich
+webdataset
+ultralytics
+pulp
+supervision
+pycocotools
+joblib
+natsort
+git+https://github.com/facebookresearch/pytorch3d.git@stable
+torch-scatter==2.1.2
+evo
+pytorch-minimize
+mmengine==0.10.4
+HTML4Vision
+plyfile
+aitviewer
+easydict
+loguru
+dill
+lapx
+chumpy@git+https://github.com/mattloper/chumpy
+moderngl-window==2.4.6
\ No newline at end of file
diff --git a/scripts/scripts_test_video/__pycache__/detect_track_video.cpython-310.pyc b/scripts/scripts_test_video/__pycache__/detect_track_video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a58ee3385e3d8e5e1d63216579c2c4e8b770cab
Binary files /dev/null and b/scripts/scripts_test_video/__pycache__/detect_track_video.cpython-310.pyc differ
diff --git a/scripts/scripts_test_video/__pycache__/hawor_slam.cpython-310.pyc b/scripts/scripts_test_video/__pycache__/hawor_slam.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..108382dc5790315884ab5209a7038990e6a7f377
Binary files /dev/null and b/scripts/scripts_test_video/__pycache__/hawor_slam.cpython-310.pyc differ
diff --git a/scripts/scripts_test_video/__pycache__/hawor_video.cpython-310.pyc b/scripts/scripts_test_video/__pycache__/hawor_video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4fd050c335a2d7f4cd128873e58fe5d28c17f20f
Binary files /dev/null and b/scripts/scripts_test_video/__pycache__/hawor_video.cpython-310.pyc differ
diff --git a/scripts/scripts_test_video/detect_track_video.py b/scripts/scripts_test_video/detect_track_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..308dedbaf7a5f58f9a042dc1076fbb3a0668c8a0
--- /dev/null
+++ b/scripts/scripts_test_video/detect_track_video.py
@@ -0,0 +1,72 @@
+import sys
+import os
+sys.path.insert(0, os.path.dirname(__file__) + '/../..')
+
+import argparse
+import numpy as np
+from glob import glob
+from lib.pipeline.tools import detect_track
+from natsort import natsorted
+import subprocess
+
+
+def extract_frames(video_path, output_folder):
+ if not os.path.exists(output_folder):
+ os.makedirs(output_folder)
+
+ command = [
+ 'ffmpeg',
+ '-i', video_path,
+ '-vf', 'fps=30',
+ '-start_number', '0',
+ os.path.join(output_folder, '%04d.jpg')
+ ]
+
+ subprocess.run(command, check=True)
+
+
+def detect_track_video(args):
+ file = args.video_path
+ root = os.path.dirname(file)
+ seq = os.path.basename(file).split('.')[0]
+
+ seq_folder = f'{root}/{seq}'
+ img_folder = f'{seq_folder}/extracted_images'
+ os.makedirs(seq_folder, exist_ok=True)
+ os.makedirs(img_folder, exist_ok=True)
+ print(f'Running detect_track on {file} ...')
+
+ ##### Extract Frames #####
+ imgfiles = natsorted(glob(f'{img_folder}/*.jpg'))
+ # print(imgfiles[:10])
+ if len(imgfiles) > 0:
+ print("Skip extracting frames")
+ else:
+ _ = extract_frames(file, img_folder)
+ imgfiles = natsorted(glob(f'{img_folder}/*.jpg'))
+
+ ##### Detection + Track #####
+ print('Detect and Track ...')
+
+ start_idx = 0
+ end_idx = len(imgfiles)
+
+ if os.path.exists(f'{seq_folder}/tracks_{start_idx}_{end_idx}/model_boxes.npy'):
+ print(f"skip track for {start_idx}_{end_idx}")
+ return start_idx, end_idx, seq_folder, imgfiles
+ os.makedirs(f"{seq_folder}/tracks_{start_idx}_{end_idx}", exist_ok=True)
+ boxes_, tracks_ = detect_track(imgfiles, thresh=0.2)
+ np.save(f'{seq_folder}/tracks_{start_idx}_{end_idx}/model_boxes.npy', boxes_)
+ np.save(f'{seq_folder}/tracks_{start_idx}_{end_idx}/model_tracks.npy', tracks_)
+
+ return start_idx, end_idx, seq_folder, imgfiles
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--img_focal", type=float)
+ parser.add_argument("--video_path", type=str, default='')
+ parser.add_argument("--input_type", type=str, default='file')
+ args = parser.parse_args()
+
+ detect_track_video(args)
\ No newline at end of file
diff --git a/scripts/scripts_test_video/hawor_slam.py b/scripts/scripts_test_video/hawor_slam.py
new file mode 100644
index 0000000000000000000000000000000000000000..a557bd102cca778bc77eb424b9b92317f326b6cf
--- /dev/null
+++ b/scripts/scripts_test_video/hawor_slam.py
@@ -0,0 +1,141 @@
+import sys
+import os
+
+from natsort import natsorted
+
+sys.path.insert(0, os.path.dirname(__file__) + '/../..')
+
+import argparse
+from tqdm import tqdm
+import numpy as np
+import torch
+import cv2
+from PIL import Image
+from glob import glob
+from pycocotools import mask as masktool
+from lib.pipeline.masked_droid_slam import *
+from lib.pipeline.est_scale import *
+from hawor.utils.process import block_print, enable_print
+
+sys.path.insert(0, os.path.dirname(__file__) + '/../../thirdparty/Metric3D')
+from metric import Metric3D
+
+
+def get_all_mp4_files(folder_path):
+ # Ensure the folder path is absolute
+ folder_path = os.path.abspath(folder_path)
+
+ # Recursively search for all .mp4 files in the folder and its subfolders
+ mp4_files = glob(os.path.join(folder_path, '**', '*.mp4'), recursive=True)
+
+ return mp4_files
+
+def split_list_by_interval(lst, interval=1000):
+ start_indices = []
+ end_indices = []
+ split_lists = []
+
+ for i in range(0, len(lst), interval):
+ start_indices.append(i)
+ end_indices.append(min(i + interval, len(lst)))
+ split_lists.append(lst[i:i + interval])
+
+ return start_indices, end_indices, split_lists
+
+def hawor_slam(args, start_idx, end_idx):
+ # File and folders
+ file = args.video_path
+ video_root = os.path.dirname(file)
+ video = os.path.basename(file).split('.')[0]
+ seq_folder = os.path.join(video_root, video)
+ os.makedirs(seq_folder, exist_ok=True)
+ video_folder = os.path.join(video_root, video)
+
+ img_folder = f'{video_folder}/extracted_images'
+ imgfiles = natsorted(glob(f'{img_folder}/*.jpg'))
+
+ first_img = cv2.imread(imgfiles[0])
+ height, width, _ = first_img.shape
+
+ print(f'Running slam on {video_folder} ...')
+
+ ##### Run SLAM #####
+ # Use Masking
+ masks = np.load(f'{video_folder}/tracks_{start_idx}_{end_idx}/model_masks.npy', allow_pickle=True)
+ masks = torch.from_numpy(masks)
+ print(masks.shape)
+
+ # Camera calibration (intrinsics) for SLAM
+ focal = args.img_focal
+ if focal is None:
+ try:
+ with open(os.path.join(video_folder, 'est_focal.txt'), 'r') as file:
+ focal = file.read()
+ focal = float(focal)
+ except:
+
+ print('No focal length provided')
+ focal = 600
+ with open(os.path.join(video_folder, 'est_focal.txt'), 'w') as file:
+ file.write(str(focal))
+ calib = np.array(est_calib(imgfiles)) # [focal, focal, cx, cy]
+ center = calib[2:]
+ calib[:2] = focal
+
+ # Droid-slam with masking
+ droid, traj = run_slam(imgfiles, masks=masks, calib=calib)
+ n = droid.video.counter.value
+ tstamp = droid.video.tstamp.cpu().int().numpy()[:n]
+ disps = droid.video.disps_up.cpu().numpy()[:n]
+ print('DBA errors:', droid.backend.errors)
+
+ del droid
+ torch.cuda.empty_cache()
+
+ # Estimate scale
+ block_print()
+ metric = Metric3D('thirdparty/Metric3D/weights/metric_depth_vit_large_800k.pth')
+ enable_print()
+ min_threshold = 0.4
+ max_threshold = 0.7
+
+ print('Predicting Metric Depth ...')
+ pred_depths = []
+ H, W = get_dimention(imgfiles)
+ for t in tqdm(tstamp):
+ pred_depth = metric(imgfiles[t], calib)
+ pred_depth = cv2.resize(pred_depth, (W, H))
+ pred_depths.append(pred_depth)
+
+ ##### Estimate Metric Scale #####
+ print('Estimating Metric Scale ...')
+ scales_ = []
+ n = len(tstamp) # for each keyframe
+ for i in tqdm(range(n)):
+ t = tstamp[i]
+ disp = disps[i]
+ pred_depth = pred_depths[i]
+ slam_depth = 1/disp
+
+ # Estimate scene scale
+ msk = masks[t].numpy().astype(np.uint8)
+ scale = est_scale_hybrid(slam_depth, pred_depth, sigma=0.5, msk=msk, near_thresh=min_threshold, far_thresh=max_threshold)
+ scales_.append(scale)
+
+ median_s = np.median(scales_)
+ print(f"estimated scale: {median_s}")
+
+ # Save results
+ os.makedirs(f"{seq_folder}/SLAM", exist_ok=True)
+ save_path = f'{seq_folder}/SLAM/hawor_slam_w_scale_{start_idx}_{end_idx}.npz'
+ np.savez(save_path,
+ tstamp=tstamp, disps=disps, traj=traj,
+ img_focal=focal, img_center=calib[-2:],
+ scale=median_s)
+
+
+
+
+
+
+
diff --git a/scripts/scripts_test_video/hawor_video.py b/scripts/scripts_test_video/hawor_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..eed5bcc6f0470cb3000737ead187226691000b18
--- /dev/null
+++ b/scripts/scripts_test_video/hawor_video.py
@@ -0,0 +1,361 @@
+from collections import defaultdict
+
+import json
+import os
+import joblib
+import numpy as np
+import torch
+import cv2
+from tqdm import tqdm
+from glob import glob
+from natsort import natsorted
+
+from lib.pipeline.tools import parse_chunks, parse_chunks_hand_frame
+from lib.models.hawor import HAWOR
+from lib.eval_utils.custom_utils import cam2world_convert, load_slam_cam
+from lib.eval_utils.custom_utils import interpolate_bboxes
+from lib.eval_utils.filling_utils import filling_postprocess, filling_preprocess
+from lib.vis.renderer import Renderer
+from hawor.utils.process import get_mano_faces, run_mano, run_mano_left
+from hawor.utils.rotation import angle_axis_to_rotation_matrix, rotation_matrix_to_angle_axis
+from infiller.lib.model.network import TransformerModel
+
+def load_hawor(checkpoint_path):
+ from pathlib import Path
+ from hawor.configs import get_config
+ model_cfg = str(Path(checkpoint_path).parent.parent / 'model_config.yaml')
+ model_cfg = get_config(model_cfg, update_cachedir=True)
+
+ # Override some config values, to crop bbox correctly
+ if (model_cfg.MODEL.BACKBONE.TYPE == 'vit') and ('BBOX_SHAPE' not in model_cfg.MODEL):
+ model_cfg.defrost()
+ assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
+ model_cfg.MODEL.BBOX_SHAPE = [192,256]
+ model_cfg.freeze()
+
+ model = HAWOR.load_from_checkpoint(checkpoint_path, strict=False, cfg=model_cfg)
+ return model, model_cfg
+
+
+
+def hawor_motion_estimation(args, start_idx, end_idx, seq_folder):
+ model, model_cfg = load_hawor(args.checkpoint)
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+ model = model.to(device)
+ model.eval()
+
+ file = args.video_path
+ video_root = os.path.dirname(file)
+ video = os.path.basename(file).split('.')[0]
+ img_folder = f"{video_root}/{video}/extracted_images"
+ imgfiles = np.array(natsorted(glob(f'{img_folder}/*.jpg')))
+
+ tracks = np.load(f'{seq_folder}/tracks_{start_idx}_{end_idx}/model_tracks.npy', allow_pickle=True).item()
+ img_focal = args.img_focal
+ if img_focal is None:
+ try:
+ with open(os.path.join(seq_folder, 'est_focal.txt'), 'r') as file:
+ img_focal = file.read()
+ img_focal = float(img_focal)
+ except:
+ img_focal = 600
+ print(f'No focal length provided, use default {img_focal}')
+ with open(os.path.join(seq_folder, 'est_focal.txt'), 'w') as file:
+ file.write(str(img_focal))
+
+ tid = np.array([tr for tr in tracks])
+
+ print(f'Running hawor on {video} ...')
+
+ left_trk = []
+ right_trk = []
+ for k, idx in enumerate(tid):
+ trk = tracks[idx]
+
+ valid = np.array([t['det'] for t in trk])
+ is_right = np.concatenate([t['det_handedness'] for t in trk])[valid]
+
+ if is_right.sum() / len(is_right) < 0.5:
+ left_trk.extend(trk)
+ else:
+ right_trk.extend(trk)
+ left_trk = sorted(left_trk, key=lambda x: x['frame'])
+ right_trk = sorted(right_trk, key=lambda x: x['frame'])
+ final_tracks = {
+ 0: left_trk,
+ 1: right_trk
+ }
+ tid = [0, 1]
+
+ img = cv2.imread(imgfiles[0])
+ img_center = [img.shape[1] / 2, img.shape[0] / 2]# w/2, h/2
+ H, W = img.shape[:2]
+ model_masks = np.zeros((len(imgfiles), H, W))
+
+ bin_size = 128
+ max_faces_per_bin = 20000
+ renderer = Renderer(img.shape[1], img.shape[0], img_focal, 'cuda',
+ bin_size=bin_size, max_faces_per_bin=max_faces_per_bin)
+ # get faces
+ faces = get_mano_faces()
+ faces_new = np.array([[92, 38, 234],
+ [234, 38, 239],
+ [38, 122, 239],
+ [239, 122, 279],
+ [122, 118, 279],
+ [279, 118, 215],
+ [118, 117, 215],
+ [215, 117, 214],
+ [117, 119, 214],
+ [214, 119, 121],
+ [119, 120, 121],
+ [121, 120, 78],
+ [120, 108, 78],
+ [78, 108, 79]])
+ faces_right = np.concatenate([faces, faces_new], axis=0)
+ faces_left = faces_right[:,[0,2,1]]
+
+ frame_chunks_all = defaultdict(list)
+ for idx in tid:
+ print(f"tracklet {idx}:")
+ trk = final_tracks[idx]
+
+ # interp bboxes
+ valid = np.array([t['det'] for t in trk])
+ if valid.sum() < 2:
+ continue
+ boxes = np.concatenate([t['det_box'] for t in trk])
+ non_zero_indices = np.where(np.any(boxes != 0, axis=1))[0]
+ first_non_zero = non_zero_indices[0]
+ last_non_zero = non_zero_indices[-1]
+ boxes[first_non_zero:last_non_zero+1] = interpolate_bboxes(boxes[first_non_zero:last_non_zero+1])
+ valid[first_non_zero:last_non_zero+1] = True
+
+
+ boxes = boxes[first_non_zero:last_non_zero+1]
+ is_right = np.concatenate([t['det_handedness'] for t in trk])[valid]
+ frame = np.array([t['frame'] for t in trk])[valid]
+
+ if is_right.sum() / len(is_right) < 0.5:
+ is_right = np.zeros((len(boxes), 1))
+ else:
+ is_right = np.ones((len(boxes), 1))
+
+ frame_chunks, boxes_chunks = parse_chunks(frame, boxes, min_len=1)
+ frame_chunks_all[idx] = frame_chunks
+
+ if len(frame_chunks) == 0:
+ continue
+
+ for frame_ck, boxes_ck in zip(frame_chunks, boxes_chunks):
+ print(f"inference from frame {frame_ck[0]} to {frame_ck[-1]}")
+ img_ck = imgfiles[frame_ck]
+ if is_right[0] > 0:
+ do_flip = False
+ else:
+ do_flip = True
+
+ results = model.inference(img_ck, boxes_ck, img_focal=img_focal, img_center=img_center, do_flip=do_flip)
+
+ data_out = {
+ "init_root_orient": results["pred_rotmat"][None, :, 0], # (B, T, 3, 3)
+ "init_hand_pose": results["pred_rotmat"][None, :, 1:], # (B, T, 15, 3, 3)
+ "init_trans": results["pred_trans"][None, :, 0], # (B, T, 3)
+ "init_betas": results["pred_shape"][None, :] # (B, T, 10)
+ }
+
+ # flip left hand
+ init_root = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
+ init_hand_pose = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
+ if do_flip:
+ init_root[..., 1] *= -1
+ init_root[..., 2] *= -1
+ init_hand_pose[..., 1] *= -1
+ init_hand_pose[..., 2] *= -1
+ data_out["init_root_orient"] = angle_axis_to_rotation_matrix(init_root)
+ data_out["init_hand_pose"] = angle_axis_to_rotation_matrix(init_hand_pose)
+
+ # save camera-space results
+ pred_dict={
+ k:v.tolist() for k, v in data_out.items()
+ }
+ pred_path = os.path.join(seq_folder, 'cam_space', str(idx), f"{frame_ck[0]}_{frame_ck[-1]}.json")
+ if not os.path.exists(os.path.join(seq_folder, 'cam_space', str(idx))):
+ os.makedirs(os.path.join(seq_folder, 'cam_space', str(idx)))
+ with open(pred_path, "w") as f:
+ json.dump(pred_dict, f, indent=1)
+
+
+ # get hand mask
+ data_out["init_root_orient"] = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
+ data_out["init_hand_pose"] = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
+ if do_flip: # left
+ outputs = run_mano_left(data_out["init_trans"], data_out["init_root_orient"], data_out["init_hand_pose"], betas=data_out["init_betas"])
+ else: # right
+ outputs = run_mano(data_out["init_trans"], data_out["init_root_orient"], data_out["init_hand_pose"], betas=data_out["init_betas"])
+
+ vertices = outputs["vertices"][0].cpu() # (T, N, 3)
+ for img_i, _ in enumerate(img_ck):
+ if do_flip:
+ faces = torch.from_numpy(faces_left).cuda()
+ else:
+ faces = torch.from_numpy(faces_right).cuda()
+ cam_R = torch.eye(3).unsqueeze(0).cuda()
+ cam_T = torch.zeros(1, 3).cuda()
+ cameras, lights = renderer.create_camera_from_cv(cam_R, cam_T)
+ verts_color = torch.tensor([0, 0, 255, 255]) / 255
+ vertices_i = vertices[[img_i]]
+ rend, mask = renderer.render_multiple(vertices_i.unsqueeze(0).cuda(), faces, verts_color.unsqueeze(0).cuda(), cameras, lights)
+
+ model_masks[frame_ck[img_i]] += mask
+
+ model_masks = model_masks > 0 # bool
+ np.save(f'{seq_folder}/tracks_{start_idx}_{end_idx}/model_masks.npy', model_masks)
+ return frame_chunks_all, img_focal
+
+def hawor_infiller(args, start_idx, end_idx, frame_chunks_all):
+ # load infiller
+ weight_path = args.infiller_weight
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+ ckpt = torch.load(weight_path, map_location=device)
+ pos_dim = 3
+ shape_dim = 10
+ num_joints = 15
+ rot_dim = (num_joints + 1) * 6 # rot6d
+ repr_dim = 2 * (pos_dim + shape_dim + rot_dim)
+ nhead = 8 # repr_dim = 154
+ horizon = 120
+ filling_model = TransformerModel(seq_len=horizon, input_dim=repr_dim, d_model=384, nhead=nhead, d_hid=2048, nlayers=8, dropout=0.05, out_dim=repr_dim, masked_attention_stage=True)
+ filling_model.to(device)
+ filling_model.load_state_dict(ckpt['transformer_encoder_state_dict'])
+ filling_model.eval()
+
+ file = args.video_path
+ video_root = os.path.dirname(file)
+ video = os.path.basename(file).split('.')[0]
+ seq_folder = os.path.join(video_root, video)
+ img_folder = f"{video_root}/{video}/extracted_images"
+
+ # Previous steps
+ imgfiles = np.array(natsorted(glob(f'{img_folder}/*.jpg')))
+
+ idx2hand = ['left', 'right']
+ filling_length = 120
+
+ fpath = os.path.join(seq_folder, f"SLAM/hawor_slam_w_scale_{start_idx}_{end_idx}.npz")
+ R_w2c_sla_all, t_w2c_sla_all, R_c2w_sla_all, t_c2w_sla_all = load_slam_cam(fpath)
+
+ pred_trans = torch.zeros(2, len(imgfiles), 3)
+ pred_rot = torch.zeros(2, len(imgfiles), 3)
+ pred_hand_pose = torch.zeros(2, len(imgfiles), 45)
+ pred_betas = torch.zeros(2, len(imgfiles), 10)
+ pred_valid = torch.zeros((2, pred_betas.size(1)))
+
+ # camera space to world space
+ tid = [0, 1]
+ for k, idx in enumerate(tid):
+ frame_chunks = frame_chunks_all[idx]
+
+ if len(frame_chunks) == 0:
+ continue
+
+ for frame_ck in frame_chunks:
+ print(f"from frame {frame_ck[0]} to {frame_ck[-1]}")
+ pred_path = os.path.join(seq_folder, 'cam_space', str(idx), f"{frame_ck[0]}_{frame_ck[-1]}.json")
+ with open(pred_path, "r") as f:
+ pred_dict = json.load(f)
+ data_out = {
+ k:torch.tensor(v) for k, v in pred_dict.items()
+ }
+
+ R_c2w_sla = R_c2w_sla_all[frame_ck]
+ t_c2w_sla = t_c2w_sla_all[frame_ck]
+
+ data_world = cam2world_convert(R_c2w_sla, t_c2w_sla, data_out, 'right' if idx > 0 else 'left')
+
+ pred_trans[[idx], frame_ck] = data_world["init_trans"]
+ pred_rot[[idx], frame_ck] = data_world["init_root_orient"]
+ pred_hand_pose[[idx], frame_ck] = data_world["init_hand_pose"].flatten(-2)
+ pred_betas[[idx], frame_ck] = data_world["init_betas"]
+ pred_valid[[idx], frame_ck] = 1
+
+
+ # runing fillingnet for this video
+ frame_list = torch.tensor(list(range(pred_trans.size(1))))
+ pred_valid = (pred_valid > 0).numpy()
+ for k, idx in enumerate([1, 0]):
+ missing = ~pred_valid[idx]
+
+ frame = frame_list[missing]
+ frame_chunks = parse_chunks_hand_frame(frame)
+
+ print(f"run infiller on {idx2hand[idx]} hand ...")
+ for frame_ck in tqdm(frame_chunks):
+ start_shift = -1
+ while frame_ck[0] + start_shift >= 0 and pred_valid[:, frame_ck[0] + start_shift].sum() != 2:
+ start_shift -= 1 # Shift to find the previous valid frame as start
+ print(f"run infiller on frame {frame_ck[0] + start_shift} to frame {min(len(imgfiles)-1, frame_ck[0] + start_shift + filling_length)}")
+
+ frame_start = frame_ck[0]
+ filling_net_start = max(0, frame_start + start_shift)
+ filling_net_end = min(len(imgfiles)-1, filling_net_start + filling_length)
+ seq_valid = pred_valid[:, filling_net_start:filling_net_end]
+ filling_seq = {}
+ filling_seq['trans'] = pred_trans[:, filling_net_start:filling_net_end].numpy()
+ filling_seq['rot'] = pred_rot[:, filling_net_start:filling_net_end].numpy()
+ filling_seq['hand_pose'] = pred_hand_pose[:, filling_net_start:filling_net_end].numpy()
+ filling_seq['betas'] = pred_betas[:, filling_net_start:filling_net_end].numpy()
+ filling_seq['valid'] = seq_valid
+ # preprocess (convert to canonical + slerp)
+ filling_input, transform_w_canon = filling_preprocess(filling_seq)
+ src_mask = torch.zeros((filling_length, filling_length), device=device).type(torch.bool)
+ src_mask = src_mask.to(device)
+ filling_input = torch.from_numpy(filling_input).unsqueeze(0).to(device).permute(1,0,2) # (seq_len, B, in_dim)
+ T_original = len(filling_input)
+ filling_length = 120
+ if T_original < filling_length:
+ pad_length = filling_length - T_original
+ last_time_step = filling_input[-1, :, :]
+ padding = last_time_step.unsqueeze(0).repeat(pad_length, 1, 1)
+ filling_input = torch.cat([filling_input, padding], dim=0)
+ seq_valid_padding = np.ones((2, filling_length - T_original))
+ seq_valid_padding = np.concatenate([seq_valid, seq_valid_padding], axis=1)
+ else:
+ seq_valid_padding = seq_valid
+
+
+ T, B, _ = filling_input.shape
+
+ valid = torch.from_numpy(seq_valid_padding).unsqueeze(0).all(dim=1).permute(1, 0) # (T,B)
+ valid_atten = torch.from_numpy(seq_valid_padding).unsqueeze(0).all(dim=1).unsqueeze(1) # (B,1,T)
+ data_mask = torch.zeros((horizon, B, 1), device=device, dtype=filling_input.dtype)
+ data_mask[valid] = 1
+ atten_mask = torch.ones((B, 1, horizon),
+ device=device, dtype=torch.bool)
+ atten_mask[valid_atten] = False
+ atten_mask = atten_mask.unsqueeze(2).repeat(1, 1, T, 1) # (B,1,T,T)
+
+ output_ck = filling_model(filling_input, src_mask, data_mask, atten_mask)
+
+ output_ck = output_ck.permute(1,0,2).reshape(T, 2, -1).cpu().detach() # two hands
+
+ output_ck = output_ck[:T_original]
+
+ filling_output = filling_postprocess(output_ck, transform_w_canon)
+
+ # repalce the missing prediciton with infiller output
+ filling_seq['trans'][~seq_valid] = filling_output['trans'][~seq_valid]
+ filling_seq['rot'][~seq_valid] = filling_output['rot'][~seq_valid]
+ filling_seq['hand_pose'][~seq_valid] = filling_output['hand_pose'][~seq_valid]
+ filling_seq['betas'][~seq_valid] = filling_output['betas'][~seq_valid]
+
+ pred_trans[:, filling_net_start:filling_net_end] = torch.from_numpy(filling_seq['trans'][:])
+ pred_rot[:, filling_net_start:filling_net_end] = torch.from_numpy(filling_seq['rot'][:])
+ pred_hand_pose[:, filling_net_start:filling_net_end] = torch.from_numpy(filling_seq['hand_pose'][:])
+ pred_betas[:, filling_net_start:filling_net_end] = torch.from_numpy(filling_seq['betas'][:])
+ pred_valid[:, filling_net_start:filling_net_end] = 1
+ save_path = os.path.join(seq_folder, "world_space_res.pth")
+ joblib.dump([pred_trans, pred_rot, pred_hand_pose, pred_betas, pred_valid], save_path)
+ return pred_trans, pred_rot, pred_hand_pose, pred_betas, pred_valid
+
+
\ No newline at end of file