ThunderVVV commited on
Commit
5f028d6
·
1 Parent(s): 014faee
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. README.md +94 -12
  3. _DATA/data/mano/.gitkeep +0 -0
  4. _DATA/data/mano/MANO_RIGHT.pkl +3 -0
  5. _DATA/data/mano_mean_params.npz +3 -0
  6. _DATA/data_left/mano_left/.gitkeep +0 -0
  7. _DATA/data_left/mano_left/MANO_LEFT.pkl +3 -0
  8. assets/teaser.png +3 -0
  9. demo.py +113 -0
  10. example/video_0.mp4 +3 -0
  11. hawor/configs/__init__.py +120 -0
  12. hawor/configs/__pycache__/__init__.cpython-310.pyc +0 -0
  13. hawor/utils/__pycache__/geometry.cpython-310.pyc +0 -0
  14. hawor/utils/__pycache__/process.cpython-310.pyc +0 -0
  15. hawor/utils/__pycache__/pylogger.cpython-310.pyc +0 -0
  16. hawor/utils/__pycache__/render_openpose.cpython-310.pyc +0 -0
  17. hawor/utils/__pycache__/rotation.cpython-310.pyc +0 -0
  18. hawor/utils/geometry.py +102 -0
  19. hawor/utils/process.py +198 -0
  20. hawor/utils/pylogger.py +17 -0
  21. hawor/utils/render_openpose.py +225 -0
  22. hawor/utils/rotation.py +293 -0
  23. imgui.ini +15 -0
  24. infiller/hand_utils/geometry.py +412 -0
  25. infiller/hand_utils/geometry_utils.py +102 -0
  26. infiller/hand_utils/mano_wrapper.py +52 -0
  27. infiller/hand_utils/process.py +171 -0
  28. infiller/hand_utils/rotation.py +293 -0
  29. infiller/lib/misc/sampler.py +79 -0
  30. infiller/lib/model/__pycache__/network.cpython-310.pyc +0 -0
  31. infiller/lib/model/network.py +276 -0
  32. infiller/lib/model/positional_encoding.py +42 -0
  33. infiller/lib/model/preprocess.py +189 -0
  34. infiller/lib/model/skeleton.py +349 -0
  35. infiller/lib/vis/pose.py +248 -0
  36. lib/core/__pycache__/constants.cpython-310.pyc +0 -0
  37. lib/core/constants.py +78 -0
  38. lib/datasets/__pycache__/track_dataset.cpython-310.pyc +0 -0
  39. lib/datasets/track_dataset.py +78 -0
  40. lib/eval_utils/__pycache__/custom_utils.cpython-310.pyc +0 -0
  41. lib/eval_utils/__pycache__/filling_utils.cpython-310.pyc +0 -0
  42. lib/eval_utils/custom_utils.py +99 -0
  43. lib/eval_utils/filling_utils.py +306 -0
  44. lib/eval_utils/video_utils.py +85 -0
  45. lib/models/__pycache__/hawor.cpython-310.pyc +0 -0
  46. lib/models/__pycache__/mano_wrapper.cpython-310.pyc +0 -0
  47. lib/models/__pycache__/modules.cpython-310.pyc +0 -0
  48. lib/models/backbones/__init__.py +8 -0
  49. lib/models/backbones/__pycache__/__init__.cpython-310.pyc +0 -0
  50. lib/models/backbones/__pycache__/vit.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,94 @@
