diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..283fbfa5f5b59c93d9e4e77879c65a409bbe2afc 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 98a49b56b246188ff059169ea200a675b3cb450a..213bdb447883f54623dee74c08504c2b37a2e1a3 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,94 @@ ---- -title: HaWoR -emoji: 👁 -colorFrom: green -colorTo: indigo -sdk: gradio -sdk_version: 5.9.1 -app_file: app.py -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +
+ +# HaWoR: World-Space Hand Motion Reconstruction from Egocentric Videos + +[Jinglei Zhang]()1   [Jiankang Deng](https://jiankangdeng.github.io/)2   [Chao Ma](https://scholar.google.com/citations?user=syoPhv8AAAAJ&hl=en)1   [Rolandos Alexandros Potamias](https://rolpotamias.github.io)2   + +1Shanghai Jiao Tong University, China +2Imperial College London, UK
+ + + +
+ +This is the official implementation of **[HaWoR](https://hawor-project.github.io/)**, a hand reconstruction model in the world coordinates: + +![teaser](assets/teaser.png) + +## Installation + +### Installation +``` +git clone --recursive https://github.com/ThunderVVV/HaWoR.git +cd HaWoR +``` + +The code has been tested with PyTorch 1.13 and CUDA 11.7. It is suggested to use an anaconda environment to install the the required dependencies: +```bash +conda create --name hawor python=3.10 +conda activate hawor + +pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 +# Install requirements +pip install -r requirements.txt +pip install pytorch-lightning==2.2.4 --no-deps +pip install lightning-utilities torchmetrics==1.4.0 +``` + +### Install masked DROID-SLAM: + +``` +cd thirdparty/DROID-SLAM +python setup.py install +``` + +Download DROID-SLAM official weights [droid.pth](https://drive.google.com/file/d/1PpqVt1H4maBa_GbPJp4NwxRsd9jk-elh/view?usp=sharing), put it under `./weights/external/`. + +### Install Metric3D + +Download Metric3D official weights [metric_depth_vit_large_800k.pth](https://drive.google.com/file/d/1eT2gG-kwsVzNy5nJrbm4KC-9DbNKyLnr/view?usp=drive_link), put it under `thirdparty/Metric3D/weights`. + +### Download the model weights + +```bash +wget https://huggingface.co/spaces/rolpotamias/WiLoR/resolve/main/pretrained_models/detector.pt -P ./weights/external/ +wget https://huggingface.co/ThunderVVV/HaWoR/resolve/main/hawor/checkpoints/hawor.ckpt -P ./weights/hawor/checkpoints/ +wget https://huggingface.co/ThunderVVV/HaWoR/resolve/main/hawor/checkpoints/infiller.pt -P ./weights/hawor/checkpoints/ +wget https://huggingface.co/ThunderVVV/HaWoR/resolve/main/hawor/model_config.yaml -P ./weights/hawor/ +``` +It is also required to download MANO model from [MANO website](https://mano.is.tue.mpg.de). +Create an account by clicking Sign Up and download the models (mano_v*_*.zip). Unzip and put the hand model to the `_DATA/data/mano/MANO_RIGHT.pkl` and `_DATA/data_left/mano_left/MANO_LEFT.pkl`. + +Note that MANO model falls under the [MANO license](https://mano.is.tue.mpg.de/license.html). +## Demo + +For visualizaiton in world view, run with: +```bash +python demo.py --video_path ./example/video_0.mp4 --vis_mode world +``` + +For visualizaiton in camera view, run with: +```bash +python demo.py --video_path ./example/video_0.mp4 --vis_mode cam +``` + +## Training +The training code will be released soon. + +## Acknowledgements +Parts of the code are taken or adapted from the following repos: +- [HaMeR](https://github.com/geopavlakos/hamer/) +- [WiLoR](https://github.com/rolpotamias/WiLoR) +- [SLAHMR](https://github.com/vye16/slahmr) +- [TRAM](https://github.com/yufu-wang/tram) +- [CMIB](https://github.com/jihoonerd/Conditional-Motion-In-Betweening) + + +## License +HaWoR models fall under the [CC-BY-NC--ND License](./license.txt). This repository depends also on [MANO Model](https://mano.is.tue.mpg.de/license.html), which are fall under their own licenses. By using this repository, you must also comply with the terms of these external licenses. +## Citing +If you find HaWoR useful for your research, please consider citing our paper: + +```bibtex + +``` diff --git a/_DATA/data/mano/.gitkeep b/_DATA/data/mano/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/_DATA/data/mano/MANO_RIGHT.pkl b/_DATA/data/mano/MANO_RIGHT.pkl new file mode 100644 index 0000000000000000000000000000000000000000..8e7ac7faf64ad51096ec1da626ea13757ed7f665 --- /dev/null +++ b/_DATA/data/mano/MANO_RIGHT.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45d60aa3b27ef9107a7afd4e00808f307fd91111e1cfa35afd5c4a62de264767 +size 3821356 diff --git a/_DATA/data/mano_mean_params.npz b/_DATA/data/mano_mean_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..dc294b01fb78a9cd6636c87a69b59cf82d28d15b --- /dev/null +++ b/_DATA/data/mano_mean_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:efc0ec58e4a5cef78f3abfb4e8f91623b8950be9eff8b8e0dbb0d036ebc63988 +size 1178 diff --git a/_DATA/data_left/mano_left/.gitkeep b/_DATA/data_left/mano_left/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/_DATA/data_left/mano_left/MANO_LEFT.pkl b/_DATA/data_left/mano_left/MANO_LEFT.pkl new file mode 100755 index 0000000000000000000000000000000000000000..32cdc533e2c01ed4995db2dc1302520d7d374c5a --- /dev/null +++ b/_DATA/data_left/mano_left/MANO_LEFT.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4022f7083f2ca7c78b2b3d595abbab52debd32b09d372b16923a801f0ea6a30 +size 3821391 diff --git a/assets/teaser.png b/assets/teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..1f3e4905bb6fc0d2d4f09966e379f78ba92c39b4 --- /dev/null +++ b/assets/teaser.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b33d76e9a10f215f0777612dd32ac73a5ce3b0e8735813968e7048ecd1ed3a1 +size 1118621 diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..2a437853015f84e74a059289f217b5bc12a00ed5 --- /dev/null +++ b/demo.py @@ -0,0 +1,113 @@ +import argparse +import sys +import os + +import torch +sys.path.insert(0, os.path.dirname(__file__)) +import numpy as np +import joblib +from scripts.scripts_test_video.detect_track_video import detect_track_video +from scripts.scripts_test_video.hawor_video import hawor_motion_estimation, hawor_infiller +from scripts.scripts_test_video.hawor_slam import hawor_slam +from hawor.utils.process import get_mano_faces, run_mano, run_mano_left +from lib.eval_utils.custom_utils import load_slam_cam +from lib.vis.run_vis2 import run_vis2_on_video, run_vis2_on_video_cam + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--img_focal", type=float) + parser.add_argument("--video_path", type=str, default='example/video_0.mp4') + parser.add_argument("--input_type", type=str, default='file') + parser.add_argument("--checkpoint", type=str, default='./weights/hawor/checkpoints/hawor.ckpt') + parser.add_argument("--infiller_weight", type=str, default='./weights/hawor/checkpoints/infiller.pt') + parser.add_argument("--vis_mode", type=str, default='world', help='cam | world') + args = parser.parse_args() + + start_idx, end_idx, seq_folder, imgfiles = detect_track_video(args) + + frame_chunks_all, img_focal = hawor_motion_estimation(args, start_idx, end_idx, seq_folder) + + hawor_slam(args, start_idx, end_idx) + slam_path = os.path.join(seq_folder, f"SLAM/hawor_slam_w_scale_{start_idx}_{end_idx}.npz") + R_w2c_sla_all, t_w2c_sla_all, R_c2w_sla_all, t_c2w_sla_all = load_slam_cam(slam_path) + + pred_trans, pred_rot, pred_hand_pose, pred_betas, pred_valid = hawor_infiller(args, start_idx, end_idx, frame_chunks_all) + + # vis sequence for this video + hand2idx = { + "right": 1, + "left": 0 + } + vis_start = 0 + vis_end = pred_trans.shape[1] - 1 + + # get faces + faces = get_mano_faces() + faces_new = np.array([[92, 38, 234], + [234, 38, 239], + [38, 122, 239], + [239, 122, 279], + [122, 118, 279], + [279, 118, 215], + [118, 117, 215], + [215, 117, 214], + [117, 119, 214], + [214, 119, 121], + [119, 120, 121], + [121, 120, 78], + [120, 108, 78], + [78, 108, 79]]) + faces_right = np.concatenate([faces, faces_new], axis=0) + + # get right hand vertices + hand = 'right' + hand_idx = hand2idx[hand] + pred_glob_r = run_mano(pred_trans[hand_idx:hand_idx+1, vis_start:vis_end], pred_rot[hand_idx:hand_idx+1, vis_start:vis_end], pred_hand_pose[hand_idx:hand_idx+1, vis_start:vis_end], betas=pred_betas[hand_idx:hand_idx+1, vis_start:vis_end]) + right_verts = pred_glob_r['vertices'][0] + right_dict = { + 'vertices': right_verts.unsqueeze(0), + 'faces': faces_right, + } + + # get left hand vertices + faces_left = faces_right[:,[0,2,1]] + hand = 'left' + hand_idx = hand2idx[hand] + pred_glob_l = run_mano_left(pred_trans[hand_idx:hand_idx+1, vis_start:vis_end], pred_rot[hand_idx:hand_idx+1, vis_start:vis_end], pred_hand_pose[hand_idx:hand_idx+1, vis_start:vis_end], betas=pred_betas[hand_idx:hand_idx+1, vis_start:vis_end]) + left_verts = pred_glob_l['vertices'][0] + left_dict = { + 'vertices': left_verts.unsqueeze(0), + 'faces': faces_left, + } + + R_x = torch.tensor([[1, 0, 0], + [0, -1, 0], + [0, 0, -1]]).float() + R_c2w_sla_all = torch.einsum('ij,njk->nik', R_x, R_c2w_sla_all) + t_c2w_sla_all = torch.einsum('ij,nj->ni', R_x, t_c2w_sla_all) + R_w2c_sla_all = R_c2w_sla_all.transpose(-1, -2) + t_w2c_sla_all = -torch.einsum("bij,bj->bi", R_w2c_sla_all, t_c2w_sla_all) + left_dict['vertices'] = torch.einsum('ij,btnj->btni', R_x, left_dict['vertices'].cpu()) + right_dict['vertices'] = torch.einsum('ij,btnj->btni', R_x, right_dict['vertices'].cpu()) + + # Here we use aitviewer(https://github.com/eth-ait/aitviewer) for simple visualization. + if args.vis_mode == 'world': + output_pth = os.path.join(seq_folder, f"vis_{vis_start}_{vis_end}") + if not os.path.exists(output_pth): + os.makedirs(output_pth) + image_names = imgfiles[vis_start:vis_end] + print(f"vis {vis_start} to {vis_end}") + run_vis2_on_video(left_dict, right_dict, output_pth, img_focal, image_names, R_c2w=R_c2w_sla_all[vis_start:vis_end], t_c2w=t_c2w_sla_all[vis_start:vis_end]) + elif args.vis_mode == 'cam': + output_pth = os.path.join(seq_folder, f"vis_{vis_start}_{vis_end}") + if not os.path.exists(output_pth): + os.makedirs(output_pth) + image_names = imgfiles[vis_start:vis_end] + print(f"vis {vis_start} to {vis_end}") + run_vis2_on_video_cam(left_dict, right_dict, output_pth, img_focal, image_names, R_w2c=R_w2c_sla_all[vis_start:vis_end], t_w2c=t_w2c_sla_all[vis_start:vis_end]) + + print("finish") + + + diff --git a/example/video_0.mp4 b/example/video_0.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..69082314e091591656cfff9993853abc208a9568 --- /dev/null +++ b/example/video_0.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:13ff124a68e4b48190e0c3f0ce9f38db59c5e3bb8a093b3c7fc9c67276be2062 +size 6515891 diff --git a/hawor/configs/__init__.py b/hawor/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..569bf1675619d2765c8135482bfdd9bebb032482 --- /dev/null +++ b/hawor/configs/__init__.py @@ -0,0 +1,120 @@ +import os +from typing import Dict +from yacs.config import CfgNode as CN + +CACHE_DIR_HAWOR = "./_DATA" + +def to_lower(x: Dict) -> Dict: + """ + Convert all dictionary keys to lowercase + Args: + x (dict): Input dictionary + Returns: + dict: Output dictionary with all keys converted to lowercase + """ + return {k.lower(): v for k, v in x.items()} + +_C = CN(new_allowed=True) + +_C.GENERAL = CN(new_allowed=True) +_C.GENERAL.RESUME = True +_C.GENERAL.TIME_TO_RUN = 3300 +_C.GENERAL.VAL_STEPS = 100 +_C.GENERAL.LOG_STEPS = 100 +_C.GENERAL.CHECKPOINT_STEPS = 20000 +_C.GENERAL.CHECKPOINT_DIR = "checkpoints" +_C.GENERAL.SUMMARY_DIR = "tensorboard" +_C.GENERAL.NUM_GPUS = 1 +_C.GENERAL.NUM_WORKERS = 4 +_C.GENERAL.MIXED_PRECISION = True +_C.GENERAL.ALLOW_CUDA = True +_C.GENERAL.PIN_MEMORY = False +_C.GENERAL.DISTRIBUTED = False +_C.GENERAL.LOCAL_RANK = 0 +_C.GENERAL.USE_SYNCBN = False +_C.GENERAL.WORLD_SIZE = 1 + +_C.TRAIN = CN(new_allowed=True) +_C.TRAIN.NUM_EPOCHS = 100 +_C.TRAIN.BATCH_SIZE = 32 +_C.TRAIN.SHUFFLE = True +_C.TRAIN.WARMUP = False +_C.TRAIN.NORMALIZE_PER_IMAGE = False +_C.TRAIN.CLIP_GRAD = False +_C.TRAIN.CLIP_GRAD_VALUE = 1.0 +_C.LOSS_WEIGHTS = CN(new_allowed=True) + +_C.DATASETS = CN(new_allowed=True) + +_C.MODEL = CN(new_allowed=True) +_C.MODEL.IMAGE_SIZE = 224 + +_C.EXTRA = CN(new_allowed=True) +_C.EXTRA.FOCAL_LENGTH = 5000 + +_C.DATASETS.CONFIG = CN(new_allowed=True) +_C.DATASETS.CONFIG.SCALE_FACTOR = 0.3 +_C.DATASETS.CONFIG.ROT_FACTOR = 30 +_C.DATASETS.CONFIG.TRANS_FACTOR = 0.02 +_C.DATASETS.CONFIG.COLOR_SCALE = 0.2 +_C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6 +_C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5 +_C.DATASETS.CONFIG.DO_FLIP = False +_C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5 +_C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10 + +def default_config() -> CN: + """ + Get a yacs CfgNode object with the default config values. + """ + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _C.clone() + +def dataset_config() -> CN: + """ + Get dataset config file + Returns: + CfgNode: Dataset config as a yacs CfgNode object. + """ + cfg = CN(new_allowed=True) + config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets_tar.yaml') + cfg.merge_from_file(config_file) + cfg.freeze() + return cfg + +def dataset_eval_config() -> CN: + cfg = CN(new_allowed=True) + config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets_eval.yaml') + cfg.merge_from_file(config_file) + cfg.freeze() + return cfg + +def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN: + """ + Read a config file and optionally merge it with the default config file. + Args: + config_file (str): Path to config file. + merge (bool): Whether to merge with the default config or not. + Returns: + CfgNode: Config as a yacs CfgNode object. + """ + if merge: + cfg = default_config() + else: + cfg = CN(new_allowed=True) + cfg.merge_from_file(config_file) + + if update_cachedir: + def update_path(path: str) -> str: + if os.path.basename(CACHE_DIR_HAWOR) in path: + return path + if os.path.isabs(path): + return path + return os.path.join(CACHE_DIR_HAWOR, path) + + cfg.MANO.MODEL_PATH = update_path(cfg.MANO.MODEL_PATH) + cfg.MANO.MEAN_PARAMS = update_path(cfg.MANO.MEAN_PARAMS) + + cfg.freeze() + return cfg diff --git a/hawor/configs/__pycache__/__init__.cpython-310.pyc b/hawor/configs/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61f3434ffd66806e5e80fa5eab81122a1bbb807f Binary files /dev/null and b/hawor/configs/__pycache__/__init__.cpython-310.pyc differ diff --git a/hawor/utils/__pycache__/geometry.cpython-310.pyc b/hawor/utils/__pycache__/geometry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f059c3f3ff802d128d320ced5558abee3c1afdf Binary files /dev/null and b/hawor/utils/__pycache__/geometry.cpython-310.pyc differ diff --git a/hawor/utils/__pycache__/process.cpython-310.pyc b/hawor/utils/__pycache__/process.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9da2e4829b6ba7e66fc498b011525820bb4f7e25 Binary files /dev/null and b/hawor/utils/__pycache__/process.cpython-310.pyc differ diff --git a/hawor/utils/__pycache__/pylogger.cpython-310.pyc b/hawor/utils/__pycache__/pylogger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7ac89c1cd866f31319e154116338138699ab077 Binary files /dev/null and b/hawor/utils/__pycache__/pylogger.cpython-310.pyc differ diff --git a/hawor/utils/__pycache__/render_openpose.cpython-310.pyc b/hawor/utils/__pycache__/render_openpose.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f151b9daaf3d603d87e22ee034475d5dbe7acf2 Binary files /dev/null and b/hawor/utils/__pycache__/render_openpose.cpython-310.pyc differ diff --git a/hawor/utils/__pycache__/rotation.cpython-310.pyc b/hawor/utils/__pycache__/rotation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07c890f30559e00cf99c04ffaa821aedf48a0570 Binary files /dev/null and b/hawor/utils/__pycache__/rotation.cpython-310.pyc differ diff --git a/hawor/utils/geometry.py b/hawor/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..1effe3d9aa66386aa3dd114f07646c6a96a5035e --- /dev/null +++ b/hawor/utils/geometry.py @@ -0,0 +1,102 @@ +from typing import Optional +import torch +from torch.nn import functional as F + +def aa_to_rotmat(theta: torch.Tensor): + """ + Convert axis-angle representation to rotation matrix. + Works by first converting it to a quaternion. + Args: + theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations. + Returns: + torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). + """ + norm = torch.norm(theta + 1e-8, p = 2, dim = 1) + angle = torch.unsqueeze(norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim = 1) + return quat_to_rotmat(quat) + +def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor: + """ + Convert quaternion representation to rotation matrix. + Args: + quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z). + Returns: + torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). + """ + norm_quat = quat + norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + + +def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor: + """ + Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Args: + x (torch.Tensor): (B,6) Batch of 6-D rotation representations. + Returns: + torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3). + """ + x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous() + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.linalg.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + +def perspective_projection(points: torch.Tensor, + translation: torch.Tensor, + focal_length: torch.Tensor, + camera_center: Optional[torch.Tensor] = None, + rotation: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Computes the perspective projection of a set of 3D points. + Args: + points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points. + translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation. + focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels. + camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels. + rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation. + Returns: + torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points. + """ + batch_size = points.shape[0] + if rotation is None: + rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1) + if camera_center is None: + camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype) + # Populate intrinsic camera matrix K. + K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype) + K[:,0,0] = focal_length[:,0] + K[:,1,1] = focal_length[:,1] + K[:,2,2] = 1. + K[:,:-1, -1] = camera_center + + # Transform points + points = torch.einsum('bij,bkj->bki', rotation, points) + points = points + translation.unsqueeze(1) + + # Apply perspective distortion + projected_points = points / points[:,:,-1].unsqueeze(-1) + + # Apply camera intrinsics + projected_points = torch.einsum('bij,bkj->bki', K, projected_points) + + return projected_points[:, :, :-1] \ No newline at end of file diff --git a/hawor/utils/process.py b/hawor/utils/process.py new file mode 100644 index 0000000000000000000000000000000000000000..89bfc33c069b78bd20e72844b3b42e06becf0b84 --- /dev/null +++ b/hawor/utils/process.py @@ -0,0 +1,198 @@ +import torch +from lib.models.mano_wrapper import MANO +from hawor.utils.geometry import aa_to_rotmat +import numpy as np +import sys +import os + +def block_print(): + sys.stdout = open(os.devnull, 'w') + +def enable_print(): + sys.stdout = sys.__stdout__ + +def get_mano_faces(): + block_print() + MANO_cfg = { + 'DATA_DIR': '_DATA/data/', + 'MODEL_PATH': '_DATA/data/mano', + 'GENDER': 'neutral', + 'NUM_HAND_JOINTS': 15, + 'CREATE_BODY_POSE': False + } + mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()} + mano = MANO(**mano_cfg) + enable_print() + return mano.faces + + +def run_mano(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True): + """ + Forward pass of the SMPL model and populates pred_data accordingly with + joints3d, verts3d, points3d. + + trans : B x T x 3 + root_orient : B x T x 3 + body_pose : B x T x J*3 + betas : (optional) B x D + """ + block_print() + MANO_cfg = { + 'DATA_DIR': '_DATA/data/', + 'MODEL_PATH': '_DATA/data/mano', + 'GENDER': 'neutral', + 'NUM_HAND_JOINTS': 15, + 'CREATE_BODY_POSE': False + } + mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()} + mano = MANO(**mano_cfg) + if use_cuda: + mano = mano.cuda() + + B, T, _ = root_orient.shape + NUM_JOINTS = 15 + mano_params = { + 'global_orient': root_orient.reshape(B*T, -1), + 'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3), + 'betas': betas.reshape(B*T, -1), + } + rotmat_mano_params = mano_params + rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3) + rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3) + rotmat_mano_params['transl'] = trans.reshape(B*T, 3) + + if use_cuda: + mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False) + else: + mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False) + + faces_right = mano.faces + faces_new = np.array([[92, 38, 234], + [234, 38, 239], + [38, 122, 239], + [239, 122, 279], + [122, 118, 279], + [279, 118, 215], + [118, 117, 215], + [215, 117, 214], + [117, 119, 214], + [214, 119, 121], + [119, 120, 121], + [121, 120, 78], + [120, 108, 78], + [78, 108, 79]]) + faces_right = np.concatenate([faces_right, faces_new], axis=0) + faces_n = len(faces_right) + faces_left = faces_right[:,[0,2,1]] + + outputs = { + "joints": mano_output.joints.reshape(B, T, -1, 3), + "vertices": mano_output.vertices.reshape(B, T, -1, 3), + } + + if not is_right is None: + # outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0] + # outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0] + is_right = (is_right[:, :, 0].cpu().numpy() > 0) + faces_result = np.zeros((B, T, faces_n, 3)) + faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0) + faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0) + faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded) + outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32)) + + + enable_print() + return outputs + +def run_mano_left(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True, fix_shapedirs=True): + """ + Forward pass of the SMPL model and populates pred_data accordingly with + joints3d, verts3d, points3d. + + trans : B x T x 3 + root_orient : B x T x 3 + body_pose : B x T x J*3 + betas : (optional) B x D + """ + block_print() + MANO_cfg = { + 'DATA_DIR': '_DATA/data_left/', + 'MODEL_PATH': '_DATA/data_left/mano_left', + 'GENDER': 'neutral', + 'NUM_HAND_JOINTS': 15, + 'CREATE_BODY_POSE': False, + 'is_rhand': False + } + mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()} + mano = MANO(**mano_cfg) + if use_cuda: + mano = mano.cuda() + + # fix MANO shapedirs of the left hand bug (https://github.com/vchoutas/smplx/issues/48) + if fix_shapedirs: + mano.shapedirs[:, 0, :] *= -1 + + B, T, _ = root_orient.shape + NUM_JOINTS = 15 + mano_params = { + 'global_orient': root_orient.reshape(B*T, -1), + 'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3), + 'betas': betas.reshape(B*T, -1), + } + rotmat_mano_params = mano_params + rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3) + rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3) + rotmat_mano_params['transl'] = trans.reshape(B*T, 3) + + if use_cuda: + mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False) + else: + mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False) + + faces_right = mano.faces + faces_new = np.array([[92, 38, 234], + [234, 38, 239], + [38, 122, 239], + [239, 122, 279], + [122, 118, 279], + [279, 118, 215], + [118, 117, 215], + [215, 117, 214], + [117, 119, 214], + [214, 119, 121], + [119, 120, 121], + [121, 120, 78], + [120, 108, 78], + [78, 108, 79]]) + faces_right = np.concatenate([faces_right, faces_new], axis=0) + faces_n = len(faces_right) + faces_left = faces_right[:,[0,2,1]] + + outputs = { + "joints": mano_output.joints.reshape(B, T, -1, 3), + "vertices": mano_output.vertices.reshape(B, T, -1, 3), + } + + if not is_right is None: + # outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0] + # outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0] + is_right = (is_right[:, :, 0].cpu().numpy() > 0) + faces_result = np.zeros((B, T, faces_n, 3)) + faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0) + faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0) + faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded) + outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32)) + + + enable_print() + return outputs + +def run_mano_twohands(init_trans, init_rot, init_hand_pose, is_right, init_betas, use_cuda=True, fix_shapedirs=True): + outputs_left = run_mano_left(init_trans[0:1], init_rot[0:1], init_hand_pose[0:1], None, init_betas[0:1], use_cuda=use_cuda, fix_shapedirs=fix_shapedirs) + outputs_right = run_mano(init_trans[1:2], init_rot[1:2], init_hand_pose[1:2], None, init_betas[1:2], use_cuda=use_cuda) + outputs_two = { + "vertices": torch.cat((outputs_left["vertices"], outputs_right["vertices"]), dim=0), + "joints": torch.cat((outputs_left["joints"], outputs_right["joints"]), dim=0) + + } + return outputs_two \ No newline at end of file diff --git a/hawor/utils/pylogger.py b/hawor/utils/pylogger.py new file mode 100644 index 0000000000000000000000000000000000000000..92ffa71893ec20acde65e44d899334a38d8d1333 --- /dev/null +++ b/hawor/utils/pylogger.py @@ -0,0 +1,17 @@ +import logging + +from pytorch_lightning.utilities import rank_zero_only + + +def get_pylogger(name=__name__) -> logging.Logger: + """Initializes multi-GPU-friendly python command line logger.""" + + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + for level in logging_levels: + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger diff --git a/hawor/utils/render_openpose.py b/hawor/utils/render_openpose.py new file mode 100644 index 0000000000000000000000000000000000000000..2ffcb5c6b52cdec2058f0f3d3b2ec5b705d5b2a9 --- /dev/null +++ b/hawor/utils/render_openpose.py @@ -0,0 +1,225 @@ +""" +Render OpenPose keypoints. +Code was ported to Python from the official C++ implementation https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/utilities/keypoint.cpp +""" +import cv2 +import math +import numpy as np +from typing import List, Tuple + +def get_keypoints_rectangle(keypoints: np.array, threshold: float) -> Tuple[float, float, float]: + """ + Compute rectangle enclosing keypoints above the threshold. + Args: + keypoints (np.array): Keypoint array of shape (N, 3). + threshold (float): Confidence visualization threshold. + Returns: + Tuple[float, float, float]: Rectangle width, height and area. + """ + valid_ind = keypoints[:, -1] > threshold + if valid_ind.sum() > 0: + valid_keypoints = keypoints[valid_ind][:, :-1] + max_x = valid_keypoints[:,0].max() + max_y = valid_keypoints[:,1].max() + min_x = valid_keypoints[:,0].min() + min_y = valid_keypoints[:,1].min() + width = max_x - min_x + height = max_y - min_y + area = width * height + return width, height, area + else: + return 0,0,0 + +def render_keypoints(img: np.array, + keypoints: np.array, + pairs: List, + colors: List, + thickness_circle_ratio: float, + thickness_line_ratio_wrt_circle: float, + pose_scales: List, + threshold: float = 0.1, + alpha: float = 1.0) -> np.array: + """ + Render keypoints on input image. + Args: + img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range. + keypoints (np.array): Keypoint array of shape (N, 3). + pairs (List): List of keypoint pairs per limb. + colors: (List): List of colors per keypoint. + thickness_circle_ratio (float): Circle thickness ratio. + thickness_line_ratio_wrt_circle (float): Line thickness ratio wrt the circle. + pose_scales (List): List of pose scales. + threshold (float): Only visualize keypoints with confidence above the threshold. + Returns: + (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image. + """ + img_orig = img.copy() + width, height = img.shape[1], img.shape[2] + area = width * height + + lineType = 8 + shift = 0 + numberColors = len(colors) + thresholdRectangle = 0.1 + + person_width, person_height, person_area = get_keypoints_rectangle(keypoints, thresholdRectangle) + if person_area > 0: + ratioAreas = min(1, max(person_width / width, person_height / height)) + thicknessRatio = np.maximum(np.round(math.sqrt(area) * thickness_circle_ratio * ratioAreas), 2) + thicknessCircle = np.maximum(1, thicknessRatio if ratioAreas > 0.05 else -np.ones_like(thicknessRatio)) + thicknessLine = np.maximum(1, np.round(thicknessRatio * thickness_line_ratio_wrt_circle)) + radius = thicknessRatio / 2 + + img = np.ascontiguousarray(img.copy()) + for i, pair in enumerate(pairs): + index1, index2 = pair + if keypoints[index1, -1] > threshold and keypoints[index2, -1] > threshold: + thicknessLineScaled = int(round(min(thicknessLine[index1], thicknessLine[index2]) * pose_scales[0])) + colorIndex = index2 + color = colors[colorIndex % numberColors] + keypoint1 = keypoints[index1, :-1].astype(int) + keypoint2 = keypoints[index2, :-1].astype(int) + cv2.line(img, tuple(keypoint1.tolist()), tuple(keypoint2.tolist()), tuple(color.tolist()), thicknessLineScaled, lineType, shift) + for part in range(len(keypoints)): + faceIndex = part + if keypoints[faceIndex, -1] > threshold: + radiusScaled = int(round(radius[faceIndex] * pose_scales[0])) + thicknessCircleScaled = int(round(thicknessCircle[faceIndex] * pose_scales[0])) + colorIndex = part + color = colors[colorIndex % numberColors] + center = keypoints[faceIndex, :-1].astype(int) + cv2.circle(img, tuple(center.tolist()), radiusScaled, tuple(color.tolist()), thicknessCircleScaled, lineType, shift) + return img + +def render_hand_keypoints(img, right_hand_keypoints, threshold=0.1, use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0): + if use_confidence and map_fn is not None: + #thicknessCircleRatioLeft = 1./50 * map_fn(left_hand_keypoints[:, -1]) + thicknessCircleRatioRight = 1./50 * map_fn(right_hand_keypoints[:, -1]) + else: + #thicknessCircleRatioLeft = 1./50 * np.ones(left_hand_keypoints.shape[0]) + thicknessCircleRatioRight = 1./50 * np.ones(right_hand_keypoints.shape[0]) + thicknessLineRatioWRTCircle = 0.75 + pairs = [0,1, 1,2, 2,3, 3,4, 0,5, 5,6, 6,7, 7,8, 0,9, 9,10, 10,11, 11,12, 0,13, 13,14, 14,15, 15,16, 0,17, 17,18, 18,19, 19,20] + pairs = np.array(pairs).reshape(-1,2) + + colors = [100., 100., 100., + 100., 0., 0., + 150., 0., 0., + 200., 0., 0., + 255., 0., 0., + 100., 100., 0., + 150., 150., 0., + 200., 200., 0., + 255., 255., 0., + 0., 100., 50., + 0., 150., 75., + 0., 200., 100., + 0., 255., 125., + 0., 50., 100., + 0., 75., 150., + 0., 100., 200., + 0., 125., 255., + 100., 0., 100., + 150., 0., 150., + 200., 0., 200., + 255., 0., 255.] + colors = np.array(colors).reshape(-1,3) + #colors = np.zeros_like(colors) + poseScales = [1] + #img = render_keypoints(img, left_hand_keypoints, pairs, colors, thicknessCircleRatioLeft, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha) + img = render_keypoints(img, right_hand_keypoints, pairs, colors, thicknessCircleRatioRight, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha) + #img = render_keypoints(img, right_hand_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1) + return img + +def render_hand_landmarks(img, right_hand_keypoints, threshold=0.1, use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0): + if use_confidence and map_fn is not None: + #thicknessCircleRatioLeft = 1./50 * map_fn(left_hand_keypoints[:, -1]) + thicknessCircleRatioRight = 1./50 * map_fn(right_hand_keypoints[:, -1]) + else: + #thicknessCircleRatioLeft = 1./50 * np.ones(left_hand_keypoints.shape[0]) + thicknessCircleRatioRight = 1./50 * np.ones(right_hand_keypoints.shape[0]) + thicknessLineRatioWRTCircle = 0.75 + pairs = [] + pairs = np.array(pairs).reshape(-1,2) + + colors = [255, 0, 0] + colors = np.array(colors).reshape(-1,3) + #colors = np.zeros_like(colors) + poseScales = [1] + #img = render_keypoints(img, left_hand_keypoints, pairs, colors, thicknessCircleRatioLeft, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha) + img = render_keypoints(img, right_hand_keypoints, pairs, colors, thicknessCircleRatioRight * 0.1, thicknessLineRatioWRTCircle * 0.1, poseScales, threshold, alpha=alpha) + #img = render_keypoints(img, right_hand_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1) + return img + +def render_body_keypoints(img: np.array, + body_keypoints: np.array) -> np.array: + """ + Render OpenPose body keypoints on input image. + Args: + img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range. + body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence). + Returns: + (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image. + """ + + thickness_circle_ratio = 1./75. * np.ones(body_keypoints.shape[0]) + thickness_line_ratio_wrt_circle = 0.75 + pairs = [] + pairs = [1,8,1,2,1,5,2,3,3,4,5,6,6,7,8,9,9,10,10,11,8,12,12,13,13,14,1,0,0,15,15,17,0,16,16,18,14,19,19,20,14,21,11,22,22,23,11,24] + pairs = np.array(pairs).reshape(-1,2) + colors = [255., 0., 85., + 255., 0., 0., + 255., 85., 0., + 255., 170., 0., + 255., 255., 0., + 170., 255., 0., + 85., 255., 0., + 0., 255., 0., + 255., 0., 0., + 0., 255., 85., + 0., 255., 170., + 0., 255., 255., + 0., 170., 255., + 0., 85., 255., + 0., 0., 255., + 255., 0., 170., + 170., 0., 255., + 255., 0., 255., + 85., 0., 255., + 0., 0., 255., + 0., 0., 255., + 0., 0., 255., + 0., 255., 255., + 0., 255., 255., + 0., 255., 255.] + colors = np.array(colors).reshape(-1,3) + pose_scales = [1] + return render_keypoints(img, body_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1) + +def render_openpose(img: np.array, + hand_keypoints: np.array) -> np.array: + """ + Render keypoints in the OpenPose format on input image. + Args: + img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range. + body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence). + Returns: + (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image. + """ + #img = render_body_keypoints(img, body_keypoints) + img = render_hand_keypoints(img, hand_keypoints) + return img + +def render_openpose_landmarks(img: np.array, + hand_keypoints: np.array) -> np.array: + """ + Render keypoints in the OpenPose format on input image. + Args: + img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range. + body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence). + Returns: + (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image. + """ + #img = render_body_keypoints(img, body_keypoints) + img = render_hand_landmarks(img, hand_keypoints) + return img diff --git a/hawor/utils/rotation.py b/hawor/utils/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..a7bf9dd99d4ec03802055c03177be5a1a634eac0 --- /dev/null +++ b/hawor/utils/rotation.py @@ -0,0 +1,293 @@ +import torch +import numpy as np +from torch.nn import functional as F + + +def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): + """ + Taken from https://github.com/mkocabas/VIBE/blob/master/lib/utils/geometry.py + Calculates the rotation matrices for a batch of rotation vectors + - param rot_vecs: torch.tensor (N, 3) array of N axis-angle vectors + - returns R: torch.tensor (N, 3, 3) rotation matrices + """ + batch_size = rot_vecs.shape[0] + device = rot_vecs.device + + angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) + rot_dir = rot_vecs / angle + + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) + + zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view( + (batch_size, 3, 3) + ) + + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + return rot_mat + + +def quaternion_mul(q0, q1): + """ + EXPECTS WXYZ + :param q0 (*, 4) + :param q1 (*, 4) + """ + r0, r1 = q0[..., :1], q1[..., :1] + v0, v1 = q0[..., 1:], q1[..., 1:] + r = r0 * r1 - (v0 * v1).sum(dim=-1, keepdim=True) + v = r0 * v1 + r1 * v0 + torch.linalg.cross(v0, v1) + return torch.cat([r, v], dim=-1) + + +def quaternion_inverse(q, eps=1e-8): + """ + EXPECTS WXYZ + :param q (*, 4) + """ + conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1) + mag = torch.square(q).sum(dim=-1, keepdim=True) + eps + return conj / mag + + +def quaternion_slerp(t, q0, q1, eps=1e-8): + """ + :param t (*, 1) must be between 0 and 1 + :param q0 (*, 4) + :param q1 (*, 4) + """ + dims = q0.shape[:-1] + t = t.view(*dims, 1) + + q0 = F.normalize(q0, p=2, dim=-1) + q1 = F.normalize(q1, p=2, dim=-1) + dot = (q0 * q1).sum(dim=-1, keepdim=True) + + # make sure we give the shortest rotation path (< 180d) + neg = dot < 0 + q1 = torch.where(neg, -q1, q1) + dot = torch.where(neg, -dot, dot) + angle = torch.acos(dot) + + # if angle is too small, just do linear interpolation + collin = torch.abs(dot) > 1 - eps + fac = 1 / torch.sin(angle) + w0 = torch.where(collin, 1 - t, torch.sin((1 - t) * angle) * fac) + w1 = torch.where(collin, t, torch.sin(t * angle) * fac) + slerp = q0 * w0 + q1 * w1 + return slerp + + +def rotation_matrix_to_angle_axis(rotation_matrix): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert rotation matrix to Rodrigues vector + """ + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + aa = quaternion_to_angle_axis(quaternion) + aa[torch.isnan(aa)] = 0.0 + return aa + + +def quaternion_to_angle_axis(quaternion): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert quaternion vector to angle axis of rotation. + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + :param quaternion (*, 4) expects WXYZ + :returns angle_axis (*, 3) + """ + # unpack input and compute conversion + q1 = quaternion[..., 1] + q2 = quaternion[..., 2] + q3 = quaternion[..., 3] + sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta = torch.sqrt(sin_squared_theta) + cos_theta = quaternion[..., 0] + two_theta = 2.0 * torch.where( + cos_theta < 0.0, + torch.atan2(-sin_theta, -cos_theta), + torch.atan2(sin_theta, cos_theta), + ) + + k_pos = two_theta / sin_theta + k_neg = 2.0 * torch.ones_like(sin_theta) + k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) + + angle_axis = torch.zeros_like(quaternion)[..., :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis + + +def angle_axis_to_rotation_matrix(angle_axis): + """ + :param angle_axis (*, 3) + return (*, 3, 3) + """ + quat = angle_axis_to_quaternion(angle_axis) + return quaternion_to_rotation_matrix(quat) + + +def quaternion_to_rotation_matrix(quaternion): + """ + Convert a quaternion to a rotation matrix. + Taken from https://github.com/kornia/kornia, based on + https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101 + https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247 + :param quaternion (N, 4) expects WXYZ order + returns rotation matrix (N, 3, 3) + """ + # normalize the input quaternion + quaternion_norm = F.normalize(quaternion, p=2, dim=-1, eps=1e-12) + *dims, _ = quaternion_norm.shape + + # unpack the normalized quaternion components + w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1) + + # compute the actual conversion + tx = 2.0 * x + ty = 2.0 * y + tz = 2.0 * z + twx = tx * w + twy = ty * w + twz = tz * w + txx = tx * x + txy = ty * x + txz = tz * x + tyy = ty * y + tyz = tz * y + tzz = tz * z + one = torch.tensor(1.0) + + matrix = torch.stack( + ( + one - (tyy + tzz), + txy - twz, + txz + twy, + txy + twz, + one - (txx + tzz), + tyz - twx, + txz - twy, + tyz + twx, + one - (txx + tyy), + ), + dim=-1, + ).view(*dims, 3, 3) + return matrix + + +def angle_axis_to_quaternion(angle_axis): + """ + This function is borrowed from https://github.com/kornia/kornia + Convert angle axis to quaternion in WXYZ order + :param angle_axis (*, 3) + :returns quaternion (*, 4) WXYZ order + """ + theta_sq = torch.sum(angle_axis**2, dim=-1, keepdim=True) # (*, 1) + # need to handle the zero rotation case + valid = theta_sq > 0 + theta = torch.sqrt(theta_sq) + half_theta = 0.5 * theta + ones = torch.ones_like(half_theta) + # fill zero with the limit of sin ax / x -> a + k = torch.where(valid, torch.sin(half_theta) / theta, 0.5 * ones) + w = torch.where(valid, torch.cos(half_theta), ones) + quat = torch.cat([w, k * angle_axis], dim=-1) + return quat + + +def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): + """ + This function is borrowed from https://github.com/kornia/kornia + Convert rotation matrix to 4d quaternion vector + This algorithm is based on algorithm described in + https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 + + :param rotation_matrix (N, 3, 3) + """ + *dims, m, n = rotation_matrix.shape + rmat_t = torch.transpose(rotation_matrix.reshape(-1, m, n), -1, -2) + + mask_d2 = rmat_t[:, 2, 2] < eps + + mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] + mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] + + t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q0 = torch.stack( + [ + rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + t0, + rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + ], + -1, + ) + t0_rep = t0.repeat(4, 1).t() + + t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q1 = torch.stack( + [ + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + t1, + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], + ], + -1, + ) + t1_rep = t1.repeat(4, 1).t() + + t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q2 = torch.stack( + [ + rmat_t[:, 0, 1] - rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], + t2, + ], + -1, + ) + t2_rep = t2.repeat(4, 1).t() + + t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q3 = torch.stack( + [ + t3, + rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] - rmat_t[:, 1, 0], + ], + -1, + ) + t3_rep = t3.repeat(4, 1).t() + + mask_c0 = mask_d2 * mask_d0_d1 + mask_c1 = mask_d2 * ~mask_d0_d1 + mask_c2 = ~mask_d2 * mask_d0_nd1 + mask_c3 = ~mask_d2 * ~mask_d0_nd1 + mask_c0 = mask_c0.view(-1, 1).type_as(q0) + mask_c1 = mask_c1.view(-1, 1).type_as(q1) + mask_c2 = mask_c2.view(-1, 1).type_as(q2) + mask_c3 = mask_c3.view(-1, 1).type_as(q3) + + q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 + q /= torch.sqrt( + t0_rep * mask_c0 + + t1_rep * mask_c1 + + t2_rep * mask_c2 # noqa + + t3_rep * mask_c3 + ) # noqa + q *= 0.5 + return q.reshape(*dims, 4) diff --git a/imgui.ini b/imgui.ini new file mode 100644 index 0000000000000000000000000000000000000000..3cf22d85b75378eea2a4e7e1df56f60199fd9708 --- /dev/null +++ b/imgui.ini @@ -0,0 +1,15 @@ +[Window][Debug##Default] +Pos=60,60 +Size=400,400 +Collapsed=0 + +[Window][Editor] +Pos=50,50 +Size=250,700 +Collapsed=0 + +[Window][Playback] +Pos=50,800 +Size=400,175 +Collapsed=1 + diff --git a/infiller/hand_utils/geometry.py b/infiller/hand_utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d65066ce383c744407f8d800a7342cd324142b --- /dev/null +++ b/infiller/hand_utils/geometry.py @@ -0,0 +1,412 @@ +import numpy as np +import torch +from torch.nn import functional as F + + +def perspective_projection(points, rotation, translation, + focal_length, camera_center, distortion=None): + """ + This function computes the perspective projection of a set of points. + Input: + points (bs, N, 3): 3D points + rotation (bs, 3, 3): Camera rotation + translation (bs, 3): Camera translation + focal_length (bs,) or scalar: Focal length + camera_center (bs, 2): Camera center + """ + batch_size = points.shape[0] + + # Extrinsic + if rotation is not None: + points = torch.einsum('bij,bkj->bki', rotation, points) + + if translation is not None: + points = points + translation.unsqueeze(1) + + if distortion is not None: + kc = distortion + points = points[:,:,:2] / points[:,:,2:] + + r2 = points[:,:,0]**2 + points[:,:,1]**2 + dx = (2 * kc[:,[2]] * points[:,:,0] * points[:,:,1] + + kc[:,[3]] * (r2 + 2*points[:,:,0]**2)) + + dy = (2 * kc[:,[3]] * points[:,:,0] * points[:,:,1] + + kc[:,[2]] * (r2 + 2*points[:,:,1]**2)) + + x = (1 + kc[:,[0]]*r2 + kc[:,[1]]*r2.pow(2) + kc[:,[4]]*r2.pow(3)) * points[:,:,0] + dx + y = (1 + kc[:,[0]]*r2 + kc[:,[1]]*r2.pow(2) + kc[:,[4]]*r2.pow(3)) * points[:,:,1] + dy + + points = torch.stack([x, y, torch.ones_like(x)], dim=-1) + + # Intrinsic + K = torch.zeros([batch_size, 3, 3], device=points.device) + K[:,0,0] = focal_length + K[:,1,1] = focal_length + K[:,2,2] = 1. + K[:,:-1, -1] = camera_center + + # Apply camera intrinsicsrf + points = points / points[:,:,-1].unsqueeze(-1) + projected_points = torch.einsum('bij,bkj->bki', K, points) + projected_points = projected_points[:, :, :-1] + + return projected_points + + +def avg_rot(rot): + # input [B,...,3,3] --> output [...,3,3] + rot = rot.mean(dim=0) + U, _, V = torch.svd(rot) + rot = U @ V.transpose(-1, -2) + return rot + + +def rot9d_to_rotmat(x): + """Convert 9D rotation representation to 3x3 rotation matrix. + Based on Levinson et al., "An Analysis of SVD for Deep Rotation Estimation" + Input: + (B,9) or (B,J*9) Batch of 9D rotation (interpreted as 3x3 est rotmat) + Output: + (B,3,3) or (B*J,3,3) Batch of corresponding rotation matrices + """ + x = x.view(-1,3,3) + u, _, vh = torch.linalg.svd(x) + + sig = torch.eye(3).expand(len(x), 3, 3).clone() + sig = sig.to(x.device) + sig[:, -1, -1] = (u @ vh).det() + + R = u @ sig @ vh + + return R + + +""" +Deprecated in favor of: rotation_conversions.py + +Useful geometric operations, e.g. differentiable Rodrigues formula +Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR +""" +def batch_rodrigues(theta): + """Convert axis-angle representation to rotation matrix. + Args: + theta: size = [B, 3] + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1) + angle = torch.unsqueeze(l1norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim = 1) + return quat_to_rotmat(quat) + +def quat_to_rotmat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + +def rot6d_to_rotmat(x): + """Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,6) Batch of 6-D rotation representations + Output: + (B,3,3) Batch of corresponding rotation matrices + """ + x = x.view(-1,3,2) + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + +def rot6d_to_rotmat_hmr2(x: torch.Tensor) -> torch.Tensor: + """ + Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Args: + x (torch.Tensor): (B,6) Batch of 6-D rotation representations. + Returns: + torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3). + """ + x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous() + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + +def rotmat_to_rot6d(rotmat): + """ Inverse function of the above. + Input: + (B,3,3) Batch of corresponding rotation matrices + Output: + (B,6) Batch of 6-D rotation representations + """ + # rot6d = rotmat[:, :, :2] + rot6d = rotmat[...,:2] + rot6d = rot6d.reshape(rot6d.size(0), -1) + return rot6d + + +def rotation_matrix_to_angle_axis(rotation_matrix): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to Rodrigues vector + + Args: + rotation_matrix (Tensor): rotation matrix. + + Returns: + Tensor: Rodrigues vector transformation. + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 3)` + + Example: + >>> input = torch.rand(2, 3, 4) # Nx4x4 + >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3 + """ + if rotation_matrix.shape[1:] == (3,3): + rot_mat = rotation_matrix.reshape(-1, 3, 3) + hom = torch.tensor([0, 0, 1], dtype=torch.float32, + device=rotation_matrix.device).reshape(1, 3, 1).expand(rot_mat.shape[0], -1, -1) + rotation_matrix = torch.cat([rot_mat, hom], dim=-1) + + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + aa = quaternion_to_angle_axis(quaternion) + aa[torch.isnan(aa)] = 0.0 + return aa + + +def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor: + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert quaternion vector to angle axis of rotation. + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + quaternion (torch.Tensor): tensor with quaternions. + + Return: + torch.Tensor: tensor with angle axis of rotation. + + Shape: + - Input: :math:`(*, 4)` where `*` means, any number of dimensions + - Output: :math:`(*, 3)` + + Example: + >>> quaternion = torch.rand(2, 4) # Nx4 + >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 + """ + if not torch.is_tensor(quaternion): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(quaternion))) + + if not quaternion.shape[-1] == 4: + raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}" + .format(quaternion.shape)) + # unpack input and compute conversion + q1: torch.Tensor = quaternion[..., 1] + q2: torch.Tensor = quaternion[..., 2] + q3: torch.Tensor = quaternion[..., 3] + sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) + cos_theta: torch.Tensor = quaternion[..., 0] + two_theta: torch.Tensor = 2.0 * torch.where( + cos_theta < 0.0, + torch.atan2(-sin_theta, -cos_theta), + torch.atan2(sin_theta, cos_theta)) + + k_pos: torch.Tensor = two_theta / sin_theta + k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) + k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) + + angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis + + +def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to 4d quaternion vector + + This algorithm is based on algorithm described in + https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 + + Args: + rotation_matrix (Tensor): the rotation matrix to convert. + + Return: + Tensor: the rotation in quaternion + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 4)` + + Example: + >>> input = torch.rand(4, 3, 4) # Nx3x4 + >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 + """ + if not torch.is_tensor(rotation_matrix): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(rotation_matrix))) + + if len(rotation_matrix.shape) > 3: + raise ValueError( + "Input size must be a three dimensional tensor. Got {}".format( + rotation_matrix.shape)) + if not rotation_matrix.shape[-2:] == (3, 4): + raise ValueError( + "Input size must be a N x 3 x 4 tensor. Got {}".format( + rotation_matrix.shape)) + + rmat_t = torch.transpose(rotation_matrix, 1, 2) + + mask_d2 = rmat_t[:, 2, 2] < eps + + mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] + mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] + + t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1) + t0_rep = t0.repeat(4, 1).t() + + t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1) + t1_rep = t1.repeat(4, 1).t() + + t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1) + t2_rep = t2.repeat(4, 1).t() + + t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1) + t3_rep = t3.repeat(4, 1).t() + + mask_c0 = mask_d2 * mask_d0_d1 + mask_c1 = mask_d2 * ~mask_d0_d1 + mask_c2 = ~mask_d2 * mask_d0_nd1 + mask_c3 = ~mask_d2 * ~mask_d0_nd1 + mask_c0 = mask_c0.view(-1, 1).type_as(q0) + mask_c1 = mask_c1.view(-1, 1).type_as(q1) + mask_c2 = mask_c2.view(-1, 1).type_as(q2) + mask_c3 = mask_c3.view(-1, 1).type_as(q3) + + q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 + q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa + t2_rep * mask_c2 + t3_rep * mask_c3) # noqa + q *= 0.5 + return q + + +def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000., img_size=224.): + """ + This function is borrowed from https://github.com/nkolot/SPIN/utils/geometry.py + + Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Input: + S: (25, 3) 3D joint locations + joints: (25, 3) 2D joint locations and confidence + Returns: + (3,) camera translation vector + """ + + num_joints = S.shape[0] + # focal length + f = np.array([focal_length,focal_length]) + # optical center + center = np.array([img_size/2., img_size/2.]) + + # transformations + Z = np.reshape(np.tile(S[:,2],(2,1)).T,-1) + XY = np.reshape(S[:,0:2],-1) + O = np.tile(center,num_joints) + F = np.tile(f,num_joints) + weight2 = np.reshape(np.tile(np.sqrt(joints_conf),(2,1)).T,-1) + + # least squares + Q = np.array([F*np.tile(np.array([1,0]),num_joints), F*np.tile(np.array([0,1]),num_joints), O-np.reshape(joints_2d,-1)]).T + c = (np.reshape(joints_2d,-1)-O)*Z - F*XY + + # weighted least squares + W = np.diagflat(weight2) + Q = np.dot(W,Q) + c = np.dot(W,c) + + # square matrix + A = np.dot(Q.T,Q) + b = np.dot(Q.T,c) + + # solution + trans = np.linalg.solve(A, b) + + return trans + + +def estimate_translation(S, joints_2d, focal_length=5000., img_size=224.): + """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Input: + S: (B, 49, 3) 3D joint locations + joints: (B, 49, 3) 2D joint locations and confidence + Returns: + (B, 3) camera translation vectors + """ + + device = S.device + # Use only joints 25:49 (GT joints) + S = S[:, -24:, :3].cpu().numpy() + joints_2d = joints_2d[:, -24:, :].cpu().numpy() + + joints_conf = joints_2d[:, :, -1] + joints_2d = joints_2d[:, :, :-1] + trans = np.zeros((S.shape[0], 3), dtype=np.float32) + # Find the translation for each example in the batch + for i in range(S.shape[0]): + S_i = S[i] + joints_i = joints_2d[i] + conf_i = joints_conf[i] + trans[i] = estimate_translation_np(S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size) + return torch.from_numpy(trans).to(device) + + diff --git a/infiller/hand_utils/geometry_utils.py b/infiller/hand_utils/geometry_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1effe3d9aa66386aa3dd114f07646c6a96a5035e --- /dev/null +++ b/infiller/hand_utils/geometry_utils.py @@ -0,0 +1,102 @@ +from typing import Optional +import torch +from torch.nn import functional as F + +def aa_to_rotmat(theta: torch.Tensor): + """ + Convert axis-angle representation to rotation matrix. + Works by first converting it to a quaternion. + Args: + theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations. + Returns: + torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). + """ + norm = torch.norm(theta + 1e-8, p = 2, dim = 1) + angle = torch.unsqueeze(norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim = 1) + return quat_to_rotmat(quat) + +def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor: + """ + Convert quaternion representation to rotation matrix. + Args: + quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z). + Returns: + torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). + """ + norm_quat = quat + norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + + +def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor: + """ + Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Args: + x (torch.Tensor): (B,6) Batch of 6-D rotation representations. + Returns: + torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3). + """ + x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous() + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.linalg.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + +def perspective_projection(points: torch.Tensor, + translation: torch.Tensor, + focal_length: torch.Tensor, + camera_center: Optional[torch.Tensor] = None, + rotation: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Computes the perspective projection of a set of 3D points. + Args: + points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points. + translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation. + focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels. + camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels. + rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation. + Returns: + torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points. + """ + batch_size = points.shape[0] + if rotation is None: + rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1) + if camera_center is None: + camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype) + # Populate intrinsic camera matrix K. + K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype) + K[:,0,0] = focal_length[:,0] + K[:,1,1] = focal_length[:,1] + K[:,2,2] = 1. + K[:,:-1, -1] = camera_center + + # Transform points + points = torch.einsum('bij,bkj->bki', rotation, points) + points = points + translation.unsqueeze(1) + + # Apply perspective distortion + projected_points = points / points[:,:,-1].unsqueeze(-1) + + # Apply camera intrinsics + projected_points = torch.einsum('bij,bkj->bki', K, projected_points) + + return projected_points[:, :, :-1] \ No newline at end of file diff --git a/infiller/hand_utils/mano_wrapper.py b/infiller/hand_utils/mano_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..4801c26e3410035db3d09f8dac70c637dd7b00f5 --- /dev/null +++ b/infiller/hand_utils/mano_wrapper.py @@ -0,0 +1,52 @@ +import torch +import numpy as np +import pickle +from typing import Optional +import smplx +from smplx.lbs import vertices2joints +from smplx.utils import MANOOutput, to_tensor +from smplx.vertex_ids import vertex_ids + + +class MANO(smplx.MANOLayer): + def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs): + """ + Extension of the official MANO implementation to support more joints. + Args: + Same as MANOLayer. + joint_regressor_extra (str): Path to extra joint regressor. + """ + super(MANO, self).__init__(*args, **kwargs) + mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20] + + #2, 3, 5, 4, 1 + if joint_regressor_extra is not None: + self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32)) + self.register_buffer('extra_joints_idxs', to_tensor(list(vertex_ids['mano'].values()), dtype=torch.long)) + self.register_buffer('joint_map', torch.tensor(mano_to_openpose, dtype=torch.long)) + + def forward(self, *args, **kwargs) -> MANOOutput: + """ + Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified. + """ + mano_output = super(MANO, self).forward(*args, **kwargs) + extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs) + joints = torch.cat([mano_output.joints, extra_joints], dim=1) + joints = joints[:, self.joint_map, :] + if hasattr(self, 'joint_regressor_extra'): + extra_joints = vertices2joints(self.joint_regressor_extra, mano_output.vertices) + joints = torch.cat([joints, extra_joints], dim=1) + mano_output.joints = joints + return mano_output + + def query(self, hmr_output): + batch_size = hmr_output['pred_rotmat'].shape[0] + pred_rotmat = hmr_output['pred_rotmat'].reshape(batch_size, -1, 3, 3) + pred_shape = hmr_output['pred_shape'].reshape(batch_size, 10) + + mano_output = self(global_orient=pred_rotmat[:, [0]], + hand_pose = pred_rotmat[:, 1:], + betas = pred_shape, + pose2rot=False) + + return mano_output \ No newline at end of file diff --git a/infiller/hand_utils/process.py b/infiller/hand_utils/process.py new file mode 100644 index 0000000000000000000000000000000000000000..bb40b94c964837733bc0403adf802f4f9ba50654 --- /dev/null +++ b/infiller/hand_utils/process.py @@ -0,0 +1,171 @@ +import torch +from hand_utils.mano_wrapper import MANO +from hand_utils.geometry_utils import aa_to_rotmat +import numpy as np + +def run_mano(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True): + """ + Forward pass of the SMPL model and populates pred_data accordingly with + joints3d, verts3d, points3d. + + trans : B x T x 3 + root_orient : B x T x 3 + body_pose : B x T x J*3 + betas : (optional) B x D + """ + MANO_cfg = { + 'DATA_DIR': '_DATA/data/', + 'MODEL_PATH': '_DATA/data/mano', + 'GENDER': 'neutral', + 'NUM_HAND_JOINTS': 15, + 'CREATE_BODY_POSE': False + } + mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()} + mano = MANO(**mano_cfg) + if use_cuda: + mano = mano.cuda() + + B, T, _ = root_orient.shape + NUM_JOINTS = 15 + mano_params = { + 'global_orient': root_orient.reshape(B*T, -1), + 'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3), + 'betas': betas.reshape(B*T, -1), + } + rotmat_mano_params = mano_params + rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3) + rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3) + rotmat_mano_params['transl'] = trans.reshape(B*T, 3) + + if use_cuda: + mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False) + else: + mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False) + + faces_right = mano.faces + faces_new = np.array([[92, 38, 234], + [234, 38, 239], + [38, 122, 239], + [239, 122, 279], + [122, 118, 279], + [279, 118, 215], + [118, 117, 215], + [215, 117, 214], + [117, 119, 214], + [214, 119, 121], + [119, 120, 121], + [121, 120, 78], + [120, 108, 78], + [78, 108, 79]]) + faces_right = np.concatenate([faces_right, faces_new], axis=0) + faces_n = len(faces_right) + faces_left = faces_right[:,[0,2,1]] + + outputs = { + "joints": mano_output.joints.reshape(B, T, -1, 3), + "vertices": mano_output.vertices.reshape(B, T, -1, 3), + } + + if not is_right is None: + # outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0] + # outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0] + is_right = (is_right[:, :, 0].cpu().numpy() > 0) + faces_result = np.zeros((B, T, faces_n, 3)) + faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0) + faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0) + faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded) + outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32)) + + + return outputs + +def run_mano_left(trans, root_orient, hand_pose, is_right=None, betas=None, use_cuda=True, fix_shapedirs=True): + """ + Forward pass of the SMPL model and populates pred_data accordingly with + joints3d, verts3d, points3d. + + trans : B x T x 3 + root_orient : B x T x 3 + body_pose : B x T x J*3 + betas : (optional) B x D + """ + MANO_cfg = { + 'DATA_DIR': '_DATA/data_left/', + 'MODEL_PATH': '_DATA/data_left/mano_left', + 'GENDER': 'neutral', + 'NUM_HAND_JOINTS': 15, + 'CREATE_BODY_POSE': False, + 'is_rhand': False + } + mano_cfg = {k.lower(): v for k,v in MANO_cfg.items()} + mano = MANO(**mano_cfg) + if use_cuda: + mano = mano.cuda() + + # fix MANO shapedirs of the left hand bug (https://github.com/vchoutas/smplx/issues/48) + if fix_shapedirs: + mano.shapedirs[:, 0, :] *= -1 + + B, T, _ = root_orient.shape + NUM_JOINTS = 15 + mano_params = { + 'global_orient': root_orient.reshape(B*T, -1), + 'hand_pose': hand_pose.reshape(B*T*NUM_JOINTS, 3), + 'betas': betas.reshape(B*T, -1), + } + rotmat_mano_params = mano_params + rotmat_mano_params['global_orient'] = aa_to_rotmat(mano_params['global_orient']).view(B*T, 1, 3, 3) + rotmat_mano_params['hand_pose'] = aa_to_rotmat(mano_params['hand_pose']).view(B*T, NUM_JOINTS, 3, 3) + rotmat_mano_params['transl'] = trans.reshape(B*T, 3) + + if use_cuda: + mano_output = mano(**{k: v.float().cuda() for k,v in rotmat_mano_params.items()}, pose2rot=False) + else: + mano_output = mano(**{k: v.float() for k,v in rotmat_mano_params.items()}, pose2rot=False) + + faces_right = mano.faces + faces_new = np.array([[92, 38, 234], + [234, 38, 239], + [38, 122, 239], + [239, 122, 279], + [122, 118, 279], + [279, 118, 215], + [118, 117, 215], + [215, 117, 214], + [117, 119, 214], + [214, 119, 121], + [119, 120, 121], + [121, 120, 78], + [120, 108, 78], + [78, 108, 79]]) + faces_right = np.concatenate([faces_right, faces_new], axis=0) + faces_n = len(faces_right) + faces_left = faces_right[:,[0,2,1]] + + outputs = { + "joints": mano_output.joints.reshape(B, T, -1, 3), + "vertices": mano_output.vertices.reshape(B, T, -1, 3), + } + + if not is_right is None: + # outputs["vertices"][..., 0] = (2*is_right-1)*outputs["vertices"][..., 0] + # outputs["joints"][..., 0] = (2*is_right-1)*outputs["joints"][..., 0] + is_right = (is_right[:, :, 0].cpu().numpy() > 0) + faces_result = np.zeros((B, T, faces_n, 3)) + faces_right_expanded = np.expand_dims(np.expand_dims(faces_right, axis=0), axis=0) + faces_left_expanded = np.expand_dims(np.expand_dims(faces_left, axis=0), axis=0) + faces_result = np.where(is_right[..., np.newaxis, np.newaxis], faces_right_expanded, faces_left_expanded) + outputs["faces"] = torch.from_numpy(faces_result.astype(np.int32)) + + + return outputs + +def run_mano_twohands(init_trans, init_rot, init_hand_pose, is_right, init_betas, use_cuda=True, fix_shapedirs=True): + outputs_left = run_mano_left(init_trans[0:1], init_rot[0:1], init_hand_pose[0:1], None, init_betas[0:1], use_cuda=use_cuda, fix_shapedirs=fix_shapedirs) + outputs_right = run_mano(init_trans[1:2], init_rot[1:2], init_hand_pose[1:2], None, init_betas[1:2], use_cuda=use_cuda) + outputs_two = { + "vertices": torch.cat((outputs_left["vertices"], outputs_right["vertices"]), dim=0), + "joints": torch.cat((outputs_left["joints"], outputs_right["joints"]), dim=0) + + } + return outputs_two \ No newline at end of file diff --git a/infiller/hand_utils/rotation.py b/infiller/hand_utils/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..a7bf9dd99d4ec03802055c03177be5a1a634eac0 --- /dev/null +++ b/infiller/hand_utils/rotation.py @@ -0,0 +1,293 @@ +import torch +import numpy as np +from torch.nn import functional as F + + +def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): + """ + Taken from https://github.com/mkocabas/VIBE/blob/master/lib/utils/geometry.py + Calculates the rotation matrices for a batch of rotation vectors + - param rot_vecs: torch.tensor (N, 3) array of N axis-angle vectors + - returns R: torch.tensor (N, 3, 3) rotation matrices + """ + batch_size = rot_vecs.shape[0] + device = rot_vecs.device + + angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) + rot_dir = rot_vecs / angle + + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) + + zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view( + (batch_size, 3, 3) + ) + + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + return rot_mat + + +def quaternion_mul(q0, q1): + """ + EXPECTS WXYZ + :param q0 (*, 4) + :param q1 (*, 4) + """ + r0, r1 = q0[..., :1], q1[..., :1] + v0, v1 = q0[..., 1:], q1[..., 1:] + r = r0 * r1 - (v0 * v1).sum(dim=-1, keepdim=True) + v = r0 * v1 + r1 * v0 + torch.linalg.cross(v0, v1) + return torch.cat([r, v], dim=-1) + + +def quaternion_inverse(q, eps=1e-8): + """ + EXPECTS WXYZ + :param q (*, 4) + """ + conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1) + mag = torch.square(q).sum(dim=-1, keepdim=True) + eps + return conj / mag + + +def quaternion_slerp(t, q0, q1, eps=1e-8): + """ + :param t (*, 1) must be between 0 and 1 + :param q0 (*, 4) + :param q1 (*, 4) + """ + dims = q0.shape[:-1] + t = t.view(*dims, 1) + + q0 = F.normalize(q0, p=2, dim=-1) + q1 = F.normalize(q1, p=2, dim=-1) + dot = (q0 * q1).sum(dim=-1, keepdim=True) + + # make sure we give the shortest rotation path (< 180d) + neg = dot < 0 + q1 = torch.where(neg, -q1, q1) + dot = torch.where(neg, -dot, dot) + angle = torch.acos(dot) + + # if angle is too small, just do linear interpolation + collin = torch.abs(dot) > 1 - eps + fac = 1 / torch.sin(angle) + w0 = torch.where(collin, 1 - t, torch.sin((1 - t) * angle) * fac) + w1 = torch.where(collin, t, torch.sin(t * angle) * fac) + slerp = q0 * w0 + q1 * w1 + return slerp + + +def rotation_matrix_to_angle_axis(rotation_matrix): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert rotation matrix to Rodrigues vector + """ + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + aa = quaternion_to_angle_axis(quaternion) + aa[torch.isnan(aa)] = 0.0 + return aa + + +def quaternion_to_angle_axis(quaternion): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert quaternion vector to angle axis of rotation. + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + :param quaternion (*, 4) expects WXYZ + :returns angle_axis (*, 3) + """ + # unpack input and compute conversion + q1 = quaternion[..., 1] + q2 = quaternion[..., 2] + q3 = quaternion[..., 3] + sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta = torch.sqrt(sin_squared_theta) + cos_theta = quaternion[..., 0] + two_theta = 2.0 * torch.where( + cos_theta < 0.0, + torch.atan2(-sin_theta, -cos_theta), + torch.atan2(sin_theta, cos_theta), + ) + + k_pos = two_theta / sin_theta + k_neg = 2.0 * torch.ones_like(sin_theta) + k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) + + angle_axis = torch.zeros_like(quaternion)[..., :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis + + +def angle_axis_to_rotation_matrix(angle_axis): + """ + :param angle_axis (*, 3) + return (*, 3, 3) + """ + quat = angle_axis_to_quaternion(angle_axis) + return quaternion_to_rotation_matrix(quat) + + +def quaternion_to_rotation_matrix(quaternion): + """ + Convert a quaternion to a rotation matrix. + Taken from https://github.com/kornia/kornia, based on + https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101 + https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247 + :param quaternion (N, 4) expects WXYZ order + returns rotation matrix (N, 3, 3) + """ + # normalize the input quaternion + quaternion_norm = F.normalize(quaternion, p=2, dim=-1, eps=1e-12) + *dims, _ = quaternion_norm.shape + + # unpack the normalized quaternion components + w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1) + + # compute the actual conversion + tx = 2.0 * x + ty = 2.0 * y + tz = 2.0 * z + twx = tx * w + twy = ty * w + twz = tz * w + txx = tx * x + txy = ty * x + txz = tz * x + tyy = ty * y + tyz = tz * y + tzz = tz * z + one = torch.tensor(1.0) + + matrix = torch.stack( + ( + one - (tyy + tzz), + txy - twz, + txz + twy, + txy + twz, + one - (txx + tzz), + tyz - twx, + txz - twy, + tyz + twx, + one - (txx + tyy), + ), + dim=-1, + ).view(*dims, 3, 3) + return matrix + + +def angle_axis_to_quaternion(angle_axis): + """ + This function is borrowed from https://github.com/kornia/kornia + Convert angle axis to quaternion in WXYZ order + :param angle_axis (*, 3) + :returns quaternion (*, 4) WXYZ order + """ + theta_sq = torch.sum(angle_axis**2, dim=-1, keepdim=True) # (*, 1) + # need to handle the zero rotation case + valid = theta_sq > 0 + theta = torch.sqrt(theta_sq) + half_theta = 0.5 * theta + ones = torch.ones_like(half_theta) + # fill zero with the limit of sin ax / x -> a + k = torch.where(valid, torch.sin(half_theta) / theta, 0.5 * ones) + w = torch.where(valid, torch.cos(half_theta), ones) + quat = torch.cat([w, k * angle_axis], dim=-1) + return quat + + +def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): + """ + This function is borrowed from https://github.com/kornia/kornia + Convert rotation matrix to 4d quaternion vector + This algorithm is based on algorithm described in + https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 + + :param rotation_matrix (N, 3, 3) + """ + *dims, m, n = rotation_matrix.shape + rmat_t = torch.transpose(rotation_matrix.reshape(-1, m, n), -1, -2) + + mask_d2 = rmat_t[:, 2, 2] < eps + + mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] + mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] + + t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q0 = torch.stack( + [ + rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + t0, + rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + ], + -1, + ) + t0_rep = t0.repeat(4, 1).t() + + t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q1 = torch.stack( + [ + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + t1, + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], + ], + -1, + ) + t1_rep = t1.repeat(4, 1).t() + + t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q2 = torch.stack( + [ + rmat_t[:, 0, 1] - rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], + t2, + ], + -1, + ) + t2_rep = t2.repeat(4, 1).t() + + t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q3 = torch.stack( + [ + t3, + rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] - rmat_t[:, 1, 0], + ], + -1, + ) + t3_rep = t3.repeat(4, 1).t() + + mask_c0 = mask_d2 * mask_d0_d1 + mask_c1 = mask_d2 * ~mask_d0_d1 + mask_c2 = ~mask_d2 * mask_d0_nd1 + mask_c3 = ~mask_d2 * ~mask_d0_nd1 + mask_c0 = mask_c0.view(-1, 1).type_as(q0) + mask_c1 = mask_c1.view(-1, 1).type_as(q1) + mask_c2 = mask_c2.view(-1, 1).type_as(q2) + mask_c3 = mask_c3.view(-1, 1).type_as(q3) + + q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 + q /= torch.sqrt( + t0_rep * mask_c0 + + t1_rep * mask_c1 + + t2_rep * mask_c2 # noqa + + t3_rep * mask_c3 + ) # noqa + q *= 0.5 + return q.reshape(*dims, 4) diff --git a/infiller/lib/misc/sampler.py b/infiller/lib/misc/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..16d44f6682886c70f30749aae1ff890a2b1557af --- /dev/null +++ b/infiller/lib/misc/sampler.py @@ -0,0 +1,79 @@ +import argparse +import os +from pathlib import Path + +import imageio +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from sklearn.preprocessing import LabelEncoder + +from cmib.data.lafan1_dataset import LAFAN1Dataset +from cmib.data.utils import write_json +from cmib.lafan1.utils import quat_ik +from cmib.model.network import TransformerModel +from cmib.model.preprocess import (lerp_input_repr, replace_constant, + slerp_input_repr, vectorize_representation) +from cmib.model.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets, joint_names, + sk_parents) +from cmib.vis.pose import plot_pose_with_stop + + +def test(opt, device): + + save_dir = Path(os.path.join('runs', 'train', opt.exp_name)) + wdir = save_dir / 'weights' + weights = os.listdir(wdir) + weights_paths = [wdir / weight for weight in weights] + latest_weight = max(weights_paths , key = os.path.getctime) + ckpt = torch.load(latest_weight, map_location=device) + print(f"Loaded weight: {latest_weight}") + + # Load Skeleton + skeleton_mocap = Skeleton(offsets=sk_offsets, parents=sk_parents, device=device) + skeleton_mocap.remove_joints(sk_joints_to_remove) + + # Load LAFAN Dataset + Path(opt.processed_data_dir).mkdir(parents=True, exist_ok=True) + lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=opt.processed_data_dir, train=False, device=device) + total_data = lafan_dataset.data['global_pos'].shape[0] + + # Replace with noise to In-betweening Frames + from_idx, target_idx = ckpt['from_idx'], ckpt['target_idx'] # default: 9-40, max: 48 + horizon = ckpt['horizon'] + print(f"HORIZON: {horizon}") + + test_idx = [] + for i in range(total_data): + test_idx.append(i) + + # Compare Input data, Prediction, GT + save_path = os.path.join(opt.save_path, 'sampler') + for i in range(len(test_idx)): + Path(save_path).mkdir(parents=True, exist_ok=True) + + start_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx] + target_pose = lafan_dataset.data['global_pos'][test_idx[i], target_idx] + gt_stopover_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx] + + gt_img_path = os.path.join(save_path) + plot_pose_with_stop(start_pose, target_pose, target_pose, gt_stopover_pose, i, skeleton_mocap, save_dir=gt_img_path, prefix='gt') + print(f"ID {test_idx[i]}: completed.") + +def parse_opt(): + parser = argparse.ArgumentParser() + parser.add_argument('--project', default='runs/train', help='project/name') + parser.add_argument('--exp_name', default='slerp_40', help='experiment name') + parser.add_argument('--data_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH', help='BVH dataset path') + parser.add_argument('--skeleton_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH/walk1_subject1.bvh', help='path to reference skeleton') + parser.add_argument('--processed_data_dir', type=str, default='processed_data_original/', help='path to save pickled processed data') + parser.add_argument('--save_path', type=str, default='runs/test', help='path to save model') + parser.add_argument('--motion_type', type=str, default='jumps', help='motion type') + opt = parser.parse_args() + return opt + +if __name__ == "__main__": + opt = parse_opt() + device = torch.device("cpu") + test(opt, device) diff --git a/infiller/lib/model/__pycache__/network.cpython-310.pyc b/infiller/lib/model/__pycache__/network.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b25cc9023b522dabca5d407fb3d10e7bd2a0c7a9 Binary files /dev/null and b/infiller/lib/model/__pycache__/network.cpython-310.pyc differ diff --git a/infiller/lib/model/network.py b/infiller/lib/model/network.py new file mode 100644 index 0000000000000000000000000000000000000000..6f812e0b1ffd1d876d6466eb4f4e974e6f57bd01 --- /dev/null +++ b/infiller/lib/model/network.py @@ -0,0 +1,276 @@ +import math +import numpy as np +import torch +from torch import nn, Tensor +from torch.nn import TransformerEncoder, TransformerEncoderLayer +# from cmib.model.positional_encoding import PositionalEmbedding + +class SinPositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=100): + super(SinPositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x): + # not used in the final model + x = x + self.pe[:x.shape[0], :] + return self.dropout(x) + + +class MultiHeadedAttention(nn.Module): + def __init__(self, n_head, d_model, d_head, dropout=0.1, + pre_lnorm=True, bias=False): + """ + Multi-headed attention with relative positional encoding and + memory mechanism. + + Args: + n_head (int): Number of heads. + d_model (int): Input dimension. + d_head (int): Head dimension. + dropout (float, optional): Dropout value. Defaults to 0.1. + pre_lnorm (bool, optional): + Apply layer norm before rest of calculation. Defaults to True. + In original Transformer paper (pre_lnorm=False): + LayerNorm(x + Sublayer(x)) + In tensor2tensor implementation (pre_lnorm=True): + x + Sublayer(LayerNorm(x)) + bias (bool, optional): + Add bias to q, k, v and output projections. Defaults to False. + + """ + super(MultiHeadedAttention, self).__init__() + + self.n_head = n_head + self.d_model = d_model + self.d_head = d_head + self.dropout = dropout + self.pre_lnorm = pre_lnorm + self.bias = bias + self.atten_scale = 1 / math.sqrt(self.d_model) + + self.q_linear = nn.Linear(d_model, n_head * d_head, bias=bias) + self.k_linear = nn.Linear(d_model, n_head * d_head, bias=bias) + self.v_linear = nn.Linear(d_model, n_head * d_head, bias=bias) + self.out_linear = nn.Linear(n_head * d_head, d_model, bias=bias) + + self.droput_layer = nn.Dropout(dropout) + self.atten_dropout_layer = nn.Dropout(dropout) + + self.layer_norm = nn.LayerNorm(d_model) + + def forward(self, hidden, memory=None, mask=None, + extra_atten_score=None): + """ + Args: + hidden (Tensor): Input embedding or hidden state of previous layer. + Shape: (batch, seq, dim) + pos_emb (Tensor): Relative positional embedding lookup table. + Shape: (batch, (seq+mem_len)*2-1, d_head) + pos_emb[:, seq+mem_len] + + memory (Tensor): Memory tensor of previous layer. + Shape: (batch, mem_len, dim) + mask (BoolTensor, optional): Attention mask. + Set item value to True if you DO NOT want keep certain + attention score, otherwise False. Defaults to None. + Shape: (seq, seq+mem_len). + """ + combined = hidden + # if memory is None: + # combined = hidden + # mem_len = 0 + # else: + # combined = torch.cat([memory, hidden], dim=1) + # mem_len = memory.shape[1] + + if self.pre_lnorm: + hidden = self.layer_norm(hidden) + combined = self.layer_norm(combined) + + # shape: (batch, q/k/v_len, dim) + q = self.q_linear(hidden) + k = self.k_linear(combined) + v = self.v_linear(combined) + + # reshape to (batch, q/k/v_len, n_head, d_head) + q = q.reshape(q.shape[0], q.shape[1], self.n_head, self.d_head) + k = k.reshape(k.shape[0], k.shape[1], self.n_head, self.d_head) + v = v.reshape(v.shape[0], v.shape[1], self.n_head, self.d_head) + + # transpose to (batch, n_head, q/k/v_len, d_head) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # add n_head dimension for relative positional embedding lookup table + # (batch, n_head, k/v_len*2-1, d_head) + # pos_emb = pos_emb[:, None] + + # (batch, n_head, q_len, k_len) + atten_score = torch.matmul(q, k.transpose(-1, -2)) + + # qpos = torch.matmul(q, pos_emb.transpose(-1, -2)) + # DEBUG + # ones = torch.zeros(q.shape) + # ones[:, :, :, 0] = 1.0 + # qpos = torch.matmul(ones, pos_emb.transpose(-1, -2)) + # atten_score = atten_score + self.skew(qpos, mem_len) + atten_score = atten_score * self.atten_scale + + # if extra_atten_score is not None: + # atten_score = atten_score + extra_atten_score + + if mask is not None: + # print(atten_score.shape) + # print(mask.shape) + # apply attention mask + atten_score = atten_score.masked_fill(mask, float("-inf")) + atten_score = atten_score.softmax(dim=-1) + atten_score = self.atten_dropout_layer(atten_score) + + # (batch, n_head, q_len, d_head) + atten_vec = torch.matmul(atten_score, v) + # (batch, q_len, n_head*d_head) + atten_vec = atten_vec.transpose(1, 2).flatten(start_dim=-2) + + # linear projection + output = self.droput_layer(self.out_linear(atten_vec)) + + if self.pre_lnorm: + return hidden + output + else: + return self.layer_norm(hidden + output) + + +class FeedForward(nn.Module): + def __init__(self, d_model, d_inner, dropout=0.1, pre_lnorm=True): + """ + Positionwise feed-forward network. + + Args: + d_model(int): Dimension of the input and output. + d_inner (int): Dimension of the middle layer(bottleneck). + dropout (float, optional): Dropout value. Defaults to 0.1. + pre_lnorm (bool, optional): + Apply layer norm before rest of calculation. Defaults to True. + In original Transformer paper (pre_lnorm=False): + LayerNorm(x + Sublayer(x)) + In tensor2tensor implementation (pre_lnorm=True): + x + Sublayer(LayerNorm(x)) + """ + super(FeedForward, self).__init__() + self.d_model = d_model + self.d_inner = d_inner + self.dropout = dropout + self.pre_lnorm = pre_lnorm + + self.layer_norm = nn.LayerNorm(d_model) + self.network = nn.Sequential( + nn.Linear(d_model, d_inner), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(d_inner, d_model), + nn.Dropout(dropout), + ) + + def forward(self, x): + if self.pre_lnorm: + return x + self.network(self.layer_norm(x)) + else: + return self.layer_norm(x + self.network(x)) +class TransformerModel(nn.Module): + def __init__( + self, + seq_len: int, + input_dim: int, + d_model: int, + nhead: int, + d_hid: int, + nlayers: int, + dropout: float = 0.5, + out_dim=91, + masked_attention_stage=False, + ): + super().__init__() + self.model_type = "Transformer" + self.seq_len = seq_len + self.d_model = d_model + self.nhead = nhead + self.d_hid = d_hid + self.nlayers = nlayers + self.pos_embedding = SinPositionalEncoding(d_model=d_model, dropout=0.1, max_len=seq_len) + if masked_attention_stage: + self.input_layer = nn.Linear(input_dim+1, d_model) + # visible to invisible attention + self.att_layers = nn.ModuleList() + self.pff_layers = nn.ModuleList() + self.pre_lnorm = True + self.layer_norm = nn.LayerNorm(d_model) + for i in range(self.nlayers): + self.att_layers.append( + MultiHeadedAttention( + self.nhead, self.d_model, + self.d_model // self.nhead, dropout=dropout, + pre_lnorm=True, + bias=False + ) + ) + + self.pff_layers.append( + FeedForward( + self.d_model, d_hid, + dropout=dropout, + pre_lnorm=True + ) + ) + else: + self.att_layers = None + self.input_layer = nn.Linear(input_dim, d_model) + encoder_layers = TransformerEncoderLayer( + d_model, nhead, d_hid, dropout, activation="gelu" + ) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + self.decoder = nn.Linear(d_model, out_dim) + + self.init_weights() + + def init_weights(self) -> None: + initrange = 0.1 + self.decoder.bias.data.zero_() + self.decoder.weight.data.uniform_(-initrange, initrange) + + def forward(self, src: Tensor, src_mask: Tensor, data_mask=None, atten_mask=None) -> Tensor: + """ + Args: + src: Tensor, shape [seq_len, batch_size, embedding_dim] + src_mask: Tensor, shape [seq_len, seq_len] + + Returns: + output Tensor of shape [seq_len, batch_size, embedding_dim] + """ + if not data_mask is None: + src = torch.cat([src, data_mask.expand(*src.shape[:-1], data_mask.shape[-1])], dim=-1) + src = self.input_layer(src) + output = self.pos_embedding(src) + # output = src + if self.att_layers: + assert not atten_mask is None + output = output.permute(1, 0, 2) + for i in range(self.nlayers): + output = self.att_layers[i](output, mask=atten_mask) + output = self.pff_layers[i](output) + if self.pre_lnorm: + output = self.layer_norm(output) + output = output.permute(1, 0, 2) + output = self.transformer_encoder(output) + output = self.decoder(output) + return output diff --git a/infiller/lib/model/positional_encoding.py b/infiller/lib/model/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..a40cfeea570bc9f52af1240dc8529baf4b38d7a0 --- /dev/null +++ b/infiller/lib/model/positional_encoding.py @@ -0,0 +1,42 @@ +import torch +from torch import nn, Tensor +import math + + +class PositionalEmbedding(nn.Module): + def __init__(self, seq_len: int = 32, d_model: int = 96): + super().__init__() + self.pos_emb = nn.Embedding(seq_len + 1, d_model) + + def forward(self, inputs): + positions = ( + torch.arange(inputs.size(0), device=inputs.device) + .expand(inputs.size(1), inputs.size(0)) + .contiguous() + + 1 + ) + outputs = inputs + self.pos_emb(positions).permute(1, 0, 2) + return outputs + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) + ) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor, shape [seq_len, batch_size, embedding_dim] + """ + x = x + self.pe[: x.size(0)] + return self.dropout(x) diff --git a/infiller/lib/model/preprocess.py b/infiller/lib/model/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..44cdae00f464fbf55177edf83a1dea72a5c9090a --- /dev/null +++ b/infiller/lib/model/preprocess.py @@ -0,0 +1,189 @@ +import torch + + +def replace_constant(minibatch_pose_input, mask_start_frame): + + seq_len = minibatch_pose_input.size(1) + interpolated = ( + torch.ones_like(minibatch_pose_input, device=minibatch_pose_input.device) * 0.1 + ) + + if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): + interpolate_start = minibatch_pose_input[:, 0, :] + interpolate_end = minibatch_pose_input[:, seq_len - 1, :] + + interpolated[:, 0, :] = interpolate_start + interpolated[:, seq_len - 1, :] = interpolate_end + + assert torch.allclose(interpolated[:, 0, :], interpolate_start) + assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end) + + else: + interpolate_start1 = minibatch_pose_input[:, 0, :] + interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :] + + interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :] + interpolate_end2 = minibatch_pose_input[:, seq_len - 1, :] + + interpolated[:, 0, :] = interpolate_start1 + interpolated[:, mask_start_frame, :] = interpolate_end1 + + interpolated[:, mask_start_frame, :] = interpolate_start2 + interpolated[:, seq_len - 1, :] = interpolate_end2 + + assert torch.allclose(interpolated[:, 0, :], interpolate_start1) + assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1) + + assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2) + assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end2) + return interpolated + + +def slerp(x, y, a): + """ + Perfroms spherical linear interpolation (SLERP) between x and y, with proportion a + + :param x: quaternion tensor + :param y: quaternion tensor + :param a: indicator (between 0 and 1) of completion of the interpolation. + :return: tensor of interpolation results + """ + device = x.device + len = torch.sum(x * y, dim=-1) + + neg = len < 0.0 + len[neg] = -len[neg] + y[neg] = -y[neg] + + a = torch.zeros_like(x[..., 0]) + a + amount0 = torch.zeros(a.shape, device=device) + amount1 = torch.zeros(a.shape, device=device) + + linear = (1.0 - len) < 0.01 + omegas = torch.arccos(len[~linear]) + sinoms = torch.sin(omegas) + + amount0[linear] = 1.0 - a[linear] + amount0[~linear] = torch.sin((1.0 - a[~linear]) * omegas) / sinoms + + amount1[linear] = a[linear] + amount1[~linear] = torch.sin(a[~linear] * omegas) / sinoms + # res = amount0[..., np.newaxis] * x + amount1[..., np.newaxis] * y + res = amount0.unsqueeze(3) * x + amount1.unsqueeze(3) * y + + return res + + +def slerp_input_repr(minibatch_pose_input, mask_start_frame): + seq_len = minibatch_pose_input.size(1) + minibatch_pose_input = minibatch_pose_input.reshape( + minibatch_pose_input.size(0), seq_len, -1, 4 + ) + interpolated = torch.zeros_like( + minibatch_pose_input, device=minibatch_pose_input.device + ) + + if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): + interpolate_start = minibatch_pose_input[:, 0:1] + interpolate_end = minibatch_pose_input[:, seq_len - 1 :] + + for i in range(seq_len): + dt = 1 / (seq_len - 1) + interpolated[:, i : i + 1, :] = slerp( + interpolate_start, interpolate_end, dt * i + ) + + assert torch.allclose(interpolated[:, 0:1], interpolate_start) + assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end) + else: + interpolate_start1 = minibatch_pose_input[:, 0:1] + interpolate_end1 = minibatch_pose_input[ + :, mask_start_frame : mask_start_frame + 1 + ] + + interpolate_start2 = minibatch_pose_input[ + :, mask_start_frame : mask_start_frame + 1 + ] + interpolate_end2 = minibatch_pose_input[:, seq_len - 1 :] + + for i in range(mask_start_frame + 1): + dt = 1 / mask_start_frame + interpolated[:, i : i + 1, :] = slerp( + interpolate_start1, interpolate_end1, dt * i + ) + + assert torch.allclose(interpolated[:, 0:1], interpolate_start1) + assert torch.allclose( + interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_end1 + ) + + for i in range(mask_start_frame, seq_len): + dt = 1 / (seq_len - mask_start_frame - 1) + interpolated[:, i : i + 1, :] = slerp( + interpolate_start2, interpolate_end2, dt * (i - mask_start_frame) + ) + + assert torch.allclose( + interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_start2 + ) + assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end2) + + interpolated = torch.nn.functional.normalize(interpolated, p=2.0, dim=3) + return interpolated.reshape(minibatch_pose_input.size(0), seq_len, -1) + + +def lerp_input_repr(minibatch_pose_input, mask_start_frame): + seq_len = minibatch_pose_input.size(1) + interpolated = torch.zeros_like( + minibatch_pose_input, device=minibatch_pose_input.device + ) + + if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): + interpolate_start = minibatch_pose_input[:, 0, :] + interpolate_end = minibatch_pose_input[:, seq_len - 1, :] + + for i in range(seq_len): + dt = 1 / (seq_len - 1) + interpolated[:, i, :] = torch.lerp( + interpolate_start, interpolate_end, dt * i + ) + + assert torch.allclose(interpolated[:, 0, :], interpolate_start) + assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end) + else: + interpolate_start1 = minibatch_pose_input[:, 0, :] + interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :] + + interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :] + interpolate_end2 = minibatch_pose_input[:, -1, :] + + for i in range(mask_start_frame + 1): + dt = 1 / mask_start_frame + interpolated[:, i, :] = torch.lerp( + interpolate_start1, interpolate_end1, dt * i + ) + + assert torch.allclose(interpolated[:, 0, :], interpolate_start1) + assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1) + + for i in range(mask_start_frame, seq_len): + dt = 1 / (seq_len - mask_start_frame - 1) + interpolated[:, i, :] = torch.lerp( + interpolate_start2, interpolate_end2, dt * (i - mask_start_frame) + ) + + assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2) + assert torch.allclose(interpolated[:, -1, :], interpolate_end2) + return interpolated + + +def vectorize_representation(global_position, global_rotation): + + batch_size = global_position.shape[0] + seq_len = global_position.shape[1] + + global_pos_vec = global_position.reshape(batch_size, seq_len, -1).contiguous() + global_rot_vec = global_rotation.reshape(batch_size, seq_len, -1).contiguous() + + global_pose_vec_gt = torch.cat([global_pos_vec, global_rot_vec], dim=2) + return global_pose_vec_gt diff --git a/infiller/lib/model/skeleton.py b/infiller/lib/model/skeleton.py new file mode 100644 index 0000000000000000000000000000000000000000..d69797b4af0672c54c3635b5e4a39d3966e90e9b --- /dev/null +++ b/infiller/lib/model/skeleton.py @@ -0,0 +1,349 @@ +import torch +import numpy as np +from cmib.data.quaternion import qmul, qrot +import torch.nn as nn + +amass_offsets = [ + [0.0, 0.0, 0.0], + + [0.058581, -0.082280, -0.017664], + [0.043451, -0.386469, 0.008037], + [-0.014790, -0.426874, -0.037428], + [0.041054, -0.060286, 0.122042], + [0.0, 0.0, 0.0], + + [-0.060310, -0.090513, -0.013543], + [-0.043257, -0.383688, -0.004843], + [0.019056, -0.420046, -0.034562], + [-0.034840, -0.062106, 0.130323], + [0.0, 0.0, 0.0], + + [0.004439, 0.124404, -0.038385], + [0.004488, 0.137956, 0.026820], + [-0.002265, 0.056032, 0.002855], + [-0.013390, 0.211636, -0.033468], + [0.010113, 0.088937, 0.050410], + [0.0, 0.0, 0.0], + + [0.071702, 0.114000, -0.018898], + [0.122921, 0.045205, -0.019046], + [0.255332, -0.015649, -0.022946], + [0.265709, 0.012698, -0.007375], + [0.0, 0.0, 0.0], + + [-0.082954, 0.112472, -0.023707], + [-0.113228, 0.046853, -0.008472], + [-0.260127, -0.014369, -0.031269], + [-0.269108, 0.006794, -0.006027], + [0.0, 0.0, 0.0] +] + +sk_offsets = [ + [-42.198200, 91.614723, -40.067841], + + [0.103456, 1.857829, 10.548506], + [43.499992, -0.000038, -0.000002], + [42.372192, 0.000015, -0.000007], + [17.299999, -0.000002, 0.000003], + [0.000000, 0.000000, 0.000000], + + [0.103457, 1.857829, -10.548503], + [43.500042, -0.000027, 0.000008], + [42.372257, -0.000008, 0.000014], + [17.299992, -0.000005, 0.000004], + [0.000000, 0.000000, 0.000000], + + [6.901968, -2.603733, -0.000001], + [12.588099, 0.000002, 0.000000], + [12.343206, 0.000000, -0.000001], + [25.832886, -0.000004, 0.000003], + [11.766620, 0.000005, -0.000001], + [0.000000, 0.000000, 0.000000], + + [19.745899, -1.480370, 6.000108], + [11.284125, -0.000009, -0.000018], + [33.000050, 0.000004, 0.000032], + [25.200008, 0.000015, 0.000008], + [0.000000, 0.000000, 0.000000], + + [19.746099, -1.480375, -6.000073], + [11.284138, -0.000015, -0.000012], + [33.000092, 0.000017, 0.000013], + [25.199780, 0.000135, 0.000422], + [0.000000, 0.000000, 0.000000], +] + +sk_parents = [ + -1, + 0, + 1, + 2, + 3, + 4, + 0, + 6, + 7, + 8, + 9, + 0, + 11, + 12, + 13, + 14, + 15, + 13, + 17, + 18, + 19, + 20, + 13, + 22, + 23, + 24, + 25, +] + +sk_joints_to_remove = [5, 10, 16, 21, 26] + +joint_names = [ + "Hips", + "LeftUpLeg", + "LeftLeg", + "LeftFoot", + "LeftToe", + "RightUpLeg", + "RightLeg", + "RightFoot", + "RightToe", + "Spine", + "Spine1", + "Spine2", + "Neck", + "Head", + "LeftShoulder", + "LeftArm", + "LeftForeArm", + "LeftHand", + "RightShoulder", + "RightArm", + "RightForeArm", + "RightHand", +] + + +class Skeleton: + def __init__( + self, + offsets, + parents, + joints_left=None, + joints_right=None, + bone_length=None, + device=None, + ): + assert len(offsets) == len(parents) + + self._offsets = torch.Tensor(offsets).to(device) + self._parents = np.array(parents) + self._joints_left = joints_left + self._joints_right = joints_right + self._compute_metadata() + + def num_joints(self): + return self._offsets.shape[0] + + def offsets(self): + return self._offsets + + def parents(self): + return self._parents + + def has_children(self): + return self._has_children + + def children(self): + return self._children + + def convert_to_global_pos(self, unit_vec_rerp): + """ + Convert the unit offset matrix to global position. + First row(root) will have absolute position value in global coordinates. + """ + bone_length = self.get_bone_length_weight() + batch_size = unit_vec_rerp.size(0) + seq_len = unit_vec_rerp.size(1) + unit_vec_table = unit_vec_rerp.reshape(batch_size, seq_len, 22, 3) + global_position = torch.zeros_like(unit_vec_table, device=unit_vec_table.device) + + for i, parent in enumerate(self._parents): + if parent == -1: # if root + global_position[:, :, i] = unit_vec_table[:, :, i] + + else: + global_position[:, :, i] = global_position[:, :, parent] + ( + nn.functional.normalize(unit_vec_table[:, :, i], p=2.0, dim=-1) + * bone_length[i] + ) + + return global_position + + def convert_to_unit_offset_mat(self, global_position): + """ + Convert the global position of the skeleton to a unit offset matrix. + First row(root) will have absolute position value in global coordinates. + """ + + bone_length = self.get_bone_length_weight() + unit_offset_mat = torch.zeros_like( + global_position, device=global_position.device + ) + + for i, parent in enumerate(self._parents): + + if parent == -1: # if root + unit_offset_mat[:, :, i] = global_position[:, :, i] + else: + unit_offset_mat[:, :, i] = ( + global_position[:, :, i] - global_position[:, :, parent] + ) / bone_length[i] + + return unit_offset_mat + + def remove_joints(self, joints_to_remove): + """ + Remove the joints specified in 'joints_to_remove', both from the + skeleton definition and from the dataset (which is modified in place). + The rotations of removed joints are propagated along the kinematic chain. + """ + valid_joints = [] + for joint in range(len(self._parents)): + if joint not in joints_to_remove: + valid_joints.append(joint) + + index_offsets = np.zeros(len(self._parents), dtype=int) + new_parents = [] + for i, parent in enumerate(self._parents): + if i not in joints_to_remove: + new_parents.append(parent - index_offsets[parent]) + else: + index_offsets[i:] += 1 + self._parents = np.array(new_parents) + + self._offsets = self._offsets[valid_joints] + self._compute_metadata() + + def forward_kinematics(self, rotations, root_positions): + """ + Perform forward kinematics using the given trajectory and local rotations. + Arguments (where N = batch size, L = sequence length, J = number of joints): + -- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint. + -- root_positions: (N, L, 3) tensor describing the root joint positions. + """ + assert len(rotations.shape) == 4 + assert rotations.shape[-1] == 4 + + positions_world = [] + rotations_world = [] + + expanded_offsets = self._offsets.expand( + rotations.shape[0], + rotations.shape[1], + self._offsets.shape[0], + self._offsets.shape[1], + ) + + # Parallelize along the batch and time dimensions + for i in range(self._offsets.shape[0]): + if self._parents[i] == -1: + positions_world.append(root_positions) + rotations_world.append(rotations[:, :, 0]) + else: + positions_world.append( + qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i]) + + positions_world[self._parents[i]] + ) + if self._has_children[i]: + rotations_world.append( + qmul(rotations_world[self._parents[i]], rotations[:, :, i]) + ) + else: + # This joint is a terminal node -> it would be useless to compute the transformation + rotations_world.append(None) + + return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2) + + def forward_kinematics_with_rotation(self, rotations, root_positions): + """ + Perform forward kinematics using the given trajectory and local rotations. + Arguments (where N = batch size, L = sequence length, J = number of joints): + -- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint. + -- root_positions: (N, L, 3) tensor describing the root joint positions. + """ + assert len(rotations.shape) == 4 + assert rotations.shape[-1] == 4 + + positions_world = [] + rotations_world = [] + + expanded_offsets = self._offsets.expand( + rotations.shape[0], + rotations.shape[1], + self._offsets.shape[0], + self._offsets.shape[1], + ) + + # Parallelize along the batch and time dimensions + for i in range(self._offsets.shape[0]): + if self._parents[i] == -1: + positions_world.append(root_positions) + rotations_world.append(rotations[:, :, 0]) + else: + positions_world.append( + qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i]) + + positions_world[self._parents[i]] + ) + if self._has_children[i]: + rotations_world.append( + qmul(rotations_world[self._parents[i]], rotations[:, :, i]) + ) + else: + # This joint is a terminal node -> it would be useless to compute the transformation + rotations_world.append( + torch.Tensor([1, 0, 0, 0]) + .expand(rotations.shape[0], rotations.shape[1], 4) + .to(rotations.device) + ) + + return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2), torch.stack( + rotations_world, dim=3 + ).permute(0, 1, 3, 2) + + def get_bone_length_weight(self): + bone_length = [] + for i, parent in enumerate(self._parents): + if parent == -1: + bone_length.append(1) + else: + bone_length.append( + torch.linalg.norm(self._offsets[i : i + 1], ord="fro").item() + ) + return torch.Tensor(bone_length) + + def joints_left(self): + return self._joints_left + + def joints_right(self): + return self._joints_right + + def _compute_metadata(self): + self._has_children = np.zeros(len(self._parents)).astype(bool) + for i, parent in enumerate(self._parents): + if parent != -1: + self._has_children[parent] = True + + self._children = [] + for i, parent in enumerate(self._parents): + self._children.append([]) + for i, parent in enumerate(self._parents): + if parent != -1: + self._children[parent].append(i) diff --git a/infiller/lib/vis/pose.py b/infiller/lib/vis/pose.py new file mode 100644 index 0000000000000000000000000000000000000000..767edd1090b93044e4584624e7eb580c0285cb3b --- /dev/null +++ b/infiller/lib/vis/pose.py @@ -0,0 +1,248 @@ +import os +import pathlib + +import matplotlib.pyplot as plt +import numpy as np + + +def project_root_position(position_arr: np.array, file_name: str): + """ + Take batch of root arrays and porject it on 2D plane + + N: samples + L: trajectory length + J: joints + + position_arr: [N,L,J,3] + """ + + root_joints = position_arr[:, :, 0] + + x_pos = root_joints[:, :, 0] + y_pos = root_joints[:, :, 2] + + fig = plt.figure() + + for i in range(x_pos.shape[1]): + + if i == 0: + plt.scatter(x_pos[:, i], y_pos[:, i], c="b") + elif i == x_pos.shape[1] - 1: + plt.scatter(x_pos[:, i], y_pos[:, i], c="r") + else: + plt.scatter(x_pos[:, i], y_pos[:, i], c="k", marker="*", s=1) + + plt.title(f"Root Position: {file_name}") + plt.xlabel("X Axis") + plt.ylabel("Y Axis") + plt.xlim((-300, 300)) + plt.ylim((-300, 300)) + plt.grid() + plt.savefig(f"{file_name}.png", dpi=200) + + +def plot_single_pose( + pose, + frame_idx, + skeleton, + save_dir, + prefix, +): + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + + parent_idx = skeleton.parents() + + for i, p in enumerate(parent_idx): + if i > 0: + ax.plot( + [pose[i, 0], pose[p, 0]], + [pose[i, 2], pose[p, 2]], + [pose[i, 1], pose[p, 1]], + c="k", + ) + + x_min = pose[:, 0].min() + x_max = pose[:, 0].max() + + y_min = pose[:, 1].min() + y_max = pose[:, 1].max() + + z_min = pose[:, 2].min() + z_max = pose[:, 2].max() + + ax.set_xlim(x_min, x_max) + ax.set_xlabel("$X$ Axis") + + ax.set_ylim(z_min, z_max) + ax.set_ylabel("$Y$ Axis") + + ax.set_zlim(y_min, y_max) + ax.set_zlabel("$Z$ Axis") + + plt.draw() + + title = f"{prefix}: {frame_idx}" + plt.title(title) + prefix = prefix + pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True) + plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60) + plt.close() + + +def plot_pose( + start_pose, + inbetween_pose, + target_pose, + frame_idx, + skeleton, + save_dir, + prefix, +): + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + + parent_idx = skeleton.parents() + + for i, p in enumerate(parent_idx): + if i > 0: + ax.plot( + [start_pose[i, 0], start_pose[p, 0]], + [start_pose[i, 2], start_pose[p, 2]], + [start_pose[i, 1], start_pose[p, 1]], + c="b", + ) + ax.plot( + [inbetween_pose[i, 0], inbetween_pose[p, 0]], + [inbetween_pose[i, 2], inbetween_pose[p, 2]], + [inbetween_pose[i, 1], inbetween_pose[p, 1]], + c="k", + ) + ax.plot( + [target_pose[i, 0], target_pose[p, 0]], + [target_pose[i, 2], target_pose[p, 2]], + [target_pose[i, 1], target_pose[p, 1]], + c="r", + ) + + x_min = np.min( + [start_pose[:, 0].min(), inbetween_pose[:, 0].min(), target_pose[:, 0].min()] + ) + x_max = np.max( + [start_pose[:, 0].max(), inbetween_pose[:, 0].max(), target_pose[:, 0].max()] + ) + + y_min = np.min( + [start_pose[:, 1].min(), inbetween_pose[:, 1].min(), target_pose[:, 1].min()] + ) + y_max = np.max( + [start_pose[:, 1].max(), inbetween_pose[:, 1].max(), target_pose[:, 1].max()] + ) + + z_min = np.min( + [start_pose[:, 2].min(), inbetween_pose[:, 2].min(), target_pose[:, 2].min()] + ) + z_max = np.max( + [start_pose[:, 2].max(), inbetween_pose[:, 2].max(), target_pose[:, 2].max()] + ) + + ax.set_xlim(x_min, x_max) + ax.set_xlabel("$X$ Axis") + + ax.set_ylim(z_min, z_max) + ax.set_ylabel("$Y$ Axis") + + ax.set_zlim(y_min, y_max) + ax.set_zlabel("$Z$ Axis") + + plt.draw() + + title = f"{prefix}: {frame_idx}" + plt.title(title) + prefix = prefix + pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True) + plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60) + plt.close() + + +def plot_pose_with_stop( + start_pose, + inbetween_pose, + target_pose, + stopover, + frame_idx, + skeleton, + save_dir, + prefix, +): + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + + parent_idx = skeleton.parents() + + for i, p in enumerate(parent_idx): + if i > 0: + ax.plot( + [start_pose[i, 0], start_pose[p, 0]], + [start_pose[i, 2], start_pose[p, 2]], + [start_pose[i, 1], start_pose[p, 1]], + c="b", + ) + ax.plot( + [inbetween_pose[i, 0], inbetween_pose[p, 0]], + [inbetween_pose[i, 2], inbetween_pose[p, 2]], + [inbetween_pose[i, 1], inbetween_pose[p, 1]], + c="k", + ) + ax.plot( + [target_pose[i, 0], target_pose[p, 0]], + [target_pose[i, 2], target_pose[p, 2]], + [target_pose[i, 1], target_pose[p, 1]], + c="r", + ) + + ax.plot( + [stopover[i, 0], stopover[p, 0]], + [stopover[i, 2], stopover[p, 2]], + [stopover[i, 1], stopover[p, 1]], + c="indigo", + ) + + x_min = np.min( + [start_pose[:, 0].min(), inbetween_pose[:, 0].min(), target_pose[:, 0].min()] + ) + x_max = np.max( + [start_pose[:, 0].max(), inbetween_pose[:, 0].max(), target_pose[:, 0].max()] + ) + + y_min = np.min( + [start_pose[:, 1].min(), inbetween_pose[:, 1].min(), target_pose[:, 1].min()] + ) + y_max = np.max( + [start_pose[:, 1].max(), inbetween_pose[:, 1].max(), target_pose[:, 1].max()] + ) + + z_min = np.min( + [start_pose[:, 2].min(), inbetween_pose[:, 2].min(), target_pose[:, 2].min()] + ) + z_max = np.max( + [start_pose[:, 2].max(), inbetween_pose[:, 2].max(), target_pose[:, 2].max()] + ) + + ax.set_xlim(x_min, x_max) + ax.set_xlabel("$X$ Axis") + + ax.set_ylim(z_min, z_max) + ax.set_ylabel("$Y$ Axis") + + ax.set_zlim(y_min, y_max) + ax.set_zlabel("$Z$ Axis") + + plt.draw() + + title = f"{prefix}: {frame_idx}" + plt.title(title) + prefix = prefix + pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True) + plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60) + plt.close() diff --git a/lib/core/__pycache__/constants.cpython-310.pyc b/lib/core/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3d76e2ff1c3a5d84ff35a861b049b49923bada5 Binary files /dev/null and b/lib/core/__pycache__/constants.cpython-310.pyc differ diff --git a/lib/core/constants.py b/lib/core/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..818df466acf09a3952ba49075fe8b66902e19d73 --- /dev/null +++ b/lib/core/constants.py @@ -0,0 +1,78 @@ +FOCAL_LENGTH = 5000. + +# Mean and standard deviation for normalizing input image +IMG_NORM_MEAN = [0.485, 0.456, 0.406] +IMG_NORM_STD = [0.229, 0.224, 0.225] + +""" +We create a superset of joints containing the OpenPose joints together with the ones that each dataset provides. +We keep a superset of 24 joints such that we include all joints from every dataset. +If a dataset doesn't provide annotations for a specific joint, we simply ignore it. +The joints used here are the following: +""" +JOINT_NAMES = [ +'OP Nose', 'OP Neck', 'OP RShoulder', #0,1,2 +'OP RElbow', 'OP RWrist', 'OP LShoulder', #3,4,5 +'OP LElbow', 'OP LWrist', 'OP MidHip', #6, 7,8 +'OP RHip', 'OP RKnee', 'OP RAnkle', #9,10,11 +'OP LHip', 'OP LKnee', 'OP LAnkle', #12,13,14 +'OP REye', 'OP LEye', 'OP REar', #15,16,17 +'OP LEar', 'OP LBigToe', 'OP LSmallToe', #18,19,20 +'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel', #21, 22, 23, 24 ##Total 25 joints for openpose +'Right Ankle', 'Right Knee', 'Right Hip', #0,1,2 +'Left Hip', 'Left Knee', 'Left Ankle', #3, 4, 5 +'Right Wrist', 'Right Elbow', 'Right Shoulder', #6 +'Left Shoulder', 'Left Elbow', 'Left Wrist', #9 +'Neck (LSP)', 'Top of Head (LSP)', #12, 13 +'Pelvis (MPII)', 'Thorax (MPII)', #14, 15 +'Spine (H36M)', 'Jaw (H36M)', #16, 17 +'Head (H36M)', 'Nose', 'Left Eye', #18, 19, 20 +'Right Eye', 'Left Ear', 'Right Ear' #21,22,23 (Total 24 joints) +] + +# Dict containing the joints in numerical order +JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))} + +# Map joints to SMPL joints +JOINT_MAP = { +'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17, +'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16, +'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0, +'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8, +'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7, +'OP REye': 25, 'OP LEye': 26, 'OP REar': 27, +'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30, +'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34, +'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45, +'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7, +'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17, +'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20, +'Neck (LSP)': 47, 'Top of Head (LSP)': 48, +'Pelvis (MPII)': 49, 'Thorax (MPII)': 50, +'Spine (H36M)': 51, 'Jaw (H36M)': 52, +'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26, +'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27 +} + +# Joint selectors +# Indices to get the 14 LSP joints from the 17 H36M joints +H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9] +H36M_TO_J14 = H36M_TO_J17[:14] +# Indices to get the 14 LSP joints from the ground truth joints +J24_TO_J17 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18, 14, 16, 17] +J24_TO_J14 = J24_TO_J17[:14] + +# Permutation of SMPL pose parameters when flipping the shape +SMPL_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 23, 22] +SMPL_POSE_FLIP_PERM = [] +for i in SMPL_JOINTS_FLIP_PERM: + SMPL_POSE_FLIP_PERM.append(3*i) + SMPL_POSE_FLIP_PERM.append(3*i+1) + SMPL_POSE_FLIP_PERM.append(3*i+2) +# Permutation indices for the 24 ground truth joints +J24_FLIP_PERM = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22] +# Permutation indices for the full set of 49 joints +J49_FLIP_PERM = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]\ + + [25+i for i in J24_FLIP_PERM] + + diff --git a/lib/datasets/__pycache__/track_dataset.cpython-310.pyc b/lib/datasets/__pycache__/track_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..238cee6d17b180d75b563f6a34b95d85a20866bd Binary files /dev/null and b/lib/datasets/__pycache__/track_dataset.cpython-310.pyc differ diff --git a/lib/datasets/track_dataset.py b/lib/datasets/track_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..076cf8897f87fd03f65eaea7d0fb9d8001eface0 --- /dev/null +++ b/lib/datasets/track_dataset.py @@ -0,0 +1,78 @@ +import torch +from torch.utils.data import Dataset +from torchvision.transforms import Normalize, ToTensor, Compose +import numpy as np +import cv2 + +from lib.core import constants +from lib.utils.imutils import crop, boxes_2_cs + + +class TrackDatasetEval(Dataset): + """ + Track Dataset Class - Load images/crops of the tracked boxes. + """ + def __init__(self, imgfiles, boxes, + crop_size=256, dilate=1.0, + img_focal=None, img_center=None, normalization=True, + item_idx=0, do_flip=False): + super(TrackDatasetEval, self).__init__() + + self.imgfiles = imgfiles + self.crop_size = crop_size + self.normalization = normalization + self.normalize_img = Compose([ + ToTensor(), + Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD) + ]) + + self.boxes = boxes + self.box_dilate = dilate + self.centers, self.scales = boxes_2_cs(boxes) + + self.img_focal = img_focal + self.img_center = img_center + self.item_idx = item_idx + self.do_flip = do_flip + + def __len__(self): + return len(self.imgfiles) + + + def __getitem__(self, index): + item = {} + imgfile = self.imgfiles[index] + scale = self.scales[index] * self.box_dilate + center = self.centers[index] + + img_focal = self.img_focal + img_center = self.img_center + + img = cv2.imread(imgfile)[:,:,::-1] + if self.do_flip: + img = img[:, ::-1, :] + img_width = img.shape[1] + center[0] = img_width - center[0] - 1 + img_crop = crop(img, center, scale, + [self.crop_size, self.crop_size], + rot=0).astype('uint8') + # cv2.imwrite('debug_crop.png', img_crop[:,:,::-1]) + + if self.normalization: + img_crop = self.normalize_img(img_crop) + else: + img_crop = torch.from_numpy(img_crop) + item['img'] = img_crop + + if self.do_flip: + # center[0] = img_width - center[0] - 1 + item['do_flip'] = torch.tensor(1).float() + item['img_idx'] = torch.tensor(index).long() + item['scale'] = torch.tensor(scale).float() + item['center'] = torch.tensor(center).float() + item['img_focal'] = torch.tensor(img_focal).float() + item['img_center'] = torch.tensor(img_center).float() + + + return item + diff --git a/lib/eval_utils/__pycache__/custom_utils.cpython-310.pyc b/lib/eval_utils/__pycache__/custom_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..434627c371a53973a3f661331ef3b8c1920ee8ac Binary files /dev/null and b/lib/eval_utils/__pycache__/custom_utils.cpython-310.pyc differ diff --git a/lib/eval_utils/__pycache__/filling_utils.cpython-310.pyc b/lib/eval_utils/__pycache__/filling_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd71e1bcff744702fd95d00968a8b5955282a967 Binary files /dev/null and b/lib/eval_utils/__pycache__/filling_utils.cpython-310.pyc differ diff --git a/lib/eval_utils/custom_utils.py b/lib/eval_utils/custom_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9fcae184117ef0186ded0e661d0e73060e158529 --- /dev/null +++ b/lib/eval_utils/custom_utils.py @@ -0,0 +1,99 @@ +import copy +import numpy as np +import torch + +from hawor.utils.process import run_mano, run_mano_left +from hawor.utils.rotation import angle_axis_to_quaternion, rotation_matrix_to_angle_axis +from scipy.interpolate import interp1d + + +def cam2world_convert(R_c2w_sla, t_c2w_sla, data_out, handedness): + init_rot_mat = copy.deepcopy(data_out["init_root_orient"]) + init_rot_mat = torch.einsum("tij,btjk->btik", R_c2w_sla, init_rot_mat) + init_rot = rotation_matrix_to_angle_axis(init_rot_mat) + init_rot_quat = angle_axis_to_quaternion(init_rot) + # data_out["init_root_orient"] = rotation_matrix_to_angle_axis(data_out["init_root_orient"]) + # data_out["init_hand_pose"] = rotation_matrix_to_angle_axis(data_out["init_hand_pose"]) + data_out_init_root_orient = rotation_matrix_to_angle_axis(data_out["init_root_orient"]) + data_out_init_hand_pose = rotation_matrix_to_angle_axis(data_out["init_hand_pose"]) + + init_trans = data_out["init_trans"] # (B, T, 3) + if handedness == "right": + outputs = run_mano(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"]) + elif handedness == "left": + outputs = run_mano_left(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"]) + root_loc = outputs["joints"][..., 0, :].cpu() # (B, T, 3) + offset = init_trans - root_loc # It is a constant, no matter what the rotation is. + init_trans = ( + torch.einsum("tij,btj->bti", R_c2w_sla, root_loc) + + t_c2w_sla[None, :] + + offset + ) + + data_world = { + "init_root_orient": init_rot, # (B, T, 3) + "init_hand_pose": data_out_init_hand_pose, # (B, T, 15, 3) + "init_trans": init_trans, # (B, T, 3) + "init_betas": data_out["init_betas"] # (B, T, 10) + } + + return data_world + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + +def load_slam_cam(fpath): + print(f"Loading cameras from {fpath}...") + pred_cam = dict(np.load(fpath, allow_pickle=True)) + pred_traj = pred_cam['traj'] + t_c2w_sla = torch.tensor(pred_traj[:, :3]) * pred_cam['scale'] + pred_camq = torch.tensor(pred_traj[:, 3:]) + R_c2w_sla = quaternion_to_matrix(pred_camq[:,[3,0,1,2]]) + R_w2c_sla = R_c2w_sla.transpose(-1, -2) + t_w2c_sla = -torch.einsum("bij,bj->bi", R_w2c_sla, t_c2w_sla) + return R_w2c_sla, t_w2c_sla, R_c2w_sla, t_c2w_sla + + +def interpolate_bboxes(bboxes): + T = bboxes.shape[0] + + zero_indices = np.where(np.all(bboxes == 0, axis=1))[0] + + non_zero_indices = np.where(np.any(bboxes != 0, axis=1))[0] + + if len(zero_indices) == 0: + return bboxes + + interpolated_bboxes = bboxes.copy() + for i in range(5): + interp_func = interp1d(non_zero_indices, bboxes[non_zero_indices, i], kind='linear', fill_value="extrapolate") + interpolated_bboxes[zero_indices, i] = interp_func(zero_indices) + + return interpolated_bboxes \ No newline at end of file diff --git a/lib/eval_utils/filling_utils.py b/lib/eval_utils/filling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..672d98753d2c2c088d011c2d8b2d00bac16c2257 --- /dev/null +++ b/lib/eval_utils/filling_utils.py @@ -0,0 +1,306 @@ +import copy +import os +import joblib +import numpy as np +from scipy.spatial.transform import Slerp, Rotation +import torch + +from hawor.utils.process import run_mano, run_mano_left +from hawor.utils.rotation import angle_axis_to_quaternion, angle_axis_to_rotation_matrix, quaternion_to_rotation_matrix, rotation_matrix_to_angle_axis +from lib.utils.geometry import rotmat_to_rot6d +from lib.utils.geometry import rot6d_to_rotmat + +def slerp_interpolation_aa(pos, valid): + + B, T, N, _ = pos.shape # B: 批次大小, T: 时间步长, N: 关节数, 4: 四元数维度 + pos_interp = pos.copy() # 创建副本以存储插值结果 + + for b in range(B): + for n in range(N): + quat_b_n = pos[b, :, n, :] + valid_b_n = valid[b, :] + + invalid_idxs = np.where(~valid_b_n)[0] + valid_idxs = np.where(valid_b_n)[0] + + if len(invalid_idxs) == 0: + continue + + if len(valid_idxs) > 1: + valid_times = valid_idxs # 有效时间步 + valid_rots = Rotation.from_rotvec(quat_b_n[valid_idxs]) # 有效四元数 + + slerp = Slerp(valid_times, valid_rots) + + for idx in invalid_idxs: + if idx < valid_idxs[0]: # 时间步小于第一个有效时间步,进行外推 + pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[0]] # 复制第一个有效四元数 + elif idx > valid_idxs[-1]: # 时间步大于最后一个有效时间步,进行外推 + pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[-1]] # 复制最后一个有效四元数 + else: + interp_rot = slerp([idx]) + pos_interp[b, idx, n, :] = interp_rot.as_rotvec()[0] + # print("#######") + # if N > 1: + # print(pos[1,0,11]) + # print(pos_interp[1,0,11]) + + return pos_interp + +def slerp_interpolation_quat(pos, valid): + + # wxyz to xyzw + pos = pos[:, :, :, [1, 2, 3, 0]] + + B, T, N, _ = pos.shape # B: 批次大小, T: 时间步长, N: 关节数, 4: 四元数维度 + pos_interp = pos.copy() # 创建副本以存储插值结果 + + for b in range(B): + for n in range(N): + quat_b_n = pos[b, :, n, :] + valid_b_n = valid[b, :] + + invalid_idxs = np.where(~valid_b_n)[0] + valid_idxs = np.where(valid_b_n)[0] + + if len(invalid_idxs) == 0: + continue + + if len(valid_idxs) > 1: + valid_times = valid_idxs # 有效时间步 + valid_rots = Rotation.from_quat(quat_b_n[valid_idxs]) # 有效四元数 + + slerp = Slerp(valid_times, valid_rots) + + for idx in invalid_idxs: + if idx < valid_idxs[0]: # 时间步小于第一个有效时间步,进行外推 + pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[0]] # 复制第一个有效四元数 + elif idx > valid_idxs[-1]: # 时间步大于最后一个有效时间步,进行外推 + pos_interp[b, idx, n, :] = quat_b_n[valid_idxs[-1]] # 复制最后一个有效四元数 + else: + interp_rot = slerp([idx]) + pos_interp[b, idx, n, :] = interp_rot.as_quat()[0] + + # xyzw to wxyz + pos_interp = pos_interp[:, :, :, [3, 0, 1, 2]] + return pos_interp + + +def linear_interpolation_nd(pos, valid): + B, T = pos.shape[:2] # 取出批次大小B和时间步长T + feature_dim = pos.shape[2] # ** 代表的任意维度 + pos_interp = pos.copy() # 创建一个副本,用来保存插值结果 + + for b in range(B): + for idx in range(feature_dim): # 针对任意维度 + pos_b_idx = pos[b, :, idx] # 取出第b批次对应的**维度下的一个时间序列 + valid_b = valid[b, :] # 当前批次的有效标志 + + # 找到无效的索引(False) + invalid_idxs = np.where(~valid_b)[0] + valid_idxs = np.where(valid_b)[0] + + if len(invalid_idxs) == 0: + continue + + # 对无效部分进行线性插值 + if len(valid_idxs) > 1: # 确保有足够的有效点用于插值 + pos_b_idx[invalid_idxs] = np.interp(invalid_idxs, valid_idxs, pos_b_idx[valid_idxs]) + pos_interp[b, :, idx] = pos_b_idx # 保存插值结果 + + return pos_interp + +def world2canonical_convert(R_c2w_sla, t_c2w_sla, data_out, handedness): + init_rot_mat = copy.deepcopy(data_out["init_root_orient"]) + init_rot_mat = torch.einsum("tij,btjk->btik", R_c2w_sla, init_rot_mat) + init_rot = rotation_matrix_to_angle_axis(init_rot_mat) + init_rot_quat = angle_axis_to_quaternion(init_rot) + # data_out["init_root_orient"] = rotation_matrix_to_angle_axis(data_out["init_root_orient"]) + # data_out["init_hand_pose"] = rotation_matrix_to_angle_axis(data_out["init_hand_pose"]) + data_out_init_root_orient = rotation_matrix_to_angle_axis(data_out["init_root_orient"]) + data_out_init_hand_pose = rotation_matrix_to_angle_axis(data_out["init_hand_pose"]) + + init_trans = data_out["init_trans"] # (B, T, 3) + if handedness == "left": + outputs = run_mano_left(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"]) + + elif handedness == "right": + outputs = run_mano(data_out["init_trans"], data_out_init_root_orient, data_out_init_hand_pose, betas=data_out["init_betas"]) + root_loc = outputs["joints"][..., 0, :].cpu() # (B, T, 3) + offset = init_trans - root_loc # It is a constant, no matter what the rotation is. + init_trans = ( + torch.einsum("tij,btj->bti", R_c2w_sla, root_loc) + + t_c2w_sla[None, :] + + offset + ) + + data_world = { + "init_root_orient": init_rot, # (B, T, 3) + "init_hand_pose": data_out_init_hand_pose, # (B, T, 15, 3) + "init_trans": init_trans, # (B, T, 3) + "init_betas": data_out["init_betas"] # (B, T, 10) + } + + return data_world + +def filling_preprocess(item): + + num_joints = 15 + + global_trans = item['trans'] # (2, seq_len, 3) + global_rot = item['rot'] #(2, seq_len, 3) + hand_pose = item['hand_pose'] # (2, seq_len, 45) + betas = item['betas'] # (2, seq_len, 10) + valid = item['valid'] # (2, seq_len) + + N, T, _ = global_trans.shape + R_canonical2world_left_aa = torch.from_numpy(global_rot[0, 0]) + R_canonical2world_right_aa = torch.from_numpy(global_rot[1, 0]) + R_world2canonical_left = angle_axis_to_rotation_matrix(R_canonical2world_left_aa).t() + R_world2canonical_right = angle_axis_to_rotation_matrix(R_canonical2world_right_aa).t() + + + # transform left hand to canonical + hand_pose = hand_pose.reshape(N, T, num_joints, 3) + data_world_left = { + "init_trans": torch.from_numpy(global_trans[0:1]), + "init_root_orient": angle_axis_to_rotation_matrix(torch.from_numpy(global_rot[0:1])), + "init_hand_pose": angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose[0:1])), + "init_betas": torch.from_numpy(betas[0:1]), + } + + data_left_init_root_orient = rotation_matrix_to_angle_axis(data_world_left["init_root_orient"]) + data_left_init_hand_pose = rotation_matrix_to_angle_axis(data_world_left["init_hand_pose"]) + outputs = run_mano_left(data_world_left["init_trans"], data_left_init_root_orient, data_left_init_hand_pose, betas=data_world_left["init_betas"]) + init_trans = data_world_left["init_trans"][0, 0] # (3,) + root_loc = outputs["joints"][0, 0, 0, :].cpu() # (3,) + offset = init_trans - root_loc # It is a constant, no matter what the rotation is. + t_world2canonical_left = -torch.einsum("ij,j->i", R_world2canonical_left, root_loc) - offset + + R_world2canonical_left = R_world2canonical_left.repeat(T, 1, 1) + t_world2canonical_left = t_world2canonical_left.repeat(T, 1) + data_canonical_left = world2canonical_convert(R_world2canonical_left, t_world2canonical_left, data_world_left, "left") + + # transform right hand to canonical + data_world_right = { + "init_trans": torch.from_numpy(global_trans[1:2]), + "init_root_orient": angle_axis_to_rotation_matrix(torch.from_numpy(global_rot[1:2])), + "init_hand_pose": angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose[1:2])), + "init_betas": torch.from_numpy(betas[1:2]), + } + + data_right_init_root_orient = rotation_matrix_to_angle_axis(data_world_right["init_root_orient"]) + data_right_init_hand_pose = rotation_matrix_to_angle_axis(data_world_right["init_hand_pose"]) + outputs = run_mano(data_world_right["init_trans"], data_right_init_root_orient, data_right_init_hand_pose, betas=data_world_right["init_betas"]) + init_trans = data_world_right["init_trans"][0, 0] # (3,) + root_loc = outputs["joints"][0, 0, 0, :].cpu() # (3,) + offset = init_trans - root_loc # It is a constant, no matter what the rotation is. + t_world2canonical_right = -torch.einsum("ij,j->i", R_world2canonical_right, root_loc) - offset + + R_world2canonical_right = R_world2canonical_right.repeat(T, 1, 1) + t_world2canonical_right = t_world2canonical_right.repeat(T, 1) + data_canonical_right = world2canonical_convert(R_world2canonical_right, t_world2canonical_right, data_world_right, "right") + + # merge left and right canonical data + global_rot = torch.cat((data_canonical_left['init_root_orient'], data_canonical_right['init_root_orient'])) + global_trans = torch.cat((data_canonical_left['init_trans'], data_canonical_right['init_trans'])).numpy() + + # global_rot = angle_axis_to_quaternion(global_rot).numpy().reshape(N, T, 1, 4) + global_rot = global_rot.reshape(N, T, 1, 3).numpy() + + hand_pose = hand_pose.reshape(N, T, 15, 3) + # hand_pose = angle_axis_to_quaternion(torch.from_numpy(hand_pose)).numpy() + + # lerp and slerp + global_trans_lerped = linear_interpolation_nd(global_trans, valid) + betas_lerped = linear_interpolation_nd(betas, valid) + global_rot_slerped = slerp_interpolation_aa(global_rot, valid) + hand_pose_slerped = slerp_interpolation_aa(hand_pose, valid) + + + # convert to rot6d + + global_rot_slerped_mat = angle_axis_to_rotation_matrix(torch.from_numpy(global_rot_slerped.reshape(N*T, -1))) + # global_rot_slerped_mat = quaternion_to_rotation_matrix(torch.from_numpy(global_rot_slerped.reshape(N*T, -1))) + global_rot_slerped_rot6d = rotmat_to_rot6d(global_rot_slerped_mat).reshape(N, T, -1).numpy() + hand_pose_slerped_mat = angle_axis_to_rotation_matrix(torch.from_numpy(hand_pose_slerped.reshape(N*T*num_joints, -1))) + # hand_pose_slerped_mat = quaternion_to_rotation_matrix(torch.from_numpy(hand_pose_slerped.reshape(N*T*num_joints, -1))) + hand_pose_slerped_rot6d = rotmat_to_rot6d(hand_pose_slerped_mat).reshape(N, T, -1).numpy() + + + # concat to (T, concat_dim) + global_pose_vec_input = np.concatenate((global_trans_lerped, betas_lerped, global_rot_slerped_rot6d, hand_pose_slerped_rot6d), axis=-1).transpose(1, 0, 2).reshape(T, -1) + + R_canon2w_left = R_world2canonical_left.transpose(-1, -2) + t_canon2w_left = -torch.einsum("tij,tj->ti", R_canon2w_left, t_world2canonical_left) + R_canon2w_right = R_world2canonical_right.transpose(-1, -2) + t_canon2w_right = -torch.einsum("tij,tj->ti", R_canon2w_right, t_world2canonical_right) + + transform_w_canon = { + "R_w2canon_left": R_world2canonical_left, + "t_w2canon_left": t_world2canonical_left, + "R_canon2w_left": R_canon2w_left, + "t_canon2w_left": t_canon2w_left, + + "R_w2canon_right": R_world2canonical_right, + "t_w2canon_right": t_world2canonical_right, + "R_canon2w_right": R_canon2w_right, + "t_canon2w_right": t_canon2w_right, + } + + return global_pose_vec_input, transform_w_canon + +def custom_rot6d_to_rotmat(rot6d): + original_shape = rot6d.shape[:-1] + rot6d = rot6d.reshape(-1, 6) + mat = rot6d_to_rotmat(rot6d) + mat = mat.reshape(*original_shape, 3, 3) + return mat + +def filling_postprocess(output, transform_w_canon): + # output = output.numpy() + output = output.permute(1, 0, 2) # (2, T, -1) + N, T, _ = output.shape + canon_trans = output[:, :, :3] + betas = output[:, :, 3:13] + canon_rot_rot6d = output[:, :, 13:19] + hand_pose_rot6d = output[:, :, 19:109].reshape(N, T, 15, 6) + + canon_rot_mat = custom_rot6d_to_rotmat(canon_rot_rot6d) + hand_pose_mat = custom_rot6d_to_rotmat(hand_pose_rot6d) + + data_canonical_left = { + "init_trans": canon_trans[[0], :, :], + "init_root_orient": canon_rot_mat[[0], :, :, :], + "init_hand_pose": hand_pose_mat[[0], :, :, :, :], + "init_betas": betas[[0], :, :] + } + + data_canonical_right = { + "init_trans": canon_trans[[1], :, :], + "init_root_orient": canon_rot_mat[[1], :, :, :], + "init_hand_pose": hand_pose_mat[[1], :, :, :, :], + "init_betas": betas[[1], :, :] + } + + R_canon2w_left = transform_w_canon['R_canon2w_left'] + t_canon2w_left = transform_w_canon['t_canon2w_left'] + R_canon2w_right = transform_w_canon['R_canon2w_right'] + t_canon2w_right = transform_w_canon['t_canon2w_right'] + + + world_left = world2canonical_convert(R_canon2w_left, t_canon2w_left, data_canonical_left, "left") + world_right = world2canonical_convert(R_canon2w_right, t_canon2w_right, data_canonical_right, "right") + + global_rot = torch.cat((world_left['init_root_orient'], world_right['init_root_orient'])).numpy() + global_trans = torch.cat((world_left['init_trans'], world_right['init_trans'])).numpy() + + pred_data = { + "trans": global_trans, # (2, T, 3) + "rot": global_rot, # (2, T, 3) + "hand_pose": rotation_matrix_to_angle_axis(hand_pose_mat).flatten(-2).numpy(), # (2, T, 45) + "betas": betas.numpy(), # (2, T, 10) + } + + return pred_data + diff --git a/lib/eval_utils/video_utils.py b/lib/eval_utils/video_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2fd29f9c3088d6a7556b0566a1953950e8f639a8 --- /dev/null +++ b/lib/eval_utils/video_utils.py @@ -0,0 +1,85 @@ +import cv2 +import os +import subprocess + +def make_video_grid_2x2(out_path, vid_paths, overwrite=False): + """ + 将四个视频以原始分辨率拼接成 2x2 网格。 + + :param out_path: 输出视频路径。 + :param vid_paths: 输入视频路径的列表(长度必须为 4)。 + :param overwrite: 如果为 True,覆盖已存在的输出文件。 + """ + if os.path.isfile(out_path) and not overwrite: + print(f"{out_path} already exists, skipping.") + return + + if any(not os.path.isfile(v) for v in vid_paths): + print("Not all inputs exist!", vid_paths) + return + + # 确保视频路径长度为 4 + if len(vid_paths) != 4: + print("Error: Exactly 4 video paths are required!") + return + + # 获取视频路径 + v1, v2, v3, v4 = vid_paths + + # ffmpeg 拼接命令,直接拼接不调整大小 + cmd = ( + f"ffmpeg -i {v1} -i {v2} -i {v3} -i {v4} " + f"-filter_complex '[0:v][1:v][2:v][3:v]xstack=inputs=4:layout=0_0|w0_0|0_h0|w0_h0[v]' " + f"-map '[v]' {out_path} -y" + ) + + print(cmd) + subprocess.call(cmd, shell=True, stdin=subprocess.PIPE) + +def create_video_from_images(image_list, output_path, fps=15, target_resolution=(540, 540)): + """ + 将图片列表合成为 MP4 视频。 + + :param image_list: 图片路径的列表。 + :param output_path: 输出视频的文件路径(如 output.mp4)。 + :param fps: 视频的帧率(默认 15 FPS)。 + """ + # if not image_list: + # print("图片列表为空!") + # return + + # 读取第一张图片以获取宽度和高度 + first_image = cv2.imread(image_list[0]) + if first_image is None: + print(f"无法读取图片: {image_list[0]}") + return + + height, width, _ = first_image.shape + if height != width: + if height < width: + vis_w = target_resolution[0] + vis_h = int(target_resolution[0] / width * height) + elif height > width: + vis_h = target_resolution[0] + vis_w = int(target_resolution[0] / height * width) + else: + vis_h = target_resolution[0] + vis_w = target_resolution[0] + target_resolution = (vis_w, vis_h) + + # 定义视频编码器和输出参数 + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 mp4v 编码器 + video_writer = cv2.VideoWriter(output_path, fourcc, fps, target_resolution) + + # 遍历图片列表并写入视频 + for image_path in image_list: + frame = cv2.imread(image_path) + frame_resized = cv2.resize(frame, target_resolution) + if frame is None: + print(f"无法读取图片: {image_path}") + continue + video_writer.write(frame_resized) + + # 释放视频写入器 + video_writer.release() + print(f"视频已保存至: {output_path}") \ No newline at end of file diff --git a/lib/models/__pycache__/hawor.cpython-310.pyc b/lib/models/__pycache__/hawor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea55f4816cf240a4b2f3786f8f560f0e794a830b Binary files /dev/null and b/lib/models/__pycache__/hawor.cpython-310.pyc differ diff --git a/lib/models/__pycache__/mano_wrapper.cpython-310.pyc b/lib/models/__pycache__/mano_wrapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d945c4f2c9b28af8dd870fe910e7575fa184b649 Binary files /dev/null and b/lib/models/__pycache__/mano_wrapper.cpython-310.pyc differ diff --git a/lib/models/__pycache__/modules.cpython-310.pyc b/lib/models/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab0189e3d153407e46e5e07fd8d5685639c73364 Binary files /dev/null and b/lib/models/__pycache__/modules.cpython-310.pyc differ diff --git a/lib/models/backbones/__init__.py b/lib/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8288cfd84555de1720b81129fa8accf76948cb88 --- /dev/null +++ b/lib/models/backbones/__init__.py @@ -0,0 +1,8 @@ +from .vit import vit + + +def create_backbone(cfg): + if cfg.MODEL.BACKBONE.TYPE == 'vit': + return vit(cfg) + else: + raise NotImplementedError('Backbone type is not implemented') diff --git a/lib/models/backbones/__pycache__/__init__.cpython-310.pyc b/lib/models/backbones/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5b426f6ccab9cf5c36e1444934e14c0bbcc8eec Binary files /dev/null and b/lib/models/backbones/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/models/backbones/__pycache__/vit.cpython-310.pyc b/lib/models/backbones/__pycache__/vit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59f170f6599832d8098e7142b3778fd7a83529b2 Binary files /dev/null and b/lib/models/backbones/__pycache__/vit.cpython-310.pyc differ diff --git a/lib/models/backbones/vit.py b/lib/models/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..c56c71889cd441294f57ad687d0678d2443d1eed --- /dev/null +++ b/lib/models/backbones/vit.py @@ -0,0 +1,348 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +from functools import partial +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + +def vit(cfg): + return ViT( + img_size=(256, 192), + patch_size=16, + embed_dim=1280, + depth=32, + num_heads=16, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.55, + ) + +def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True): + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + cls_token = None + B, L, C = abs_pos.shape + if has_cls_token: + cls_token = abs_pos[:, 0:1] + abs_pos = abs_pos[:, 1:] + + if ori_h != h or ori_w != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2), + size=(h, w), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).reshape(B, -1, C) + + else: + new_abs_pos = abs_pos + + if cls_token is not None: + new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1) + return new_abs_pos + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return 'p={}'.format(self.drop_prob) + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., attn_head_dim=None,): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.dim = dim + + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, + norm_layer=nn.LayerNorm, attn_head_dim=None + ): + super().__init__() + + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim + ) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2) + self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio)) + self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1])) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1)) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, (Hp, Wp) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class ViT(nn.Module): + + def __init__(self, + img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False, + frozen_stages=-1, ratio=1, last_norm=True, + patch_padding='pad', freeze_attn=False, freeze_ffn=False, + ): + # Protect mutable default arguments + super(ViT, self).__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.frozen_stages = frozen_stages + self.use_checkpoint = use_checkpoint + self.patch_padding = patch_padding + self.freeze_attn = freeze_attn + self.freeze_ffn = freeze_ffn + self.depth = depth + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio) + num_patches = self.patch_embed.num_patches + + # since the pretraining model has class token + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + ) + for i in range(depth)]) + + self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + + self._freeze_stages() + + def _freeze_stages(self): + """Freeze parameters.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if self.freeze_attn: + for i in range(0, self.depth): + m = self.blocks[i] + m.attn.eval() + m.norm1.eval() + for param in m.attn.parameters(): + param.requires_grad = False + for param in m.norm1.parameters(): + param.requires_grad = False + + if self.freeze_ffn: + self.pos_embed.requires_grad = False + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + for i in range(0, self.depth): + m = self.blocks[i] + m.mlp.eval() + m.norm2.eval() + for param in m.mlp.parameters(): + param.requires_grad = False + for param in m.norm2.parameters(): + param.requires_grad = False + + def init_weights(self): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward_features(self, x): + B, C, H, W = x.shape + x, (Hp, Wp) = self.patch_embed(x) + + if self.pos_embed is not None: + # fit for multiple GPU training + # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference + x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1] + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + x = self.last_norm(x) + + xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() + + return xp + + def forward(self, x): + x = self.forward_features(x) + return x + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self._freeze_stages() diff --git a/lib/models/components/__init__.py b/lib/models/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/models/components/__pycache__/__init__.cpython-310.pyc b/lib/models/components/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..305210c4ee998e41e6ecfd36a2321e54e24d6578 Binary files /dev/null and b/lib/models/components/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/models/components/__pycache__/pose_transformer.cpython-310.pyc b/lib/models/components/__pycache__/pose_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50a53a44f68fd4da61ddefec37603449f76a1dcc Binary files /dev/null and b/lib/models/components/__pycache__/pose_transformer.cpython-310.pyc differ diff --git a/lib/models/components/__pycache__/t_cond_mlp.cpython-310.pyc b/lib/models/components/__pycache__/t_cond_mlp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d244f3e7957c1d3eb97a77ece03adf3cc0e94fe7 Binary files /dev/null and b/lib/models/components/__pycache__/t_cond_mlp.cpython-310.pyc differ diff --git a/lib/models/components/pose_transformer.py b/lib/models/components/pose_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ac04971407cb59637490cc4842f048b9bc4758be --- /dev/null +++ b/lib/models/components/pose_transformer.py @@ -0,0 +1,358 @@ +from inspect import isfunction +from typing import Callable, Optional + +import torch +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import nn + +from .t_cond_mlp import ( + AdaptiveLayerNorm1D, + FrequencyEmbedder, + normalization_layer, +) +# from .vit import Attention, FeedForward + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +class PreNorm(nn.Module): + def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1): + super().__init__() + self.norm = normalization_layer(norm, dim, norm_cond_dim) + self.fn = fn + + def forward(self, x: torch.Tensor, *args, **kwargs): + if isinstance(self.norm, AdaptiveLayerNorm1D): + return self.fn(self.norm(x, *args), **kwargs) + else: + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.0): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class CrossAttention(nn.Module): + def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + context_dim = default(context_dim, dim) + self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) + self.to_q = nn.Linear(dim, inner_dim, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x, context=None): + context = default(context, x) + k, v = self.to_kv(context).chunk(2, dim=-1) + q = self.to_q(x) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v]) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + norm: str = "layer", + norm_cond_dim: int = -1, + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + ff = FeedForward(dim, mlp_dim, dropout=dropout) + self.layers.append( + nn.ModuleList( + [ + PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), + ] + ) + ) + + def forward(self, x: torch.Tensor, *args): + for attn, ff in self.layers: + x = attn(x, *args) + x + x = ff(x, *args) + x + return x + + +class TransformerCrossAttn(nn.Module): + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + norm: str = "layer", + norm_cond_dim: int = -1, + context_dim: Optional[int] = None, + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + ca = CrossAttention( + dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout + ) + ff = FeedForward(dim, mlp_dim, dropout=dropout) + self.layers.append( + nn.ModuleList( + [ + PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), + ] + ) + ) + + def forward(self, x: torch.Tensor, *args, context=None, context_list=None): + if context_list is None: + context_list = [context] * len(self.layers) + if len(context_list) != len(self.layers): + raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})") + + for i, (self_attn, cross_attn, ff) in enumerate(self.layers): + x = self_attn(x, *args) + x + x = cross_attn(x, *args, context=context_list[i]) + x + x = ff(x, *args) + x + return x + + +class DropTokenDropout(nn.Module): + def __init__(self, p: float = 0.1): + super().__init__() + if p < 0 or p > 1: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + self.p = p + + def forward(self, x: torch.Tensor): + # x: (batch_size, seq_len, dim) + if self.training and self.p > 0: + zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool() + # TODO: permutation idx for each batch using torch.argsort + if zero_mask.any(): + x = x[:, ~zero_mask, :] + return x + + +class ZeroTokenDropout(nn.Module): + def __init__(self, p: float = 0.1): + super().__init__() + if p < 0 or p > 1: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + self.p = p + + def forward(self, x: torch.Tensor): + # x: (batch_size, seq_len, dim) + if self.training and self.p > 0: + zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool() + # Zero-out the masked tokens + x[zero_mask, :] = 0 + return x + + +class TransformerEncoder(nn.Module): + def __init__( + self, + num_tokens: int, + token_dim: int, + dim: int, + depth: int, + heads: int, + mlp_dim: int, + dim_head: int = 64, + dropout: float = 0.0, + emb_dropout: float = 0.0, + emb_dropout_type: str = "drop", + emb_dropout_loc: str = "token", + norm: str = "layer", + norm_cond_dim: int = -1, + token_pe_numfreq: int = -1, + ): + super().__init__() + if token_pe_numfreq > 0: + token_dim_new = token_dim * (2 * token_pe_numfreq + 1) + self.to_token_embedding = nn.Sequential( + Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim), + FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1), + Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new), + nn.Linear(token_dim_new, dim), + ) + else: + self.to_token_embedding = nn.Linear(token_dim, dim) + self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) + if emb_dropout_type == "drop": + self.dropout = DropTokenDropout(emb_dropout) + elif emb_dropout_type == "zero": + self.dropout = ZeroTokenDropout(emb_dropout) + else: + raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}") + self.emb_dropout_loc = emb_dropout_loc + + self.transformer = Transformer( + dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim + ) + + def forward(self, inp: torch.Tensor, *args, **kwargs): + x = inp + + if self.emb_dropout_loc == "input": + x = self.dropout(x) + x = self.to_token_embedding(x) + + if self.emb_dropout_loc == "token": + x = self.dropout(x) + b, n, _ = x.shape + x += self.pos_embedding[:, :n] + + if self.emb_dropout_loc == "token_afterpos": + x = self.dropout(x) + x = self.transformer(x, *args) + return x + + +class TransformerDecoder(nn.Module): + def __init__( + self, + num_tokens: int, + token_dim: int, + dim: int, + depth: int, + heads: int, + mlp_dim: int, + dim_head: int = 64, + dropout: float = 0.0, + emb_dropout: float = 0.0, + emb_dropout_type: str = 'drop', + norm: str = "layer", + norm_cond_dim: int = -1, + context_dim: Optional[int] = None, + skip_token_embedding: bool = False, + ): + super().__init__() + if not skip_token_embedding: + self.to_token_embedding = nn.Linear(token_dim, dim) + else: + self.to_token_embedding = nn.Identity() + if token_dim != dim: + raise ValueError( + f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True" + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) + if emb_dropout_type == "drop": + self.dropout = DropTokenDropout(emb_dropout) + elif emb_dropout_type == "zero": + self.dropout = ZeroTokenDropout(emb_dropout) + elif emb_dropout_type == "normal": + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = TransformerCrossAttn( + dim, + depth, + heads, + dim_head, + mlp_dim, + dropout, + norm=norm, + norm_cond_dim=norm_cond_dim, + context_dim=context_dim, + ) + + def forward(self, inp: torch.Tensor, *args, context=None, context_list=None): + x = self.to_token_embedding(inp) + b, n, _ = x.shape + + x = self.dropout(x) + x += self.pos_embedding[:, :n] + + x = self.transformer(x, *args, context=context, context_list=context_list) + return x + diff --git a/lib/models/components/t_cond_mlp.py b/lib/models/components/t_cond_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5a09bf54f67712a69953039b7b5af41c3f029 --- /dev/null +++ b/lib/models/components/t_cond_mlp.py @@ -0,0 +1,199 @@ +import copy +from typing import List, Optional + +import torch + + +class AdaptiveLayerNorm1D(torch.nn.Module): + def __init__(self, data_dim: int, norm_cond_dim: int): + super().__init__() + if data_dim <= 0: + raise ValueError(f"data_dim must be positive, but got {data_dim}") + if norm_cond_dim <= 0: + raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}") + self.norm = torch.nn.LayerNorm( + data_dim + ) # TODO: Check if elementwise_affine=True is correct + self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim) + torch.nn.init.zeros_(self.linear.weight) + torch.nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + # x: (batch, ..., data_dim) + # t: (batch, norm_cond_dim) + # return: (batch, data_dim) + x = self.norm(x) + alpha, beta = self.linear(t).chunk(2, dim=-1) + + # Add singleton dimensions to alpha and beta + if x.dim() > 2: + alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1]) + beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1]) + + return x * (1 + alpha) + beta + + +class SequentialCond(torch.nn.Sequential): + def forward(self, input, *args, **kwargs): + for module in self: + if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)): + # print(f'Passing on args to {module}', [a.shape for a in args]) + input = module(input, *args, **kwargs) + else: + # print(f'Skipping passing args to {module}', [a.shape for a in args]) + input = module(input) + return input + + +def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1): + if norm == "batch": + return torch.nn.BatchNorm1d(dim) + elif norm == "layer": + return torch.nn.LayerNorm(dim) + elif norm == "ada": + assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}" + return AdaptiveLayerNorm1D(dim, norm_cond_dim) + elif norm is None: + return torch.nn.Identity() + else: + raise ValueError(f"Unknown norm: {norm}") + + +def linear_norm_activ_dropout( + input_dim: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, +) -> SequentialCond: + layers = [] + layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias)) + if norm is not None: + layers.append(normalization_layer(norm, output_dim, norm_cond_dim)) + layers.append(copy.deepcopy(activation)) + if dropout > 0.0: + layers.append(torch.nn.Dropout(dropout)) + return SequentialCond(*layers) + + +def create_simple_mlp( + input_dim: int, + hidden_dims: List[int], + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, +) -> SequentialCond: + layers = [] + prev_dim = input_dim + for hidden_dim in hidden_dims: + layers.extend( + linear_norm_activ_dropout( + prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ) + ) + prev_dim = hidden_dim + layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias)) + return SequentialCond(*layers) + + +class ResidualMLPBlock(torch.nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + num_hidden_layers: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, + ): + super().__init__() + if not (input_dim == output_dim == hidden_dim): + raise NotImplementedError( + f"input_dim {input_dim} != output_dim {output_dim} is not implemented" + ) + + layers = [] + prev_dim = input_dim + for i in range(num_hidden_layers): + layers.append( + linear_norm_activ_dropout( + prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ) + ) + prev_dim = hidden_dim + self.model = SequentialCond(*layers) + self.skip = torch.nn.Identity() + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return x + self.model(x, *args, **kwargs) + + +class ResidualMLP(torch.nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + num_hidden_layers: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + num_blocks: int = 1, + norm_cond_dim: int = -1, + ): + super().__init__() + self.input_dim = input_dim + self.model = SequentialCond( + linear_norm_activ_dropout( + input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ), + *[ + ResidualMLPBlock( + hidden_dim, + hidden_dim, + num_hidden_layers, + hidden_dim, + activation, + bias, + norm, + dropout, + norm_cond_dim, + ) + for _ in range(num_blocks) + ], + torch.nn.Linear(hidden_dim, output_dim, bias=bias), + ) + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return self.model(x, *args, **kwargs) + + +class FrequencyEmbedder(torch.nn.Module): + def __init__(self, num_frequencies, max_freq_log2): + super().__init__() + frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies) + self.register_buffer("frequencies", frequencies) + + def forward(self, x): + # x should be of size (N,) or (N, D) + N = x.size(0) + if x.dim() == 1: # (N,) + x = x.unsqueeze(1) # (N, D) where D=1 + x_unsqueezed = x.unsqueeze(-1) # (N, D, 1) + scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies) + s = torch.sin(scaled) + c = torch.cos(scaled) + embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view( + N, -1 + ) # (N, D * 2 * num_frequencies + D) + return embedded + diff --git a/lib/models/hawor.py b/lib/models/hawor.py new file mode 100644 index 0000000000000000000000000000000000000000..34a081cd63d0a3e3be894fda21375683f3e61d51 --- /dev/null +++ b/lib/models/hawor.py @@ -0,0 +1,527 @@ +import einops +import numpy as np +import torch +import pytorch_lightning as pl +from typing import Dict +from torchvision.utils import make_grid + +from tqdm import tqdm +from yacs.config import CfgNode + +from lib.datasets.track_dataset import TrackDatasetEval +from lib.models.modules import MANOTransformerDecoderHead, temporal_attention +from hawor.utils.pylogger import get_pylogger +from hawor.utils.render_openpose import render_openpose +from lib.utils.geometry import rot6d_to_rotmat_hmr2 as rot6d_to_rotmat +from lib.utils.geometry import perspective_projection +from hawor.utils.rotation import angle_axis_to_rotation_matrix +from torch.utils.data import default_collate + +from .backbones import create_backbone +from .mano_wrapper import MANO + + +log = get_pylogger(__name__) +idx = 0 + +class HAWOR(pl.LightningModule): + + def __init__(self, cfg: CfgNode): + """ + Setup HAWOR model + Args: + cfg (CfgNode): Config file as a yacs CfgNode + """ + super().__init__() + + # Save hyperparameters + self.save_hyperparameters(logger=False, ignore=['init_renderer']) + + self.cfg = cfg + self.crop_size = cfg.MODEL.IMAGE_SIZE + self.seq_len = 16 + self.pose_num = 16 + self.pose_dim = 6 # rot6d representation + self.box_info_dim = 3 + + # Create backbone feature extractor + self.backbone = create_backbone(cfg) + try: + if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None): + whole_state_dict = torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict'] + backbone_state_dict = {} + for key in whole_state_dict: + if key[:9] == 'backbone.': + backbone_state_dict[key[9:]] = whole_state_dict[key] + self.backbone.load_state_dict(backbone_state_dict) + print(f'Loaded backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}') + for param in self.backbone.parameters(): + param.requires_grad = False + else: + print('WARNING: init backbone from sratch !!!') + except: + print('WARNING: init backbone from sratch !!!') + + # Space-time memory + if cfg.MODEL.ST_MODULE: + hdim = cfg.MODEL.ST_HDIM + nlayer = cfg.MODEL.ST_NLAYER + self.st_module = temporal_attention(in_dim=1280+3, + out_dim=1280, + hdim=hdim, + nlayer=nlayer, + residual=True) + print(f'Using Temporal Attention space-time: {nlayer} layers {hdim} dim.') + else: + self.st_module = None + + # Motion memory + if cfg.MODEL.MOTION_MODULE: + hdim = cfg.MODEL.MOTION_HDIM + nlayer = cfg.MODEL.MOTION_NLAYER + + self.motion_module = temporal_attention(in_dim=self.pose_num * self.pose_dim + self.box_info_dim, + out_dim=self.pose_num * self.pose_dim, + hdim=hdim, + nlayer=nlayer, + residual=False) + print(f'Using Temporal Attention motion layer: {nlayer} layers {hdim} dim.') + else: + self.motion_module = None + + # Create MANO head + # self.mano_head = build_mano_head(cfg) + self.mano_head = MANOTransformerDecoderHead(cfg) + + + # default open torch compile + if cfg.MODEL.BACKBONE.get('TORCH_COMPILE', 0): + log.info("Model will use torch.compile") + self.backbone = torch.compile(self.backbone) + self.mano_head = torch.compile(self.mano_head) + + # Define loss functions + # self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1') + # self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1') + # self.mano_parameter_loss = ParameterLoss() + + # Instantiate MANO model + mano_cfg = {k.lower(): v for k,v in dict(cfg.MANO).items()} + self.mano = MANO(**mano_cfg) + + # Buffer that shows whetheer we need to initialize ActNorm layers + self.register_buffer('initialized', torch.tensor(False)) + + # Disable automatic optimization since we use adversarial training + self.automatic_optimization = False + + if cfg.MODEL.get('LOAD_WEIGHTS', None): + whole_state_dict = torch.load(cfg.MODEL.LOAD_WEIGHTS, map_location='cpu')['state_dict'] + self.load_state_dict(whole_state_dict, strict=True) + print(f"load {cfg.MODEL.LOAD_WEIGHTS}") + + def get_parameters(self): + all_params = list(self.mano_head.parameters()) + if not self.st_module is None: + all_params += list(self.st_module.parameters()) + if not self.motion_module is None: + all_params += list(self.motion_module.parameters()) + all_params += list(self.backbone.parameters()) + return all_params + + def configure_optimizers(self) -> torch.optim.Optimizer: + """ + Setup model and distriminator Optimizers + Returns: + Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers + """ + param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}] + + optimizer = torch.optim.AdamW(params=param_groups, + # lr=self.cfg.TRAIN.LR, + weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) + return optimizer + + def forward_step(self, batch: Dict, train: bool = False) -> Dict: + """ + Run a forward step of the network + Args: + batch (Dict): Dictionary containing batch data + train (bool): Flag indicating whether it is training or validation mode + Returns: + Dict: Dictionary containing the regression output + """ + + image = batch['img'].flatten(0, 1) + center = batch['center'].flatten(0, 1) + scale = batch['scale'].flatten(0, 1) + img_focal = batch['img_focal'].flatten(0, 1) + img_center = batch['img_center'].flatten(0, 1) + bn = len(image) + + # estimate focal length, and bbox + bbox_info = self.bbox_est(center, scale, img_focal, img_center) + + # backbone + feature = self.backbone(image[:,:,:,32:-32]) + feature = feature.float() + + # space-time module + if self.st_module is not None: + bb = einops.repeat(bbox_info, 'b c -> b c h w', h=16, w=12) + feature = torch.cat([feature, bb], dim=1) + + feature = einops.rearrange(feature, '(b t) c h w -> (b h w) t c', t=16) + feature = self.st_module(feature) + feature = einops.rearrange(feature, '(b h w) t c -> (b t) c h w', h=16, w=12) + + # smpl_head: transformer + smpl + # pred_mano_params, pred_cam, pred_mano_params_list = self.mano_head(feature) + # pred_shape = pred_mano_params_list['pred_shape'] + # pred_pose = pred_mano_params_list['pred_pose'] + pred_pose, pred_shape, pred_cam = self.mano_head(feature) + pred_rotmat_0 = rot6d_to_rotmat(pred_pose).reshape(-1, self.pose_num, 3, 3) + + # smpl motion module + if self.motion_module is not None: + bb = einops.rearrange(bbox_info, '(b t) c -> b t c', t=16) + pred_pose = einops.rearrange(pred_pose, '(b t) c -> b t c', t=16) + pred_pose = torch.cat([pred_pose, bb], dim=2) + + pred_pose = self.motion_module(pred_pose) + pred_pose = einops.rearrange(pred_pose, 'b t c -> (b t) c') + + out = {} + if 'do_flip' in batch: + pred_cam[..., 1] *= -1 + center[..., 0] = img_center[..., 0]*2 - center[..., 0] - 1 + out['pred_cam'] = pred_cam + out['pred_pose'] = pred_pose + out['pred_shape'] = pred_shape + out['pred_rotmat'] = rot6d_to_rotmat(out['pred_pose']).reshape(-1, self.pose_num, 3, 3) + out['pred_rotmat_0'] = pred_rotmat_0 + + s_out = self.mano.query(out) + j3d = s_out.joints + j2d = self.project(j3d, out['pred_cam'], center, scale, img_focal, img_center) + j2d = j2d / self.crop_size - 0.5 # norm to [-0.5, 0.5] + + trans_full = self.get_trans(out['pred_cam'], center, scale, img_focal, img_center) + out['trans_full'] = trans_full + + output = { + 'pred_mano_params': { + 'global_orient': out['pred_rotmat'][:, :1].clone(), + 'hand_pose': out['pred_rotmat'][:, 1:].clone(), + 'betas': out['pred_shape'].clone(), + }, + 'pred_keypoints_3d': j3d.clone(), + 'pred_keypoints_2d': j2d.clone(), + 'out': out, + } + # print(output) + # output['gt_project_j2d'] = self.project(batch['gt_j3d_wo_trans'].clone().flatten(0,1), out['pred_cam'], center, scale, img_focal, img_center) + # output['gt_project_j2d'] = output['gt_project_j2d'] / self.crop_size - 0.5 + + + return output + + def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor: + """ + Compute losses given the input batch and the regression output + Args: + batch (Dict): Dictionary containing batch data + output (Dict): Dictionary containing the regression output + train (bool): Flag indicating whether it is training or validation mode + Returns: + torch.Tensor : Total loss for current batch + """ + + pred_mano_params = output['pred_mano_params'] + pred_keypoints_2d = output['pred_keypoints_2d'] + pred_keypoints_3d = output['pred_keypoints_3d'] + + + batch_size = pred_mano_params['hand_pose'].shape[0] + device = pred_mano_params['hand_pose'].device + dtype = pred_mano_params['hand_pose'].dtype + + # Get annotations + gt_keypoints_2d = batch['gt_cam_j2d'].flatten(0, 1) + gt_keypoints_2d = torch.cat([gt_keypoints_2d, torch.ones(*gt_keypoints_2d.shape[:-1], 1, device=gt_keypoints_2d.device)], dim=-1) + gt_keypoints_3d = batch['gt_j3d_wo_trans'].flatten(0, 1) + gt_keypoints_3d = torch.cat([gt_keypoints_3d, torch.ones(*gt_keypoints_3d.shape[:-1], 1, device=gt_keypoints_3d.device)], dim=-1) + pose_gt = batch['gt_cam_full_pose'].flatten(0, 1).reshape(-1, 16, 3) + rotmat_gt = angle_axis_to_rotation_matrix(pose_gt) + gt_mano_params = { + 'global_orient': rotmat_gt[:, :1], + 'hand_pose': rotmat_gt[:, 1:], + 'betas': batch['gt_cam_betas'], + } + + # Compute 3D keypoint loss + loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d) + loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0) + + # to avoid nan + loss_keypoints_2d = torch.nan_to_num(loss_keypoints_2d) + + # Compute loss on MANO parameters + loss_mano_params = {} + for k, pred in pred_mano_params.items(): + gt = gt_mano_params[k].view(batch_size, -1) + loss_mano_params[k] = self.mano_parameter_loss(pred.reshape(batch_size, -1), gt.reshape(batch_size, -1)) + + loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\ + self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d+\ + sum([loss_mano_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_mano_params]) + + losses = dict(loss=loss.detach(), + loss_keypoints_2d=loss_keypoints_2d.detach() * self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'], + loss_keypoints_3d=loss_keypoints_3d.detach() * self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D']) + + for k, v in loss_mano_params.items(): + losses['loss_' + k] = v.detach() * self.cfg.LOSS_WEIGHTS[k.upper()] + + output['losses'] = losses + + return loss + + # Tensoroboard logging should run from first rank only + @pl.utilities.rank_zero.rank_zero_only + def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, write_to_summary_writer: bool = True, render_log: bool = True) -> None: + """ + Log results to Tensorboard + Args: + batch (Dict): Dictionary containing batch data + output (Dict): Dictionary containing the regression output + step_count (int): Global training step count + train (bool): Flag indicating whether it is training or validation mode + """ + + mode = 'train' if train else 'val' + batch_size = output['pred_keypoints_2d'].shape[0] + images = batch['img'].flatten(0,1) + images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1,3,1,1) + images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1) + + losses = output['losses'] + if write_to_summary_writer: + summary_writer = self.logger.experiment + for loss_name, val in losses.items(): + summary_writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count) + + if render_log: + gt_keypoints_2d = batch['gt_cam_j2d'].flatten(0,1).clone() + pred_keypoints_2d = output['pred_keypoints_2d'].clone().detach().reshape(batch_size, -1, 2) + gt_project_j2d = pred_keypoints_2d + # gt_project_j2d = output['gt_project_j2d'].clone().detach().reshape(batch_size, -1, 2) + + num_images = 4 + skip=16 + + predictions = self.visualize_tensorboard(images[:num_images*skip:skip].cpu().numpy(), + pred_keypoints_2d[:num_images*skip:skip].cpu().numpy(), + gt_project_j2d[:num_images*skip:skip].cpu().numpy(), + gt_keypoints_2d[:num_images*skip:skip].cpu().numpy(), + ) + summary_writer.add_image('%s/predictions' % mode, predictions, step_count) + + + def forward(self, batch: Dict) -> Dict: + """ + Run a forward step of the network in val mode + Args: + batch (Dict): Dictionary containing batch data + Returns: + Dict: Dictionary containing the regression output + """ + return self.forward_step(batch, train=False) + + def training_step(self, joint_batch: Dict, batch_idx: int) -> Dict: + """ + Run a full training step + Args: + joint_batch (Dict): Dictionary containing image and mocap batch data + batch_idx (int): Unused. + batch_idx (torch.Tensor): Unused. + Returns: + Dict: Dictionary containing regression output. + """ + batch = joint_batch['img'] + optimizer = self.optimizers(use_pl_optimizer=True) + + batch_size = batch['img'].shape[0] + output = self.forward_step(batch, train=True) + # pred_mano_params = output['pred_mano_params'] + loss = self.compute_loss(batch, output, train=True) + + # Error if Nan + if torch.isnan(loss): + raise ValueError('Loss is NaN') + + optimizer.zero_grad() + self.manual_backward(loss) + # Clip gradient + if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0: + gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, error_if_nonfinite=True) + self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size) + optimizer.step() + + # if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0: + if self.global_step > 0 and self.global_step % 100 == 0: + self.tensorboard_logging(batch, output, self.global_step, train=True, render_log=self.cfg.TRAIN.get("RENDER_LOG", True)) + + self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=batch_size) + + return output + + def inference(self, imgfiles, boxes, img_focal, img_center, device='cuda', do_flip=False): + db = TrackDatasetEval(imgfiles, boxes, img_focal=img_focal, + img_center=img_center, normalization=True, dilate=1.2, do_flip=do_flip) + + # Results + pred_cam = [] + pred_pose = [] + pred_shape = [] + pred_rotmat = [] + pred_trans = [] + + # To-do: efficient implementation with batch + items = [] + for i in tqdm(range(len(db))): + item = db[i] + items.append(item) + + # padding to 16 + if i == len(db) - 1 and len(db) % 16 != 0: + pad = 16 - len(db) % 16 + for _ in range(pad): + items.append(item) + + if len(items) < 16: + continue + elif len(items) == 16: + batch = default_collate(items) + items = [] + else: + raise NotImplementedError + + with torch.no_grad(): + batch = {k: v.to(device).unsqueeze(0) for k, v in batch.items() if type(v)==torch.Tensor} + # for image_i in range(16): + # hawor_input_cv2 = vis_tensor_cv2(batch['img'][:, image_i]) + # cv2.imwrite(f'debug_vis_model.png', hawor_input_cv2) + # print("vis") + output = self.forward(batch) + out = output['out'] + + if i == len(db) - 1 and len(db) % 16 != 0: + out = {k:v[:len(db) % 16] for k,v in out.items()} + else: + out = {k:v for k,v in out.items()} + + pred_cam.append(out['pred_cam'].cpu()) + pred_pose.append(out['pred_pose'].cpu()) + pred_shape.append(out['pred_shape'].cpu()) + pred_rotmat.append(out['pred_rotmat'].cpu()) + pred_trans.append(out['trans_full'].cpu()) + + + results = {'pred_cam': torch.cat(pred_cam), + 'pred_pose': torch.cat(pred_pose), + 'pred_shape': torch.cat(pred_shape), + 'pred_rotmat': torch.cat(pred_rotmat), + 'pred_trans': torch.cat(pred_trans), + 'img_focal': img_focal, + 'img_center': img_center} + + return results + + def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict: + """ + Run a validation step and log to Tensorboard + Args: + batch (Dict): Dictionary containing batch data + batch_idx (int): Unused. + Returns: + Dict: Dictionary containing regression output. + """ + # batch_size = batch['img'].shape[0] + output = self.forward_step(batch, train=False) + loss = self.compute_loss(batch, output, train=False) + output['loss'] = loss + self.tensorboard_logging(batch, output, self.global_step, train=False) + + return output + + def visualize_tensorboard(self, images, pred_keypoints, gt_project_j2d, gt_keypoints): + pred_keypoints = 256 * (pred_keypoints + 0.5) + gt_keypoints = 256 * (gt_keypoints + 0.5) + gt_project_j2d = 256 * (gt_project_j2d + 0.5) + pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1) + gt_keypoints = np.concatenate((gt_keypoints, np.ones_like(gt_keypoints)[:, :, [0]]), axis=-1) + gt_project_j2d = np.concatenate((gt_project_j2d, np.ones_like(gt_project_j2d)[:, :, [0]]), axis=-1) + images_np = np.transpose(images, (0,2,3,1)) + rend_imgs = [] + for i in range(images_np.shape[0]): + pred_keypoints_img = render_openpose(255 * images_np[i].copy(), pred_keypoints[i]) / 255 + gt_project_j2d_img = render_openpose(255 * images_np[i].copy(), gt_project_j2d[i]) / 255 + gt_keypoints_img = render_openpose(255*images_np[i].copy(), gt_keypoints[i]) / 255 + rend_imgs.append(torch.from_numpy(images[i])) + rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2,0,1)) + rend_imgs.append(torch.from_numpy(gt_project_j2d_img).permute(2,0,1)) + rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2,0,1)) + rend_imgs = make_grid(rend_imgs, nrow=4, padding=2) + return rend_imgs + + def project(self, points, pred_cam, center, scale, img_focal, img_center, return_full=False): + + trans_full = self.get_trans(pred_cam, center, scale, img_focal, img_center) + + # Projection in full frame image coordinate + points = points + trans_full + points2d_full = perspective_projection(points, rotation=None, translation=None, + focal_length=img_focal, camera_center=img_center) + + # Adjust projected points to crop image coordinate + # (s.t. 1. we can calculate loss in crop image easily + # 2. we can query its pixel in the crop + # ) + b = scale * 200 + points2d = points2d_full - (center - b[:,None]/2)[:,None,:] + points2d = points2d * (self.crop_size / b)[:,None,None] + + if return_full: + return points2d_full, points2d + else: + return points2d + + def get_trans(self, pred_cam, center, scale, img_focal, img_center): + b = scale * 200 + cx, cy = center[:,0], center[:,1] # center of crop + s, tx, ty = pred_cam.unbind(-1) + + img_cx, img_cy = img_center[:,0], img_center[:,1] # center of original image + + bs = b*s + tx_full = tx + 2*(cx-img_cx)/bs + ty_full = ty + 2*(cy-img_cy)/bs + tz_full = 2*img_focal/bs + + trans_full = torch.stack([tx_full, ty_full, tz_full], dim=-1) + trans_full = trans_full.unsqueeze(1) + + return trans_full + + def bbox_est(self, center, scale, img_focal, img_center): + # Original image center + img_cx, img_cy = img_center[:,0], img_center[:,1] + + # Implement CLIFF (Li et al.) bbox feature + cx, cy, b = center[:, 0], center[:, 1], scale * 200 + bbox_info = torch.stack([cx - img_cx, cy - img_cy, b], dim=-1) + bbox_info[:, :2] = bbox_info[:, :2] / img_focal.unsqueeze(-1) * 2.8 + bbox_info[:, 2] = (bbox_info[:, 2] - 0.24 * img_focal) / (0.06 * img_focal) + + return bbox_info diff --git a/lib/models/mano_wrapper.py b/lib/models/mano_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..4801c26e3410035db3d09f8dac70c637dd7b00f5 --- /dev/null +++ b/lib/models/mano_wrapper.py @@ -0,0 +1,52 @@ +import torch +import numpy as np +import pickle +from typing import Optional +import smplx +from smplx.lbs import vertices2joints +from smplx.utils import MANOOutput, to_tensor +from smplx.vertex_ids import vertex_ids + + +class MANO(smplx.MANOLayer): + def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs): + """ + Extension of the official MANO implementation to support more joints. + Args: + Same as MANOLayer. + joint_regressor_extra (str): Path to extra joint regressor. + """ + super(MANO, self).__init__(*args, **kwargs) + mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20] + + #2, 3, 5, 4, 1 + if joint_regressor_extra is not None: + self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32)) + self.register_buffer('extra_joints_idxs', to_tensor(list(vertex_ids['mano'].values()), dtype=torch.long)) + self.register_buffer('joint_map', torch.tensor(mano_to_openpose, dtype=torch.long)) + + def forward(self, *args, **kwargs) -> MANOOutput: + """ + Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified. + """ + mano_output = super(MANO, self).forward(*args, **kwargs) + extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs) + joints = torch.cat([mano_output.joints, extra_joints], dim=1) + joints = joints[:, self.joint_map, :] + if hasattr(self, 'joint_regressor_extra'): + extra_joints = vertices2joints(self.joint_regressor_extra, mano_output.vertices) + joints = torch.cat([joints, extra_joints], dim=1) + mano_output.joints = joints + return mano_output + + def query(self, hmr_output): + batch_size = hmr_output['pred_rotmat'].shape[0] + pred_rotmat = hmr_output['pred_rotmat'].reshape(batch_size, -1, 3, 3) + pred_shape = hmr_output['pred_shape'].reshape(batch_size, 10) + + mano_output = self(global_orient=pred_rotmat[:, [0]], + hand_pose = pred_rotmat[:, 1:], + betas = pred_shape, + pose2rot=False) + + return mano_output \ No newline at end of file diff --git a/lib/models/modules.py b/lib/models/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..b0adbef1ddfb6f15e8cf0df6b99f924f9c409d6f --- /dev/null +++ b/lib/models/modules.py @@ -0,0 +1,133 @@ +import numpy as np +import einops +import torch +import torch.nn as nn +from .components.pose_transformer import TransformerDecoder + +if torch.cuda.is_available(): + autocast = torch.cuda.amp.autocast + # print('Using autocast') +else: + # dummy GradScaler for PyTorch < 1.6 OR no cuda + class autocast: + def __init__(self, enabled=True): + pass + def __enter__(self): + pass + def __exit__(self, *args): + pass + +class MANOTransformerDecoderHead(nn.Module): + """ HMR2 Cross-attention based SMPL Transformer decoder + """ + def __init__(self, cfg): + super().__init__() + transformer_args = dict( + depth = 6, # originally 6 + heads = 8, + mlp_dim = 1024, + dim_head = 64, + dropout = 0.0, + emb_dropout = 0.0, + norm = "layer", + context_dim = 1280, + num_tokens = 1, + token_dim = 1, + dim = 1024 + ) + self.transformer = TransformerDecoder(**transformer_args) + + dim = 1024 + npose = 16*6 + self.decpose = nn.Linear(dim, npose) + self.decshape = nn.Linear(dim, 10) + self.deccam = nn.Linear(dim, 3) + nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) + nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) + nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) + + mean_params = np.load(cfg.MANO.MEAN_PARAMS) + init_hand_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0) + init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0) + init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0) + self.register_buffer('init_hand_pose', init_hand_pose) + self.register_buffer('init_betas', init_betas) + self.register_buffer('init_cam', init_cam) + + + def forward(self, x, **kwargs): + + batch_size = x.shape[0] + # vit pretrained backbone is channel-first. Change to token-first + x = einops.rearrange(x, 'b c h w -> b (h w) c') + + init_hand_pose = self.init_hand_pose.expand(batch_size, -1) + init_betas = self.init_betas.expand(batch_size, -1) + init_cam = self.init_cam.expand(batch_size, -1) + + # Pass through transformer + token = torch.zeros(batch_size, 1, 1).to(x.device) + token_out = self.transformer(token, context=x) + token_out = token_out.squeeze(1) # (B, C) + + # Readout from token_out + pred_pose = self.decpose(token_out) + init_hand_pose + pred_shape = self.decshape(token_out) + init_betas + pred_cam = self.deccam(token_out) + init_cam + + return pred_pose, pred_shape, pred_cam + + + +class temporal_attention(nn.Module): + def __init__(self, in_dim=1280, out_dim=1280, hdim=512, nlayer=6, nhead=4, residual=False): + super(temporal_attention, self).__init__() + self.hdim = hdim + self.out_dim = out_dim + self.residual = residual + self.l1 = nn.Linear(in_dim, hdim) + self.l2 = nn.Linear(hdim, out_dim) + + self.pos_embedding = PositionalEncoding(hdim, dropout=0.1) + TranLayer = nn.TransformerEncoderLayer(d_model=hdim, nhead=nhead, dim_feedforward=1024, + dropout=0.1, activation='gelu') + self.trans = nn.TransformerEncoder(TranLayer, num_layers=nlayer) + + nn.init.xavier_uniform_(self.l1.weight, gain=0.01) + nn.init.xavier_uniform_(self.l2.weight, gain=0.01) + + def forward(self, x): + x = x.permute(1,0,2) # (b,t,c) -> (t,b,c) + + h = self.l1(x) + h = self.pos_embedding(h) + h = self.trans(h) + h = self.l2(h) + + if self.residual: + x = x[..., :self.out_dim] + h + else: + x = h + x = x.permute(1,0,2) + + return x + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=100): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x): + # not used in the final model + x = x + self.pe[:x.shape[0], :] + return self.dropout(x) diff --git a/lib/pipeline/__init__.py b/lib/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/pipeline/__pycache__/__init__.cpython-310.pyc b/lib/pipeline/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b0ca428c1966c571d69b2796cd2c187dc66446f Binary files /dev/null and b/lib/pipeline/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/pipeline/__pycache__/est_scale.cpython-310.pyc b/lib/pipeline/__pycache__/est_scale.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc53b52d4d22b1a1a17175a5f62c04b706cf276a Binary files /dev/null and b/lib/pipeline/__pycache__/est_scale.cpython-310.pyc differ diff --git a/lib/pipeline/__pycache__/masked_droid_slam.cpython-310.pyc b/lib/pipeline/__pycache__/masked_droid_slam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0f93b76370783d2774c468d83be5d415b749115 Binary files /dev/null and b/lib/pipeline/__pycache__/masked_droid_slam.cpython-310.pyc differ diff --git a/lib/pipeline/__pycache__/tools.cpython-310.pyc b/lib/pipeline/__pycache__/tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43b0b6082146b5ba984551e79ef8fe27a66f91f6 Binary files /dev/null and b/lib/pipeline/__pycache__/tools.cpython-310.pyc differ diff --git a/lib/pipeline/est_scale.py b/lib/pipeline/est_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..780d7a1e2dd4d7b2680e49ba13e3a7feb8cf99c5 --- /dev/null +++ b/lib/pipeline/est_scale.py @@ -0,0 +1,195 @@ +import numpy as np +import cv2 +import torch +from torchmin import minimize + + +def est_scale_iterative(slam_depth, pred_depth, iters=10, msk=None): + """ Simple depth-align by iterative median and thresholding """ + s = pred_depth / slam_depth + + if msk is None: + msk = np.zeros_like(pred_depth) + else: + msk = cv2.resize(msk, (pred_depth.shape[1], pred_depth.shape[0])) + + robust = (msk<0.5) * (0 4: + image = cv2.undistort(image, K, calib[4:]) + + h0, w0, _ = image.shape + h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0))) + w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0))) + + image = cv2.resize(image, (w1, h1)) + image = image[:h1-h1%8, :w1-w1%8] + image = torch.as_tensor(image).permute(2, 0, 1) + + intrinsics = torch.as_tensor([fx, fy, cx, cy]) + intrinsics[0::2] *= (w1 / w0) + intrinsics[1::2] *= (h1 / h0) + + yield t, image[None], intrinsics + + +def run_slam(imagedir, masks, calib=None, depth=None, stride=1, + filter_thresh=2.4, disable_vis=True): + """ Maksed DROID-SLAM """ + droid = None + depth = None + args.filter_thresh = filter_thresh + args.disable_vis = disable_vis + masks = masks[::stride] + + img_msks, conf_msks = preprocess_masks(imagedir, masks) + if calib is None: + calib = est_calib(imagedir) + + for (t, image, intrinsics) in tqdm(image_stream(imagedir, calib, stride)): + + if droid is None: + args.image_size = [image.shape[2], image.shape[3]] + droid = Droid(args) + + img_msk = img_msks[t] + conf_msk = conf_msks[t] + image = image * (img_msk < 0.5) + # cv2.imwrite('debug.png', image[0].permute(1, 2, 0).numpy()) + + droid.track(t, image, intrinsics=intrinsics, depth=depth, mask=conf_msk) + + traj = droid.terminate(image_stream(imagedir, calib, stride)) + + return droid, traj + +def run_droid_slam(imagedir, calib=None, depth=None, stride=1, + filter_thresh=2.4, disable_vis=True): + """ Maksed DROID-SLAM """ + droid = None + depth = None + args.filter_thresh = filter_thresh + args.disable_vis = disable_vis + + if calib is None: + calib = est_calib(imagedir) + + for (t, image, intrinsics) in tqdm(image_stream(imagedir, calib, stride)): + + if droid is None: + args.image_size = [image.shape[2], image.shape[3]] + droid = Droid(args) + + droid.track(t, image, intrinsics=intrinsics, depth=depth) + + traj = droid.terminate(image_stream(imagedir, calib, stride)) + + return droid, traj + + +def eval_slam(traj_est, cam_t, cam_q, return_traj=True, correct_scale=False, align=True, align_origin=False): + """ Evaluation for SLAM """ + tstamps = np.array([i for i in range(len(traj_est))], dtype=np.float32) + + traj_est = PoseTrajectory3D( + positions_xyz=traj_est[:,:3], + orientations_quat_wxyz=traj_est[:,3:], + timestamps=tstamps) + + traj_ref = PoseTrajectory3D( + positions_xyz=cam_t.copy(), + orientations_quat_wxyz=cam_q.copy(), + timestamps=tstamps) + + traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est) + result = main_ape.ape(traj_ref, traj_est, est_name='traj', + pose_relation=PoseRelation.translation_part, align=align, align_origin=align_origin, + correct_scale=correct_scale) + + stats = result.stats + + if return_traj: + return stats, traj_ref, traj_est + + return stats + + +def test_slam(imagedir, img_msks, conf_msks, calib, stride=10, max_frame=50): + """ Shorter SLAM step to test reprojection error """ + args = parser.parse_args([]) + args.stereo = False + args.upsample = False + args.disable_vis = True + args.frontend_window = 10 + args.frontend_thresh = 10 + droid = None + + for (t, image, intrinsics) in image_stream(imagedir, calib, stride, max_frame): + if droid is None: + args.image_size = [image.shape[2], image.shape[3]] + droid = Droid(args) + + img_msk = img_msks[t] + conf_msk = conf_msks[t] + image = image * (img_msk < 0.5) + droid.track(t, image, intrinsics=intrinsics, mask=conf_msk) + + reprojection_error = droid.compute_error() + del droid + + return reprojection_error + + +def search_focal_length(img_folder, masks, stride=10, max_frame=50, + low=500, high=1500, step=100): + """ Search for a good focal length by SLAM reprojection error """ + masks = masks[::stride] + masks = masks[:max_frame] + img_msks, conf_msks = preprocess_masks(img_folder, masks) + + # default estimate + calib = np.array(est_calib(img_folder)) + best_focal = calib[0] + best_err = test_slam(img_folder, img_msks, conf_msks, + stride=stride, calib=calib, max_frame=max_frame) + + # search based on slam reprojection error + for focal in range(low, high, step): + calib[:2] = focal + err = test_slam(img_folder, img_msks, conf_msks, + stride=stride, calib=calib, max_frame=max_frame) + + if err < best_err: + best_err = err + best_focal = focal + + print('Best focal length:', best_focal) + + return best_focal + + +def preprocess_masks(img_folder, masks): + """ Resize masks for masked droid """ + H, W = get_dimention(img_folder) + resize_1 = Resize((H, W), antialias=True) + resize_2 = Resize((H//8, W//8), antialias=True) + + img_msks = [] + for i in range(0, len(masks), 500): + m = resize_1(masks[i:i+500]) + img_msks.append(m) + img_msks = torch.cat(img_msks) + + conf_msks = [] + for i in range(0, len(masks), 500): + m = resize_2(masks[i:i+500]) + conf_msks.append(m) + conf_msks = torch.cat(conf_msks) + + return img_msks, conf_msks + + + + + diff --git a/lib/pipeline/tools.py b/lib/pipeline/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..1869faf35b8532b26b25fcb1971fa9046e9b3b15 --- /dev/null +++ b/lib/pipeline/tools.py @@ -0,0 +1,129 @@ +import cv2 +from tqdm import tqdm +import numpy as np +import torch + +from ultralytics import YOLO + + +if torch.cuda.is_available(): + autocast = torch.cuda.amp.autocast +else: + class autocast: + def __init__(self, enabled=True): + pass + def __enter__(self): + pass + def __exit__(self, *args): + pass + + +def detect_track(imgfiles, thresh=0.5): + + hand_det_model = YOLO('./weights/external/detector.pt') + + # Run + boxes_ = [] + tracks = {} + for t, imgpath in enumerate(tqdm(imgfiles)): + img_cv2 = cv2.imread(imgpath) + + ### --- Detection --- + with torch.no_grad(): + with autocast(): + results = hand_det_model.track(img_cv2, conf=thresh, persist=True, verbose=False) + + boxes = results[0].boxes.xyxy.cpu().numpy() + confs = results[0].boxes.conf.cpu().numpy() + handedness = results[0].boxes.cls.cpu().numpy() + if not results[0].boxes.id is None: + track_id = results[0].boxes.id.cpu().numpy() + else: + track_id = [-1] * len(boxes) + + boxes = np.hstack([boxes, confs[:, None]]) + find_right = False + find_left = False + for idx, box in enumerate(boxes): + if track_id[idx] == -1: + if handedness[[idx]] > 0: + id = int(10000) + else: + id = int(5000) + else: + id = track_id[idx] + subj = dict() + subj['frame'] = t + subj['det'] = True + subj['det_box'] = boxes[[idx]] + subj['det_handedness'] = handedness[[idx]] + + + if (not find_right and handedness[[idx]] > 0) or (not find_left and handedness[[idx]]==0): + if id in tracks: + tracks[id].append(subj) + else: + tracks[id] = [subj] + + if handedness[[idx]] > 0: + find_right = True + elif handedness[[idx]] == 0: + find_left = True + tracks = np.array(tracks, dtype=object) + boxes_ = np.array(boxes_, dtype=object) + + return boxes_, tracks + + +def parse_chunks(frame, boxes, min_len=16): + """ If a track disappear in the middle, + we separate it to different segments to estimate the HPS independently. + If a segment is less than 16 frames, we get rid of it for now. + """ + frame_chunks = [] + boxes_chunks = [] + step = frame[1:] - frame[:-1] + step = np.concatenate([[0], step]) + breaks = np.where(step != 1)[0] + + start = 0 + for bk in breaks: + f_chunk = frame[start:bk] + b_chunk = boxes[start:bk] + start = bk + if len(f_chunk)>=min_len: + frame_chunks.append(f_chunk) + boxes_chunks.append(b_chunk) + + if bk==breaks[-1]: # last chunk + f_chunk = frame[bk:] + b_chunk = boxes[bk:] + if len(f_chunk)>=min_len: + frame_chunks.append(f_chunk) + boxes_chunks.append(b_chunk) + + return frame_chunks, boxes_chunks + +def parse_chunks_hand_frame(frame): + """ If a track disappear in the middle, + we separate it to different segments to estimate the HPS independently. + If a segment is less than 16 frames, we get rid of it for now. + """ + frame_chunks = [] + step = frame[1:] - frame[:-1] + step = np.concatenate([[0], step]) + breaks = np.where(step != 1)[0] + + start = 0 + for bk in breaks: + f_chunk = frame[start:bk] + start = bk + if len(f_chunk) > 0: + frame_chunks.append(f_chunk) + + if bk==breaks[-1]: # last chunk + f_chunk = frame[bk:] + if len(f_chunk) > 0: + frame_chunks.append(f_chunk) + + return frame_chunks diff --git a/lib/utils/__pycache__/geometry.cpython-310.pyc b/lib/utils/__pycache__/geometry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9c90a145f832d48366222f7963cf1bc1f0cdf04 Binary files /dev/null and b/lib/utils/__pycache__/geometry.cpython-310.pyc differ diff --git a/lib/utils/__pycache__/imutils.cpython-310.pyc b/lib/utils/__pycache__/imutils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b6a7494b10d0336eecadabbd494530d3b956ca4 Binary files /dev/null and b/lib/utils/__pycache__/imutils.cpython-310.pyc differ diff --git a/lib/utils/geometry.py b/lib/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d65066ce383c744407f8d800a7342cd324142b --- /dev/null +++ b/lib/utils/geometry.py @@ -0,0 +1,412 @@ +import numpy as np +import torch +from torch.nn import functional as F + + +def perspective_projection(points, rotation, translation, + focal_length, camera_center, distortion=None): + """ + This function computes the perspective projection of a set of points. + Input: + points (bs, N, 3): 3D points + rotation (bs, 3, 3): Camera rotation + translation (bs, 3): Camera translation + focal_length (bs,) or scalar: Focal length + camera_center (bs, 2): Camera center + """ + batch_size = points.shape[0] + + # Extrinsic + if rotation is not None: + points = torch.einsum('bij,bkj->bki', rotation, points) + + if translation is not None: + points = points + translation.unsqueeze(1) + + if distortion is not None: + kc = distortion + points = points[:,:,:2] / points[:,:,2:] + + r2 = points[:,:,0]**2 + points[:,:,1]**2 + dx = (2 * kc[:,[2]] * points[:,:,0] * points[:,:,1] + + kc[:,[3]] * (r2 + 2*points[:,:,0]**2)) + + dy = (2 * kc[:,[3]] * points[:,:,0] * points[:,:,1] + + kc[:,[2]] * (r2 + 2*points[:,:,1]**2)) + + x = (1 + kc[:,[0]]*r2 + kc[:,[1]]*r2.pow(2) + kc[:,[4]]*r2.pow(3)) * points[:,:,0] + dx + y = (1 + kc[:,[0]]*r2 + kc[:,[1]]*r2.pow(2) + kc[:,[4]]*r2.pow(3)) * points[:,:,1] + dy + + points = torch.stack([x, y, torch.ones_like(x)], dim=-1) + + # Intrinsic + K = torch.zeros([batch_size, 3, 3], device=points.device) + K[:,0,0] = focal_length + K[:,1,1] = focal_length + K[:,2,2] = 1. + K[:,:-1, -1] = camera_center + + # Apply camera intrinsicsrf + points = points / points[:,:,-1].unsqueeze(-1) + projected_points = torch.einsum('bij,bkj->bki', K, points) + projected_points = projected_points[:, :, :-1] + + return projected_points + + +def avg_rot(rot): + # input [B,...,3,3] --> output [...,3,3] + rot = rot.mean(dim=0) + U, _, V = torch.svd(rot) + rot = U @ V.transpose(-1, -2) + return rot + + +def rot9d_to_rotmat(x): + """Convert 9D rotation representation to 3x3 rotation matrix. + Based on Levinson et al., "An Analysis of SVD for Deep Rotation Estimation" + Input: + (B,9) or (B,J*9) Batch of 9D rotation (interpreted as 3x3 est rotmat) + Output: + (B,3,3) or (B*J,3,3) Batch of corresponding rotation matrices + """ + x = x.view(-1,3,3) + u, _, vh = torch.linalg.svd(x) + + sig = torch.eye(3).expand(len(x), 3, 3).clone() + sig = sig.to(x.device) + sig[:, -1, -1] = (u @ vh).det() + + R = u @ sig @ vh + + return R + + +""" +Deprecated in favor of: rotation_conversions.py + +Useful geometric operations, e.g. differentiable Rodrigues formula +Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR +""" +def batch_rodrigues(theta): + """Convert axis-angle representation to rotation matrix. + Args: + theta: size = [B, 3] + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1) + angle = torch.unsqueeze(l1norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim = 1) + return quat_to_rotmat(quat) + +def quat_to_rotmat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + +def rot6d_to_rotmat(x): + """Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,6) Batch of 6-D rotation representations + Output: + (B,3,3) Batch of corresponding rotation matrices + """ + x = x.view(-1,3,2) + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + +def rot6d_to_rotmat_hmr2(x: torch.Tensor) -> torch.Tensor: + """ + Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Args: + x (torch.Tensor): (B,6) Batch of 6-D rotation representations. + Returns: + torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3). + """ + x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous() + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + +def rotmat_to_rot6d(rotmat): + """ Inverse function of the above. + Input: + (B,3,3) Batch of corresponding rotation matrices + Output: + (B,6) Batch of 6-D rotation representations + """ + # rot6d = rotmat[:, :, :2] + rot6d = rotmat[...,:2] + rot6d = rot6d.reshape(rot6d.size(0), -1) + return rot6d + + +def rotation_matrix_to_angle_axis(rotation_matrix): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to Rodrigues vector + + Args: + rotation_matrix (Tensor): rotation matrix. + + Returns: + Tensor: Rodrigues vector transformation. + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 3)` + + Example: + >>> input = torch.rand(2, 3, 4) # Nx4x4 + >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3 + """ + if rotation_matrix.shape[1:] == (3,3): + rot_mat = rotation_matrix.reshape(-1, 3, 3) + hom = torch.tensor([0, 0, 1], dtype=torch.float32, + device=rotation_matrix.device).reshape(1, 3, 1).expand(rot_mat.shape[0], -1, -1) + rotation_matrix = torch.cat([rot_mat, hom], dim=-1) + + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + aa = quaternion_to_angle_axis(quaternion) + aa[torch.isnan(aa)] = 0.0 + return aa + + +def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor: + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert quaternion vector to angle axis of rotation. + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + quaternion (torch.Tensor): tensor with quaternions. + + Return: + torch.Tensor: tensor with angle axis of rotation. + + Shape: + - Input: :math:`(*, 4)` where `*` means, any number of dimensions + - Output: :math:`(*, 3)` + + Example: + >>> quaternion = torch.rand(2, 4) # Nx4 + >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 + """ + if not torch.is_tensor(quaternion): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(quaternion))) + + if not quaternion.shape[-1] == 4: + raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}" + .format(quaternion.shape)) + # unpack input and compute conversion + q1: torch.Tensor = quaternion[..., 1] + q2: torch.Tensor = quaternion[..., 2] + q3: torch.Tensor = quaternion[..., 3] + sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) + cos_theta: torch.Tensor = quaternion[..., 0] + two_theta: torch.Tensor = 2.0 * torch.where( + cos_theta < 0.0, + torch.atan2(-sin_theta, -cos_theta), + torch.atan2(sin_theta, cos_theta)) + + k_pos: torch.Tensor = two_theta / sin_theta + k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) + k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) + + angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis + + +def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to 4d quaternion vector + + This algorithm is based on algorithm described in + https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 + + Args: + rotation_matrix (Tensor): the rotation matrix to convert. + + Return: + Tensor: the rotation in quaternion + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 4)` + + Example: + >>> input = torch.rand(4, 3, 4) # Nx3x4 + >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 + """ + if not torch.is_tensor(rotation_matrix): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(rotation_matrix))) + + if len(rotation_matrix.shape) > 3: + raise ValueError( + "Input size must be a three dimensional tensor. Got {}".format( + rotation_matrix.shape)) + if not rotation_matrix.shape[-2:] == (3, 4): + raise ValueError( + "Input size must be a N x 3 x 4 tensor. Got {}".format( + rotation_matrix.shape)) + + rmat_t = torch.transpose(rotation_matrix, 1, 2) + + mask_d2 = rmat_t[:, 2, 2] < eps + + mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] + mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] + + t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1) + t0_rep = t0.repeat(4, 1).t() + + t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1) + t1_rep = t1.repeat(4, 1).t() + + t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1) + t2_rep = t2.repeat(4, 1).t() + + t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1) + t3_rep = t3.repeat(4, 1).t() + + mask_c0 = mask_d2 * mask_d0_d1 + mask_c1 = mask_d2 * ~mask_d0_d1 + mask_c2 = ~mask_d2 * mask_d0_nd1 + mask_c3 = ~mask_d2 * ~mask_d0_nd1 + mask_c0 = mask_c0.view(-1, 1).type_as(q0) + mask_c1 = mask_c1.view(-1, 1).type_as(q1) + mask_c2 = mask_c2.view(-1, 1).type_as(q2) + mask_c3 = mask_c3.view(-1, 1).type_as(q3) + + q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 + q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa + t2_rep * mask_c2 + t3_rep * mask_c3) # noqa + q *= 0.5 + return q + + +def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000., img_size=224.): + """ + This function is borrowed from https://github.com/nkolot/SPIN/utils/geometry.py + + Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Input: + S: (25, 3) 3D joint locations + joints: (25, 3) 2D joint locations and confidence + Returns: + (3,) camera translation vector + """ + + num_joints = S.shape[0] + # focal length + f = np.array([focal_length,focal_length]) + # optical center + center = np.array([img_size/2., img_size/2.]) + + # transformations + Z = np.reshape(np.tile(S[:,2],(2,1)).T,-1) + XY = np.reshape(S[:,0:2],-1) + O = np.tile(center,num_joints) + F = np.tile(f,num_joints) + weight2 = np.reshape(np.tile(np.sqrt(joints_conf),(2,1)).T,-1) + + # least squares + Q = np.array([F*np.tile(np.array([1,0]),num_joints), F*np.tile(np.array([0,1]),num_joints), O-np.reshape(joints_2d,-1)]).T + c = (np.reshape(joints_2d,-1)-O)*Z - F*XY + + # weighted least squares + W = np.diagflat(weight2) + Q = np.dot(W,Q) + c = np.dot(W,c) + + # square matrix + A = np.dot(Q.T,Q) + b = np.dot(Q.T,c) + + # solution + trans = np.linalg.solve(A, b) + + return trans + + +def estimate_translation(S, joints_2d, focal_length=5000., img_size=224.): + """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Input: + S: (B, 49, 3) 3D joint locations + joints: (B, 49, 3) 2D joint locations and confidence + Returns: + (B, 3) camera translation vectors + """ + + device = S.device + # Use only joints 25:49 (GT joints) + S = S[:, -24:, :3].cpu().numpy() + joints_2d = joints_2d[:, -24:, :].cpu().numpy() + + joints_conf = joints_2d[:, :, -1] + joints_2d = joints_2d[:, :, :-1] + trans = np.zeros((S.shape[0], 3), dtype=np.float32) + # Find the translation for each example in the batch + for i in range(S.shape[0]): + S_i = S[i] + joints_i = joints_2d[i] + conf_i = joints_conf[i] + trans[i] = estimate_translation_np(S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size) + return torch.from_numpy(trans).to(device) + + diff --git a/lib/utils/imutils.py b/lib/utils/imutils.py new file mode 100644 index 0000000000000000000000000000000000000000..d38db0bdd4446245f4bfdafb43643d92125f9d8d --- /dev/null +++ b/lib/utils/imutils.py @@ -0,0 +1,286 @@ +""" +This file contains functions that are used to perform data augmentation. +""" +import torch +import numpy as np +from skimage.transform import rotate, resize +import cv2 +from torchvision.transforms import Normalize, ToTensor, Compose + +from lib.core import constants + +def get_normalization(): + normalize_img = Compose([ToTensor(), + Normalize(mean=constants.IMG_NORM_MEAN, + std=constants.IMG_NORM_STD) + ]) + return normalize_img + +def get_transform(center, scale, res, rot=0): + """Generate transformation matrix.""" + h = 200 * scale + 1e-6 + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / h + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / h + .5) + t[1, 2] = res[0] * (-float(center[1]) / h + .5) + t[2, 2] = 1 + if not rot == 0: + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3,3)) + rot_rad = rot * np.pi / 180 + sn,cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0,:2] = [cs, -sn] + rot_mat[1,:2] = [sn, cs] + rot_mat[2,2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0,2] = -res[1]/2 + t_mat[1,2] = -res[0]/2 + t_inv = t_mat.copy() + t_inv[:2,2] *= -1 + t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t))) + return t + +def transform(pt, center, scale, res, invert=0, rot=0, asint=True): + """Transform pixel location to different reference.""" + t = get_transform(center, scale, res, rot=rot) + if invert: + t = np.linalg.inv(t) + new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T + new_pt = np.dot(t, new_pt) + + if asint: + return new_pt[:2].astype(int)+1 + else: + return new_pt[:2]+1 + +def transform_pts(pts, center, scale, res, invert=0, rot=0, asint=True): + """Transform pixel location to different reference.""" + t = get_transform(center, scale, res, rot=rot) + if invert: + t = np.linalg.inv(t) + pts = np.concatenate((pts, np.ones_like(pts)[:, [0]]), axis=-1) + new_pt = pts.T + new_pt = np.dot(t, new_pt) + + if asint: + return new_pt[:2, :].T.astype(int) + else: + return new_pt[:2, :].T + +def crop(img, center, scale, res, rot=0): + """Crop image according to the supplied bounding box.""" + # Upper left point + ul = np.array(transform([1, 1], center, scale, res, invert=1))-1 + # Bottom right point + br = np.array(transform([res[0]+1, + res[1]+1], center, scale, res, invert=1))-1 + + # Padding so that when rotated proper amount of context is included + pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + if not rot == 0: + ul -= pad + br += pad + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(new_shape) + + + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] + new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(len(img[0]), br[0]) + old_y = max(0, ul[1]), min(len(img), br[1]) + try: + new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], + old_x[0]:old_x[1]] + except: + print("invlid bbox, fill with 0") + + if not rot == 0: + # Remove padding + new_img = rotate(new_img, rot) + new_img = new_img[pad:-pad, pad:-pad] + + new_img = resize(new_img, res) + return new_img + +def crop_j2d(j2d, center, scale, res, rot=0): + """Crop image according to the supplied bounding box.""" + # Upper left point + # crop_j2d = np.array(transform_pts(j2d, center, scale, res, invert=0)) + b = scale * 200 + points2d = j2d - (center - b/2) + points2d = points2d * (res[0] / b) + + return points2d + + +def crop_crop(img, center, scale, res, rot=0): + """Crop image according to the supplied bounding box.""" + # Upper left point + ul = np.array(transform([1, 1], center, scale, res, invert=1))-1 + # Bottom right point + br = np.array(transform([res[0]+1, + res[1]+1], center, scale, res, invert=1))-1 + + # Padding so that when rotated proper amount of context is included + pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + if not rot == 0: + ul -= pad + br += pad + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(new_shape) + + + if new_img.shape[0] > img.shape[0]: + p = (new_img.shape[0] - img.shape[0]) / 2 + p = int(p) + new_img = cv2.copyMakeBorder(img, p, p, p, p, cv2.BORDER_REPLICATE) + + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] + new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(len(img[0]), br[0]) + old_y = max(0, ul[1]), min(len(img), br[1]) + new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], + old_x[0]:old_x[1]] + + if not rot == 0: + # Remove padding + new_img = rotate(new_img, rot) + new_img = new_img[pad:-pad, pad:-pad] + + new_img = resize(new_img, res) + return new_img + +def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True): + """'Undo' the image cropping/resizing. + This function is used when evaluating mask/part segmentation. + """ + res = img.shape[:2] + # Upper left point + ul = np.array(transform([1, 1], center, scale, res, invert=1))-1 + # Bottom right point + br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1 + # size of cropped image + crop_shape = [br[1] - ul[1], br[0] - ul[0]] + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(orig_shape, dtype=np.uint8) + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0] + new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(orig_shape[1], br[0]) + old_y = max(0, ul[1]), min(orig_shape[0], br[1]) + img = resize(img, crop_shape, interp='nearest') + new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]] + return new_img + +def rot_aa(aa, rot): + """Rotate axis angle parameters.""" + # pose parameters + R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], + [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], + [0, 0, 1]]) + # find the rotation of the body in camera frame + per_rdg, _ = cv2.Rodrigues(aa) + # apply the global rotation to the global orientation + resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg)) + aa = (resrot.T)[0] + return aa + +def flip_img(img): + """Flip rgb images or masks. + channels come last, e.g. (256,256,3). + """ + img = np.fliplr(img) + return img + +def flip_kp(kp): + """Flip keypoints.""" + if len(kp) == 24: + flipped_parts = constants.J24_FLIP_PERM + elif len(kp) == 49: + flipped_parts = constants.J49_FLIP_PERM + kp = kp[flipped_parts] + kp[:,0] = - kp[:,0] + return kp + +def flip_pose(pose): + """Flip pose. + The flipping is based on SMPL parameters. + """ + flipped_parts = constants.SMPL_POSE_FLIP_PERM + pose = pose[flipped_parts] + # we also negate the second and the third dimension of the axis-angle + pose[1::3] = -pose[1::3] + pose[2::3] = -pose[2::3] + return pose + + +def crop_img(img, center, scale, res, val=255): + """Crop image according to the supplied bounding box.""" + # Upper left point + ul = np.array(transform([1, 1], center, scale, res, invert=1))-1 + # Bottom right point + br = np.array(transform([res[0]+1, + res[1]+1], center, scale, res, invert=1))-1 + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.ones(new_shape) * val + + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] + new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(len(img[0]), br[0]) + old_y = max(0, ul[1]), min(len(img), br[1]) + new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], + old_x[0]:old_x[1]] + new_img = resize(new_img, res) + return new_img + + +def boxes_2_cs(boxes): + x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] + w, h = x2-x1, y2-y1 + cx, cy = x1+w/2, y1+h/2 + size = np.stack([w, h]).max(axis=0) + + centers = np.stack([cx, cy], axis=1) + scales = size / 200 + return centers, scales + + +def box_2_cs(box): + x1,y1,x2,y2 = box[:4].int().tolist() + + w, h = x2-x1, y2-y1 + cx, cy = x1+w/2, y1+h/2 + size = max(w, h) + + center = [cx, cy] + scale = size / 200 + return center, scale + + +def est_intrinsics(img_shape): + h, w, c = img_shape + img_center = torch.tensor([w/2., h/2.]).float() + img_focal = torch.tensor(np.sqrt(h**2 + w**2)).float() + return img_center, img_focal + diff --git a/lib/vis/__pycache__/renderer.cpython-310.pyc b/lib/vis/__pycache__/renderer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd3fcdddc6426fe08a0300b4ec618c1c97c4dad3 Binary files /dev/null and b/lib/vis/__pycache__/renderer.cpython-310.pyc differ diff --git a/lib/vis/__pycache__/run_vis2.cpython-310.pyc b/lib/vis/__pycache__/run_vis2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f8f39d1ffb22e7a23459d38fc903eb10802f76a Binary files /dev/null and b/lib/vis/__pycache__/run_vis2.cpython-310.pyc differ diff --git a/lib/vis/__pycache__/tools.cpython-310.pyc b/lib/vis/__pycache__/tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b96ce951765fead7570b9951a63e48b0e9ebb0b Binary files /dev/null and b/lib/vis/__pycache__/tools.cpython-310.pyc differ diff --git a/lib/vis/__pycache__/viewer.cpython-310.pyc b/lib/vis/__pycache__/viewer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5679447fbf1a890041998ef990ec13444f1d734 Binary files /dev/null and b/lib/vis/__pycache__/viewer.cpython-310.pyc differ diff --git a/lib/vis/renderer.py b/lib/vis/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..efe7e6ab2c89cdc5d95d85dfb62ca0c2d643a0c5 --- /dev/null +++ b/lib/vis/renderer.py @@ -0,0 +1,356 @@ +# Useful rendering functions from WHAM (some modification) + +import cv2 +import torch +import numpy as np + +from pytorch3d.renderer import ( + PerspectiveCameras, + TexturesVertex, + PointLights, + Materials, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, +) +from pytorch3d.structures import Meshes +from pytorch3d.structures.meshes import join_meshes_as_scene +from pytorch3d.renderer.cameras import look_at_rotation +from pytorch3d.renderer.camera_conversions import _cameras_from_opencv_projection + +from .tools import get_colors, checkerboard_geometry + + +def overlay_image_onto_background(image, mask, bbox, background): + if isinstance(image, torch.Tensor): + image = image.detach().cpu().numpy() + if isinstance(mask, torch.Tensor): + mask = mask.detach().cpu().numpy() + + out_image = background.copy() + bbox = bbox[0].int().cpu().numpy().copy() + roi_image = out_image[bbox[1]:bbox[3], bbox[0]:bbox[2]] + + roi_image[mask] = image[mask] + out_image[bbox[1]:bbox[3], bbox[0]:bbox[2]] = roi_image + + return out_image + + +def update_intrinsics_from_bbox(K_org, bbox): + device, dtype = K_org.device, K_org.dtype + + K = torch.zeros((K_org.shape[0], 4, 4) + ).to(device=device, dtype=dtype) + K[:, :3, :3] = K_org.clone() + K[:, 2, 2] = 0 + K[:, 2, -1] = 1 + K[:, -1, 2] = 1 + + image_sizes = [] + for idx, bbox in enumerate(bbox): + left, upper, right, lower = bbox + cx, cy = K[idx, 0, 2], K[idx, 1, 2] + + new_cx = cx - left + new_cy = cy - upper + new_height = max(lower - upper, 1) + new_width = max(right - left, 1) + new_cx = new_width - new_cx + new_cy = new_height - new_cy + + K[idx, 0, 2] = new_cx + K[idx, 1, 2] = new_cy + image_sizes.append((int(new_height), int(new_width))) + + return K, image_sizes + + +def perspective_projection(x3d, K, R=None, T=None): + if R != None: + x3d = torch.matmul(R, x3d.transpose(1, 2)).transpose(1, 2) + if T != None: + x3d = x3d + T.transpose(1, 2) + + x2d = torch.div(x3d, x3d[..., 2:]) + x2d = torch.matmul(K, x2d.transpose(-1, -2)).transpose(-1, -2)[..., :2] + return x2d + + +def compute_bbox_from_points(X, img_w, img_h, scaleFactor=1.2): + left = torch.clamp(X.min(1)[0][:, 0], min=0, max=img_w) + right = torch.clamp(X.max(1)[0][:, 0], min=0, max=img_w) + top = torch.clamp(X.min(1)[0][:, 1], min=0, max=img_h) + bottom = torch.clamp(X.max(1)[0][:, 1], min=0, max=img_h) + + cx = (left + right) / 2 + cy = (top + bottom) / 2 + width = (right - left) + height = (bottom - top) + + new_left = torch.clamp(cx - width/2 * scaleFactor, min=0, max=img_w-1) + new_right = torch.clamp(cx + width/2 * scaleFactor, min=1, max=img_w) + new_top = torch.clamp(cy - height / 2 * scaleFactor, min=0, max=img_h-1) + new_bottom = torch.clamp(cy + height / 2 * scaleFactor, min=1, max=img_h) + + bbox = torch.stack((new_left.detach(), new_top.detach(), + new_right.detach(), new_bottom.detach())).int().float().T + + return bbox + + +class Renderer(): + def __init__(self, width, height, focal_length, device, + bin_size=None, max_faces_per_bin=None): + + self.width = width + self.height = height + self.focal_length = focal_length + + self.device = device + + self.initialize_camera_params() + self.lights = PointLights(device=device, location=[[0.0, 0.0, -10.0]]) + self.create_renderer(bin_size, max_faces_per_bin) + + def create_renderer(self, bin_size, max_faces_per_bin): + self.renderer = MeshRenderer( + rasterizer=MeshRasterizer( + raster_settings=RasterizationSettings( + image_size=self.image_sizes[0], + blur_radius=1e-5, bin_size=bin_size, + max_faces_per_bin=max_faces_per_bin), + ), + shader=SoftPhongShader( + device=self.device, + lights=self.lights, + ) + ) + + def initialize_camera_params(self): + """Hard coding for camera parameters + TODO: Do some soft coding""" + + # Extrinsics + self.R = torch.diag( + torch.tensor([1, 1, 1]) + ).float().to(self.device).unsqueeze(0) + + self.T = torch.tensor( + [0, 0, 0] + ).unsqueeze(0).float().to(self.device) + + # Intrinsics + self.K = torch.tensor( + [[self.focal_length, 0, self.width/2], + [0, self.focal_length, self.height/2], + [0, 0, 1]] + ).unsqueeze(0).float().to(self.device) + self.bboxes = torch.tensor([[0, 0, self.width, self.height]]).float() + self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, self.bboxes) + + # self.K_full = self.K # test + self.cameras = self.create_camera() + + def create_camera(self, R=None, T=None): + if R is not None: + self.R = R.clone().view(1, 3, 3).to(self.device) + if T is not None: + self.T = T.clone().view(1, 3).to(self.device) + + return PerspectiveCameras( + device=self.device, + R=self.R, #.mT, + T=self.T, + K=self.K_full, + image_size=self.image_sizes, + in_ndc=False) + + def create_camera_from_cv(self, R, T, K=None, image_size=None): + # R: [1, 3, 3] Tensor + # T: [1, 3] Tensor + # K: [1, 3, 3] Tensor + # image_size: [1, 2] Tensor in HW + if K is None: + K = self.K + + if image_size is None: + image_size = torch.tensor(self.image_sizes) + + cameras = _cameras_from_opencv_projection(R, T, K, image_size) + lights = PointLights(device=K.device, location=T) + + return cameras, lights + + def set_ground(self, length, center_x, center_z): + device = self.device + v, f, vc, fc = map(torch.from_numpy, checkerboard_geometry(length=length, c1=center_x, c2=center_z, up="y")) + v, f, vc = v.to(device), f.to(device), vc.to(device) + self.ground_geometry = [v, f, vc] + + + def update_bbox(self, x3d, scale=2.0, mask=None): + """ Update bbox of cameras from the given 3d points + + x3d: input 3D keypoints (or vertices), (num_frames, num_points, 3) + """ + + if x3d.size(-1) != 3: + x2d = x3d.unsqueeze(0) + else: + x2d = perspective_projection(x3d.unsqueeze(0), self.K, self.R, self.T.reshape(1, 3, 1)) + + if mask is not None: + x2d = x2d[:, ~mask] + + bbox = compute_bbox_from_points(x2d, self.width, self.height, scale) + self.bboxes = bbox + + self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox) + self.cameras = self.create_camera() + self.create_renderer() + + def reset_bbox(self,): + bbox = torch.zeros((1, 4)).float().to(self.device) + bbox[0, 2] = self.width + bbox[0, 3] = self.height + self.bboxes = bbox + + self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox) + self.cameras = self.create_camera() + self.create_renderer() + + def render_mesh(self, vertices, background, colors=[0.8, 0.8, 0.8]): + self.update_bbox(vertices[::50], scale=1.2) + vertices = vertices.unsqueeze(0) + + if colors[0] > 1: colors = [c / 255. for c in colors] + verts_features = torch.tensor(colors).reshape(1, 1, 3).to(device=vertices.device, dtype=vertices.dtype) + verts_features = verts_features.repeat(1, vertices.shape[1], 1) + textures = TexturesVertex(verts_features=verts_features) + + mesh = Meshes(verts=vertices, + faces=self.faces, + textures=textures,) + + materials = Materials( + device=self.device, + specular_color=(colors, ), + shininess=0 + ) + + results = torch.flip( + self.renderer(mesh, materials=materials, cameras=self.cameras, lights=self.lights), + [1, 2] + ) + image = results[0, ..., :3] * 255 + mask = results[0, ..., -1] > 1e-3 + + image = overlay_image_onto_background(image, mask, self.bboxes, background.copy()) + self.reset_bbox() + return image + + + def render_with_ground(self, verts, faces, colors, cameras, lights): + """ + :param verts (B, V, 3) + :param faces (F, 3) + :param colors (B, 3) + """ + + # (B, V, 3), (B, F, 3), (B, V, 3) + verts, faces, colors = prep_shared_geometry(verts, faces, colors) + # (V, 3), (F, 3), (V, 3) + gv, gf, gc = self.ground_geometry + verts = list(torch.unbind(verts, dim=0)) + [gv] + faces = list(torch.unbind(faces, dim=0)) + [gf] + colors = list(torch.unbind(colors, dim=0)) + [gc[..., :3]] + mesh = create_meshes(verts, faces, colors) + + materials = Materials( + device=self.device, + shininess=0 + ) + + results = self.renderer(mesh, cameras=cameras, lights=lights, materials=materials) + image = (results[0, ..., :3].cpu().numpy() * 255).astype(np.uint8) + + return image + + def render_multiple(self, verts_list, faces, colors_list, cameras, lights): + """ + :param verts (B, V, 3) + :param faces (F, 3) + :param colors (B, 3) + """ + # (B, V, 3), (B, F, 3), (B, V, 3) + verts_, faces_, colors_ = [], [], [] + for i, verts in enumerate(verts_list): + colors = colors_list[[i]] + verts_i, faces_i, colors_i = prep_shared_geometry(verts, faces, colors) + if i == 0: + verts_ = list(torch.unbind(verts_i, dim=0)) + faces_ = list(torch.unbind(faces_i, dim=0)) + colors_ = list(torch.unbind(colors_i, dim=0)) + else: + verts_ += list(torch.unbind(verts_i, dim=0)) + faces_ += list(torch.unbind(faces_i, dim=0)) + colors_ += list(torch.unbind(colors_i, dim=0)) + + # # (V, 3), (F, 3), (V, 3) + # gv, gf, gc = self.ground_geometry + # verts_ += [gv] + # faces_ += [gf] + # colors_ += [gc[..., :3]] + mesh = create_meshes(verts_, faces_, colors_) + + materials = Materials( + device=self.device, + shininess=0 + ) + results = self.renderer(mesh, cameras=cameras, lights=lights, materials=materials) + image = (results[0, ..., :3].cpu().numpy() * 255).astype(np.uint8) + mask = results[0, ..., -1].cpu().numpy() > 0 + return image, mask + + +def prep_shared_geometry(verts, faces, colors): + """ + :param verts (B, V, 3) + :param faces (F, 3) + :param colors (B, 4) + """ + B, V, _ = verts.shape + F, _ = faces.shape + colors = colors.unsqueeze(1).expand(B, V, -1)[..., :3] + faces = faces.unsqueeze(0).expand(B, F, -1) + return verts, faces, colors + + +def create_meshes(verts, faces, colors): + """ + :param verts (B, V, 3) + :param faces (B, F, 3) + :param colors (B, V, 3) + """ + textures = TexturesVertex(verts_features=colors) + meshes = Meshes(verts=verts, faces=faces, textures=textures) + return join_meshes_as_scene(meshes) + + +def get_global_cameras(verts, device, distance=5, position=(-5.0, 5.0, 0.0)): + positions = torch.tensor([position]).repeat(len(verts), 1) + targets = verts.mean(1) + + directions = targets - positions + directions = directions / torch.norm(directions, dim=-1).unsqueeze(-1) * distance + positions = targets - directions + + rotation = look_at_rotation(positions, targets, ).mT + translation = -(rotation @ positions.unsqueeze(-1)).squeeze(-1) + + lights = PointLights(device=device, location=[position]) + return rotation, translation, lights + + diff --git a/lib/vis/run_vis2.py b/lib/vis/run_vis2.py new file mode 100644 index 0000000000000000000000000000000000000000..390a1dfe2da0949d14f2a1283fbf30f280d4295d --- /dev/null +++ b/lib/vis/run_vis2.py @@ -0,0 +1,250 @@ +import os +import cv2 +import numpy as np +import torch +import trimesh + +import lib.vis.viewer as viewer_utils +from lib.vis.wham_tools.tools import checkerboard_geometry + +def camera_marker_geometry(radius, height): + vertices = np.array( + [ + [-radius, -radius, 0], + [radius, -radius, 0], + [radius, radius, 0], + [-radius, radius, 0], + [0, 0, - height], + ] + ) + + + faces = np.array( + [[0, 1, 2], [0, 2, 3], [1, 0, 4], [2, 1, 4], [3, 2, 4], [0, 3, 4],] + ) + + face_colors = np.array( + [ + [0.5, 0.5, 0.5, 1.0], + [0.5, 0.5, 0.5, 1.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 1.0], + ] + ) + return vertices, faces, face_colors + + +def run_vis2_on_video(res_dict, res_dict2, output_pth, focal_length, image_names, R_c2w=None, t_c2w=None): + + img0 = cv2.imread(image_names[0]) + height, width, _ = img0.shape + + world_mano = {} + world_mano['vertices'] = res_dict['vertices'] + world_mano['faces'] = res_dict['faces'] + + world_mano2 = {} + world_mano2['vertices'] = res_dict2['vertices'] + world_mano2['faces'] = res_dict2['faces'] + + vis_dict = {} + color_idx = 0 + world_mano['vertices'] = world_mano['vertices'] + for _id, _verts in enumerate(world_mano['vertices']): + verts = _verts.cpu().numpy() # T, N, 3 + body_faces = world_mano['faces'] + body_meshes = { + "v3d": verts, + "f3d": body_faces, + "vc": None, + "name": f"hand_{_id}", + # "color": "pace-green", + "color": "director-purple", + } + vis_dict[f"hand_{_id}"] = body_meshes + color_idx += 1 + + world_mano2['vertices'] = world_mano2['vertices'] + for _id, _verts in enumerate(world_mano2['vertices']): + verts = _verts.cpu().numpy() # T, N, 3 + body_faces = world_mano2['faces'] + body_meshes = { + "v3d": verts, + "f3d": body_faces, + "vc": None, + "name": f"hand2_{_id}", + # "color": "pace-blue", + "color": "director-blue", + } + vis_dict[f"hand2_{_id}"] = body_meshes + color_idx += 1 + + v, f, vc, fc = checkerboard_geometry(length=100, c1=0, c2=0, up="z") + v[:, 2] -= 2 # z plane + gound_meshes = { + "v3d": v, + "f3d": f, + "vc": vc, + "name": "ground", + "fc": fc, + "color": -1, + } + vis_dict["ground"] = gound_meshes + + num_frames = len(world_mano['vertices'][_id]) + Rt = np.zeros((num_frames, 3, 4)) + Rt[:, :3, :3] = R_c2w[:num_frames] + Rt[:, :3, 3] = t_c2w[:num_frames] + + verts, faces, face_colors = camera_marker_geometry(0.05, 0.1) + verts = np.einsum("tij,nj->tni", Rt[:, :3, :3], verts) + Rt[:, None, :3, 3] + camera_meshes = { + "v3d": verts, + "f3d": faces, + "vc": None, + "name": "camera", + "fc": face_colors, + "color": -1, + } + vis_dict["camera"] = camera_meshes + + side_source = torch.tensor([0.463, -0.478, 2.456]) + side_target = torch.tensor([0.026, -0.481, -3.184]) + up = torch.tensor([1.0, 0.0, 0.0]) + view_camera = lookat_matrix(side_source, side_target, up) + viewer_Rt = np.tile(view_camera[:3, :4], (num_frames, 1, 1)) + + meshes = viewer_utils.construct_viewer_meshes( + vis_dict, draw_edges=False, flat_shading=False + ) + + vis_h, vis_w = (1000, 1000) + K = np.array( + [ + [1000, 0, vis_w / 2], + [0, 1000, vis_h / 2], + [0, 0, 1] + ] + ) + + data = viewer_utils.ViewerData(viewer_Rt, K, vis_w, vis_h) + batch = (meshes, data) + + viewer = viewer_utils.ARCTICViewer(interactive=True, size=(vis_w, vis_h)) + viewer.render_seq(batch, out_folder=os.path.join(output_pth, 'aitviewer')) + +def run_vis2_on_video_cam(res_dict, res_dict2, output_pth, focal_length, image_names, R_w2c=None, t_w2c=None): + + img0 = cv2.imread(image_names[0]) + height, width, _ = img0.shape + + world_mano = {} + world_mano['vertices'] = res_dict['vertices'] + world_mano['faces'] = res_dict['faces'] + + world_mano2 = {} + world_mano2['vertices'] = res_dict2['vertices'] + world_mano2['faces'] = res_dict2['faces'] + + vis_dict = {} + color_idx = 0 + world_mano['vertices'] = world_mano['vertices'] + for _id, _verts in enumerate(world_mano['vertices']): + verts = _verts.cpu().numpy() # T, N, 3 + body_faces = world_mano['faces'] + body_meshes = { + "v3d": verts, + "f3d": body_faces, + "vc": None, + "name": f"hand_{_id}", + # "color": "pace-green", + "color": "director-purple", + } + vis_dict[f"hand_{_id}"] = body_meshes + color_idx += 1 + + world_mano2['vertices'] = world_mano2['vertices'] + for _id, _verts in enumerate(world_mano2['vertices']): + verts = _verts.cpu().numpy() # T, N, 3 + body_faces = world_mano2['faces'] + body_meshes = { + "v3d": verts, + "f3d": body_faces, + "vc": None, + "name": f"hand2_{_id}", + # "color": "pace-blue", + "color": "director-blue", + } + vis_dict[f"hand2_{_id}"] = body_meshes + color_idx += 1 + + meshes = viewer_utils.construct_viewer_meshes( + vis_dict, draw_edges=False, flat_shading=False + ) + + num_frames = len(world_mano['vertices'][_id]) + Rt = np.zeros((num_frames, 3, 4)) + Rt[:, :3, :3] = R_w2c[:num_frames] + Rt[:, :3, 3] = t_w2c[:num_frames] + + cols, rows = (width, height) + K = np.array( + [ + [focal_length, 0, width / 2], + [0, focal_length, height / 2], + [0, 0, 1] + ] + ) + vis_h = height + vis_w = width + + data = viewer_utils.ViewerData(Rt, K, cols, rows, imgnames=image_names) + batch = (meshes, data) + + viewer = viewer_utils.ARCTICViewer(interactive=True, size=(vis_w, vis_h)) + viewer.render_seq(batch, out_folder=os.path.join(output_pth, 'aitviewer')) + +def lookat_matrix(source_pos, target_pos, up): + """ + IMPORTANT: USES RIGHT UP BACK XYZ CONVENTION + :param source_pos (*, 3) + :param target_pos (*, 3) + :param up (3,) + """ + *dims, _ = source_pos.shape + up = up.reshape(*(1,) * len(dims), 3) + up = up / torch.linalg.norm(up, dim=-1, keepdim=True) + back = normalize(target_pos - source_pos) + right = normalize(torch.linalg.cross(up, back)) + up = normalize(torch.linalg.cross(back, right)) + R = torch.stack([right, up, back], dim=-1) + return make_4x4_pose(R, source_pos) + +def make_4x4_pose(R, t): + """ + :param R (*, 3, 3) + :param t (*, 3) + return (*, 4, 4) + """ + dims = R.shape[:-2] + pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1) + bottom = ( + torch.tensor([0, 0, 0, 1], device=R.device) + .reshape(*(1,) * len(dims), 1, 4) + .expand(*dims, 1, 4) + ) + return torch.cat([pose_3x4, bottom], dim=-2) + +def normalize(x): + return x / torch.linalg.norm(x, dim=-1, keepdim=True) + +def save_mesh_to_obj(vertices, faces, file_path): + # 创建一个 Trimesh 对象 + mesh = trimesh.Trimesh(vertices=vertices, faces=faces) + + # 导出为 .obj 文件 + mesh.export(file_path) + print(f"Mesh saved to {file_path}") + diff --git a/lib/vis/tools.py b/lib/vis/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..da8695e6391633b6143adebfb154566ee6ebf697 --- /dev/null +++ b/lib/vis/tools.py @@ -0,0 +1,825 @@ +# Useful visualization functions from SLAHMR and WHAM + +import os +import cv2 +import numpy as np +import torch +from PIL import Image + + +def read_image(path, scale=1): + im = Image.open(path) + if scale == 1: + return np.array(im) + W, H = im.size + w, h = int(scale * W), int(scale * H) + return np.array(im.resize((w, h), Image.ANTIALIAS)) + + +def transform_torch3d(T_c2w): + """ + :param T_c2w (*, 4, 4) + returns (*, 3, 3), (*, 3) + """ + R1 = torch.tensor( + [[-1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, 1.0],], device=T_c2w.device, + ) + R2 = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0],], device=T_c2w.device, + ) + cam_R, cam_t = T_c2w[..., :3, :3], T_c2w[..., :3, 3] + cam_R = torch.einsum("...ij,jk->...ik", cam_R, R1) + cam_t = torch.einsum("ij,...j->...i", R2, cam_t) + return cam_R, cam_t + + +def transform_pyrender(T_c2w): + """ + :param T_c2w (*, 4, 4) + """ + T_vis = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, -1.0, 0.0, 0.0], + [0.0, 0.0, -1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + device=T_c2w.device, + ) + return torch.einsum( + "...ij,jk->...ik", torch.einsum("ij,...jk->...ik", T_vis, T_c2w), T_vis + ) + + +def smpl_to_geometry(verts, faces, vis_mask=None, track_ids=None): + """ + :param verts (B, T, V, 3) + :param faces (F, 3) + :param vis_mask (optional) (B, T) visibility of each person + :param track_ids (optional) (B,) + returns list of T verts (B, V, 3), faces (F, 3), colors (B, 3) + where B is different depending on the visibility of the people + """ + B, T = verts.shape[:2] + device = verts.device + + # (B, 3) + colors = ( + track_to_colors(track_ids) + if track_ids is not None + else torch.ones(B, 3, device) * 0.5 + ) + + # list T (B, V, 3), T (B, 3), T (F, 3) + return filter_visible_meshes(verts, colors, faces, vis_mask) + + +def filter_visible_meshes(verts, colors, faces, vis_mask=None, vis_opacity=False): + """ + :param verts (B, T, V, 3) + :param colors (B, 3) + :param faces (F, 3) + :param vis_mask (optional tensor, default None) (B, T) ternary mask + -1 if not in frame + 0 if temporarily occluded + 1 if visible + :param vis_opacity (optional bool, default False) + if True, make occluded people alpha=0.5, otherwise alpha=1 + returns a list of T lists verts (Bi, V, 3), colors (Bi, 4), faces (F, 3) + """ + # import ipdb; ipdb.set_trace() + B, T = verts.shape[:2] + faces = [faces for t in range(T)] + if vis_mask is None: + verts = [verts[:, t] for t in range(T)] + colors = [colors for t in range(T)] + return verts, colors, faces + + # render occluded and visible, but not removed + vis_mask = vis_mask >= 0 + if vis_opacity: + alpha = 0.5 * (vis_mask[..., None] + 1) + else: + alpha = (vis_mask[..., None] >= 0).float() + vert_list = [verts[vis_mask[:, t], t] for t in range(T)] + colors = [ + torch.cat([colors[vis_mask[:, t]], alpha[vis_mask[:, t], t]], dim=-1) + for t in range(T) + ] + bounds = get_bboxes(verts, vis_mask) + return vert_list, colors, faces, bounds + + +def get_bboxes(verts, vis_mask): + """ + return bb_min, bb_max, and mean for each track (B, 3) over entire trajectory + :param verts (B, T, V, 3) + :param vis_mask (B, T) + """ + B, T, *_ = verts.shape + bb_min, bb_max, mean = [], [], [] + for b in range(B): + v = verts[b, vis_mask[b, :T]] # (Tb, V, 3) + bb_min.append(v.amin(dim=(0, 1))) + bb_max.append(v.amax(dim=(0, 1))) + mean.append(v.mean(dim=(0, 1))) + bb_min = torch.stack(bb_min, dim=0) + bb_max = torch.stack(bb_max, dim=0) + mean = torch.stack(mean, dim=0) + # point to a track that's long and close to the camera + zs = mean[:, 2] + counts = vis_mask[:, :T].sum(dim=-1) # (B,) + mask = counts < 0.8 * T + zs[mask] = torch.inf + sel = torch.argmin(zs) + return bb_min.amin(dim=0), bb_max.amax(dim=0), mean[sel] + + +def track_to_colors(track_ids): + """ + :param track_ids (B) + """ + color_map = torch.from_numpy(get_colors()).to(track_ids) + return color_map[track_ids] / 255 # (B, 3) + + +def get_colors(): + # color_file = os.path.abspath(os.path.join(__file__, "../colors_phalp.txt")) + color_file = os.path.abspath(os.path.join(__file__, "../colors.txt")) + RGB_tuples = np.vstack( + [ + np.loadtxt(color_file, skiprows=0), + # np.loadtxt(color_file, skiprows=1), + np.random.uniform(0, 255, size=(10000, 3)), + [[0, 0, 0]], + ] + ) + b = np.where(RGB_tuples == 0) + RGB_tuples[b] = 1 + return RGB_tuples.astype(np.float32) + + +def checkerboard_geometry( + length=12.0, + color0=[0.8, 0.9, 0.9], + color1=[0.6, 0.7, 0.7], + tile_width=0.5, + alpha=1.0, + up="y", + c1=0.0, + c2=0.0, +): + assert up == "y" or up == "z" + color0 = np.array(color0 + [alpha]) + color1 = np.array(color1 + [alpha]) + radius = length / 2.0 + num_rows = num_cols = max(2, int(length / tile_width)) + vertices = [] + vert_colors = [] + faces = [] + face_colors = [] + for i in range(num_rows): + for j in range(num_cols): + u0, v0 = j * tile_width - radius, i * tile_width - radius + us = np.array([u0, u0, u0 + tile_width, u0 + tile_width]) + vs = np.array([v0, v0 + tile_width, v0 + tile_width, v0]) + zs = np.zeros(4) + if up == "y": + cur_verts = np.stack([us, zs, vs], axis=-1) # (4, 3) + cur_verts[:, 0] += c1 + cur_verts[:, 2] += c2 + else: + cur_verts = np.stack([us, vs, zs], axis=-1) # (4, 3) + cur_verts[:, 0] += c1 + cur_verts[:, 1] += c2 + + cur_faces = np.array( + [[0, 1, 3], [1, 2, 3], [0, 3, 1], [1, 3, 2]], dtype=np.int64 + ) + cur_faces += 4 * (i * num_cols + j) # the number of previously added verts + use_color0 = (i % 2 == 0 and j % 2 == 0) or (i % 2 == 1 and j % 2 == 1) + cur_color = color0 if use_color0 else color1 + cur_colors = np.array([cur_color, cur_color, cur_color, cur_color]) + + vertices.append(cur_verts) + faces.append(cur_faces) + vert_colors.append(cur_colors) + face_colors.append(cur_colors) + + vertices = np.concatenate(vertices, axis=0).astype(np.float32) + vert_colors = np.concatenate(vert_colors, axis=0).astype(np.float32) + faces = np.concatenate(faces, axis=0).astype(np.float32) + face_colors = np.concatenate(face_colors, axis=0).astype(np.float32) + + return vertices, faces, vert_colors, face_colors + + +def camera_marker_geometry(radius, height, up): + assert up == "y" or up == "z" + if up == "y": + vertices = np.array( + [ + [-radius, -radius, 0], + [radius, -radius, 0], + [radius, radius, 0], + [-radius, radius, 0], + [0, 0, height], + ] + ) + else: + vertices = np.array( + [ + [-radius, 0, -radius], + [radius, 0, -radius], + [radius, 0, radius], + [-radius, 0, radius], + [0, -height, 0], + ] + ) + + faces = np.array( + [[0, 3, 1], [1, 3, 2], [0, 1, 4], [1, 2, 4], [2, 3, 4], [3, 0, 4],] + ) + + face_colors = np.array( + [ + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 1.0], + ] + ) + return vertices, faces, face_colors + + +def vis_keypoints( + keypts_list, + img_size, + radius=6, + thickness=3, + kpt_score_thr=0.3, + dataset="TopDownCocoDataset", +): + """ + Visualize keypoints + From ViTPose/mmpose/apis/inference.py + """ + palette = np.array( + [ + [255, 128, 0], + [255, 153, 51], + [255, 178, 102], + [230, 230, 0], + [255, 153, 255], + [153, 204, 255], + [255, 102, 255], + [255, 51, 255], + [102, 178, 255], + [51, 153, 255], + [255, 153, 153], + [255, 102, 102], + [255, 51, 51], + [153, 255, 153], + [102, 255, 102], + [51, 255, 51], + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [255, 255, 255], + ] + ) + + if dataset in ( + "TopDownCocoDataset", + "BottomUpCocoDataset", + "TopDownOCHumanDataset", + "AnimalMacaqueDataset", + ): + # show the results + skeleton = [ + [15, 13], + [13, 11], + [16, 14], + [14, 12], + [11, 12], + [5, 11], + [6, 12], + [5, 6], + [5, 7], + [6, 8], + [7, 9], + [8, 10], + [1, 2], + [0, 1], + [0, 2], + [1, 3], + [2, 4], + [3, 5], + [4, 6], + ] + + pose_link_color = palette[ + [0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16] + ] + pose_kpt_color = palette[ + [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0] + ] + + elif dataset == "TopDownCocoWholeBodyDataset": + # show the results + skeleton = [ + [15, 13], + [13, 11], + [16, 14], + [14, 12], + [11, 12], + [5, 11], + [6, 12], + [5, 6], + [5, 7], + [6, 8], + [7, 9], + [8, 10], + [1, 2], + [0, 1], + [0, 2], + [1, 3], + [2, 4], + [3, 5], + [4, 6], + [15, 17], + [15, 18], + [15, 19], + [16, 20], + [16, 21], + [16, 22], + [91, 92], + [92, 93], + [93, 94], + [94, 95], + [91, 96], + [96, 97], + [97, 98], + [98, 99], + [91, 100], + [100, 101], + [101, 102], + [102, 103], + [91, 104], + [104, 105], + [105, 106], + [106, 107], + [91, 108], + [108, 109], + [109, 110], + [110, 111], + [112, 113], + [113, 114], + [114, 115], + [115, 116], + [112, 117], + [117, 118], + [118, 119], + [119, 120], + [112, 121], + [121, 122], + [122, 123], + [123, 124], + [112, 125], + [125, 126], + [126, 127], + [127, 128], + [112, 129], + [129, 130], + [130, 131], + [131, 132], + ] + + pose_link_color = palette[ + [0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16] + + [16, 16, 16, 16, 16, 16] + + [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16] + + [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16] + ] + pose_kpt_color = palette[ + [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0] + + [0, 0, 0, 0, 0, 0] + + [19] * (68 + 42) + ] + + elif dataset == "TopDownAicDataset": + skeleton = [ + [2, 1], + [1, 0], + [0, 13], + [13, 3], + [3, 4], + [4, 5], + [8, 7], + [7, 6], + [6, 9], + [9, 10], + [10, 11], + [12, 13], + [0, 6], + [3, 9], + ] + + pose_link_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 0, 7, 7]] + pose_kpt_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 0, 0]] + + elif dataset == "TopDownMpiiDataset": + skeleton = [ + [0, 1], + [1, 2], + [2, 6], + [6, 3], + [3, 4], + [4, 5], + [6, 7], + [7, 8], + [8, 9], + [8, 12], + [12, 11], + [11, 10], + [8, 13], + [13, 14], + [14, 15], + ] + + pose_link_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 9, 9, 9, 9, 9, 9]] + pose_kpt_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 0, 9, 9, 9, 9, 9, 9]] + + elif dataset == "TopDownMpiiTrbDataset": + skeleton = [ + [12, 13], + [13, 0], + [13, 1], + [0, 2], + [1, 3], + [2, 4], + [3, 5], + [0, 6], + [1, 7], + [6, 7], + [6, 8], + [7, 9], + [8, 10], + [9, 11], + [14, 15], + [16, 17], + [18, 19], + [20, 21], + [22, 23], + [24, 25], + [26, 27], + [28, 29], + [30, 31], + [32, 33], + [34, 35], + [36, 37], + [38, 39], + ] + + pose_link_color = palette[[16] * 14 + [19] * 13] + pose_kpt_color = palette[[16] * 14 + [0] * 26] + + elif dataset in ("OneHand10KDataset", "FreiHandDataset", "PanopticDataset"): + skeleton = [ + [0, 1], + [1, 2], + [2, 3], + [3, 4], + [0, 5], + [5, 6], + [6, 7], + [7, 8], + [0, 9], + [9, 10], + [10, 11], + [11, 12], + [0, 13], + [13, 14], + [14, 15], + [15, 16], + [0, 17], + [17, 18], + [18, 19], + [19, 20], + ] + + pose_link_color = palette[ + [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16] + ] + pose_kpt_color = palette[ + [0, 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16] + ] + + elif dataset == "InterHand2DDataset": + skeleton = [ + [0, 1], + [1, 2], + [2, 3], + [4, 5], + [5, 6], + [6, 7], + [8, 9], + [9, 10], + [10, 11], + [12, 13], + [13, 14], + [14, 15], + [16, 17], + [17, 18], + [18, 19], + [3, 20], + [7, 20], + [11, 20], + [15, 20], + [19, 20], + ] + + pose_link_color = palette[ + [0, 0, 0, 4, 4, 4, 8, 8, 8, 12, 12, 12, 16, 16, 16, 0, 4, 8, 12, 16] + ] + pose_kpt_color = palette[ + [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16, 0] + ] + + elif dataset == "Face300WDataset": + # show the results + skeleton = [] + + pose_link_color = palette[[]] + pose_kpt_color = palette[[19] * 68] + kpt_score_thr = 0 + + elif dataset == "FaceAFLWDataset": + # show the results + skeleton = [] + + pose_link_color = palette[[]] + pose_kpt_color = palette[[19] * 19] + kpt_score_thr = 0 + + elif dataset == "FaceCOFWDataset": + # show the results + skeleton = [] + + pose_link_color = palette[[]] + pose_kpt_color = palette[[19] * 29] + kpt_score_thr = 0 + + elif dataset == "FaceWFLWDataset": + # show the results + skeleton = [] + + pose_link_color = palette[[]] + pose_kpt_color = palette[[19] * 98] + kpt_score_thr = 0 + + elif dataset == "AnimalHorse10Dataset": + skeleton = [ + [0, 1], + [1, 12], + [12, 16], + [16, 21], + [21, 17], + [17, 11], + [11, 10], + [10, 8], + [8, 9], + [9, 12], + [2, 3], + [3, 4], + [5, 6], + [6, 7], + [13, 14], + [14, 15], + [18, 19], + [19, 20], + ] + + pose_link_color = palette[[4] * 10 + [6] * 2 + [6] * 2 + [7] * 2 + [7] * 2] + pose_kpt_color = palette[ + [4, 4, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 4, 7, 7, 7, 4, 4, 7, 7, 7, 4] + ] + + elif dataset == "AnimalFlyDataset": + skeleton = [ + [1, 0], + [2, 0], + [3, 0], + [4, 3], + [5, 4], + [7, 6], + [8, 7], + [9, 8], + [11, 10], + [12, 11], + [13, 12], + [15, 14], + [16, 15], + [17, 16], + [19, 18], + [20, 19], + [21, 20], + [23, 22], + [24, 23], + [25, 24], + [27, 26], + [28, 27], + [29, 28], + [30, 3], + [31, 3], + ] + + pose_link_color = palette[[0] * 25] + pose_kpt_color = palette[[0] * 32] + + elif dataset == "AnimalLocustDataset": + skeleton = [ + [1, 0], + [2, 1], + [3, 2], + [4, 3], + [6, 5], + [7, 6], + [9, 8], + [10, 9], + [11, 10], + [13, 12], + [14, 13], + [15, 14], + [17, 16], + [18, 17], + [19, 18], + [21, 20], + [22, 21], + [24, 23], + [25, 24], + [26, 25], + [28, 27], + [29, 28], + [30, 29], + [32, 31], + [33, 32], + [34, 33], + ] + + pose_link_color = palette[[0] * 26] + pose_kpt_color = palette[[0] * 35] + + elif dataset == "AnimalZebraDataset": + skeleton = [[1, 0], [2, 1], [3, 2], [4, 2], [5, 7], [6, 7], [7, 2], [8, 7]] + + pose_link_color = palette[[0] * 8] + pose_kpt_color = palette[[0] * 9] + + elif dataset in "AnimalPoseDataset": + skeleton = [ + [0, 1], + [0, 2], + [1, 3], + [0, 4], + [1, 4], + [4, 5], + [5, 7], + [6, 7], + [5, 8], + [8, 12], + [12, 16], + [5, 9], + [9, 13], + [13, 17], + [6, 10], + [10, 14], + [14, 18], + [6, 11], + [11, 15], + [15, 19], + ] + + pose_link_color = palette[[0] * 20] + pose_kpt_color = palette[[0] * 20] + else: + NotImplementedError() + + img_w, img_h = img_size + img = 255 * np.ones((img_h, img_w, 3), dtype=np.uint8) + img = imshow_keypoints( + img, + keypts_list, + skeleton, + kpt_score_thr, + pose_kpt_color, + pose_link_color, + radius, + thickness, + ) + alpha = 255 * (img != 255).any(axis=-1, keepdims=True).astype(np.uint8) + return np.concatenate([img, alpha], axis=-1) + + +def imshow_keypoints( + img, + pose_result, + skeleton=None, + kpt_score_thr=0.3, + pose_kpt_color=None, + pose_link_color=None, + radius=4, + thickness=1, + show_keypoint_weight=False, +): + """Draw keypoints and links on an image. + From ViTPose/mmpose/core/visualization/image.py + + Args: + img (H, W, 3) array + pose_result (list[kpts]): The poses to draw. Each element kpts is + a set of K keypoints as an Kx3 numpy.ndarray, where each + keypoint is represented as x, y, score. + kpt_score_thr (float, optional): Minimum score of keypoints + to be shown. Default: 0.3. + pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, + the keypoint will not be drawn. + pose_link_color (np.array[Mx3]): Color of M links. If None, the + links will not be drawn. + thickness (int): Thickness of lines. + show_keypoint_weight (bool): If True, opacity indicates keypoint score + """ + import math + img_h, img_w, _ = img.shape + idcs = [0, 16, 15, 18, 17, 5, 2, 6, 3, 7, 4, 12, 9, 13, 10, 14, 11] + for kpts in pose_result: + kpts = np.array(kpts, copy=False)[idcs] + + # draw each point on image + if pose_kpt_color is not None: + assert len(pose_kpt_color) == len(kpts) + for kid, kpt in enumerate(kpts): + x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] + if kpt_score > kpt_score_thr: + color = tuple(int(c) for c in pose_kpt_color[kid]) + if show_keypoint_weight: + img_copy = img.copy() + cv2.circle( + img_copy, (int(x_coord), int(y_coord)), radius, color, -1 + ) + transparency = max(0, min(1, kpt_score)) + cv2.addWeighted( + img_copy, transparency, img, 1 - transparency, 0, dst=img + ) + else: + cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1) + + # draw links + if skeleton is not None and pose_link_color is not None: + assert len(pose_link_color) == len(skeleton) + for sk_id, sk in enumerate(skeleton): + pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) + pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) + if ( + pos1[0] > 0 + and pos1[0] < img_w + and pos1[1] > 0 + and pos1[1] < img_h + and pos2[0] > 0 + and pos2[0] < img_w + and pos2[1] > 0 + and pos2[1] < img_h + and kpts[sk[0], 2] > kpt_score_thr + and kpts[sk[1], 2] > kpt_score_thr + ): + color = tuple(int(c) for c in pose_link_color[sk_id]) + if show_keypoint_weight: + img_copy = img.copy() + X = (pos1[0], pos2[0]) + Y = (pos1[1], pos2[1]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1])) + stickwidth = 2 + polygon = cv2.ellipse2Poly( + (int(mX), int(mY)), + (int(length / 2), int(stickwidth)), + int(angle), + 0, + 360, + 1, + ) + cv2.fillConvexPoly(img_copy, polygon, color) + transparency = max( + 0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2])) + ) + cv2.addWeighted( + img_copy, transparency, img, 1 - transparency, 0, dst=img + ) + else: + cv2.line(img, pos1, pos2, color, thickness=thickness) + + return img \ No newline at end of file diff --git a/lib/vis/viewer.py b/lib/vis/viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..5f0f2ba200c3066bda4c6193a9645f39f27e20a9 --- /dev/null +++ b/lib/vis/viewer.py @@ -0,0 +1,308 @@ +import os +import os.path as op +import re +from abc import abstractmethod + +import matplotlib.cm as cm +import numpy as np +from aitviewer.headless import HeadlessRenderer +from aitviewer.renderables.billboard import Billboard +from aitviewer.renderables.meshes import Meshes +from aitviewer.scene.camera import OpenCVCamera +from aitviewer.scene.material import Material +from aitviewer.utils.so3 import aa2rot_numpy +from aitviewer.viewer import Viewer +from easydict import EasyDict as edict +from loguru import logger +from PIL import Image +from tqdm import tqdm +import random + +OBJ_ID = 100 +SMPLX_ID = 150 +LEFT_ID = 200 +RIGHT_ID = 250 +SEGM_IDS = {"object": OBJ_ID, "smplx": SMPLX_ID, "left": LEFT_ID, "right": RIGHT_ID} + +cmap = cm.get_cmap("plasma") +materials = { + "white": Material(color=(1.0, 1.0, 1.0, 1.0), ambient=0.2), + "green": Material(color=(0.0, 1.0, 0.0, 1.0), ambient=0.2), + "blue": Material(color=(0.0, 0.0, 1.0, 1.0), ambient=0.2), + "red": Material(color=(0.969, 0.106, 0.059, 1.0), ambient=0.2), + "cyan": Material(color=(0.051, 0.659, 0.051, 1.0), ambient=0.2), + "light-blue": Material(color=(0.588, 0.5647, 0.9725, 1.0), ambient=0.2), + "cyan-light": Material(color=(0.051, 0.659, 0.051, 1.0), ambient=0.2), + "dark-light": Material(color=(0.404, 0.278, 0.278, 1.0), ambient=0.2), + "rice": Material(color=(0.922, 0.922, 0.102, 1.0), ambient=0.2), + "whac-whac": Material(color=(167/255, 193/255, 203/255, 1.0), ambient=0.2), + "whac-wham": Material(color=(165/255, 153/255, 174/255, 1.0), ambient=0.2), + "pace-blue": Material(color=(0.584, 0.902, 0.976, 1.0), ambient=0.2), + "pace-green": Material(color=(0.631, 1.0, 0.753, 1.0), ambient=0.2), + "director-purple": Material(color=(0.804, 0.6, 0.820, 1.0), ambient=0.2), + "director-blue": Material(color=(0.207, 0.596, 0.792, 1.0), ambient=0.2), + "none": None, +} +color_list = list(materials.keys()) + +def random_material(): + return Material(color=(random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1), 1), ambient=0.2) + + +class ViewerData(edict): + """ + Interface to standardize viewer data. + """ + + def __init__(self, Rt, K, cols, rows, imgnames=None): + self.imgnames = imgnames + self.Rt = Rt + self.K = K + self.num_frames = Rt.shape[0] + self.cols = cols + self.rows = rows + self.validate_format() + + def validate_format(self): + assert len(self.Rt.shape) == 3 + assert self.Rt.shape[0] == self.num_frames + assert self.Rt.shape[1] == 3 + assert self.Rt.shape[2] == 4 + + assert len(self.K.shape) == 2 + assert self.K.shape[0] == 3 + assert self.K.shape[1] == 3 + if self.imgnames is not None: + assert self.num_frames == len(self.imgnames) + assert self.num_frames > 0 + im_p = self.imgnames[0] + assert op.exists(im_p), f"Image path {im_p} does not exist" + + +class ARCTICViewer: + def __init__( + self, + render_types=["rgb", "depth", "mask"], + interactive=True, + size=(2024, 2024), + ): + if not interactive: + v = HeadlessRenderer() + else: + v = Viewer(size=size) + + self.v = v + self.interactive = interactive + # self.layers = layers + self.render_types = render_types + + def view_interactive(self): + self.v.run() + + def view_fn_headless(self, num_iter, out_folder): + v = self.v + + v._init_scene() + + logger.info("Rendering to video") + if "video" in self.render_types: + vid_p = op.join(out_folder, "video.mp4") + v.save_video(video_dir=vid_p) + + pbar = tqdm(range(num_iter)) + for fidx in pbar: + out_rgb = op.join(out_folder, "images", f"rgb/{fidx:04d}.png") + out_mask = op.join(out_folder, "images", f"mask/{fidx:04d}.png") + out_depth = op.join(out_folder, "images", f"depth/{fidx:04d}.npy") + + # render RGB, depth, segmentation masks + if "rgb" in self.render_types: + v.export_frame(out_rgb) + if "depth" in self.render_types: + os.makedirs(op.dirname(out_depth), exist_ok=True) + render_depth(v, out_depth) + if "mask" in self.render_types: + os.makedirs(op.dirname(out_mask), exist_ok=True) + render_mask(v, out_mask) + v.scene.next_frame() + logger.info(f"Exported to {out_folder}") + + @abstractmethod + def load_data(self): + pass + + def check_format(self, batch): + meshes_all, data = batch + assert isinstance(meshes_all, dict) + assert len(meshes_all) > 0 + for mesh in meshes_all.values(): + assert isinstance(mesh, Meshes) + assert isinstance(data, ViewerData) + + def render_seq(self, batch, out_folder="./render_out", floor_y=0): + meshes_all, data = batch + self.setup_viewer(data, floor_y) + for mesh in meshes_all.values(): + self.v.scene.add(mesh) + if self.interactive: + self.view_interactive() + else: + num_iter = data["num_frames"] + self.view_fn_headless(num_iter, out_folder) + + def setup_viewer(self, data, floor_y): + v = self.v + fps = 30 + if "imgnames" in data: + setup_billboard(data, v) + + # camera.show_path() + v.run_animations = True # autoplay + v.run_animations = False # autoplay + v.playback_fps = fps + v.scene.fps = fps + v.scene.origin.enabled = False + v.scene.floor.enabled = False + v.auto_set_floor = False + # v.scene.camera.position = np.array((0.0, 0.0, 0)) + self.v = v + + +def dist2vc(dist_ro, dist_lo, dist_o, _cmap, tf_fn=None): + if tf_fn is not None: + exp_map = tf_fn + else: + exp_map = small_exp_map + dist_ro = exp_map(dist_ro) + dist_lo = exp_map(dist_lo) + dist_o = exp_map(dist_o) + + vc_ro = _cmap(dist_ro) + vc_lo = _cmap(dist_lo) + vc_o = _cmap(dist_o) + return vc_ro, vc_lo, vc_o + + +def small_exp_map(_dist): + dist = np.copy(_dist) + # dist = 1.0 - np.clip(dist, 0, 0.1) / 0.1 + dist = np.exp(-20.0 * dist) + return dist + + +def construct_viewer_meshes(data, draw_edges=False, flat_shading=True): + rotation_flip = aa2rot_numpy(np.array([1, 0, 0]) * np.pi) + meshes = {} + for key, val in data.items(): + if 'single' in key: + draw_edges = True + else: + draw_edges = False + if "object" in key: + flat_shading = False + else: + flat_shading = False + v3d = val["v3d"] + if not isinstance(val["color"], str): + val["color"] = color_list[val["color"]] + if val["color"] == "random": + mesh_material = random_material() + else: + mesh_material = materials[val["color"]] + meshes[key] = Meshes( + v3d, + val["f3d"], + vertex_colors=val["vc"], + face_colors=val["fc"] if "fc" in val else None, + name=val["name"], + flat_shading=flat_shading, + draw_edges=draw_edges, + material=mesh_material, + # rotation=rotation_flip, + ) + return meshes + + +def setup_viewer( + v, shared_folder_p, video, images_path, data, flag, seq_name, side_angle +): + fps = 10 + cols, rows = 224, 224 + focal = 1000.0 + + # setup image paths + regex = re.compile(r"(\d*)$") + + def sort_key(x): + name = os.path.splitext(x)[0] + return int(regex.search(name).group(0)) + + # setup billboard + images_path = op.join(shared_folder_p, "images") + images_paths = [ + os.path.join(images_path, f) + for f in sorted(os.listdir(images_path), key=sort_key) + ] + assert len(images_paths) > 0 + + cam_t = data[f"{flag}.object.cam_t"] + num_frames = min(cam_t.shape[0], len(images_paths)) + cam_t = cam_t[:num_frames] + # setup camera + K = np.array([[focal, 0, rows / 2.0], [0, focal, cols / 2.0], [0, 0, 1]]) + Rt = np.zeros((num_frames, 3, 4)) + Rt[:, :, 3] = cam_t + Rt[:, :3, :3] = np.eye(3) + Rt[:, 1:3, :3] *= -1.0 + + camera = OpenCVCamera(K, Rt, cols, rows, viewer=v) + if side_angle is None: + billboard = Billboard.from_camera_and_distance( + camera, 10.0, cols, rows, images_paths + ) + v.scene.add(billboard) + v.scene.add(camera) + v.run_animations = True # autoplay + v.playback_fps = fps + v.scene.fps = fps + v.scene.origin.enabled = False + v.scene.floor.enabled = False + v.auto_set_floor = False + v.scene.floor.position[1] = -3 + v.set_temp_camera(camera) + # v.scene.camera.position = np.array((0.0, 0.0, 0)) + return v + + +def render_depth(v, depth_p): + depth = np.array(v.get_depth()).astype(np.float16) + np.save(depth_p, depth) + + +def render_mask(v, mask_p): + nodes_uid = {node.name: node.uid for node in v.scene.collect_nodes()} + my_cmap = { + uid: [SEGM_IDS[name], SEGM_IDS[name], SEGM_IDS[name]] + for name, uid in nodes_uid.items() + if name in SEGM_IDS.keys() + } + mask = np.array(v.get_mask(color_map=my_cmap)).astype(np.uint8) + mask = Image.fromarray(mask) + mask.save(mask_p) + + +def setup_billboard(data, v): + images_paths = data.imgnames + K = data.K + Rt = data.Rt + rows = data.rows + cols = data.cols + camera = OpenCVCamera(K, Rt, cols, rows, viewer=v) + if images_paths is not None: + billboard = Billboard.from_camera_and_distance( + camera, 10.0, cols, rows, images_paths + ) + v.scene.add(billboard) + v.scene.add(camera) + v.scene.camera.load_cam() + v.set_temp_camera(camera) diff --git a/lib/vis/wham_tools/__pycache__/tools.cpython-310.pyc b/lib/vis/wham_tools/__pycache__/tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df1f556d13adac6ae0fe0c8b0cdd1b1c48662637 Binary files /dev/null and b/lib/vis/wham_tools/__pycache__/tools.cpython-310.pyc differ diff --git a/lib/vis/wham_tools/tools.py b/lib/vis/wham_tools/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..c9e7f053f49a11e12a7f01a76dbca9059c359f89 --- /dev/null +++ b/lib/vis/wham_tools/tools.py @@ -0,0 +1,56 @@ +import numpy as np + + +def checkerboard_geometry( + length=12.0, + color0=[172/255, 172/255, 172/255], + color1=[215/255, 215/255, 215/255], + tile_width=0.5, + alpha=1.0, + up="y", + c1=0.0, + c2=0.0, +): + assert up == "y" or up == "z" + color0 = np.array(color0 + [alpha]) + color1 = np.array(color1 + [alpha]) + radius = length / 2.0 + num_rows = num_cols = max(2, int(length / tile_width)) + vertices = [] + vert_colors = [] + faces = [] + face_colors = [] + for i in range(num_rows): + for j in range(num_cols): + u0, v0 = j * tile_width - radius, i * tile_width - radius + us = np.array([u0, u0, u0 + tile_width, u0 + tile_width]) + vs = np.array([v0, v0 + tile_width, v0 + tile_width, v0]) + zs = np.zeros(4) + if up == "y": + cur_verts = np.stack([us, zs, vs], axis=-1) # (4, 3) + cur_verts[:, 0] += c1 + cur_verts[:, 2] += c2 + else: + cur_verts = np.stack([us, vs, zs], axis=-1) # (4, 3) + cur_verts[:, 0] += c1 + cur_verts[:, 1] += c2 + + cur_faces = np.array( + [[0, 1, 3], [1, 2, 3], [0, 3, 1], [1, 3, 2]], dtype=np.int64 + ) + cur_faces += 4 * (i * num_cols + j) # the number of previously added verts + use_color0 = (i % 2 == 0 and j % 2 == 0) or (i % 2 == 1 and j % 2 == 1) + cur_color = color0 if use_color0 else color1 + cur_colors = np.array([cur_color, cur_color, cur_color, cur_color]) + + vertices.append(cur_verts) + faces.append(cur_faces) + vert_colors.append(cur_colors) + face_colors.append(cur_colors) + + vertices = np.concatenate(vertices, axis=0).astype(np.float32) + vert_colors = np.concatenate(vert_colors, axis=0).astype(np.float32) + faces = np.concatenate(faces, axis=0).astype(np.float32) + face_colors = np.concatenate(face_colors, axis=0).astype(np.float32) + + return vertices, faces, vert_colors, face_colors \ No newline at end of file diff --git a/license.txt b/license.txt new file mode 100644 index 0000000000000000000000000000000000000000..3811ab9a1eb86022f72d9c871edb3c05907a158f --- /dev/null +++ b/license.txt @@ -0,0 +1,402 @@ +Attribution-NonCommercial-NoDerivatives 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 +International Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-NoDerivatives 4.0 International Public +License ("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + c. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + d. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + e. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + f. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + g. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + h. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + i. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + j. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + k. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce and reproduce, but not Share, Adapted Material + for NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material, You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + For the avoidance of doubt, You do not have permission under + this Public License to Share Adapted Material. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only and provided You do not Share Adapted Material; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index bcb91bb3e0308f3fdadc2c88a746919e194f3426..66530b53768d7236f7f9dc8fae910d2718218e52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,37 @@ -git+https://github.com/princeton-vl/DROID-SLAM.git \ No newline at end of file +numpy==1.26.4 +opencv-python +pyrender +scikit-image +smplx==0.1.28 +yacs +mmcv==1.3.9 +timm +einops +xtcocotools +pandas +hydra-core +hydra-submitit-launcher +hydra-colorlog +pyrootutils +rich +webdataset +ultralytics +pulp +supervision +pycocotools +joblib +natsort +git+https://github.com/facebookresearch/pytorch3d.git@stable +torch-scatter==2.1.2 +evo +pytorch-minimize +mmengine==0.10.4 +HTML4Vision +plyfile +aitviewer +easydict +loguru +dill +lapx +chumpy@git+https://github.com/mattloper/chumpy +moderngl-window==2.4.6 \ No newline at end of file diff --git a/scripts/scripts_test_video/__pycache__/detect_track_video.cpython-310.pyc b/scripts/scripts_test_video/__pycache__/detect_track_video.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a58ee3385e3d8e5e1d63216579c2c4e8b770cab Binary files /dev/null and b/scripts/scripts_test_video/__pycache__/detect_track_video.cpython-310.pyc differ diff --git a/scripts/scripts_test_video/__pycache__/hawor_slam.cpython-310.pyc b/scripts/scripts_test_video/__pycache__/hawor_slam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..108382dc5790315884ab5209a7038990e6a7f377 Binary files /dev/null and b/scripts/scripts_test_video/__pycache__/hawor_slam.cpython-310.pyc differ diff --git a/scripts/scripts_test_video/__pycache__/hawor_video.cpython-310.pyc b/scripts/scripts_test_video/__pycache__/hawor_video.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fd050c335a2d7f4cd128873e58fe5d28c17f20f Binary files /dev/null and b/scripts/scripts_test_video/__pycache__/hawor_video.cpython-310.pyc differ diff --git a/scripts/scripts_test_video/detect_track_video.py b/scripts/scripts_test_video/detect_track_video.py new file mode 100644 index 0000000000000000000000000000000000000000..308dedbaf7a5f58f9a042dc1076fbb3a0668c8a0 --- /dev/null +++ b/scripts/scripts_test_video/detect_track_video.py @@ -0,0 +1,72 @@ +import sys +import os +sys.path.insert(0, os.path.dirname(__file__) + '/../..') + +import argparse +import numpy as np +from glob import glob +from lib.pipeline.tools import detect_track +from natsort import natsorted +import subprocess + + +def extract_frames(video_path, output_folder): + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + command = [ + 'ffmpeg', + '-i', video_path, + '-vf', 'fps=30', + '-start_number', '0', + os.path.join(output_folder, '%04d.jpg') + ] + + subprocess.run(command, check=True) + + +def detect_track_video(args): + file = args.video_path + root = os.path.dirname(file) + seq = os.path.basename(file).split('.')[0] + + seq_folder = f'{root}/{seq}' + img_folder = f'{seq_folder}/extracted_images' + os.makedirs(seq_folder, exist_ok=True) + os.makedirs(img_folder, exist_ok=True) + print(f'Running detect_track on {file} ...') + + ##### Extract Frames ##### + imgfiles = natsorted(glob(f'{img_folder}/*.jpg')) + # print(imgfiles[:10]) + if len(imgfiles) > 0: + print("Skip extracting frames") + else: + _ = extract_frames(file, img_folder) + imgfiles = natsorted(glob(f'{img_folder}/*.jpg')) + + ##### Detection + Track ##### + print('Detect and Track ...') + + start_idx = 0 + end_idx = len(imgfiles) + + if os.path.exists(f'{seq_folder}/tracks_{start_idx}_{end_idx}/model_boxes.npy'): + print(f"skip track for {start_idx}_{end_idx}") + return start_idx, end_idx, seq_folder, imgfiles + os.makedirs(f"{seq_folder}/tracks_{start_idx}_{end_idx}", exist_ok=True) + boxes_, tracks_ = detect_track(imgfiles, thresh=0.2) + np.save(f'{seq_folder}/tracks_{start_idx}_{end_idx}/model_boxes.npy', boxes_) + np.save(f'{seq_folder}/tracks_{start_idx}_{end_idx}/model_tracks.npy', tracks_) + + return start_idx, end_idx, seq_folder, imgfiles + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument("--img_focal", type=float) + parser.add_argument("--video_path", type=str, default='') + parser.add_argument("--input_type", type=str, default='file') + args = parser.parse_args() + + detect_track_video(args) \ No newline at end of file diff --git a/scripts/scripts_test_video/hawor_slam.py b/scripts/scripts_test_video/hawor_slam.py new file mode 100644 index 0000000000000000000000000000000000000000..a557bd102cca778bc77eb424b9b92317f326b6cf --- /dev/null +++ b/scripts/scripts_test_video/hawor_slam.py @@ -0,0 +1,141 @@ +import sys +import os + +from natsort import natsorted + +sys.path.insert(0, os.path.dirname(__file__) + '/../..') + +import argparse +from tqdm import tqdm +import numpy as np +import torch +import cv2 +from PIL import Image +from glob import glob +from pycocotools import mask as masktool +from lib.pipeline.masked_droid_slam import * +from lib.pipeline.est_scale import * +from hawor.utils.process import block_print, enable_print + +sys.path.insert(0, os.path.dirname(__file__) + '/../../thirdparty/Metric3D') +from metric import Metric3D + + +def get_all_mp4_files(folder_path): + # Ensure the folder path is absolute + folder_path = os.path.abspath(folder_path) + + # Recursively search for all .mp4 files in the folder and its subfolders + mp4_files = glob(os.path.join(folder_path, '**', '*.mp4'), recursive=True) + + return mp4_files + +def split_list_by_interval(lst, interval=1000): + start_indices = [] + end_indices = [] + split_lists = [] + + for i in range(0, len(lst), interval): + start_indices.append(i) + end_indices.append(min(i + interval, len(lst))) + split_lists.append(lst[i:i + interval]) + + return start_indices, end_indices, split_lists + +def hawor_slam(args, start_idx, end_idx): + # File and folders + file = args.video_path + video_root = os.path.dirname(file) + video = os.path.basename(file).split('.')[0] + seq_folder = os.path.join(video_root, video) + os.makedirs(seq_folder, exist_ok=True) + video_folder = os.path.join(video_root, video) + + img_folder = f'{video_folder}/extracted_images' + imgfiles = natsorted(glob(f'{img_folder}/*.jpg')) + + first_img = cv2.imread(imgfiles[0]) + height, width, _ = first_img.shape + + print(f'Running slam on {video_folder} ...') + + ##### Run SLAM ##### + # Use Masking + masks = np.load(f'{video_folder}/tracks_{start_idx}_{end_idx}/model_masks.npy', allow_pickle=True) + masks = torch.from_numpy(masks) + print(masks.shape) + + # Camera calibration (intrinsics) for SLAM + focal = args.img_focal + if focal is None: + try: + with open(os.path.join(video_folder, 'est_focal.txt'), 'r') as file: + focal = file.read() + focal = float(focal) + except: + + print('No focal length provided') + focal = 600 + with open(os.path.join(video_folder, 'est_focal.txt'), 'w') as file: + file.write(str(focal)) + calib = np.array(est_calib(imgfiles)) # [focal, focal, cx, cy] + center = calib[2:] + calib[:2] = focal + + # Droid-slam with masking + droid, traj = run_slam(imgfiles, masks=masks, calib=calib) + n = droid.video.counter.value + tstamp = droid.video.tstamp.cpu().int().numpy()[:n] + disps = droid.video.disps_up.cpu().numpy()[:n] + print('DBA errors:', droid.backend.errors) + + del droid + torch.cuda.empty_cache() + + # Estimate scale + block_print() + metric = Metric3D('thirdparty/Metric3D/weights/metric_depth_vit_large_800k.pth') + enable_print() + min_threshold = 0.4 + max_threshold = 0.7 + + print('Predicting Metric Depth ...') + pred_depths = [] + H, W = get_dimention(imgfiles) + for t in tqdm(tstamp): + pred_depth = metric(imgfiles[t], calib) + pred_depth = cv2.resize(pred_depth, (W, H)) + pred_depths.append(pred_depth) + + ##### Estimate Metric Scale ##### + print('Estimating Metric Scale ...') + scales_ = [] + n = len(tstamp) # for each keyframe + for i in tqdm(range(n)): + t = tstamp[i] + disp = disps[i] + pred_depth = pred_depths[i] + slam_depth = 1/disp + + # Estimate scene scale + msk = masks[t].numpy().astype(np.uint8) + scale = est_scale_hybrid(slam_depth, pred_depth, sigma=0.5, msk=msk, near_thresh=min_threshold, far_thresh=max_threshold) + scales_.append(scale) + + median_s = np.median(scales_) + print(f"estimated scale: {median_s}") + + # Save results + os.makedirs(f"{seq_folder}/SLAM", exist_ok=True) + save_path = f'{seq_folder}/SLAM/hawor_slam_w_scale_{start_idx}_{end_idx}.npz' + np.savez(save_path, + tstamp=tstamp, disps=disps, traj=traj, + img_focal=focal, img_center=calib[-2:], + scale=median_s) + + + + + + + diff --git a/scripts/scripts_test_video/hawor_video.py b/scripts/scripts_test_video/hawor_video.py new file mode 100644 index 0000000000000000000000000000000000000000..eed5bcc6f0470cb3000737ead187226691000b18 --- /dev/null +++ b/scripts/scripts_test_video/hawor_video.py @@ -0,0 +1,361 @@ +from collections import defaultdict + +import json +import os +import joblib +import numpy as np +import torch +import cv2 +from tqdm import tqdm +from glob import glob +from natsort import natsorted + +from lib.pipeline.tools import parse_chunks, parse_chunks_hand_frame +from lib.models.hawor import HAWOR +from lib.eval_utils.custom_utils import cam2world_convert, load_slam_cam +from lib.eval_utils.custom_utils import interpolate_bboxes +from lib.eval_utils.filling_utils import filling_postprocess, filling_preprocess +from lib.vis.renderer import Renderer +from hawor.utils.process import get_mano_faces, run_mano, run_mano_left +from hawor.utils.rotation import angle_axis_to_rotation_matrix, rotation_matrix_to_angle_axis +from infiller.lib.model.network import TransformerModel + +def load_hawor(checkpoint_path): + from pathlib import Path + from hawor.configs import get_config + model_cfg = str(Path(checkpoint_path).parent.parent / 'model_config.yaml') + model_cfg = get_config(model_cfg, update_cachedir=True) + + # Override some config values, to crop bbox correctly + if (model_cfg.MODEL.BACKBONE.TYPE == 'vit') and ('BBOX_SHAPE' not in model_cfg.MODEL): + model_cfg.defrost() + assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone" + model_cfg.MODEL.BBOX_SHAPE = [192,256] + model_cfg.freeze() + + model = HAWOR.load_from_checkpoint(checkpoint_path, strict=False, cfg=model_cfg) + return model, model_cfg + + + +def hawor_motion_estimation(args, start_idx, end_idx, seq_folder): + model, model_cfg = load_hawor(args.checkpoint) + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + model = model.to(device) + model.eval() + + file = args.video_path + video_root = os.path.dirname(file) + video = os.path.basename(file).split('.')[0] + img_folder = f"{video_root}/{video}/extracted_images" + imgfiles = np.array(natsorted(glob(f'{img_folder}/*.jpg'))) + + tracks = np.load(f'{seq_folder}/tracks_{start_idx}_{end_idx}/model_tracks.npy', allow_pickle=True).item() + img_focal = args.img_focal + if img_focal is None: + try: + with open(os.path.join(seq_folder, 'est_focal.txt'), 'r') as file: + img_focal = file.read() + img_focal = float(img_focal) + except: + img_focal = 600 + print(f'No focal length provided, use default {img_focal}') + with open(os.path.join(seq_folder, 'est_focal.txt'), 'w') as file: + file.write(str(img_focal)) + + tid = np.array([tr for tr in tracks]) + + print(f'Running hawor on {video} ...') + + left_trk = [] + right_trk = [] + for k, idx in enumerate(tid): + trk = tracks[idx] + + valid = np.array([t['det'] for t in trk]) + is_right = np.concatenate([t['det_handedness'] for t in trk])[valid] + + if is_right.sum() / len(is_right) < 0.5: + left_trk.extend(trk) + else: + right_trk.extend(trk) + left_trk = sorted(left_trk, key=lambda x: x['frame']) + right_trk = sorted(right_trk, key=lambda x: x['frame']) + final_tracks = { + 0: left_trk, + 1: right_trk + } + tid = [0, 1] + + img = cv2.imread(imgfiles[0]) + img_center = [img.shape[1] / 2, img.shape[0] / 2]# w/2, h/2 + H, W = img.shape[:2] + model_masks = np.zeros((len(imgfiles), H, W)) + + bin_size = 128 + max_faces_per_bin = 20000 + renderer = Renderer(img.shape[1], img.shape[0], img_focal, 'cuda', + bin_size=bin_size, max_faces_per_bin=max_faces_per_bin) + # get faces + faces = get_mano_faces() + faces_new = np.array([[92, 38, 234], + [234, 38, 239], + [38, 122, 239], + [239, 122, 279], + [122, 118, 279], + [279, 118, 215], + [118, 117, 215], + [215, 117, 214], + [117, 119, 214], + [214, 119, 121], + [119, 120, 121], + [121, 120, 78], + [120, 108, 78], + [78, 108, 79]]) + faces_right = np.concatenate([faces, faces_new], axis=0) + faces_left = faces_right[:,[0,2,1]] + + frame_chunks_all = defaultdict(list) + for idx in tid: + print(f"tracklet {idx}:") + trk = final_tracks[idx] + + # interp bboxes + valid = np.array([t['det'] for t in trk]) + if valid.sum() < 2: + continue + boxes = np.concatenate([t['det_box'] for t in trk]) + non_zero_indices = np.where(np.any(boxes != 0, axis=1))[0] + first_non_zero = non_zero_indices[0] + last_non_zero = non_zero_indices[-1] + boxes[first_non_zero:last_non_zero+1] = interpolate_bboxes(boxes[first_non_zero:last_non_zero+1]) + valid[first_non_zero:last_non_zero+1] = True + + + boxes = boxes[first_non_zero:last_non_zero+1] + is_right = np.concatenate([t['det_handedness'] for t in trk])[valid] + frame = np.array([t['frame'] for t in trk])[valid] + + if is_right.sum() / len(is_right) < 0.5: + is_right = np.zeros((len(boxes), 1)) + else: + is_right = np.ones((len(boxes), 1)) + + frame_chunks, boxes_chunks = parse_chunks(frame, boxes, min_len=1) + frame_chunks_all[idx] = frame_chunks + + if len(frame_chunks) == 0: + continue + + for frame_ck, boxes_ck in zip(frame_chunks, boxes_chunks): + print(f"inference from frame {frame_ck[0]} to {frame_ck[-1]}") + img_ck = imgfiles[frame_ck] + if is_right[0] > 0: + do_flip = False + else: + do_flip = True + + results = model.inference(img_ck, boxes_ck, img_focal=img_focal, img_center=img_center, do_flip=do_flip) + + data_out = { + "init_root_orient": results["pred_rotmat"][None, :, 0], # (B, T, 3, 3) + "init_hand_pose": results["pred_rotmat"][None, :, 1:], # (B, T, 15, 3, 3) + "init_trans": results["pred_trans"][None, :, 0], # (B, T, 3) + "init_betas": results["pred_shape"][None, :] # (B, T, 10) + } + + # flip left hand + init_root = rotation_matrix_to_angle_axis(data_out["init_root_orient"]) + init_hand_pose = rotation_matrix_to_angle_axis(data_out["init_hand_pose"]) + if do_flip: + init_root[..., 1] *= -1 + init_root[..., 2] *= -1 + init_hand_pose[..., 1] *= -1 + init_hand_pose[..., 2] *= -1 + data_out["init_root_orient"] = angle_axis_to_rotation_matrix(init_root) + data_out["init_hand_pose"] = angle_axis_to_rotation_matrix(init_hand_pose) + + # save camera-space results + pred_dict={ + k:v.tolist() for k, v in data_out.items() + } + pred_path = os.path.join(seq_folder, 'cam_space', str(idx), f"{frame_ck[0]}_{frame_ck[-1]}.json") + if not os.path.exists(os.path.join(seq_folder, 'cam_space', str(idx))): + os.makedirs(os.path.join(seq_folder, 'cam_space', str(idx))) + with open(pred_path, "w") as f: + json.dump(pred_dict, f, indent=1) + + + # get hand mask + data_out["init_root_orient"] = rotation_matrix_to_angle_axis(data_out["init_root_orient"]) + data_out["init_hand_pose"] = rotation_matrix_to_angle_axis(data_out["init_hand_pose"]) + if do_flip: # left + outputs = run_mano_left(data_out["init_trans"], data_out["init_root_orient"], data_out["init_hand_pose"], betas=data_out["init_betas"]) + else: # right + outputs = run_mano(data_out["init_trans"], data_out["init_root_orient"], data_out["init_hand_pose"], betas=data_out["init_betas"]) + + vertices = outputs["vertices"][0].cpu() # (T, N, 3) + for img_i, _ in enumerate(img_ck): + if do_flip: + faces = torch.from_numpy(faces_left).cuda() + else: + faces = torch.from_numpy(faces_right).cuda() + cam_R = torch.eye(3).unsqueeze(0).cuda() + cam_T = torch.zeros(1, 3).cuda() + cameras, lights = renderer.create_camera_from_cv(cam_R, cam_T) + verts_color = torch.tensor([0, 0, 255, 255]) / 255 + vertices_i = vertices[[img_i]] + rend, mask = renderer.render_multiple(vertices_i.unsqueeze(0).cuda(), faces, verts_color.unsqueeze(0).cuda(), cameras, lights) + + model_masks[frame_ck[img_i]] += mask + + model_masks = model_masks > 0 # bool + np.save(f'{seq_folder}/tracks_{start_idx}_{end_idx}/model_masks.npy', model_masks) + return frame_chunks_all, img_focal + +def hawor_infiller(args, start_idx, end_idx, frame_chunks_all): + # load infiller + weight_path = args.infiller_weight + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + ckpt = torch.load(weight_path, map_location=device) + pos_dim = 3 + shape_dim = 10 + num_joints = 15 + rot_dim = (num_joints + 1) * 6 # rot6d + repr_dim = 2 * (pos_dim + shape_dim + rot_dim) + nhead = 8 # repr_dim = 154 + horizon = 120 + filling_model = TransformerModel(seq_len=horizon, input_dim=repr_dim, d_model=384, nhead=nhead, d_hid=2048, nlayers=8, dropout=0.05, out_dim=repr_dim, masked_attention_stage=True) + filling_model.to(device) + filling_model.load_state_dict(ckpt['transformer_encoder_state_dict']) + filling_model.eval() + + file = args.video_path + video_root = os.path.dirname(file) + video = os.path.basename(file).split('.')[0] + seq_folder = os.path.join(video_root, video) + img_folder = f"{video_root}/{video}/extracted_images" + + # Previous steps + imgfiles = np.array(natsorted(glob(f'{img_folder}/*.jpg'))) + + idx2hand = ['left', 'right'] + filling_length = 120 + + fpath = os.path.join(seq_folder, f"SLAM/hawor_slam_w_scale_{start_idx}_{end_idx}.npz") + R_w2c_sla_all, t_w2c_sla_all, R_c2w_sla_all, t_c2w_sla_all = load_slam_cam(fpath) + + pred_trans = torch.zeros(2, len(imgfiles), 3) + pred_rot = torch.zeros(2, len(imgfiles), 3) + pred_hand_pose = torch.zeros(2, len(imgfiles), 45) + pred_betas = torch.zeros(2, len(imgfiles), 10) + pred_valid = torch.zeros((2, pred_betas.size(1))) + + # camera space to world space + tid = [0, 1] + for k, idx in enumerate(tid): + frame_chunks = frame_chunks_all[idx] + + if len(frame_chunks) == 0: + continue + + for frame_ck in frame_chunks: + print(f"from frame {frame_ck[0]} to {frame_ck[-1]}") + pred_path = os.path.join(seq_folder, 'cam_space', str(idx), f"{frame_ck[0]}_{frame_ck[-1]}.json") + with open(pred_path, "r") as f: + pred_dict = json.load(f) + data_out = { + k:torch.tensor(v) for k, v in pred_dict.items() + } + + R_c2w_sla = R_c2w_sla_all[frame_ck] + t_c2w_sla = t_c2w_sla_all[frame_ck] + + data_world = cam2world_convert(R_c2w_sla, t_c2w_sla, data_out, 'right' if idx > 0 else 'left') + + pred_trans[[idx], frame_ck] = data_world["init_trans"] + pred_rot[[idx], frame_ck] = data_world["init_root_orient"] + pred_hand_pose[[idx], frame_ck] = data_world["init_hand_pose"].flatten(-2) + pred_betas[[idx], frame_ck] = data_world["init_betas"] + pred_valid[[idx], frame_ck] = 1 + + + # runing fillingnet for this video + frame_list = torch.tensor(list(range(pred_trans.size(1)))) + pred_valid = (pred_valid > 0).numpy() + for k, idx in enumerate([1, 0]): + missing = ~pred_valid[idx] + + frame = frame_list[missing] + frame_chunks = parse_chunks_hand_frame(frame) + + print(f"run infiller on {idx2hand[idx]} hand ...") + for frame_ck in tqdm(frame_chunks): + start_shift = -1 + while frame_ck[0] + start_shift >= 0 and pred_valid[:, frame_ck[0] + start_shift].sum() != 2: + start_shift -= 1 # Shift to find the previous valid frame as start + print(f"run infiller on frame {frame_ck[0] + start_shift} to frame {min(len(imgfiles)-1, frame_ck[0] + start_shift + filling_length)}") + + frame_start = frame_ck[0] + filling_net_start = max(0, frame_start + start_shift) + filling_net_end = min(len(imgfiles)-1, filling_net_start + filling_length) + seq_valid = pred_valid[:, filling_net_start:filling_net_end] + filling_seq = {} + filling_seq['trans'] = pred_trans[:, filling_net_start:filling_net_end].numpy() + filling_seq['rot'] = pred_rot[:, filling_net_start:filling_net_end].numpy() + filling_seq['hand_pose'] = pred_hand_pose[:, filling_net_start:filling_net_end].numpy() + filling_seq['betas'] = pred_betas[:, filling_net_start:filling_net_end].numpy() + filling_seq['valid'] = seq_valid + # preprocess (convert to canonical + slerp) + filling_input, transform_w_canon = filling_preprocess(filling_seq) + src_mask = torch.zeros((filling_length, filling_length), device=device).type(torch.bool) + src_mask = src_mask.to(device) + filling_input = torch.from_numpy(filling_input).unsqueeze(0).to(device).permute(1,0,2) # (seq_len, B, in_dim) + T_original = len(filling_input) + filling_length = 120 + if T_original < filling_length: + pad_length = filling_length - T_original + last_time_step = filling_input[-1, :, :] + padding = last_time_step.unsqueeze(0).repeat(pad_length, 1, 1) + filling_input = torch.cat([filling_input, padding], dim=0) + seq_valid_padding = np.ones((2, filling_length - T_original)) + seq_valid_padding = np.concatenate([seq_valid, seq_valid_padding], axis=1) + else: + seq_valid_padding = seq_valid + + + T, B, _ = filling_input.shape + + valid = torch.from_numpy(seq_valid_padding).unsqueeze(0).all(dim=1).permute(1, 0) # (T,B) + valid_atten = torch.from_numpy(seq_valid_padding).unsqueeze(0).all(dim=1).unsqueeze(1) # (B,1,T) + data_mask = torch.zeros((horizon, B, 1), device=device, dtype=filling_input.dtype) + data_mask[valid] = 1 + atten_mask = torch.ones((B, 1, horizon), + device=device, dtype=torch.bool) + atten_mask[valid_atten] = False + atten_mask = atten_mask.unsqueeze(2).repeat(1, 1, T, 1) # (B,1,T,T) + + output_ck = filling_model(filling_input, src_mask, data_mask, atten_mask) + + output_ck = output_ck.permute(1,0,2).reshape(T, 2, -1).cpu().detach() # two hands + + output_ck = output_ck[:T_original] + + filling_output = filling_postprocess(output_ck, transform_w_canon) + + # repalce the missing prediciton with infiller output + filling_seq['trans'][~seq_valid] = filling_output['trans'][~seq_valid] + filling_seq['rot'][~seq_valid] = filling_output['rot'][~seq_valid] + filling_seq['hand_pose'][~seq_valid] = filling_output['hand_pose'][~seq_valid] + filling_seq['betas'][~seq_valid] = filling_output['betas'][~seq_valid] + + pred_trans[:, filling_net_start:filling_net_end] = torch.from_numpy(filling_seq['trans'][:]) + pred_rot[:, filling_net_start:filling_net_end] = torch.from_numpy(filling_seq['rot'][:]) + pred_hand_pose[:, filling_net_start:filling_net_end] = torch.from_numpy(filling_seq['hand_pose'][:]) + pred_betas[:, filling_net_start:filling_net_end] = torch.from_numpy(filling_seq['betas'][:]) + pred_valid[:, filling_net_start:filling_net_end] = 1 + save_path = os.path.join(seq_folder, "world_space_res.pth") + joblib.dump([pred_trans, pred_rot, pred_hand_pose, pred_betas, pred_valid], save_path) + return pred_trans, pred_rot, pred_hand_pose, pred_betas, pred_valid + + \ No newline at end of file