Spaces:
Running
on
Zero
Running
on
Zero
ThunderVVV
commited on
Commit
·
5f028d6
1
Parent(s):
014faee
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- README.md +94 -12
- _DATA/data/mano/.gitkeep +0 -0
- _DATA/data/mano/MANO_RIGHT.pkl +3 -0
- _DATA/data/mano_mean_params.npz +3 -0
- _DATA/data_left/mano_left/.gitkeep +0 -0
- _DATA/data_left/mano_left/MANO_LEFT.pkl +3 -0
- assets/teaser.png +3 -0
- demo.py +113 -0
- example/video_0.mp4 +3 -0
- hawor/configs/__init__.py +120 -0
- hawor/configs/__pycache__/__init__.cpython-310.pyc +0 -0
- hawor/utils/__pycache__/geometry.cpython-310.pyc +0 -0
- hawor/utils/__pycache__/process.cpython-310.pyc +0 -0
- hawor/utils/__pycache__/pylogger.cpython-310.pyc +0 -0
- hawor/utils/__pycache__/render_openpose.cpython-310.pyc +0 -0
- hawor/utils/__pycache__/rotation.cpython-310.pyc +0 -0
- hawor/utils/geometry.py +102 -0
- hawor/utils/process.py +198 -0
- hawor/utils/pylogger.py +17 -0
- hawor/utils/render_openpose.py +225 -0
- hawor/utils/rotation.py +293 -0
- imgui.ini +15 -0
- infiller/hand_utils/geometry.py +412 -0
- infiller/hand_utils/geometry_utils.py +102 -0
- infiller/hand_utils/mano_wrapper.py +52 -0
- infiller/hand_utils/process.py +171 -0
- infiller/hand_utils/rotation.py +293 -0
- infiller/lib/misc/sampler.py +79 -0
- infiller/lib/model/__pycache__/network.cpython-310.pyc +0 -0
- infiller/lib/model/network.py +276 -0
- infiller/lib/model/positional_encoding.py +42 -0
- infiller/lib/model/preprocess.py +189 -0
- infiller/lib/model/skeleton.py +349 -0
- infiller/lib/vis/pose.py +248 -0
- lib/core/__pycache__/constants.cpython-310.pyc +0 -0
- lib/core/constants.py +78 -0
- lib/datasets/__pycache__/track_dataset.cpython-310.pyc +0 -0
- lib/datasets/track_dataset.py +78 -0
- lib/eval_utils/__pycache__/custom_utils.cpython-310.pyc +0 -0
- lib/eval_utils/__pycache__/filling_utils.cpython-310.pyc +0 -0
- lib/eval_utils/custom_utils.py +99 -0
- lib/eval_utils/filling_utils.py +306 -0
- lib/eval_utils/video_utils.py +85 -0
- lib/models/__pycache__/hawor.cpython-310.pyc +0 -0
- lib/models/__pycache__/mano_wrapper.cpython-310.pyc +0 -0
- lib/models/__pycache__/modules.cpython-310.pyc +0 -0
- lib/models/backbones/__init__.py +8 -0
- lib/models/backbones/__pycache__/__init__.cpython-310.pyc +0 -0
- lib/models/backbones/__pycache__/vit.cpython-310.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,12 +1,94 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
# HaWoR: World-Space Hand Motion Reconstruction from Egocentric Videos
|
4 |
+
|
5 |
+
[Jinglei Zhang]()<sup>1</sup>   [Jiankang Deng](https://jiankangdeng.github.io/)<sup>2</sup>   [Chao Ma](https://scholar.google.com/citations?user=syoPhv8AAAAJ&hl=en)<sup>1</sup>   [Rolandos Alexandros Potamias](https://rolpotamias.github.io)<sup>2</sup>  
|
6 |
+
|
7 |
+
<sup>1</sup>Shanghai Jiao Tong University, China
|
8 |
+
<sup>2</sup>Imperial College London, UK <br>
|
9 |
+
|
10 |
+
<a href='https://hawor-project.github.io/'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
|
11 |
+
<a href='https://arxiv.org/abs/'><img src='https://img.shields.io/badge/Paper-arXiv-red'></a>
|
12 |
+
</div>
|
13 |
+
|
14 |
+
This is the official implementation of **[HaWoR](https://hawor-project.github.io/)**, a hand reconstruction model in the world coordinates:
|
15 |
+
|
16 |
+
![teaser](assets/teaser.png)
|
17 |
+
|
18 |
+
## Installation
|
19 |
+
|
20 |
+
### Installation
|
21 |
+
```
|
22 |
+
git clone --recursive https://github.com/ThunderVVV/HaWoR.git
|
23 |
+
cd HaWoR
|
24 |
+
```
|
25 |
+
|
26 |
+
The code has been tested with PyTorch 1.13 and CUDA 11.7. It is suggested to use an anaconda environment to install the the required dependencies:
|
27 |
+
```bash
|
28 |
+
conda create --name hawor python=3.10
|
29 |
+
conda activate hawor
|
30 |
+
|
31 |
+
pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
|
32 |
+
# Install requirements
|
33 |
+
pip install -r requirements.txt
|
34 |
+
pip install pytorch-lightning==2.2.4 --no-deps
|
35 |
+
pip install lightning-utilities torchmetrics==1.4.0
|
36 |
+
```
|
37 |
+
|
38 |
+
### Install masked DROID-SLAM:
|
39 |
+
|
40 |
+
```
|
41 |
+
cd thirdparty/DROID-SLAM
|
42 |
+
python setup.py install
|
43 |
+
```
|
44 |
+
|
45 |
+
Download DROID-SLAM official weights [droid.pth](https://drive.google.com/file/d/1PpqVt1H4maBa_GbPJp4NwxRsd9jk-elh/view?usp=sharing), put it under `./weights/external/`.
|
46 |
+
|
47 |
+
### Install Metric3D
|
48 |
+
|
49 |
+
Download Metric3D official weights [metric_depth_vit_large_800k.pth](https://drive.google.com/file/d/1eT2gG-kwsVzNy5nJrbm4KC-9DbNKyLnr/view?usp=drive_link), put it under `thirdparty/Metric3D/weights`.
|
50 |
+
|
51 |
+
### Download the model weights
|
52 |
+
|
53 |
+
```bash
|
54 |
+
wget https://huggingface.co/spaces/rolpotamias/WiLoR/resolve/main/pretrained_models/detector.pt -P ./weights/external/
|
55 |
+
wget https://huggingface.co/ThunderVVV/HaWoR/resolve/main/hawor/checkpoints/hawor.ckpt -P ./weights/hawor/checkpoints/
|
56 |
+
wget https://huggingface.co/ThunderVVV/HaWoR/resolve/main/hawor/checkpoints/infiller.pt -P ./weights/hawor/checkpoints/
|
57 |
+
wget https://huggingface.co/ThunderVVV/HaWoR/resolve/main/hawor/model_config.yaml -P ./weights/hawor/
|
58 |
+
```
|
59 |
+
It is also required to download MANO model from [MANO website](https://mano.is.tue.mpg.de).
|
60 |
+
Create an account by clicking Sign Up and download the models (mano_v*_*.zip). Unzip and put the hand model to the `_DATA/data/mano/MANO_RIGHT.pkl` and `_DATA/data_left/mano_left/MANO_LEFT.pkl`.
|
61 |
+
|
62 |
+
Note that MANO model falls under the [MANO license](https://mano.is.tue.mpg.de/license.html).
|
63 |
+
## Demo
|
64 |
+
|
65 |
+
For visualizaiton in world view, run with:
|
66 |
+
```bash
|
67 |
+
python demo.py --video_path ./example/video_0.mp4 --vis_mode world
|
68 |
+
```
|
69 |
+
|
70 |
+
For visualizaiton in camera view, run with:
|
71 |
+
```bash
|
72 |
+
python demo.py --video_path ./example/video_0.mp4 --vis_mode cam
|
73 |
+
```
|
74 |
+
|
75 |
+
## Training
|
76 |
+
The training code will be released soon.
|
77 |
+
|
78 |
+
## Acknowledgements
|
79 |
+
Parts of the code are taken or adapted from the following repos:
|
80 |
+
- [HaMeR](https://github.com/geopavlakos/hamer/)
|
81 |
+
- [WiLoR](https://github.com/rolpotamias/WiLoR)
|
82 |
+
- [SLAHMR](https://github.com/vye16/slahmr)
|
83 |
+
- [TRAM](https://github.com/yufu-wang/tram)
|
84 |
+
- [CMIB](https://github.com/jihoonerd/Conditional-Motion-In-Betweening)
|
85 |
+
|
86 |
+
|
87 |
+
## License
|
88 |
+
HaWoR models fall under the [CC-BY-NC--ND License](./license.txt). This repository depends also on [MANO Model](https://mano.is.tue.mpg.de/license.html), which are fall under their own licenses. By using this repository, you must also comply with the terms of these external licenses.
|
89 |
+
## Citing
|
90 |
+
If you find HaWoR useful for your research, please consider citing our paper:
|
91 |
+
|
92 |
+
```bibtex
|
93 |
+
|
94 |
+
```
|
_DATA/data/mano/.gitkeep
ADDED
File without changes
|
_DATA/data/mano/MANO_RIGHT.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:45d60aa3b27ef9107a7afd4e00808f307fd91111e1cfa35afd5c4a62de264767
|
3 |
+
size 3821356
|
_DATA/data/mano_mean_params.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:efc0ec58e4a5cef78f3abfb4e8f91623b8950be9eff8b8e0dbb0d036ebc63988
|
3 |
+
size 1178
|
_DATA/data_left/mano_left/.gitkeep
ADDED
File without changes
|
_DATA/data_left/mano_left/MANO_LEFT.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c4022f7083f2ca7c78b2b3d595abbab52debd32b09d372b16923a801f0ea6a30
|
3 |
+
size 3821391
|
assets/teaser.png
ADDED
Git LFS Details
|
demo.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
7 |
+
import numpy as np
|
8 |
+
import joblib
|
9 |
+
from scripts.scripts_test_video.detect_track_video import detect_track_video
|
10 |
+
from scripts.scripts_test_video.hawor_video import hawor_motion_estimation, hawor_infiller
|
11 |
+
from scripts.scripts_test_video.hawor_slam import hawor_slam
|
12 |
+
from hawor.utils.process import get_mano_faces, run_mano, run_mano_left
|
13 |
+
from lib.eval_utils.custom_utils import load_slam_cam
|
14 |
+
from lib.vis.run_vis2 import run_vis2_on_video, run_vis2_on_video_cam
|
15 |
+
|
16 |
+
|
17 |
+
if __name__ == '__main__':
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument("--img_focal", type=float)
|
20 |
+
parser.add_argument("--video_path", type=str, default='example/video_0.mp4')
|
21 |
+
parser.add_argument("--input_type", type=str, default='file')
|
22 |
+
parser.add_argument("--checkpoint", type=str, default='./weights/hawor/checkpoints/hawor.ckpt')
|
23 |
+
parser.add_argument("--infiller_weight", type=str, default='./weights/hawor/checkpoints/infiller.pt')
|
24 |
+
parser.add_argument("--vis_mode", type=str, default='world', help='cam | world')
|
25 |
+
args = parser.parse_args()
|
26 |
+
|
27 |
+
start_idx, end_idx, seq_folder, imgfiles = detect_track_video(args)
|
28 |
+
|
29 |
+
frame_chunks_all, img_focal = hawor_motion_estimation(args, start_idx, end_idx, seq_folder)
|
30 |
+
|
31 |
+
hawor_slam(args, start_idx, end_idx)
|
32 |
+
slam_path = os.path.join(seq_folder, f"SLAM/hawor_slam_w_scale_{start_idx}_{end_idx}.npz")
|
33 |
+
R_w2c_sla_all, t_w2c_sla_all, R_c2w_sla_all, t_c2w_sla_all = load_slam_cam(slam_path)
|
34 |
+
|
35 |
+
pred_trans, pred_rot, pred_hand_pose, pred_betas, pred_valid = hawor_infiller(args, start_idx, end_idx, frame_chunks_all)
|
36 |
+
|
37 |
+
# vis sequence for this video
|
38 |
+
hand2idx = {
|
39 |
+
"right": 1,
|
40 |
+
"left": 0
|
41 |
+
}
|
42 |
+
vis_start = 0
|
43 |
+
vis_end = pred_trans.shape[1] - 1
|
44 |
+
|
45 |
+
# get faces
|
46 |
+
faces = get_mano_faces()
|
47 |
+
faces_new = np.array([[92, 38, 234],
|
48 |
+
[234, 38, 239],
|
49 |
+
[38, 122, 239],
|
50 |
+
[239, 122, 279],
|
51 |
+
[122, 118, 279],
|
52 |
+
[279, 118, 215],
|
53 |
+
[118, 117, 215],
|
54 |
+
[215, 117, 214],
|
55 |
+
[117, 119, 214],
|
56 |
+
[214, 119, 121],
|
57 |
+
[119, 120, 121],
|
58 |
+
[121, 120, 78],
|
59 |
+
[120, 108, 78],
|
60 |
+
[78, 108, 79]])
|
61 |
+
faces_right = np.concatenate([faces, faces_new], axis=0)
|
62 |
+
|
63 |
+
# get right hand vertices
|
64 |
+
hand = 'right'
|
65 |
+
hand_idx = hand2idx[hand]
|
66 |
+
pred_glob_r = run_mano(pred_trans[hand_idx:hand_idx+1, vis_start:vis_end], pred_rot[hand_idx:hand_idx+1, vis_start:vis_end], pred_hand_pose[hand_idx:hand_idx+1, vis_start:vis_end], betas=pred_betas[hand_idx:hand_idx+1, vis_start:vis_end])
|
67 |
+
right_verts = pred_glob_r['vertices'][0]
|
68 |
+
right_dict = {
|
69 |
+
'vertices': right_verts.unsqueeze(0),
|
70 |
+
'faces': faces_right,
|
71 |
+
}
|
72 |
+
|
73 |
+
# get left hand vertices
|
74 |
+
faces_left = faces_right[:,[0,2,1]]
|
75 |
+
hand = 'left'
|
76 |
+
hand_idx = hand2idx[hand]
|
77 |
+
pred_glob_l = run_mano_left(pred_trans[hand_idx:hand_idx+1, vis_start:vis_end], pred_rot[hand_idx:hand_idx+1, vis_start:vis_end], pred_hand_pose[hand_idx:hand_idx+1, vis_start:vis_end], betas=pred_betas[hand_idx:hand_idx+1, vis_start:vis_end])
|
78 |
+
left_verts = pred_glob_l['vertices'][0]
|
79 |
+
left_dict = {
|
80 |
+
'vertices': left_verts.unsqueeze(0),
|
81 |
+
'faces': faces_left,
|
82 |
+
}
|
83 |
+
|
84 |
+
R_x = torch.tensor([[1, 0, 0],
|
85 |
+
[0, -1, 0],
|
86 |
+
[0, 0, -1]]).float()
|
87 |
+
R_c2w_sla_all = torch.einsum('ij,njk->nik', R_x, R_c2w_sla_all)
|
88 |
+
t_c2w_sla_all = torch.einsum('ij,nj->ni', R_x, t_c2w_sla_all)
|
89 |
+
R_w2c_sla_all = R_c2w_sla_all.transpose(-1, -2)
|
90 |
+
t_w2c_sla_all = -torch.einsum("bij,bj->bi", R_w2c_sla_all, t_c2w_sla_all)
|
91 |
+
left_dict['vertices'] = torch.einsum('ij,btnj->btni', R_x, left_dict['vertices'].cpu())
|
92 |
+
right_dict['vertices'] = torch.einsum('ij,btnj->btni', R_x, right_dict['vertices'].cpu())
|
93 |
+
|
94 |
+
# Here we use aitviewer(https://github.com/eth-ait/aitviewer) for simple visualization.
|
95 |
+
if args.vis_mode == 'world':
|
96 |
+
output_pth = os.path.join(seq_folder, f"vis_{vis_start}_{vis_end}")
|
97 |
+
if not os.path.exists(output_pth):
|
98 |
+
os.makedirs(output_pth)
|
99 |
+
image_names = imgfiles[vis_start:vis_end]
|
100 |
+
print(f"vis {vis_start} to {vis_end}")
|
101 |
+
run_vis2_on_video(left_dict, right_dict, output_pth, img_focal, image_names, R_c2w=R_c2w_sla_all[vis_start:vis_end], t_c2w=t_c2w_sla_all[vis_start:vis_end])
|
102 |
+
elif args.vis_mode == 'cam':
|
103 |
+
output_pth = os.path.join(seq_folder, f"vis_{vis_start}_{vis_end}")
|
104 |
+
if not os.path.exists(output_pth):
|
105 |
+
os.makedirs(output_pth)
|
106 |
+
image_names = imgfiles[vis_start:vis_end]
|
107 |
+
print(f"vis {vis_start} to {vis_end}")
|
108 |
+
run_vis2_on_video_cam(left_dict, right_dict, output_pth, img_focal, image_names, R_w2c=R_w2c_sla_all[vis_start:vis_end], t_w2c=t_w2c_sla_all[vis_start:vis_end])
|
109 |
+
|
110 |
+
print("finish")
|
111 |
+
|
112 |
+
|
113 |
+
|
example/video_0.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:13ff124a68e4b48190e0c3f0ce9f38db59c5e3bb8a093b3c7fc9c67276be2062
|
3 |
+
size 6515891
|
hawor/configs/__init__.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict
|
3 |
+
from yacs.config import CfgNode as CN
|
4 |
+
|
5 |
+
CACHE_DIR_HAWOR = "./_DATA"
|
6 |
+
|
7 |
+
def to_lower(x: Dict) -> Dict:
|
8 |
+
"""
|
9 |
+
Convert all dictionary keys to lowercase
|
10 |
+
Args:
|
11 |
+
x (dict): Input dictionary
|
12 |
+
Returns:
|
13 |
+
dict: Output dictionary with all keys converted to lowercase
|
14 |
+
"""
|
15 |
+
return {k.lower(): v for k, v in x.items()}
|
16 |
+
|
17 |
+
_C = CN(new_allowed=True)
|
18 |
+
|
19 |
+
_C.GENERAL = CN(new_allowed=True)
|
20 |
+
_C.GENERAL.RESUME = True
|
21 |
+
_C.GENERAL.TIME_TO_RUN = 3300
|
22 |
+
_C.GENERAL.VAL_STEPS = 100
|
23 |
+
_C.GENERAL.LOG_STEPS = 100
|
24 |
+
_C.GENERAL.CHECKPOINT_STEPS = 20000
|
25 |
+
_C.GENERAL.CHECKPOINT_DIR = "checkpoints"
|
26 |
+
_C.GENERAL.SUMMARY_DIR = "tensorboard"
|
27 |
+
_C.GENERAL.NUM_GPUS = 1
|
28 |
+
_C.GENERAL.NUM_WORKERS = 4
|
29 |
+
_C.GENERAL.MIXED_PRECISION = True
|
30 |
+
_C.GENERAL.ALLOW_CUDA = True
|
31 |
+
_C.GENERAL.PIN_MEMORY = False
|
32 |
+
_C.GENERAL.DISTRIBUTED = False
|
33 |
+
_C.GENERAL.LOCAL_RANK = 0
|
34 |
+
_C.GENERAL.USE_SYNCBN = False
|
35 |
+
_C.GENERAL.WORLD_SIZE = 1
|
36 |
+
|
37 |
+
_C.TRAIN = CN(new_allowed=True)
|
38 |
+
_C.TRAIN.NUM_EPOCHS = 100
|
39 |
+
_C.TRAIN.BATCH_SIZE = 32
|
40 |
+
_C.TRAIN.SHUFFLE = True
|
41 |
+
_C.TRAIN.WARMUP = False
|
42 |
+
_C.TRAIN.NORMALIZE_PER_IMAGE = False
|
43 |
+
_C.TRAIN.CLIP_GRAD = False
|
44 |
+
_C.TRAIN.CLIP_GRAD_VALUE = 1.0
|
45 |
+
_C.LOSS_WEIGHTS = CN(new_allowed=True)
|
46 |
+
|
47 |
+
_C.DATASETS = CN(new_allowed=True)
|
48 |
+
|
49 |
+
_C.MODEL = CN(new_allowed=True)
|
50 |
+
_C.MODEL.IMAGE_SIZE = 224
|
51 |
+
|
52 |
+
_C.EXTRA = CN(new_allowed=True)
|
53 |
+
_C.EXTRA.FOCAL_LENGTH = 5000
|
54 |
+
|
55 |
+
_C.DATASETS.CONFIG = CN(new_allowed=True)
|
56 |
+
_C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
|
57 |
+
_C.DATASETS.CONFIG.ROT_FACTOR = 30
|
58 |
+
_C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
|
59 |
+
_C.DATASETS.CONFIG.COLOR_SCALE = 0.2
|
60 |
+
_C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
|
61 |
+
_C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
|
62 |
+
_C.DATASETS.CONFIG.DO_FLIP = False
|
63 |
+
_C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
|
64 |
+
_C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
|
65 |
+
|
66 |
+
def default_config() -> CN:
|
67 |
+
"""
|
68 |
+
Get a yacs CfgNode object with the default config values.
|
69 |
+
"""
|
70 |
+
# Return a clone so that the defaults will not be altered
|
71 |
+
# This is for the "local variable" use pattern
|
72 |
+
return _C.clone()
|
73 |
+
|
74 |
+
def dataset_config() -> CN:
|
75 |
+
"""
|
76 |
+
Get dataset config file
|
77 |
+
Returns:
|
78 |
+
CfgNode: Dataset config as a yacs CfgNode object.
|
79 |
+
"""
|
80 |
+
cfg = CN(new_allowed=True)
|
81 |
+
config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets_tar.yaml')
|
82 |
+
cfg.merge_from_file(config_file)
|
83 |
+
cfg.freeze()
|
84 |
+
return cfg
|
85 |
+
|
86 |
+
def dataset_eval_config() -> CN:
|
87 |
+
cfg = CN(new_allowed=True)
|
88 |
+
config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets_eval.yaml')
|
89 |
+
cfg.merge_from_file(config_file)
|
90 |
+
cfg.freeze()
|
91 |
+
return cfg
|
92 |
+
|
93 |
+
def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN:
|
94 |
+
"""
|
95 |
+
Read a config file and optionally merge it with the default config file.
|
96 |
+
Args:
|
97 |
+
config_file (str): Path to config file.
|
98 |
+
merge (bool): Whether to merge with the default config or not.
|
99 |
+
Returns:
|
100 |
+
CfgNode: Config as a yacs CfgNode object.
|
101 |
+
"""
|
102 |
+
if merge:
|
103 |
+
cfg = default_config()
|
104 |
+
else:
|
105 |
+
cfg = CN(new_allowed=True)
|
106 |
+
cfg.merge_from_file(config_file)
|
107 |
+
|
108 |
+
if update_cachedir:
|
109 |
+
def update_path(path: str) -> str:
|
110 |
+
if os.path.basename(CACHE_DIR_HAWOR) in path:
|
111 |
+
return path
|
112 |
+
if os.path.isabs(path):
|
113 |
+
return path
|
114 |
+
return os.path.join(CACHE_DIR_HAWOR, path)
|
115 |
+
|
116 |
+
cfg.MANO.MODEL_PATH = update_path(cfg.MANO.MODEL_PATH)
|
117 |
+
cfg.MANO.MEAN_PARAMS = update_path(cfg.MANO.MEAN_PARAMS)
|
118 |
+
|
119 |
+
cfg.freeze()
|
120 |
+
return cfg
|
hawor/configs/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (3.55 kB). View file
|
|
hawor/utils/__pycache__/geometry.cpython-310.pyc
ADDED
Binary file (4.09 kB). View file
|
|
hawor/utils/__pycache__/process.cpython-310.pyc
ADDED
Binary file (5.54 kB). View file
|
|
hawor/utils/__pycache__/pylogger.cpython-310.pyc
ADDED
Binary file (655 Bytes). View file
|
|
hawor/utils/__pycache__/render_openpose.cpython-310.pyc
ADDED
Binary file (7.24 kB). View file
|
|
hawor/utils/__pycache__/rotation.cpython-310.pyc
ADDED
Binary file (7.65 kB). View file
|
|
hawor/utils/geometry.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import torch
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
def aa_to_rotmat(theta: torch.Tensor):
|
6 |
+
"""
|
7 |
+
Convert axis-angle representation to rotation matrix.
|
8 |
+
Works by first converting it to a quaternion.
|
9 |
+
Args:
|
10 |
+
theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
|
11 |
+
Returns:
|
12 |
+
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
|
13 |
+
"""
|
14 |
+
norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
|
15 |
+
angle = torch.unsqueeze(norm, -1)
|
16 |
+
normalized = torch.div(theta, angle)
|
17 |
+
angle = angle * 0.5
|
18 |
+
v_cos = torch.cos(angle)
|
19 |
+
v_sin = torch.sin(angle)
|
20 |
+
quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
|
21 |
+
return quat_to_rotmat(quat)
|
22 |
+
|
23 |
+
def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
|
24 |
+
"""
|
25 |
+
Convert quaternion representation to rotation matrix.
|
26 |
+
Args:
|
27 |
+
quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
|
28 |
+
Returns:
|
29 |
+
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
|
30 |
+
"""
|
31 |
+
norm_quat = quat
|
32 |
+
norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
|
33 |
+
w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
|
34 |
+
|
35 |
+
B = quat.size(0)
|
36 |
+
|
37 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
38 |
+
wx, wy, wz = w*x, w*y, w*z
|
39 |
+
xy, xz, yz = x*y, x*z, y*z
|
40 |
+
|
41 |
+
rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
|
42 |
+
2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
|
43 |
+
2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
|
44 |
+
return rotMat
|
45 |
+
|
46 |
+
|
47 |
+
def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
|
48 |
+
"""
|
49 |
+
Convert 6D rotation representation to 3x3 rotation matrix.
|
50 |
+
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
51 |
+
Args:
|
52 |
+
x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
|
53 |
+
Returns:
|
54 |
+
torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
|
55 |
+
"""
|
56 |
+
x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
|
57 |
+
a1 = x[:, :, 0]
|
58 |
+
a2 = x[:, :, 1]
|
59 |
+
b1 = F.normalize(a1)
|
60 |
+
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
|
61 |
+
b3 = torch.linalg.cross(b1, b2)
|
62 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
63 |
+
|
64 |
+
def perspective_projection(points: torch.Tensor,
|
65 |
+
translation: torch.Tensor,
|
66 |
+
focal_length: torch.Tensor,
|
67 |
+
camera_center: Optional[torch.Tensor] = None,
|
68 |
+
rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
|
69 |
+
"""
|
70 |
+
Computes the perspective projection of a set of 3D points.
|
71 |
+
Args:
|
72 |
+
points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
|
73 |
+
translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
|
74 |
+
focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
|
75 |
+
camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
|
76 |
+
rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
|
77 |
+
Returns:
|
78 |
+
torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
|
79 |
+
"""
|
80 |
+
batch_size = points.shape[0]
|
81 |
+
if rotation is None:
|
82 |
+
rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
|
83 |
+
if camera_center is None:
|
84 |
+
camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
|
85 |
+
# Populate intrinsic camera matrix K.
|
86 |
+
K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
|
87 |
+
K[:,0,0] = focal_length[:,0]
|
88 |
+
K[:,1,1] = focal_length[:,1]
|
89 |
+
K[:,2,2] = 1.
|
90 |
+
K[:,:-1, -1] = camera_center
|
91 |
+
|
92 |
+
# Transform points
|
93 |
+
points = torch.einsum('bij,bkj->bki', rotation, points)
|
94 |
+
points = points + translation.unsqueeze(1)
|
95 |
+
|
96 |
+
# Apply perspective distortion
|
97 |
+
projected_points = points / points[:,:,-1].unsqueeze(-1)
|
98 |
+
|
99 |
+
# Apply camera intrinsics
|
100 |
+
projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
|
101 |
+
|
102 |
+
return projected_points[:, :, :-1]
|
hawor/utils/process.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from lib.models.mano_wrapper import MANO
|
3 |
+
from hawor.utils.geometry import aa_to_rotmat
|
4 |
+
import numpy as np
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
|
8 |
+
def block_print():
|
9 |
+
sys.stdout = open(os.devnull, 'w')
|
10 |
+
|
11 |
+
def enable_print():
|
12 |
+
sys.stdout = sys.__stdout__
|
13 |
+
|
14 |
+
def get_mano_faces():
|
15 |
+
block_print()
|
16 |
+
MANO_cfg = {
|
17 |
+
'DATA_DIR': '_DATA/data/',
|
18 |
+
'MODEL_PATH': '_DATA/data/mano',
|
19 |
+
'GENDER': 'neutral',
|
20 |
+
'NUM_HAND_JOINTS': 15,
|
21 |
+
'CREATE_BODY_POSE': False
|
22 |
+
}
|
23 |
+
mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
|
24 |
+
mano = MANO(**mano_cfg)
|
25 |
+
enable_print()
|
26 |
+
return mano.faces
|
27 |
+
|
28 |
+
|
29 |
+
def run_mano(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True):
|
30 |
+
"""
|
31 |
+
Forward pass of the SMPL model and populates pred_data accordingly with
|
32 |
+
joints3d, verts3d, points3d.
|
33 |
+
|
34 |
+
trans : B x T x 3
|
35 |
+
root_orient : B x T x 3
|
36 |
+
body_pose : B x T x J*3
|
37 |
+
betas : (optional) B x D
|
38 |
+
"""
|
39 |
+
block_print()
|
40 |
+
MANO_cfg = {
|
41 |
+
'DATA_DIR': '_DATA/data/',
|
42 |
+
'MODEL_PATH': '_DATA/data/mano',
|
43 |
+
'GENDER': 'neutral',
|
44 |
+
'NUM_HAND_JOINTS': 15,
|
45 |
+
'CREATE_BODY_POSE': False
|
46 |
+
}
|
47 |
+
mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
|
48 |
+
mano = MANO(**mano_cfg)
|
49 |
+
if use_cuda:
|
50 |
+
mano = mano.cuda()
|
51 |
+
|
52 |
+
B, T, _ = root_orient.shape
|
53 |
+
NUM_JOINTS = 15
|
54 |
+
mano_params = {
|
55 |
+
'global_orient': root_orient.reshape(B*T, -1),
|
56 |
+
'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3),
|
57 |
+
'betas': betas.reshape(B*T, -1),
|
58 |
+
}
|
59 |
+
rotmat_mano_params = mano_params
|
60 |
+
rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3)
|
61 |
+
rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3)
|
62 |
+
rotmat_mano_params['transl'] = trans.reshape(B*T, 3)
|
63 |
+
|
64 |
+
if use_cuda:
|
65 |
+
mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False)
|
66 |
+
else:
|
67 |
+
mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False)
|
68 |
+
|
69 |
+
faces_right = mano.faces
|
70 |
+
faces_new = np.array([[92, 38, 234],
|
71 |
+
[234, 38, 239],
|
72 |
+
[38, 122, 239],
|
73 |
+
[239, 122, 279],
|
74 |
+
[122, 118, 279],
|
75 |
+
[279, 118, 215],
|
76 |
+
[118, 117, 215],
|
77 |
+
[215, 117, 214],
|
78 |
+
[117, 119, 214],
|
79 |
+
[214, 119, 121],
|
80 |
+
[119, 120, 121],
|
81 |
+
[121, 120, 78],
|
82 |
+
[120, 108, 78],
|
83 |
+
[78, 108, 79]])
|
84 |
+
faces_right = np.concatenate([faces_right, faces_new], axis=0)
|
85 |
+
faces_n = len(faces_right)
|
86 |
+
faces_left = faces_right[:,[0,2,1]]
|
87 |
+
|
88 |
+
outputs = {
|
89 |
+
"joints": mano_output.joints.reshape(B, T, -1, 3),
|
90 |
+
"vertices": mano_output.vertices.reshape(B, T, -1, 3),
|
91 |
+
}
|
92 |
+
|
93 |
+
if not is_right is None:
|
94 |
+
# outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0]
|
95 |
+
# outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0]
|
96 |
+
is_right = (is_right[:, :, 0].cpu().numpy() > 0)
|
97 |
+
faces_result = np.zeros((B, T, faces_n, 3))
|
98 |
+
faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0)
|
99 |
+
faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0)
|
100 |
+
faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded)
|
101 |
+
outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32))
|
102 |
+
|
103 |
+
|
104 |
+
enable_print()
|
105 |
+
return outputs
|
106 |
+
|
107 |
+
def run_mano_left(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True, fix_shapedirs=True):
|
108 |
+
"""
|
109 |
+
Forward pass of the SMPL model and populates pred_data accordingly with
|
110 |
+
joints3d, verts3d, points3d.
|
111 |
+
|
112 |
+
trans : B x T x 3
|
113 |
+
root_orient : B x T x 3
|
114 |
+
body_pose : B x T x J*3
|
115 |
+
betas : (optional) B x D
|
116 |
+
"""
|
117 |
+
block_print()
|
118 |
+
MANO_cfg = {
|
119 |
+
'DATA_DIR': '_DATA/data_left/',
|
120 |
+
'MODEL_PATH': '_DATA/data_left/mano_left',
|
121 |
+
'GENDER': 'neutral',
|
122 |
+
'NUM_HAND_JOINTS': 15,
|
123 |
+
'CREATE_BODY_POSE': False,
|
124 |
+
'is_rhand': False
|
125 |
+
}
|
126 |
+
mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
|
127 |
+
mano = MANO(**mano_cfg)
|
128 |
+
if use_cuda:
|
129 |
+
mano = mano.cuda()
|
130 |
+
|
131 |
+
# fix MANO shapedirs of the left hand bug (https://github.com/vchoutas/smplx/issues/48)
|
132 |
+
if fix_shapedirs:
|
133 |
+
mano.shapedirs[:, 0, :] *= -1
|
134 |
+
|
135 |
+
B, T, _ = root_orient.shape
|
136 |
+
NUM_JOINTS = 15
|
137 |
+
mano_params = {
|
138 |
+
'global_orient': root_orient.reshape(B*T, -1),
|
139 |
+
'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3),
|
140 |
+
'betas': betas.reshape(B*T, -1),
|
141 |
+
}
|
142 |
+
rotmat_mano_params = mano_params
|
143 |
+
rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3)
|
144 |
+
rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3)
|
145 |
+
rotmat_mano_params['transl'] = trans.reshape(B*T, 3)
|
146 |
+
|
147 |
+
if use_cuda:
|
148 |
+
mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False)
|
149 |
+
else:
|
150 |
+
mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False)
|
151 |
+
|
152 |
+
faces_right = mano.faces
|
153 |
+
faces_new = np.array([[92, 38, 234],
|
154 |
+
[234, 38, 239],
|
155 |
+
[38, 122, 239],
|
156 |
+
[239, 122, 279],
|
157 |
+
[122, 118, 279],
|
158 |
+
[279, 118, 215],
|
159 |
+
[118, 117, 215],
|
160 |
+
[215, 117, 214],
|
161 |
+
[117, 119, 214],
|
162 |
+
[214, 119, 121],
|
163 |
+
[119, 120, 121],
|
164 |
+
[121, 120, 78],
|
165 |
+
[120, 108, 78],
|
166 |
+
[78, 108, 79]])
|
167 |
+
faces_right = np.concatenate([faces_right, faces_new], axis=0)
|
168 |
+
faces_n = len(faces_right)
|
169 |
+
faces_left = faces_right[:,[0,2,1]]
|
170 |
+
|
171 |
+
outputs = {
|
172 |
+
"joints": mano_output.joints.reshape(B, T, -1, 3),
|
173 |
+
"vertices": mano_output.vertices.reshape(B, T, -1, 3),
|
174 |
+
}
|
175 |
+
|
176 |
+
if not is_right is None:
|
177 |
+
# outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0]
|
178 |
+
# outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0]
|
179 |
+
is_right = (is_right[:, :, 0].cpu().numpy() > 0)
|
180 |
+
faces_result = np.zeros((B, T, faces_n, 3))
|
181 |
+
faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0)
|
182 |
+
faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0)
|
183 |
+
faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded)
|
184 |
+
outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32))
|
185 |
+
|
186 |
+
|
187 |
+
enable_print()
|
188 |
+
return outputs
|
189 |
+
|
190 |
+
def run_mano_twohands(init_trans, init_rot, init_hand_pose, is_right, init_betas, use_cuda=True, fix_shapedirs=True):
|
191 |
+
outputs_left = run_mano_left(init_trans[0:1], init_rot[0:1], init_hand_pose[0:1], None, init_betas[0:1], use_cuda=use_cuda, fix_shapedirs=fix_shapedirs)
|
192 |
+
outputs_right = run_mano(init_trans[1:2], init_rot[1:2], init_hand_pose[1:2], None, init_betas[1:2], use_cuda=use_cuda)
|
193 |
+
outputs_two = {
|
194 |
+
"vertices": torch.cat((outputs_left["vertices"], outputs_right["vertices"]), dim=0),
|
195 |
+
"joints": torch.cat((outputs_left["joints"], outputs_right["joints"]), dim=0)
|
196 |
+
|
197 |
+
}
|
198 |
+
return outputs_two
|
hawor/utils/pylogger.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from pytorch_lightning.utilities import rank_zero_only
|
4 |
+
|
5 |
+
|
6 |
+
def get_pylogger(name=__name__) -> logging.Logger:
|
7 |
+
"""Initializes multi-GPU-friendly python command line logger."""
|
8 |
+
|
9 |
+
logger = logging.getLogger(name)
|
10 |
+
|
11 |
+
# this ensures all logging levels get marked with the rank zero decorator
|
12 |
+
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
13 |
+
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
|
14 |
+
for level in logging_levels:
|
15 |
+
setattr(logger, level, rank_zero_only(getattr(logger, level)))
|
16 |
+
|
17 |
+
return logger
|
hawor/utils/render_openpose.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Render OpenPose keypoints.
|
3 |
+
Code was ported to Python from the official C++ implementation https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/utilities/keypoint.cpp
|
4 |
+
"""
|
5 |
+
import cv2
|
6 |
+
import math
|
7 |
+
import numpy as np
|
8 |
+
from typing import List, Tuple
|
9 |
+
|
10 |
+
def get_keypoints_rectangle(keypoints: np.array, threshold: float) -> Tuple[float, float, float]:
|
11 |
+
"""
|
12 |
+
Compute rectangle enclosing keypoints above the threshold.
|
13 |
+
Args:
|
14 |
+
keypoints (np.array): Keypoint array of shape (N, 3).
|
15 |
+
threshold (float): Confidence visualization threshold.
|
16 |
+
Returns:
|
17 |
+
Tuple[float, float, float]: Rectangle width, height and area.
|
18 |
+
"""
|
19 |
+
valid_ind = keypoints[:, -1] > threshold
|
20 |
+
if valid_ind.sum() > 0:
|
21 |
+
valid_keypoints = keypoints[valid_ind][:, :-1]
|
22 |
+
max_x = valid_keypoints[:,0].max()
|
23 |
+
max_y = valid_keypoints[:,1].max()
|
24 |
+
min_x = valid_keypoints[:,0].min()
|
25 |
+
min_y = valid_keypoints[:,1].min()
|
26 |
+
width = max_x - min_x
|
27 |
+
height = max_y - min_y
|
28 |
+
area = width * height
|
29 |
+
return width, height, area
|
30 |
+
else:
|
31 |
+
return 0,0,0
|
32 |
+
|
33 |
+
def render_keypoints(img: np.array,
|
34 |
+
keypoints: np.array,
|
35 |
+
pairs: List,
|
36 |
+
colors: List,
|
37 |
+
thickness_circle_ratio: float,
|
38 |
+
thickness_line_ratio_wrt_circle: float,
|
39 |
+
pose_scales: List,
|
40 |
+
threshold: float = 0.1,
|
41 |
+
alpha: float = 1.0) -> np.array:
|
42 |
+
"""
|
43 |
+
Render keypoints on input image.
|
44 |
+
Args:
|
45 |
+
img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
|
46 |
+
keypoints (np.array): Keypoint array of shape (N, 3).
|
47 |
+
pairs (List): List of keypoint pairs per limb.
|
48 |
+
colors: (List): List of colors per keypoint.
|
49 |
+
thickness_circle_ratio (float): Circle thickness ratio.
|
50 |
+
thickness_line_ratio_wrt_circle (float): Line thickness ratio wrt the circle.
|
51 |
+
pose_scales (List): List of pose scales.
|
52 |
+
threshold (float): Only visualize keypoints with confidence above the threshold.
|
53 |
+
Returns:
|
54 |
+
(np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
|
55 |
+
"""
|
56 |
+
img_orig = img.copy()
|
57 |
+
width, height = img.shape[1], img.shape[2]
|
58 |
+
area = width * height
|
59 |
+
|
60 |
+
lineType = 8
|
61 |
+
shift = 0
|
62 |
+
numberColors = len(colors)
|
63 |
+
thresholdRectangle = 0.1
|
64 |
+
|
65 |
+
person_width, person_height, person_area = get_keypoints_rectangle(keypoints, thresholdRectangle)
|
66 |
+
if person_area > 0:
|
67 |
+
ratioAreas = min(1, max(person_width / width, person_height / height))
|
68 |
+
thicknessRatio = np.maximum(np.round(math.sqrt(area) * thickness_circle_ratio * ratioAreas), 2)
|
69 |
+
thicknessCircle = np.maximum(1, thicknessRatio if ratioAreas > 0.05 else -np.ones_like(thicknessRatio))
|
70 |
+
thicknessLine = np.maximum(1, np.round(thicknessRatio * thickness_line_ratio_wrt_circle))
|
71 |
+
radius = thicknessRatio / 2
|
72 |
+
|
73 |
+
img = np.ascontiguousarray(img.copy())
|
74 |
+
for i, pair in enumerate(pairs):
|
75 |
+
index1, index2 = pair
|
76 |
+
if keypoints[index1, -1] > threshold and keypoints[index2, -1] > threshold:
|
77 |
+
thicknessLineScaled = int(round(min(thicknessLine[index1], thicknessLine[index2]) * pose_scales[0]))
|
78 |
+
colorIndex = index2
|
79 |
+
color = colors[colorIndex % numberColors]
|
80 |
+
keypoint1 = keypoints[index1, :-1].astype(int)
|
81 |
+
keypoint2 = keypoints[index2, :-1].astype(int)
|
82 |
+
cv2.line(img, tuple(keypoint1.tolist()), tuple(keypoint2.tolist()), tuple(color.tolist()), thicknessLineScaled, lineType, shift)
|
83 |
+
for part in range(len(keypoints)):
|
84 |
+
faceIndex = part
|
85 |
+
if keypoints[faceIndex, -1] > threshold:
|
86 |
+
radiusScaled = int(round(radius[faceIndex] * pose_scales[0]))
|
87 |
+
thicknessCircleScaled = int(round(thicknessCircle[faceIndex] * pose_scales[0]))
|
88 |
+
colorIndex = part
|
89 |
+
color = colors[colorIndex % numberColors]
|
90 |
+
center = keypoints[faceIndex, :-1].astype(int)
|
91 |
+
cv2.circle(img, tuple(center.tolist()), radiusScaled, tuple(color.tolist()), thicknessCircleScaled, lineType, shift)
|
92 |
+
return img
|
93 |
+
|
94 |
+
def render_hand_keypoints(img, right_hand_keypoints, threshold=0.1, use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0):
|
95 |
+
if use_confidence and map_fn is not None:
|
96 |
+
#thicknessCircleRatioLeft = 1./50 * map_fn(left_hand_keypoints[:, -1])
|
97 |
+
thicknessCircleRatioRight = 1./50 * map_fn(right_hand_keypoints[:, -1])
|
98 |
+
else:
|
99 |
+
#thicknessCircleRatioLeft = 1./50 * np.ones(left_hand_keypoints.shape[0])
|
100 |
+
thicknessCircleRatioRight = 1./50 * np.ones(right_hand_keypoints.shape[0])
|
101 |
+
thicknessLineRatioWRTCircle = 0.75
|
102 |
+
pairs = [0,1, 1,2, 2,3, 3,4, 0,5, 5,6, 6,7, 7,8, 0,9, 9,10, 10,11, 11,12, 0,13, 13,14, 14,15, 15,16, 0,17, 17,18, 18,19, 19,20]
|
103 |
+
pairs = np.array(pairs).reshape(-1,2)
|
104 |
+
|
105 |
+
colors = [100., 100., 100.,
|
106 |
+
100., 0., 0.,
|
107 |
+
150., 0., 0.,
|
108 |
+
200., 0., 0.,
|
109 |
+
255., 0., 0.,
|
110 |
+
100., 100., 0.,
|
111 |
+
150., 150., 0.,
|
112 |
+
200., 200., 0.,
|
113 |
+
255., 255., 0.,
|
114 |
+
0., 100., 50.,
|
115 |
+
0., 150., 75.,
|
116 |
+
0., 200., 100.,
|
117 |
+
0., 255., 125.,
|
118 |
+
0., 50., 100.,
|
119 |
+
0., 75., 150.,
|
120 |
+
0., 100., 200.,
|
121 |
+
0., 125., 255.,
|
122 |
+
100., 0., 100.,
|
123 |
+
150., 0., 150.,
|
124 |
+
200., 0., 200.,
|
125 |
+
255., 0., 255.]
|
126 |
+
colors = np.array(colors).reshape(-1,3)
|
127 |
+
#colors = np.zeros_like(colors)
|
128 |
+
poseScales = [1]
|
129 |
+
#img = render_keypoints(img, left_hand_keypoints, pairs, colors, thicknessCircleRatioLeft, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha)
|
130 |
+
img = render_keypoints(img, right_hand_keypoints, pairs, colors, thicknessCircleRatioRight, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha)
|
131 |
+
#img = render_keypoints(img, right_hand_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1)
|
132 |
+
return img
|
133 |
+
|
134 |
+
def render_hand_landmarks(img, right_hand_keypoints, threshold=0.1, use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0):
|
135 |
+
if use_confidence and map_fn is not None:
|
136 |
+
#thicknessCircleRatioLeft = 1./50 * map_fn(left_hand_keypoints[:, -1])
|
137 |
+
thicknessCircleRatioRight = 1./50 * map_fn(right_hand_keypoints[:, -1])
|
138 |
+
else:
|
139 |
+
#thicknessCircleRatioLeft = 1./50 * np.ones(left_hand_keypoints.shape[0])
|
140 |
+
thicknessCircleRatioRight = 1./50 * np.ones(right_hand_keypoints.shape[0])
|
141 |
+
thicknessLineRatioWRTCircle = 0.75
|
142 |
+
pairs = []
|
143 |
+
pairs = np.array(pairs).reshape(-1,2)
|
144 |
+
|
145 |
+
colors = [255, 0, 0]
|
146 |
+
colors = np.array(colors).reshape(-1,3)
|
147 |
+
#colors = np.zeros_like(colors)
|
148 |
+
poseScales = [1]
|
149 |
+
#img = render_keypoints(img, left_hand_keypoints, pairs, colors, thicknessCircleRatioLeft, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha)
|
150 |
+
img = render_keypoints(img, right_hand_keypoints, pairs, colors, thicknessCircleRatioRight * 0.1, thicknessLineRatioWRTCircle * 0.1, poseScales, threshold, alpha=alpha)
|
151 |
+
#img = render_keypoints(img, right_hand_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1)
|
152 |
+
return img
|
153 |
+
|
154 |
+
def render_body_keypoints(img: np.array,
|
155 |
+
body_keypoints: np.array) -> np.array:
|
156 |
+
"""
|
157 |
+
Render OpenPose body keypoints on input image.
|
158 |
+
Args:
|
159 |
+
img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
|
160 |
+
body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
|
161 |
+
Returns:
|
162 |
+
(np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
|
163 |
+
"""
|
164 |
+
|
165 |
+
thickness_circle_ratio = 1./75. * np.ones(body_keypoints.shape[0])
|
166 |
+
thickness_line_ratio_wrt_circle = 0.75
|
167 |
+
pairs = []
|
168 |
+
pairs = [1,8,1,2,1,5,2,3,3,4,5,6,6,7,8,9,9,10,10,11,8,12,12,13,13,14,1,0,0,15,15,17,0,16,16,18,14,19,19,20,14,21,11,22,22,23,11,24]
|
169 |
+
pairs = np.array(pairs).reshape(-1,2)
|
170 |
+
colors = [255., 0., 85.,
|
171 |
+
255., 0., 0.,
|
172 |
+
255., 85., 0.,
|
173 |
+
255., 170., 0.,
|
174 |
+
255., 255., 0.,
|
175 |
+
170., 255., 0.,
|
176 |
+
85., 255., 0.,
|
177 |
+
0., 255., 0.,
|
178 |
+
255., 0., 0.,
|
179 |
+
0., 255., 85.,
|
180 |
+
0., 255., 170.,
|
181 |
+
0., 255., 255.,
|
182 |
+
0., 170., 255.,
|
183 |
+
0., 85., 255.,
|
184 |
+
0., 0., 255.,
|
185 |
+
255., 0., 170.,
|
186 |
+
170., 0., 255.,
|
187 |
+
255., 0., 255.,
|
188 |
+
85., 0., 255.,
|
189 |
+
0., 0., 255.,
|
190 |
+
0., 0., 255.,
|
191 |
+
0., 0., 255.,
|
192 |
+
0., 255., 255.,
|
193 |
+
0., 255., 255.,
|
194 |
+
0., 255., 255.]
|
195 |
+
colors = np.array(colors).reshape(-1,3)
|
196 |
+
pose_scales = [1]
|
197 |
+
return render_keypoints(img, body_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1)
|
198 |
+
|
199 |
+
def render_openpose(img: np.array,
|
200 |
+
hand_keypoints: np.array) -> np.array:
|
201 |
+
"""
|
202 |
+
Render keypoints in the OpenPose format on input image.
|
203 |
+
Args:
|
204 |
+
img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
|
205 |
+
body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
|
206 |
+
Returns:
|
207 |
+
(np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
|
208 |
+
"""
|
209 |
+
#img = render_body_keypoints(img, body_keypoints)
|
210 |
+
img = render_hand_keypoints(img, hand_keypoints)
|
211 |
+
return img
|
212 |
+
|
213 |
+
def render_openpose_landmarks(img: np.array,
|
214 |
+
hand_keypoints: np.array) -> np.array:
|
215 |
+
"""
|
216 |
+
Render keypoints in the OpenPose format on input image.
|
217 |
+
Args:
|
218 |
+
img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
|
219 |
+
body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
|
220 |
+
Returns:
|
221 |
+
(np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
|
222 |
+
"""
|
223 |
+
#img = render_body_keypoints(img, body_keypoints)
|
224 |
+
img = render_hand_landmarks(img, hand_keypoints)
|
225 |
+
return img
|
hawor/utils/rotation.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
|
7 |
+
"""
|
8 |
+
Taken from https://github.com/mkocabas/VIBE/blob/master/lib/utils/geometry.py
|
9 |
+
Calculates the rotation matrices for a batch of rotation vectors
|
10 |
+
- param rot_vecs: torch.tensor (N, 3) array of N axis-angle vectors
|
11 |
+
- returns R: torch.tensor (N, 3, 3) rotation matrices
|
12 |
+
"""
|
13 |
+
batch_size = rot_vecs.shape[0]
|
14 |
+
device = rot_vecs.device
|
15 |
+
|
16 |
+
angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
|
17 |
+
rot_dir = rot_vecs / angle
|
18 |
+
|
19 |
+
cos = torch.unsqueeze(torch.cos(angle), dim=1)
|
20 |
+
sin = torch.unsqueeze(torch.sin(angle), dim=1)
|
21 |
+
|
22 |
+
# Bx1 arrays
|
23 |
+
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
|
24 |
+
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
|
25 |
+
|
26 |
+
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
|
27 |
+
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(
|
28 |
+
(batch_size, 3, 3)
|
29 |
+
)
|
30 |
+
|
31 |
+
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
|
32 |
+
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
|
33 |
+
return rot_mat
|
34 |
+
|
35 |
+
|
36 |
+
def quaternion_mul(q0, q1):
|
37 |
+
"""
|
38 |
+
EXPECTS WXYZ
|
39 |
+
:param q0 (*, 4)
|
40 |
+
:param q1 (*, 4)
|
41 |
+
"""
|
42 |
+
r0, r1 = q0[..., :1], q1[..., :1]
|
43 |
+
v0, v1 = q0[..., 1:], q1[..., 1:]
|
44 |
+
r = r0 * r1 - (v0 * v1).sum(dim=-1, keepdim=True)
|
45 |
+
v = r0 * v1 + r1 * v0 + torch.linalg.cross(v0, v1)
|
46 |
+
return torch.cat([r, v], dim=-1)
|
47 |
+
|
48 |
+
|
49 |
+
def quaternion_inverse(q, eps=1e-8):
|
50 |
+
"""
|
51 |
+
EXPECTS WXYZ
|
52 |
+
:param q (*, 4)
|
53 |
+
"""
|
54 |
+
conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1)
|
55 |
+
mag = torch.square(q).sum(dim=-1, keepdim=True) + eps
|
56 |
+
return conj / mag
|
57 |
+
|
58 |
+
|
59 |
+
def quaternion_slerp(t, q0, q1, eps=1e-8):
|
60 |
+
"""
|
61 |
+
:param t (*, 1) must be between 0 and 1
|
62 |
+
:param q0 (*, 4)
|
63 |
+
:param q1 (*, 4)
|
64 |
+
"""
|
65 |
+
dims = q0.shape[:-1]
|
66 |
+
t = t.view(*dims, 1)
|
67 |
+
|
68 |
+
q0 = F.normalize(q0, p=2, dim=-1)
|
69 |
+
q1 = F.normalize(q1, p=2, dim=-1)
|
70 |
+
dot = (q0 * q1).sum(dim=-1, keepdim=True)
|
71 |
+
|
72 |
+
# make sure we give the shortest rotation path (< 180d)
|
73 |
+
neg = dot < 0
|
74 |
+
q1 = torch.where(neg, -q1, q1)
|
75 |
+
dot = torch.where(neg, -dot, dot)
|
76 |
+
angle = torch.acos(dot)
|
77 |
+
|
78 |
+
# if angle is too small, just do linear interpolation
|
79 |
+
collin = torch.abs(dot) > 1 - eps
|
80 |
+
fac = 1 / torch.sin(angle)
|
81 |
+
w0 = torch.where(collin, 1 - t, torch.sin((1 - t) * angle) * fac)
|
82 |
+
w1 = torch.where(collin, t, torch.sin(t * angle) * fac)
|
83 |
+
slerp = q0 * w0 + q1 * w1
|
84 |
+
return slerp
|
85 |
+
|
86 |
+
|
87 |
+
def rotation_matrix_to_angle_axis(rotation_matrix):
|
88 |
+
"""
|
89 |
+
This function is borrowed from https://github.com/kornia/kornia
|
90 |
+
|
91 |
+
Convert rotation matrix to Rodrigues vector
|
92 |
+
"""
|
93 |
+
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
94 |
+
aa = quaternion_to_angle_axis(quaternion)
|
95 |
+
aa[torch.isnan(aa)] = 0.0
|
96 |
+
return aa
|
97 |
+
|
98 |
+
|
99 |
+
def quaternion_to_angle_axis(quaternion):
|
100 |
+
"""
|
101 |
+
This function is borrowed from https://github.com/kornia/kornia
|
102 |
+
|
103 |
+
Convert quaternion vector to angle axis of rotation.
|
104 |
+
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
|
105 |
+
|
106 |
+
:param quaternion (*, 4) expects WXYZ
|
107 |
+
:returns angle_axis (*, 3)
|
108 |
+
"""
|
109 |
+
# unpack input and compute conversion
|
110 |
+
q1 = quaternion[..., 1]
|
111 |
+
q2 = quaternion[..., 2]
|
112 |
+
q3 = quaternion[..., 3]
|
113 |
+
sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3
|
114 |
+
|
115 |
+
sin_theta = torch.sqrt(sin_squared_theta)
|
116 |
+
cos_theta = quaternion[..., 0]
|
117 |
+
two_theta = 2.0 * torch.where(
|
118 |
+
cos_theta < 0.0,
|
119 |
+
torch.atan2(-sin_theta, -cos_theta),
|
120 |
+
torch.atan2(sin_theta, cos_theta),
|
121 |
+
)
|
122 |
+
|
123 |
+
k_pos = two_theta / sin_theta
|
124 |
+
k_neg = 2.0 * torch.ones_like(sin_theta)
|
125 |
+
k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
|
126 |
+
|
127 |
+
angle_axis = torch.zeros_like(quaternion)[..., :3]
|
128 |
+
angle_axis[..., 0] += q1 * k
|
129 |
+
angle_axis[..., 1] += q2 * k
|
130 |
+
angle_axis[..., 2] += q3 * k
|
131 |
+
return angle_axis
|
132 |
+
|
133 |
+
|
134 |
+
def angle_axis_to_rotation_matrix(angle_axis):
|
135 |
+
"""
|
136 |
+
:param angle_axis (*, 3)
|
137 |
+
return (*, 3, 3)
|
138 |
+
"""
|
139 |
+
quat = angle_axis_to_quaternion(angle_axis)
|
140 |
+
return quaternion_to_rotation_matrix(quat)
|
141 |
+
|
142 |
+
|
143 |
+
def quaternion_to_rotation_matrix(quaternion):
|
144 |
+
"""
|
145 |
+
Convert a quaternion to a rotation matrix.
|
146 |
+
Taken from https://github.com/kornia/kornia, based on
|
147 |
+
https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101
|
148 |
+
https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247
|
149 |
+
:param quaternion (N, 4) expects WXYZ order
|
150 |
+
returns rotation matrix (N, 3, 3)
|
151 |
+
"""
|
152 |
+
# normalize the input quaternion
|
153 |
+
quaternion_norm = F.normalize(quaternion, p=2, dim=-1, eps=1e-12)
|
154 |
+
*dims, _ = quaternion_norm.shape
|
155 |
+
|
156 |
+
# unpack the normalized quaternion components
|
157 |
+
w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1)
|
158 |
+
|
159 |
+
# compute the actual conversion
|
160 |
+
tx = 2.0 * x
|
161 |
+
ty = 2.0 * y
|
162 |
+
tz = 2.0 * z
|
163 |
+
twx = tx * w
|
164 |
+
twy = ty * w
|
165 |
+
twz = tz * w
|
166 |
+
txx = tx * x
|
167 |
+
txy = ty * x
|
168 |
+
txz = tz * x
|
169 |
+
tyy = ty * y
|
170 |
+
tyz = tz * y
|
171 |
+
tzz = tz * z
|
172 |
+
one = torch.tensor(1.0)
|
173 |
+
|
174 |
+
matrix = torch.stack(
|
175 |
+
(
|
176 |
+
one - (tyy + tzz),
|
177 |
+
txy - twz,
|
178 |
+
txz + twy,
|
179 |
+
txy + twz,
|
180 |
+
one - (txx + tzz),
|
181 |
+
tyz - twx,
|
182 |
+
txz - twy,
|
183 |
+
tyz + twx,
|
184 |
+
one - (txx + tyy),
|
185 |
+
),
|
186 |
+
dim=-1,
|
187 |
+
).view(*dims, 3, 3)
|
188 |
+
return matrix
|
189 |
+
|
190 |
+
|
191 |
+
def angle_axis_to_quaternion(angle_axis):
|
192 |
+
"""
|
193 |
+
This function is borrowed from https://github.com/kornia/kornia
|
194 |
+
Convert angle axis to quaternion in WXYZ order
|
195 |
+
:param angle_axis (*, 3)
|
196 |
+
:returns quaternion (*, 4) WXYZ order
|
197 |
+
"""
|
198 |
+
theta_sq = torch.sum(angle_axis**2, dim=-1, keepdim=True) # (*, 1)
|
199 |
+
# need to handle the zero rotation case
|
200 |
+
valid = theta_sq > 0
|
201 |
+
theta = torch.sqrt(theta_sq)
|
202 |
+
half_theta = 0.5 * theta
|
203 |
+
ones = torch.ones_like(half_theta)
|
204 |
+
# fill zero with the limit of sin ax / x -> a
|
205 |
+
k = torch.where(valid, torch.sin(half_theta) / theta, 0.5 * ones)
|
206 |
+
w = torch.where(valid, torch.cos(half_theta), ones)
|
207 |
+
quat = torch.cat([w, k * angle_axis], dim=-1)
|
208 |
+
return quat
|
209 |
+
|
210 |
+
|
211 |
+
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
|
212 |
+
"""
|
213 |
+
This function is borrowed from https://github.com/kornia/kornia
|
214 |
+
Convert rotation matrix to 4d quaternion vector
|
215 |
+
This algorithm is based on algorithm described in
|
216 |
+
https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
|
217 |
+
|
218 |
+
:param rotation_matrix (N, 3, 3)
|
219 |
+
"""
|
220 |
+
*dims, m, n = rotation_matrix.shape
|
221 |
+
rmat_t = torch.transpose(rotation_matrix.reshape(-1, m, n), -1, -2)
|
222 |
+
|
223 |
+
mask_d2 = rmat_t[:, 2, 2] < eps
|
224 |
+
|
225 |
+
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
|
226 |
+
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
|
227 |
+
|
228 |
+
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
229 |
+
q0 = torch.stack(
|
230 |
+
[
|
231 |
+
rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
232 |
+
t0,
|
233 |
+
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
234 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
235 |
+
],
|
236 |
+
-1,
|
237 |
+
)
|
238 |
+
t0_rep = t0.repeat(4, 1).t()
|
239 |
+
|
240 |
+
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
241 |
+
q1 = torch.stack(
|
242 |
+
[
|
243 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
244 |
+
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
245 |
+
t1,
|
246 |
+
rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
|
247 |
+
],
|
248 |
+
-1,
|
249 |
+
)
|
250 |
+
t1_rep = t1.repeat(4, 1).t()
|
251 |
+
|
252 |
+
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
253 |
+
q2 = torch.stack(
|
254 |
+
[
|
255 |
+
rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
|
256 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
257 |
+
rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
|
258 |
+
t2,
|
259 |
+
],
|
260 |
+
-1,
|
261 |
+
)
|
262 |
+
t2_rep = t2.repeat(4, 1).t()
|
263 |
+
|
264 |
+
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
265 |
+
q3 = torch.stack(
|
266 |
+
[
|
267 |
+
t3,
|
268 |
+
rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
269 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
270 |
+
rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
|
271 |
+
],
|
272 |
+
-1,
|
273 |
+
)
|
274 |
+
t3_rep = t3.repeat(4, 1).t()
|
275 |
+
|
276 |
+
mask_c0 = mask_d2 * mask_d0_d1
|
277 |
+
mask_c1 = mask_d2 * ~mask_d0_d1
|
278 |
+
mask_c2 = ~mask_d2 * mask_d0_nd1
|
279 |
+
mask_c3 = ~mask_d2 * ~mask_d0_nd1
|
280 |
+
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
|
281 |
+
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
|
282 |
+
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
|
283 |
+
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
|
284 |
+
|
285 |
+
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
|
286 |
+
q /= torch.sqrt(
|
287 |
+
t0_rep * mask_c0
|
288 |
+
+ t1_rep * mask_c1
|
289 |
+
+ t2_rep * mask_c2 # noqa
|
290 |
+
+ t3_rep * mask_c3
|
291 |
+
) # noqa
|
292 |
+
q *= 0.5
|
293 |
+
return q.reshape(*dims, 4)
|
imgui.ini
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[Window][Debug##Default]
|
2 |
+
Pos=60,60
|
3 |
+
Size=400,400
|
4 |
+
Collapsed=0
|
5 |
+
|
6 |
+
[Window][Editor]
|
7 |
+
Pos=50,50
|
8 |
+
Size=250,700
|
9 |
+
Collapsed=0
|
10 |
+
|
11 |
+
[Window][Playback]
|
12 |
+
Pos=50,800
|
13 |
+
Size=400,175
|
14 |
+
Collapsed=1
|
15 |
+
|
infiller/hand_utils/geometry.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def perspective_projection(points, rotation, translation,
|
7 |
+
focal_length, camera_center, distortion=None):
|
8 |
+
"""
|
9 |
+
This function computes the perspective projection of a set of points.
|
10 |
+
Input:
|
11 |
+
points (bs, N, 3): 3D points
|
12 |
+
rotation (bs, 3, 3): Camera rotation
|
13 |
+
translation (bs, 3): Camera translation
|
14 |
+
focal_length (bs,) or scalar: Focal length
|
15 |
+
camera_center (bs, 2): Camera center
|
16 |
+
"""
|
17 |
+
batch_size = points.shape[0]
|
18 |
+
|
19 |
+
# Extrinsic
|
20 |
+
if rotation is not None:
|
21 |
+
points = torch.einsum('bij,bkj->bki', rotation, points)
|
22 |
+
|
23 |
+
if translation is not None:
|
24 |
+
points = points + translation.unsqueeze(1)
|
25 |
+
|
26 |
+
if distortion is not None:
|
27 |
+
kc = distortion
|
28 |
+
points = points[:,:,:2] / points[:,:,2:]
|
29 |
+
|
30 |
+
r2 = points[:,:,0]**2 + points[:,:,1]**2
|
31 |
+
dx = (2 * kc[:,[2]] * points[:,:,0] * points[:,:,1]
|
32 |
+
+ kc[:,[3]] * (r2 + 2*points[:,:,0]**2))
|
33 |
+
|
34 |
+
dy = (2 * kc[:,[3]] * points[:,:,0] * points[:,:,1]
|
35 |
+
+ kc[:,[2]] * (r2 + 2*points[:,:,1]**2))
|
36 |
+
|
37 |
+
x = (1 + kc[:,[0]]*r2 + kc[:,[1]]*r2.pow(2) + kc[:,[4]]*r2.pow(3)) * points[:,:,0] + dx
|
38 |
+
y = (1 + kc[:,[0]]*r2 + kc[:,[1]]*r2.pow(2) + kc[:,[4]]*r2.pow(3)) * points[:,:,1] + dy
|
39 |
+
|
40 |
+
points = torch.stack([x, y, torch.ones_like(x)], dim=-1)
|
41 |
+
|
42 |
+
# Intrinsic
|
43 |
+
K = torch.zeros([batch_size, 3, 3], device=points.device)
|
44 |
+
K[:,0,0] = focal_length
|
45 |
+
K[:,1,1] = focal_length
|
46 |
+
K[:,2,2] = 1.
|
47 |
+
K[:,:-1, -1] = camera_center
|
48 |
+
|
49 |
+
# Apply camera intrinsicsrf
|
50 |
+
points = points / points[:,:,-1].unsqueeze(-1)
|
51 |
+
projected_points = torch.einsum('bij,bkj->bki', K, points)
|
52 |
+
projected_points = projected_points[:, :, :-1]
|
53 |
+
|
54 |
+
return projected_points
|
55 |
+
|
56 |
+
|
57 |
+
def avg_rot(rot):
|
58 |
+
# input [B,...,3,3] --> output [...,3,3]
|
59 |
+
rot = rot.mean(dim=0)
|
60 |
+
U, _, V = torch.svd(rot)
|
61 |
+
rot = U @ V.transpose(-1, -2)
|
62 |
+
return rot
|
63 |
+
|
64 |
+
|
65 |
+
def rot9d_to_rotmat(x):
|
66 |
+
"""Convert 9D rotation representation to 3x3 rotation matrix.
|
67 |
+
Based on Levinson et al., "An Analysis of SVD for Deep Rotation Estimation"
|
68 |
+
Input:
|
69 |
+
(B,9) or (B,J*9) Batch of 9D rotation (interpreted as 3x3 est rotmat)
|
70 |
+
Output:
|
71 |
+
(B,3,3) or (B*J,3,3) Batch of corresponding rotation matrices
|
72 |
+
"""
|
73 |
+
x = x.view(-1,3,3)
|
74 |
+
u, _, vh = torch.linalg.svd(x)
|
75 |
+
|
76 |
+
sig = torch.eye(3).expand(len(x), 3, 3).clone()
|
77 |
+
sig = sig.to(x.device)
|
78 |
+
sig[:, -1, -1] = (u @ vh).det()
|
79 |
+
|
80 |
+
R = u @ sig @ vh
|
81 |
+
|
82 |
+
return R
|
83 |
+
|
84 |
+
|
85 |
+
"""
|
86 |
+
Deprecated in favor of: rotation_conversions.py
|
87 |
+
|
88 |
+
Useful geometric operations, e.g. differentiable Rodrigues formula
|
89 |
+
Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
|
90 |
+
"""
|
91 |
+
def batch_rodrigues(theta):
|
92 |
+
"""Convert axis-angle representation to rotation matrix.
|
93 |
+
Args:
|
94 |
+
theta: size = [B, 3]
|
95 |
+
Returns:
|
96 |
+
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
97 |
+
"""
|
98 |
+
l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
|
99 |
+
angle = torch.unsqueeze(l1norm, -1)
|
100 |
+
normalized = torch.div(theta, angle)
|
101 |
+
angle = angle * 0.5
|
102 |
+
v_cos = torch.cos(angle)
|
103 |
+
v_sin = torch.sin(angle)
|
104 |
+
quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
|
105 |
+
return quat_to_rotmat(quat)
|
106 |
+
|
107 |
+
def quat_to_rotmat(quat):
|
108 |
+
"""Convert quaternion coefficients to rotation matrix.
|
109 |
+
Args:
|
110 |
+
quat: size = [B, 4] 4 <===>(w, x, y, z)
|
111 |
+
Returns:
|
112 |
+
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
113 |
+
"""
|
114 |
+
norm_quat = quat
|
115 |
+
norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
|
116 |
+
w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
|
117 |
+
|
118 |
+
B = quat.size(0)
|
119 |
+
|
120 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
121 |
+
wx, wy, wz = w*x, w*y, w*z
|
122 |
+
xy, xz, yz = x*y, x*z, y*z
|
123 |
+
|
124 |
+
rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
|
125 |
+
2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
|
126 |
+
2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
|
127 |
+
return rotMat
|
128 |
+
|
129 |
+
def rot6d_to_rotmat(x):
|
130 |
+
"""Convert 6D rotation representation to 3x3 rotation matrix.
|
131 |
+
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
132 |
+
Input:
|
133 |
+
(B,6) Batch of 6-D rotation representations
|
134 |
+
Output:
|
135 |
+
(B,3,3) Batch of corresponding rotation matrices
|
136 |
+
"""
|
137 |
+
x = x.view(-1,3,2)
|
138 |
+
a1 = x[:, :, 0]
|
139 |
+
a2 = x[:, :, 1]
|
140 |
+
b1 = F.normalize(a1)
|
141 |
+
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
|
142 |
+
b3 = torch.cross(b1, b2)
|
143 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
144 |
+
|
145 |
+
def rot6d_to_rotmat_hmr2(x: torch.Tensor) -> torch.Tensor:
|
146 |
+
"""
|
147 |
+
Convert 6D rotation representation to 3x3 rotation matrix.
|
148 |
+
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
149 |
+
Args:
|
150 |
+
x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
|
151 |
+
Returns:
|
152 |
+
torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
|
153 |
+
"""
|
154 |
+
x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
|
155 |
+
a1 = x[:, :, 0]
|
156 |
+
a2 = x[:, :, 1]
|
157 |
+
b1 = F.normalize(a1)
|
158 |
+
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
|
159 |
+
b3 = torch.cross(b1, b2)
|
160 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
161 |
+
|
162 |
+
def rotmat_to_rot6d(rotmat):
|
163 |
+
""" Inverse function of the above.
|
164 |
+
Input:
|
165 |
+
(B,3,3) Batch of corresponding rotation matrices
|
166 |
+
Output:
|
167 |
+
(B,6) Batch of 6-D rotation representations
|
168 |
+
"""
|
169 |
+
# rot6d = rotmat[:, :, :2]
|
170 |
+
rot6d = rotmat[...,:2]
|
171 |
+
rot6d = rot6d.reshape(rot6d.size(0), -1)
|
172 |
+
return rot6d
|
173 |
+
|
174 |
+
|
175 |
+
def rotation_matrix_to_angle_axis(rotation_matrix):
|
176 |
+
"""
|
177 |
+
This function is borrowed from https://github.com/kornia/kornia
|
178 |
+
|
179 |
+
Convert 3x4 rotation matrix to Rodrigues vector
|
180 |
+
|
181 |
+
Args:
|
182 |
+
rotation_matrix (Tensor): rotation matrix.
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
Tensor: Rodrigues vector transformation.
|
186 |
+
|
187 |
+
Shape:
|
188 |
+
- Input: :math:`(N, 3, 4)`
|
189 |
+
- Output: :math:`(N, 3)`
|
190 |
+
|
191 |
+
Example:
|
192 |
+
>>> input = torch.rand(2, 3, 4) # Nx4x4
|
193 |
+
>>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
|
194 |
+
"""
|
195 |
+
if rotation_matrix.shape[1:] == (3,3):
|
196 |
+
rot_mat = rotation_matrix.reshape(-1, 3, 3)
|
197 |
+
hom = torch.tensor([0, 0, 1], dtype=torch.float32,
|
198 |
+
device=rotation_matrix.device).reshape(1, 3, 1).expand(rot_mat.shape[0], -1, -1)
|
199 |
+
rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
|
200 |
+
|
201 |
+
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
202 |
+
aa = quaternion_to_angle_axis(quaternion)
|
203 |
+
aa[torch.isnan(aa)] = 0.0
|
204 |
+
return aa
|
205 |
+
|
206 |
+
|
207 |
+
def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
|
208 |
+
"""
|
209 |
+
This function is borrowed from https://github.com/kornia/kornia
|
210 |
+
|
211 |
+
Convert quaternion vector to angle axis of rotation.
|
212 |
+
|
213 |
+
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
|
214 |
+
|
215 |
+
Args:
|
216 |
+
quaternion (torch.Tensor): tensor with quaternions.
|
217 |
+
|
218 |
+
Return:
|
219 |
+
torch.Tensor: tensor with angle axis of rotation.
|
220 |
+
|
221 |
+
Shape:
|
222 |
+
- Input: :math:`(*, 4)` where `*` means, any number of dimensions
|
223 |
+
- Output: :math:`(*, 3)`
|
224 |
+
|
225 |
+
Example:
|
226 |
+
>>> quaternion = torch.rand(2, 4) # Nx4
|
227 |
+
>>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
|
228 |
+
"""
|
229 |
+
if not torch.is_tensor(quaternion):
|
230 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
231 |
+
type(quaternion)))
|
232 |
+
|
233 |
+
if not quaternion.shape[-1] == 4:
|
234 |
+
raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
|
235 |
+
.format(quaternion.shape))
|
236 |
+
# unpack input and compute conversion
|
237 |
+
q1: torch.Tensor = quaternion[..., 1]
|
238 |
+
q2: torch.Tensor = quaternion[..., 2]
|
239 |
+
q3: torch.Tensor = quaternion[..., 3]
|
240 |
+
sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
|
241 |
+
|
242 |
+
sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
|
243 |
+
cos_theta: torch.Tensor = quaternion[..., 0]
|
244 |
+
two_theta: torch.Tensor = 2.0 * torch.where(
|
245 |
+
cos_theta < 0.0,
|
246 |
+
torch.atan2(-sin_theta, -cos_theta),
|
247 |
+
torch.atan2(sin_theta, cos_theta))
|
248 |
+
|
249 |
+
k_pos: torch.Tensor = two_theta / sin_theta
|
250 |
+
k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
|
251 |
+
k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
|
252 |
+
|
253 |
+
angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
|
254 |
+
angle_axis[..., 0] += q1 * k
|
255 |
+
angle_axis[..., 1] += q2 * k
|
256 |
+
angle_axis[..., 2] += q3 * k
|
257 |
+
return angle_axis
|
258 |
+
|
259 |
+
|
260 |
+
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
|
261 |
+
"""
|
262 |
+
This function is borrowed from https://github.com/kornia/kornia
|
263 |
+
|
264 |
+
Convert 3x4 rotation matrix to 4d quaternion vector
|
265 |
+
|
266 |
+
This algorithm is based on algorithm described in
|
267 |
+
https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
|
268 |
+
|
269 |
+
Args:
|
270 |
+
rotation_matrix (Tensor): the rotation matrix to convert.
|
271 |
+
|
272 |
+
Return:
|
273 |
+
Tensor: the rotation in quaternion
|
274 |
+
|
275 |
+
Shape:
|
276 |
+
- Input: :math:`(N, 3, 4)`
|
277 |
+
- Output: :math:`(N, 4)`
|
278 |
+
|
279 |
+
Example:
|
280 |
+
>>> input = torch.rand(4, 3, 4) # Nx3x4
|
281 |
+
>>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
|
282 |
+
"""
|
283 |
+
if not torch.is_tensor(rotation_matrix):
|
284 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
285 |
+
type(rotation_matrix)))
|
286 |
+
|
287 |
+
if len(rotation_matrix.shape) > 3:
|
288 |
+
raise ValueError(
|
289 |
+
"Input size must be a three dimensional tensor. Got {}".format(
|
290 |
+
rotation_matrix.shape))
|
291 |
+
if not rotation_matrix.shape[-2:] == (3, 4):
|
292 |
+
raise ValueError(
|
293 |
+
"Input size must be a N x 3 x 4 tensor. Got {}".format(
|
294 |
+
rotation_matrix.shape))
|
295 |
+
|
296 |
+
rmat_t = torch.transpose(rotation_matrix, 1, 2)
|
297 |
+
|
298 |
+
mask_d2 = rmat_t[:, 2, 2] < eps
|
299 |
+
|
300 |
+
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
|
301 |
+
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
|
302 |
+
|
303 |
+
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
304 |
+
q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
305 |
+
t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
306 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
|
307 |
+
t0_rep = t0.repeat(4, 1).t()
|
308 |
+
|
309 |
+
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
310 |
+
q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
311 |
+
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
312 |
+
t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
|
313 |
+
t1_rep = t1.repeat(4, 1).t()
|
314 |
+
|
315 |
+
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
316 |
+
q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
|
317 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
318 |
+
rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
|
319 |
+
t2_rep = t2.repeat(4, 1).t()
|
320 |
+
|
321 |
+
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
322 |
+
q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
323 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
324 |
+
rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
|
325 |
+
t3_rep = t3.repeat(4, 1).t()
|
326 |
+
|
327 |
+
mask_c0 = mask_d2 * mask_d0_d1
|
328 |
+
mask_c1 = mask_d2 * ~mask_d0_d1
|
329 |
+
mask_c2 = ~mask_d2 * mask_d0_nd1
|
330 |
+
mask_c3 = ~mask_d2 * ~mask_d0_nd1
|
331 |
+
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
|
332 |
+
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
|
333 |
+
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
|
334 |
+
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
|
335 |
+
|
336 |
+
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
|
337 |
+
q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
|
338 |
+
t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
|
339 |
+
q *= 0.5
|
340 |
+
return q
|
341 |
+
|
342 |
+
|
343 |
+
def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000., img_size=224.):
|
344 |
+
"""
|
345 |
+
This function is borrowed from https://github.com/nkolot/SPIN/utils/geometry.py
|
346 |
+
|
347 |
+
Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
|
348 |
+
Input:
|
349 |
+
S: (25, 3) 3D joint locations
|
350 |
+
joints: (25, 3) 2D joint locations and confidence
|
351 |
+
Returns:
|
352 |
+
(3,) camera translation vector
|
353 |
+
"""
|
354 |
+
|
355 |
+
num_joints = S.shape[0]
|
356 |
+
# focal length
|
357 |
+
f = np.array([focal_length,focal_length])
|
358 |
+
# optical center
|
359 |
+
center = np.array([img_size/2., img_size/2.])
|
360 |
+
|
361 |
+
# transformations
|
362 |
+
Z = np.reshape(np.tile(S[:,2],(2,1)).T,-1)
|
363 |
+
XY = np.reshape(S[:,0:2],-1)
|
364 |
+
O = np.tile(center,num_joints)
|
365 |
+
F = np.tile(f,num_joints)
|
366 |
+
weight2 = np.reshape(np.tile(np.sqrt(joints_conf),(2,1)).T,-1)
|
367 |
+
|
368 |
+
# least squares
|
369 |
+
Q = np.array([F*np.tile(np.array([1,0]),num_joints), F*np.tile(np.array([0,1]),num_joints), O-np.reshape(joints_2d,-1)]).T
|
370 |
+
c = (np.reshape(joints_2d,-1)-O)*Z - F*XY
|
371 |
+
|
372 |
+
# weighted least squares
|
373 |
+
W = np.diagflat(weight2)
|
374 |
+
Q = np.dot(W,Q)
|
375 |
+
c = np.dot(W,c)
|
376 |
+
|
377 |
+
# square matrix
|
378 |
+
A = np.dot(Q.T,Q)
|
379 |
+
b = np.dot(Q.T,c)
|
380 |
+
|
381 |
+
# solution
|
382 |
+
trans = np.linalg.solve(A, b)
|
383 |
+
|
384 |
+
return trans
|
385 |
+
|
386 |
+
|
387 |
+
def estimate_translation(S, joints_2d, focal_length=5000., img_size=224.):
|
388 |
+
"""Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
|
389 |
+
Input:
|
390 |
+
S: (B, 49, 3) 3D joint locations
|
391 |
+
joints: (B, 49, 3) 2D joint locations and confidence
|
392 |
+
Returns:
|
393 |
+
(B, 3) camera translation vectors
|
394 |
+
"""
|
395 |
+
|
396 |
+
device = S.device
|
397 |
+
# Use only joints 25:49 (GT joints)
|
398 |
+
S = S[:, -24:, :3].cpu().numpy()
|
399 |
+
joints_2d = joints_2d[:, -24:, :].cpu().numpy()
|
400 |
+
|
401 |
+
joints_conf = joints_2d[:, :, -1]
|
402 |
+
joints_2d = joints_2d[:, :, :-1]
|
403 |
+
trans = np.zeros((S.shape[0], 3), dtype=np.float32)
|
404 |
+
# Find the translation for each example in the batch
|
405 |
+
for i in range(S.shape[0]):
|
406 |
+
S_i = S[i]
|
407 |
+
joints_i = joints_2d[i]
|
408 |
+
conf_i = joints_conf[i]
|
409 |
+
trans[i] = estimate_translation_np(S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size)
|
410 |
+
return torch.from_numpy(trans).to(device)
|
411 |
+
|
412 |
+
|
infiller/hand_utils/geometry_utils.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import torch
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
def aa_to_rotmat(theta: torch.Tensor):
|
6 |
+
"""
|
7 |
+
Convert axis-angle representation to rotation matrix.
|
8 |
+
Works by first converting it to a quaternion.
|
9 |
+
Args:
|
10 |
+
theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
|
11 |
+
Returns:
|
12 |
+
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
|
13 |
+
"""
|
14 |
+
norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
|
15 |
+
angle = torch.unsqueeze(norm, -1)
|
16 |
+
normalized = torch.div(theta, angle)
|
17 |
+
angle = angle * 0.5
|
18 |
+
v_cos = torch.cos(angle)
|
19 |
+
v_sin = torch.sin(angle)
|
20 |
+
quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
|
21 |
+
return quat_to_rotmat(quat)
|
22 |
+
|
23 |
+
def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
|
24 |
+
"""
|
25 |
+
Convert quaternion representation to rotation matrix.
|
26 |
+
Args:
|
27 |
+
quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
|
28 |
+
Returns:
|
29 |
+
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
|
30 |
+
"""
|
31 |
+
norm_quat = quat
|
32 |
+
norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
|
33 |
+
w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
|
34 |
+
|
35 |
+
B = quat.size(0)
|
36 |
+
|
37 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
38 |
+
wx, wy, wz = w*x, w*y, w*z
|
39 |
+
xy, xz, yz = x*y, x*z, y*z
|
40 |
+
|
41 |
+
rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
|
42 |
+
2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
|
43 |
+
2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
|
44 |
+
return rotMat
|
45 |
+
|
46 |
+
|
47 |
+
def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
|
48 |
+
"""
|
49 |
+
Convert 6D rotation representation to 3x3 rotation matrix.
|
50 |
+
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
51 |
+
Args:
|
52 |
+
x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
|
53 |
+
Returns:
|
54 |
+
torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
|
55 |
+
"""
|
56 |
+
x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
|
57 |
+
a1 = x[:, :, 0]
|
58 |
+
a2 = x[:, :, 1]
|
59 |
+
b1 = F.normalize(a1)
|
60 |
+
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
|
61 |
+
b3 = torch.linalg.cross(b1, b2)
|
62 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
63 |
+
|
64 |
+
def perspective_projection(points: torch.Tensor,
|
65 |
+
translation: torch.Tensor,
|
66 |
+
focal_length: torch.Tensor,
|
67 |
+
camera_center: Optional[torch.Tensor] = None,
|
68 |
+
rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
|
69 |
+
"""
|
70 |
+
Computes the perspective projection of a set of 3D points.
|
71 |
+
Args:
|
72 |
+
points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
|
73 |
+
translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
|
74 |
+
focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
|
75 |
+
camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
|
76 |
+
rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
|
77 |
+
Returns:
|
78 |
+
torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
|
79 |
+
"""
|
80 |
+
batch_size = points.shape[0]
|
81 |
+
if rotation is None:
|
82 |
+
rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
|
83 |
+
if camera_center is None:
|
84 |
+
camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
|
85 |
+
# Populate intrinsic camera matrix K.
|
86 |
+
K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
|
87 |
+
K[:,0,0] = focal_length[:,0]
|
88 |
+
K[:,1,1] = focal_length[:,1]
|
89 |
+
K[:,2,2] = 1.
|
90 |
+
K[:,:-1, -1] = camera_center
|
91 |
+
|
92 |
+
# Transform points
|
93 |
+
points = torch.einsum('bij,bkj->bki', rotation, points)
|
94 |
+
points = points + translation.unsqueeze(1)
|
95 |
+
|
96 |
+
# Apply perspective distortion
|
97 |
+
projected_points = points / points[:,:,-1].unsqueeze(-1)
|
98 |
+
|
99 |
+
# Apply camera intrinsics
|
100 |
+
projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
|
101 |
+
|
102 |
+
return projected_points[:, :, :-1]
|
infiller/hand_utils/mano_wrapper.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import pickle
|
4 |
+
from typing import Optional
|
5 |
+
import smplx
|
6 |
+
from smplx.lbs import vertices2joints
|
7 |
+
from smplx.utils import MANOOutput, to_tensor
|
8 |
+
from smplx.vertex_ids import vertex_ids
|
9 |
+
|
10 |
+
|
11 |
+
class MANO(smplx.MANOLayer):
|
12 |
+
def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs):
|
13 |
+
"""
|
14 |
+
Extension of the official MANO implementation to support more joints.
|
15 |
+
Args:
|
16 |
+
Same as MANOLayer.
|
17 |
+
joint_regressor_extra (str): Path to extra joint regressor.
|
18 |
+
"""
|
19 |
+
super(MANO, self).__init__(*args, **kwargs)
|
20 |
+
mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]
|
21 |
+
|
22 |
+
#2, 3, 5, 4, 1
|
23 |
+
if joint_regressor_extra is not None:
|
24 |
+
self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
|
25 |
+
self.register_buffer('extra_joints_idxs', to_tensor(list(vertex_ids['mano'].values()), dtype=torch.long))
|
26 |
+
self.register_buffer('joint_map', torch.tensor(mano_to_openpose, dtype=torch.long))
|
27 |
+
|
28 |
+
def forward(self, *args, **kwargs) -> MANOOutput:
|
29 |
+
"""
|
30 |
+
Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified.
|
31 |
+
"""
|
32 |
+
mano_output = super(MANO, self).forward(*args, **kwargs)
|
33 |
+
extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs)
|
34 |
+
joints = torch.cat([mano_output.joints, extra_joints], dim=1)
|
35 |
+
joints = joints[:, self.joint_map, :]
|
36 |
+
if hasattr(self, 'joint_regressor_extra'):
|
37 |
+
extra_joints = vertices2joints(self.joint_regressor_extra, mano_output.vertices)
|
38 |
+
joints = torch.cat([joints, extra_joints], dim=1)
|
39 |
+
mano_output.joints = joints
|
40 |
+
return mano_output
|
41 |
+
|
42 |
+
def query(self, hmr_output):
|
43 |
+
batch_size = hmr_output['pred_rotmat'].shape[0]
|
44 |
+
pred_rotmat = hmr_output['pred_rotmat'].reshape(batch_size, -1, 3, 3)
|
45 |
+
pred_shape = hmr_output['pred_shape'].reshape(batch_size, 10)
|
46 |
+
|
47 |
+
mano_output = self(global_orient=pred_rotmat[:, [0]],
|
48 |
+
hand_pose = pred_rotmat[:, 1:],
|
49 |
+
betas = pred_shape,
|
50 |
+
pose2rot=False)
|
51 |
+
|
52 |
+
return mano_output
|
infiller/hand_utils/process.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from hand_utils.mano_wrapper import MANO
|
3 |
+
from hand_utils.geometry_utils import aa_to_rotmat
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def run_mano(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True):
|
7 |
+
"""
|
8 |
+
Forward pass of the SMPL model and populates pred_data accordingly with
|
9 |
+
joints3d, verts3d, points3d.
|
10 |
+
|
11 |
+
trans : B x T x 3
|
12 |
+
root_orient : B x T x 3
|
13 |
+
body_pose : B x T x J*3
|
14 |
+
betas : (optional) B x D
|
15 |
+
"""
|
16 |
+
MANO_cfg = {
|
17 |
+
'DATA_DIR': '_DATA/data/',
|
18 |
+
'MODEL_PATH': '_DATA/data/mano',
|
19 |
+
'GENDER': 'neutral',
|
20 |
+
'NUM_HAND_JOINTS': 15,
|
21 |
+
'CREATE_BODY_POSE': False
|
22 |
+
}
|
23 |
+
mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
|
24 |
+
mano = MANO(**mano_cfg)
|
25 |
+
if use_cuda:
|
26 |
+
mano = mano.cuda()
|
27 |
+
|
28 |
+
B, T, _ = root_orient.shape
|
29 |
+
NUM_JOINTS = 15
|
30 |
+
mano_params = {
|
31 |
+
'global_orient': root_orient.reshape(B*T, -1),
|
32 |
+
'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3),
|
33 |
+
'betas': betas.reshape(B*T, -1),
|
34 |
+
}
|
35 |
+
rotmat_mano_params = mano_params
|
36 |
+
rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3)
|
37 |
+
rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3)
|
38 |
+
rotmat_mano_params['transl'] = trans.reshape(B*T, 3)
|
39 |
+
|
40 |
+
if use_cuda:
|
41 |
+
mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False)
|
42 |
+
else:
|
43 |
+
mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False)
|
44 |
+
|
45 |
+
faces_right = mano.faces
|
46 |
+
faces_new = np.array([[92, 38, 234],
|
47 |
+
[234, 38, 239],
|
48 |
+
[38, 122, 239],
|
49 |
+
[239, 122, 279],
|
50 |
+
[122, 118, 279],
|
51 |
+
[279, 118, 215],
|
52 |
+
[118, 117, 215],
|
53 |
+
[215, 117, 214],
|
54 |
+
[117, 119, 214],
|
55 |
+
[214, 119, 121],
|
56 |
+
[119, 120, 121],
|
57 |
+
[121, 120, 78],
|
58 |
+
[120, 108, 78],
|
59 |
+
[78, 108, 79]])
|
60 |
+
faces_right = np.concatenate([faces_right, faces_new], axis=0)
|
61 |
+
faces_n = len(faces_right)
|
62 |
+
faces_left = faces_right[:,[0,2,1]]
|
63 |
+
|
64 |
+
outputs = {
|
65 |
+
"joints": mano_output.joints.reshape(B, T, -1, 3),
|
66 |
+
"vertices": mano_output.vertices.reshape(B, T, -1, 3),
|
67 |
+
}
|
68 |
+
|
69 |
+
if not is_right is None:
|
70 |
+
# outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0]
|
71 |
+
# outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0]
|
72 |
+
is_right = (is_right[:, :, 0].cpu().numpy() > 0)
|
73 |
+
faces_result = np.zeros((B, T, faces_n, 3))
|
74 |
+
faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0)
|
75 |
+
faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0)
|
76 |
+
faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded)
|
77 |
+
outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32))
|
78 |
+
|
79 |
+
|
80 |
+
return outputs
|
81 |
+
|
82 |
+
def run_mano_left(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True, fix_shapedirs=True):
|
83 |
+
"""
|
84 |
+
Forward pass of the SMPL model and populates pred_data accordingly with
|
85 |
+
joints3d, verts3d, points3d.
|
86 |
+
|
87 |
+
trans : B x T x 3
|
88 |
+
root_orient : B x T x 3
|
89 |
+
body_pose : B x T x J*3
|
90 |
+
betas : (optional) B x D
|
91 |
+
"""
|
92 |
+
MANO_cfg = {
|
93 |
+
'DATA_DIR': '_DATA/data_left/',
|
94 |
+
'MODEL_PATH': '_DATA/data_left/mano_left',
|
95 |
+
'GENDER': 'neutral',
|
96 |
+
'NUM_HAND_JOINTS': 15,
|
97 |
+
'CREATE_BODY_POSE': False,
|
98 |
+
'is_rhand': False
|
99 |
+
}
|
100 |
+
mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()}
|
101 |
+
mano = MANO(**mano_cfg)
|
102 |
+
if use_cuda:
|
103 |
+
mano = mano.cuda()
|
104 |
+
|
105 |
+
# fix MANO shapedirs of the left hand bug (https://github.com/vchoutas/smplx/issues/48)
|
106 |
+
if fix_shapedirs:
|
107 |
+
mano.shapedirs[:, 0, :] *= -1
|
108 |
+
|
109 |
+
B, T, _ = root_orient.shape
|
110 |
+
NUM_JOINTS = 15
|
111 |
+
mano_params = {
|
112 |
+
'global_orient': root_orient.reshape(B*T, -1),
|
113 |
+
'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3),
|
114 |
+
'betas': betas.reshape(B*T, -1),
|
115 |
+
}
|
116 |
+
rotmat_mano_params = mano_params
|
117 |
+
rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3)
|
118 |
+
rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3)
|
119 |
+
rotmat_mano_params['transl'] = trans.reshape(B*T, 3)
|
120 |
+
|
121 |
+
if use_cuda:
|
122 |
+
mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False)
|
123 |
+
else:
|
124 |
+
mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False)
|
125 |
+
|
126 |
+
faces_right = mano.faces
|
127 |
+
faces_new = np.array([[92, 38, 234],
|
128 |
+
[234, 38, 239],
|
129 |
+
[38, 122, 239],
|
130 |
+
[239, 122, 279],
|
131 |
+
[122, 118, 279],
|
132 |
+
[279, 118, 215],
|
133 |
+
[118, 117, 215],
|
134 |
+
[215, 117, 214],
|
135 |
+
[117, 119, 214],
|
136 |
+
[214, 119, 121],
|
137 |
+
[119, 120, 121],
|
138 |
+
[121, 120, 78],
|
139 |
+
[120, 108, 78],
|
140 |
+
[78, 108, 79]])
|
141 |
+
faces_right = np.concatenate([faces_right, faces_new], axis=0)
|
142 |
+
faces_n = len(faces_right)
|
143 |
+
faces_left = faces_right[:,[0,2,1]]
|
144 |
+
|
145 |
+
outputs = {
|
146 |
+
"joints": mano_output.joints.reshape(B, T, -1, 3),
|
147 |
+
"vertices": mano_output.vertices.reshape(B, T, -1, 3),
|
148 |
+
}
|
149 |
+
|
150 |
+
if not is_right is None:
|
151 |
+
# outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0]
|
152 |
+
# outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0]
|
153 |
+
is_right = (is_right[:, :, 0].cpu().numpy() > 0)
|
154 |
+
faces_result = np.zeros((B, T, faces_n, 3))
|
155 |
+
faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0)
|
156 |
+
faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0)
|
157 |
+
faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded)
|
158 |
+
outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32))
|
159 |
+
|
160 |
+
|
161 |
+
return outputs
|
162 |
+
|
163 |
+
def run_mano_twohands(init_trans, init_rot, init_hand_pose, is_right, init_betas, use_cuda=True, fix_shapedirs=True):
|
164 |
+
outputs_left = run_mano_left(init_trans[0:1], init_rot[0:1], init_hand_pose[0:1], None, init_betas[0:1], use_cuda=use_cuda, fix_shapedirs=fix_shapedirs)
|
165 |
+
outputs_right = run_mano(init_trans[1:2], init_rot[1:2], init_hand_pose[1:2], None, init_betas[1:2], use_cuda=use_cuda)
|
166 |
+
outputs_two = {
|
167 |
+
"vertices": torch.cat((outputs_left["vertices"], outputs_right["vertices"]), dim=0),
|
168 |
+
"joints": torch.cat((outputs_left["joints"], outputs_right["joints"]), dim=0)
|
169 |
+
|
170 |
+
}
|
171 |
+
return outputs_two
|
infiller/hand_utils/rotation.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
|
7 |
+
"""
|
8 |
+
Taken from https://github.com/mkocabas/VIBE/blob/master/lib/utils/geometry.py
|
9 |
+
Calculates the rotation matrices for a batch of rotation vectors
|
10 |
+
- param rot_vecs: torch.tensor (N, 3) array of N axis-angle vectors
|
11 |
+
- returns R: torch.tensor (N, 3, 3) rotation matrices
|
12 |
+
"""
|
13 |
+
batch_size = rot_vecs.shape[0]
|
14 |
+
device = rot_vecs.device
|
15 |
+
|
16 |
+
angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
|
17 |
+
rot_dir = rot_vecs / angle
|
18 |
+
|
19 |
+
cos = torch.unsqueeze(torch.cos(angle), dim=1)
|
20 |
+
sin = torch.unsqueeze(torch.sin(angle), dim=1)
|
21 |
+
|
22 |
+
# Bx1 arrays
|
23 |
+
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
|
24 |
+
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
|
25 |
+
|
26 |
+
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
|
27 |
+
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(
|
28 |
+
(batch_size, 3, 3)
|
29 |
+
)
|
30 |
+
|
31 |
+
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
|
32 |
+
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
|
33 |
+
return rot_mat
|
34 |
+
|
35 |
+
|
36 |
+
def quaternion_mul(q0, q1):
|
37 |
+
"""
|
38 |
+
EXPECTS WXYZ
|
39 |
+
:param q0 (*, 4)
|
40 |
+
:param q1 (*, 4)
|
41 |
+
"""
|
42 |
+
r0, r1 = q0[..., :1], q1[..., :1]
|
43 |
+
v0, v1 = q0[..., 1:], q1[..., 1:]
|
44 |
+
r = r0 * r1 - (v0 * v1).sum(dim=-1, keepdim=True)
|
45 |
+
v = r0 * v1 + r1 * v0 + torch.linalg.cross(v0, v1)
|
46 |
+
return torch.cat([r, v], dim=-1)
|
47 |
+
|
48 |
+
|
49 |
+
def quaternion_inverse(q, eps=1e-8):
|
50 |
+
"""
|
51 |
+
EXPECTS WXYZ
|
52 |
+
:param q (*, 4)
|
53 |
+
"""
|
54 |
+
conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1)
|
55 |
+
mag = torch.square(q).sum(dim=-1, keepdim=True) + eps
|
56 |
+
return conj / mag
|
57 |
+
|
58 |
+
|
59 |
+
def quaternion_slerp(t, q0, q1, eps=1e-8):
|
60 |
+
"""
|
61 |
+
:param t (*, 1) must be between 0 and 1
|
62 |
+
:param q0 (*, 4)
|
63 |
+
:param q1 (*, 4)
|
64 |
+
"""
|
65 |
+
dims = q0.shape[:-1]
|
66 |
+
t = t.view(*dims, 1)
|
67 |
+
|
68 |
+
q0 = F.normalize(q0, p=2, dim=-1)
|
69 |
+
q1 = F.normalize(q1, p=2, dim=-1)
|
70 |
+
dot = (q0 * q1).sum(dim=-1, keepdim=True)
|
71 |
+
|
72 |
+
# make sure we give the shortest rotation path (< 180d)
|
73 |
+
neg = dot < 0
|
74 |
+
q1 = torch.where(neg, -q1, q1)
|
75 |
+
dot = torch.where(neg, -dot, dot)
|
76 |
+
angle = torch.acos(dot)
|
77 |
+
|
78 |
+
# if angle is too small, just do linear interpolation
|
79 |
+
collin = torch.abs(dot) > 1 - eps
|
80 |
+
fac = 1 / torch.sin(angle)
|
81 |
+
w0 = torch.where(collin, 1 - t, torch.sin((1 - t) * angle) * fac)
|
82 |
+
w1 = torch.where(collin, t, torch.sin(t * angle) * fac)
|
83 |
+
slerp = q0 * w0 + q1 * w1
|
84 |
+
return slerp
|
85 |
+
|
86 |
+
|
87 |
+
def rotation_matrix_to_angle_axis(rotation_matrix):
|
88 |
+
"""
|
89 |
+
This function is borrowed from https://github.com/kornia/kornia
|
90 |
+
|
91 |
+
Convert rotation matrix to Rodrigues vector
|
92 |
+
"""
|
93 |
+
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
94 |
+
aa = quaternion_to_angle_axis(quaternion)
|
95 |
+
aa[torch.isnan(aa)] = 0.0
|
96 |
+
return aa
|
97 |
+
|
98 |
+
|
99 |
+
def quaternion_to_angle_axis(quaternion):
|
100 |
+
"""
|
101 |
+
This function is borrowed from https://github.com/kornia/kornia
|
102 |
+
|
103 |
+
Convert quaternion vector to angle axis of rotation.
|
104 |
+
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
|
105 |
+
|
106 |
+
:param quaternion (*, 4) expects WXYZ
|
107 |
+
:returns angle_axis (*, 3)
|
108 |
+
"""
|
109 |
+
# unpack input and compute conversion
|
110 |
+
q1 = quaternion[..., 1]
|
111 |
+
q2 = quaternion[..., 2]
|
112 |
+
q3 = quaternion[..., 3]
|
113 |
+
sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3
|
114 |
+
|
115 |
+
sin_theta = torch.sqrt(sin_squared_theta)
|
116 |
+
cos_theta = quaternion[..., 0]
|
117 |
+
two_theta = 2.0 * torch.where(
|
118 |
+
cos_theta < 0.0,
|
119 |
+
torch.atan2(-sin_theta, -cos_theta),
|
120 |
+
torch.atan2(sin_theta, cos_theta),
|
121 |
+
)
|
122 |
+
|
123 |
+
k_pos = two_theta / sin_theta
|
124 |
+
k_neg = 2.0 * torch.ones_like(sin_theta)
|
125 |
+
k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
|
126 |
+
|
127 |
+
angle_axis = torch.zeros_like(quaternion)[..., :3]
|
128 |
+
angle_axis[..., 0] += q1 * k
|
129 |
+
angle_axis[..., 1] += q2 * k
|
130 |
+
angle_axis[..., 2] += q3 * k
|
131 |
+
return angle_axis
|
132 |
+
|
133 |
+
|
134 |
+
def angle_axis_to_rotation_matrix(angle_axis):
|
135 |
+
"""
|
136 |
+
:param angle_axis (*, 3)
|
137 |
+
return (*, 3, 3)
|
138 |
+
"""
|
139 |
+
quat = angle_axis_to_quaternion(angle_axis)
|
140 |
+
return quaternion_to_rotation_matrix(quat)
|
141 |
+
|
142 |
+
|
143 |
+
def quaternion_to_rotation_matrix(quaternion):
|
144 |
+
"""
|
145 |
+
Convert a quaternion to a rotation matrix.
|
146 |
+
Taken from https://github.com/kornia/kornia, based on
|
147 |
+
https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101
|
148 |
+
https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247
|
149 |
+
:param quaternion (N, 4) expects WXYZ order
|
150 |
+
returns rotation matrix (N, 3, 3)
|
151 |
+
"""
|
152 |
+
# normalize the input quaternion
|
153 |
+
quaternion_norm = F.normalize(quaternion, p=2, dim=-1, eps=1e-12)
|
154 |
+
*dims, _ = quaternion_norm.shape
|
155 |
+
|
156 |
+
# unpack the normalized quaternion components
|
157 |
+
w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1)
|
158 |
+
|
159 |
+
# compute the actual conversion
|
160 |
+
tx = 2.0 * x
|
161 |
+
ty = 2.0 * y
|
162 |
+
tz = 2.0 * z
|
163 |
+
twx = tx * w
|
164 |
+
twy = ty * w
|
165 |
+
twz = tz * w
|
166 |
+
txx = tx * x
|
167 |
+
txy = ty * x
|
168 |
+
txz = tz * x
|
169 |
+
tyy = ty * y
|
170 |
+
tyz = tz * y
|
171 |
+
tzz = tz * z
|
172 |
+
one = torch.tensor(1.0)
|
173 |
+
|
174 |
+
matrix = torch.stack(
|
175 |
+
(
|
176 |
+
one - (tyy + tzz),
|
177 |
+
txy - twz,
|
178 |
+
txz + twy,
|
179 |
+
txy + twz,
|
180 |
+
one - (txx + tzz),
|
181 |
+
tyz - twx,
|
182 |
+
txz - twy,
|
183 |
+
tyz + twx,
|
184 |
+
one - (txx + tyy),
|
185 |
+
),
|
186 |
+
dim=-1,
|
187 |
+
).view(*dims, 3, 3)
|
188 |
+
return matrix
|
189 |
+
|
190 |
+
|
191 |
+
def angle_axis_to_quaternion(angle_axis):
|
192 |
+
"""
|
193 |
+
This function is borrowed from https://github.com/kornia/kornia
|
194 |
+
Convert angle axis to quaternion in WXYZ order
|
195 |
+
:param angle_axis (*, 3)
|
196 |
+
:returns quaternion (*, 4) WXYZ order
|
197 |
+
"""
|
198 |
+
theta_sq = torch.sum(angle_axis**2, dim=-1, keepdim=True) # (*, 1)
|
199 |
+
# need to handle the zero rotation case
|
200 |
+
valid = theta_sq > 0
|
201 |
+
theta = torch.sqrt(theta_sq)
|
202 |
+
half_theta = 0.5 * theta
|
203 |
+
ones = torch.ones_like(half_theta)
|
204 |
+
# fill zero with the limit of sin ax / x -> a
|
205 |
+
k = torch.where(valid, torch.sin(half_theta) / theta, 0.5 * ones)
|
206 |
+
w = torch.where(valid, torch.cos(half_theta), ones)
|
207 |
+
quat = torch.cat([w, k * angle_axis], dim=-1)
|
208 |
+
return quat
|
209 |
+
|
210 |
+
|
211 |
+
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
|
212 |
+
"""
|
213 |
+
This function is borrowed from https://github.com/kornia/kornia
|
214 |
+
Convert rotation matrix to 4d quaternion vector
|
215 |
+
This algorithm is based on algorithm described in
|
216 |
+
https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
|
217 |
+
|
218 |
+
:param rotation_matrix (N, 3, 3)
|
219 |
+
"""
|
220 |
+
*dims, m, n = rotation_matrix.shape
|
221 |
+
rmat_t = torch.transpose(rotation_matrix.reshape(-1, m, n), -1, -2)
|
222 |
+
|
223 |
+
mask_d2 = rmat_t[:, 2, 2] < eps
|
224 |
+
|
225 |
+
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
|
226 |
+
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
|
227 |
+
|
228 |
+
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
229 |
+
q0 = torch.stack(
|
230 |
+
[
|
231 |
+
rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
232 |
+
t0,
|
233 |
+
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
234 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
235 |
+
],
|
236 |
+
-1,
|
237 |
+
)
|
238 |
+
t0_rep = t0.repeat(4, 1).t()
|
239 |
+
|
240 |
+
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
241 |
+
q1 = torch.stack(
|
242 |
+
[
|
243 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
244 |
+
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
245 |
+
t1,
|
246 |
+
rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
|
247 |
+
],
|
248 |
+
-1,
|
249 |
+
)
|
250 |
+
t1_rep = t1.repeat(4, 1).t()
|
251 |
+
|
252 |
+
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
253 |
+
q2 = torch.stack(
|
254 |
+
[
|
255 |
+
rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
|
256 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
257 |
+
rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
|
258 |
+
t2,
|
259 |
+
],
|
260 |
+
-1,
|
261 |
+
)
|
262 |
+
t2_rep = t2.repeat(4, 1).t()
|
263 |
+
|
264 |
+
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
265 |
+
q3 = torch.stack(
|
266 |
+
[
|
267 |
+
t3,
|
268 |
+
rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
269 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
270 |
+
rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
|
271 |
+
],
|
272 |
+
-1,
|
273 |
+
)
|
274 |
+
t3_rep = t3.repeat(4, 1).t()
|
275 |
+
|
276 |
+
mask_c0 = mask_d2 * mask_d0_d1
|
277 |
+
mask_c1 = mask_d2 * ~mask_d0_d1
|
278 |
+
mask_c2 = ~mask_d2 * mask_d0_nd1
|
279 |
+
mask_c3 = ~mask_d2 * ~mask_d0_nd1
|
280 |
+
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
|
281 |
+
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
|
282 |
+
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
|
283 |
+
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
|
284 |
+
|
285 |
+
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
|
286 |
+
q /= torch.sqrt(
|
287 |
+
t0_rep * mask_c0
|
288 |
+
+ t1_rep * mask_c1
|
289 |
+
+ t2_rep * mask_c2 # noqa
|
290 |
+
+ t3_rep * mask_c3
|
291 |
+
) # noqa
|
292 |
+
q *= 0.5
|
293 |
+
return q.reshape(*dims, 4)
|
infiller/lib/misc/sampler.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import imageio
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from PIL import Image
|
10 |
+
from sklearn.preprocessing import LabelEncoder
|
11 |
+
|
12 |
+
from cmib.data.lafan1_dataset import LAFAN1Dataset
|
13 |
+
from cmib.data.utils import write_json
|
14 |
+
from cmib.lafan1.utils import quat_ik
|
15 |
+
from cmib.model.network import TransformerModel
|
16 |
+
from cmib.model.preprocess import (lerp_input_repr, replace_constant,
|
17 |
+
slerp_input_repr, vectorize_representation)
|
18 |
+
from cmib.model.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets, joint_names,
|
19 |
+
sk_parents)
|
20 |
+
from cmib.vis.pose import plot_pose_with_stop
|
21 |
+
|
22 |
+
|
23 |
+
def test(opt, device):
|
24 |
+
|
25 |
+
save_dir = Path(os.path.join('runs', 'train', opt.exp_name))
|
26 |
+
wdir = save_dir / 'weights'
|
27 |
+
weights = os.listdir(wdir)
|
28 |
+
weights_paths = [wdir / weight for weight in weights]
|
29 |
+
latest_weight = max(weights_paths , key = os.path.getctime)
|
30 |
+
ckpt = torch.load(latest_weight, map_location=device)
|
31 |
+
print(f"Loaded weight: {latest_weight}")
|
32 |
+
|
33 |
+
# Load Skeleton
|
34 |
+
skeleton_mocap = Skeleton(offsets=sk_offsets, parents=sk_parents, device=device)
|
35 |
+
skeleton_mocap.remove_joints(sk_joints_to_remove)
|
36 |
+
|
37 |
+
# Load LAFAN Dataset
|
38 |
+
Path(opt.processed_data_dir).mkdir(parents=True, exist_ok=True)
|
39 |
+
lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=opt.processed_data_dir, train=False, device=device)
|
40 |
+
total_data = lafan_dataset.data['global_pos'].shape[0]
|
41 |
+
|
42 |
+
# Replace with noise to In-betweening Frames
|
43 |
+
from_idx, target_idx = ckpt['from_idx'], ckpt['target_idx'] # default: 9-40, max: 48
|
44 |
+
horizon = ckpt['horizon']
|
45 |
+
print(f"HORIZON: {horizon}")
|
46 |
+
|
47 |
+
test_idx = []
|
48 |
+
for i in range(total_data):
|
49 |
+
test_idx.append(i)
|
50 |
+
|
51 |
+
# Compare Input data, Prediction, GT
|
52 |
+
save_path = os.path.join(opt.save_path, 'sampler')
|
53 |
+
for i in range(len(test_idx)):
|
54 |
+
Path(save_path).mkdir(parents=True, exist_ok=True)
|
55 |
+
|
56 |
+
start_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx]
|
57 |
+
target_pose = lafan_dataset.data['global_pos'][test_idx[i], target_idx]
|
58 |
+
gt_stopover_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx]
|
59 |
+
|
60 |
+
gt_img_path = os.path.join(save_path)
|
61 |
+
plot_pose_with_stop(start_pose, target_pose, target_pose, gt_stopover_pose, i, skeleton_mocap, save_dir=gt_img_path, prefix='gt')
|
62 |
+
print(f"ID {test_idx[i]}: completed.")
|
63 |
+
|
64 |
+
def parse_opt():
|
65 |
+
parser = argparse.ArgumentParser()
|
66 |
+
parser.add_argument('--project', default='runs/train', help='project/name')
|
67 |
+
parser.add_argument('--exp_name', default='slerp_40', help='experiment name')
|
68 |
+
parser.add_argument('--data_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH', help='BVH dataset path')
|
69 |
+
parser.add_argument('--skeleton_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH/walk1_subject1.bvh', help='path to reference skeleton')
|
70 |
+
parser.add_argument('--processed_data_dir', type=str, default='processed_data_original/', help='path to save pickled processed data')
|
71 |
+
parser.add_argument('--save_path', type=str, default='runs/test', help='path to save model')
|
72 |
+
parser.add_argument('--motion_type', type=str, default='jumps', help='motion type')
|
73 |
+
opt = parser.parse_args()
|
74 |
+
return opt
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
opt = parse_opt()
|
78 |
+
device = torch.device("cpu")
|
79 |
+
test(opt, device)
|
infiller/lib/model/__pycache__/network.cpython-310.pyc
ADDED
Binary file (7.82 kB). View file
|
|
infiller/lib/model/network.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn, Tensor
|
5 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
6 |
+
# from cmib.model.positional_encoding import PositionalEmbedding
|
7 |
+
|
8 |
+
class SinPositionalEncoding(nn.Module):
|
9 |
+
def __init__(self, d_model, dropout=0.1, max_len=100):
|
10 |
+
super(SinPositionalEncoding, self).__init__()
|
11 |
+
self.dropout = nn.Dropout(p=dropout)
|
12 |
+
|
13 |
+
pe = torch.zeros(max_len, d_model)
|
14 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
15 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
|
16 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
17 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
18 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
19 |
+
|
20 |
+
self.register_buffer('pe', pe)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
# not used in the final model
|
24 |
+
x = x + self.pe[:x.shape[0], :]
|
25 |
+
return self.dropout(x)
|
26 |
+
|
27 |
+
|
28 |
+
class MultiHeadedAttention(nn.Module):
|
29 |
+
def __init__(self, n_head, d_model, d_head, dropout=0.1,
|
30 |
+
pre_lnorm=True, bias=False):
|
31 |
+
"""
|
32 |
+
Multi-headed attention with relative positional encoding and
|
33 |
+
memory mechanism.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
n_head (int): Number of heads.
|
37 |
+
d_model (int): Input dimension.
|
38 |
+
d_head (int): Head dimension.
|
39 |
+
dropout (float, optional): Dropout value. Defaults to 0.1.
|
40 |
+
pre_lnorm (bool, optional):
|
41 |
+
Apply layer norm before rest of calculation. Defaults to True.
|
42 |
+
In original Transformer paper (pre_lnorm=False):
|
43 |
+
LayerNorm(x + Sublayer(x))
|
44 |
+
In tensor2tensor implementation (pre_lnorm=True):
|
45 |
+
x + Sublayer(LayerNorm(x))
|
46 |
+
bias (bool, optional):
|
47 |
+
Add bias to q, k, v and output projections. Defaults to False.
|
48 |
+
|
49 |
+
"""
|
50 |
+
super(MultiHeadedAttention, self).__init__()
|
51 |
+
|
52 |
+
self.n_head = n_head
|
53 |
+
self.d_model = d_model
|
54 |
+
self.d_head = d_head
|
55 |
+
self.dropout = dropout
|
56 |
+
self.pre_lnorm = pre_lnorm
|
57 |
+
self.bias = bias
|
58 |
+
self.atten_scale = 1 / math.sqrt(self.d_model)
|
59 |
+
|
60 |
+
self.q_linear = nn.Linear(d_model, n_head * d_head, bias=bias)
|
61 |
+
self.k_linear = nn.Linear(d_model, n_head * d_head, bias=bias)
|
62 |
+
self.v_linear = nn.Linear(d_model, n_head * d_head, bias=bias)
|
63 |
+
self.out_linear = nn.Linear(n_head * d_head, d_model, bias=bias)
|
64 |
+
|
65 |
+
self.droput_layer = nn.Dropout(dropout)
|
66 |
+
self.atten_dropout_layer = nn.Dropout(dropout)
|
67 |
+
|
68 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
69 |
+
|
70 |
+
def forward(self, hidden, memory=None, mask=None,
|
71 |
+
extra_atten_score=None):
|
72 |
+
"""
|
73 |
+
Args:
|
74 |
+
hidden (Tensor): Input embedding or hidden state of previous layer.
|
75 |
+
Shape: (batch, seq, dim)
|
76 |
+
pos_emb (Tensor): Relative positional embedding lookup table.
|
77 |
+
Shape: (batch, (seq+mem_len)*2-1, d_head)
|
78 |
+
pos_emb[:, seq+mem_len]
|
79 |
+
|
80 |
+
memory (Tensor): Memory tensor of previous layer.
|
81 |
+
Shape: (batch, mem_len, dim)
|
82 |
+
mask (BoolTensor, optional): Attention mask.
|
83 |
+
Set item value to True if you DO NOT want keep certain
|
84 |
+
attention score, otherwise False. Defaults to None.
|
85 |
+
Shape: (seq, seq+mem_len).
|
86 |
+
"""
|
87 |
+
combined = hidden
|
88 |
+
# if memory is None:
|
89 |
+
# combined = hidden
|
90 |
+
# mem_len = 0
|
91 |
+
# else:
|
92 |
+
# combined = torch.cat([memory, hidden], dim=1)
|
93 |
+
# mem_len = memory.shape[1]
|
94 |
+
|
95 |
+
if self.pre_lnorm:
|
96 |
+
hidden = self.layer_norm(hidden)
|
97 |
+
combined = self.layer_norm(combined)
|
98 |
+
|
99 |
+
# shape: (batch, q/k/v_len, dim)
|
100 |
+
q = self.q_linear(hidden)
|
101 |
+
k = self.k_linear(combined)
|
102 |
+
v = self.v_linear(combined)
|
103 |
+
|
104 |
+
# reshape to (batch, q/k/v_len, n_head, d_head)
|
105 |
+
q = q.reshape(q.shape[0], q.shape[1], self.n_head, self.d_head)
|
106 |
+
k = k.reshape(k.shape[0], k.shape[1], self.n_head, self.d_head)
|
107 |
+
v = v.reshape(v.shape[0], v.shape[1], self.n_head, self.d_head)
|
108 |
+
|
109 |
+
# transpose to (batch, n_head, q/k/v_len, d_head)
|
110 |
+
q = q.transpose(1, 2)
|
111 |
+
k = k.transpose(1, 2)
|
112 |
+
v = v.transpose(1, 2)
|
113 |
+
|
114 |
+
# add n_head dimension for relative positional embedding lookup table
|
115 |
+
# (batch, n_head, k/v_len*2-1, d_head)
|
116 |
+
# pos_emb = pos_emb[:, None]
|
117 |
+
|
118 |
+
# (batch, n_head, q_len, k_len)
|
119 |
+
atten_score = torch.matmul(q, k.transpose(-1, -2))
|
120 |
+
|
121 |
+
# qpos = torch.matmul(q, pos_emb.transpose(-1, -2))
|
122 |
+
# DEBUG
|
123 |
+
# ones = torch.zeros(q.shape)
|
124 |
+
# ones[:, :, :, 0] = 1.0
|
125 |
+
# qpos = torch.matmul(ones, pos_emb.transpose(-1, -2))
|
126 |
+
# atten_score = atten_score + self.skew(qpos, mem_len)
|
127 |
+
atten_score = atten_score * self.atten_scale
|
128 |
+
|
129 |
+
# if extra_atten_score is not None:
|
130 |
+
# atten_score = atten_score + extra_atten_score
|
131 |
+
|
132 |
+
if mask is not None:
|
133 |
+
# print(atten_score.shape)
|
134 |
+
# print(mask.shape)
|
135 |
+
# apply attention mask
|
136 |
+
atten_score = atten_score.masked_fill(mask, float("-inf"))
|
137 |
+
atten_score = atten_score.softmax(dim=-1)
|
138 |
+
atten_score = self.atten_dropout_layer(atten_score)
|
139 |
+
|
140 |
+
# (batch, n_head, q_len, d_head)
|
141 |
+
atten_vec = torch.matmul(atten_score, v)
|
142 |
+
# (batch, q_len, n_head*d_head)
|
143 |
+
atten_vec = atten_vec.transpose(1, 2).flatten(start_dim=-2)
|
144 |
+
|
145 |
+
# linear projection
|
146 |
+
output = self.droput_layer(self.out_linear(atten_vec))
|
147 |
+
|
148 |
+
if self.pre_lnorm:
|
149 |
+
return hidden + output
|
150 |
+
else:
|
151 |
+
return self.layer_norm(hidden + output)
|
152 |
+
|
153 |
+
|
154 |
+
class FeedForward(nn.Module):
|
155 |
+
def __init__(self, d_model, d_inner, dropout=0.1, pre_lnorm=True):
|
156 |
+
"""
|
157 |
+
Positionwise feed-forward network.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
d_model(int): Dimension of the input and output.
|
161 |
+
d_inner (int): Dimension of the middle layer(bottleneck).
|
162 |
+
dropout (float, optional): Dropout value. Defaults to 0.1.
|
163 |
+
pre_lnorm (bool, optional):
|
164 |
+
Apply layer norm before rest of calculation. Defaults to True.
|
165 |
+
In original Transformer paper (pre_lnorm=False):
|
166 |
+
LayerNorm(x + Sublayer(x))
|
167 |
+
In tensor2tensor implementation (pre_lnorm=True):
|
168 |
+
x + Sublayer(LayerNorm(x))
|
169 |
+
"""
|
170 |
+
super(FeedForward, self).__init__()
|
171 |
+
self.d_model = d_model
|
172 |
+
self.d_inner = d_inner
|
173 |
+
self.dropout = dropout
|
174 |
+
self.pre_lnorm = pre_lnorm
|
175 |
+
|
176 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
177 |
+
self.network = nn.Sequential(
|
178 |
+
nn.Linear(d_model, d_inner),
|
179 |
+
nn.ReLU(),
|
180 |
+
nn.Dropout(dropout),
|
181 |
+
nn.Linear(d_inner, d_model),
|
182 |
+
nn.Dropout(dropout),
|
183 |
+
)
|
184 |
+
|
185 |
+
def forward(self, x):
|
186 |
+
if self.pre_lnorm:
|
187 |
+
return x + self.network(self.layer_norm(x))
|
188 |
+
else:
|
189 |
+
return self.layer_norm(x + self.network(x))
|
190 |
+
class TransformerModel(nn.Module):
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
seq_len: int,
|
194 |
+
input_dim: int,
|
195 |
+
d_model: int,
|
196 |
+
nhead: int,
|
197 |
+
d_hid: int,
|
198 |
+
nlayers: int,
|
199 |
+
dropout: float = 0.5,
|
200 |
+
out_dim=91,
|
201 |
+
masked_attention_stage=False,
|
202 |
+
):
|
203 |
+
super().__init__()
|
204 |
+
self.model_type = "Transformer"
|
205 |
+
self.seq_len = seq_len
|
206 |
+
self.d_model = d_model
|
207 |
+
self.nhead = nhead
|
208 |
+
self.d_hid = d_hid
|
209 |
+
self.nlayers = nlayers
|
210 |
+
self.pos_embedding = SinPositionalEncoding(d_model=d_model, dropout=0.1, max_len=seq_len)
|
211 |
+
if masked_attention_stage:
|
212 |
+
self.input_layer = nn.Linear(input_dim+1, d_model)
|
213 |
+
# visible to invisible attention
|
214 |
+
self.att_layers = nn.ModuleList()
|
215 |
+
self.pff_layers = nn.ModuleList()
|
216 |
+
self.pre_lnorm = True
|
217 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
218 |
+
for i in range(self.nlayers):
|
219 |
+
self.att_layers.append(
|
220 |
+
MultiHeadedAttention(
|
221 |
+
self.nhead, self.d_model,
|
222 |
+
self.d_model // self.nhead, dropout=dropout,
|
223 |
+
pre_lnorm=True,
|
224 |
+
bias=False
|
225 |
+
)
|
226 |
+
)
|
227 |
+
|
228 |
+
self.pff_layers.append(
|
229 |
+
FeedForward(
|
230 |
+
self.d_model, d_hid,
|
231 |
+
dropout=dropout,
|
232 |
+
pre_lnorm=True
|
233 |
+
)
|
234 |
+
)
|
235 |
+
else:
|
236 |
+
self.att_layers = None
|
237 |
+
self.input_layer = nn.Linear(input_dim, d_model)
|
238 |
+
encoder_layers = TransformerEncoderLayer(
|
239 |
+
d_model, nhead, d_hid, dropout, activation="gelu"
|
240 |
+
)
|
241 |
+
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
|
242 |
+
self.decoder = nn.Linear(d_model, out_dim)
|
243 |
+
|
244 |
+
self.init_weights()
|
245 |
+
|
246 |
+
def init_weights(self) -> None:
|
247 |
+
initrange = 0.1
|
248 |
+
self.decoder.bias.data.zero_()
|
249 |
+
self.decoder.weight.data.uniform_(-initrange, initrange)
|
250 |
+
|
251 |
+
def forward(self, src: Tensor, src_mask: Tensor, data_mask=None, atten_mask=None) -> Tensor:
|
252 |
+
"""
|
253 |
+
Args:
|
254 |
+
src: Tensor, shape [seq_len, batch_size, embedding_dim]
|
255 |
+
src_mask: Tensor, shape [seq_len, seq_len]
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
output Tensor of shape [seq_len, batch_size, embedding_dim]
|
259 |
+
"""
|
260 |
+
if not data_mask is None:
|
261 |
+
src = torch.cat([src, data_mask.expand(*src.shape[:-1], data_mask.shape[-1])], dim=-1)
|
262 |
+
src = self.input_layer(src)
|
263 |
+
output = self.pos_embedding(src)
|
264 |
+
# output = src
|
265 |
+
if self.att_layers:
|
266 |
+
assert not atten_mask is None
|
267 |
+
output = output.permute(1, 0, 2)
|
268 |
+
for i in range(self.nlayers):
|
269 |
+
output = self.att_layers[i](output, mask=atten_mask)
|
270 |
+
output = self.pff_layers[i](output)
|
271 |
+
if self.pre_lnorm:
|
272 |
+
output = self.layer_norm(output)
|
273 |
+
output = output.permute(1, 0, 2)
|
274 |
+
output = self.transformer_encoder(output)
|
275 |
+
output = self.decoder(output)
|
276 |
+
return output
|
infiller/lib/model/positional_encoding.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
class PositionalEmbedding(nn.Module):
|
7 |
+
def __init__(self, seq_len: int = 32, d_model: int = 96):
|
8 |
+
super().__init__()
|
9 |
+
self.pos_emb = nn.Embedding(seq_len + 1, d_model)
|
10 |
+
|
11 |
+
def forward(self, inputs):
|
12 |
+
positions = (
|
13 |
+
torch.arange(inputs.size(0), device=inputs.device)
|
14 |
+
.expand(inputs.size(1), inputs.size(0))
|
15 |
+
.contiguous()
|
16 |
+
+ 1
|
17 |
+
)
|
18 |
+
outputs = inputs + self.pos_emb(positions).permute(1, 0, 2)
|
19 |
+
return outputs
|
20 |
+
|
21 |
+
|
22 |
+
class PositionalEncoding(nn.Module):
|
23 |
+
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
|
24 |
+
super().__init__()
|
25 |
+
self.dropout = nn.Dropout(p=dropout)
|
26 |
+
|
27 |
+
position = torch.arange(max_len).unsqueeze(1)
|
28 |
+
div_term = torch.exp(
|
29 |
+
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
|
30 |
+
)
|
31 |
+
pe = torch.zeros(max_len, 1, d_model)
|
32 |
+
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
33 |
+
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
34 |
+
self.register_buffer("pe", pe)
|
35 |
+
|
36 |
+
def forward(self, x: Tensor) -> Tensor:
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
x: Tensor, shape [seq_len, batch_size, embedding_dim]
|
40 |
+
"""
|
41 |
+
x = x + self.pe[: x.size(0)]
|
42 |
+
return self.dropout(x)
|
infiller/lib/model/preprocess.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def replace_constant(minibatch_pose_input, mask_start_frame):
|
5 |
+
|
6 |
+
seq_len = minibatch_pose_input.size(1)
|
7 |
+
interpolated = (
|
8 |
+
torch.ones_like(minibatch_pose_input, device=minibatch_pose_input.device) * 0.1
|
9 |
+
)
|
10 |
+
|
11 |
+
if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
|
12 |
+
interpolate_start = minibatch_pose_input[:, 0, :]
|
13 |
+
interpolate_end = minibatch_pose_input[:, seq_len - 1, :]
|
14 |
+
|
15 |
+
interpolated[:, 0, :] = interpolate_start
|
16 |
+
interpolated[:, seq_len - 1, :] = interpolate_end
|
17 |
+
|
18 |
+
assert torch.allclose(interpolated[:, 0, :], interpolate_start)
|
19 |
+
assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end)
|
20 |
+
|
21 |
+
else:
|
22 |
+
interpolate_start1 = minibatch_pose_input[:, 0, :]
|
23 |
+
interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :]
|
24 |
+
|
25 |
+
interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :]
|
26 |
+
interpolate_end2 = minibatch_pose_input[:, seq_len - 1, :]
|
27 |
+
|
28 |
+
interpolated[:, 0, :] = interpolate_start1
|
29 |
+
interpolated[:, mask_start_frame, :] = interpolate_end1
|
30 |
+
|
31 |
+
interpolated[:, mask_start_frame, :] = interpolate_start2
|
32 |
+
interpolated[:, seq_len - 1, :] = interpolate_end2
|
33 |
+
|
34 |
+
assert torch.allclose(interpolated[:, 0, :], interpolate_start1)
|
35 |
+
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1)
|
36 |
+
|
37 |
+
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2)
|
38 |
+
assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end2)
|
39 |
+
return interpolated
|
40 |
+
|
41 |
+
|
42 |
+
def slerp(x, y, a):
|
43 |
+
"""
|
44 |
+
Perfroms spherical linear interpolation (SLERP) between x and y, with proportion a
|
45 |
+
|
46 |
+
:param x: quaternion tensor
|
47 |
+
:param y: quaternion tensor
|
48 |
+
:param a: indicator (between 0 and 1) of completion of the interpolation.
|
49 |
+
:return: tensor of interpolation results
|
50 |
+
"""
|
51 |
+
device = x.device
|
52 |
+
len = torch.sum(x * y, dim=-1)
|
53 |
+
|
54 |
+
neg = len < 0.0
|
55 |
+
len[neg] = -len[neg]
|
56 |
+
y[neg] = -y[neg]
|
57 |
+
|
58 |
+
a = torch.zeros_like(x[..., 0]) + a
|
59 |
+
amount0 = torch.zeros(a.shape, device=device)
|
60 |
+
amount1 = torch.zeros(a.shape, device=device)
|
61 |
+
|
62 |
+
linear = (1.0 - len) < 0.01
|
63 |
+
omegas = torch.arccos(len[~linear])
|
64 |
+
sinoms = torch.sin(omegas)
|
65 |
+
|
66 |
+
amount0[linear] = 1.0 - a[linear]
|
67 |
+
amount0[~linear] = torch.sin((1.0 - a[~linear]) * omegas) / sinoms
|
68 |
+
|
69 |
+
amount1[linear] = a[linear]
|
70 |
+
amount1[~linear] = torch.sin(a[~linear] * omegas) / sinoms
|
71 |
+
# res = amount0[..., np.newaxis] * x + amount1[..., np.newaxis] * y
|
72 |
+
res = amount0.unsqueeze(3) * x + amount1.unsqueeze(3) * y
|
73 |
+
|
74 |
+
return res
|
75 |
+
|
76 |
+
|
77 |
+
def slerp_input_repr(minibatch_pose_input, mask_start_frame):
|
78 |
+
seq_len = minibatch_pose_input.size(1)
|
79 |
+
minibatch_pose_input = minibatch_pose_input.reshape(
|
80 |
+
minibatch_pose_input.size(0), seq_len, -1, 4
|
81 |
+
)
|
82 |
+
interpolated = torch.zeros_like(
|
83 |
+
minibatch_pose_input, device=minibatch_pose_input.device
|
84 |
+
)
|
85 |
+
|
86 |
+
if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
|
87 |
+
interpolate_start = minibatch_pose_input[:, 0:1]
|
88 |
+
interpolate_end = minibatch_pose_input[:, seq_len - 1 :]
|
89 |
+
|
90 |
+
for i in range(seq_len):
|
91 |
+
dt = 1 / (seq_len - 1)
|
92 |
+
interpolated[:, i : i + 1, :] = slerp(
|
93 |
+
interpolate_start, interpolate_end, dt * i
|
94 |
+
)
|
95 |
+
|
96 |
+
assert torch.allclose(interpolated[:, 0:1], interpolate_start)
|
97 |
+
assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end)
|
98 |
+
else:
|
99 |
+
interpolate_start1 = minibatch_pose_input[:, 0:1]
|
100 |
+
interpolate_end1 = minibatch_pose_input[
|
101 |
+
:, mask_start_frame : mask_start_frame + 1
|
102 |
+
]
|
103 |
+
|
104 |
+
interpolate_start2 = minibatch_pose_input[
|
105 |
+
:, mask_start_frame : mask_start_frame + 1
|
106 |
+
]
|
107 |
+
interpolate_end2 = minibatch_pose_input[:, seq_len - 1 :]
|
108 |
+
|
109 |
+
for i in range(mask_start_frame + 1):
|
110 |
+
dt = 1 / mask_start_frame
|
111 |
+
interpolated[:, i : i + 1, :] = slerp(
|
112 |
+
interpolate_start1, interpolate_end1, dt * i
|
113 |
+
)
|
114 |
+
|
115 |
+
assert torch.allclose(interpolated[:, 0:1], interpolate_start1)
|
116 |
+
assert torch.allclose(
|
117 |
+
interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_end1
|
118 |
+
)
|
119 |
+
|
120 |
+
for i in range(mask_start_frame, seq_len):
|
121 |
+
dt = 1 / (seq_len - mask_start_frame - 1)
|
122 |
+
interpolated[:, i : i + 1, :] = slerp(
|
123 |
+
interpolate_start2, interpolate_end2, dt * (i - mask_start_frame)
|
124 |
+
)
|
125 |
+
|
126 |
+
assert torch.allclose(
|
127 |
+
interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_start2
|
128 |
+
)
|
129 |
+
assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end2)
|
130 |
+
|
131 |
+
interpolated = torch.nn.functional.normalize(interpolated, p=2.0, dim=3)
|
132 |
+
return interpolated.reshape(minibatch_pose_input.size(0), seq_len, -1)
|
133 |
+
|
134 |
+
|
135 |
+
def lerp_input_repr(minibatch_pose_input, mask_start_frame):
|
136 |
+
seq_len = minibatch_pose_input.size(1)
|
137 |
+
interpolated = torch.zeros_like(
|
138 |
+
minibatch_pose_input, device=minibatch_pose_input.device
|
139 |
+
)
|
140 |
+
|
141 |
+
if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
|
142 |
+
interpolate_start = minibatch_pose_input[:, 0, :]
|
143 |
+
interpolate_end = minibatch_pose_input[:, seq_len - 1, :]
|
144 |
+
|
145 |
+
for i in range(seq_len):
|
146 |
+
dt = 1 / (seq_len - 1)
|
147 |
+
interpolated[:, i, :] = torch.lerp(
|
148 |
+
interpolate_start, interpolate_end, dt * i
|
149 |
+
)
|
150 |
+
|
151 |
+
assert torch.allclose(interpolated[:, 0, :], interpolate_start)
|
152 |
+
assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end)
|
153 |
+
else:
|
154 |
+
interpolate_start1 = minibatch_pose_input[:, 0, :]
|
155 |
+
interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :]
|
156 |
+
|
157 |
+
interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :]
|
158 |
+
interpolate_end2 = minibatch_pose_input[:, -1, :]
|
159 |
+
|
160 |
+
for i in range(mask_start_frame + 1):
|
161 |
+
dt = 1 / mask_start_frame
|
162 |
+
interpolated[:, i, :] = torch.lerp(
|
163 |
+
interpolate_start1, interpolate_end1, dt * i
|
164 |
+
)
|
165 |
+
|
166 |
+
assert torch.allclose(interpolated[:, 0, :], interpolate_start1)
|
167 |
+
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1)
|
168 |
+
|
169 |
+
for i in range(mask_start_frame, seq_len):
|
170 |
+
dt = 1 / (seq_len - mask_start_frame - 1)
|
171 |
+
interpolated[:, i, :] = torch.lerp(
|
172 |
+
interpolate_start2, interpolate_end2, dt * (i - mask_start_frame)
|
173 |
+
)
|
174 |
+
|
175 |
+
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2)
|
176 |
+
assert torch.allclose(interpolated[:, -1, :], interpolate_end2)
|
177 |
+
return interpolated
|
178 |
+
|
179 |
+
|
180 |
+
def vectorize_representation(global_position, global_rotation):
|
181 |
+
|
182 |
+
batch_size = global_position.shape[0]
|
183 |
+
seq_len = global_position.shape[1]
|
184 |
+
|
185 |
+
global_pos_vec = global_position.reshape(batch_size, seq_len, -1).contiguous()
|
186 |
+
global_rot_vec = global_rotation.reshape(batch_size, seq_len, -1).contiguous()
|
187 |
+
|
188 |
+
global_pose_vec_gt = torch.cat([global_pos_vec, global_rot_vec], dim=2)
|
189 |
+
return global_pose_vec_gt
|
infiller/lib/model/skeleton.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from cmib.data.quaternion import qmul, qrot
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
amass_offsets = [
|
7 |
+
[0.0, 0.0, 0.0],
|
8 |
+
|
9 |
+
[0.058581, -0.082280, -0.017664],
|
10 |
+
[0.043451, -0.386469, 0.008037],
|
11 |
+
[-0.014790, -0.426874, -0.037428],
|
12 |
+
[0.041054, -0.060286, 0.122042],
|
13 |
+
[0.0, 0.0, 0.0],
|
14 |
+
|
15 |
+
[-0.060310, -0.090513, -0.013543],
|
16 |
+
[-0.043257, -0.383688, -0.004843],
|
17 |
+
[0.019056, -0.420046, -0.034562],
|
18 |
+
[-0.034840, -0.062106, 0.130323],
|
19 |
+
[0.0, 0.0, 0.0],
|
20 |
+
|
21 |
+
[0.004439, 0.124404, -0.038385],
|
22 |
+
[0.004488, 0.137956, 0.026820],
|
23 |
+
[-0.002265, 0.056032, 0.002855],
|
24 |
+
[-0.013390, 0.211636, -0.033468],
|
25 |
+
[0.010113, 0.088937, 0.050410],
|
26 |
+
[0.0, 0.0, 0.0],
|
27 |
+
|
28 |
+
[0.071702, 0.114000, -0.018898],
|
29 |
+
[0.122921, 0.045205, -0.019046],
|
30 |
+
[0.255332, -0.015649, -0.022946],
|
31 |
+
[0.265709, 0.012698, -0.007375],
|
32 |
+
[0.0, 0.0, 0.0],
|
33 |
+
|
34 |
+
[-0.082954, 0.112472, -0.023707],
|
35 |
+
[-0.113228, 0.046853, -0.008472],
|
36 |
+
[-0.260127, -0.014369, -0.031269],
|
37 |
+
[-0.269108, 0.006794, -0.006027],
|
38 |
+
[0.0, 0.0, 0.0]
|
39 |
+
]
|
40 |
+
|
41 |
+
sk_offsets = [
|
42 |
+
[-42.198200, 91.614723, -40.067841],
|
43 |
+
|
44 |
+
[0.103456, 1.857829, 10.548506],
|
45 |
+
[43.499992, -0.000038, -0.000002],
|
46 |
+
[42.372192, 0.000015, -0.000007],
|
47 |
+
[17.299999, -0.000002, 0.000003],
|
48 |
+
[0.000000, 0.000000, 0.000000],
|
49 |
+
|
50 |
+
[0.103457, 1.857829, -10.548503],
|
51 |
+
[43.500042, -0.000027, 0.000008],
|
52 |
+
[42.372257, -0.000008, 0.000014],
|
53 |
+
[17.299992, -0.000005, 0.000004],
|
54 |
+
[0.000000, 0.000000, 0.000000],
|
55 |
+
|
56 |
+
[6.901968, -2.603733, -0.000001],
|
57 |
+
[12.588099, 0.000002, 0.000000],
|
58 |
+
[12.343206, 0.000000, -0.000001],
|
59 |
+
[25.832886, -0.000004, 0.000003],
|
60 |
+
[11.766620, 0.000005, -0.000001],
|
61 |
+
[0.000000, 0.000000, 0.000000],
|
62 |
+
|
63 |
+
[19.745899, -1.480370, 6.000108],
|
64 |
+
[11.284125, -0.000009, -0.000018],
|
65 |
+
[33.000050, 0.000004, 0.000032],
|
66 |
+
[25.200008, 0.000015, 0.000008],
|
67 |
+
[0.000000, 0.000000, 0.000000],
|
68 |
+
|
69 |
+
[19.746099, -1.480375, -6.000073],
|
70 |
+
[11.284138, -0.000015, -0.000012],
|
71 |
+
[33.000092, 0.000017, 0.000013],
|
72 |
+
[25.199780, 0.000135, 0.000422],
|
73 |
+
[0.000000, 0.000000, 0.000000],
|
74 |
+
]
|
75 |
+
|
76 |
+
sk_parents = [
|
77 |
+
-1,
|
78 |
+
0,
|
79 |
+
1,
|
80 |
+
2,
|
81 |
+
3,
|
82 |
+
4,
|
83 |
+
0,
|
84 |
+
6,
|
85 |
+
7,
|
86 |
+
8,
|
87 |
+
9,
|
88 |
+
0,
|
89 |
+
11,
|
90 |
+
12,
|
91 |
+
13,
|
92 |
+
14,
|
93 |
+
15,
|
94 |
+
13,
|
95 |
+
17,
|
96 |
+
18,
|
97 |
+
19,
|
98 |
+
20,
|
99 |
+
13,
|
100 |
+
22,
|
101 |
+
23,
|
102 |
+
24,
|
103 |
+
25,
|
104 |
+
]
|
105 |
+
|
106 |
+
sk_joints_to_remove = [5, 10, 16, 21, 26]
|
107 |
+
|
108 |
+
joint_names = [
|
109 |
+
"Hips",
|
110 |
+
"LeftUpLeg",
|
111 |
+
"LeftLeg",
|
112 |
+
"LeftFoot",
|
113 |
+
"LeftToe",
|
114 |
+
"RightUpLeg",
|
115 |
+
"RightLeg",
|
116 |
+
"RightFoot",
|
117 |
+
"RightToe",
|
118 |
+
"Spine",
|
119 |
+
"Spine1",
|
120 |
+
"Spine2",
|
121 |
+
"Neck",
|
122 |
+
"Head",
|
123 |
+
"LeftShoulder",
|
124 |
+
"LeftArm",
|
125 |
+
"LeftForeArm",
|
126 |
+
"LeftHand",
|
127 |
+
"RightShoulder",
|
128 |
+
"RightArm",
|
129 |
+
"RightForeArm",
|
130 |
+
"RightHand",
|
131 |
+
]
|
132 |
+
|
133 |
+
|
134 |
+
class Skeleton:
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
offsets,
|
138 |
+
parents,
|
139 |
+
joints_left=None,
|
140 |
+
joints_right=None,
|
141 |
+
bone_length=None,
|
142 |
+
device=None,
|
143 |
+
):
|
144 |
+
assert len(offsets) == len(parents)
|
145 |
+
|
146 |
+
self._offsets = torch.Tensor(offsets).to(device)
|
147 |
+
self._parents = np.array(parents)
|
148 |
+
self._joints_left = joints_left
|
149 |
+
self._joints_right = joints_right
|
150 |
+
self._compute_metadata()
|
151 |
+
|
152 |
+
def num_joints(self):
|
153 |
+
return self._offsets.shape[0]
|
154 |
+
|
155 |
+
def offsets(self):
|
156 |
+
return self._offsets
|
157 |
+
|
158 |
+
def parents(self):
|
159 |
+
return self._parents
|
160 |
+
|
161 |
+
def has_children(self):
|
162 |
+
return self._has_children
|
163 |
+
|
164 |
+
def children(self):
|
165 |
+
return self._children
|
166 |
+
|
167 |
+
def convert_to_global_pos(self, unit_vec_rerp):
|
168 |
+
"""
|
169 |
+
Convert the unit offset matrix to global position.
|
170 |
+
First row(root) will have absolute position value in global coordinates.
|
171 |
+
"""
|
172 |
+
bone_length = self.get_bone_length_weight()
|
173 |
+
batch_size = unit_vec_rerp.size(0)
|
174 |
+
seq_len = unit_vec_rerp.size(1)
|
175 |
+
unit_vec_table = unit_vec_rerp.reshape(batch_size, seq_len, 22, 3)
|
176 |
+
global_position = torch.zeros_like(unit_vec_table, device=unit_vec_table.device)
|
177 |
+
|
178 |
+
for i, parent in enumerate(self._parents):
|
179 |
+
if parent == -1: # if root
|
180 |
+
global_position[:, :, i] = unit_vec_table[:, :, i]
|
181 |
+
|
182 |
+
else:
|
183 |
+
global_position[:, :, i] = global_position[:, :, parent] + (
|
184 |
+
nn.functional.normalize(unit_vec_table[:, :, i], p=2.0, dim=-1)
|
185 |
+
* bone_length[i]
|
186 |
+
)
|
187 |
+
|
188 |
+
return global_position
|
189 |
+
|
190 |
+
def convert_to_unit_offset_mat(self, global_position):
|
191 |
+
"""
|
192 |
+
Convert the global position of the skeleton to a unit offset matrix.
|
193 |
+
First row(root) will have absolute position value in global coordinates.
|
194 |
+
"""
|
195 |
+
|
196 |
+
bone_length = self.get_bone_length_weight()
|
197 |
+
unit_offset_mat = torch.zeros_like(
|
198 |
+
global_position, device=global_position.device
|
199 |
+
)
|
200 |
+
|
201 |
+
for i, parent in enumerate(self._parents):
|
202 |
+
|
203 |
+
if parent == -1: # if root
|
204 |
+
unit_offset_mat[:, :, i] = global_position[:, :, i]
|
205 |
+
else:
|
206 |
+
unit_offset_mat[:, :, i] = (
|
207 |
+
global_position[:, :, i] - global_position[:, :, parent]
|
208 |
+
) / bone_length[i]
|
209 |
+
|
210 |
+
return unit_offset_mat
|
211 |
+
|
212 |
+
def remove_joints(self, joints_to_remove):
|
213 |
+
"""
|
214 |
+
Remove the joints specified in 'joints_to_remove', both from the
|
215 |
+
skeleton definition and from the dataset (which is modified in place).
|
216 |
+
The rotations of removed joints are propagated along the kinematic chain.
|
217 |
+
"""
|
218 |
+
valid_joints = []
|
219 |
+
for joint in range(len(self._parents)):
|
220 |
+
if joint not in joints_to_remove:
|
221 |
+
valid_joints.append(joint)
|
222 |
+
|
223 |
+
index_offsets = np.zeros(len(self._parents), dtype=int)
|
224 |
+
new_parents = []
|
225 |
+
for i, parent in enumerate(self._parents):
|
226 |
+
if i not in joints_to_remove:
|
227 |
+
new_parents.append(parent - index_offsets[parent])
|
228 |
+
else:
|
229 |
+
index_offsets[i:] += 1
|
230 |
+
self._parents = np.array(new_parents)
|
231 |
+
|
232 |
+
self._offsets = self._offsets[valid_joints]
|
233 |
+
self._compute_metadata()
|
234 |
+
|
235 |
+
def forward_kinematics(self, rotations, root_positions):
|
236 |
+
"""
|
237 |
+
Perform forward kinematics using the given trajectory and local rotations.
|
238 |
+
Arguments (where N = batch size, L = sequence length, J = number of joints):
|
239 |
+
-- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint.
|
240 |
+
-- root_positions: (N, L, 3) tensor describing the root joint positions.
|
241 |
+
"""
|
242 |
+
assert len(rotations.shape) == 4
|
243 |
+
assert rotations.shape[-1] == 4
|
244 |
+
|
245 |
+
positions_world = []
|
246 |
+
rotations_world = []
|
247 |
+
|
248 |
+
expanded_offsets = self._offsets.expand(
|
249 |
+
rotations.shape[0],
|
250 |
+
rotations.shape[1],
|
251 |
+
self._offsets.shape[0],
|
252 |
+
self._offsets.shape[1],
|
253 |
+
)
|
254 |
+
|
255 |
+
# Parallelize along the batch and time dimensions
|
256 |
+
for i in range(self._offsets.shape[0]):
|
257 |
+
if self._parents[i] == -1:
|
258 |
+
positions_world.append(root_positions)
|
259 |
+
rotations_world.append(rotations[:, :, 0])
|
260 |
+
else:
|
261 |
+
positions_world.append(
|
262 |
+
qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i])
|
263 |
+
+ positions_world[self._parents[i]]
|
264 |
+
)
|
265 |
+
if self._has_children[i]:
|
266 |
+
rotations_world.append(
|
267 |
+
qmul(rotations_world[self._parents[i]], rotations[:, :, i])
|
268 |
+
)
|
269 |
+
else:
|
270 |
+
# This joint is a terminal node -> it would be useless to compute the transformation
|
271 |
+
rotations_world.append(None)
|
272 |
+
|
273 |
+
return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2)
|
274 |
+
|
275 |
+
def forward_kinematics_with_rotation(self, rotations, root_positions):
|
276 |
+
"""
|
277 |
+
Perform forward kinematics using the given trajectory and local rotations.
|
278 |
+
Arguments (where N = batch size, L = sequence length, J = number of joints):
|
279 |
+
-- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint.
|
280 |
+
-- root_positions: (N, L, 3) tensor describing the root joint positions.
|
281 |
+
"""
|
282 |
+
assert len(rotations.shape) == 4
|
283 |
+
assert rotations.shape[-1] == 4
|
284 |
+
|
285 |
+
positions_world = []
|
286 |
+
rotations_world = []
|
287 |
+
|
288 |
+
expanded_offsets = self._offsets.expand(
|
289 |
+
rotations.shape[0],
|
290 |
+
rotations.shape[1],
|
291 |
+
self._offsets.shape[0],
|
292 |
+
self._offsets.shape[1],
|
293 |
+
)
|
294 |
+
|
295 |
+
# Parallelize along the batch and time dimensions
|
296 |
+
for i in range(self._offsets.shape[0]):
|
297 |
+
if self._parents[i] == -1:
|
298 |
+
positions_world.append(root_positions)
|
299 |
+
rotations_world.append(rotations[:, :, 0])
|
300 |
+
else:
|
301 |
+
positions_world.append(
|
302 |
+
qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i])
|
303 |
+
+ positions_world[self._parents[i]]
|
304 |
+
)
|
305 |
+
if self._has_children[i]:
|
306 |
+
rotations_world.append(
|
307 |
+
qmul(rotations_world[self._parents[i]], rotations[:, :, i])
|
308 |
+
)
|
309 |
+
else:
|
310 |
+
# This joint is a terminal node -> it would be useless to compute the transformation
|
311 |
+
rotations_world.append(
|
312 |
+
torch.Tensor([1, 0, 0, 0])
|
313 |
+
.expand(rotations.shape[0], rotations.shape[1], 4)
|
314 |
+
.to(rotations.device)
|
315 |
+
)
|
316 |
+
|
317 |
+
return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2), torch.stack(
|
318 |
+
rotations_world, dim=3
|
319 |
+
).permute(0, 1, 3, 2)
|
320 |
+
|
321 |
+
def get_bone_length_weight(self):
|
322 |
+
bone_length = []
|
323 |
+
for i, parent in enumerate(self._parents):
|
324 |
+
if parent == -1:
|
325 |
+
bone_length.append(1)
|
326 |
+
else:
|
327 |
+
bone_length.append(
|
328 |
+
torch.linalg.norm(self._offsets[i : i + 1], ord="fro").item()
|
329 |
+
)
|
330 |
+
return torch.Tensor(bone_length)
|
331 |
+
|
332 |
+
def joints_left(self):
|
333 |
+
return self._joints_left
|
334 |
+
|
335 |
+
def joints_right(self):
|
336 |
+
return self._joints_right
|
337 |
+
|
338 |
+
def _compute_metadata(self):
|
339 |
+
self._has_children = np.zeros(len(self._parents)).astype(bool)
|
340 |
+
for i, parent in enumerate(self._parents):
|
341 |
+
if parent != -1:
|
342 |
+
self._has_children[parent] = True
|
343 |
+
|
344 |
+
self._children = []
|
345 |
+
for i, parent in enumerate(self._parents):
|
346 |
+
self._children.append([])
|
347 |
+
for i, parent in enumerate(self._parents):
|
348 |
+
if parent != -1:
|
349 |
+
self._children[parent].append(i)
|
infiller/lib/vis/pose.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pathlib
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def project_root_position(position_arr: np.array, file_name: str):
|
9 |
+
"""
|
10 |
+
Take batch of root arrays and porject it on 2D plane
|
11 |
+
|
12 |
+
N: samples
|
13 |
+
L: trajectory length
|
14 |
+
J: joints
|
15 |
+
|
16 |
+
position_arr: [N,L,J,3]
|
17 |
+
"""
|
18 |
+
|
19 |
+
root_joints = position_arr[:, :, 0]
|
20 |
+
|
21 |
+
x_pos = root_joints[:, :, 0]
|
22 |
+
y_pos = root_joints[:, :, 2]
|
23 |
+
|
24 |
+
fig = plt.figure()
|
25 |
+
|
26 |
+
for i in range(x_pos.shape[1]):
|
27 |
+
|
28 |
+
if i == 0:
|
29 |
+
plt.scatter(x_pos[:, i], y_pos[:, i], c="b")
|
30 |
+
elif i == x_pos.shape[1] - 1:
|
31 |
+
plt.scatter(x_pos[:, i], y_pos[:, i], c="r")
|
32 |
+
else:
|
33 |
+
plt.scatter(x_pos[:, i], y_pos[:, i], c="k", marker="*", s=1)
|
34 |
+
|
35 |
+
plt.title(f"Root Position: {file_name}")
|
36 |
+
plt.xlabel("X Axis")
|
37 |
+
plt.ylabel("Y Axis")
|
38 |
+
plt.xlim((-300, 300))
|
39 |
+
plt.ylim((-300, 300))
|
40 |
+
plt.grid()
|
41 |
+
plt.savefig(f"{file_name}.png", dpi=200)
|
42 |
+
|
43 |
+
|
44 |
+
def plot_single_pose(
|
45 |
+
pose,
|
46 |
+
frame_idx,
|
47 |
+
skeleton,
|
48 |
+
save_dir,
|
49 |
+
prefix,
|
50 |
+
):
|
51 |
+
fig = plt.figure()
|
52 |
+
ax = fig.add_subplot(111, projection="3d")
|
53 |
+
|
54 |
+
parent_idx = skeleton.parents()
|
55 |
+
|
56 |
+
for i, p in enumerate(parent_idx):
|
57 |
+
if i > 0:
|
58 |
+
ax.plot(
|
59 |
+
[pose[i, 0], pose[p, 0]],
|
60 |
+
[pose[i, 2], pose[p, 2]],
|
61 |
+
[pose[i, 1], pose[p, 1]],
|
62 |
+
c="k",
|
63 |
+
)
|
64 |
+
|
65 |
+
x_min = pose[:, 0].min()
|
66 |
+
x_max = pose[:, 0].max()
|
67 |
+
|
68 |
+
y_min = pose[:, 1].min()
|
69 |
+
y_max = pose[:, 1].max()
|
70 |
+
|
71 |
+
z_min = pose[:, 2].min()
|
72 |
+
z_max = pose[:, 2].max()
|
73 |
+
|
74 |
+
ax.set_xlim(x_min, x_max)
|
75 |
+
ax.set_xlabel("$X$ Axis")
|
76 |
+
|
77 |
+
ax.set_ylim(z_min, z_max)
|
78 |
+
ax.set_ylabel("$Y$ Axis")
|
79 |
+
|
80 |
+
ax.set_zlim(y_min, y_max)
|
81 |
+
ax.set_zlabel("$Z$ Axis")
|
82 |
+
|
83 |
+
plt.draw()
|
84 |
+
|
85 |
+
title = f"{prefix}: {frame_idx}"
|
86 |
+
plt.title(title)
|
87 |
+
prefix = prefix
|
88 |
+
pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)
|
89 |
+
plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60)
|
90 |
+
plt.close()
|
91 |
+
|
92 |
+
|
93 |
+
def plot_pose(
|
94 |
+
start_pose,
|
95 |
+
inbetween_pose,
|
96 |
+
target_pose,
|
97 |
+
frame_idx,
|
98 |
+
skeleton,
|
99 |
+
save_dir,
|
100 |
+
prefix,
|
101 |
+
):
|
102 |
+
fig = plt.figure()
|
103 |
+
ax = fig.add_subplot(111, projection="3d")
|
104 |
+
|
105 |
+
parent_idx = skeleton.parents()
|
106 |
+
|
107 |
+
for i, p in enumerate(parent_idx):
|
108 |
+
if i > 0:
|
109 |
+
ax.plot(
|
110 |
+
[start_pose[i, 0], start_pose[p, 0]],
|
111 |
+
[start_pose[i, 2], start_pose[p, 2]],
|
112 |
+
[start_pose[i, 1], start_pose[p, 1]],
|
113 |
+
c="b",
|
114 |
+
)
|
115 |
+
ax.plot(
|
116 |
+
[inbetween_pose[i, 0], inbetween_pose[p, 0]],
|
117 |
+
[inbetween_pose[i, 2], inbetween_pose[p, 2]],
|
118 |
+
[inbetween_pose[i, 1], inbetween_pose[p, 1]],
|
119 |
+
c="k",
|
120 |
+
)
|
121 |
+
ax.plot(
|
122 |
+
[target_pose[i, 0], target_pose[p, 0]],
|
123 |
+
[target_pose[i, 2], target_pose[p, 2]],
|
124 |
+
[target_pose[i, 1], target_pose[p, 1]],
|
125 |
+
c="r",
|
126 |
+
)
|
127 |
+
|
128 |
+
x_min = np.min(
|
129 |
+
[start_pose[:, 0].min(), inbetween_pose[:, 0].min(), target_pose[:, 0].min()]
|
130 |
+
)
|
131 |
+
x_max = np.max(
|
132 |
+
[start_pose[:, 0].max(), inbetween_pose[:, 0].max(), target_pose[:, 0].max()]
|
133 |
+
)
|
134 |
+
|
135 |
+
y_min = np.min(
|
136 |
+
[start_pose[:, 1].min(), inbetween_pose[:, 1].min(), target_pose[:, 1].min()]
|
137 |
+
)
|
138 |
+
y_max = np.max(
|
139 |
+
[start_pose[:, 1].max(), inbetween_pose[:, 1].max(), target_pose[:, 1].max()]
|
140 |
+
)
|
141 |
+
|
142 |
+
z_min = np.min(
|
143 |
+
[start_pose[:, 2].min(), inbetween_pose[:, 2].min(), target_pose[:, 2].min()]
|
144 |
+
)
|
145 |
+
z_max = np.max(
|
146 |
+
[start_pose[:, 2].max(), inbetween_pose[:, 2].max(), target_pose[:, 2].max()]
|
147 |
+
)
|
148 |
+
|
149 |
+
ax.set_xlim(x_min, x_max)
|
150 |
+
ax.set_xlabel("$X$ Axis")
|
151 |
+
|
152 |
+
ax.set_ylim(z_min, z_max)
|
153 |
+
ax.set_ylabel("$Y$ Axis")
|
154 |
+
|
155 |
+
ax.set_zlim(y_min, y_max)
|
156 |
+
ax.set_zlabel("$Z$ Axis")
|
157 |
+
|
158 |
+
plt.draw()
|
159 |
+
|
160 |
+
title = f"{prefix}: {frame_idx}"
|
161 |
+
plt.title(title)
|
162 |
+
prefix = prefix
|
163 |
+
pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)
|
164 |
+
plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60)
|
165 |
+
plt.close()
|
166 |
+
|
167 |
+
|
168 |
+
def plot_pose_with_stop(
|
169 |
+
start_pose,
|
170 |
+
inbetween_pose,
|
171 |
+
target_pose,
|
172 |
+
stopover,
|
173 |
+
frame_idx,
|
174 |
+
skeleton,
|
175 |
+
save_dir,
|
176 |
+
prefix,
|
177 |
+
):
|
178 |
+
fig = plt.figure()
|
179 |
+
ax = fig.add_subplot(111, projection="3d")
|
180 |
+
|
181 |
+
parent_idx = skeleton.parents()
|
182 |
+
|
183 |
+
for i, p in enumerate(parent_idx):
|
184 |
+
if i > 0:
|
185 |
+
ax.plot(
|
186 |
+
[start_pose[i, 0], start_pose[p, 0]],
|
187 |
+
[start_pose[i, 2], start_pose[p, 2]],
|
188 |
+
[start_pose[i, 1], start_pose[p, 1]],
|
189 |
+
c="b",
|
190 |
+
)
|
191 |
+
ax.plot(
|
192 |
+
[inbetween_pose[i, 0], inbetween_pose[p, 0]],
|
193 |
+
[inbetween_pose[i, 2], inbetween_pose[p, 2]],
|
194 |
+
[inbetween_pose[i, 1], inbetween_pose[p, 1]],
|
195 |
+
c="k",
|
196 |
+
)
|
197 |
+
ax.plot(
|
198 |
+
[target_pose[i, 0], target_pose[p, 0]],
|
199 |
+
[target_pose[i, 2], target_pose[p, 2]],
|
200 |
+
[target_pose[i, 1], target_pose[p, 1]],
|
201 |
+
c="r",
|
202 |
+
)
|
203 |
+
|
204 |
+
ax.plot(
|
205 |
+
[stopover[i, 0], stopover[p, 0]],
|
206 |
+
[stopover[i, 2], stopover[p, 2]],
|
207 |
+
[stopover[i, 1], stopover[p, 1]],
|
208 |
+
c="indigo",
|
209 |
+
)
|
210 |
+
|
211 |
+
x_min = np.min(
|
212 |
+
[start_pose[:, 0].min(), inbetween_pose[:, 0].min(), target_pose[:, 0].min()]
|
213 |
+
)
|
214 |
+
x_max = np.max(
|
215 |
+
[start_pose[:, 0].max(), inbetween_pose[:, 0].max(), target_pose[:, 0].max()]
|
216 |
+
)
|
217 |
+
|
218 |
+
y_min = np.min(
|
219 |
+
[start_pose[:, 1].min(), inbetween_pose[:, 1].min(), target_pose[:, 1].min()]
|
220 |
+
)
|
221 |
+
y_max = np.max(
|
222 |
+
[start_pose[:, 1].max(), inbetween_pose[:, 1].max(), target_pose[:, 1].max()]
|
223 |
+
)
|
224 |
+
|
225 |
+
z_min = np.min(
|
226 |
+
[start_pose[:, 2].min(), inbetween_pose[:, 2].min(), target_pose[:, 2].min()]
|
227 |
+
)
|
228 |
+
z_max = np.max(
|
229 |
+
[start_pose[:, 2].max(), inbetween_pose[:, 2].max(), target_pose[:, 2].max()]
|
230 |
+
)
|
231 |
+
|
232 |
+
ax.set_xlim(x_min, x_max)
|
233 |
+
ax.set_xlabel("$X$ Axis")
|
234 |
+
|
235 |
+
ax.set_ylim(z_min, z_max)
|
236 |
+
ax.set_ylabel("$Y$ Axis")
|
237 |
+
|
238 |
+
ax.set_zlim(y_min, y_max)
|
239 |
+
ax.set_zlabel("$Z$ Axis")
|
240 |
+
|
241 |
+
plt.draw()
|
242 |
+
|
243 |
+
title = f"{prefix}: {frame_idx}"
|
244 |
+
plt.title(title)
|
245 |
+
prefix = prefix
|
246 |
+
pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)
|
247 |
+
plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60)
|
248 |
+
plt.close()
|
lib/core/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (2.87 kB). View file
|
|
lib/core/constants.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FOCAL_LENGTH = 5000.
|
2 |
+
|
3 |
+
# Mean and standard deviation for normalizing input image
|
4 |
+
IMG_NORM_MEAN = [0.485, 0.456, 0.406]
|
5 |
+
IMG_NORM_STD = [0.229, 0.224, 0.225]
|
6 |
+
|
7 |
+
"""
|
8 |
+
We create a superset of joints containing the OpenPose joints together with the ones that each dataset provides.
|
9 |
+
We keep a superset of 24 joints such that we include all joints from every dataset.
|
10 |
+
If a dataset doesn't provide annotations for a specific joint, we simply ignore it.
|
11 |
+
The joints used here are the following:
|
12 |
+
"""
|
13 |
+
JOINT_NAMES = [
|
14 |
+
'OP Nose', 'OP Neck', 'OP RShoulder', #0,1,2
|
15 |
+
'OP RElbow', 'OP RWrist', 'OP LShoulder', #3,4,5
|
16 |
+
'OP LElbow', 'OP LWrist', 'OP MidHip', #6, 7,8
|
17 |
+
'OP RHip', 'OP RKnee', 'OP RAnkle', #9,10,11
|
18 |
+
'OP LHip', 'OP LKnee', 'OP LAnkle', #12,13,14
|
19 |
+
'OP REye', 'OP LEye', 'OP REar', #15,16,17
|
20 |
+
'OP LEar', 'OP LBigToe', 'OP LSmallToe', #18,19,20
|
21 |
+
'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel', #21, 22, 23, 24 ##Total 25 joints for openpose
|
22 |
+
'Right Ankle', 'Right Knee', 'Right Hip', #0,1,2
|
23 |
+
'Left Hip', 'Left Knee', 'Left Ankle', #3, 4, 5
|
24 |
+
'Right Wrist', 'Right Elbow', 'Right Shoulder', #6
|
25 |
+
'Left Shoulder', 'Left Elbow', 'Left Wrist', #9
|
26 |
+
'Neck (LSP)', 'Top of Head (LSP)', #12, 13
|
27 |
+
'Pelvis (MPII)', 'Thorax (MPII)', #14, 15
|
28 |
+
'Spine (H36M)', 'Jaw (H36M)', #16, 17
|
29 |
+
'Head (H36M)', 'Nose', 'Left Eye', #18, 19, 20
|
30 |
+
'Right Eye', 'Left Ear', 'Right Ear' #21,22,23 (Total 24 joints)
|
31 |
+
]
|
32 |
+
|
33 |
+
# Dict containing the joints in numerical order
|
34 |
+
JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))}
|
35 |
+
|
36 |
+
# Map joints to SMPL joints
|
37 |
+
JOINT_MAP = {
|
38 |
+
'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17,
|
39 |
+
'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16,
|
40 |
+
'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0,
|
41 |
+
'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8,
|
42 |
+
'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7,
|
43 |
+
'OP REye': 25, 'OP LEye': 26, 'OP REar': 27,
|
44 |
+
'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30,
|
45 |
+
'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34,
|
46 |
+
'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45,
|
47 |
+
'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7,
|
48 |
+
'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17,
|
49 |
+
'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20,
|
50 |
+
'Neck (LSP)': 47, 'Top of Head (LSP)': 48,
|
51 |
+
'Pelvis (MPII)': 49, 'Thorax (MPII)': 50,
|
52 |
+
'Spine (H36M)': 51, 'Jaw (H36M)': 52,
|
53 |
+
'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26,
|
54 |
+
'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27
|
55 |
+
}
|
56 |
+
|
57 |
+
# Joint selectors
|
58 |
+
# Indices to get the 14 LSP joints from the 17 H36M joints
|
59 |
+
H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9]
|
60 |
+
H36M_TO_J14 = H36M_TO_J17[:14]
|
61 |
+
# Indices to get the 14 LSP joints from the ground truth joints
|
62 |
+
J24_TO_J17 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18, 14, 16, 17]
|
63 |
+
J24_TO_J14 = J24_TO_J17[:14]
|
64 |
+
|
65 |
+
# Permutation of SMPL pose parameters when flipping the shape
|
66 |
+
SMPL_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 23, 22]
|
67 |
+
SMPL_POSE_FLIP_PERM = []
|
68 |
+
for i in SMPL_JOINTS_FLIP_PERM:
|
69 |
+
SMPL_POSE_FLIP_PERM.append(3*i)
|
70 |
+
SMPL_POSE_FLIP_PERM.append(3*i+1)
|
71 |
+
SMPL_POSE_FLIP_PERM.append(3*i+2)
|
72 |
+
# Permutation indices for the 24 ground truth joints
|
73 |
+
J24_FLIP_PERM = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22]
|
74 |
+
# Permutation indices for the full set of 49 joints
|
75 |
+
J49_FLIP_PERM = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]\
|
76 |
+
+ [25+i for i in J24_FLIP_PERM]
|
77 |
+
|
78 |
+
|
lib/datasets/__pycache__/track_dataset.cpython-310.pyc
ADDED
Binary file (2.28 kB). View file
|
|
lib/datasets/track_dataset.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from torchvision.transforms import Normalize, ToTensor, Compose
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
from lib.core import constants
|
8 |
+
from lib.utils.imutils import crop, boxes_2_cs
|
9 |
+
|
10 |
+
|
11 |
+
class TrackDatasetEval(Dataset):
|
12 |
+
"""
|
13 |
+
Track Dataset Class - Load images/crops of the tracked boxes.
|
14 |
+
"""
|
15 |
+
def __init__(self, imgfiles, boxes,
|
16 |
+
crop_size=256, dilate=1.0,
|
17 |
+
img_focal=None, img_center=None, normalization=True,
|
18 |
+
item_idx=0, do_flip=False):
|
19 |
+
super(TrackDatasetEval, self).__init__()
|
20 |
+
|
21 |
+
self.imgfiles = imgfiles
|
22 |
+
self.crop_size = crop_size
|
23 |
+
self.normalization = normalization
|
24 |
+
self.normalize_img = Compose([
|
25 |
+
ToTensor(),
|
26 |
+
Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD)
|
27 |
+
])
|
28 |
+
|
29 |
+
self.boxes = boxes
|
30 |
+
self.box_dilate = dilate
|
31 |
+
self.centers, self.scales = boxes_2_cs(boxes)
|
32 |
+
|
33 |
+
self.img_focal = img_focal
|
34 |
+
self.img_center = img_center
|
35 |
+
self.item_idx = item_idx
|
36 |
+
self.do_flip = do_flip
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.imgfiles)
|
40 |
+
|
41 |
+
|
42 |
+
def __getitem__(self, index):
|
43 |
+
item = {}
|
44 |
+
imgfile = self.imgfiles[index]
|
45 |
+
scale = self.scales[index] * self.box_dilate
|
46 |
+
center = self.centers[index]
|
47 |
+
|
48 |
+
img_focal = self.img_focal
|
49 |
+
img_center = self.img_center
|
50 |
+
|
51 |
+
img = cv2.imread(imgfile)[:,:,::-1]
|
52 |
+
if self.do_flip:
|
53 |
+
img = img[:, ::-1, :]
|
54 |
+
img_width = img.shape[1]
|
55 |
+
center[0] = img_width - center[0] - 1
|
56 |
+
img_crop = crop(img, center, scale,
|
57 |
+
[self.crop_size, self.crop_size],
|
58 |
+
rot=0).astype('uint8')
|
59 |
+
# cv2.imwrite('debug_crop.png', img_crop[:,:,::-1])
|
60 |
+
|
61 |
+
if self.normalization:
|
62 |
+
img_crop = self.normalize_img(img_crop)
|
63 |
+
else:
|
64 |
+
img_crop = torch.from_numpy(img_crop)
|
65 |
+
item['img'] = img_crop
|
66 |
+
|
67 |
+
if self.do_flip:
|
68 |
+
# center[0] = img_width - center[0] - 1
|
69 |
+
item['do_flip'] = torch.tensor(1).float()
|
70 |
+
item['img_idx'] = torch.tensor(index).long()
|
71 |
+
item['scale'] = torch.tensor(scale).float()
|
72 |
+
item['center'] = torch.tensor(center).float()
|
73 |
+
item['img_focal'] = torch.tensor(img_focal).float()
|
74 |
+
item['img_center'] = torch.tensor(img_center).float()
|
75 |
+
|
76 |
+
|
77 |
+
return item
|
78 |
+
|
lib/eval_utils/__pycache__/custom_utils.cpython-310.pyc
ADDED
Binary file (2.96 kB). View file
|
|
lib/eval_utils/__pycache__/filling_utils.cpython-310.pyc
ADDED
Binary file (6.88 kB). View file
|
|
lib/eval_utils/custom_utils.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from hawor.utils.process import run_mano, run_mano_left
|
6 |
+
from hawor.utils.rotation import angle_axis_to_quaternion, rotation_matrix_to_angle_axis
|
7 |
+
from scipy.interpolate import interp1d
|
8 |
+
|
9 |
+
|
10 |
+
def cam2world_convert(R_c2w_sla, t_c2w_sla, data_out, handedness):
|
11 |
+
init_rot_mat = copy.deepcopy(data_out["init_root_orient"])
|
12 |
+
init_rot_mat = torch.einsum("tij,btjk->btik", R_c2w_sla, init_rot_mat)
|
13 |
+
init_rot = rotation_matrix_to_angle_axis(init_rot_mat)
|
14 |
+
init_rot_quat = angle_axis_to_quaternion(init_rot)
|
15 |
+
# data_out["init_root_orient"] = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
|
16 |
+
# data_out["init_hand_pose"] = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
|
17 |
+
data_out_init_root_orient = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
|
18 |
+
data_out_init_hand_pose = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
|
19 |
+
|
20 |
+
init_trans = data_out["init_trans"] # (B, T, 3)
|
21 |
+
if handedness == "right":
|
22 |
+
outputs = run_mano(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"])
|
23 |
+
elif handedness == "left":
|
24 |
+
outputs = run_mano_left(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"])
|
25 |
+
root_loc = outputs["joints"][..., 0, :].cpu() # (B, T, 3)
|
26 |
+
offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
|
27 |
+
init_trans = (
|
28 |
+
torch.einsum("tij,btj->bti", R_c2w_sla, root_loc)
|
29 |
+
+ t_c2w_sla[None, :]
|
30 |
+
+ offset
|
31 |
+
)
|
32 |
+
|
33 |
+
data_world = {
|
34 |
+
"init_root_orient": init_rot, # (B, T, 3)
|
35 |
+
"init_hand_pose": data_out_init_hand_pose, # (B, T, 15, 3)
|
36 |
+
"init_trans": init_trans, # (B, T, 3)
|
37 |
+
"init_betas": data_out["init_betas"] # (B, T, 10)
|
38 |
+
}
|
39 |
+
|
40 |
+
return data_world
|
41 |
+
|
42 |
+
def quaternion_to_matrix(quaternions):
|
43 |
+
"""
|
44 |
+
Convert rotations given as quaternions to rotation matrices.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
quaternions: quaternions with real part first,
|
48 |
+
as tensor of shape (..., 4).
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
52 |
+
"""
|
53 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
54 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
55 |
+
|
56 |
+
o = torch.stack(
|
57 |
+
(
|
58 |
+
1 - two_s * (j * j + k * k),
|
59 |
+
two_s * (i * j - k * r),
|
60 |
+
two_s * (i * k + j * r),
|
61 |
+
two_s * (i * j + k * r),
|
62 |
+
1 - two_s * (i * i + k * k),
|
63 |
+
two_s * (j * k - i * r),
|
64 |
+
two_s * (i * k - j * r),
|
65 |
+
two_s * (j * k + i * r),
|
66 |
+
1 - two_s * (i * i + j * j),
|
67 |
+
),
|
68 |
+
-1,
|
69 |
+
)
|
70 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
71 |
+
|
72 |
+
def load_slam_cam(fpath):
|
73 |
+
print(f"Loading cameras from {fpath}...")
|
74 |
+
pred_cam = dict(np.load(fpath, allow_pickle=True))
|
75 |
+
pred_traj = pred_cam['traj']
|
76 |
+
t_c2w_sla = torch.tensor(pred_traj[:, :3]) * pred_cam['scale']
|
77 |
+
pred_camq = torch.tensor(pred_traj[:, 3:])
|
78 |
+
R_c2w_sla = quaternion_to_matrix(pred_camq[:,[3,0,1,2]])
|
79 |
+
R_w2c_sla = R_c2w_sla.transpose(-1, -2)
|
80 |
+
t_w2c_sla = -torch.einsum("bij,bj->bi", R_w2c_sla, t_c2w_sla)
|
81 |
+
return R_w2c_sla, t_w2c_sla, R_c2w_sla, t_c2w_sla
|
82 |
+
|
83 |
+
|
84 |
+
def interpolate_bboxes(bboxes):
|
85 |
+
T = bboxes.shape[0]
|
86 |
+
|
87 |
+
zero_indices = np.where(np.all(bboxes == 0, axis=1))[0]
|
88 |
+
|
89 |
+
non_zero_indices = np.where(np.any(bboxes != 0, axis=1))[0]
|
90 |
+
|
91 |
+
if len(zero_indices) == 0:
|
92 |
+
return bboxes
|
93 |
+
|
94 |
+
interpolated_bboxes = bboxes.copy()
|
95 |
+
for i in range(5):
|
96 |
+
interp_func = interp1d(non_zero_indices, bboxes[non_zero_indices, i], kind='linear', fill_value="extrapolate")
|
97 |
+
interpolated_bboxes[zero_indices, i] = interp_func(zero_indices)
|
98 |
+
|
99 |
+
return interpolated_bboxes
|
lib/eval_utils/filling_utils.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import joblib
|
4 |
+
import numpy as np
|
5 |
+
from scipy.spatial.transform import Slerp, Rotation
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from hawor.utils.process import run_mano, run_mano_left
|
9 |
+
from hawor.utils.rotation import angle_axis_to_quaternion, angle_axis_to_rotation_matrix, quaternion_to_rotation_matrix, rotation_matrix_to_angle_axis
|
10 |
+
from lib.utils.geometry import rotmat_to_rot6d
|
11 |
+
from lib.utils.geometry import rot6d_to_rotmat
|
12 |
+
|
13 |
+
def slerp_interpolation_aa(pos, valid):
|
14 |
+
|
15 |
+
B, T, N, _ = pos.shape # B: 批次大小, T: 时间步长, N: 关节数, 4: 四元数维度
|
16 |
+
pos_interp = pos.copy() # 创建副本以存储插值结果
|
17 |
+
|
18 |
+
for b in range(B):
|
19 |
+
for n in range(N):
|
20 |
+
quat_b_n = pos[b, :, n, :]
|
21 |
+
valid_b_n = valid[b, :]
|
22 |
+
|
23 |
+
invalid_idxs = np.where(~valid_b_n)[0]
|
24 |
+
valid_idxs = np.where(valid_b_n)[0]
|
25 |
+
|
26 |
+
if len(invalid_idxs) == 0:
|
27 |
+
continue
|
28 |
+
|
29 |
+
if len(valid_idxs) > 1:
|
30 |
+
valid_times = valid_idxs # 有效时间步
|
31 |
+
valid_rots = Rotation.from_rotvec(quat_b_n[valid_idxs]) # 有效四元数
|
32 |
+
|
33 |
+
slerp = Slerp(valid_times, valid_rots)
|
34 |
+
|
35 |
+
for idx in invalid_idxs:
|
36 |
+
if idx < valid_idxs[0]: # 时间步小于第一个有效时间步,进行外推
|
37 |
+
pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[0]] # 复制第一个有效四元数
|
38 |
+
elif idx > valid_idxs[-1]: # 时间步大于最后一个有效时间步,进行外推
|
39 |
+
pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[-1]] # 复制最后一个有效四元数
|
40 |
+
else:
|
41 |
+
interp_rot = slerp([idx])
|
42 |
+
pos_interp[b, idx, n, :] = interp_rot.as_rotvec()[0]
|
43 |
+
# print("#######")
|
44 |
+
# if N > 1:
|
45 |
+
# print(pos[1,0,11])
|
46 |
+
# print(pos_interp[1,0,11])
|
47 |
+
|
48 |
+
return pos_interp
|
49 |
+
|
50 |
+
def slerp_interpolation_quat(pos, valid):
|
51 |
+
|
52 |
+
# wxyz to xyzw
|
53 |
+
pos = pos[:, :, :, [1, 2, 3, 0]]
|
54 |
+
|
55 |
+
B, T, N, _ = pos.shape # B: 批次大小, T: 时间步长, N: 关节数, 4: 四元数维度
|
56 |
+
pos_interp = pos.copy() # 创建副本以存储插值结果
|
57 |
+
|
58 |
+
for b in range(B):
|
59 |
+
for n in range(N):
|
60 |
+
quat_b_n = pos[b, :, n, :]
|
61 |
+
valid_b_n = valid[b, :]
|
62 |
+
|
63 |
+
invalid_idxs = np.where(~valid_b_n)[0]
|
64 |
+
valid_idxs = np.where(valid_b_n)[0]
|
65 |
+
|
66 |
+
if len(invalid_idxs) == 0:
|
67 |
+
continue
|
68 |
+
|
69 |
+
if len(valid_idxs) > 1:
|
70 |
+
valid_times = valid_idxs # 有效时间步
|
71 |
+
valid_rots = Rotation.from_quat(quat_b_n[valid_idxs]) # 有效四元数
|
72 |
+
|
73 |
+
slerp = Slerp(valid_times, valid_rots)
|
74 |
+
|
75 |
+
for idx in invalid_idxs:
|
76 |
+
if idx < valid_idxs[0]: # 时间步小于第一个有效时间步,进行外推
|
77 |
+
pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[0]] # 复制第一个有效四元数
|
78 |
+
elif idx > valid_idxs[-1]: # 时间步大于最后一个有效时间步,进行外推
|
79 |
+
pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[-1]] # 复制最后一个有效四元数
|
80 |
+
else:
|
81 |
+
interp_rot = slerp([idx])
|
82 |
+
pos_interp[b, idx, n, :] = interp_rot.as_quat()[0]
|
83 |
+
|
84 |
+
# xyzw to wxyz
|
85 |
+
pos_interp = pos_interp[:, :, :, [3, 0, 1, 2]]
|
86 |
+
return pos_interp
|
87 |
+
|
88 |
+
|
89 |
+
def linear_interpolation_nd(pos, valid):
|
90 |
+
B, T = pos.shape[:2] # 取出批次大小B和时间步长T
|
91 |
+
feature_dim = pos.shape[2] # ** 代表的任意维度
|
92 |
+
pos_interp = pos.copy() # 创建一个副本,用来保存插值结果
|
93 |
+
|
94 |
+
for b in range(B):
|
95 |
+
for idx in range(feature_dim): # 针对任意维度
|
96 |
+
pos_b_idx = pos[b, :, idx] # 取出第b批次对应的**维度下的一个时间序列
|
97 |
+
valid_b = valid[b, :] # 当前批次的有效标志
|
98 |
+
|
99 |
+
# 找到无效的索引(False)
|
100 |
+
invalid_idxs = np.where(~valid_b)[0]
|
101 |
+
valid_idxs = np.where(valid_b)[0]
|
102 |
+
|
103 |
+
if len(invalid_idxs) == 0:
|
104 |
+
continue
|
105 |
+
|
106 |
+
# 对无效部分进行线性插值
|
107 |
+
if len(valid_idxs) > 1: # 确保有足够的有效点用于插值
|
108 |
+
pos_b_idx[invalid_idxs] = np.interp(invalid_idxs, valid_idxs, pos_b_idx[valid_idxs])
|
109 |
+
pos_interp[b, :, idx] = pos_b_idx # 保存插值结果
|
110 |
+
|
111 |
+
return pos_interp
|
112 |
+
|
113 |
+
def world2canonical_convert(R_c2w_sla, t_c2w_sla, data_out, handedness):
|
114 |
+
init_rot_mat = copy.deepcopy(data_out["init_root_orient"])
|
115 |
+
init_rot_mat = torch.einsum("tij,btjk->btik", R_c2w_sla, init_rot_mat)
|
116 |
+
init_rot = rotation_matrix_to_angle_axis(init_rot_mat)
|
117 |
+
init_rot_quat = angle_axis_to_quaternion(init_rot)
|
118 |
+
# data_out["init_root_orient"] = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
|
119 |
+
# data_out["init_hand_pose"] = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
|
120 |
+
data_out_init_root_orient = rotation_matrix_to_angle_axis(data_out["init_root_orient"])
|
121 |
+
data_out_init_hand_pose = rotation_matrix_to_angle_axis(data_out["init_hand_pose"])
|
122 |
+
|
123 |
+
init_trans = data_out["init_trans"] # (B, T, 3)
|
124 |
+
if handedness == "left":
|
125 |
+
outputs = run_mano_left(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"])
|
126 |
+
|
127 |
+
elif handedness == "right":
|
128 |
+
outputs = run_mano(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"])
|
129 |
+
root_loc = outputs["joints"][..., 0, :].cpu() # (B, T, 3)
|
130 |
+
offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
|
131 |
+
init_trans = (
|
132 |
+
torch.einsum("tij,btj->bti", R_c2w_sla, root_loc)
|
133 |
+
+ t_c2w_sla[None, :]
|
134 |
+
+ offset
|
135 |
+
)
|
136 |
+
|
137 |
+
data_world = {
|
138 |
+
"init_root_orient": init_rot, # (B, T, 3)
|
139 |
+
"init_hand_pose": data_out_init_hand_pose, # (B, T, 15, 3)
|
140 |
+
"init_trans": init_trans, # (B, T, 3)
|
141 |
+
"init_betas": data_out["init_betas"] # (B, T, 10)
|
142 |
+
}
|
143 |
+
|
144 |
+
return data_world
|
145 |
+
|
146 |
+
def filling_preprocess(item):
|
147 |
+
|
148 |
+
num_joints = 15
|
149 |
+
|
150 |
+
global_trans = item['trans'] # (2, seq_len, 3)
|
151 |
+
global_rot = item['rot'] #(2, seq_len, 3)
|
152 |
+
hand_pose = item['hand_pose'] # (2, seq_len, 45)
|
153 |
+
betas = item['betas'] # (2, seq_len, 10)
|
154 |
+
valid = item['valid'] # (2, seq_len)
|
155 |
+
|
156 |
+
N, T, _ = global_trans.shape
|
157 |
+
R_canonical2world_left_aa = torch.from_numpy(global_rot[0, 0])
|
158 |
+
R_canonical2world_right_aa = torch.from_numpy(global_rot[1, 0])
|
159 |
+
R_world2canonical_left = angle_axis_to_rotation_matrix(R_canonical2world_left_aa).t()
|
160 |
+
R_world2canonical_right = angle_axis_to_rotation_matrix(R_canonical2world_right_aa).t()
|
161 |
+
|
162 |
+
|
163 |
+
# transform left hand to canonical
|
164 |
+
hand_pose = hand_pose.reshape(N, T, num_joints, 3)
|
165 |
+
data_world_left = {
|
166 |
+
"init_trans": torch.from_numpy(global_trans[0:1]),
|
167 |
+
"init_root_orient": angle_axis_to_rotation_matrix(torch.from_numpy(global_rot[0:1])),
|
168 |
+
"init_hand_pose": angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose[0:1])),
|
169 |
+
"init_betas": torch.from_numpy(betas[0:1]),
|
170 |
+
}
|
171 |
+
|
172 |
+
data_left_init_root_orient = rotation_matrix_to_angle_axis(data_world_left["init_root_orient"])
|
173 |
+
data_left_init_hand_pose = rotation_matrix_to_angle_axis(data_world_left["init_hand_pose"])
|
174 |
+
outputs = run_mano_left(data_world_left["init_trans"], data_left_init_root_orient, data_left_init_hand_pose, betas=data_world_left["init_betas"])
|
175 |
+
init_trans = data_world_left["init_trans"][0, 0] # (3,)
|
176 |
+
root_loc = outputs["joints"][0, 0, 0, :].cpu() # (3,)
|
177 |
+
offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
|
178 |
+
t_world2canonical_left = -torch.einsum("ij,j->i", R_world2canonical_left, root_loc) - offset
|
179 |
+
|
180 |
+
R_world2canonical_left = R_world2canonical_left.repeat(T, 1, 1)
|
181 |
+
t_world2canonical_left = t_world2canonical_left.repeat(T, 1)
|
182 |
+
data_canonical_left = world2canonical_convert(R_world2canonical_left, t_world2canonical_left, data_world_left, "left")
|
183 |
+
|
184 |
+
# transform right hand to canonical
|
185 |
+
data_world_right = {
|
186 |
+
"init_trans": torch.from_numpy(global_trans[1:2]),
|
187 |
+
"init_root_orient": angle_axis_to_rotation_matrix(torch.from_numpy(global_rot[1:2])),
|
188 |
+
"init_hand_pose": angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose[1:2])),
|
189 |
+
"init_betas": torch.from_numpy(betas[1:2]),
|
190 |
+
}
|
191 |
+
|
192 |
+
data_right_init_root_orient = rotation_matrix_to_angle_axis(data_world_right["init_root_orient"])
|
193 |
+
data_right_init_hand_pose = rotation_matrix_to_angle_axis(data_world_right["init_hand_pose"])
|
194 |
+
outputs = run_mano(data_world_right["init_trans"], data_right_init_root_orient, data_right_init_hand_pose, betas=data_world_right["init_betas"])
|
195 |
+
init_trans = data_world_right["init_trans"][0, 0] # (3,)
|
196 |
+
root_loc = outputs["joints"][0, 0, 0, :].cpu() # (3,)
|
197 |
+
offset = init_trans - root_loc # It is a constant, no matter what the rotation is.
|
198 |
+
t_world2canonical_right = -torch.einsum("ij,j->i", R_world2canonical_right, root_loc) - offset
|
199 |
+
|
200 |
+
R_world2canonical_right = R_world2canonical_right.repeat(T, 1, 1)
|
201 |
+
t_world2canonical_right = t_world2canonical_right.repeat(T, 1)
|
202 |
+
data_canonical_right = world2canonical_convert(R_world2canonical_right, t_world2canonical_right, data_world_right, "right")
|
203 |
+
|
204 |
+
# merge left and right canonical data
|
205 |
+
global_rot = torch.cat((data_canonical_left['init_root_orient'], data_canonical_right['init_root_orient']))
|
206 |
+
global_trans = torch.cat((data_canonical_left['init_trans'], data_canonical_right['init_trans'])).numpy()
|
207 |
+
|
208 |
+
# global_rot = angle_axis_to_quaternion(global_rot).numpy().reshape(N, T, 1, 4)
|
209 |
+
global_rot = global_rot.reshape(N, T, 1, 3).numpy()
|
210 |
+
|
211 |
+
hand_pose = hand_pose.reshape(N, T, 15, 3)
|
212 |
+
# hand_pose = angle_axis_to_quaternion(torch.from_numpy(hand_pose)).numpy()
|
213 |
+
|
214 |
+
# lerp and slerp
|
215 |
+
global_trans_lerped = linear_interpolation_nd(global_trans, valid)
|
216 |
+
betas_lerped = linear_interpolation_nd(betas, valid)
|
217 |
+
global_rot_slerped = slerp_interpolation_aa(global_rot, valid)
|
218 |
+
hand_pose_slerped = slerp_interpolation_aa(hand_pose, valid)
|
219 |
+
|
220 |
+
|
221 |
+
# convert to rot6d
|
222 |
+
|
223 |
+
global_rot_slerped_mat = angle_axis_to_rotation_matrix(torch.from_numpy(global_rot_slerped.reshape(N*T, -1)))
|
224 |
+
# global_rot_slerped_mat = quaternion_to_rotation_matrix(torch.from_numpy(global_rot_slerped.reshape(N*T, -1)))
|
225 |
+
global_rot_slerped_rot6d = rotmat_to_rot6d(global_rot_slerped_mat).reshape(N, T, -1).numpy()
|
226 |
+
hand_pose_slerped_mat = angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose_slerped.reshape(N*T*num_joints, -1)))
|
227 |
+
# hand_pose_slerped_mat = quaternion_to_rotation_matrix(torch.from_numpy(hand_pose_slerped.reshape(N*T*num_joints, -1)))
|
228 |
+
hand_pose_slerped_rot6d = rotmat_to_rot6d(hand_pose_slerped_mat).reshape(N, T, -1).numpy()
|
229 |
+
|
230 |
+
|
231 |
+
# concat to (T, concat_dim)
|
232 |
+
global_pose_vec_input = np.concatenate((global_trans_lerped, betas_lerped, global_rot_slerped_rot6d, hand_pose_slerped_rot6d), axis=-1).transpose(1, 0, 2).reshape(T, -1)
|
233 |
+
|
234 |
+
R_canon2w_left = R_world2canonical_left.transpose(-1, -2)
|
235 |
+
t_canon2w_left = -torch.einsum("tij,tj->ti", R_canon2w_left, t_world2canonical_left)
|
236 |
+
R_canon2w_right = R_world2canonical_right.transpose(-1, -2)
|
237 |
+
t_canon2w_right = -torch.einsum("tij,tj->ti", R_canon2w_right, t_world2canonical_right)
|
238 |
+
|
239 |
+
transform_w_canon = {
|
240 |
+
"R_w2canon_left": R_world2canonical_left,
|
241 |
+
"t_w2canon_left": t_world2canonical_left,
|
242 |
+
"R_canon2w_left": R_canon2w_left,
|
243 |
+
"t_canon2w_left": t_canon2w_left,
|
244 |
+
|
245 |
+
"R_w2canon_right": R_world2canonical_right,
|
246 |
+
"t_w2canon_right": t_world2canonical_right,
|
247 |
+
"R_canon2w_right": R_canon2w_right,
|
248 |
+
"t_canon2w_right": t_canon2w_right,
|
249 |
+
}
|
250 |
+
|
251 |
+
return global_pose_vec_input, transform_w_canon
|
252 |
+
|
253 |
+
def custom_rot6d_to_rotmat(rot6d):
|
254 |
+
original_shape = rot6d.shape[:-1]
|
255 |
+
rot6d = rot6d.reshape(-1, 6)
|
256 |
+
mat = rot6d_to_rotmat(rot6d)
|
257 |
+
mat = mat.reshape(*original_shape, 3, 3)
|
258 |
+
return mat
|
259 |
+
|
260 |
+
def filling_postprocess(output, transform_w_canon):
|
261 |
+
# output = output.numpy()
|
262 |
+
output = output.permute(1, 0, 2) # (2, T, -1)
|
263 |
+
N, T, _ = output.shape
|
264 |
+
canon_trans = output[:, :, :3]
|
265 |
+
betas = output[:, :, 3:13]
|
266 |
+
canon_rot_rot6d = output[:, :, 13:19]
|
267 |
+
hand_pose_rot6d = output[:, :, 19:109].reshape(N, T, 15, 6)
|
268 |
+
|
269 |
+
canon_rot_mat = custom_rot6d_to_rotmat(canon_rot_rot6d)
|
270 |
+
hand_pose_mat = custom_rot6d_to_rotmat(hand_pose_rot6d)
|
271 |
+
|
272 |
+
data_canonical_left = {
|
273 |
+
"init_trans": canon_trans[[0], :, :],
|
274 |
+
"init_root_orient": canon_rot_mat[[0], :, :, :],
|
275 |
+
"init_hand_pose": hand_pose_mat[[0], :, :, :, :],
|
276 |
+
"init_betas": betas[[0], :, :]
|
277 |
+
}
|
278 |
+
|
279 |
+
data_canonical_right = {
|
280 |
+
"init_trans": canon_trans[[1], :, :],
|
281 |
+
"init_root_orient": canon_rot_mat[[1], :, :, :],
|
282 |
+
"init_hand_pose": hand_pose_mat[[1], :, :, :, :],
|
283 |
+
"init_betas": betas[[1], :, :]
|
284 |
+
}
|
285 |
+
|
286 |
+
R_canon2w_left = transform_w_canon['R_canon2w_left']
|
287 |
+
t_canon2w_left = transform_w_canon['t_canon2w_left']
|
288 |
+
R_canon2w_right = transform_w_canon['R_canon2w_right']
|
289 |
+
t_canon2w_right = transform_w_canon['t_canon2w_right']
|
290 |
+
|
291 |
+
|
292 |
+
world_left = world2canonical_convert(R_canon2w_left, t_canon2w_left, data_canonical_left, "left")
|
293 |
+
world_right = world2canonical_convert(R_canon2w_right, t_canon2w_right, data_canonical_right, "right")
|
294 |
+
|
295 |
+
global_rot = torch.cat((world_left['init_root_orient'], world_right['init_root_orient'])).numpy()
|
296 |
+
global_trans = torch.cat((world_left['init_trans'], world_right['init_trans'])).numpy()
|
297 |
+
|
298 |
+
pred_data = {
|
299 |
+
"trans": global_trans, # (2, T, 3)
|
300 |
+
"rot": global_rot, # (2, T, 3)
|
301 |
+
"hand_pose": rotation_matrix_to_angle_axis(hand_pose_mat).flatten(-2).numpy(), # (2, T, 45)
|
302 |
+
"betas": betas.numpy(), # (2, T, 10)
|
303 |
+
}
|
304 |
+
|
305 |
+
return pred_data
|
306 |
+
|
lib/eval_utils/video_utils.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
|
5 |
+
def make_video_grid_2x2(out_path, vid_paths, overwrite=False):
|
6 |
+
"""
|
7 |
+
将四个视频以原始分辨率拼接成 2x2 网格。
|
8 |
+
|
9 |
+
:param out_path: 输出视频路径。
|
10 |
+
:param vid_paths: 输入视频路径的列表(长度必须为 4)。
|
11 |
+
:param overwrite: 如果为 True,覆盖已存在的输出文件。
|
12 |
+
"""
|
13 |
+
if os.path.isfile(out_path) and not overwrite:
|
14 |
+
print(f"{out_path} already exists, skipping.")
|
15 |
+
return
|
16 |
+
|
17 |
+
if any(not os.path.isfile(v) for v in vid_paths):
|
18 |
+
print("Not all inputs exist!", vid_paths)
|
19 |
+
return
|
20 |
+
|
21 |
+
# 确保视频路径长度为 4
|
22 |
+
if len(vid_paths) != 4:
|
23 |
+
print("Error: Exactly 4 video paths are required!")
|
24 |
+
return
|
25 |
+
|
26 |
+
# 获取视频路径
|
27 |
+
v1, v2, v3, v4 = vid_paths
|
28 |
+
|
29 |
+
# ffmpeg 拼接命令,直接拼接不调整大小
|
30 |
+
cmd = (
|
31 |
+
f"ffmpeg -i {v1} -i {v2} -i {v3} -i {v4} "
|
32 |
+
f"-filter_complex '[0:v][1:v][2:v][3:v]xstack=inputs=4:layout=0_0|w0_0|0_h0|w0_h0[v]' "
|
33 |
+
f"-map '[v]' {out_path} -y"
|
34 |
+
)
|
35 |
+
|
36 |
+
print(cmd)
|
37 |
+
subprocess.call(cmd, shell=True, stdin=subprocess.PIPE)
|
38 |
+
|
39 |
+
def create_video_from_images(image_list, output_path, fps=15, target_resolution=(540, 540)):
|
40 |
+
"""
|
41 |
+
将图片列表合成为 MP4 视频。
|
42 |
+
|
43 |
+
:param image_list: 图片路径的列表。
|
44 |
+
:param output_path: 输出视频的文件路径(如 output.mp4)。
|
45 |
+
:param fps: 视频的帧率(默认 15 FPS)。
|
46 |
+
"""
|
47 |
+
# if not image_list:
|
48 |
+
# print("图片列表为空!")
|
49 |
+
# return
|
50 |
+
|
51 |
+
# 读取第一张图片以获取宽度和高度
|
52 |
+
first_image = cv2.imread(image_list[0])
|
53 |
+
if first_image is None:
|
54 |
+
print(f"无法读取图片: {image_list[0]}")
|
55 |
+
return
|
56 |
+
|
57 |
+
height, width, _ = first_image.shape
|
58 |
+
if height != width:
|
59 |
+
if height < width:
|
60 |
+
vis_w = target_resolution[0]
|
61 |
+
vis_h = int(target_resolution[0] / width * height)
|
62 |
+
elif height > width:
|
63 |
+
vis_h = target_resolution[0]
|
64 |
+
vis_w = int(target_resolution[0] / height * width)
|
65 |
+
else:
|
66 |
+
vis_h = target_resolution[0]
|
67 |
+
vis_w = target_resolution[0]
|
68 |
+
target_resolution = (vis_w, vis_h)
|
69 |
+
|
70 |
+
# 定义视频编码器和输出参数
|
71 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 mp4v 编码器
|
72 |
+
video_writer = cv2.VideoWriter(output_path, fourcc, fps, target_resolution)
|
73 |
+
|
74 |
+
# 遍历图片列表并写入视频
|
75 |
+
for image_path in image_list:
|
76 |
+
frame = cv2.imread(image_path)
|
77 |
+
frame_resized = cv2.resize(frame, target_resolution)
|
78 |
+
if frame is None:
|
79 |
+
print(f"无法读取图片: {image_path}")
|
80 |
+
continue
|
81 |
+
video_writer.write(frame_resized)
|
82 |
+
|
83 |
+
# 释放视频写入器
|
84 |
+
video_writer.release()
|
85 |
+
print(f"视频已保存至: {output_path}")
|
lib/models/__pycache__/hawor.cpython-310.pyc
ADDED
Binary file (15.3 kB). View file
|
|
lib/models/__pycache__/mano_wrapper.cpython-310.pyc
ADDED
Binary file (2.43 kB). View file
|
|
lib/models/__pycache__/modules.cpython-310.pyc
ADDED
Binary file (4.71 kB). View file
|
|
lib/models/backbones/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .vit import vit
|
2 |
+
|
3 |
+
|
4 |
+
def create_backbone(cfg):
|
5 |
+
if cfg.MODEL.BACKBONE.TYPE == 'vit':
|
6 |
+
return vit(cfg)
|
7 |
+
else:
|
8 |
+
raise NotImplementedError('Backbone type is not implemented')
|
lib/models/backbones/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (445 Bytes). View file
|
|
lib/models/backbones/__pycache__/vit.cpython-310.pyc
ADDED
Binary file (11.3 kB). View file
|
|