1
- ---
2
- title: HaWoR
3
- emoji: 👁
4
- colorFrom: green
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.9.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # HaWoR: World-Space Hand Motion Reconstruction from Egocentric Videos
4
+
5
+ [Jinglei Zhang]()<sup>1</sup> &emsp; [Jiankang Deng](https://jiankangdeng.github.io/)<sup>2</sup> &emsp; [Chao Ma](https://scholar.google.com/citations?user=syoPhv8AAAAJ&hl=en)<sup>1</sup> &emsp; [Rolandos Alexandros Potamias](https://rolpotamias.github.io)<sup>2</sup> &emsp;
6
+
7
+ <sup>1</sup>Shanghai Jiao Tong University, China
8
+ <sup>2</sup>Imperial College London, UK <br>
9
+
10
+ <a href='https://hawor-project.github.io/'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
11
+ <a href='https://arxiv.org/abs/'><img src='https://img.shields.io/badge/Paper-arXiv-red'></a>
12
+ </div>
13
+
14
+ This is the official implementation of **[HaWoR](https://hawor-project.github.io/)**, a hand reconstruction model in the world coordinates:
15
+
16
+ ![teaser](assets/teaser.png)
17
+
18
+ ## Installation
19
+
20
+ ### Installation
21
+ ```
22
+ git clone --recursive https://github.com/ThunderVVV/HaWoR.git
23
+ cd HaWoR
24
+ ```
25
+
26
+ The code has been tested with PyTorch 1.13 and CUDA 11.7. It is suggested to use an anaconda environment to install the the required dependencies:
27
+ ```bash
28
+ conda create --name hawor python=3.10
29
+ conda activate hawor
30
+
31
+ pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
32
+ # Install requirements
33
+ pip install -r requirements.txt
34
+ pip install pytorch-lightning==2.2.4 --no-deps
35
+ pip install lightning-utilities torchmetrics==1.4.0
36
+ ```
37
+
38
+ ### Install masked DROID-SLAM:
39
+
40
+ ```
41
+ cd thirdparty/DROID-SLAM
42
+ python setup.py install
43
+ ```
44
+
45
+ Download DROID-SLAM official weights [droid.pth](https://drive.google.com/file/d/1PpqVt1H4maBa_GbPJp4NwxRsd9jk-elh/view?usp=sharing), put it under `./weights/external/`.
46
+
47
+ ### Install Metric3D
48
+
49
+ Download Metric3D official weights [metric_depth_vit_large_800k.pth](https://drive.google.com/file/d/1eT2gG-kwsVzNy5nJrbm4KC-9DbNKyLnr/view?usp=drive_link), put it under `thirdparty/Metric3D/weights`.
50
+
51
+ ### Download the model weights
52
+
53
+ ```bash
54
+ wget https://huggingface.co/spaces/rolpotamias/WiLoR/resolve/main/pretrained_models/detector.pt -P ./weights/external/
55
+ wget https://huggingface.co/ThunderVVV/HaWoR/resolve/main/hawor/checkpoints/hawor.ckpt -P ./weights/hawor/checkpoints/
56
+ wget https://huggingface.co/ThunderVVV/HaWoR/resolve/main/hawor/checkpoints/infiller.pt -P ./weights/hawor/checkpoints/
57
+ wget https://huggingface.co/ThunderVVV/HaWoR/resolve/main/hawor/model_config.yaml -P ./weights/hawor/
58
+ ```
59
+ It is also required to download MANO model from [MANO website](https://mano.is.tue.mpg.de).
60
+ Create an account by clicking Sign Up and download the models (mano_v*_*.zip). Unzip and put the hand model to the `_DATA/data/mano/MANO_RIGHT.pkl` and `_DATA/data_left/mano_left/MANO_LEFT.pkl`.
61
+
62
+ Note that MANO model falls under the [MANO license](https://mano.is.tue.mpg.de/license.html).
63
+ ## Demo
64
+
65
+ For visualizaiton in world view, run with:
66
+ ```bash
67
+ python demo.py --video_path ./example/video_0.mp4 --vis_mode world
68
+ ```
69
+
70
+ For visualizaiton in camera view, run with:
71
+ ```bash
72
+ python demo.py --video_path ./example/video_0.mp4 --vis_mode cam
73
+ ```
74
+
75
+ ## Training
76
+ The training code will be released soon.
77
+
78
+ ## Acknowledgements
79
+ Parts of the code are taken or adapted from the following repos:
80
+ - [HaMeR](https://github.com/geopavlakos/hamer/)
81
+ - [WiLoR](https://github.com/rolpotamias/WiLoR)
82
+ - [SLAHMR](https://github.com/vye16/slahmr)
83
+ - [TRAM](https://github.com/yufu-wang/tram)
84
+ - [CMIB](https://github.com/jihoonerd/Conditional-Motion-In-Betweening)
85
+
86
+
87
+ ## License
88
+ HaWoR models fall under the [CC-BY-NC--ND License](./license.txt). This repository depends also on [MANO Model](https://mano.is.tue.mpg.de/license.html), which are fall under their own licenses. By using this repository, you must also comply with the terms of these external licenses.
89
+ ## Citing
90
+ If you find HaWoR useful for your research, please consider citing our paper:
91
+
92
+ ```bibtex
93
+
94
+ ```
_DATA/data/mano/.gitkeep ADDED
File without changes
_DATA/data/mano/MANO_RIGHT.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45d60aa3b27ef9107a7afd4e00808f307fd91111e1cfa35afd5c4a62de264767
3
+ size 3821356
_DATA/data/mano_mean_params.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efc0ec58e4a5cef78f3abfb4e8f91623b8950be9eff8b8e0dbb0d036ebc63988
3
+ size 1178
_DATA/data_left/mano_left/.gitkeep ADDED
File without changes
_DATA/data_left/mano_left/MANO_LEFT.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4022f7083f2ca7c78b2b3d595abbab52debd32b09d372b16923a801f0ea6a30
3
+ size 3821391
assets/teaser.png ADDED

Git LFS Details

  • SHA256: 6b33d76e9a10f215f0777612dd32ac73a5ce3b0e8735813968e7048ecd1ed3a1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
demo.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import os
4
+
5
+ import torch
6
+ sys.path.insert(0, os.path.dirname(__file__))
7
+ import numpy as np
8
+ import joblib
9
+ from scripts.scripts_test_video.detect_track_video import detect_track_video
10
+ from scripts.scripts_test_video.hawor_video import hawor_motion_estimation, hawor_infiller
11
+ from scripts.scripts_test_video.hawor_slam import hawor_slam
12
+ from hawor.utils.process import get_mano_faces, run_mano, run_mano_left
13
+ from lib.eval_utils.custom_utils import load_slam_cam
14
+ from lib.vis.run_vis2 import run_vis2_on_video, run_vis2_on_video_cam
15
+
16
+
17
+ if __name__ == '__main__':
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("--img_focal", type=float)
20
+ parser.add_argument("--video_path", type=str, default='example/video_0.mp4')
21
+ parser.add_argument("--input_type", type=str, default='file')
22
+ parser.add_argument("--checkpoint", type=str, default='./weights/hawor/checkpoints/hawor.ckpt')
23
+ parser.add_argument("--infiller_weight", type=str, default='./weights/hawor/checkpoints/infiller.pt')
24
+ parser.add_argument("--vis_mode", type=str, default='world', help='cam | world')
25
+ args = parser.parse_args()
26
+
27
+ start_idx, end_idx, seq_folder, imgfiles = detect_track_video(args)
28
+
29
+ frame_chunks_all, img_focal = hawor_motion_estimation(args, start_idx, end_idx, seq_folder)
30
+
31
+ hawor_slam(args, start_idx, end_idx)
32
+ slam_path = os.path.join(seq_folder, f"SLAM/hawor_slam_w_scale_{start_idx}_{end_idx}.npz")
33
+ R_w2c_sla_all, t_w2c_sla_all, R_c2w_sla_all, t_c2w_sla_all = load_slam_cam(slam_path)
34
+
35
+ pred_trans, pred_rot, pred_hand_pose, pred_betas, pred_valid = hawor_infiller(args, start_idx, end_idx, frame_chunks_all)
36
+
37
+ # vis sequence for this video
38
+ hand2idx = {
39
+ "right": 1,
40
+ "left": 0
41
+ }
42
+ vis_start = 0
43
+ vis_end = pred_trans.shape[1] - 1
44
+
45
+ # get faces
46
+ faces = get_mano_faces()
47
+ faces_new = np.array([[92, 38, 234],
48
+ [234, 38, 239],
49
+ [38, 122, 239],
50
+ [239, 122, 279],
51
+ [122, 118, 279],
52
+ [279, 118, 215],
53
+ [118, 117, 215],
54
+ [215, 117, 214],
55
+ [117, 119, 214],
56
+ [214, 119, 121],
57
+ [119, 120, 121],
58
+ [121, 120, 78],
59
+ [120, 108, 78],
60
+ [78, 108, 79]])
61
+ faces_right = np.concatenate([faces, faces_new], axis=0)
62
+
63
+ # get right hand vertices
64
+ hand = 'right'
65
+ hand_idx = hand2idx[hand]
66
+ pred_glob_r = run_mano(pred_trans[hand_idx:hand_idx+1, vis_start:vis_end], pred_rot[hand_idx:hand_idx+1, vis_start:vis_end], pred_hand_pose[hand_idx:hand_idx+1, vis_start:vis_end], betas=pred_betas[hand_idx:hand_idx+1, vis_start:vis_end])
67
+ right_verts = pred_glob_r['vertices'][0]
68
+ right_dict = {
69
+ 'vertices': right_verts.unsqueeze(0),
70
+ 'faces': faces_right,
71
+ }
72
+
73
+ # get left hand vertices
74
+ faces_left = faces_right[:,[0,2,1]]
75
+ hand = 'left'
76
+ hand_idx = hand2idx[hand]
77
+ pred_glob_l = run_mano_left(pred_trans[hand_idx:hand_idx+1, vis_start:vis_end], pred_rot[hand_idx:hand_idx+1, vis_start:vis_end], pred_hand_pose[hand_idx:hand_idx+1, vis_start:vis_end], betas=pred_betas[hand_idx:hand_idx+1, vis_start:vis_end])
78
+ left_verts = pred_glob_l['vertices'][0]
79
+ left_dict = {
80
+ 'vertices': left_verts.unsqueeze(0),
81
+ 'faces': faces_left,
82
+ }
83
+
84
+ R_x = torch.tensor([[1, 0, 0],
85
+ [0, -1, 0],
86
+ [0, 0, -1]]).float()
87
+ R_c2w_sla_all = torch.einsum('ij,njk->nik', R_x, R_c2w_sla_all)
88
+ t_c2w_sla_all = torch.einsum('ij,nj->ni', R_x, t_c2w_sla_all)
89
+ R_w2c_sla_all = R_c2w_sla_all.transpose(-1, -2)
90
+ t_w2c_sla_all = -torch.einsum("bij,bj->bi", R_w2c_sla_all, t_c2w_sla_all)
91
+ left_dict['vertices'] = torch.einsum('ij,btnj->btni', R_x, left_dict['vertices'].cpu())
92
+ right_dict['vertices'] = torch.einsum('ij,btnj->btni', R_x, right_dict['vertices'].cpu())
93
+
94
+ # Here we use aitviewer(https://github.com/eth-ait/aitviewer) for simple visualization.
95
+ if args.vis_mode == 'world':
96
+ output_pth = os.path.join(seq_folder, f"vis_{vis_start}_{vis_end}")
97
+ if not os.path.exists(output_pth):
98
+ os.makedirs(output_pth)
99
+ image_names = imgfiles[vis_start:vis_end]
100
+ print(f"vis {vis_start} to {vis_end}")
101
+ run_vis2_on_video(left_dict, right_dict, output_pth, img_focal, image_names, R_c2w=R_c2w_sla_all[vis_start:vis_end], t_c2w=t_c2w_sla_all[vis_start:vis_end])
102
+ elif args.vis_mode == 'cam':
103
+ output_pth = os.path.join(seq_folder, f"vis_{vis_start}_{vis_end}")
104
+ if not os.path.exists(output_pth):
105
+ os.makedirs(output_pth)
106
+ image_names = imgfiles[vis_start:vis_end]
107
+ print(f"vis {vis_start} to {vis_end}")
108
+ run_vis2_on_video_cam(left_dict, right_dict, output_pth, img_focal, image_names, R_w2c=R_w2c_sla_all[vis_start:vis_end], t_w2c=t_w2c_sla_all[vis_start:vis_end])
109
+
110
+ print("finish")
111
+
112
+
113
+
example/video_0.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13ff124a68e4b48190e0c3f0ce9f38db59c5e3bb8a093b3c7fc9c67276be2062
3
+ size 6515891
hawor/configs/__init__.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict
3
+ from yacs.config import CfgNode as CN
4
+
5
+ CACHE_DIR_HAWOR = "./_DATA"
6
+
7
+ def to_lower(x: Dict) -> Dict:
8
+ """
9
+ Convert all dictionary keys to lowercase
10
+ Args:
11
+ x (dict): Input dictionary
12
+ Returns:
13
+ dict: Output dictionary with all keys converted to lowercase
14
+ """
15
+ return {k.lower(): v for k, v in x.items()}
16
+
17
+ _C = CN(new_allowed=True)
18
+
19
+ _C.GENERAL = CN(new_allowed=True)
20
+ _C.GENERAL.RESUME = True
21
+ _C.GENERAL.TIME_TO_RUN = 3300
22
+ _C.GENERAL.VAL_STEPS = 100
23
+ _C.GENERAL.LOG_STEPS = 100
24
+ _C.GENERAL.CHECKPOINT_STEPS = 20000
25
+ _C.GENERAL.CHECKPOINT_DIR = "checkpoints"
26
+ _C.GENERAL.SUMMARY_DIR = "tensorboard"
27
+ _C.GENERAL.NUM_GPUS = 1
28
+ _C.GENERAL.NUM_WORKERS = 4
29
+ _C.GENERAL.MIXED_PRECISION = True
30
+ _C.GENERAL.ALLOW_CUDA = True
31
+ _C.GENERAL.PIN_MEMORY = False
32
+ _C.GENERAL.DISTRIBUTED = False
33
+ _C.GENERAL.LOCAL_RANK = 0
34
+ _C.GENERAL.USE_SYNCBN = False
35
+ _C.GENERAL.WORLD_SIZE = 1
36
+
37
+ _C.TRAIN = CN(new_allowed=True)
38
+ _C.TRAIN.NUM_EPOCHS = 100
39
+ _C.TRAIN.BATCH_SIZE = 32
40
+ _C.TRAIN.SHUFFLE = True
41
+ _C.TRAIN.WARMUP = False
42
+ _C.TRAIN.NORMALIZE_PER_IMAGE = False
43
+ _C.TRAIN.CLIP_GRAD = False
44
+ _C.TRAIN.CLIP_GRAD_VALUE = 1.0
45
+ _C.LOSS_WEIGHTS = CN(new_allowed=True)
46
+
47
+ _C.DATASETS = CN(new_allowed=True)
48
+
49
+ _C.MODEL = CN(new_allowed=True)
50
+ _C.MODEL.IMAGE_SIZE = 224
51
+
52
+ _C.EXTRA = CN(new_allowed=True)
53
+ _C.EXTRA.FOCAL_LENGTH = 5000
54
+
55
+ _C.DATASETS.CONFIG = CN(new_allowed=True)
56
+ _C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
57
+ _C.DATASETS.CONFIG.ROT_FACTOR = 30
58
+ _C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
59
+ _C.DATASETS.CONFIG.COLOR_SCALE = 0.2
60
+ _C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
61
+ _C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
62
+ _C.DATASETS.CONFIG.DO_FLIP = False
63
+ _C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
64
+ _C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
65
+
66
+ def default_config() -> CN:
67
+ """
68
+ Get a yacs CfgNode object with the default config values.
69
+ """
70
+ # Return a clone so that the defaults will not be altered
71
+ # This is for the "local variable" use pattern
72
+ return _C.clone()
73
+
74
+ def dataset_config() -> CN:
75
+ """
76
+ Get dataset config file
77
+ Returns:
78
+ CfgNode: Dataset config as a yacs CfgNode object.
79
+ """
80
+ cfg = CN(new_allowed=True)
81
+ config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets_tar.yaml')
82
+ cfg.merge_from_file(config_file)
83
+ cfg.freeze()
84
+ return cfg
85
+
86
+ def dataset_eval_config() -> CN:
87
+ cfg = CN(new_allowed=True)
88
+ config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets_eval.yaml')
89
+ cfg.merge_from_file(config_file)
90
+ cfg.freeze()
91
+ return cfg
92
+
93
+ def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN:
94
+ """
95
+ Read a config file and optionally merge it with the default config file.
96
+ Args:
97
+ config_file (str): Path to config file.
98
+ merge (bool): Whether to merge with the default config or not.
99
+ Returns:
100
+ CfgNode: Config as a yacs CfgNode object.
101
+ """
102
+ if merge:
103
+ cfg = default_config()
104
+ else:
105
+ cfg = CN(new_allowed=True)
106
+ cfg.merge_from_file(config_file)
107
+
108
+ if update_cachedir:
109
+ def update_path(path: str) -> str:
110
+ if os.path.basename(CACHE_DIR_HAWOR) in path:
111
+ return path
112
+ if os.path.isabs(path):
113
+ return path
114
+ return os.path.join(CACHE_DIR_HAWOR, path)
115
+
116
+ cfg.MANO.MODEL_PATH = update_path(cfg.MANO.MODEL_PATH)
117
+ cfg.MANO.MEAN_PARAMS = update_path(cfg.MANO.MEAN_PARAMS)
118
+
119
+ cfg.freeze()
120
+ return cfg
hawor/configs/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.55 kB). View file
 
hawor/utils/__pycache__/geometry.cpython-310.pyc ADDED
Binary file (4.09 kB). View file
 
hawor/utils/__pycache__/process.cpython-310.pyc ADDED
Binary file (5.54 kB). View file
 
hawor/utils/__pycache__/pylogger.cpython-310.pyc ADDED
Binary file (655 Bytes). View file
 
hawor/utils/__pycache__/render_openpose.cpython-310.pyc ADDED
Binary file (7.24 kB). View file
 
hawor/utils/__pycache__/rotation.cpython-310.pyc ADDED
Binary file (7.65 kB). View file
 
hawor/utils/geometry.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+ def aa_to_rotmat(theta: torch.Tensor):
6
+ """
7
+ Convert axis-angle representation to rotation matrix.
8
+ Works by first converting it to a quaternion.
9
+ Args:
10
+ theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
11
+ Returns:
12
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
13
+ """
14
+ norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
15
+ angle = torch.unsqueeze(norm, -1)
16
+ normalized = torch.div(theta, angle)
17
+ angle = angle * 0.5
18
+ v_cos = torch.cos(angle)
19
+ v_sin = torch.sin(angle)
20
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
21
+ return quat_to_rotmat(quat)
22
+
23
+ def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
24
+ """
25
+ Convert quaternion representation to rotation matrix.
26
+ Args:
27
+ quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
28
+ Returns:
29
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
30
+ """
31
+ norm_quat = quat
32
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
33
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
34
+
35
+ B = quat.size(0)
36
+
37
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
38
+ wx, wy, wz = w*x, w*y, w*z
39
+ xy, xz, yz = x*y, x*z, y*z
40
+
41
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
42
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
43
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
44
+ return rotMat
45
+
46
+
47
+ def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Convert 6D rotation representation to 3x3 rotation matrix.
50
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
51
+ Args:
52
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
53
+ Returns:
54
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
55
+ """
56
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
57
+ a1 = x[:, :, 0]
58
+ a2 = x[:, :, 1]
59
+ b1 = F.normalize(a1)
60
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
61
+ b3 = torch.linalg.cross(b1, b2)
62
+ return torch.stack((b1, b2, b3), dim=-1)
63
+
64
+ def perspective_projection(points: torch.Tensor,
65
+ translation: torch.Tensor,
66
+ focal_length: torch.Tensor,
67
+ camera_center: Optional[torch.Tensor] = None,
68
+ rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
69
+ """
70
+ Computes the perspective projection of a set of 3D points.
71
+ Args:
72
+ points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
73
+ translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
74
+ focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
75
+ camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
76
+ rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
77
+ Returns:
78
+ torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
79
+ """
80
+ batch_size = points.shape[0]
81
+ if rotation is None:
82
+ rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
83
+ if camera_center is None:
84
+ camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
85
+ # Populate intrinsic camera matrix K.
86
+ K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
87
+ K[:,0,0] = focal_length[:,0]
88
+ K[:,1,1] = focal_length[:,1]
89
+ K[:,2,2] = 1.
90
+ K[:,:-1, -1] = camera_center
91
+
92
+ # Transform points
93
+ points = torch.einsum('bij,bkj->bki', rotation, points)
94
+ points = points + translation.unsqueeze(1)
95
+
96
+ # Apply perspective distortion
97
+ projected_points = points / points[:,:,-1].unsqueeze(-1)
98
+
99
+ # Apply camera intrinsics
100
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
101
+
102
+ return projected_points[:, :, :-1]
hawor/utils/process.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from lib.models.mano_wrapper import MANO
3
+ from hawor.utils.geometry import aa_to_rotmat
4
+ import numpy as np
5
+ import sys
6
+ import os
7
+
8
+ def block_print():
9
+ sys.stdout = open(os.devnull, 'w')
10
+
11
+ def enable_print():
12
+ sys.stdout = sys.__stdout__
13
+
14
+ def get_mano_faces():
15
+ block_print()
16
+ MANO_cfg = {
17
+ 'DATA_DIR': '_DATA/data/',
18
+ 'MODEL_PATH': '_DATA/data/mano',
19
+ 'GENDER': 'neutral',
20
+ 'NUM_HAND_JOINTS': 15,
21
+ 'CREATE_BODY_POSE': False
22
+ }
23
+ mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
24
+ mano = MANO(**mano_cfg)
25
+ enable_print()
26
+ return mano.faces
27
+
28
+
29
+ def run_mano(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True):
30
+ """
31
+ Forward pass of the SMPL model and populates pred_data accordingly with
32
+ joints3d, verts3d, points3d.
33
+
34
+ trans : B x T x 3
35
+ root_orient : B x T x 3
36
+ body_pose : B x T x J*3
37
+ betas : (optional) B x D
38
+ """
39
+ block_print()
40
+ MANO_cfg = {
41
+ 'DATA_DIR': '_DATA/data/',
42
+ 'MODEL_PATH': '_DATA/data/mano',
43
+ 'GENDER': 'neutral',
44
+ 'NUM_HAND_JOINTS': 15,
45
+ 'CREATE_BODY_POSE': False
46
+ }
47
+ mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
48
+ mano = MANO(**mano_cfg)
49
+ if use_cuda:
50
+ mano = mano.cuda()
51
+
52
+ B, T, _ = root_orient.shape
53
+ NUM_JOINTS = 15
54
+ mano_params = {
55
+ 'global_orient': root_orient.reshape(B*T, -1),
56
+ 'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3),
57
+ 'betas': betas.reshape(B*T, -1),
58
+ }
59
+ rotmat_mano_params = mano_params
60
+ rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3)
61
+ rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3)
62
+ rotmat_mano_params['transl'] = trans.reshape(B*T, 3)
63
+
64
+ if use_cuda:
65
+ mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False)
66
+ else:
67
+ mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False)
68
+
69
+ faces_right = mano.faces
70
+ faces_new = np.array([[92, 38, 234],
71
+ [234, 38, 239],
72
+ [38, 122, 239],
73
+ [239, 122, 279],
74
+ [122, 118, 279],
75
+ [279, 118, 215],
76
+ [118, 117, 215],
77
+ [215, 117, 214],
78
+ [117, 119, 214],
79
+ [214, 119, 121],
80
+ [119, 120, 121],
81
+ [121, 120, 78],
82
+ [120, 108, 78],
83
+ [78, 108, 79]])
84
+ faces_right = np.concatenate([faces_right, faces_new], axis=0)
85
+ faces_n = len(faces_right)
86
+ faces_left = faces_right[:,[0,2,1]]
87
+
88
+ outputs = {
89
+ "joints": mano_output.joints.reshape(B, T, -1, 3),
90
+ "vertices": mano_output.vertices.reshape(B, T, -1, 3),
91
+ }
92
+
93
+ if not is_right is None:
94
+ # outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0]
95
+ # outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0]
96
+ is_right = (is_right[:, :, 0].cpu().numpy() > 0)
97
+ faces_result = np.zeros((B, T, faces_n, 3))
98
+ faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0)
99
+ faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0)
100
+ faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded)
101
+ outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32))
102
+
103
+
104
+ enable_print()
105
+ return outputs
106
+
107
+ def run_mano_left(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True, fix_shapedirs=True):
108
+ """
109
+ Forward pass of the SMPL model and populates pred_data accordingly with
110
+ joints3d, verts3d, points3d.
111
+
112
+ trans : B x T x 3
113
+ root_orient : B x T x 3
114
+ body_pose : B x T x J*3
115
+ betas : (optional) B x D
116
+ """
117
+ block_print()
118
+ MANO_cfg = {
119
+ 'DATA_DIR': '_DATA/data_left/',
120
+ 'MODEL_PATH': '_DATA/data_left/mano_left',
121
+ 'GENDER': 'neutral',
122
+ 'NUM_HAND_JOINTS': 15,
123
+ 'CREATE_BODY_POSE': False,
124
+ 'is_rhand': False
125
+ }
126
+ mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
127
+ mano = MANO(**mano_cfg)
128
+ if use_cuda:
129
+ mano = mano.cuda()
130
+
131
+ # fix MANO shapedirs of the left hand bug (https://github.com/vchoutas/smplx/issues/48)
132
+ if fix_shapedirs:
133
+ mano.shapedirs[:, 0, :] *= -1
134
+
135
+ B, T, _ = root_orient.shape
136
+ NUM_JOINTS = 15
137
+ mano_params = {
138
+ 'global_orient': root_orient.reshape(B*T, -1),
139
+ 'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3),
140
+ 'betas': betas.reshape(B*T, -1),
141
+ }
142
+ rotmat_mano_params = mano_params
143
+ rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3)
144
+ rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3)
145
+ rotmat_mano_params['transl'] = trans.reshape(B*T, 3)
146
+
147
+ if use_cuda:
148
+ mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False)
149
+ else:
150
+ mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False)
151
+
152
+ faces_right = mano.faces
153
+ faces_new = np.array([[92, 38, 234],
154
+ [234, 38, 239],
155
+ [38, 122, 239],
156
+ [239, 122, 279],
157
+ [122, 118, 279],
158
+ [279, 118, 215],
159
+ [118, 117, 215],
160
+ [215, 117, 214],
161
+ [117, 119, 214],
162
+ [214, 119, 121],
163
+ [119, 120, 121],
164
+ [121, 120, 78],
165
+ [120, 108, 78],
166
+ [78, 108, 79]])
167
+ faces_right = np.concatenate([faces_right, faces_new], axis=0)
168
+ faces_n = len(faces_right)
169
+ faces_left = faces_right[:,[0,2,1]]
170
+
171
+ outputs = {
172
+ "joints": mano_output.joints.reshape(B, T, -1, 3),
173
+ "vertices": mano_output.vertices.reshape(B, T, -1, 3),
174
+ }
175
+
176
+ if not is_right is None:
177
+ # outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0]
178
+ # outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0]
179
+ is_right = (is_right[:, :, 0].cpu().numpy() > 0)
180
+ faces_result = np.zeros((B, T, faces_n, 3))
181
+ faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0)
182
+ faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0)
183
+ faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded)
184
+ outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32))
185
+
186
+
187
+ enable_print()
188
+ return outputs
189
+
190
+ def run_mano_twohands(init_trans, init_rot, init_hand_pose, is_right, init_betas, use_cuda=True, fix_shapedirs=True):
191
+ outputs_left = run_mano_left(init_trans[0:1], init_rot[0:1], init_hand_pose[0:1], None, init_betas[0:1], use_cuda=use_cuda, fix_shapedirs=fix_shapedirs)
192
+ outputs_right = run_mano(init_trans[1:2], init_rot[1:2], init_hand_pose[1:2], None, init_betas[1:2], use_cuda=use_cuda)
193
+ outputs_two = {
194
+ "vertices": torch.cat((outputs_left["vertices"], outputs_right["vertices"]), dim=0),
195
+ "joints": torch.cat((outputs_left["joints"], outputs_right["joints"]), dim=0)
196
+
197
+ }
198
+ return outputs_two
hawor/utils/pylogger.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from pytorch_lightning.utilities import rank_zero_only
4
+
5
+
6
+ def get_pylogger(name=__name__) -> logging.Logger:
7
+ """Initializes multi-GPU-friendly python command line logger."""
8
+
9
+ logger = logging.getLogger(name)
10
+
11
+ # this ensures all logging levels get marked with the rank zero decorator
12
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
13
+ logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
14
+ for level in logging_levels:
15
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
16
+
17
+ return logger
hawor/utils/render_openpose.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Render OpenPose keypoints.
3
+ Code was ported to Python from the official C++ implementation https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/utilities/keypoint.cpp
4
+ """
5
+ import cv2
6
+ import math
7
+ import numpy as np
8
+ from typing import List, Tuple
9
+
10
+ def get_keypoints_rectangle(keypoints: np.array, threshold: float) -> Tuple[float, float, float]:
11
+ """
12
+ Compute rectangle enclosing keypoints above the threshold.
13
+ Args:
14
+ keypoints (np.array): Keypoint array of shape (N, 3).
15
+ threshold (float): Confidence visualization threshold.
16
+ Returns:
17
+ Tuple[float, float, float]: Rectangle width, height and area.
18
+ """
19
+ valid_ind = keypoints[:, -1] > threshold
20
+ if valid_ind.sum() > 0:
21
+ valid_keypoints = keypoints[valid_ind][:, :-1]
22
+ max_x = valid_keypoints[:,0].max()
23
+ max_y = valid_keypoints[:,1].max()
24
+ min_x = valid_keypoints[:,0].min()
25
+ min_y = valid_keypoints[:,1].min()
26
+ width = max_x - min_x
27
+ height = max_y - min_y
28
+ area = width * height
29
+ return width, height, area
30
+ else:
31
+ return 0,0,0
32
+
33
+ def render_keypoints(img: np.array,
34
+ keypoints: np.array,
35
+ pairs: List,
36
+ colors: List,
37
+ thickness_circle_ratio: float,
38
+ thickness_line_ratio_wrt_circle: float,
39
+ pose_scales: List,
40
+ threshold: float = 0.1,
41
+ alpha: float = 1.0) -> np.array:
42
+ """
43
+ Render keypoints on input image.
44
+ Args:
45
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
46
+ keypoints (np.array): Keypoint array of shape (N, 3).
47
+ pairs (List): List of keypoint pairs per limb.
48
+ colors: (List): List of colors per keypoint.
49
+ thickness_circle_ratio (float): Circle thickness ratio.
50
+ thickness_line_ratio_wrt_circle (float): Line thickness ratio wrt the circle.
51
+ pose_scales (List): List of pose scales.
52
+ threshold (float): Only visualize keypoints with confidence above the threshold.
53
+ Returns:
54
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
55
+ """
56
+ img_orig = img.copy()
57
+ width, height = img.shape[1], img.shape[2]
58
+ area = width * height
59
+
60
+ lineType = 8
61
+ shift = 0
62
+ numberColors = len(colors)
63
+ thresholdRectangle = 0.1
64
+
65
+ person_width, person_height, person_area = get_keypoints_rectangle(keypoints, thresholdRectangle)
66
+ if person_area > 0:
67
+ ratioAreas = min(1, max(person_width / width, person_height / height))
68
+ thicknessRatio = np.maximum(np.round(math.sqrt(area) * thickness_circle_ratio * ratioAreas), 2)
69
+ thicknessCircle = np.maximum(1, thicknessRatio if ratioAreas > 0.05 else -np.ones_like(thicknessRatio))
70
+ thicknessLine = np.maximum(1, np.round(thicknessRatio * thickness_line_ratio_wrt_circle))
71
+ radius = thicknessRatio / 2
72
+
73
+ img = np.ascontiguousarray(img.copy())
74
+ for i, pair in enumerate(pairs):
75
+ index1, index2 = pair
76
+ if keypoints[index1, -1] > threshold and keypoints[index2, -1] > threshold:
77
+ thicknessLineScaled = int(round(min(thicknessLine[index1], thicknessLine[index2]) * pose_scales[0]))
78
+ colorIndex = index2
79
+ color = colors[colorIndex % numberColors]
80
+ keypoint1 = keypoints[index1, :-1].astype(int)
81
+ keypoint2 = keypoints[index2, :-1].astype(int)
82
+ cv2.line(img, tuple(keypoint1.tolist()), tuple(keypoint2.tolist()), tuple(color.tolist()), thicknessLineScaled, lineType, shift)
83
+ for part in range(len(keypoints)):
84
+ faceIndex = part
85
+ if keypoints[faceIndex, -1] > threshold:
86
+ radiusScaled = int(round(radius[faceIndex] * pose_scales[0]))
87
+ thicknessCircleScaled = int(round(thicknessCircle[faceIndex] * pose_scales[0]))
88
+ colorIndex = part
89
+ color = colors[colorIndex % numberColors]
90
+ center = keypoints[faceIndex, :-1].astype(int)
91
+ cv2.circle(img, tuple(center.tolist()), radiusScaled, tuple(color.tolist()), thicknessCircleScaled, lineType, shift)
92
+ return img
93
+
94
+ def render_hand_keypoints(img, right_hand_keypoints, threshold=0.1, use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0):
95
+ if use_confidence and map_fn is not None:
96
+ #thicknessCircleRatioLeft = 1./50 * map_fn(left_hand_keypoints[:, -1])
97
+ thicknessCircleRatioRight = 1./50 * map_fn(right_hand_keypoints[:, -1])
98
+ else:
99
+ #thicknessCircleRatioLeft = 1./50 * np.ones(left_hand_keypoints.shape[0])
100
+ thicknessCircleRatioRight = 1./50 * np.ones(right_hand_keypoints.shape[0])
101
+ thicknessLineRatioWRTCircle = 0.75
102
+ pairs = [0,1, 1,2, 2,3, 3,4, 0,5, 5,6, 6,7, 7,8, 0,9, 9,10, 10,11, 11,12, 0,13, 13,14, 14,15, 15,16, 0,17, 17,18, 18,19, 19,20]
103
+ pairs = np.array(pairs).reshape(-1,2)
104
+
105
+ colors = [100., 100., 100.,
106
+ 100., 0., 0.,
107
+ 150., 0., 0.,
108
+ 200., 0., 0.,
109
+ 255., 0., 0.,
110
+ 100., 100., 0.,
111
+ 150., 150., 0.,
112
+ 200., 200., 0.,
113
+ 255., 255., 0.,
114
+ 0., 100., 50.,
115
+ 0., 150., 75.,
116
+ 0., 200., 100.,
117
+ 0., 255., 125.,
118
+ 0., 50., 100.,
119
+ 0., 75., 150.,
120
+ 0., 100., 200.,
121
+ 0., 125., 255.,
122
+ 100., 0., 100.,
123
+ 150., 0., 150.,
124
+ 200., 0., 200.,
125
+ 255., 0., 255.]
126
+ colors = np.array(colors).reshape(-1,3)
127
+ #colors = np.zeros_like(colors)
128
+ poseScales = [1]
129
+ #img = render_keypoints(img, left_hand_keypoints, pairs, colors, thicknessCircleRatioLeft, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha)
130
+ img = render_keypoints(img, right_hand_keypoints, pairs, colors, thicknessCircleRatioRight, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha)
131
+ #img = render_keypoints(img, right_hand_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1)
132
+ return img
133
+
134
+ def render_hand_landmarks(img, right_hand_keypoints, threshold=0.1, use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0):
135
+ if use_confidence and map_fn is not None:
136
+ #thicknessCircleRatioLeft = 1./50 * map_fn(left_hand_keypoints[:, -1])
137
+ thicknessCircleRatioRight = 1./50 * map_fn(right_hand_keypoints[:, -1])
138
+ else:
139
+ #thicknessCircleRatioLeft = 1./50 * np.ones(left_hand_keypoints.shape[0])
140
+ thicknessCircleRatioRight = 1./50 * np.ones(right_hand_keypoints.shape[0])
141
+ thicknessLineRatioWRTCircle = 0.75
142
+ pairs = []
143
+ pairs = np.array(pairs).reshape(-1,2)
144
+
145
+ colors = [255, 0, 0]
146
+ colors = np.array(colors).reshape(-1,3)
147
+ #colors = np.zeros_like(colors)
148
+ poseScales = [1]
149
+ #img = render_keypoints(img, left_hand_keypoints, pairs, colors, thicknessCircleRatioLeft, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha)
150
+ img = render_keypoints(img, right_hand_keypoints, pairs, colors, thicknessCircleRatioRight * 0.1, thicknessLineRatioWRTCircle * 0.1, poseScales, threshold, alpha=alpha)
151
+ #img = render_keypoints(img, right_hand_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1)
152
+ return img
153
+
154
+ def render_body_keypoints(img: np.array,
155
+ body_keypoints: np.array) -> np.array:
156
+ """
157
+ Render OpenPose body keypoints on input image.
158
+ Args:
159
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
160
+ body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
161
+ Returns:
162
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
163
+ """
164
+
165
+ thickness_circle_ratio = 1./75. * np.ones(body_keypoints.shape[0])
166
+ thickness_line_ratio_wrt_circle = 0.75
167
+ pairs = []
168
+ pairs = [1,8,1,2,1,5,2,3,3,4,5,6,6,7,8,9,9,10,10,11,8,12,12,13,13,14,1,0,0,15,15,17,0,16,16,18,14,19,19,20,14,21,11,22,22,23,11,24]
169
+ pairs = np.array(pairs).reshape(-1,2)
170
+ colors = [255., 0., 85.,
171
+ 255., 0., 0.,
172
+ 255., 85., 0.,
173
+ 255., 170., 0.,
174
+ 255., 255., 0.,
175
+ 170., 255., 0.,
176
+ 85., 255., 0.,
177
+ 0., 255., 0.,
178
+ 255., 0., 0.,
179
+ 0., 255., 85.,
180
+ 0., 255., 170.,
181
+ 0., 255., 255.,
182
+ 0., 170., 255.,
183
+ 0., 85., 255.,
184
+ 0., 0., 255.,
185
+ 255., 0., 170.,
186
+ 170., 0., 255.,
187
+ 255., 0., 255.,
188
+ 85., 0., 255.,
189
+ 0., 0., 255.,
190
+ 0., 0., 255.,
191
+ 0., 0., 255.,
192
+ 0., 255., 255.,
193
+ 0., 255., 255.,
194
+ 0., 255., 255.]
195
+ colors = np.array(colors).reshape(-1,3)
196
+ pose_scales = [1]
197
+ return render_keypoints(img, body_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1)
198
+
199
+ def render_openpose(img: np.array,
200
+ hand_keypoints: np.array) -> np.array:
201
+ """
202
+ Render keypoints in the OpenPose format on input image.
203
+ Args:
204
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
205
+ body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
206
+ Returns:
207
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
208
+ """
209
+ #img = render_body_keypoints(img, body_keypoints)
210
+ img = render_hand_keypoints(img, hand_keypoints)
211
+ return img
212
+
213
+ def render_openpose_landmarks(img: np.array,
214
+ hand_keypoints: np.array) -> np.array:
215
+ """
216
+ Render keypoints in the OpenPose format on input image.
217
+ Args:
218
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
219
+ body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
220
+ Returns:
221
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
222
+ """
223
+ #img = render_body_keypoints(img, body_keypoints)
224
+ img = render_hand_landmarks(img, hand_keypoints)
225
+ return img
hawor/utils/rotation.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
7
+ """
8
+ Taken from https://github.com/mkocabas/VIBE/blob/master/lib/utils/geometry.py
9
+ Calculates the rotation matrices for a batch of rotation vectors
10
+ - param rot_vecs: torch.tensor (N, 3) array of N axis-angle vectors
11
+ - returns R: torch.tensor (N, 3, 3) rotation matrices
12
+ """
13
+ batch_size = rot_vecs.shape[0]
14
+ device = rot_vecs.device
15
+
16
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
17
+ rot_dir = rot_vecs / angle
18
+
19
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
20
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
21
+
22
+ # Bx1 arrays
23
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
24
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
25
+
26
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
27
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(
28
+ (batch_size, 3, 3)
29
+ )
30
+
31
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
32
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
33
+ return rot_mat
34
+
35
+
36
+ def quaternion_mul(q0, q1):
37
+ """
38
+ EXPECTS WXYZ
39
+ :param q0 (*, 4)
40
+ :param q1 (*, 4)
41
+ """
42
+ r0, r1 = q0[..., :1], q1[..., :1]
43
+ v0, v1 = q0[..., 1:], q1[..., 1:]
44
+ r = r0 * r1 - (v0 * v1).sum(dim=-1, keepdim=True)
45
+ v = r0 * v1 + r1 * v0 + torch.linalg.cross(v0, v1)
46
+ return torch.cat([r, v], dim=-1)
47
+
48
+
49
+ def quaternion_inverse(q, eps=1e-8):
50
+ """
51
+ EXPECTS WXYZ
52
+ :param q (*, 4)
53
+ """
54
+ conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1)
55
+ mag = torch.square(q).sum(dim=-1, keepdim=True) + eps
56
+ return conj / mag
57
+
58
+
59
+ def quaternion_slerp(t, q0, q1, eps=1e-8):
60
+ """
61
+ :param t (*, 1) must be between 0 and 1
62
+ :param q0 (*, 4)
63
+ :param q1 (*, 4)
64
+ """
65
+ dims = q0.shape[:-1]
66
+ t = t.view(*dims, 1)
67
+
68
+ q0 = F.normalize(q0, p=2, dim=-1)
69
+ q1 = F.normalize(q1, p=2, dim=-1)
70
+ dot = (q0 * q1).sum(dim=-1, keepdim=True)
71
+
72
+ # make sure we give the shortest rotation path (< 180d)
73
+ neg = dot < 0
74
+ q1 = torch.where(neg, -q1, q1)
75
+ dot = torch.where(neg, -dot, dot)
76
+ angle = torch.acos(dot)
77
+
78
+ # if angle is too small, just do linear interpolation
79
+ collin = torch.abs(dot) > 1 - eps
80
+ fac = 1 / torch.sin(angle)
81
+ w0 = torch.where(collin, 1 - t, torch.sin((1 - t) * angle) * fac)
82
+ w1 = torch.where(collin, t, torch.sin(t * angle) * fac)
83
+ slerp = q0 * w0 + q1 * w1
84
+ return slerp
85
+
86
+
87
+ def rotation_matrix_to_angle_axis(rotation_matrix):
88
+ """
89
+ This function is borrowed from https://github.com/kornia/kornia
90
+
91
+ Convert rotation matrix to Rodrigues vector
92
+ """
93
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
94
+ aa = quaternion_to_angle_axis(quaternion)
95
+ aa[torch.isnan(aa)] = 0.0
96
+ return aa
97
+
98
+
99
+ def quaternion_to_angle_axis(quaternion):
100
+ """
101
+ This function is borrowed from https://github.com/kornia/kornia
102
+
103
+ Convert quaternion vector to angle axis of rotation.
104
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
105
+
106
+ :param quaternion (*, 4) expects WXYZ
107
+ :returns angle_axis (*, 3)
108
+ """
109
+ # unpack input and compute conversion
110
+ q1 = quaternion[..., 1]
111
+ q2 = quaternion[..., 2]
112
+ q3 = quaternion[..., 3]
113
+ sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3
114
+
115
+ sin_theta = torch.sqrt(sin_squared_theta)
116
+ cos_theta = quaternion[..., 0]
117
+ two_theta = 2.0 * torch.where(
118
+ cos_theta < 0.0,
119
+ torch.atan2(-sin_theta, -cos_theta),
120
+ torch.atan2(sin_theta, cos_theta),
121
+ )
122
+
123
+ k_pos = two_theta / sin_theta
124
+ k_neg = 2.0 * torch.ones_like(sin_theta)
125
+ k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
126
+
127
+ angle_axis = torch.zeros_like(quaternion)[..., :3]
128
+ angle_axis[..., 0] += q1 * k
129
+ angle_axis[..., 1] += q2 * k
130
+ angle_axis[..., 2] += q3 * k
131
+ return angle_axis
132
+
133
+
134
+ def angle_axis_to_rotation_matrix(angle_axis):
135
+ """
136
+ :param angle_axis (*, 3)
137
+ return (*, 3, 3)
138
+ """
139
+ quat = angle_axis_to_quaternion(angle_axis)
140
+ return quaternion_to_rotation_matrix(quat)
141
+
142
+
143
+ def quaternion_to_rotation_matrix(quaternion):
144
+ """
145
+ Convert a quaternion to a rotation matrix.
146
+ Taken from https://github.com/kornia/kornia, based on
147
+ https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101
148
+ https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247
149
+ :param quaternion (N, 4) expects WXYZ order
150
+ returns rotation matrix (N, 3, 3)
151
+ """
152
+ # normalize the input quaternion
153
+ quaternion_norm = F.normalize(quaternion, p=2, dim=-1, eps=1e-12)
154
+ *dims, _ = quaternion_norm.shape
155
+
156
+ # unpack the normalized quaternion components
157
+ w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1)
158
+
159
+ # compute the actual conversion
160
+ tx = 2.0 * x
161
+ ty = 2.0 * y
162
+ tz = 2.0 * z
163
+ twx = tx * w
164
+ twy = ty * w
165
+ twz = tz * w
166
+ txx = tx * x
167
+ txy = ty * x
168
+ txz = tz * x
169
+ tyy = ty * y
170
+ tyz = tz * y
171
+ tzz = tz * z
172
+ one = torch.tensor(1.0)
173
+
174
+ matrix = torch.stack(
175
+ (
176
+ one - (tyy + tzz),
177
+ txy - twz,
178
+ txz + twy,
179
+ txy + twz,
180
+ one - (txx + tzz),
181
+ tyz - twx,
182
+ txz - twy,
183
+ tyz + twx,
184
+ one - (txx + tyy),
185
+ ),
186
+ dim=-1,
187
+ ).view(*dims, 3, 3)
188
+ return matrix
189
+
190
+
191
+ def angle_axis_to_quaternion(angle_axis):
192
+ """
193
+ This function is borrowed from https://github.com/kornia/kornia
194
+ Convert angle axis to quaternion in WXYZ order
195
+ :param angle_axis (*, 3)
196
+ :returns quaternion (*, 4) WXYZ order
197
+ """
198
+ theta_sq = torch.sum(angle_axis**2, dim=-1, keepdim=True) # (*, 1)
199
+ # need to handle the zero rotation case
200
+ valid = theta_sq > 0
201
+ theta = torch.sqrt(theta_sq)
202
+ half_theta = 0.5 * theta
203
+ ones = torch.ones_like(half_theta)
204
+ # fill zero with the limit of sin ax / x -> a
205
+ k = torch.where(valid, torch.sin(half_theta) / theta, 0.5 * ones)
206
+ w = torch.where(valid, torch.cos(half_theta), ones)
207
+ quat = torch.cat([w, k * angle_axis], dim=-1)
208
+ return quat
209
+
210
+
211
+ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
212
+ """
213
+ This function is borrowed from https://github.com/kornia/kornia
214
+ Convert rotation matrix to 4d quaternion vector
215
+ This algorithm is based on algorithm described in
216
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
217
+
218
+ :param rotation_matrix (N, 3, 3)
219
+ """
220
+ *dims, m, n = rotation_matrix.shape
221
+ rmat_t = torch.transpose(rotation_matrix.reshape(-1, m, n), -1, -2)
222
+
223
+ mask_d2 = rmat_t[:, 2, 2] < eps
224
+
225
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
226
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
227
+
228
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
229
+ q0 = torch.stack(
230
+ [
231
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
232
+ t0,
233
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
234
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
235
+ ],
236
+ -1,
237
+ )
238
+ t0_rep = t0.repeat(4, 1).t()
239
+
240
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
241
+ q1 = torch.stack(
242
+ [
243
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
244
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
245
+ t1,
246
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
247
+ ],
248
+ -1,
249
+ )
250
+ t1_rep = t1.repeat(4, 1).t()
251
+
252
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
253
+ q2 = torch.stack(
254
+ [
255
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
256
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
257
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
258
+ t2,
259
+ ],
260
+ -1,
261
+ )
262
+ t2_rep = t2.repeat(4, 1).t()
263
+
264
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
265
+ q3 = torch.stack(
266
+ [
267
+ t3,
268
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
269
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
270
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
271
+ ],
272
+ -1,
273
+ )
274
+ t3_rep = t3.repeat(4, 1).t()
275
+
276
+ mask_c0 = mask_d2 * mask_d0_d1
277
+ mask_c1 = mask_d2 * ~mask_d0_d1
278
+ mask_c2 = ~mask_d2 * mask_d0_nd1
279
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
280
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
281
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
282
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
283
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
284
+
285
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
286
+ q /= torch.sqrt(
287
+ t0_rep * mask_c0
288
+ + t1_rep * mask_c1
289
+ + t2_rep * mask_c2 # noqa
290
+ + t3_rep * mask_c3
291
+ ) # noqa
292
+ q *= 0.5
293
+ return q.reshape(*dims, 4)
imgui.ini ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [Window][Debug##Default]
2
+ Pos=60,60
3
+ Size=400,400
4
+ Collapsed=0
5
+
6
+ [Window][Editor]
7
+ Pos=50,50
8
+ Size=250,700
9
+ Collapsed=0
10
+
11
+ [Window][Playback]
12
+ Pos=50,800
13
+ Size=400,175
14
+ Collapsed=1
15
+
infiller/hand_utils/geometry.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def perspective_projection(points, rotation, translation,
7
+ focal_length, camera_center, distortion=None):
8
+ """
9
+ This function computes the perspective projection of a set of points.
10
+ Input:
11
+ points (bs, N, 3): 3D points
12
+ rotation (bs, 3, 3): Camera rotation
13
+ translation (bs, 3): Camera translation
14
+ focal_length (bs,) or scalar: Focal length
15
+ camera_center (bs, 2): Camera center
16
+ """
17
+ batch_size = points.shape[0]
18
+
19
+ # Extrinsic
20
+ if rotation is not None:
21
+ points = torch.einsum('bij,bkj->bki', rotation, points)
22
+
23
+ if translation is not None:
24
+ points = points + translation.unsqueeze(1)
25
+
26
+ if distortion is not None:
27
+ kc = distortion
28
+ points = points[:,:,:2] / points[:,:,2:]
29
+
30
+ r2 = points[:,:,0]**2 + points[:,:,1]**2
31
+ dx = (2 * kc[:,[2]] * points[:,:,0] * points[:,:,1]
32
+ + kc[:,[3]] * (r2 + 2*points[:,:,0]**2))
33
+
34
+ dy = (2 * kc[:,[3]] * points[:,:,0] * points[:,:,1]
35
+ + kc[:,[2]] * (r2 + 2*points[:,:,1]**2))
36
+
37
+ x = (1 + kc[:,[0]]*r2 + kc[:,[1]]*r2.pow(2) + kc[:,[4]]*r2.pow(3)) * points[:,:,0] + dx
38
+ y = (1 + kc[:,[0]]*r2 + kc[:,[1]]*r2.pow(2) + kc[:,[4]]*r2.pow(3)) * points[:,:,1] + dy
39
+
40
+ points = torch.stack([x, y, torch.ones_like(x)], dim=-1)
41
+
42
+ # Intrinsic
43
+ K = torch.zeros([batch_size, 3, 3], device=points.device)
44
+ K[:,0,0] = focal_length
45
+ K[:,1,1] = focal_length
46
+ K[:,2,2] = 1.
47
+ K[:,:-1, -1] = camera_center
48
+
49
+ # Apply camera intrinsicsrf
50
+ points = points / points[:,:,-1].unsqueeze(-1)
51
+ projected_points = torch.einsum('bij,bkj->bki', K, points)
52
+ projected_points = projected_points[:, :, :-1]
53
+
54
+ return projected_points
55
+
56
+
57
+ def avg_rot(rot):
58
+ # input [B,...,3,3] --> output [...,3,3]
59
+ rot = rot.mean(dim=0)
60
+ U, _, V = torch.svd(rot)
61
+ rot = U @ V.transpose(-1, -2)
62
+ return rot
63
+
64
+
65
+ def rot9d_to_rotmat(x):
66
+ """Convert 9D rotation representation to 3x3 rotation matrix.
67
+ Based on Levinson et al., "An Analysis of SVD for Deep Rotation Estimation"
68
+ Input:
69
+ (B,9) or (B,J*9) Batch of 9D rotation (interpreted as 3x3 est rotmat)
70
+ Output:
71
+ (B,3,3) or (B*J,3,3) Batch of corresponding rotation matrices
72
+ """
73
+ x = x.view(-1,3,3)
74
+ u, _, vh = torch.linalg.svd(x)
75
+
76
+ sig = torch.eye(3).expand(len(x), 3, 3).clone()
77
+ sig = sig.to(x.device)
78
+ sig[:, -1, -1] = (u @ vh).det()
79
+
80
+ R = u @ sig @ vh
81
+
82
+ return R
83
+
84
+
85
+ """
86
+ Deprecated in favor of: rotation_conversions.py
87
+
88
+ Useful geometric operations, e.g. differentiable Rodrigues formula
89
+ Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
90
+ """
91
+ def batch_rodrigues(theta):
92
+ """Convert axis-angle representation to rotation matrix.
93
+ Args:
94
+ theta: size = [B, 3]
95
+ Returns:
96
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
97
+ """
98
+ l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
99
+ angle = torch.unsqueeze(l1norm, -1)
100
+ normalized = torch.div(theta, angle)
101
+ angle = angle * 0.5
102
+ v_cos = torch.cos(angle)
103
+ v_sin = torch.sin(angle)
104
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
105
+ return quat_to_rotmat(quat)
106
+
107
+ def quat_to_rotmat(quat):
108
+ """Convert quaternion coefficients to rotation matrix.
109
+ Args:
110
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
111
+ Returns:
112
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
113
+ """
114
+ norm_quat = quat
115
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
116
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
117
+
118
+ B = quat.size(0)
119
+
120
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
121
+ wx, wy, wz = w*x, w*y, w*z
122
+ xy, xz, yz = x*y, x*z, y*z
123
+
124
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
125
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
126
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
127
+ return rotMat
128
+
129
+ def rot6d_to_rotmat(x):
130
+ """Convert 6D rotation representation to 3x3 rotation matrix.
131
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
132
+ Input:
133
+ (B,6) Batch of 6-D rotation representations
134
+ Output:
135
+ (B,3,3) Batch of corresponding rotation matrices
136
+ """
137
+ x = x.view(-1,3,2)
138
+ a1 = x[:, :, 0]
139
+ a2 = x[:, :, 1]
140
+ b1 = F.normalize(a1)
141
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
142
+ b3 = torch.cross(b1, b2)
143
+ return torch.stack((b1, b2, b3), dim=-1)
144
+
145
+ def rot6d_to_rotmat_hmr2(x: torch.Tensor) -> torch.Tensor:
146
+ """
147
+ Convert 6D rotation representation to 3x3 rotation matrix.
148
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
149
+ Args:
150
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
151
+ Returns:
152
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
153
+ """
154
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
155
+ a1 = x[:, :, 0]
156
+ a2 = x[:, :, 1]
157
+ b1 = F.normalize(a1)
158
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
159
+ b3 = torch.cross(b1, b2)
160
+ return torch.stack((b1, b2, b3), dim=-1)
161
+
162
+ def rotmat_to_rot6d(rotmat):
163
+ """ Inverse function of the above.
164
+ Input:
165
+ (B,3,3) Batch of corresponding rotation matrices
166
+ Output:
167
+ (B,6) Batch of 6-D rotation representations
168
+ """
169
+ # rot6d = rotmat[:, :, :2]
170
+ rot6d = rotmat[...,:2]
171
+ rot6d = rot6d.reshape(rot6d.size(0), -1)
172
+ return rot6d
173
+
174
+
175
+ def rotation_matrix_to_angle_axis(rotation_matrix):
176
+ """
177
+ This function is borrowed from https://github.com/kornia/kornia
178
+
179
+ Convert 3x4 rotation matrix to Rodrigues vector
180
+
181
+ Args:
182
+ rotation_matrix (Tensor): rotation matrix.
183
+
184
+ Returns:
185
+ Tensor: Rodrigues vector transformation.
186
+
187
+ Shape:
188
+ - Input: :math:`(N, 3, 4)`
189
+ - Output: :math:`(N, 3)`
190
+
191
+ Example:
192
+ >>> input = torch.rand(2, 3, 4) # Nx4x4
193
+ >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
194
+ """
195
+ if rotation_matrix.shape[1:] == (3,3):
196
+ rot_mat = rotation_matrix.reshape(-1, 3, 3)
197
+ hom = torch.tensor([0, 0, 1], dtype=torch.float32,
198
+ device=rotation_matrix.device).reshape(1, 3, 1).expand(rot_mat.shape[0], -1, -1)
199
+ rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
200
+
201
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
202
+ aa = quaternion_to_angle_axis(quaternion)
203
+ aa[torch.isnan(aa)] = 0.0
204
+ return aa
205
+
206
+
207
+ def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
208
+ """
209
+ This function is borrowed from https://github.com/kornia/kornia
210
+
211
+ Convert quaternion vector to angle axis of rotation.
212
+
213
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
214
+
215
+ Args:
216
+ quaternion (torch.Tensor): tensor with quaternions.
217
+
218
+ Return:
219
+ torch.Tensor: tensor with angle axis of rotation.
220
+
221
+ Shape:
222
+ - Input: :math:`(*, 4)` where `*` means, any number of dimensions
223
+ - Output: :math:`(*, 3)`
224
+
225
+ Example:
226
+ >>> quaternion = torch.rand(2, 4) # Nx4
227
+ >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
228
+ """
229
+ if not torch.is_tensor(quaternion):
230
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
231
+ type(quaternion)))
232
+
233
+ if not quaternion.shape[-1] == 4:
234
+ raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
235
+ .format(quaternion.shape))
236
+ # unpack input and compute conversion
237
+ q1: torch.Tensor = quaternion[..., 1]
238
+ q2: torch.Tensor = quaternion[..., 2]
239
+ q3: torch.Tensor = quaternion[..., 3]
240
+ sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
241
+
242
+ sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
243
+ cos_theta: torch.Tensor = quaternion[..., 0]
244
+ two_theta: torch.Tensor = 2.0 * torch.where(
245
+ cos_theta < 0.0,
246
+ torch.atan2(-sin_theta, -cos_theta),
247
+ torch.atan2(sin_theta, cos_theta))
248
+
249
+ k_pos: torch.Tensor = two_theta / sin_theta
250
+ k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
251
+ k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
252
+
253
+ angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
254
+ angle_axis[..., 0] += q1 * k
255
+ angle_axis[..., 1] += q2 * k
256
+ angle_axis[..., 2] += q3 * k
257
+ return angle_axis
258
+
259
+
260
+ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
261
+ """
262
+ This function is borrowed from https://github.com/kornia/kornia
263
+
264
+ Convert 3x4 rotation matrix to 4d quaternion vector
265
+
266
+ This algorithm is based on algorithm described in
267
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
268
+
269
+ Args:
270
+ rotation_matrix (Tensor): the rotation matrix to convert.
271
+
272
+ Return:
273
+ Tensor: the rotation in quaternion
274
+
275
+ Shape:
276
+ - Input: :math:`(N, 3, 4)`
277
+ - Output: :math:`(N, 4)`
278
+
279
+ Example:
280
+ >>> input = torch.rand(4, 3, 4) # Nx3x4
281
+ >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
282
+ """
283
+ if not torch.is_tensor(rotation_matrix):
284
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
285
+ type(rotation_matrix)))
286
+
287
+ if len(rotation_matrix.shape) > 3:
288
+ raise ValueError(
289
+ "Input size must be a three dimensional tensor. Got {}".format(
290
+ rotation_matrix.shape))
291
+ if not rotation_matrix.shape[-2:] == (3, 4):
292
+ raise ValueError(
293
+ "Input size must be a N x 3 x 4 tensor. Got {}".format(
294
+ rotation_matrix.shape))
295
+
296
+ rmat_t = torch.transpose(rotation_matrix, 1, 2)
297
+
298
+ mask_d2 = rmat_t[:, 2, 2] < eps
299
+
300
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
301
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
302
+
303
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
304
+ q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
305
+ t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
306
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
307
+ t0_rep = t0.repeat(4, 1).t()
308
+
309
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
310
+ q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
311
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
312
+ t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
313
+ t1_rep = t1.repeat(4, 1).t()
314
+
315
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
316
+ q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
317
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
318
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
319
+ t2_rep = t2.repeat(4, 1).t()
320
+
321
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
322
+ q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
323
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
324
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
325
+ t3_rep = t3.repeat(4, 1).t()
326
+
327
+ mask_c0 = mask_d2 * mask_d0_d1
328
+ mask_c1 = mask_d2 * ~mask_d0_d1
329
+ mask_c2 = ~mask_d2 * mask_d0_nd1
330
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
331
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
332
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
333
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
334
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
335
+
336
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
337
+ q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
338
+ t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
339
+ q *= 0.5
340
+ return q
341
+
342
+
343
+ def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000., img_size=224.):
344
+ """
345
+ This function is borrowed from https://github.com/nkolot/SPIN/utils/geometry.py
346
+
347
+ Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
348
+ Input:
349
+ S: (25, 3) 3D joint locations
350
+ joints: (25, 3) 2D joint locations and confidence
351
+ Returns:
352
+ (3,) camera translation vector
353
+ """
354
+
355
+ num_joints = S.shape[0]
356
+ # focal length
357
+ f = np.array([focal_length,focal_length])
358
+ # optical center
359
+ center = np.array([img_size/2., img_size/2.])
360
+
361
+ # transformations
362
+ Z = np.reshape(np.tile(S[:,2],(2,1)).T,-1)
363
+ XY = np.reshape(S[:,0:2],-1)
364
+ O = np.tile(center,num_joints)
365
+ F = np.tile(f,num_joints)
366
+ weight2 = np.reshape(np.tile(np.sqrt(joints_conf),(2,1)).T,-1)
367
+
368
+ # least squares
369
+ Q = np.array([F*np.tile(np.array([1,0]),num_joints), F*np.tile(np.array([0,1]),num_joints), O-np.reshape(joints_2d,-1)]).T
370
+ c = (np.reshape(joints_2d,-1)-O)*Z - F*XY
371
+
372
+ # weighted least squares
373
+ W = np.diagflat(weight2)
374
+ Q = np.dot(W,Q)
375
+ c = np.dot(W,c)
376
+
377
+ # square matrix
378
+ A = np.dot(Q.T,Q)
379
+ b = np.dot(Q.T,c)
380
+
381
+ # solution
382
+ trans = np.linalg.solve(A, b)
383
+
384
+ return trans
385
+
386
+
387
+ def estimate_translation(S, joints_2d, focal_length=5000., img_size=224.):
388
+ """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
389
+ Input:
390
+ S: (B, 49, 3) 3D joint locations
391
+ joints: (B, 49, 3) 2D joint locations and confidence
392
+ Returns:
393
+ (B, 3) camera translation vectors
394
+ """
395
+
396
+ device = S.device
397
+ # Use only joints 25:49 (GT joints)
398
+ S = S[:, -24:, :3].cpu().numpy()
399
+ joints_2d = joints_2d[:, -24:, :].cpu().numpy()
400
+
401
+ joints_conf = joints_2d[:, :, -1]
402
+ joints_2d = joints_2d[:, :, :-1]
403
+ trans = np.zeros((S.shape[0], 3), dtype=np.float32)
404
+ # Find the translation for each example in the batch
405
+ for i in range(S.shape[0]):
406
+ S_i = S[i]
407
+ joints_i = joints_2d[i]
408
+ conf_i = joints_conf[i]
409
+ trans[i] = estimate_translation_np(S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size)
410
+ return torch.from_numpy(trans).to(device)
411
+
412
+
infiller/hand_utils/geometry_utils.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+ def aa_to_rotmat(theta: torch.Tensor):
6
+ """
7
+ Convert axis-angle representation to rotation matrix.
8
+ Works by first converting it to a quaternion.
9
+ Args:
10
+ theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
11
+ Returns:
12
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
13
+ """
14
+ norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
15
+ angle = torch.unsqueeze(norm, -1)
16
+ normalized = torch.div(theta, angle)
17
+ angle = angle * 0.5
18
+ v_cos = torch.cos(angle)
19
+ v_sin = torch.sin(angle)
20
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
21
+ return quat_to_rotmat(quat)
22
+
23
+ def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
24
+ """
25
+ Convert quaternion representation to rotation matrix.
26
+ Args:
27
+ quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
28
+ Returns:
29
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
30
+ """
31
+ norm_quat = quat
32
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
33
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
34
+
35
+ B = quat.size(0)
36
+
37
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
38
+ wx, wy, wz = w*x, w*y, w*z
39
+ xy, xz, yz = x*y, x*z, y*z
40
+
41
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
42
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
43
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
44
+ return rotMat
45
+
46
+
47
+ def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Convert 6D rotation representation to 3x3 rotation matrix.
50
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
51
+ Args:
52
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
53
+ Returns:
54
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
55
+ """
56
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
57
+ a1 = x[:, :, 0]
58
+ a2 = x[:, :, 1]
59
+ b1 = F.normalize(a1)
60
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
61
+ b3 = torch.linalg.cross(b1, b2)
62
+ return torch.stack((b1, b2, b3), dim=-1)
63
+
64
+ def perspective_projection(points: torch.Tensor,
65
+ translation: torch.Tensor,
66
+ focal_length: torch.Tensor,
67
+ camera_center: Optional[torch.Tensor] = None,
68
+ rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
69
+ """
70
+ Computes the perspective projection of a set of 3D points.
71
+ Args:
72
+ points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
73
+ translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
74
+ focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
75
+ camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
76
+ rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
77
+ Returns:
78
+ torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
79
+ """
80
+ batch_size = points.shape[0]
81
+ if rotation is None:
82
+ rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
83
+ if camera_center is None:
84
+ camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
85
+ # Populate intrinsic camera matrix K.
86
+ K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
87
+ K[:,0,0] = focal_length[:,0]
88
+ K[:,1,1] = focal_length[:,1]
89
+ K[:,2,2] = 1.
90
+ K[:,:-1, -1] = camera_center
91
+
92
+ # Transform points
93
+ points = torch.einsum('bij,bkj->bki', rotation, points)
94
+ points = points + translation.unsqueeze(1)
95
+
96
+ # Apply perspective distortion
97
+ projected_points = points / points[:,:,-1].unsqueeze(-1)
98
+
99
+ # Apply camera intrinsics
100
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
101
+
102
+ return projected_points[:, :, :-1]
infiller/hand_utils/mano_wrapper.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pickle
4
+ from typing import Optional
5
+ import smplx
6
+ from smplx.lbs import vertices2joints
7
+ from smplx.utils import MANOOutput, to_tensor
8
+ from smplx.vertex_ids import vertex_ids
9
+
10
+
11
+ class MANO(smplx.MANOLayer):
12
+ def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs):
13
+ """
14
+ Extension of the official MANO implementation to support more joints.
15
+ Args:
16
+ Same as MANOLayer.
17
+ joint_regressor_extra (str): Path to extra joint regressor.
18
+ """
19
+ super(MANO, self).__init__(*args, **kwargs)
20
+ mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]
21
+
22
+ #2, 3, 5, 4, 1
23
+ if joint_regressor_extra is not None:
24
+ self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
25
+ self.register_buffer('extra_joints_idxs', to_tensor(list(vertex_ids['mano'].values()), dtype=torch.long))
26
+ self.register_buffer('joint_map', torch.tensor(mano_to_openpose, dtype=torch.long))
27
+
28
+ def forward(self, *args, **kwargs) -> MANOOutput:
29
+ """
30
+ Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified.
31
+ """
32
+ mano_output = super(MANO, self).forward(*args, **kwargs)
33
+ extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs)
34
+ joints = torch.cat([mano_output.joints, extra_joints], dim=1)
35
+ joints = joints[:, self.joint_map, :]
36
+ if hasattr(self, 'joint_regressor_extra'):
37
+ extra_joints = vertices2joints(self.joint_regressor_extra, mano_output.vertices)
38
+ joints = torch.cat([joints, extra_joints], dim=1)
39
+ mano_output.joints = joints
40
+ return mano_output
41
+
42
+ def query(self, hmr_output):
43
+ batch_size = hmr_output['pred_rotmat'].shape[0]
44
+ pred_rotmat = hmr_output['pred_rotmat'].reshape(batch_size, -1, 3, 3)
45
+ pred_shape = hmr_output['pred_shape'].reshape(batch_size, 10)
46
+
47
+ mano_output = self(global_orient=pred_rotmat[:, [0]],
48
+ hand_pose = pred_rotmat[:, 1:],
49
+ betas = pred_shape,
50
+ pose2rot=False)
51
+
52
+ return mano_output
infiller/hand_utils/process.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from hand_utils.mano_wrapper import MANO
3
+ from hand_utils.geometry_utils import aa_to_rotmat
4
+ import numpy as np
5
+
6
+ def run_mano(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True):
7
+ """
8
+ Forward pass of the SMPL model and populates pred_data accordingly with
9
+ joints3d, verts3d, points3d.
10
+
11
+ trans : B x T x 3
12
+ root_orient : B x T x 3
13
+ body_pose : B x T x J*3
14
+ betas : (optional) B x D
15
+ """
16
+ MANO_cfg = {
17
+ 'DATA_DIR': '_DATA/data/',
18
+ 'MODEL_PATH': '_DATA/data/mano',
19
+ 'GENDER': 'neutral',
20
+ 'NUM_HAND_JOINTS': 15,
21
+ 'CREATE_BODY_POSE': False
22
+ }
23
+ mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
24
+ mano = MANO(**mano_cfg)
25
+ if use_cuda:
26
+ mano = mano.cuda()
27
+
28
+ B, T, _ = root_orient.shape
29
+ NUM_JOINTS = 15
30
+ mano_params = {
31
+ 'global_orient': root_orient.reshape(B*T, -1),
32
+ 'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3),
33
+ 'betas': betas.reshape(B*T, -1),
34
+ }
35
+ rotmat_mano_params = mano_params
36
+ rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3)
37
+ rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3)
38
+ rotmat_mano_params['transl'] = trans.reshape(B*T, 3)
39
+
40
+ if use_cuda:
41
+ mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False)
42
+ else:
43
+ mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False)
44
+
45
+ faces_right = mano.faces
46
+ faces_new = np.array([[92, 38, 234],
47
+ [234, 38, 239],
48
+ [38, 122, 239],
49
+ [239, 122, 279],
50
+ [122, 118, 279],
51
+ [279, 118, 215],
52
+ [118, 117, 215],
53
+ [215, 117, 214],
54
+ [117, 119, 214],
55
+ [214, 119, 121],
56
+ [119, 120, 121],
57
+ [121, 120, 78],
58
+ [120, 108, 78],
59
+ [78, 108, 79]])
60
+ faces_right = np.concatenate([faces_right, faces_new], axis=0)
61
+ faces_n = len(faces_right)
62
+ faces_left = faces_right[:,[0,2,1]]
63
+
64
+ outputs = {
65
+ "joints": mano_output.joints.reshape(B, T, -1, 3),
66
+ "vertices": mano_output.vertices.reshape(B, T, -1, 3),
67
+ }
68
+
69
+ if not is_right is None:
70
+ # outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0]
71
+ # outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0]
72
+ is_right = (is_right[:, :, 0].cpu().numpy() > 0)
73
+ faces_result = np.zeros((B, T, faces_n, 3))
74
+ faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0)
75
+ faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0)
76
+ faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded)
77
+ outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32))
78
+
79
+
80
+ return outputs
81
+
82
+ def run_mano_left(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True, fix_shapedirs=True):
83
+ """
84
+ Forward pass of the SMPL model and populates pred_data accordingly with
85
+ joints3d, verts3d, points3d.
86
+
87
+ trans : B x T x 3
88
+ root_orient : B x T x 3
89
+ body_pose : B x T x J*3
90
+ betas : (optional) B x D
91
+ """
92
+ MANO_cfg = {
93
+ 'DATA_DIR': '_DATA/data_left/',
94
+ 'MODEL_PATH': '_DATA/data_left/mano_left',
95
+ 'GENDER': 'neutral',
96
+ 'NUM_HAND_JOINTS': 15,
97
+ 'CREATE_BODY_POSE': False,
98
+ 'is_rhand': False
99
+ }
100
+ mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
101
+ mano = MANO(**mano_cfg)
102
+ if use_cuda:
103
+ mano = mano.cuda()
104
+
105
+ # fix MANO shapedirs of the left hand bug (https://github.com/vchoutas/smplx/issues/48)
106
+ if fix_shapedirs:
107
+ mano.shapedirs[:, 0, :] *= -1
108
+
109
+ B, T, _ = root_orient.shape
110
+ NUM_JOINTS = 15
111
+ mano_params = {
112
+ 'global_orient': root_orient.reshape(B*T, -1),
113
+ 'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3),
114
+ 'betas': betas.reshape(B*T, -1),
115
+ }
116
+ rotmat_mano_params = mano_params
117
+ rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3)
118
+ rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3)
119
+ rotmat_mano_params['transl'] = trans.reshape(B*T, 3)
120
+
121
+ if use_cuda:
122
+ mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False)
123
+ else:
124
+ mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False)
125
+
126
+ faces_right = mano.faces
127
+ faces_new = np.array([[92, 38, 234],
128
+ [234, 38, 239],
129
+ [38, 122, 239],
130
+ [239, 122, 279],
131
+ [122, 118, 279],
132
+ [279, 118, 215],
133
+ [118, 117, 215],
134
+ [215, 117, 214],
135
+ [117, 119, 214],
136
+ [214, 119, 121],
137
+ [119, 120, 121],
138
+ [121, 120, 78],
139
+ [120, 108, 78],
140
+ [78, 108, 79]])
141
+ faces_right = np.concatenate([faces_right, faces_new], axis=0)
142
+ faces_n = len(faces_right)
143
+ faces_left = faces_right[:,[0,2,1]]
144
+
145
+ outputs = {
146
+ "joints": mano_output.joints.reshape(B, T, -1, 3),
147
+ "vertices": mano_output.vertices.reshape(B, T, -1, 3),
148
+ }
149
+
150
+ if not is_right is None:
151
+ # outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0]
152
+ # outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0]
153
+ is_right = (is_right[:, :, 0].cpu().numpy() > 0)
154
+ faces_result = np.zeros((B, T, faces_n, 3))
155
+ faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0)
156
+ faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0)
157
+ faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded)
158
+ outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32))
159
+
160
+
161
+ return outputs
162
+
163
+ def run_mano_twohands(init_trans, init_rot, init_hand_pose, is_right, init_betas, use_cuda=True, fix_shapedirs=True):
164
+ outputs_left = run_mano_left(init_trans[0:1], init_rot[0:1], init_hand_pose[0:1], None, init_betas[0:1], use_cuda=use_cuda, fix_shapedirs=fix_shapedirs)
165
+ outputs_right = run_mano(init_trans[1:2], init_rot[1:2], init_hand_pose[1:2], None, init_betas[1:2], use_cuda=use_cuda)
166
+ outputs_two = {
167
+ "vertices": torch.cat((outputs_left["vertices"], outputs_right["vertices"]), dim=0),
168
+ "joints": torch.cat((outputs_left["joints"], outputs_right["joints"]), dim=0)
169
+
170
+ }
171
+ return outputs_two
infiller/hand_utils/rotation.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
7
+ """
8
+ Taken from https://github.com/mkocabas/VIBE/blob/master/lib/utils/geometry.py
9
+ Calculates the rotation matrices for a batch of rotation vectors
10
+ - param rot_vecs: torch.tensor (N, 3) array of N axis-angle vectors
11
+ - returns R: torch.tensor (N, 3, 3) rotation matrices
12
+ """
13
+ batch_size = rot_vecs.shape[0]
14
+ device = rot_vecs.device
15
+
16
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
17
+ rot_dir = rot_vecs / angle
18
+
19
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
20
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
21
+
22
+ # Bx1 arrays
23
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
24
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
25
+
26
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
27
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(
28
+ (batch_size, 3, 3)
29
+ )
30
+
31
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
32
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
33
+ return rot_mat
34
+
35
+
36
+ def quaternion_mul(q0, q1):
37
+ """
38
+ EXPECTS WXYZ
39
+ :param q0 (*, 4)
40
+ :param q1 (*, 4)
41
+ """
42
+ r0, r1 = q0[..., :1], q1[..., :1]
43
+ v0, v1 = q0[..., 1:], q1[..., 1:]
44
+ r = r0 * r1 - (v0 * v1).sum(dim=-1, keepdim=True)
45
+ v = r0 * v1 + r1 * v0 + torch.linalg.cross(v0, v1)
46
+ return torch.cat([r, v], dim=-1)
47
+
48
+
49
+ def quaternion_inverse(q, eps=1e-8):
50
+ """
51
+ EXPECTS WXYZ
52
+ :param q (*, 4)
53
+ """
54
+ conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1)
55
+ mag = torch.square(q).sum(dim=-1, keepdim=True) + eps
56
+ return conj / mag
57
+
58
+
59
+ def quaternion_slerp(t, q0, q1, eps=1e-8):
60
+ """
61
+ :param t (*, 1) must be between 0 and 1
62
+ :param q0 (*, 4)
63
+ :param q1 (*, 4)
64
+ """
65
+ dims = q0.shape[:-1]
66
+ t = t.view(*dims, 1)
67
+
68
+ q0 = F.normalize(q0, p=2, dim=-1)
69
+ q1 = F.normalize(q1, p=2, dim=-1)
70
+ dot = (q0 * q1).sum(dim=-1, keepdim=True)
71
+
72
+ # make sure we give the shortest rotation path (< 180d)
73
+ neg = dot < 0
74
+ q1 = torch.where(neg, -q1, q1)
75
+ dot = torch.where(neg, -dot, dot)
76
+ angle = torch.acos(dot)
77
+
78
+ # if angle is too small, just do linear interpolation
79
+ collin = torch.abs(dot) > 1 - eps
80
+ fac = 1 / torch.sin(angle)
81
+ w0 = torch.where(collin, 1 - t, torch.sin((1 - t) * angle) * fac)
82
+ w1 = torch.where(collin, t, torch.sin(t * angle) * fac)
83
+ slerp = q0 * w0 + q1 * w1
84
+ return slerp
85
+
86
+
87
+ def rotation_matrix_to_angle_axis(rotation_matrix):
88
+ """
89
+ This function is borrowed from https://github.com/kornia/kornia
90
+
91
+ Convert rotation matrix to Rodrigues vector
92
+ """
93
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
94
+ aa = quaternion_to_angle_axis(quaternion)
95
+ aa[torch.isnan(aa)] = 0.0
96
+ return aa
97
+
98
+
99
+ def quaternion_to_angle_axis(quaternion):
100
+ """
101
+ This function is borrowed from https://github.com/kornia/kornia
102
+
103
+ Convert quaternion vector to angle axis of rotation.
104
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
105
+
106
+ :param quaternion (*, 4) expects WXYZ
107
+ :returns angle_axis (*, 3)
108
+ """
109
+ # unpack input and compute conversion
110
+ q1 = quaternion[..., 1]
111
+ q2 = quaternion[..., 2]
112
+ q3 = quaternion[..., 3]
113
+ sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3
114
+
115
+ sin_theta = torch.sqrt(sin_squared_theta)
116
+ cos_theta = quaternion[..., 0]
117
+ two_theta = 2.0 * torch.where(
118
+ cos_theta < 0.0,
119
+ torch.atan2(-sin_theta, -cos_theta),
120
+ torch.atan2(sin_theta, cos_theta),
121
+ )
122
+
123
+ k_pos = two_theta / sin_theta
124
+ k_neg = 2.0 * torch.ones_like(sin_theta)
125
+ k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
126
+
127
+ angle_axis = torch.zeros_like(quaternion)[..., :3]
128
+ angle_axis[..., 0] += q1 * k
129
+ angle_axis[..., 1] += q2 * k
130
+ angle_axis[..., 2] += q3 * k
131
+ return angle_axis
132
+
133
+
134
+ def angle_axis_to_rotation_matrix(angle_axis):
135
+ """
136
+ :param angle_axis (*, 3)
137
+ return (*, 3, 3)
138
+ """
139
+ quat = angle_axis_to_quaternion(angle_axis)
140
+ return quaternion_to_rotation_matrix(quat)
141
+
142
+
143
+ def quaternion_to_rotation_matrix(quaternion):
144
+ """
145
+ Convert a quaternion to a rotation matrix.
146
+ Taken from https://github.com/kornia/kornia, based on
147
+ https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101
148
+ https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247
149
+ :param quaternion (N, 4) expects WXYZ order
150
+ returns rotation matrix (N, 3, 3)
151
+ """
152
+ # normalize the input quaternion
153
+ quaternion_norm = F.normalize(quaternion, p=2, dim=-1, eps=1e-12)
154
+ *dims, _ = quaternion_norm.shape
155
+
156
+ # unpack the normalized quaternion components
157
+ w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1)
158
+
159
+ # compute the actual conversion
160
+ tx = 2.0 * x
161
+ ty = 2.0 * y
162
+ tz = 2.0 * z
163
+ twx = tx * w
164
+ twy = ty * w
165
+ twz = tz * w
166
+ txx = tx * x
167
+ txy = ty * x
168
+ txz = tz * x
169
+ tyy = ty * y
170
+ tyz = tz * y
171
+ tzz = tz * z
172
+ one = torch.tensor(1.0)
173
+
174
+ matrix = torch.stack(
175
+ (
176
+ one - (tyy + tzz),
177
+ txy - twz,
178
+ txz + twy,
179
+ txy + twz,
180
+ one - (txx + tzz),
181
+ tyz - twx,
182
+ txz - twy,
183
+ tyz + twx,
184
+ one - (txx + tyy),
185
+ ),
186
+ dim=-1,
187
+ ).view(*dims, 3, 3)
188
+ return matrix
189
+
190
+
191
+ def angle_axis_to_quaternion(angle_axis):
192
+ """
193
+ This function is borrowed from https://github.com/kornia/kornia
194
+ Convert angle axis to quaternion in WXYZ order
195
+ :param angle_axis (*, 3)
196
+ :returns quaternion (*, 4) WXYZ order
197
+ """
198
+ theta_sq = torch.sum(angle_axis**2, dim=-1, keepdim=True) # (*, 1)
199
+ # need to handle the zero rotation case
200
+ valid = theta_sq > 0
201
+ theta = torch.sqrt(theta_sq)
202
+ half_theta = 0.5 * theta
203
+ ones = torch.ones_like(half_theta)
204
+ # fill zero with the limit of sin ax / x -> a
205
+ k = torch.where(valid, torch.sin(half_theta) / theta, 0.5 * ones)
206
+ w = torch.where(valid, torch.cos(half_theta), ones)
207
+ quat = torch.cat([w, k * angle_axis], dim=-1)
208
+ return quat
209
+
210
+
211
+ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
212
+ """
213
+ This function is borrowed from https://github.com/kornia/kornia
214
+ Convert rotation matrix to 4d quaternion vector
215
+ This algorithm is based on algorithm described in
216
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
217
+
218
+ :param rotation_matrix (N, 3, 3)
219
+ """
220
+ *dims, m, n = rotation_matrix.shape
221
+ rmat_t = torch.transpose(rotation_matrix.reshape(-1, m, n), -1, -2)
222
+
223
+ mask_d2 = rmat_t[:, 2, 2] < eps
224
+
225
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
226
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
227
+
228
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
229
+ q0 = torch.stack(
230
+ [
231
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
232
+ t0,
233
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
234
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
235
+ ],
236
+ -1,
237
+ )
238
+ t0_rep = t0.repeat(4, 1).t()
239
+
240
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
241
+ q1 = torch.stack(
242
+ [
243
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
244
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
245
+ t1,
246
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
247
+ ],
248
+ -1,
249
+ )
250
+ t1_rep = t1.repeat(4, 1).t()
251
+
252
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
253
+ q2 = torch.stack(
254
+ [
255
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
256
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
257
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
258
+ t2,
259
+ ],
260
+ -1,
261
+ )
262
+ t2_rep = t2.repeat(4, 1).t()
263
+
264
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
265
+ q3 = torch.stack(
266
+ [
267
+ t3,
268
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
269
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
270
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
271
+ ],
272
+ -1,
273
+ )
274
+ t3_rep = t3.repeat(4, 1).t()
275
+
276
+ mask_c0 = mask_d2 * mask_d0_d1
277
+ mask_c1 = mask_d2 * ~mask_d0_d1
278
+ mask_c2 = ~mask_d2 * mask_d0_nd1
279
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
280
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
281
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
282
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
283
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
284
+
285
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
286
+ q /= torch.sqrt(
287
+ t0_rep * mask_c0
288
+ + t1_rep * mask_c1
289
+ + t2_rep * mask_c2 # noqa
290
+ + t3_rep * mask_c3
291
+ ) # noqa
292
+ q *= 0.5
293
+ return q.reshape(*dims, 4)
infiller/lib/misc/sampler.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import imageio
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from PIL import Image
10
+ from sklearn.preprocessing import LabelEncoder
11
+
12
+ from cmib.data.lafan1_dataset import LAFAN1Dataset
13
+ from cmib.data.utils import write_json
14
+ from cmib.lafan1.utils import quat_ik
15
+ from cmib.model.network import TransformerModel
16
+ from cmib.model.preprocess import (lerp_input_repr, replace_constant,
17
+ slerp_input_repr, vectorize_representation)
18
+ from cmib.model.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets, joint_names,
19
+ sk_parents)
20
+ from cmib.vis.pose import plot_pose_with_stop
21
+
22
+
23
+ def test(opt, device):
24
+
25
+ save_dir = Path(os.path.join('runs', 'train', opt.exp_name))
26
+ wdir = save_dir / 'weights'
27
+ weights = os.listdir(wdir)
28
+ weights_paths = [wdir / weight for weight in weights]
29
+ latest_weight = max(weights_paths , key = os.path.getctime)
30
+ ckpt = torch.load(latest_weight, map_location=device)
31
+ print(f"Loaded weight: {latest_weight}")
32
+
33
+ # Load Skeleton
34
+ skeleton_mocap = Skeleton(offsets=sk_offsets, parents=sk_parents, device=device)
35
+ skeleton_mocap.remove_joints(sk_joints_to_remove)
36
+
37
+ # Load LAFAN Dataset
38
+ Path(opt.processed_data_dir).mkdir(parents=True, exist_ok=True)
39
+ lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=opt.processed_data_dir, train=False, device=device)
40
+ total_data = lafan_dataset.data['global_pos'].shape[0]
41
+
42
+ # Replace with noise to In-betweening Frames
43
+ from_idx, target_idx = ckpt['from_idx'], ckpt['target_idx'] # default: 9-40, max: 48
44
+ horizon = ckpt['horizon']
45
+ print(f"HORIZON: {horizon}")
46
+
47
+ test_idx = []
48
+ for i in range(total_data):
49
+ test_idx.append(i)
50
+
51
+ # Compare Input data, Prediction, GT
52
+ save_path = os.path.join(opt.save_path, 'sampler')
53
+ for i in range(len(test_idx)):
54
+ Path(save_path).mkdir(parents=True, exist_ok=True)
55
+
56
+ start_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx]
57
+ target_pose = lafan_dataset.data['global_pos'][test_idx[i], target_idx]
58
+ gt_stopover_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx]
59
+
60
+ gt_img_path = os.path.join(save_path)
61
+ plot_pose_with_stop(start_pose, target_pose, target_pose, gt_stopover_pose, i, skeleton_mocap, save_dir=gt_img_path, prefix='gt')
62
+ print(f"ID {test_idx[i]}: completed.")
63
+
64
+ def parse_opt():
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument('--project', default='runs/train', help='project/name')
67
+ parser.add_argument('--exp_name', default='slerp_40', help='experiment name')
68
+ parser.add_argument('--data_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH', help='BVH dataset path')
69
+ parser.add_argument('--skeleton_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH/walk1_subject1.bvh', help='path to reference skeleton')
70
+ parser.add_argument('--processed_data_dir', type=str, default='processed_data_original/', help='path to save pickled processed data')
71
+ parser.add_argument('--save_path', type=str, default='runs/test', help='path to save model')
72
+ parser.add_argument('--motion_type', type=str, default='jumps', help='motion type')
73
+ opt = parser.parse_args()
74
+ return opt
75
+
76
+ if __name__ == "__main__":
77
+ opt = parse_opt()
78
+ device = torch.device("cpu")
79
+ test(opt, device)
infiller/lib/model/__pycache__/network.cpython-310.pyc ADDED
Binary file (7.82 kB). View file
 
infiller/lib/model/network.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn, Tensor
5
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
6
+ # from cmib.model.positional_encoding import PositionalEmbedding
7
+
8
+ class SinPositionalEncoding(nn.Module):
9
+ def __init__(self, d_model, dropout=0.1, max_len=100):
10
+ super(SinPositionalEncoding, self).__init__()
11
+ self.dropout = nn.Dropout(p=dropout)
12
+
13
+ pe = torch.zeros(max_len, d_model)
14
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
15
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
16
+ pe[:, 0::2] = torch.sin(position * div_term)
17
+ pe[:, 1::2] = torch.cos(position * div_term)
18
+ pe = pe.unsqueeze(0).transpose(0, 1)
19
+
20
+ self.register_buffer('pe', pe)
21
+
22
+ def forward(self, x):
23
+ # not used in the final model
24
+ x = x + self.pe[:x.shape[0], :]
25
+ return self.dropout(x)
26
+
27
+
28
+ class MultiHeadedAttention(nn.Module):
29
+ def __init__(self, n_head, d_model, d_head, dropout=0.1,
30
+ pre_lnorm=True, bias=False):
31
+ """
32
+ Multi-headed attention with relative positional encoding and
33
+ memory mechanism.
34
+
35
+ Args:
36
+ n_head (int): Number of heads.
37
+ d_model (int): Input dimension.
38
+ d_head (int): Head dimension.
39
+ dropout (float, optional): Dropout value. Defaults to 0.1.
40
+ pre_lnorm (bool, optional):
41
+ Apply layer norm before rest of calculation. Defaults to True.
42
+ In original Transformer paper (pre_lnorm=False):
43
+ LayerNorm(x + Sublayer(x))
44
+ In tensor2tensor implementation (pre_lnorm=True):
45
+ x + Sublayer(LayerNorm(x))
46
+ bias (bool, optional):
47
+ Add bias to q, k, v and output projections. Defaults to False.
48
+
49
+ """
50
+ super(MultiHeadedAttention, self).__init__()
51
+
52
+ self.n_head = n_head
53
+ self.d_model = d_model
54
+ self.d_head = d_head
55
+ self.dropout = dropout
56
+ self.pre_lnorm = pre_lnorm
57
+ self.bias = bias
58
+ self.atten_scale = 1 / math.sqrt(self.d_model)
59
+
60
+ self.q_linear = nn.Linear(d_model, n_head * d_head, bias=bias)
61
+ self.k_linear = nn.Linear(d_model, n_head * d_head, bias=bias)
62
+ self.v_linear = nn.Linear(d_model, n_head * d_head, bias=bias)
63
+ self.out_linear = nn.Linear(n_head * d_head, d_model, bias=bias)
64
+
65
+ self.droput_layer = nn.Dropout(dropout)
66
+ self.atten_dropout_layer = nn.Dropout(dropout)
67
+
68
+ self.layer_norm = nn.LayerNorm(d_model)
69
+
70
+ def forward(self, hidden, memory=None, mask=None,
71
+ extra_atten_score=None):
72
+ """
73
+ Args:
74
+ hidden (Tensor): Input embedding or hidden state of previous layer.
75
+ Shape: (batch, seq, dim)
76
+ pos_emb (Tensor): Relative positional embedding lookup table.
77
+ Shape: (batch, (seq+mem_len)*2-1, d_head)
78
+ pos_emb[:, seq+mem_len]
79
+
80
+ memory (Tensor): Memory tensor of previous layer.
81
+ Shape: (batch, mem_len, dim)
82
+ mask (BoolTensor, optional): Attention mask.
83
+ Set item value to True if you DO NOT want keep certain
84
+ attention score, otherwise False. Defaults to None.
85
+ Shape: (seq, seq+mem_len).
86
+ """
87
+ combined = hidden
88
+ # if memory is None:
89
+ # combined = hidden
90
+ # mem_len = 0
91
+ # else:
92
+ # combined = torch.cat([memory, hidden], dim=1)
93
+ # mem_len = memory.shape[1]
94
+
95
+ if self.pre_lnorm:
96
+ hidden = self.layer_norm(hidden)
97
+ combined = self.layer_norm(combined)
98
+
99
+ # shape: (batch, q/k/v_len, dim)
100
+ q = self.q_linear(hidden)
101
+ k = self.k_linear(combined)
102
+ v = self.v_linear(combined)
103
+
104
+ # reshape to (batch, q/k/v_len, n_head, d_head)
105
+ q = q.reshape(q.shape[0], q.shape[1], self.n_head, self.d_head)
106
+ k = k.reshape(k.shape[0], k.shape[1], self.n_head, self.d_head)
107
+ v = v.reshape(v.shape[0], v.shape[1], self.n_head, self.d_head)
108
+
109
+ # transpose to (batch, n_head, q/k/v_len, d_head)
110
+ q = q.transpose(1, 2)
111
+ k = k.transpose(1, 2)
112
+ v = v.transpose(1, 2)
113
+
114
+ # add n_head dimension for relative positional embedding lookup table
115
+ # (batch, n_head, k/v_len*2-1, d_head)
116
+ # pos_emb = pos_emb[:, None]
117
+
118
+ # (batch, n_head, q_len, k_len)
119
+ atten_score = torch.matmul(q, k.transpose(-1, -2))
120
+
121
+ # qpos = torch.matmul(q, pos_emb.transpose(-1, -2))
122
+ # DEBUG
123
+ # ones = torch.zeros(q.shape)
124
+ # ones[:, :, :, 0] = 1.0
125
+ # qpos = torch.matmul(ones, pos_emb.transpose(-1, -2))
126
+ # atten_score = atten_score + self.skew(qpos, mem_len)
127
+ atten_score = atten_score * self.atten_scale
128
+
129
+ # if extra_atten_score is not None:
130
+ # atten_score = atten_score + extra_atten_score
131
+
132
+ if mask is not None:
133
+ # print(atten_score.shape)
134
+ # print(mask.shape)
135
+ # apply attention mask
136
+ atten_score = atten_score.masked_fill(mask, float("-inf"))
137
+ atten_score = atten_score.softmax(dim=-1)
138
+ atten_score = self.atten_dropout_layer(atten_score)
139
+
140
+ # (batch, n_head, q_len, d_head)
141
+ atten_vec = torch.matmul(atten_score, v)
142
+ # (batch, q_len, n_head*d_head)
143
+ atten_vec = atten_vec.transpose(1, 2).flatten(start_dim=-2)
144
+
145
+ # linear projection
146
+ output = self.droput_layer(self.out_linear(atten_vec))
147
+
148
+ if self.pre_lnorm:
149
+ return hidden + output
150
+ else:
151
+ return self.layer_norm(hidden + output)
152
+
153
+
154
+ class FeedForward(nn.Module):
155
+ def __init__(self, d_model, d_inner, dropout=0.1, pre_lnorm=True):
156
+ """
157
+ Positionwise feed-forward network.
158
+
159
+ Args:
160
+ d_model(int): Dimension of the input and output.
161
+ d_inner (int): Dimension of the middle layer(bottleneck).
162
+ dropout (float, optional): Dropout value. Defaults to 0.1.
163
+ pre_lnorm (bool, optional):
164
+ Apply layer norm before rest of calculation. Defaults to True.
165
+ In original Transformer paper (pre_lnorm=False):
166
+ LayerNorm(x + Sublayer(x))
167
+ In tensor2tensor implementation (pre_lnorm=True):
168
+ x + Sublayer(LayerNorm(x))
169
+ """
170
+ super(FeedForward, self).__init__()
171
+ self.d_model = d_model
172
+ self.d_inner = d_inner
173
+ self.dropout = dropout
174
+ self.pre_lnorm = pre_lnorm
175
+
176
+ self.layer_norm = nn.LayerNorm(d_model)
177
+ self.network = nn.Sequential(
178
+ nn.Linear(d_model, d_inner),
179
+ nn.ReLU(),
180
+ nn.Dropout(dropout),
181
+ nn.Linear(d_inner, d_model),
182
+ nn.Dropout(dropout),
183
+ )
184
+
185
+ def forward(self, x):
186
+ if self.pre_lnorm:
187
+ return x + self.network(self.layer_norm(x))
188
+ else:
189
+ return self.layer_norm(x + self.network(x))
190
+ class TransformerModel(nn.Module):
191
+ def __init__(
192
+ self,
193
+ seq_len: int,
194
+ input_dim: int,
195
+ d_model: int,
196
+ nhead: int,
197
+ d_hid: int,
198
+ nlayers: int,
199
+ dropout: float = 0.5,
200
+ out_dim=91,
201
+ masked_attention_stage=False,
202
+ ):
203
+ super().__init__()
204
+ self.model_type = "Transformer"
205
+ self.seq_len = seq_len
206
+ self.d_model = d_model
207
+ self.nhead = nhead
208
+ self.d_hid = d_hid
209
+ self.nlayers = nlayers
210
+ self.pos_embedding = SinPositionalEncoding(d_model=d_model, dropout=0.1, max_len=seq_len)
211
+ if masked_attention_stage:
212
+ self.input_layer = nn.Linear(input_dim+1, d_model)
213
+ # visible to invisible attention
214
+ self.att_layers = nn.ModuleList()
215
+ self.pff_layers = nn.ModuleList()
216
+ self.pre_lnorm = True
217
+ self.layer_norm = nn.LayerNorm(d_model)
218
+ for i in range(self.nlayers):
219
+ self.att_layers.append(
220
+ MultiHeadedAttention(
221
+ self.nhead, self.d_model,
222
+ self.d_model // self.nhead, dropout=dropout,
223
+ pre_lnorm=True,
224
+ bias=False
225
+ )
226
+ )
227
+
228
+ self.pff_layers.append(
229
+ FeedForward(
230
+ self.d_model, d_hid,
231
+ dropout=dropout,
232
+ pre_lnorm=True
233
+ )
234
+ )
235
+ else:
236
+ self.att_layers = None
237
+ self.input_layer = nn.Linear(input_dim, d_model)
238
+ encoder_layers = TransformerEncoderLayer(
239
+ d_model, nhead, d_hid, dropout, activation="gelu"
240
+ )
241
+ self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
242
+ self.decoder = nn.Linear(d_model, out_dim)
243
+
244
+ self.init_weights()
245
+
246
+ def init_weights(self) -> None:
247
+ initrange = 0.1
248
+ self.decoder.bias.data.zero_()
249
+ self.decoder.weight.data.uniform_(-initrange, initrange)
250
+
251
+ def forward(self, src: Tensor, src_mask: Tensor, data_mask=None, atten_mask=None) -> Tensor:
252
+ """
253
+ Args:
254
+ src: Tensor, shape [seq_len, batch_size, embedding_dim]
255
+ src_mask: Tensor, shape [seq_len, seq_len]
256
+
257
+ Returns:
258
+ output Tensor of shape [seq_len, batch_size, embedding_dim]
259
+ """
260
+ if not data_mask is None:
261
+ src = torch.cat([src, data_mask.expand(*src.shape[:-1], data_mask.shape[-1])], dim=-1)
262
+ src = self.input_layer(src)
263
+ output = self.pos_embedding(src)
264
+ # output = src
265
+ if self.att_layers:
266
+ assert not atten_mask is None
267
+ output = output.permute(1, 0, 2)
268
+ for i in range(self.nlayers):
269
+ output = self.att_layers[i](output, mask=atten_mask)
270
+ output = self.pff_layers[i](output)
271
+ if self.pre_lnorm:
272
+ output = self.layer_norm(output)
273
+ output = output.permute(1, 0, 2)
274
+ output = self.transformer_encoder(output)
275
+ output = self.decoder(output)
276
+ return output
infiller/lib/model/positional_encoding.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import math
4
+
5
+
6
+ class PositionalEmbedding(nn.Module):
7
+ def __init__(self, seq_len: int = 32, d_model: int = 96):
8
+ super().__init__()
9
+ self.pos_emb = nn.Embedding(seq_len + 1, d_model)
10
+
11
+ def forward(self, inputs):
12
+ positions = (
13
+ torch.arange(inputs.size(0), device=inputs.device)
14
+ .expand(inputs.size(1), inputs.size(0))
15
+ .contiguous()
16
+ + 1
17
+ )
18
+ outputs = inputs + self.pos_emb(positions).permute(1, 0, 2)
19
+ return outputs
20
+
21
+
22
+ class PositionalEncoding(nn.Module):
23
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
24
+ super().__init__()
25
+ self.dropout = nn.Dropout(p=dropout)
26
+
27
+ position = torch.arange(max_len).unsqueeze(1)
28
+ div_term = torch.exp(
29
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
30
+ )
31
+ pe = torch.zeros(max_len, 1, d_model)
32
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
33
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
34
+ self.register_buffer("pe", pe)
35
+
36
+ def forward(self, x: Tensor) -> Tensor:
37
+ """
38
+ Args:
39
+ x: Tensor, shape [seq_len, batch_size, embedding_dim]
40
+ """
41
+ x = x + self.pe[: x.size(0)]
42
+ return self.dropout(x)
infiller/lib/model/preprocess.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def replace_constant(minibatch_pose_input, mask_start_frame):
5
+
6
+ seq_len = minibatch_pose_input.size(1)
7
+ interpolated = (
8
+ torch.ones_like(minibatch_pose_input, device=minibatch_pose_input.device) * 0.1
9
+ )
10
+
11
+ if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
12
+ interpolate_start = minibatch_pose_input[:, 0, :]
13
+ interpolate_end = minibatch_pose_input[:, seq_len - 1, :]
14
+
15
+ interpolated[:, 0, :] = interpolate_start
16
+ interpolated[:, seq_len - 1, :] = interpolate_end
17
+
18
+ assert torch.allclose(interpolated[:, 0, :], interpolate_start)
19
+ assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end)
20
+
21
+ else:
22
+ interpolate_start1 = minibatch_pose_input[:, 0, :]
23
+ interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :]
24
+
25
+ interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :]
26
+ interpolate_end2 = minibatch_pose_input[:, seq_len - 1, :]
27
+
28
+ interpolated[:, 0, :] = interpolate_start1
29
+ interpolated[:, mask_start_frame, :] = interpolate_end1
30
+
31
+ interpolated[:, mask_start_frame, :] = interpolate_start2
32
+ interpolated[:, seq_len - 1, :] = interpolate_end2
33
+
34
+ assert torch.allclose(interpolated[:, 0, :], interpolate_start1)
35
+ assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1)
36
+
37
+ assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2)
38
+ assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end2)
39
+ return interpolated
40
+
41
+
42
+ def slerp(x, y, a):
43
+ """
44
+ Perfroms spherical linear interpolation (SLERP) between x and y, with proportion a
45
+
46
+ :param x: quaternion tensor
47
+ :param y: quaternion tensor
48
+ :param a: indicator (between 0 and 1) of completion of the interpolation.
49
+ :return: tensor of interpolation results
50
+ """
51
+ device = x.device
52
+ len = torch.sum(x * y, dim=-1)
53
+
54
+ neg = len < 0.0
55
+ len[neg] = -len[neg]
56
+ y[neg] = -y[neg]
57
+
58
+ a = torch.zeros_like(x[..., 0]) + a
59
+ amount0 = torch.zeros(a.shape, device=device)
60
+ amount1 = torch.zeros(a.shape, device=device)
61
+
62
+ linear = (1.0 - len) < 0.01
63
+ omegas = torch.arccos(len[~linear])
64
+ sinoms = torch.sin(omegas)
65
+
66
+ amount0[linear] = 1.0 - a[linear]
67
+ amount0[~linear] = torch.sin((1.0 - a[~linear]) * omegas) / sinoms
68
+
69
+ amount1[linear] = a[linear]
70
+ amount1[~linear] = torch.sin(a[~linear] * omegas) / sinoms
71
+ # res = amount0[..., np.newaxis] * x + amount1[..., np.newaxis] * y
72
+ res = amount0.unsqueeze(3) * x + amount1.unsqueeze(3) * y
73
+
74
+ return res
75
+
76
+
77
+ def slerp_input_repr(minibatch_pose_input, mask_start_frame):
78
+ seq_len = minibatch_pose_input.size(1)
79
+ minibatch_pose_input = minibatch_pose_input.reshape(
80
+ minibatch_pose_input.size(0), seq_len, -1, 4
81
+ )
82
+ interpolated = torch.zeros_like(
83
+ minibatch_pose_input, device=minibatch_pose_input.device
84
+ )
85
+
86
+ if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
87
+ interpolate_start = minibatch_pose_input[:, 0:1]
88
+ interpolate_end = minibatch_pose_input[:, seq_len - 1 :]
89
+
90
+ for i in range(seq_len):
91
+ dt = 1 / (seq_len - 1)
92
+ interpolated[:, i : i + 1, :] = slerp(
93
+ interpolate_start, interpolate_end, dt * i
94
+ )
95
+
96
+ assert torch.allclose(interpolated[:, 0:1], interpolate_start)
97
+ assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end)
98
+ else:
99
+ interpolate_start1 = minibatch_pose_input[:, 0:1]
100
+ interpolate_end1 = minibatch_pose_input[
101
+ :, mask_start_frame : mask_start_frame + 1
102
+ ]
103
+
104
+ interpolate_start2 = minibatch_pose_input[
105
+ :, mask_start_frame : mask_start_frame + 1
106
+ ]
107
+ interpolate_end2 = minibatch_pose_input[:, seq_len - 1 :]
108
+
109
+ for i in range(mask_start_frame + 1):
110
+ dt = 1 / mask_start_frame
111
+ interpolated[:, i : i + 1, :] = slerp(
112
+ interpolate_start1, interpolate_end1, dt * i
113
+ )
114
+
115
+ assert torch.allclose(interpolated[:, 0:1], interpolate_start1)
116
+ assert torch.allclose(
117
+ interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_end1
118
+ )
119
+
120
+ for i in range(mask_start_frame, seq_len):
121
+ dt = 1 / (seq_len - mask_start_frame - 1)
122
+ interpolated[:, i : i + 1, :] = slerp(
123
+ interpolate_start2, interpolate_end2, dt * (i - mask_start_frame)
124
+ )
125
+
126
+ assert torch.allclose(
127
+ interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_start2
128
+ )
129
+ assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end2)
130
+
131
+ interpolated = torch.nn.functional.normalize(interpolated, p=2.0, dim=3)
132
+ return interpolated.reshape(minibatch_pose_input.size(0), seq_len, -1)
133
+
134
+
135
+ def lerp_input_repr(minibatch_pose_input, mask_start_frame):
136
+ seq_len = minibatch_pose_input.size(1)
137
+ interpolated = torch.zeros_like(
138
+ minibatch_pose_input, device=minibatch_pose_input.device
139
+ )
140
+
141
+ if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
142
+ interpolate_start = minibatch_pose_input[:, 0, :]
143
+ interpolate_end = minibatch_pose_input[:, seq_len - 1, :]
144
+
145
+ for i in range(seq_len):
146
+ dt = 1 / (seq_len - 1)
147
+ interpolated[:, i, :] = torch.lerp(
148
+ interpolate_start, interpolate_end, dt * i
149
+ )
150
+
151
+ assert torch.allclose(interpolated[:, 0, :], interpolate_start)
152
+ assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end)
153
+ else:
154
+ interpolate_start1 = minibatch_pose_input[:, 0, :]
155
+ interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :]
156
+
157
+ interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :]
158
+ interpolate_end2 = minibatch_pose_input[:, -1, :]
159
+
160
+ for i in range(mask_start_frame + 1):
161
+ dt = 1 / mask_start_frame
162
+ interpolated[:, i, :] = torch.lerp(
163
+ interpolate_start1, interpolate_end1, dt * i
164
+ )
165
+
166
+ assert torch.allclose(interpolated[:, 0, :], interpolate_start1)
167
+ assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1)
168
+
169
+ for i in range(mask_start_frame, seq_len):
170
+ dt = 1 / (seq_len - mask_start_frame - 1)
171
+ interpolated[:, i, :] = torch.lerp(
172
+ interpolate_start2, interpolate_end2, dt * (i - mask_start_frame)
173
+ )
174
+
175
+ assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2)
176
+ assert torch.allclose(interpolated[:, -1, :], interpolate_end2)
177
+ return interpolated
178
+
179
+
180
+ def vectorize_representation(global_position, global_rotation):
181
+
182
+ batch_size = global_position.shape[0]
183
+ seq_len = global_position.shape[1]
184
+
185
+ global_pos_vec = global_position.reshape(batch_size, seq_len, -1).contiguous()
186
+ global_rot_vec = global_rotation.reshape(batch_size, seq_len, -1).contiguous()
187
+
188
+ global_pose_vec_gt = torch.cat([global_pos_vec, global_rot_vec], dim=2)
189
+ return global_pose_vec_gt
infiller/lib/model/skeleton.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from cmib.data.quaternion import qmul, qrot
4
+ import torch.nn as nn
5
+
6
+ amass_offsets = [
7
+ [0.0, 0.0, 0.0],
8
+
9
+ [0.058581, -0.082280, -0.017664],
10
+ [0.043451, -0.386469, 0.008037],
11
+ [-0.014790, -0.426874, -0.037428],
12
+ [0.041054, -0.060286, 0.122042],
13
+ [0.0, 0.0, 0.0],
14
+
15
+ [-0.060310, -0.090513, -0.013543],
16
+ [-0.043257, -0.383688, -0.004843],
17
+ [0.019056, -0.420046, -0.034562],
18
+ [-0.034840, -0.062106, 0.130323],
19
+ [0.0, 0.0, 0.0],
20
+
21
+ [0.004439, 0.124404, -0.038385],
22
+ [0.004488, 0.137956, 0.026820],
23
+ [-0.002265, 0.056032, 0.002855],
24
+ [-0.013390, 0.211636, -0.033468],
25
+ [0.010113, 0.088937, 0.050410],
26
+ [0.0, 0.0, 0.0],
27
+
28
+ [0.071702, 0.114000, -0.018898],
29
+ [0.122921, 0.045205, -0.019046],
30
+ [0.255332, -0.015649, -0.022946],
31
+ [0.265709, 0.012698, -0.007375],
32
+ [0.0, 0.0, 0.0],
33
+
34
+ [-0.082954, 0.112472, -0.023707],
35
+ [-0.113228, 0.046853, -0.008472],
36
+ [-0.260127, -0.014369, -0.031269],
37
+ [-0.269108, 0.006794, -0.006027],
38
+ [0.0, 0.0, 0.0]
39
+ ]
40
+
41
+ sk_offsets = [
42
+ [-42.198200, 91.614723, -40.067841],
43
+
44
+ [0.103456, 1.857829, 10.548506],
45
+ [43.499992, -0.000038, -0.000002],
46
+ [42.372192, 0.000015, -0.000007],
47
+ [17.299999, -0.000002, 0.000003],
48
+ [0.000000, 0.000000, 0.000000],
49
+
50
+ [0.103457, 1.857829, -10.548503],
51
+ [43.500042, -0.000027, 0.000008],
52
+ [42.372257, -0.000008, 0.000014],
53
+ [17.299992, -0.000005, 0.000004],
54
+ [0.000000, 0.000000, 0.000000],
55
+
56
+ [6.901968, -2.603733, -0.000001],
57
+ [12.588099, 0.000002, 0.000000],
58
+ [12.343206, 0.000000, -0.000001],
59
+ [25.832886, -0.000004, 0.000003],
60
+ [11.766620, 0.000005, -0.000001],
61
+ [0.000000, 0.000000, 0.000000],
62
+
63
+ [19.745899, -1.480370, 6.000108],
64
+ [11.284125, -0.000009, -0.000018],
65
+ [33.000050, 0.000004, 0.000032],
66
+ [25.200008, 0.000015, 0.000008],
67
+ [0.000000, 0.000000, 0.000000],
68
+
69
+ [19.746099, -1.480375, -6.000073],
70
+ [11.284138, -0.000015, -0.000012],
71
+ [33.000092, 0.000017, 0.000013],
72
+ [25.199780, 0.000135, 0.000422],
73
+ [0.000000, 0.000000, 0.000000],
74
+ ]
75
+
76
+ sk_parents = [
77
+ -1,
78
+ 0,
79
+ 1,
80
+ 2,
81
+ 3,
82
+ 4,
83
+ 0,
84
+ 6,
85
+ 7,
86
+ 8,
87
+ 9,
88
+ 0,
89
+ 11,
90
+ 12,
91
+ 13,
92
+ 14,
93
+ 15,
94
+ 13,
95
+ 17,
96
+ 18,
97
+ 19,
98
+ 20,
99
+ 13,
100
+ 22,
101
+ 23,
102
+ 24,
103
+ 25,
104
+ ]
105
+
106
+ sk_joints_to_remove = [5, 10, 16, 21, 26]
107
+
108
+ joint_names = [
109
+ "Hips",
110
+ "LeftUpLeg",
111
+ "LeftLeg",
112
+ "LeftFoot",
113
+ "LeftToe",
114
+ "RightUpLeg",
115
+ "RightLeg",
116
+ "RightFoot",
117
+ "RightToe",
118
+ "Spine",
119
+ "Spine1",
120
+ "Spine2",
121
+ "Neck",
122
+ "Head",
123
+ "LeftShoulder",
124
+ "LeftArm",
125
+ "LeftForeArm",
126
+ "LeftHand",
127
+ "RightShoulder",
128
+ "RightArm",
129
+ "RightForeArm",
130
+ "RightHand",
131
+ ]
132
+
133
+
134
+ class Skeleton:
135
+ def __init__(
136
+ self,
137
+ offsets,
138
+ parents,
139
+ joints_left=None,
140
+ joints_right=None,
141
+ bone_length=None,
142
+ device=None,
143
+ ):
144
+ assert len(offsets) == len(parents)
145
+
146
+ self._offsets = torch.Tensor(offsets).to(device)
147
+ self._parents = np.array(parents)
148
+ self._joints_left = joints_left
149
+ self._joints_right = joints_right
150
+ self._compute_metadata()
151
+
152
+ def num_joints(self):
153
+ return self._offsets.shape[0]
154
+
155
+ def offsets(self):
156
+ return self._offsets
157
+
158
+ def parents(self):
159
+ return self._parents
160
+
161
+ def has_children(self):
162
+ return self._has_children
163
+
164
+ def children(self):
165
+ return self._children
166
+
167
+ def convert_to_global_pos(self, unit_vec_rerp):
168
+ """
169
+ Convert the unit offset matrix to global position.
170
+ First row(root) will have absolute position value in global coordinates.
171
+ """
172
+ bone_length = self.get_bone_length_weight()
173
+ batch_size = unit_vec_rerp.size(0)
174
+ seq_len = unit_vec_rerp.size(1)
175
+ unit_vec_table = unit_vec_rerp.reshape(batch_size, seq_len, 22, 3)
176
+ global_position = torch.zeros_like(unit_vec_table, device=unit_vec_table.device)
177
+
178
+ for i, parent in enumerate(self._parents):
179
+ if parent == -1: # if root
180
+ global_position[:, :, i] = unit_vec_table[:, :, i]
181
+
182
+ else:
183
+ global_position[:, :, i] = global_position[:, :, parent] + (
184
+ nn.functional.normalize(unit_vec_table[:, :, i], p=2.0, dim=-1)
185
+ * bone_length[i]
186
+ )
187
+
188
+ return global_position
189
+
190
+ def convert_to_unit_offset_mat(self, global_position):
191
+ """
192
+ Convert the global position of the skeleton to a unit offset matrix.
193
+ First row(root) will have absolute position value in global coordinates.
194
+ """
195
+
196
+ bone_length = self.get_bone_length_weight()
197
+ unit_offset_mat = torch.zeros_like(
198
+ global_position, device=global_position.device
199
+ )
200
+
201
+ for i, parent in enumerate(self._parents):
202
+
203
+ if parent == -1: # if root
204
+ unit_offset_mat[:, :, i] = global_position[:, :, i]
205
+ else:
206
+ unit_offset_mat[:, :, i] = (
207
+ global_position[:, :, i] - global_position[:, :, parent]
208
+ ) / bone_length[i]
209
+
210
+ return unit_offset_mat
211
+
212
+ def remove_joints(self, joints_to_remove):
213
+ """
214
+ Remove the joints specified in 'joints_to_remove', both from the
215
+ skeleton definition and from the dataset (which is modified in place).
216
+ The rotations of removed joints are propagated along the kinematic chain.
217
+ """
218
+ valid_joints = []
219
+ for joint in range(len(self._parents)):
220
+ if joint not in joints_to_remove:
221
+ valid_joints.append(joint)
222
+
223
+ index_offsets = np.zeros(len(self._parents), dtype=int)
224
+ new_parents = []
225
+ for i, parent in enumerate(self._parents):
226
+ if i not in joints_to_remove:
227
+ new_parents.append(parent - index_offsets[parent])
228
+ else:
229
+ index_offsets[i:] += 1
230
+ self._parents = np.array(new_parents)
231
+
232
+ self._offsets = self._offsets[valid_joints]
233
+ self._compute_metadata()
234
+
235
+ def forward_kinematics(self, rotations, root_positions):
236
+ """
237
+ Perform forward kinematics using the given trajectory and local rotations.
238
+ Arguments (where N = batch size, L = sequence length, J = number of joints):
239
+ -- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint.
240
+ -- root_positions: (N, L, 3) tensor describing the root joint positions.
241
+ """
242
+ assert len(rotations.shape) == 4
243
+ assert rotations.shape[-1] == 4
244
+
245
+ positions_world = []
246
+ rotations_world = []
247
+
248
+ expanded_offsets = self._offsets.expand(
249
+ rotations.shape[0],
250
+ rotations.shape[1],
251
+ self._offsets.shape[0],
252
+ self._offsets.shape[1],
253
+ )
254
+
255
+ # Parallelize along the batch and time dimensions
256
+ for i in range(self._offsets.shape[0]):
257
+ if self._parents[i] == -1:
258
+ positions_world.append(root_positions)
259
+ rotations_world.append(rotations[:, :, 0])
260
+ else:
261
+ positions_world.append(
262
+ qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i])
263
+ + positions_world[self._parents[i]]
264
+ )
265
+ if self._has_children[i]:
266
+ rotations_world.append(
267
+ qmul(rotations_world[self._parents[i]], rotations[:, :, i])
268
+ )
269
+ else:
270
+ # This joint is a terminal node -> it would be useless to compute the transformation
271
+ rotations_world.append(None)
272
+
273
+ return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2)
274
+
275
+ def forward_kinematics_with_rotation(self, rotations, root_positions):
276
+ """
277
+ Perform forward kinematics using the given trajectory and local rotations.
278
+ Arguments (where N = batch size, L = sequence length, J = number of joints):
279
+ -- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint.
280
+ -- root_positions: (N, L, 3) tensor describing the root joint positions.
281
+ """
282
+ assert len(rotations.shape) == 4
283
+ assert rotations.shape[-1] == 4
284
+
285
+ positions_world = []
286
+ rotations_world = []
287
+
288
+ expanded_offsets = self._offsets.expand(
289
+ rotations.shape[0],
290
+ rotations.shape[1],
291
+ self._offsets.shape[0],
292
+ self._offsets.shape[1],
293
+ )
294
+
295
+ # Parallelize along the batch and time dimensions
296
+ for i in range(self._offsets.shape[0]):
297
+ if self._parents[i] == -1:
298
+ positions_world.append(root_positions)
299
+ rotations_world.append(rotations[:, :, 0])
300
+ else:
301
+ positions_world.append(
302
+ qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i])
303
+ + positions_world[self._parents[i]]
304
+ )
305
+ if self._has_children[i]:
306
+ rotations_world.append(
307
+ qmul(rotations_world[self._parents[i]], rotations[:, :, i])
308
+ )
309
+ else:
310
+ # This joint is a terminal node -> it would be useless to compute the transformation
311
+ rotations_world.append(
312
+ torch.Tensor([1, 0, 0, 0])
313
+ .expand(rotations.shape[0], rotations.shape[1], 4)
314
+ .to(rotations.device)
315
+ )
316
+
317
+ return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2), torch.stack(
318
+ rotations_world, dim=3
319
+ ).permute(0, 1, 3, 2)
320
+
321
+ def get_bone_length_weight(self):
322
+ bone_length = []
323
+ for i, parent in enumerate(self._parents):
324
+ if parent == -1:
325
+ bone_length.append(1)
326
+ else:
327
+ bone_length.append(
328
+ torch.linalg.norm(self._offsets[i : i + 1], ord="fro").item()
329
+ )
330
+ return torch.Tensor(bone_length)
331
+
332
+ def joints_left(self):
333
+ return self._joints_left
334
+
335
+ def joints_right(self):
336
+ return self._joints_right
337
+
338
+ def _compute_metadata(self):
339
+ self._has_children = np.zeros(len(self._parents)).astype(bool)
340
+ for i, parent in enumerate(self._parents):
341
+ if parent != -1:
342
+ self._has_children[parent] = True
343
+
344
+ self._children = []
345
+ for i, parent in enumerate(self._parents):
346
+ self._children.append([])
347
+ for i, parent in enumerate(self._parents):
348
+ if parent != -1:
349
+ self._children[parent].append(i)
infiller/lib/vis/pose.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+
7
+
8
+ def project_root_position(position_arr: np.array, file_name: str):
9
+ """
10
+ Take batch of root arrays and porject it on 2D plane
11
+
12
+ N: samples
13
+ L: trajectory length
14
+ J: joints
15
+
16
+ position_arr: [N,L,J,3]
17
+ """
18
+
19
+ root_joints = position_arr[:, :, 0]
20
+
21
+ x_pos = root_joints[:, :, 0]
22
+ y_pos = root_joints[:, :, 2]
23
+
24
+ fig = plt.figure()
25
+
26
+ for i in range(x_pos.shape[1]):
27
+
28
+ if i == 0:
29
+ plt.scatter(x_pos[:, i], y_pos[:, i], c="b")
30
+ elif i == x_pos.shape[1] - 1:
31
+ plt.scatter(x_pos[:, i], y_pos[:, i], c="r")
32
+ else:
33
+ plt.scatter(x_pos[:, i], y_pos[:, i], c="k", marker="*", s=1)
34
+
35
+ plt.title(f"Root Position: {file_name}")
36
+ plt.xlabel("X Axis")
37
+ plt.ylabel("Y Axis")
38
+ plt.xlim((-300, 300))
39
+ plt.ylim((-300, 300))
40
+ plt.grid()
41
+ plt.savefig(f"{file_name}.png", dpi=200)
42
+
43
+
44
+ def plot_single_pose(
45
+ pose,
46
+ frame_idx,
47
+ skeleton,
48
+ save_dir,
49
+ prefix,
50
+ ):
51
+ fig = plt.figure()
52
+ ax = fig.add_subplot(111, projection="3d")
53
+
54
+ parent_idx = skeleton.parents()
55
+
56
+ for i, p in enumerate(parent_idx):
57
+ if i > 0:
58
+ ax.plot(
59
+ [pose[i, 0], pose[p, 0]],
60
+ [pose[i, 2], pose[p, 2]],
61
+ [pose[i, 1], pose[p, 1]],
62
+ c="k",
63
+ )
64
+
65
+ x_min = pose[:, 0].min()
66
+ x_max = pose[:, 0].max()
67
+
68
+ y_min = pose[:, 1].min()
69
+ y_max = pose[:, 1].max()
70
+
71
+ z_min = pose[:, 2].min()
72
+ z_max = pose[:, 2].max()
73
+
74
+ ax.set_xlim(x_min, x_max)
75
+ ax.set_xlabel("$X$ Axis")
76
+
77
+ ax.set_ylim(z_min, z_max)
78
+ ax.set_ylabel("$Y$ Axis")
79
+
80
+ ax.set_zlim(y_min, y_max)
81
+ ax.set_zlabel("$Z$ Axis")
82
+
83
+ plt.draw()
84
+
85
+ title = f"{prefix}: {frame_idx}"
86
+ plt.title(title)
87
+ prefix = prefix
88
+ pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)
89
+ plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60)
90
+ plt.close()
91
+
92
+
93
+ def plot_pose(
94
+ start_pose,
95
+ inbetween_pose,
96
+ target_pose,
97
+ frame_idx,
98
+ skeleton,
99
+ save_dir,
100
+ prefix,
101
+ ):
102
+ fig = plt.figure()
103
+ ax = fig.add_subplot(111, projection="3d")
104
+
105
+ parent_idx = skeleton.parents()
106
+
107
+ for i, p in enumerate(parent_idx):
108
+ if i > 0:
109
+ ax.plot(
110
+ [start_pose[i, 0], start_pose[p, 0]],
111
+ [start_pose[i, 2], start_pose[p, 2]],
112
+ [start_pose[i, 1], start_pose[p, 1]],
113
+ c="b",
114
+ )
115
+ ax.plot(
116
+ [inbetween_pose[i, 0], inbetween_pose[p, 0]],
117
+ [inbetween_pose[i, 2], inbetween_pose[p, 2]],
118
+ [inbetween_pose[i, 1], inbetween_pose[p, 1]],
119
+ c="k",
120
+ )
121
+ ax.plot(
122
+ [target_pose[i, 0], target_pose[p, 0]],
123
+ [target_pose[i, 2], target_pose[p, 2]],
124
+ [target_pose[i, 1], target_pose[p, 1]],
125
+ c="r",
126
+ )
127
+
128
+ x_min = np.min(
129
+ [start_pose[:, 0].min(), inbetween_pose[:, 0].min(), target_pose[:, 0].min()]
130
+ )
131
+ x_max = np.max(
132
+ [start_pose[:, 0].max(), inbetween_pose[:, 0].max(), target_pose[:, 0].max()]
133
+ )
134
+
135
+ y_min = np.min(
136
+ [start_pose[:, 1].min(), inbetween_pose[:, 1].min(), target_pose[:, 1].min()]
137
+ )
138
+ y_max = np.max(
139
+ [start_pose[:, 1].max(), inbetween_pose[:, 1].max(), target_pose[:, 1].max()]
140
+ )
141
+
142
+ z_min = np.min(
143
+ [start_pose[:, 2].min(), inbetween_pose[:, 2].min(), target_pose[:, 2].min()]
144
+ )
145
+ z_max = np.max(
146
+ [start_pose[:, 2].max(), inbetween_pose[:, 2].max(), target_pose[:, 2].max()]
147
+ )
148
+
149
+ ax.set_xlim(x_min, x_max)
150
+ ax.set_xlabel("$X$ Axis")
151
+
152
+ ax.set_ylim(z_min, z_max)
153
+ ax.set_ylabel("$Y$ Axis")
154
+
155
+ ax.set_zlim(y_min, y_max)
156
+ ax.set_zlabel("$Z$ Axis")
157
+
158
+ plt.draw()
159
+
160
+ title = f"{prefix}: {frame_idx}"
161
+ plt.title(title)
162
+ prefix = prefix
163
+ pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)
164
+ plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60)
165
+ plt.close()
166
+
167
+
168
+ def plot_pose_with_stop(
169
+ start_pose,
170
+ inbetween_pose,
171
+ target_pose,
172
+ stopover,
173
+ frame_idx,
174
+ skeleton,
175
+ save_dir,
176
+ prefix,
177
+ ):
178
+ fig = plt.figure()
179
+ ax = fig.add_subplot(111, projection="3d")
180
+
181
+ parent_idx = skeleton.parents()
182
+
183
+ for i, p in enumerate(parent_idx):
184
+ if i > 0:
185
+ ax.plot(
186
+ [start_pose[i, 0], start_pose[p, 0]],
187
+ [start_pose[i, 2], start_pose[p, 2]],
188
+ [start_pose[i, 1], start_pose[p, 1]],
189
+ c="b",
190
+ )
191
+ ax.plot(
192
+ [inbetween_pose[i, 0], inbetween_pose[p, 0]],
193
+ [inbetween_pose[i, 2], inbetween_pose[p, 2]],
194
+ [inbetween_pose[i, 1], inbetween_pose[p, 1]],
195
+ c="k",
196
+ )
197
+ ax.plot(
198
+ [target_pose[i, 0], target_pose[p, 0]],
199
+ [target_pose[i, 2], target_pose[p, 2]],
200
+ [target_pose[i, 1], target_pose[p, 1]],
201
+ c="r",
202
+ )
203
+
204
+ ax.plot(
205
+ [stopover[i, 0], stopover[p, 0]],
206
+ [stopover[i, 2], stopover[p, 2]],
207
+ [stopover[i, 1], stopover[p, 1]],
208
+ c="indigo",
209
+ )
210
+
211
+ x_min = np.min(
212
+ [start_pose[:, 0].min(), inbetween_pose[:, 0].min(), target_pose[:, 0].min()]
213
+ )
214
+ x_max = np.max(
215
+ [start_pose[:, 0].max(), inbetween_pose[:, 0].max(), target_pose[:, 0].max()]
216
+ )
217
+
218
+ y_min = np.min(
219
+ [start_pose[:, 1].min(), inbetween_pose[:, 1].min(), target_pose[:, 1].min()]
220
+ )
221
+ y_max = np.max(
222
+ [start_pose[:, 1].max(), inbetween_pose[:, 1].max(), target_pose[:, 1].max()]
223
+ )
224
+
225
+ z_min = np.min(
226
+ [start_pose[:, 2].min(), inbetween_pose[:, 2].min(), target_pose[:, 2].min()]
227
+ )
228
+ z_max = np.max(
229
+ [start_pose[:, 2].max(), inbetween_pose[:, 2].max(), target_pose[:, 2].max()]
230
+ )
231
+
232
+ ax.set_xlim(x_min, x_max)
233
+ ax.set_xlabel("$X$ Axis")
234
+
235
+ ax.set_ylim(z_min, z_max)
236
+ ax.set_ylabel("$Y$ Axis")
237
+
238
+ ax.set_zlim(y_min, y_max)
239
+ ax.set_zlabel("$Z$ Axis")
240
+
241
+ plt.draw()
242
+
243
+ title = f"{prefix}: {frame_idx}"
244
+ plt.title(title)
245
+ prefix = prefix
246
+ pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)
247
+ plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60)
248
+ plt.close()
lib/core/__pycache__/constants.cpython-310.pyc ADDED
Binary file (2.87 kB). View file
 
lib/core/constants.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FOCAL_LENGTH = 5000.
2
+
3
+ # Mean and standard deviation for normalizing input image
4
+ IMG_NORM_MEAN = [0.485, 0.456, 0.406]
5
+ IMG_NORM_STD = [0.229, 0.224, 0.225]
6
+
7
+ """
8
+ We create a superset of joints containing the OpenPose joints together with the ones that each dataset provides.
9
+ We keep a superset of 24 joints such that we include all joints from every dataset.
10
+ If a dataset doesn't provide annotations for a specific joint, we simply ignore it.
11
+ The joints used here are the following:
12
+ """
13
+ JOINT_NAMES = [
14
+ 'OP Nose', 'OP Neck', 'OP RShoulder', #0,1,2
15
+ 'OP RElbow', 'OP RWrist', 'OP LShoulder', #3,4,5
16
+ 'OP LElbow', 'OP LWrist', 'OP MidHip', #6, 7,8
17
+ 'OP RHip', 'OP RKnee', 'OP RAnkle', #9,10,11
18
+ 'OP LHip', 'OP LKnee', 'OP LAnkle', #12,13,14
19
+ 'OP REye', 'OP LEye', 'OP REar', #15,16,17
20
+ 'OP LEar', 'OP LBigToe', 'OP LSmallToe', #18,19,20
21
+ 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel', #21, 22, 23, 24 ##Total 25 joints for openpose
22
+ 'Right Ankle', 'Right Knee', 'Right Hip', #0,1,2
23
+ 'Left Hip', 'Left Knee', 'Left Ankle', #3, 4, 5
24
+ 'Right Wrist', 'Right Elbow', 'Right Shoulder', #6
25
+ 'Left Shoulder', 'Left Elbow', 'Left Wrist', #9
26
+ 'Neck (LSP)', 'Top of Head (LSP)', #12, 13
27
+ 'Pelvis (MPII)', 'Thorax (MPII)', #14, 15
28
+ 'Spine (H36M)', 'Jaw (H36M)', #16, 17
29
+ 'Head (H36M)', 'Nose', 'Left Eye', #18, 19, 20
30
+ 'Right Eye', 'Left Ear', 'Right Ear' #21,22,23 (Total 24 joints)
31
+ ]
32
+
33
+ # Dict containing the joints in numerical order
34
+ JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))}
35
+
36
+ # Map joints to SMPL joints
37
+ JOINT_MAP = {
38
+ 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17,
39
+ 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16,
40
+ 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0,
41
+ 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8,
42
+ 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7,
43
+ 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27,
44
+ 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30,
45
+ 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34,
46
+ 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45,
47
+ 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7,
48
+ 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17,
49
+ 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20,
50
+ 'Neck (LSP)': 47, 'Top of Head (LSP)': 48,
51
+ 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50,
52
+ 'Spine (H36M)': 51, 'Jaw (H36M)': 52,
53
+ 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26,
54
+ 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27
55
+ }
56
+
57
+ # Joint selectors
58
+ # Indices to get the 14 LSP joints from the 17 H36M joints
59
+ H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9]
60
+ H36M_TO_J14 = H36M_TO_J17[:14]
61
+ # Indices to get the 14 LSP joints from the ground truth joints
62
+ J24_TO_J17 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18, 14, 16, 17]
63
+ J24_TO_J14 = J24_TO_J17[:14]
64
+
65
+ # Permutation of SMPL pose parameters when flipping the shape
66
+ SMPL_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 23, 22]
67
+ SMPL_POSE_FLIP_PERM = []
68
+ for i in SMPL_JOINTS_FLIP_PERM:
69
+ SMPL_POSE_FLIP_PERM.append(3*i)
70
+ SMPL_POSE_FLIP_PERM.append(3*i+1)
71
+ SMPL_POSE_FLIP_PERM.append(3*i+2)
72
+ # Permutation indices for the 24 ground truth joints
73
+ J24_FLIP_PERM = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22]
74
+ # Permutation indices for the full set of 49 joints
75
+ J49_FLIP_PERM = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]\
76
+ + [25+i for i in J24_FLIP_PERM]
77
+
78
+
lib/datasets/__pycache__/track_dataset.cpython-310.pyc ADDED
Binary file (2.28 kB). View file
 
lib/datasets/track_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from torchvision.transforms import Normalize, ToTensor, Compose
4
+ import numpy as np
5
+ import cv2
6
+
7
+ from lib.core import constants
8
+ from lib.utils.imutils import crop, boxes_2_cs
9
+
10
+
11
+ class TrackDatasetEval(Dataset):
12
+ """
13
+ Track Dataset Class - Load images/crops of the tracked boxes.
14
+ """
15
+ def __init__(self, imgfiles, boxes,
16
+ crop_size=256, dilate=1.0,
17
+ img_focal=None, img_center=None, normalization=True,
18
+ item_idx=0, do_flip=False):
19
+ super(TrackDatasetEval, self).__init__()
20
+
21
+ self.imgfiles = imgfiles
22
+ self.crop_size = crop_size
23
+ self.normalization = normalization
24
+ self.normalize_img = Compose([
25
+ ToTensor(),
26
+ Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD)
27
+ ])
28
+
29
+ self.boxes = boxes
30
+ self.box_dilate = dilate
31
+ self.centers, self.scales = boxes_2_cs(boxes)
32
+
33
+ self.img_focal = img_focal
34
+ self.img_center = img_center
35
+ self.item_idx = item_idx
36
+ self.do_flip = do_flip
37
+
38
+ def __len__(self):
39
+ return len(self.imgfiles)
40
+
41
+
42
+ def __getitem__(self, index):
43
+ item = {}
44
+ imgfile = self.imgfiles[index]
45
+ scale = self.scales[index] * self.box_dilate
46
+ center = self.centers[index]
47
+
48
+ img_focal = self.img_focal
49
+ img_center = self.img_center
50
+
51
+ img = cv2.imread(imgfile)[:,:,::-1]
52
+ if self.do_flip:
53
+ img = img[:, ::-1, :]
54
+ img_width = img.shape[1]
55
+ center[0] = img_width - center[0] - 1
56
+ img_crop = crop(img, center, scale,
57
+ [self.crop_size, self.crop_size],
58
+ rot=0).astype('uint8')
59
+ # cv2.imwrite('debug_crop.png', img_crop[:,:,::-1])
60
+
61
+ if self.normalization:
62
+ img_crop = self.normalize_img(img_crop)
63
+ else:
64
+ img_crop = torch.from_numpy(img_crop)
65
+ item['img'] = img_crop
66
+
67
+ if self.do_flip:
68
+ # center[0] = img_width - center[0] - 1
69
+ item['do_flip'] = torch.tensor(1).float()
70
+ item['img_idx'] = torch.tensor(index).long()
71
+ item['scale'] = torch.tensor(scale).float()
72
+ item['center'] = torch.tensor(center).float()
73
+ item['img_focal'] = torch.tensor(img_focal).float()
74
+ item['img_center'] = torch.tensor(img_center).float()
75
+
76
+
77
+ return item
78
+
lib/eval_utils/__pycache__/custom_utils.cpython-310.pyc ADDED
Binary file (2.96 kB). View file
 
lib/eval_utils/__pycache__/filling_utils.cpython-310.pyc ADDED
Binary file (6.88 kB). View file
 
lib/eval_utils/custom_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import numpy as np
3
+ import torch
4
+
5
+ from hawor.utils.process import run_mano, run_mano_left
6
+ from hawor.utils.rotation import angle_axis_to_quaternion, rotation_matrix_to_angle_axis
7
+ from scipy.interpolate import interp1d
8
+
9
+
10
+ def cam2world_convert(R_c2w_sla, t_c2w_sla, data_out, handedness):
11
+ init_rot_mat = copy.deepcopy(data_out["init_root_orient"])
12
+ init_rot_mat = torch.einsum("tij,btjk->btik", R_c2w_sla, init_rot_mat)
13
+ init_rot = rotation_matrix_to_angle_axis(init_rot_mat)
14
+ init_rot_quat = angle_axis_to_quaternion(init_rot)
15
+ # data_out["init_root_orient"] = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
16
+ # data_out["init_hand_pose"] = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
17
+ data_out_init_root_orient = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
18
+ data_out_init_hand_pose = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
19
+
20
+ init_trans = data_out["init_trans"] # (B, T, 3)
21
+ if handedness == "right":
22
+ outputs = run_mano(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"])
23
+ elif handedness == "left":
24
+ outputs = run_mano_left(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"])
25
+ root_loc = outputs["joints"][..., 0, :].cpu() # (B, T, 3)
26
+ offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
27
+ init_trans = (
28
+ torch.einsum("tij,btj->bti", R_c2w_sla, root_loc)
29
+ + t_c2w_sla[None, :]
30
+ + offset
31
+ )
32
+
33
+ data_world = {
34
+ "init_root_orient": init_rot, # (B, T, 3)
35
+ "init_hand_pose": data_out_init_hand_pose, # (B, T, 15, 3)
36
+ "init_trans": init_trans, # (B, T, 3)
37
+ "init_betas": data_out["init_betas"] # (B, T, 10)
38
+ }
39
+
40
+ return data_world
41
+
42
+ def quaternion_to_matrix(quaternions):
43
+ """
44
+ Convert rotations given as quaternions to rotation matrices.
45
+
46
+ Args:
47
+ quaternions: quaternions with real part first,
48
+ as tensor of shape (..., 4).
49
+
50
+ Returns:
51
+ Rotation matrices as tensor of shape (..., 3, 3).
52
+ """
53
+ r, i, j, k = torch.unbind(quaternions, -1)
54
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
55
+
56
+ o = torch.stack(
57
+ (
58
+ 1 - two_s * (j * j + k * k),
59
+ two_s * (i * j - k * r),
60
+ two_s * (i * k + j * r),
61
+ two_s * (i * j + k * r),
62
+ 1 - two_s * (i * i + k * k),
63
+ two_s * (j * k - i * r),
64
+ two_s * (i * k - j * r),
65
+ two_s * (j * k + i * r),
66
+ 1 - two_s * (i * i + j * j),
67
+ ),
68
+ -1,
69
+ )
70
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
71
+
72
+ def load_slam_cam(fpath):
73
+ print(f"Loading cameras from {fpath}...")
74
+ pred_cam = dict(np.load(fpath, allow_pickle=True))
75
+ pred_traj = pred_cam['traj']
76
+ t_c2w_sla = torch.tensor(pred_traj[:, :3]) * pred_cam['scale']
77
+ pred_camq = torch.tensor(pred_traj[:, 3:])
78
+ R_c2w_sla = quaternion_to_matrix(pred_camq[:,[3,0,1,2]])
79
+ R_w2c_sla = R_c2w_sla.transpose(-1, -2)
80
+ t_w2c_sla = -torch.einsum("bij,bj->bi", R_w2c_sla, t_c2w_sla)
81
+ return R_w2c_sla, t_w2c_sla, R_c2w_sla, t_c2w_sla
82
+
83
+
84
+ def interpolate_bboxes(bboxes):
85
+ T = bboxes.shape[0]
86
+
87
+ zero_indices = np.where(np.all(bboxes == 0, axis=1))[0]
88
+
89
+ non_zero_indices = np.where(np.any(bboxes != 0, axis=1))[0]
90
+
91
+ if len(zero_indices) == 0:
92
+ return bboxes
93
+
94
+ interpolated_bboxes = bboxes.copy()
95
+ for i in range(5):
96
+ interp_func = interp1d(non_zero_indices, bboxes[non_zero_indices, i], kind='linear', fill_value="extrapolate")
97
+ interpolated_bboxes[zero_indices, i] = interp_func(zero_indices)
98
+
99
+ return interpolated_bboxes
lib/eval_utils/filling_utils.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import joblib
4
+ import numpy as np
5
+ from scipy.spatial.transform import Slerp, Rotation
6
+ import torch
7
+
8
+ from hawor.utils.process import run_mano, run_mano_left
9
+ from hawor.utils.rotation import angle_axis_to_quaternion, angle_axis_to_rotation_matrix, quaternion_to_rotation_matrix, rotation_matrix_to_angle_axis
10
+ from lib.utils.geometry import rotmat_to_rot6d
11
+ from lib.utils.geometry import rot6d_to_rotmat
12
+
13
+ def slerp_interpolation_aa(pos, valid):
14
+
15
+ B, T, N, _ = pos.shape # B: 批次大小, T: 时间步长, N: 关节数, 4: 四元数维度
16
+ pos_interp = pos.copy() # 创建副本以存储插值结果
17
+
18
+ for b in range(B):
19
+ for n in range(N):
20
+ quat_b_n = pos[b, :, n, :]
21
+ valid_b_n = valid[b, :]
22
+
23
+ invalid_idxs = np.where(~valid_b_n)[0]
24
+ valid_idxs = np.where(valid_b_n)[0]
25
+
26
+ if len(invalid_idxs) == 0:
27
+ continue
28
+
29
+ if len(valid_idxs) > 1:
30
+ valid_times = valid_idxs # 有效时间步
31
+ valid_rots = Rotation.from_rotvec(quat_b_n[valid_idxs]) # 有效四元数
32
+
33
+ slerp = Slerp(valid_times, valid_rots)
34
+
35
+ for idx in invalid_idxs:
36
+ if idx < valid_idxs[0]: # 时间步小于第一个有效时间步,进行外推
37
+ pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[0]] # 复制第一个有效四元数
38
+ elif idx > valid_idxs[-1]: # 时间步大于最后一个有效时间步,进行外推
39
+ pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[-1]] # 复制最后一个有效四元数
40
+ else:
41
+ interp_rot = slerp([idx])
42
+ pos_interp[b, idx, n, :] = interp_rot.as_rotvec()[0]
43
+ # print("#######")
44
+ # if N > 1:
45
+ # print(pos[1,0,11])
46
+ # print(pos_interp[1,0,11])
47
+
48
+ return pos_interp
49
+
50
+ def slerp_interpolation_quat(pos, valid):
51
+
52
+ # wxyz to xyzw
53
+ pos = pos[:, :, :, [1, 2, 3, 0]]
54
+
55
+ B, T, N, _ = pos.shape # B: 批次大小, T: 时间步长, N: 关节数, 4: 四元数维度
56
+ pos_interp = pos.copy() # 创建副本以存储插值结果
57
+
58
+ for b in range(B):
59
+ for n in range(N):
60
+ quat_b_n = pos[b, :, n, :]
61
+ valid_b_n = valid[b, :]
62
+
63
+ invalid_idxs = np.where(~valid_b_n)[0]
64
+ valid_idxs = np.where(valid_b_n)[0]
65
+
66
+ if len(invalid_idxs) == 0:
67
+ continue
68
+
69
+ if len(valid_idxs) > 1:
70
+ valid_times = valid_idxs # 有效时间步
71
+ valid_rots = Rotation.from_quat(quat_b_n[valid_idxs]) # 有效四元数
72
+
73
+ slerp = Slerp(valid_times, valid_rots)
74
+
75
+ for idx in invalid_idxs:
76
+ if idx < valid_idxs[0]: # 时间步小于第一个有效时间步,进行外推
77
+ pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[0]] # 复制第一个有效四元数
78
+ elif idx > valid_idxs[-1]: # 时间步大于最后一个有效时间步,进行外推
79
+ pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[-1]] # 复制最后一个有效四元数
80
+ else:
81
+ interp_rot = slerp([idx])
82
+ pos_interp[b, idx, n, :] = interp_rot.as_quat()[0]
83
+
84
+ # xyzw to wxyz
85
+ pos_interp = pos_interp[:, :, :, [3, 0, 1, 2]]
86
+ return pos_interp
87
+
88
+
89
+ def linear_interpolation_nd(pos, valid):
90
+ B, T = pos.shape[:2] # 取出批次大小B和时间步长T
91
+ feature_dim = pos.shape[2] # ** 代表的任意维度
92
+ pos_interp = pos.copy() # 创建一个副本,用来保存插值结果
93
+
94
+ for b in range(B):
95
+ for idx in range(feature_dim): # 针对任意维度
96
+ pos_b_idx = pos[b, :, idx] # 取出第b批次对应的**维度下的一个时间序列
97
+ valid_b = valid[b, :] # 当前批次的有效标志
98
+
99
+ # 找到无效的索引(False)
100
+ invalid_idxs = np.where(~valid_b)[0]
101
+ valid_idxs = np.where(valid_b)[0]
102
+
103
+ if len(invalid_idxs) == 0:
104
+ continue
105
+
106
+ # 对无效部分进行线性插值
107
+ if len(valid_idxs) > 1: # 确保有足够的有效点用于插值
108
+ pos_b_idx[invalid_idxs] = np.interp(invalid_idxs, valid_idxs, pos_b_idx[valid_idxs])
109
+ pos_interp[b, :, idx] = pos_b_idx # 保存插值结果
110
+
111
+ return pos_interp
112
+
113
+ def world2canonical_convert(R_c2w_sla, t_c2w_sla, data_out, handedness):
114
+ init_rot_mat = copy.deepcopy(data_out["init_root_orient"])
115
+ init_rot_mat = torch.einsum("tij,btjk->btik", R_c2w_sla, init_rot_mat)
116
+ init_rot = rotation_matrix_to_angle_axis(init_rot_mat)
117
+ init_rot_quat = angle_axis_to_quaternion(init_rot)
118
+ # data_out["init_root_orient"] = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
119
+ # data_out["init_hand_pose"] = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
120
+ data_out_init_root_orient = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
121
+ data_out_init_hand_pose = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
122
+
123
+ init_trans = data_out["init_trans"] # (B, T, 3)
124
+ if handedness == "left":
125
+ outputs = run_mano_left(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"])
126
+
127
+ elif handedness == "right":
128
+ outputs = run_mano(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"])
129
+ root_loc = outputs["joints"][..., 0, :].cpu() # (B, T, 3)
130
+ offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
131
+ init_trans = (
132
+ torch.einsum("tij,btj->bti", R_c2w_sla, root_loc)
133
+ + t_c2w_sla[None, :]
134
+ + offset
135
+ )
136
+
137
+ data_world = {
138
+ "init_root_orient": init_rot, # (B, T, 3)
139
+ "init_hand_pose": data_out_init_hand_pose, # (B, T, 15, 3)
140
+ "init_trans": init_trans, # (B, T, 3)
141
+ "init_betas": data_out["init_betas"] # (B, T, 10)
142
+ }
143
+
144
+ return data_world
145
+
146
+ def filling_preprocess(item):
147
+
148
+ num_joints = 15
149
+
150
+ global_trans = item['trans'] # (2, seq_len, 3)
151
+ global_rot = item['rot'] #(2, seq_len, 3)
152
+ hand_pose = item['hand_pose'] # (2, seq_len, 45)
153
+ betas = item['betas'] # (2, seq_len, 10)
154
+ valid = item['valid'] # (2, seq_len)
155
+
156
+ N, T, _ = global_trans.shape
157
+ R_canonical2world_left_aa = torch.from_numpy(global_rot[0, 0])
158
+ R_canonical2world_right_aa = torch.from_numpy(global_rot[1, 0])
159
+ R_world2canonical_left = angle_axis_to_rotation_matrix(R_canonical2world_left_aa).t()
160
+ R_world2canonical_right = angle_axis_to_rotation_matrix(R_canonical2world_right_aa).t()
161
+
162
+
163
+ # transform left hand to canonical
164
+ hand_pose = hand_pose.reshape(N, T, num_joints, 3)
165
+ data_world_left = {
166
+ "init_trans": torch.from_numpy(global_trans[0:1]),
167
+ "init_root_orient": angle_axis_to_rotation_matrix(torch.from_numpy(global_rot[0:1])),
168
+ "init_hand_pose": angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose[0:1])),
169
+ "init_betas": torch.from_numpy(betas[0:1]),
170
+ }
171
+
172
+ data_left_init_root_orient = rotation_matrix_to_angle_axis(data_world_left["init_root_orient"])
173
+ data_left_init_hand_pose = rotation_matrix_to_angle_axis(data_world_left["init_hand_pose"])
174
+ outputs = run_mano_left(data_world_left["init_trans"], data_left_init_root_orient, data_left_init_hand_pose, betas=data_world_left["init_betas"])
175
+ init_trans = data_world_left["init_trans"][0, 0] # (3,)
176
+ root_loc = outputs["joints"][0, 0, 0, :].cpu() # (3,)
177
+ offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
178
+ t_world2canonical_left = -torch.einsum("ij,j->i", R_world2canonical_left, root_loc) - offset
179
+
180
+ R_world2canonical_left = R_world2canonical_left.repeat(T, 1, 1)
181
+ t_world2canonical_left = t_world2canonical_left.repeat(T, 1)
182
+ data_canonical_left = world2canonical_convert(R_world2canonical_left, t_world2canonical_left, data_world_left, "left")
183
+
184
+ # transform right hand to canonical
185
+ data_world_right = {
186
+ "init_trans": torch.from_numpy(global_trans[1:2]),
187
+ "init_root_orient": angle_axis_to_rotation_matrix(torch.from_numpy(global_rot[1:2])),
188
+ "init_hand_pose": angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose[1:2])),
189
+ "init_betas": torch.from_numpy(betas[1:2]),
190
+ }
191
+
192
+ data_right_init_root_orient = rotation_matrix_to_angle_axis(data_world_right["init_root_orient"])
193
+ data_right_init_hand_pose = rotation_matrix_to_angle_axis(data_world_right["init_hand_pose"])
194
+ outputs = run_mano(data_world_right["init_trans"], data_right_init_root_orient, data_right_init_hand_pose, betas=data_world_right["init_betas"])
195
+ init_trans = data_world_right["init_trans"][0, 0] # (3,)
196
+ root_loc = outputs["joints"][0, 0, 0, :].cpu() # (3,)
197
+ offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
198
+ t_world2canonical_right = -torch.einsum("ij,j->i", R_world2canonical_right, root_loc) - offset
199
+
200
+ R_world2canonical_right = R_world2canonical_right.repeat(T, 1, 1)
201
+ t_world2canonical_right = t_world2canonical_right.repeat(T, 1)
202
+ data_canonical_right = world2canonical_convert(R_world2canonical_right, t_world2canonical_right, data_world_right, "right")
203
+
204
+ # merge left and right canonical data
205
+ global_rot = torch.cat((data_canonical_left['init_root_orient'], data_canonical_right['init_root_orient']))
206
+ global_trans = torch.cat((data_canonical_left['init_trans'], data_canonical_right['init_trans'])).numpy()
207
+
208
+ # global_rot = angle_axis_to_quaternion(global_rot).numpy().reshape(N, T, 1, 4)
209
+ global_rot = global_rot.reshape(N, T, 1, 3).numpy()
210
+
211
+ hand_pose = hand_pose.reshape(N, T, 15, 3)
212
+ # hand_pose = angle_axis_to_quaternion(torch.from_numpy(hand_pose)).numpy()
213
+
214
+ # lerp and slerp
215
+ global_trans_lerped = linear_interpolation_nd(global_trans, valid)
216
+ betas_lerped = linear_interpolation_nd(betas, valid)
217
+ global_rot_slerped = slerp_interpolation_aa(global_rot, valid)
218
+ hand_pose_slerped = slerp_interpolation_aa(hand_pose, valid)
219
+
220
+
221
+ # convert to rot6d
222
+
223
+ global_rot_slerped_mat = angle_axis_to_rotation_matrix(torch.from_numpy(global_rot_slerped.reshape(N*T, -1)))
224
+ # global_rot_slerped_mat = quaternion_to_rotation_matrix(torch.from_numpy(global_rot_slerped.reshape(N*T, -1)))
225
+ global_rot_slerped_rot6d = rotmat_to_rot6d(global_rot_slerped_mat).reshape(N, T, -1).numpy()
226
+ hand_pose_slerped_mat = angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose_slerped.reshape(N*T*num_joints, -1)))
227
+ # hand_pose_slerped_mat = quaternion_to_rotation_matrix(torch.from_numpy(hand_pose_slerped.reshape(N*T*num_joints, -1)))
228
+ hand_pose_slerped_rot6d = rotmat_to_rot6d(hand_pose_slerped_mat).reshape(N, T, -1).numpy()
229
+
230
+
231
+ # concat to (T, concat_dim)
232
+ global_pose_vec_input = np.concatenate((global_trans_lerped, betas_lerped, global_rot_slerped_rot6d, hand_pose_slerped_rot6d), axis=-1).transpose(1, 0, 2).reshape(T, -1)
233
+
234
+ R_canon2w_left = R_world2canonical_left.transpose(-1, -2)
235
+ t_canon2w_left = -torch.einsum("tij,tj->ti", R_canon2w_left, t_world2canonical_left)
236
+ R_canon2w_right = R_world2canonical_right.transpose(-1, -2)
237
+ t_canon2w_right = -torch.einsum("tij,tj->ti", R_canon2w_right, t_world2canonical_right)
238
+
239
+ transform_w_canon = {
240
+ "R_w2canon_left": R_world2canonical_left,
241
+ "t_w2canon_left": t_world2canonical_left,
242
+ "R_canon2w_left": R_canon2w_left,
243
+ "t_canon2w_left": t_canon2w_left,
244
+
245
+ "R_w2canon_right": R_world2canonical_right,
246
+ "t_w2canon_right": t_world2canonical_right,
247
+ "R_canon2w_right": R_canon2w_right,
248
+ "t_canon2w_right": t_canon2w_right,
249
+ }
250
+
251
+ return global_pose_vec_input, transform_w_canon
252
+
253
+ def custom_rot6d_to_rotmat(rot6d):
254
+ original_shape = rot6d.shape[:-1]
255
+ rot6d = rot6d.reshape(-1, 6)
256
+ mat = rot6d_to_rotmat(rot6d)
257
+ mat = mat.reshape(*original_shape, 3, 3)
258
+ return mat
259
+
260
+ def filling_postprocess(output, transform_w_canon):
261
+ # output = output.numpy()
262
+ output = output.permute(1, 0, 2) # (2, T, -1)
263
+ N, T, _ = output.shape
264
+ canon_trans = output[:, :, :3]
265
+ betas = output[:, :, 3:13]
266
+ canon_rot_rot6d = output[:, :, 13:19]
267
+ hand_pose_rot6d = output[:, :, 19:109].reshape(N, T, 15, 6)
268
+
269
+ canon_rot_mat = custom_rot6d_to_rotmat(canon_rot_rot6d)
270
+ hand_pose_mat = custom_rot6d_to_rotmat(hand_pose_rot6d)
271
+
272
+ data_canonical_left = {
273
+ "init_trans": canon_trans[[0], :, :],
274
+ "init_root_orient": canon_rot_mat[[0], :, :, :],
275
+ "init_hand_pose": hand_pose_mat[[0], :, :, :, :],
276
+ "init_betas": betas[[0], :, :]
277
+ }
278
+
279
+ data_canonical_right = {
280
+ "init_trans": canon_trans[[1], :, :],
281
+ "init_root_orient": canon_rot_mat[[1], :, :, :],
282
+ "init_hand_pose": hand_pose_mat[[1], :, :, :, :],
283
+ "init_betas": betas[[1], :, :]
284
+ }
285
+
286
+ R_canon2w_left = transform_w_canon['R_canon2w_left']
287
+ t_canon2w_left = transform_w_canon['t_canon2w_left']
288
+ R_canon2w_right = transform_w_canon['R_canon2w_right']
289
+ t_canon2w_right = transform_w_canon['t_canon2w_right']
290
+
291
+
292
+ world_left = world2canonical_convert(R_canon2w_left, t_canon2w_left, data_canonical_left, "left")
293
+ world_right = world2canonical_convert(R_canon2w_right, t_canon2w_right, data_canonical_right, "right")
294
+
295
+ global_rot = torch.cat((world_left['init_root_orient'], world_right['init_root_orient'])).numpy()
296
+ global_trans = torch.cat((world_left['init_trans'], world_right['init_trans'])).numpy()
297
+
298
+ pred_data = {
299
+ "trans": global_trans, # (2, T, 3)
300
+ "rot": global_rot, # (2, T, 3)
301
+ "hand_pose": rotation_matrix_to_angle_axis(hand_pose_mat).flatten(-2).numpy(), # (2, T, 45)
302
+ "betas": betas.numpy(), # (2, T, 10)
303
+ }
304
+
305
+ return pred_data
306
+
lib/eval_utils/video_utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import subprocess
4
+
5
+ def make_video_grid_2x2(out_path, vid_paths, overwrite=False):
6
+ """
7
+ 将四个视频以原始分辨率拼接成 2x2 网格。
8
+
9
+ :param out_path: 输出视频路径。
10
+ :param vid_paths: 输入视频路径的列表(长度必须为 4)。
11
+ :param overwrite: 如果为 True,覆盖已存在的输出文件。
12
+ """
13
+ if os.path.isfile(out_path) and not overwrite:
14
+ print(f"{out_path} already exists, skipping.")
15
+ return
16
+
17
+ if any(not os.path.isfile(v) for v in vid_paths):
18
+ print("Not all inputs exist!", vid_paths)
19
+ return
20
+
21
+ # 确保视频路径长度为 4
22
+ if len(vid_paths) != 4:
23
+ print("Error: Exactly 4 video paths are required!")
24
+ return
25
+
26
+ # 获取视频路径
27
+ v1, v2, v3, v4 = vid_paths
28
+
29
+ # ffmpeg 拼接命令,直接拼接不调整大小
30
+ cmd = (
31
+ f"ffmpeg -i {v1} -i {v2} -i {v3} -i {v4} "
32
+ f"-filter_complex '[0:v][1:v][2:v][3:v]xstack=inputs=4:layout=0_0|w0_0|0_h0|w0_h0[v]' "
33
+ f"-map '[v]' {out_path} -y"
34
+ )
35
+
36
+ print(cmd)
37
+ subprocess.call(cmd, shell=True, stdin=subprocess.PIPE)
38
+
39
+ def create_video_from_images(image_list, output_path, fps=15, target_resolution=(540, 540)):
40
+ """
41
+ 将图片列表合成为 MP4 视频。
42
+
43
+ :param image_list: 图片路径的列表。
44
+ :param output_path: 输出视频的文件路径(如 output.mp4)。
45
+ :param fps: 视频的帧率(默认 15 FPS)。
46
+ """
47
+ # if not image_list:
48
+ # print("图片列表为空!")
49
+ # return
50
+
51
+ # 读取第一张图片以获取宽度和高度
52
+ first_image = cv2.imread(image_list[0])
53
+ if first_image is None:
54
+ print(f"无法读取图片: {image_list[0]}")
55
+ return
56
+
57
+ height, width, _ = first_image.shape
58
+ if height != width:
59
+ if height < width:
60
+ vis_w = target_resolution[0]
61
+ vis_h = int(target_resolution[0] / width * height)
62
+ elif height > width:
63
+ vis_h = target_resolution[0]
64
+ vis_w = int(target_resolution[0] / height * width)
65
+ else:
66
+ vis_h = target_resolution[0]
67
+ vis_w = target_resolution[0]
68
+ target_resolution = (vis_w, vis_h)
69
+
70
+ # 定义视频编码器和输出参数
71
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 mp4v 编码器
72
+ video_writer = cv2.VideoWriter(output_path, fourcc, fps, target_resolution)
73
+
74
+ # 遍历图片列表并写入视频
75
+ for image_path in image_list:
76
+ frame = cv2.imread(image_path)
77
+ frame_resized = cv2.resize(frame, target_resolution)
78
+ if frame is None:
79
+ print(f"无法读取图片: {image_path}")
80
+ continue
81
+ video_writer.write(frame_resized)
82
+
83
+ # 释放视频写入器
84
+ video_writer.release()
85
+ print(f"视频已保存至: {output_path}")
lib/models/__pycache__/hawor.cpython-310.pyc ADDED
Binary file (15.3 kB). View file
 
lib/models/__pycache__/mano_wrapper.cpython-310.pyc ADDED
Binary file (2.43 kB). View file
 
lib/models/__pycache__/modules.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
lib/models/backbones/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .vit import vit
2
+
3
+
4
+ def create_backbone(cfg):
5
+ if cfg.MODEL.BACKBONE.TYPE == 'vit':
6
+ return vit(cfg)
7
+ else:
8
+ raise NotImplementedError('Backbone type is not implemented')
lib/models/backbones/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (445 Bytes). View file
 
lib/models/backbones/__pycache__/vit.cpython-310.pyc ADDED
Binary file (11.3 kB). View file