initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +145 -0
- .readthedocs.yaml +13 -0
- app.py +202 -0
- configs/lmm/lmm.py +75 -0
- configs/lmm/lmm_small_demo.py +22 -0
- examples/angry.m4a +0 -0
- examples/placeholder.m4a +0 -0
- examples/surprise.m4a +0 -0
- mogen/__init__.py +56 -0
- mogen/apis/__init__.py +8 -0
- mogen/apis/test.py +158 -0
- mogen/apis/train.py +161 -0
- mogen/core/__init__.py +0 -0
- mogen/core/distributed_wrapper.py +135 -0
- mogen/core/optimizer/__init__.py +3 -0
- mogen/core/optimizer/builder.py +52 -0
- mogen/datasets/__init__.py +12 -0
- mogen/datasets/base_dataset.py +183 -0
- mogen/datasets/builder.py +149 -0
- mogen/datasets/dataset_wrappers.py +42 -0
- mogen/datasets/human_body_prior/__init__.py +22 -0
- mogen/datasets/human_body_prior/body_model/__init__.py +22 -0
- mogen/datasets/human_body_prior/body_model/body_model.py +281 -0
- mogen/datasets/human_body_prior/body_model/lbs.py +404 -0
- mogen/datasets/human_body_prior/body_model/parts_segm/readme +1 -0
- mogen/datasets/human_body_prior/body_model/rigid_object_model.py +67 -0
- mogen/datasets/human_body_prior/models/__init__.py +22 -0
- mogen/datasets/human_body_prior/models/ik_engine.py +287 -0
- mogen/datasets/human_body_prior/models/model_components.py +41 -0
- mogen/datasets/human_body_prior/models/vposer_model.py +133 -0
- mogen/datasets/human_body_prior/tools/__init__.py +22 -0
- mogen/datasets/human_body_prior/tools/angle_continuous_repres.py +80 -0
- mogen/datasets/human_body_prior/tools/configurations.py +47 -0
- mogen/datasets/human_body_prior/tools/model_loader.py +87 -0
- mogen/datasets/human_body_prior/tools/omni_tools.py +163 -0
- mogen/datasets/human_body_prior/tools/rotation_tools.py +151 -0
- mogen/datasets/human_body_prior/tools/tgm_conversion.py +527 -0
- mogen/datasets/human_body_prior/train/README.md +41 -0
- mogen/datasets/human_body_prior/train/V02_05/V02_05.py +54 -0
- mogen/datasets/human_body_prior/train/V02_05/V02_05.yaml +84 -0
- mogen/datasets/human_body_prior/train/V02_05/__init__.py +22 -0
- mogen/datasets/human_body_prior/train/__init__.py +22 -0
- mogen/datasets/human_body_prior/train/vposer_trainer.py +337 -0
- mogen/datasets/human_body_prior/visualizations/__init__.py +22 -0
- mogen/datasets/human_body_prior/visualizations/training_visualization.py +123 -0
- mogen/datasets/motionverse_dataset.py +828 -0
- mogen/datasets/paramUtil.py +140 -0
- mogen/datasets/pipelines/__init__.py +30 -0
- mogen/datasets/pipelines/compose.py +42 -0
- mogen/datasets/pipelines/formatting.py +135 -0
.gitignore
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
**/*.pyc
|
6 |
+
|
7 |
+
# C extensions
|
8 |
+
*.so
|
9 |
+
|
10 |
+
# Distribution / packaging
|
11 |
+
.Python
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.coverage
|
43 |
+
.coverage.*
|
44 |
+
.cache
|
45 |
+
nosetests.xml
|
46 |
+
coverage.xml
|
47 |
+
*.cover
|
48 |
+
.hypothesis/
|
49 |
+
.pytest_cache/
|
50 |
+
|
51 |
+
# Translations
|
52 |
+
*.mo
|
53 |
+
*.pot
|
54 |
+
|
55 |
+
# Django stuff:
|
56 |
+
*.log
|
57 |
+
local_settings.py
|
58 |
+
db.sqlite3
|
59 |
+
|
60 |
+
# Flask stuff:
|
61 |
+
instance/
|
62 |
+
.webassets-cache
|
63 |
+
|
64 |
+
# Scrapy stuff:
|
65 |
+
.scrapy
|
66 |
+
|
67 |
+
# Sphinx documentation
|
68 |
+
docs/en/build
|
69 |
+
|
70 |
+
# PyBuilder
|
71 |
+
target/
|
72 |
+
|
73 |
+
# Jupyter Notebook
|
74 |
+
.ipynb_checkpoints
|
75 |
+
|
76 |
+
# pyenv
|
77 |
+
.python-version
|
78 |
+
|
79 |
+
# celery beat schedule file
|
80 |
+
celerybeat-schedule
|
81 |
+
|
82 |
+
# SageMath parsed files
|
83 |
+
*.sage.py
|
84 |
+
|
85 |
+
# Environments
|
86 |
+
.env
|
87 |
+
.venv
|
88 |
+
env/
|
89 |
+
venv/
|
90 |
+
ENV/
|
91 |
+
env.bak/
|
92 |
+
venv.bak/
|
93 |
+
|
94 |
+
# Spyder project settings
|
95 |
+
.spyderproject
|
96 |
+
.spyproject
|
97 |
+
|
98 |
+
# Rope project settings
|
99 |
+
.ropeproject
|
100 |
+
|
101 |
+
# mkdocs documentation
|
102 |
+
/site
|
103 |
+
|
104 |
+
# mypy
|
105 |
+
.mypy_cache/
|
106 |
+
|
107 |
+
# custom
|
108 |
+
data
|
109 |
+
!mmhuman3d/data
|
110 |
+
# data for pytest moved to http server
|
111 |
+
# !tests/data
|
112 |
+
.vscode
|
113 |
+
.idea
|
114 |
+
*.pkl
|
115 |
+
*.pkl.json
|
116 |
+
*.log.json
|
117 |
+
work_dirs/
|
118 |
+
logs/
|
119 |
+
|
120 |
+
# Pytorch
|
121 |
+
*.pth
|
122 |
+
*.pt
|
123 |
+
|
124 |
+
|
125 |
+
# Visualization
|
126 |
+
*.mp4
|
127 |
+
*.png
|
128 |
+
*.gif
|
129 |
+
*.jpg
|
130 |
+
*.obj
|
131 |
+
*.ply
|
132 |
+
!demo/resources/*
|
133 |
+
|
134 |
+
# Resources as exception
|
135 |
+
!resources/*
|
136 |
+
|
137 |
+
# Loaded/Saved data files
|
138 |
+
*.npz
|
139 |
+
*.npy
|
140 |
+
*.pickle
|
141 |
+
|
142 |
+
# MacOS
|
143 |
+
*DS_Store*
|
144 |
+
# git
|
145 |
+
*.orig
|
.readthedocs.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: 2
|
2 |
+
|
3 |
+
build:
|
4 |
+
os: ubuntu-22.04
|
5 |
+
tools:
|
6 |
+
python: "3.9"
|
7 |
+
|
8 |
+
sphinx:
|
9 |
+
configuration: docs/en/source/conf.py
|
10 |
+
|
11 |
+
python:
|
12 |
+
install:
|
13 |
+
- requirements: requirements/docs.txt
|
app.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import gradio as gr
|
4 |
+
import time
|
5 |
+
|
6 |
+
os.makedirs("outputs", exist_ok=True)
|
7 |
+
sys.path.insert(0, '.')
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import os.path as osp
|
11 |
+
import mmcv
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from mmcv.runner import load_checkpoint
|
15 |
+
from mmcv.parallel import MMDataParallel
|
16 |
+
from scipy.ndimage import gaussian_filter
|
17 |
+
from IPython.display import Image
|
18 |
+
|
19 |
+
from mogen.models.utils.imagebind_wrapper import (
|
20 |
+
extract_text_feature,
|
21 |
+
extract_audio_feature,
|
22 |
+
imagebind_huge
|
23 |
+
)
|
24 |
+
from mogen.models import build_architecture
|
25 |
+
|
26 |
+
from mogen.utils.plot_utils import (
|
27 |
+
plot_3d_motion,
|
28 |
+
add_audio,
|
29 |
+
get_audio_length
|
30 |
+
)
|
31 |
+
from mogen.datasets.paramUtil import (
|
32 |
+
t2m_body_hand_kinematic_chain,
|
33 |
+
t2m_kinematic_chain
|
34 |
+
)
|
35 |
+
from mogen.datasets.utils import recover_from_ric
|
36 |
+
from mogen.datasets.pipelines import RetargetSkeleton
|
37 |
+
|
38 |
+
|
39 |
+
def motion_temporal_filter(motion, sigma=1):
|
40 |
+
motion = motion.reshape(motion.shape[0], -1)
|
41 |
+
for i in range(motion.shape[1]):
|
42 |
+
motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest")
|
43 |
+
return motion.reshape(motion.shape[0], -1, 3)
|
44 |
+
|
45 |
+
def plot_tomato(data, kinematic_chain, result_path, npy_path, fps, sigma=None):
|
46 |
+
joints = recover_from_ric(torch.from_numpy(data).float(), 52).numpy()
|
47 |
+
joints = motion_temporal_filter(joints, sigma=2.5)
|
48 |
+
joints = rtg_skl({"keypoints3d": joints, "meta_data": {"has_lhnd": True}})["keypoints3d"]
|
49 |
+
plot_3d_motion(
|
50 |
+
out_path=result_path,
|
51 |
+
joints=joints,
|
52 |
+
kinematic_chain=kinematic_chain,
|
53 |
+
title=None,
|
54 |
+
fps=fps)
|
55 |
+
if npy_path is not None:
|
56 |
+
np.save(npy_path, joints)
|
57 |
+
|
58 |
+
def create_lmm():
|
59 |
+
config_path = "configs/lmm/lmm_small_demo.py"
|
60 |
+
ckpt_path = "pretrained/lmm_small_demo.pth"
|
61 |
+
cfg = mmcv.Config.fromfile(config_path)
|
62 |
+
model = build_architecture(cfg.model)
|
63 |
+
load_checkpoint(model, ckpt_path, map_location='cpu')
|
64 |
+
if device == 'cpu':
|
65 |
+
model = model.cpu()
|
66 |
+
else:
|
67 |
+
model = MMDataParallel(model, device_ids=[0])
|
68 |
+
model.eval()
|
69 |
+
return model
|
70 |
+
|
71 |
+
# device = 'cpu'
|
72 |
+
device = 'cuda'
|
73 |
+
# os.environ["NO_PROXY"] = os.environ["no_proxy"] = "localhost, 127.0.0.1:7860"
|
74 |
+
model_lmm = create_lmm()
|
75 |
+
model_imagebind = imagebind_huge(pretrained=True)
|
76 |
+
model_imagebind.eval()
|
77 |
+
model_imagebind.to(device)
|
78 |
+
rtg_skl = RetargetSkeleton(tgt_skel_file='data/motionverse/statistics/skeleton.npy')
|
79 |
+
|
80 |
+
mean_path = "data/mean.npy"
|
81 |
+
std_path = "data/std.npy"
|
82 |
+
mean = np.load(mean_path)
|
83 |
+
std = np.load(std_path)
|
84 |
+
|
85 |
+
def show_generation_result(model, text, audio_path, motion_length, result_path):
|
86 |
+
fps = 20
|
87 |
+
if audio_path is not None:
|
88 |
+
motion_length = min(200, int(get_audio_length(audio_path) * fps) + 1)
|
89 |
+
motion = torch.zeros(1, motion_length, 669).to(device)
|
90 |
+
motion_mask = torch.ones(1, motion_length).to(device)
|
91 |
+
motion_mask[0, :motion_length] = 1
|
92 |
+
motion_mask = motion_mask.unsqueeze(-1).repeat(1, 1, 10)
|
93 |
+
motion_mask[:, :, 9] = 0
|
94 |
+
dataset_name = "humanml3d_t2m"
|
95 |
+
kinematic_chain = t2m_body_hand_kinematic_chain
|
96 |
+
rotation_type = "h3d_rot"
|
97 |
+
motion_metas = [{
|
98 |
+
'meta_data': dict(framerate=fps, dataset_name=dataset_name, rotation_type=rotation_type)
|
99 |
+
}]
|
100 |
+
motion_length = torch.Tensor([motion_length]).long().to(device)
|
101 |
+
if text is None and audio_path is not None:
|
102 |
+
text = "A person is standing and speaking."
|
103 |
+
|
104 |
+
model = model.to(device)
|
105 |
+
input = {
|
106 |
+
'motion': motion,
|
107 |
+
'motion_mask': motion_mask,
|
108 |
+
'motion_length': motion_length,
|
109 |
+
'motion_metas': motion_metas,
|
110 |
+
'num_intervals': 1
|
111 |
+
}
|
112 |
+
if text is not None:
|
113 |
+
text_word_feat, text_seq_feat = \
|
114 |
+
extract_text_feature([text], model_imagebind, device)
|
115 |
+
assert text_word_feat.shape[0] == 1
|
116 |
+
assert text_word_feat.shape[1] == 77
|
117 |
+
assert text_word_feat.shape[2] == 1024
|
118 |
+
assert text_seq_feat.shape[0] == 1
|
119 |
+
assert text_seq_feat.shape[1] == 1024
|
120 |
+
input['text_word_feat'] = text_word_feat
|
121 |
+
input['text_seq_feat'] = text_seq_feat
|
122 |
+
input['text_cond'] = torch.Tensor([1.0] * 1).to(device)
|
123 |
+
else:
|
124 |
+
input['text_word_feat'] = torch.zeros(1, 77, 1024).to(device)
|
125 |
+
input['text_seq_feat'] = torch.zeros(1, 1024)
|
126 |
+
input['text_cond'] = torch.Tensor([0] * 1).to(device)
|
127 |
+
if audio_path is not None:
|
128 |
+
speech_word_feat, speech_seq_feat = \
|
129 |
+
extract_audio_feature([audio_path], model_imagebind, device)
|
130 |
+
assert speech_word_feat.shape[0] == 1
|
131 |
+
assert speech_word_feat.shape[1] == 229
|
132 |
+
assert speech_word_feat.shape[2] == 768
|
133 |
+
assert speech_seq_feat.shape[0] == 1
|
134 |
+
assert speech_seq_feat.shape[1] == 1024
|
135 |
+
input['speech_word_feat'] = speech_word_feat
|
136 |
+
input['speech_seq_feat'] = speech_seq_feat
|
137 |
+
input['speech_cond'] = torch.Tensor([1.0] * 1).to(device)
|
138 |
+
else:
|
139 |
+
input['speech_word_feat'] = torch.zeros(1, 229, 768).to(device)
|
140 |
+
input['speech_seq_feat'] = torch.zeros(1, 1024)
|
141 |
+
input['speech_cond'] = torch.Tensor([0] * 1).to(device)
|
142 |
+
|
143 |
+
all_pred_motion = []
|
144 |
+
with torch.no_grad():
|
145 |
+
input['inference_kwargs'] = {}
|
146 |
+
output = model(**input)[0]['pred_motion'][:motion_length]
|
147 |
+
pred_motion = output.cpu().detach().numpy()
|
148 |
+
pred_motion = pred_motion * std + mean
|
149 |
+
|
150 |
+
plot_tomato(pred_motion, kinematic_chain, result_path, None, fps, 2)
|
151 |
+
|
152 |
+
if audio_path is not None:
|
153 |
+
add_audio(result_path, [audio_path])
|
154 |
+
|
155 |
+
def generate(prompt, audio_path, length):
|
156 |
+
if not os.path.exists("outputs"):
|
157 |
+
os.mkdir("outputs")
|
158 |
+
result_path = "outputs/" + str(int(time.time())) + ".mp4"
|
159 |
+
print(audio_path)
|
160 |
+
if audio_path.endswith("placeholder.wav"):
|
161 |
+
audio_path = None
|
162 |
+
if len(prompt) == 0:
|
163 |
+
prompt = None
|
164 |
+
show_generation_result(model_lmm, prompt, audio_path, length, result_path)
|
165 |
+
return result_path
|
166 |
+
|
167 |
+
input_audio = gr.Audio(
|
168 |
+
type='filepath',
|
169 |
+
format='wav',
|
170 |
+
label="Audio (1-10s, overwrite motion length):",
|
171 |
+
show_label=True,
|
172 |
+
sources=["upload", "microphone"],
|
173 |
+
min_length=1,
|
174 |
+
max_length=10,
|
175 |
+
waveform_options=gr.WaveformOptions(
|
176 |
+
waveform_color="#01C6FF",
|
177 |
+
waveform_progress_color="#0066B4",
|
178 |
+
skip_length=2,
|
179 |
+
show_controls=False,
|
180 |
+
),
|
181 |
+
)
|
182 |
+
|
183 |
+
input_text = gr.Textbox(
|
184 |
+
label="Text prompt:"
|
185 |
+
)
|
186 |
+
|
187 |
+
demo = gr.Interface(
|
188 |
+
fn=generate,
|
189 |
+
inputs=[input_text, input_audio, gr.Slider(20, 200, value=60, label="Motion length (fps 20):")],
|
190 |
+
outputs=gr.Video(label="Video:"),
|
191 |
+
examples=[
|
192 |
+
["A person walks in a circle.", "examples/placeholder.m4a", 120],
|
193 |
+
["A person jumps forward.", "examples/placeholder.m4a", 100],
|
194 |
+
["A person is stretching arms.", "examples/placeholder.m4a", 80],
|
195 |
+
["", "examples/surprise.m4a", 200],
|
196 |
+
["", "examples/angry.m4a", 200],
|
197 |
+
],
|
198 |
+
title="LMM: Large Motion Model for Unified Multi-Modal Motion Generation",
|
199 |
+
description="\nThis is an interactive demo for LMM. For more information, feel free to visit our project page(https://github.com/mingyuan-zhang/LMM).")
|
200 |
+
|
201 |
+
demo.queue()
|
202 |
+
demo.launch()
|
configs/lmm/lmm.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_names = [
|
2 |
+
'all',
|
3 |
+
'amass_mocap', 'motionx_mocap', 'humanact12_mocap', 'uestc_mocap', 'ntu_mocap', 'aist_mocap',
|
4 |
+
'beat_mocap', 'tedg_mocap', 'tedex_mocap', 's2g3d_mocap', 'h36m_mocap', 'mpi_mocap',
|
5 |
+
|
6 |
+
'humanml3d_t2m', 'kitml_t2m', 'babel_t2m', 'motionx_t2m',
|
7 |
+
'humanact12_t2m', 'uestc_t2m', 'ntu_t2m',
|
8 |
+
|
9 |
+
'aist_m2d',
|
10 |
+
'beat_s2g', 'tedg_s2g', 'tedex_s2g', 's2g3d_s2g',
|
11 |
+
|
12 |
+
'h36m_v2m', 'mpi_v2m'
|
13 |
+
]
|
14 |
+
num_datasets = len(dataset_names)
|
15 |
+
# model settings
|
16 |
+
model = dict(
|
17 |
+
type='UnifiedMotionDiffusion',
|
18 |
+
model=dict(
|
19 |
+
type='LargeMotionModel',
|
20 |
+
input_feats=669,
|
21 |
+
max_seq_len=200,
|
22 |
+
num_parts=10,
|
23 |
+
latent_part_dim=64,
|
24 |
+
time_embed_dim=2048,
|
25 |
+
dataset_names=dataset_names,
|
26 |
+
num_layers=4,
|
27 |
+
num_cond_layers=2,
|
28 |
+
num_datasets=num_datasets,
|
29 |
+
dropout=0,
|
30 |
+
ca_block_cfg=dict(
|
31 |
+
type='ArtAttention',
|
32 |
+
num_experts=16,
|
33 |
+
topk=4,
|
34 |
+
gate_type='cosine_top',
|
35 |
+
gate_noise=1.0,
|
36 |
+
num_datasets=num_datasets,
|
37 |
+
has_text=True,
|
38 |
+
has_music=True,
|
39 |
+
has_speech=True,
|
40 |
+
has_video=True
|
41 |
+
),
|
42 |
+
text_input_dim=1024,
|
43 |
+
music_input_dim=768,
|
44 |
+
speech_input_dim=768,
|
45 |
+
video_input_dim=1024,
|
46 |
+
guidance_cfg=dict(
|
47 |
+
all=dict(type='linear', scale=5.5),
|
48 |
+
),
|
49 |
+
moe_route_loss_weight=10.0,
|
50 |
+
template_kl_loss_weight=0.0001,
|
51 |
+
use_pos_embedding=False,
|
52 |
+
cond_drop_rate=0.1
|
53 |
+
),
|
54 |
+
loss_recon=dict(
|
55 |
+
type='KinematicLoss', loss_type='mse', loss_weight=[20], reduction='none'),
|
56 |
+
train_repeat=1,
|
57 |
+
diffusion_train=dict(
|
58 |
+
beta_scheduler='linear',
|
59 |
+
diffusion_steps=1000,
|
60 |
+
model_mean_type='start_x',
|
61 |
+
model_var_type='fixed_large',
|
62 |
+
),
|
63 |
+
diffusion_test_dict=dict(
|
64 |
+
base=dict(
|
65 |
+
beta_scheduler='linear',
|
66 |
+
diffusion_steps=1000,
|
67 |
+
model_mean_type='start_x',
|
68 |
+
model_var_type='fixed_large',
|
69 |
+
),
|
70 |
+
all='15,15,8,6,6'
|
71 |
+
),
|
72 |
+
inference_type='ddim',
|
73 |
+
loss_reduction='batch',
|
74 |
+
loss_weight='data/motionverse/statistics/loss_weight.npy'
|
75 |
+
)
|
configs/lmm/lmm_small_demo.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['lmm.py']
|
2 |
+
|
3 |
+
model = dict(
|
4 |
+
model=dict(
|
5 |
+
latent_part_dim=64,
|
6 |
+
num_layers=8,
|
7 |
+
num_cond_layers=2,
|
8 |
+
dropout=0.1,
|
9 |
+
ca_block_cfg=dict(
|
10 |
+
num_experts=16,
|
11 |
+
topk=4
|
12 |
+
),
|
13 |
+
guidance_cfg=dict(
|
14 |
+
humanml3d_t2m=dict(type='linear', scale=10.5),
|
15 |
+
),
|
16 |
+
),
|
17 |
+
diffusion_test_dict=dict(
|
18 |
+
humanml3d_t2m='15,15,8,6,6',
|
19 |
+
),
|
20 |
+
)
|
21 |
+
|
22 |
+
data = dict(samples_per_gpu=32)
|
examples/angry.m4a
ADDED
Binary file (108 kB). View file
|
|
examples/placeholder.m4a
ADDED
Binary file (30 kB). View file
|
|
examples/surprise.m4a
ADDED
Binary file (89.4 kB). View file
|
|
mogen/__init__.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import mmcv
|
4 |
+
from packaging.version import parse
|
5 |
+
|
6 |
+
from .version import __version__
|
7 |
+
|
8 |
+
|
9 |
+
def digit_version(version_str: str, length: int = 4):
|
10 |
+
"""Convert a version string into a tuple of integers.
|
11 |
+
This method is usually used for comparing two versions. For pre-release
|
12 |
+
versions: alpha < beta < rc.
|
13 |
+
Args:
|
14 |
+
version_str (str): The version string.
|
15 |
+
length (int): The maximum number of version levels. Default: 4.
|
16 |
+
Returns:
|
17 |
+
tuple[int]: The version info in digits (integers).
|
18 |
+
"""
|
19 |
+
version = parse(version_str)
|
20 |
+
assert version.release, f'failed to parse version {version_str}'
|
21 |
+
release = list(version.release)
|
22 |
+
release = release[:length]
|
23 |
+
if len(release) < length:
|
24 |
+
release = release + [0] * (length - len(release))
|
25 |
+
if version.is_prerelease:
|
26 |
+
mapping = {'a': -3, 'b': -2, 'rc': -1}
|
27 |
+
val = -4
|
28 |
+
# version.pre can be None
|
29 |
+
if version.pre:
|
30 |
+
if version.pre[0] not in mapping:
|
31 |
+
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
|
32 |
+
'version checking may go wrong')
|
33 |
+
else:
|
34 |
+
val = mapping[version.pre[0]]
|
35 |
+
release.extend([val, version.pre[-1]])
|
36 |
+
else:
|
37 |
+
release.extend([val, 0])
|
38 |
+
|
39 |
+
elif version.is_postrelease:
|
40 |
+
release.extend([1, version.post])
|
41 |
+
else:
|
42 |
+
release.extend([0, 0])
|
43 |
+
return tuple(release)
|
44 |
+
|
45 |
+
|
46 |
+
mmcv_minimum_version = '1.4.2'
|
47 |
+
mmcv_maximum_version = '1.9.0'
|
48 |
+
mmcv_version = digit_version(mmcv.__version__)
|
49 |
+
|
50 |
+
|
51 |
+
assert (mmcv_version >= digit_version(mmcv_minimum_version)
|
52 |
+
and mmcv_version <= digit_version(mmcv_maximum_version)), \
|
53 |
+
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
54 |
+
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
|
55 |
+
|
56 |
+
__all__ = ['__version__', 'digit_version']
|
mogen/apis/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mogen.apis.test import (collect_results_cpu, collect_results_gpu,
|
2 |
+
multi_gpu_test, single_gpu_test)
|
3 |
+
from mogen.apis.train import set_random_seed, train_model
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
|
7 |
+
'single_gpu_test', 'set_random_seed', 'train_model'
|
8 |
+
]
|
mogen/apis/test.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import pickle
|
3 |
+
import shutil
|
4 |
+
import tempfile
|
5 |
+
import time
|
6 |
+
|
7 |
+
import mmcv
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
from mmcv.runner import get_dist_info
|
11 |
+
|
12 |
+
|
13 |
+
def single_gpu_test(model, data_loader):
|
14 |
+
"""Test with single gpu."""
|
15 |
+
model.eval()
|
16 |
+
results = []
|
17 |
+
dataset = data_loader.dataset
|
18 |
+
prog_bar = mmcv.ProgressBar(len(dataset))
|
19 |
+
for i, data in enumerate(data_loader):
|
20 |
+
with torch.no_grad():
|
21 |
+
result = model(return_loss=False, **data)
|
22 |
+
|
23 |
+
batch_size = len(result)
|
24 |
+
if isinstance(result, list):
|
25 |
+
results.extend(result)
|
26 |
+
else:
|
27 |
+
results.append(result)
|
28 |
+
|
29 |
+
batch_size = data['motion'].size(0)
|
30 |
+
for _ in range(batch_size):
|
31 |
+
prog_bar.update()
|
32 |
+
return results
|
33 |
+
|
34 |
+
|
35 |
+
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
|
36 |
+
"""Test model with multiple gpus.
|
37 |
+
This method tests model with multiple gpus and collects the results
|
38 |
+
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
|
39 |
+
it encodes results to gpu tensors and use gpu communication for results
|
40 |
+
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
|
41 |
+
and collects them by the rank 0 worker.
|
42 |
+
Args:
|
43 |
+
model (nn.Module): Model to be tested.
|
44 |
+
data_loader (nn.Dataloader): Pytorch data loader.
|
45 |
+
tmpdir (str): Path of directory to save the temporary results from
|
46 |
+
different gpus under cpu mode.
|
47 |
+
gpu_collect (bool): Option to use either gpu or cpu to collect results.
|
48 |
+
Returns:
|
49 |
+
list: The prediction results.
|
50 |
+
"""
|
51 |
+
model.eval()
|
52 |
+
results = []
|
53 |
+
dataset = data_loader.dataset
|
54 |
+
rank, world_size = get_dist_info()
|
55 |
+
if rank == 0:
|
56 |
+
# Check if tmpdir is valid for cpu_collect
|
57 |
+
if (not gpu_collect) and (tmpdir is not None and osp.exists(tmpdir)):
|
58 |
+
raise OSError((f'The tmpdir {tmpdir} already exists.',
|
59 |
+
' Since tmpdir will be deleted after testing,',
|
60 |
+
' please make sure you specify an empty one.'))
|
61 |
+
prog_bar = mmcv.ProgressBar(len(dataset))
|
62 |
+
time.sleep(2) # This line can prevent deadlock problem in some cases.
|
63 |
+
for i, data in enumerate(data_loader):
|
64 |
+
with torch.no_grad():
|
65 |
+
result = model(return_loss=False, **data)
|
66 |
+
if isinstance(result, list):
|
67 |
+
results.extend(result)
|
68 |
+
else:
|
69 |
+
results.append(result)
|
70 |
+
|
71 |
+
if rank == 0:
|
72 |
+
batch_size = data['motion'].size(0)
|
73 |
+
for _ in range(batch_size * world_size):
|
74 |
+
prog_bar.update()
|
75 |
+
|
76 |
+
# collect results from all ranks
|
77 |
+
if gpu_collect:
|
78 |
+
results = collect_results_gpu(results, len(dataset))
|
79 |
+
else:
|
80 |
+
results = collect_results_cpu(results, len(dataset), tmpdir)
|
81 |
+
return results
|
82 |
+
|
83 |
+
|
84 |
+
def collect_results_cpu(result_part, size, tmpdir=None):
|
85 |
+
"""Collect results in cpu."""
|
86 |
+
rank, world_size = get_dist_info()
|
87 |
+
# create a tmp dir if it is not specified
|
88 |
+
if tmpdir is None:
|
89 |
+
MAX_LEN = 512
|
90 |
+
# 32 is whitespace
|
91 |
+
dir_tensor = torch.full((MAX_LEN, ),
|
92 |
+
32,
|
93 |
+
dtype=torch.uint8,
|
94 |
+
device='cuda')
|
95 |
+
if rank == 0:
|
96 |
+
mmcv.mkdir_or_exist('.dist_test')
|
97 |
+
tmpdir = tempfile.mkdtemp(dir='.dist_test')
|
98 |
+
tmpdir = torch.tensor(bytearray(tmpdir.encode()),
|
99 |
+
dtype=torch.uint8,
|
100 |
+
device='cuda')
|
101 |
+
dir_tensor[:len(tmpdir)] = tmpdir
|
102 |
+
dist.broadcast(dir_tensor, 0)
|
103 |
+
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
|
104 |
+
else:
|
105 |
+
mmcv.mkdir_or_exist(tmpdir)
|
106 |
+
# dump the part result to the dir
|
107 |
+
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
|
108 |
+
dist.barrier()
|
109 |
+
# collect all parts
|
110 |
+
if rank != 0:
|
111 |
+
return None
|
112 |
+
else:
|
113 |
+
# load results of all parts from tmp dir
|
114 |
+
part_list = []
|
115 |
+
for i in range(world_size):
|
116 |
+
part_file = osp.join(tmpdir, f'part_{i}.pkl')
|
117 |
+
part_result = mmcv.load(part_file)
|
118 |
+
part_list.append(part_result)
|
119 |
+
# sort the results
|
120 |
+
ordered_results = []
|
121 |
+
for res in zip(*part_list):
|
122 |
+
ordered_results.extend(list(res))
|
123 |
+
# the dataloader may pad some samples
|
124 |
+
ordered_results = ordered_results[:size]
|
125 |
+
# remove tmp dir
|
126 |
+
shutil.rmtree(tmpdir)
|
127 |
+
return ordered_results
|
128 |
+
|
129 |
+
|
130 |
+
def collect_results_gpu(result_part, size):
|
131 |
+
"""Collect results in gpu."""
|
132 |
+
rank, world_size = get_dist_info()
|
133 |
+
# dump result part to tensor with pickle
|
134 |
+
part_tensor = torch.tensor(bytearray(pickle.dumps(result_part)),
|
135 |
+
dtype=torch.uint8,
|
136 |
+
device='cuda')
|
137 |
+
# gather all result part tensor shape
|
138 |
+
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
|
139 |
+
shape_list = [shape_tensor.clone() for _ in range(world_size)]
|
140 |
+
dist.all_gather(shape_list, shape_tensor)
|
141 |
+
# padding result part tensor to max length
|
142 |
+
shape_max = torch.tensor(shape_list).max()
|
143 |
+
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
|
144 |
+
part_send[:shape_tensor[0]] = part_tensor
|
145 |
+
part_recv_list = [
|
146 |
+
part_tensor.new_zeros(shape_max) for _ in range(world_size)
|
147 |
+
]
|
148 |
+
# gather all result part
|
149 |
+
dist.all_gather(part_recv_list, part_send)
|
150 |
+
|
151 |
+
if rank == 0:
|
152 |
+
ordered_results = []
|
153 |
+
for recv, shape in zip(part_recv_list, shape_list):
|
154 |
+
part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
|
155 |
+
ordered_results.extend(part_result)
|
156 |
+
# the dataloader may pad some samples
|
157 |
+
ordered_results = ordered_results[:size]
|
158 |
+
return ordered_results
|
mogen/apis/train.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
7 |
+
from mmcv.runner import (
|
8 |
+
DistSamplerSeedHook,
|
9 |
+
Fp16OptimizerHook,
|
10 |
+
OptimizerHook,
|
11 |
+
GradientCumulativeFp16OptimizerHook,
|
12 |
+
GradientCumulativeOptimizerHook,
|
13 |
+
build_runner)
|
14 |
+
|
15 |
+
from mogen.core.distributed_wrapper import DistributedDataParallelWrapper
|
16 |
+
from mogen.core.evaluation import DistEvalHook, EvalHook
|
17 |
+
from mogen.core.optimizer import build_optimizers
|
18 |
+
from mogen.datasets import build_dataloader, build_dataset
|
19 |
+
from mogen.utils import get_root_logger
|
20 |
+
|
21 |
+
|
22 |
+
def set_random_seed(seed, deterministic=False):
|
23 |
+
"""Set random seed.
|
24 |
+
Args:
|
25 |
+
seed (int): Seed to be used.
|
26 |
+
deterministic (bool): Whether to set the deterministic option for
|
27 |
+
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
28 |
+
to True and `torch.backends.cudnn.benchmark` to False.
|
29 |
+
Default: False.
|
30 |
+
"""
|
31 |
+
random.seed(seed)
|
32 |
+
np.random.seed(seed)
|
33 |
+
torch.manual_seed(seed)
|
34 |
+
torch.cuda.manual_seed_all(seed)
|
35 |
+
if deterministic:
|
36 |
+
torch.backends.cudnn.deterministic = True
|
37 |
+
torch.backends.cudnn.benchmark = False
|
38 |
+
|
39 |
+
|
40 |
+
def train_model(model,
|
41 |
+
dataset,
|
42 |
+
cfg,
|
43 |
+
distributed=False,
|
44 |
+
validate=False,
|
45 |
+
timestamp=None,
|
46 |
+
device='cuda',
|
47 |
+
meta=None):
|
48 |
+
"""Main api for training model."""
|
49 |
+
logger = get_root_logger(cfg.log_level)
|
50 |
+
|
51 |
+
# prepare data loaders
|
52 |
+
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
53 |
+
|
54 |
+
data_loaders = [
|
55 |
+
build_dataloader(
|
56 |
+
ds,
|
57 |
+
cfg.data.samples_per_gpu,
|
58 |
+
cfg.data.workers_per_gpu,
|
59 |
+
# cfg.gpus will be ignored if distributed
|
60 |
+
num_gpus=len(cfg.gpu_ids),
|
61 |
+
dist=distributed,
|
62 |
+
round_up=True,
|
63 |
+
sampler_cfg=cfg.data.sampler_cfg,
|
64 |
+
batch_sampler_cfg=cfg.data.batch_sampler_cfg,
|
65 |
+
seed=cfg.seed) for ds in dataset
|
66 |
+
]
|
67 |
+
|
68 |
+
# determine whether use adversarial training precess or not
|
69 |
+
use_adversarial_train = cfg.get('use_adversarial_train', False)
|
70 |
+
|
71 |
+
# put model on gpus
|
72 |
+
if distributed:
|
73 |
+
find_unused_parameters = cfg.get('find_unused_parameters', True)
|
74 |
+
# Sets the `find_unused_parameters` parameter in
|
75 |
+
# torch.nn.parallel.DistributedDataParallel
|
76 |
+
if use_adversarial_train:
|
77 |
+
# Use DistributedDataParallelWrapper for adversarial training
|
78 |
+
model = DistributedDataParallelWrapper(
|
79 |
+
model,
|
80 |
+
device_ids=[torch.cuda.current_device()],
|
81 |
+
broadcast_buffers=False,
|
82 |
+
find_unused_parameters=find_unused_parameters)
|
83 |
+
else:
|
84 |
+
model = MMDistributedDataParallel(
|
85 |
+
model.cuda(),
|
86 |
+
device_ids=[torch.cuda.current_device()],
|
87 |
+
broadcast_buffers=False,
|
88 |
+
find_unused_parameters=find_unused_parameters)
|
89 |
+
else:
|
90 |
+
if device == 'cuda':
|
91 |
+
model = MMDataParallel(model.cuda(cfg.gpu_ids[0]),
|
92 |
+
device_ids=cfg.gpu_ids)
|
93 |
+
elif device == 'cpu':
|
94 |
+
model = model.cpu()
|
95 |
+
else:
|
96 |
+
raise ValueError(F'unsupported device name {device}.')
|
97 |
+
|
98 |
+
# build runner
|
99 |
+
optimizer = build_optimizers(model, cfg.optimizer)
|
100 |
+
|
101 |
+
if cfg.get('runner') is None:
|
102 |
+
cfg.runner = {
|
103 |
+
'type': 'EpochBasedRunner',
|
104 |
+
'max_epochs': cfg.total_epochs
|
105 |
+
}
|
106 |
+
warnings.warn(
|
107 |
+
'config is now expected to have a `runner` section, '
|
108 |
+
'please set `runner` in your config.', UserWarning)
|
109 |
+
|
110 |
+
runner = build_runner(cfg.runner,
|
111 |
+
default_args=dict(model=model,
|
112 |
+
batch_processor=None,
|
113 |
+
optimizer=optimizer,
|
114 |
+
work_dir=cfg.work_dir,
|
115 |
+
logger=logger,
|
116 |
+
meta=meta))
|
117 |
+
|
118 |
+
# an ugly walkaround to make the .log and .log.json filenames the same
|
119 |
+
runner.timestamp = timestamp
|
120 |
+
|
121 |
+
if use_adversarial_train:
|
122 |
+
# The optimizer step process is included in the train_step function
|
123 |
+
# of the model, so the runner should NOT include optimizer hook.
|
124 |
+
optimizer_config = None
|
125 |
+
else:
|
126 |
+
if distributed and 'type' not in cfg.optimizer_config:
|
127 |
+
optimizer_config = OptimizerHook(**cfg.optimizer_config)
|
128 |
+
else:
|
129 |
+
optimizer_config = cfg.optimizer_config
|
130 |
+
|
131 |
+
# register hooks
|
132 |
+
runner.register_training_hooks(cfg.lr_config,
|
133 |
+
optimizer_config,
|
134 |
+
cfg.checkpoint_config,
|
135 |
+
cfg.log_config,
|
136 |
+
cfg.get('momentum_config', None),
|
137 |
+
custom_hooks_config=cfg.get(
|
138 |
+
'custom_hooks', None))
|
139 |
+
if distributed:
|
140 |
+
runner.register_hook(DistSamplerSeedHook())
|
141 |
+
|
142 |
+
# register eval hooks
|
143 |
+
if validate:
|
144 |
+
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
145 |
+
val_dataloader = build_dataloader(
|
146 |
+
val_dataset,
|
147 |
+
samples_per_gpu=cfg.data.samples_per_gpu,
|
148 |
+
workers_per_gpu=cfg.data.workers_per_gpu,
|
149 |
+
dist=distributed,
|
150 |
+
shuffle=False,
|
151 |
+
round_up=True)
|
152 |
+
eval_cfg = cfg.get('evaluation', {})
|
153 |
+
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
154 |
+
eval_hook = DistEvalHook if distributed else EvalHook
|
155 |
+
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
156 |
+
|
157 |
+
if cfg.resume_from:
|
158 |
+
runner.resume(cfg.resume_from)
|
159 |
+
elif cfg.load_from:
|
160 |
+
runner.load_checkpoint(cfg.load_from)
|
161 |
+
runner.run(data_loaders, cfg.workflow)
|
mogen/core/__init__.py
ADDED
File without changes
|
mogen/core/distributed_wrapper.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel
|
5 |
+
from mmcv.parallel.scatter_gather import scatter_kwargs
|
6 |
+
from torch.cuda._utils import _get_device_index
|
7 |
+
|
8 |
+
|
9 |
+
@MODULE_WRAPPERS.register_module()
|
10 |
+
class DistributedDataParallelWrapper(nn.Module):
|
11 |
+
"""A DistributedDataParallel wrapper for models in 3D mesh estimation task.
|
12 |
+
|
13 |
+
In some pieplines, there is a need to wrap different modules in
|
14 |
+
the models with separate DistributedDataParallel. Otherwise, it will cause
|
15 |
+
errors for GAN training.
|
16 |
+
More specific, the GAN model, usually has two sub-modules:
|
17 |
+
generator and discriminator. If we wrap both of them in one
|
18 |
+
standard DistributedDataParallel, it will cause errors during training,
|
19 |
+
because when we update the parameters of the generator (or discriminator),
|
20 |
+
the parameters of the discriminator (or generator) is not updated, which is
|
21 |
+
not allowed for DistributedDataParallel.
|
22 |
+
So we design this wrapper to separately wrap DistributedDataParallel
|
23 |
+
for generator and discriminator.
|
24 |
+
In this wrapper, we perform two operations:
|
25 |
+
1. Wrap the modules in the models with separate MMDistributedDataParallel.
|
26 |
+
Note that only modules with parameters will be wrapped.
|
27 |
+
2. Do scatter operation for 'forward', 'train_step' and 'val_step'.
|
28 |
+
Note that the arguments of this wrapper is the same as those in
|
29 |
+
`torch.nn.parallel.distributed.DistributedDataParallel`.
|
30 |
+
Args:
|
31 |
+
module (nn.Module): Module that needs to be wrapped.
|
32 |
+
device_ids (list[int | `torch.device`]): Same as that in
|
33 |
+
`torch.nn.parallel.distributed.DistributedDataParallel`.
|
34 |
+
dim (int, optional): Same as that in the official scatter function in
|
35 |
+
pytorch. Defaults to 0.
|
36 |
+
broadcast_buffers (bool): Same as that in
|
37 |
+
`torch.nn.parallel.distributed.DistributedDataParallel`.
|
38 |
+
Defaults to False.
|
39 |
+
find_unused_parameters (bool, optional): Same as that in
|
40 |
+
`torch.nn.parallel.distributed.DistributedDataParallel`.
|
41 |
+
Traverse the autograd graph of all tensors contained in returned
|
42 |
+
value of the wrapped module’s forward function. Defaults to False.
|
43 |
+
kwargs (dict): Other arguments used in
|
44 |
+
`torch.nn.parallel.distributed.DistributedDataParallel`.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self,
|
48 |
+
module,
|
49 |
+
device_ids,
|
50 |
+
dim=0,
|
51 |
+
broadcast_buffers=False,
|
52 |
+
find_unused_parameters=False,
|
53 |
+
**kwargs):
|
54 |
+
super().__init__()
|
55 |
+
assert len(device_ids) == 1, (
|
56 |
+
'Currently, DistributedDataParallelWrapper only supports one'
|
57 |
+
'single CUDA device for each process.'
|
58 |
+
f'The length of device_ids must be 1, but got {len(device_ids)}.')
|
59 |
+
self.module = module
|
60 |
+
self.dim = dim
|
61 |
+
self.to_ddp(device_ids=device_ids,
|
62 |
+
dim=dim,
|
63 |
+
broadcast_buffers=broadcast_buffers,
|
64 |
+
find_unused_parameters=find_unused_parameters,
|
65 |
+
**kwargs)
|
66 |
+
self.output_device = _get_device_index(device_ids[0], True)
|
67 |
+
|
68 |
+
def to_ddp(self, device_ids, dim, broadcast_buffers,
|
69 |
+
find_unused_parameters, **kwargs):
|
70 |
+
"""Wrap models with separate MMDistributedDataParallel.
|
71 |
+
|
72 |
+
It only wraps the modules with parameters.
|
73 |
+
"""
|
74 |
+
for name, module in self.module._modules.items():
|
75 |
+
if next(module.parameters(), None) is None:
|
76 |
+
module = module.cuda()
|
77 |
+
elif all(not p.requires_grad for p in module.parameters()):
|
78 |
+
module = module.cuda()
|
79 |
+
else:
|
80 |
+
module = MMDistributedDataParallel(
|
81 |
+
module.cuda(),
|
82 |
+
device_ids=device_ids,
|
83 |
+
dim=dim,
|
84 |
+
broadcast_buffers=broadcast_buffers,
|
85 |
+
find_unused_parameters=find_unused_parameters,
|
86 |
+
**kwargs)
|
87 |
+
self.module._modules[name] = module
|
88 |
+
|
89 |
+
def scatter(self, inputs, kwargs, device_ids):
|
90 |
+
"""Scatter function.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
inputs (Tensor): Input Tensor.
|
94 |
+
kwargs (dict): Args for
|
95 |
+
``mmcv.parallel.scatter_gather.scatter_kwargs``.
|
96 |
+
device_ids (int): Device id.
|
97 |
+
"""
|
98 |
+
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
99 |
+
|
100 |
+
def forward(self, *inputs, **kwargs):
|
101 |
+
"""Forward function.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
inputs (tuple): Input data.
|
105 |
+
kwargs (dict): Args for
|
106 |
+
``mmcv.parallel.scatter_gather.scatter_kwargs``.
|
107 |
+
"""
|
108 |
+
inputs, kwargs = self.scatter(inputs, kwargs,
|
109 |
+
[torch.cuda.current_device()])
|
110 |
+
return self.module(*inputs[0], **kwargs[0])
|
111 |
+
|
112 |
+
def train_step(self, *inputs, **kwargs):
|
113 |
+
"""Train step function.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
inputs (Tensor): Input Tensor.
|
117 |
+
kwargs (dict): Args for
|
118 |
+
``mmcv.parallel.scatter_gather.scatter_kwargs``.
|
119 |
+
"""
|
120 |
+
inputs, kwargs = self.scatter(inputs, kwargs,
|
121 |
+
[torch.cuda.current_device()])
|
122 |
+
output = self.module.train_step(*inputs[0], **kwargs[0])
|
123 |
+
return output
|
124 |
+
|
125 |
+
def val_step(self, *inputs, **kwargs):
|
126 |
+
"""Validation step function.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
inputs (tuple): Input data.
|
130 |
+
kwargs (dict): Args for ``scatter_kwargs``.
|
131 |
+
"""
|
132 |
+
inputs, kwargs = self.scatter(inputs, kwargs,
|
133 |
+
[torch.cuda.current_device()])
|
134 |
+
output = self.module.val_step(*inputs[0], **kwargs[0])
|
135 |
+
return output
|
mogen/core/optimizer/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .builder import OPTIMIZERS, build_optimizers
|
2 |
+
|
3 |
+
__all__ = ['build_optimizers', 'OPTIMIZERS']
|
mogen/core/optimizer/builder.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from mmcv.runner import build_optimizer
|
3 |
+
from mmcv.utils import Registry
|
4 |
+
|
5 |
+
OPTIMIZERS = Registry('optimizers')
|
6 |
+
|
7 |
+
|
8 |
+
def build_optimizers(model, cfgs):
|
9 |
+
"""Build multiple optimizers from configs. If `cfgs` contains several dicts
|
10 |
+
for optimizers, then a dict for each constructed optimizers will be
|
11 |
+
returned. If `cfgs` only contains one optimizer config, the constructed
|
12 |
+
optimizer itself will be returned. For example,
|
13 |
+
|
14 |
+
1) Multiple optimizer configs:
|
15 |
+
|
16 |
+
.. code-block:: python
|
17 |
+
|
18 |
+
optimizer_cfg = dict(
|
19 |
+
model1=dict(type='SGD', lr=lr),
|
20 |
+
model2=dict(type='SGD', lr=lr))
|
21 |
+
|
22 |
+
The return dict is
|
23 |
+
``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)``
|
24 |
+
|
25 |
+
2) Single optimizer config:
|
26 |
+
|
27 |
+
.. code-block:: python
|
28 |
+
|
29 |
+
optimizer_cfg = dict(type='SGD', lr=lr)
|
30 |
+
|
31 |
+
The return is ``torch.optim.Optimizer``.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
model (:obj:`nn.Module`): The model with parameters to be optimized.
|
35 |
+
cfgs (dict): The config dict of the optimizer.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`:
|
39 |
+
The initialized optimizers.
|
40 |
+
"""
|
41 |
+
optimizers = {}
|
42 |
+
if hasattr(model, 'module'):
|
43 |
+
model = model.module
|
44 |
+
# determine whether 'cfgs' has several dicts for optimizers
|
45 |
+
if all(isinstance(v, dict) for v in cfgs.values()):
|
46 |
+
for key, cfg in cfgs.items():
|
47 |
+
cfg_ = cfg.copy()
|
48 |
+
module = getattr(model, key)
|
49 |
+
optimizers[key] = build_optimizer(module, cfg_)
|
50 |
+
return optimizers
|
51 |
+
|
52 |
+
return build_optimizer(model, cfgs)
|
mogen/datasets/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_dataset import BaseMotionDataset
|
2 |
+
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
|
3 |
+
from .pipelines import Compose
|
4 |
+
from .samplers import DistributedSampler
|
5 |
+
from .text_motion_dataset import TextMotionDataset
|
6 |
+
from .motionverse_dataset import MotionVerse
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
'BaseMotionDataset', 'TextMotionDataset', 'DATASETS', 'PIPELINES',
|
10 |
+
'build_dataloader', 'build_dataset', 'Compose', 'DistributedSampler',
|
11 |
+
'MotionVerse'
|
12 |
+
]
|
mogen/datasets/base_dataset.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
from abc import abstractmethod
|
5 |
+
from typing import Optional, Union, List, Dict
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
|
11 |
+
# from mogen.core.evaluation import build_evaluator
|
12 |
+
from mogen.models.builder import build_submodule
|
13 |
+
from .builder import DATASETS
|
14 |
+
from .pipelines import Compose
|
15 |
+
|
16 |
+
|
17 |
+
@DATASETS.register_module()
|
18 |
+
class BaseMotionDataset(Dataset):
|
19 |
+
"""
|
20 |
+
Base class for motion datasets.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
data_prefix (str): The prefix of the data path.
|
24 |
+
pipeline (list): A list of dicts, where each element represents an operation
|
25 |
+
defined in `mogen.datasets.pipelines`.
|
26 |
+
dataset_name (Optional[Union[str, None]]): The name of the dataset. Used to
|
27 |
+
identify the type of evaluation metric.
|
28 |
+
fixed_length (Optional[Union[int, None]]): The fixed length of the dataset for
|
29 |
+
iteration. If None, the dataset length is based on the number
|
30 |
+
of annotations.
|
31 |
+
ann_file (Optional[Union[str, None]]): The annotation file. If it is a string,
|
32 |
+
it is expected to be read from the file. If None, it will be
|
33 |
+
read from `data_prefix`.
|
34 |
+
motion_dir (Optional[Union[str, None]]): The directory containing motion data.
|
35 |
+
eval_cfg (Optional[Union[dict, None]]): Configuration for evaluation metrics.
|
36 |
+
test_mode (Optional[bool]): Whether the dataset is in test mode. Default is False.
|
37 |
+
|
38 |
+
Attributes:
|
39 |
+
data_infos (list): Loaded dataset annotations.
|
40 |
+
evaluators (list): List of evaluation objects.
|
41 |
+
eval_indexes (np.ndarray): Array of indices used for evaluation.
|
42 |
+
evaluator_model (torch.nn.Module): Model used for evaluation.
|
43 |
+
pipeline (Compose): Data processing pipeline.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self,
|
47 |
+
data_prefix: str,
|
48 |
+
pipeline: List[Dict],
|
49 |
+
dataset_name: Optional[Union[str, None]] = None,
|
50 |
+
fixed_length: Optional[Union[int, None]] = None,
|
51 |
+
ann_file: Optional[Union[str, None]] = None,
|
52 |
+
motion_dir: Optional[Union[str, None]] = None,
|
53 |
+
eval_cfg: Optional[Union[dict, None]] = None,
|
54 |
+
test_mode: Optional[bool] = False):
|
55 |
+
super(BaseMotionDataset, self).__init__()
|
56 |
+
|
57 |
+
self.data_prefix = data_prefix
|
58 |
+
self.pipeline = Compose(pipeline)
|
59 |
+
self.dataset_name = dataset_name
|
60 |
+
self.fixed_length = fixed_length
|
61 |
+
self.ann_file = os.path.join(data_prefix, 'datasets', dataset_name, ann_file)
|
62 |
+
self.motion_dir = os.path.join(data_prefix, 'datasets', dataset_name, motion_dir)
|
63 |
+
self.eval_cfg = copy.deepcopy(eval_cfg)
|
64 |
+
self.test_mode = test_mode
|
65 |
+
|
66 |
+
self.load_annotations()
|
67 |
+
if self.test_mode:
|
68 |
+
self.prepare_evaluation()
|
69 |
+
|
70 |
+
@abstractmethod
|
71 |
+
def load_anno(self, name: str) -> dict:
|
72 |
+
"""
|
73 |
+
Abstract method to load a single annotation.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
name (str): Name or identifier of the annotation to load.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
dict: Loaded annotation as a dictionary.
|
80 |
+
"""
|
81 |
+
pass
|
82 |
+
|
83 |
+
def load_annotations(self):
|
84 |
+
"""Load annotations from `ann_file` to `data_infos`."""
|
85 |
+
self.data_infos = []
|
86 |
+
idx = 0
|
87 |
+
for line in open(self.ann_file, 'r').readlines():
|
88 |
+
line = line.strip()
|
89 |
+
self.data_infos.append(self.load_anno(idx, line))
|
90 |
+
idx += 1
|
91 |
+
|
92 |
+
def prepare_data(self, idx: int) -> dict:
|
93 |
+
"""
|
94 |
+
Prepare raw data for the given index.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
idx (int): Index of the data to prepare.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
dict: Processed data for the given index.
|
101 |
+
"""
|
102 |
+
results = copy.deepcopy(self.data_infos[idx])
|
103 |
+
results['dataset_name'] = self.dataset_name
|
104 |
+
results['sample_idx'] = idx
|
105 |
+
return self.pipeline(results)
|
106 |
+
|
107 |
+
def __len__(self) -> int:
|
108 |
+
"""Return the length of the current dataset.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
int: Length of the dataset.
|
112 |
+
"""
|
113 |
+
if self.test_mode:
|
114 |
+
return len(self.eval_indexes)
|
115 |
+
elif self.fixed_length is not None:
|
116 |
+
return self.fixed_length
|
117 |
+
return len(self.data_infos)
|
118 |
+
|
119 |
+
def __getitem__(self, idx: int) -> dict:
|
120 |
+
"""
|
121 |
+
Prepare data for the given index.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
idx (int): Index of the data.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
dict: Data for the specified index.
|
128 |
+
"""
|
129 |
+
if self.test_mode:
|
130 |
+
idx = self.eval_indexes[idx]
|
131 |
+
elif self.fixed_length is not None:
|
132 |
+
idx = idx % len(self.data_infos)
|
133 |
+
elif self.balanced_sampling:
|
134 |
+
cid = np.random.randint(0, len(self.category_list))
|
135 |
+
idx = np.random.randint(0, len(self.category_list[cid]))
|
136 |
+
idx = self.category_list[cid][idx]
|
137 |
+
return self.prepare_data(idx)
|
138 |
+
|
139 |
+
def prepare_evaluation(self):
|
140 |
+
"""Prepare evaluation settings, including evaluators and evaluation indices."""
|
141 |
+
self.evaluators = []
|
142 |
+
self.eval_indexes = []
|
143 |
+
self.evaluator_model = build_submodule(self.eval_cfg.get('evaluator_model', None))
|
144 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
145 |
+
self.evaluator_model = self.evaluator_model.to(device)
|
146 |
+
self.evaluator_model.eval()
|
147 |
+
self.eval_cfg['evaluator_model'] = self.evaluator_model
|
148 |
+
|
149 |
+
for _ in range(self.eval_cfg['replication_times']):
|
150 |
+
eval_indexes = np.arange(len(self.data_infos))
|
151 |
+
if self.eval_cfg.get('shuffle_indexes', False):
|
152 |
+
np.random.shuffle(eval_indexes)
|
153 |
+
self.eval_indexes.append(eval_indexes)
|
154 |
+
|
155 |
+
for metric in self.eval_cfg['metrics']:
|
156 |
+
evaluator, self.eval_indexes = build_evaluator(
|
157 |
+
metric, self.eval_cfg, len(self.data_infos), self.eval_indexes)
|
158 |
+
self.evaluators.append(evaluator)
|
159 |
+
|
160 |
+
self.eval_indexes = np.concatenate(self.eval_indexes)
|
161 |
+
|
162 |
+
def evaluate(self, results: List[dict], work_dir: str, logger=None) -> dict:
|
163 |
+
"""
|
164 |
+
Evaluate the model performance based on the results.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
results (list): A list of result dictionaries.
|
168 |
+
work_dir (str): Directory where evaluation logs will be stored.
|
169 |
+
logger: Logger object to record evaluation results (optional).
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
dict: Dictionary containing evaluation metrics.
|
173 |
+
"""
|
174 |
+
metrics = {}
|
175 |
+
for evaluator in self.evaluators:
|
176 |
+
metrics.update(evaluator.evaluate(results))
|
177 |
+
if logger is not None:
|
178 |
+
logger.info(metrics)
|
179 |
+
eval_output = os.path.join(work_dir, 'eval_results.log')
|
180 |
+
with open(eval_output, 'w') as f:
|
181 |
+
for k, v in metrics.items():
|
182 |
+
f.write(k + ': ' + str(v) + '\n')
|
183 |
+
return metrics
|
mogen/datasets/builder.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform
|
2 |
+
import random
|
3 |
+
from functools import partial
|
4 |
+
from typing import Optional, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from mmcv.parallel import collate
|
8 |
+
from mmcv.runner import get_dist_info
|
9 |
+
from mmcv.utils import Registry, build_from_cfg
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from torch.utils.data.dataset import Dataset
|
12 |
+
|
13 |
+
from .samplers import (
|
14 |
+
DistributedSampler,
|
15 |
+
DistributedWeightedRandomSampler,
|
16 |
+
MonoTaskBatchSampler
|
17 |
+
)
|
18 |
+
|
19 |
+
if platform.system() != 'Windows':
|
20 |
+
# https://github.com/pytorch/pytorch/issues/973
|
21 |
+
import resource
|
22 |
+
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
23 |
+
base_soft_limit = rlimit[0]
|
24 |
+
hard_limit = rlimit[1]
|
25 |
+
soft_limit = min(max(4096, base_soft_limit), hard_limit)
|
26 |
+
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
|
27 |
+
|
28 |
+
DATASETS = Registry('dataset')
|
29 |
+
PIPELINES = Registry('pipeline')
|
30 |
+
|
31 |
+
|
32 |
+
def build_dataset(cfg: Union[dict, list, tuple],
|
33 |
+
default_args: Optional[Union[dict, None]] = None):
|
34 |
+
""""Build dataset by the given config."""
|
35 |
+
from .dataset_wrappers import ConcatDataset, RepeatDataset
|
36 |
+
if isinstance(cfg, (list, tuple)):
|
37 |
+
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
|
38 |
+
elif cfg['type'] == 'RepeatDataset':
|
39 |
+
dataset = RepeatDataset(build_dataset(cfg['dataset'], default_args),
|
40 |
+
cfg['times'])
|
41 |
+
else:
|
42 |
+
dataset = build_from_cfg(cfg, DATASETS, default_args)
|
43 |
+
|
44 |
+
return dataset
|
45 |
+
|
46 |
+
|
47 |
+
def build_dataloader(dataset: Dataset,
|
48 |
+
samples_per_gpu: int,
|
49 |
+
workers_per_gpu: int,
|
50 |
+
num_gpus: Optional[int] = 1,
|
51 |
+
dist: Optional[bool] = True,
|
52 |
+
shuffle: Optional[bool] = True,
|
53 |
+
round_up: Optional[bool] = True,
|
54 |
+
seed: Optional[Union[int, None]] = None,
|
55 |
+
sampler_cfg: Optional[dict] = None,
|
56 |
+
batch_sampler_cfg: Optional[dict] = None,
|
57 |
+
persistent_workers: Optional[bool] = True,
|
58 |
+
**kwargs):
|
59 |
+
"""Build PyTorch DataLoader.
|
60 |
+
In distributed training, each GPU/process has a dataloader.
|
61 |
+
In non-distributed training, there is only one dataloader for all GPUs.
|
62 |
+
Args:
|
63 |
+
dataset (:obj:`Dataset`): A PyTorch dataset.
|
64 |
+
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
|
65 |
+
batch size of each GPU.
|
66 |
+
workers_per_gpu (int): How many subprocesses to use for data loading
|
67 |
+
for each GPU.
|
68 |
+
num_gpus (int, optional): Number of GPUs. Only used in non-distributed
|
69 |
+
training.
|
70 |
+
dist (bool, optional): Distributed training/test or not. Default: True.
|
71 |
+
shuffle (bool, optional): Whether to shuffle the data at every epoch.
|
72 |
+
Default: True.
|
73 |
+
round_up (bool, optional): Whether to round up the length of dataset by
|
74 |
+
adding extra samples to make it evenly divisible. Default: True.
|
75 |
+
kwargs: any keyword argument to be used to initialize DataLoader
|
76 |
+
Returns:
|
77 |
+
DataLoader: A PyTorch dataloader.
|
78 |
+
"""
|
79 |
+
rank, world_size = get_dist_info()
|
80 |
+
if dist:
|
81 |
+
weighted_sample = False
|
82 |
+
if sampler_cfg is not None:
|
83 |
+
weighted_sample = sampler_cfg.get('weighted_sample', False)
|
84 |
+
if weighted_sample:
|
85 |
+
sampler_cls = DistributedWeightedRandomSampler
|
86 |
+
else:
|
87 |
+
sampler_cls = DistributedSampler
|
88 |
+
sampler = sampler_cls(
|
89 |
+
dataset,
|
90 |
+
world_size,
|
91 |
+
rank,
|
92 |
+
shuffle=shuffle,
|
93 |
+
round_up=round_up
|
94 |
+
)
|
95 |
+
shuffle = False
|
96 |
+
batch_size = samples_per_gpu
|
97 |
+
num_workers = workers_per_gpu
|
98 |
+
else:
|
99 |
+
sampler = None
|
100 |
+
batch_size = num_gpus * samples_per_gpu
|
101 |
+
num_workers = num_gpus * workers_per_gpu
|
102 |
+
|
103 |
+
init_fn = partial(
|
104 |
+
worker_init_fn, num_workers=num_workers, rank=rank,
|
105 |
+
seed=seed) if seed is not None else None
|
106 |
+
|
107 |
+
if batch_sampler_cfg is not None:
|
108 |
+
type_name = batch_sampler_cfg['type']
|
109 |
+
assert type_name == 'MonoTaskBatchSampler'
|
110 |
+
batch_sampler = MonoTaskBatchSampler(
|
111 |
+
sampler=sampler,
|
112 |
+
batch_size=batch_size,
|
113 |
+
num_tasks = batch_sampler_cfg['num_tasks']
|
114 |
+
)
|
115 |
+
data_loader = DataLoader(
|
116 |
+
dataset,
|
117 |
+
batch_sampler=batch_sampler,
|
118 |
+
num_workers=num_workers,
|
119 |
+
collate_fn=partial(
|
120 |
+
collate, samples_per_gpu=samples_per_gpu),
|
121 |
+
pin_memory=False,
|
122 |
+
shuffle=shuffle,
|
123 |
+
worker_init_fn=init_fn,
|
124 |
+
persistent_workers=persistent_workers,
|
125 |
+
**kwargs)
|
126 |
+
else:
|
127 |
+
data_loader = DataLoader(
|
128 |
+
dataset,
|
129 |
+
batch_size=batch_size,
|
130 |
+
sampler=sampler,
|
131 |
+
num_workers=num_workers,
|
132 |
+
collate_fn=partial(
|
133 |
+
collate, samples_per_gpu=samples_per_gpu),
|
134 |
+
pin_memory=False,
|
135 |
+
shuffle=shuffle,
|
136 |
+
worker_init_fn=init_fn,
|
137 |
+
persistent_workers=persistent_workers,
|
138 |
+
**kwargs)
|
139 |
+
|
140 |
+
return data_loader
|
141 |
+
|
142 |
+
|
143 |
+
def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
|
144 |
+
"""Init random seed for each worker."""
|
145 |
+
# The seed of each worker equals to
|
146 |
+
# num_worker * rank + worker_id + user_seed
|
147 |
+
worker_seed = num_workers * rank + worker_id + seed
|
148 |
+
np.random.seed(worker_seed)
|
149 |
+
random.seed(worker_seed)
|
mogen/datasets/dataset_wrappers.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
2 |
+
from torch.utils.data.dataset import Dataset
|
3 |
+
|
4 |
+
from .builder import DATASETS
|
5 |
+
|
6 |
+
|
7 |
+
@DATASETS.register_module()
|
8 |
+
class ConcatDataset(_ConcatDataset):
|
9 |
+
"""A wrapper of concatenated dataset.
|
10 |
+
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
|
11 |
+
add `get_cat_ids` function.
|
12 |
+
Args:
|
13 |
+
datasets (list[:obj:`Dataset`]): A list of datasets.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, datasets: list):
|
17 |
+
super(ConcatDataset, self).__init__(datasets)
|
18 |
+
|
19 |
+
|
20 |
+
@DATASETS.register_module()
|
21 |
+
class RepeatDataset(object):
|
22 |
+
"""A wrapper of repeated dataset.
|
23 |
+
The length of repeated dataset will be `times` larger than the original
|
24 |
+
dataset. This is useful when the data loading time is long but the dataset
|
25 |
+
is small. Using RepeatDataset can reduce the data loading time between
|
26 |
+
epochs.
|
27 |
+
Args:
|
28 |
+
dataset (:obj:`Dataset`): The dataset to be repeated.
|
29 |
+
times (int): Repeat times.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, dataset: Dataset, times: int):
|
33 |
+
self.dataset = dataset
|
34 |
+
self.times = times
|
35 |
+
|
36 |
+
self._ori_len = len(self.dataset)
|
37 |
+
|
38 |
+
def __getitem__(self, idx: int):
|
39 |
+
return self.dataset[idx % self._ori_len]
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return self.times * self._ori_len
|
mogen/datasets/human_body_prior/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2018.01.02
|
mogen/datasets/human_body_prior/body_model/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2018.01.02
|
mogen/datasets/human_body_prior/body_model/body_model.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2018.12.13
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
|
29 |
+
# from smplx.lbs import lbs
|
30 |
+
from .lbs import lbs
|
31 |
+
import sys
|
32 |
+
|
33 |
+
class BodyModel(nn.Module):
|
34 |
+
|
35 |
+
def __init__(self,
|
36 |
+
bm_fname,
|
37 |
+
num_betas=10,
|
38 |
+
num_dmpls=None, dmpl_fname=None,
|
39 |
+
num_expressions=80,
|
40 |
+
use_posedirs=True,
|
41 |
+
dtype=torch.float32,
|
42 |
+
persistant_buffer=False):
|
43 |
+
|
44 |
+
super(BodyModel, self).__init__()
|
45 |
+
|
46 |
+
'''
|
47 |
+
:param bm_fname: path to a SMPL model as pkl file
|
48 |
+
:param num_betas: number of shape parameters to include.
|
49 |
+
:param device: default on gpu
|
50 |
+
:param dtype: float precision of the computations
|
51 |
+
:return: verts, trans, pose, betas
|
52 |
+
'''
|
53 |
+
|
54 |
+
self.dtype = dtype
|
55 |
+
|
56 |
+
|
57 |
+
# -- Load SMPL params --
|
58 |
+
if '.npz' in bm_fname:
|
59 |
+
smpl_dict = np.load(bm_fname, encoding='latin1')
|
60 |
+
else:
|
61 |
+
raise ValueError('bm_fname should be either a .pkl nor .npz file')
|
62 |
+
|
63 |
+
# these are supposed for later convenient look up
|
64 |
+
self.num_betas = num_betas
|
65 |
+
self.num_dmpls = num_dmpls
|
66 |
+
self.num_expressions = num_expressions
|
67 |
+
|
68 |
+
njoints = smpl_dict['posedirs'].shape[2] // 3
|
69 |
+
self.model_type = {69: 'smpl', 153: 'smplh', 162: 'smplx', 45: 'mano', 105: 'animal_horse', 102: 'animal_dog', }[njoints]
|
70 |
+
|
71 |
+
assert self.model_type in ['smpl', 'smplh', 'smplx', 'mano', 'mano', 'animal_horse', 'animal_dog'], ValueError(
|
72 |
+
'model_type should be in smpl/smplh/smplx/mano.')
|
73 |
+
|
74 |
+
self.use_dmpl = False
|
75 |
+
if num_dmpls is not None:
|
76 |
+
if dmpl_fname is not None:
|
77 |
+
self.use_dmpl = True
|
78 |
+
else:
|
79 |
+
raise (ValueError('dmpl_fname should be provided when using dmpls!'))
|
80 |
+
|
81 |
+
if self.use_dmpl and self.model_type in ['smplx', 'mano', 'animal_horse', 'animal_dog']: raise (
|
82 |
+
NotImplementedError('DMPLs only work with SMPL/SMPLH models for now.'))
|
83 |
+
|
84 |
+
# Mean template vertices
|
85 |
+
self.comp_register('init_v_template', torch.tensor(smpl_dict['v_template'][None], dtype=dtype), persistent=persistant_buffer)
|
86 |
+
|
87 |
+
self.comp_register('f', torch.tensor(smpl_dict['f'].astype(np.int32), dtype=torch.int32), persistent=persistant_buffer)
|
88 |
+
|
89 |
+
num_total_betas = smpl_dict['shapedirs'].shape[-1]
|
90 |
+
if num_betas < 1:
|
91 |
+
num_betas = num_total_betas
|
92 |
+
|
93 |
+
shapedirs = smpl_dict['shapedirs'][:, :, :num_betas]
|
94 |
+
self.comp_register('shapedirs', torch.tensor(shapedirs, dtype=dtype), persistent=persistant_buffer)
|
95 |
+
|
96 |
+
if self.model_type == 'smplx':
|
97 |
+
if smpl_dict['shapedirs'].shape[-1] > 300:
|
98 |
+
begin_shape_id = 300
|
99 |
+
else:
|
100 |
+
begin_shape_id = 10
|
101 |
+
num_expressions = smpl_dict['shapedirs'].shape[-1] - 10
|
102 |
+
|
103 |
+
exprdirs = smpl_dict['shapedirs'][:, :, begin_shape_id:(begin_shape_id + num_expressions)]
|
104 |
+
self.comp_register('exprdirs', torch.tensor(exprdirs, dtype=dtype), persistent=persistant_buffer)
|
105 |
+
|
106 |
+
expression = torch.tensor(np.zeros((1, num_expressions)), dtype=dtype)
|
107 |
+
self.comp_register('init_expression', expression, persistent=persistant_buffer)
|
108 |
+
|
109 |
+
if self.use_dmpl:
|
110 |
+
dmpldirs = np.load(dmpl_fname)['eigvec']
|
111 |
+
|
112 |
+
dmpldirs = dmpldirs[:, :, :num_dmpls]
|
113 |
+
self.comp_register('dmpldirs', torch.tensor(dmpldirs, dtype=dtype), persistent=persistant_buffer)
|
114 |
+
|
115 |
+
# Regressor for joint locations given shape - 6890 x 24
|
116 |
+
self.comp_register('J_regressor', torch.tensor(smpl_dict['J_regressor'], dtype=dtype), persistent=persistant_buffer)
|
117 |
+
|
118 |
+
# Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*30 x 207
|
119 |
+
if use_posedirs:
|
120 |
+
posedirs = smpl_dict['posedirs']
|
121 |
+
posedirs = posedirs.reshape([posedirs.shape[0] * 3, -1]).T
|
122 |
+
self.comp_register('posedirs', torch.tensor(posedirs, dtype=dtype), persistent=persistant_buffer)
|
123 |
+
else:
|
124 |
+
self.posedirs = None
|
125 |
+
|
126 |
+
# indices of parents for each joints
|
127 |
+
kintree_table = smpl_dict['kintree_table'].astype(np.int32)
|
128 |
+
self.comp_register('kintree_table', torch.tensor(kintree_table, dtype=torch.int32), persistent=persistant_buffer)
|
129 |
+
|
130 |
+
# LBS weights
|
131 |
+
# weights = np.repeat(smpl_dict['weights'][np.newaxis], batch_size, axis=0)
|
132 |
+
weights = smpl_dict['weights']
|
133 |
+
self.comp_register('weights', torch.tensor(weights, dtype=dtype), persistent=persistant_buffer)
|
134 |
+
|
135 |
+
self.comp_register('init_trans', torch.zeros((1,3), dtype=dtype), persistent=persistant_buffer)
|
136 |
+
# self.register_parameter('trans', nn.Parameter(trans, requires_grad=True))
|
137 |
+
|
138 |
+
# root_orient
|
139 |
+
# if self.model_type in ['smpl', 'smplh']:
|
140 |
+
self.comp_register('init_root_orient', torch.zeros((1,3), dtype=dtype), persistent=persistant_buffer)
|
141 |
+
|
142 |
+
# pose_body
|
143 |
+
if self.model_type in ['smpl', 'smplh', 'smplx']:
|
144 |
+
self.comp_register('init_pose_body', torch.zeros((1,63), dtype=dtype), persistent=persistant_buffer)
|
145 |
+
elif self.model_type == 'animal_horse':
|
146 |
+
self.comp_register('init_pose_body', torch.zeros((1,105), dtype=dtype), persistent=persistant_buffer)
|
147 |
+
elif self.model_type == 'animal_dog':
|
148 |
+
self.comp_register('init_pose_body', torch.zeros((1,102), dtype=dtype), persistent=persistant_buffer)
|
149 |
+
|
150 |
+
# pose_hand
|
151 |
+
if self.model_type in ['smpl']:
|
152 |
+
self.comp_register('init_pose_hand', torch.zeros((1,1*3*2), dtype=dtype), persistent=persistant_buffer)
|
153 |
+
elif self.model_type in ['smplh', 'smplx']:
|
154 |
+
self.comp_register('init_pose_hand', torch.zeros((1,15*3*2), dtype=dtype), persistent=persistant_buffer)
|
155 |
+
elif self.model_type in ['mano']:
|
156 |
+
self.comp_register('init_pose_hand', torch.zeros((1,15*3), dtype=dtype), persistent=persistant_buffer)
|
157 |
+
|
158 |
+
# face poses
|
159 |
+
if self.model_type == 'smplx':
|
160 |
+
self.comp_register('init_pose_jaw', torch.zeros((1,1*3), dtype=dtype), persistent=persistant_buffer)
|
161 |
+
self.comp_register('init_pose_eye', torch.zeros((1,2*3), dtype=dtype), persistent=persistant_buffer)
|
162 |
+
|
163 |
+
self.comp_register('init_betas', torch.zeros((1,num_betas), dtype=dtype), persistent=persistant_buffer)
|
164 |
+
|
165 |
+
if self.use_dmpl:
|
166 |
+
self.comp_register('init_dmpls', torch.zeros((1,num_dmpls), dtype=dtype), persistent=persistant_buffer)
|
167 |
+
|
168 |
+
def comp_register(self, name, value, persistent=False):
|
169 |
+
if sys.version_info[0] > 2:
|
170 |
+
self.register_buffer(name, value, persistent)
|
171 |
+
else:
|
172 |
+
self.register_buffer(name, value)
|
173 |
+
|
174 |
+
def r(self):
|
175 |
+
from human_body_prior.tools.omni_tools import copy2cpu as c2c
|
176 |
+
return c2c(self.forward().v)
|
177 |
+
|
178 |
+
def forward(self, root_orient=None, pose_body=None, pose_hand=None, pose_jaw=None, pose_eye=None, betas=None,
|
179 |
+
trans=None, dmpls=None, expression=None, v_template =None, joints=None, v_shaped=None, return_dict=False, **kwargs):
|
180 |
+
'''
|
181 |
+
|
182 |
+
:param root_orient: Nx3
|
183 |
+
:param pose_body:
|
184 |
+
:param pose_hand:
|
185 |
+
:param pose_jaw:
|
186 |
+
:param pose_eye:
|
187 |
+
:param kwargs:
|
188 |
+
:return:
|
189 |
+
'''
|
190 |
+
batch_size = 1
|
191 |
+
# compute batchsize by any of the provided variables
|
192 |
+
for arg in [root_orient,pose_body,pose_hand,pose_jaw,pose_eye,betas,trans, dmpls,expression, v_template,joints]:
|
193 |
+
if arg is not None:
|
194 |
+
batch_size = arg.shape[0]
|
195 |
+
break
|
196 |
+
|
197 |
+
# assert not (v_template is not None and betas is not None), ValueError('vtemplate and betas could not be used jointly.')
|
198 |
+
assert self.model_type in ['smpl', 'smplh', 'smplx', 'mano', 'animal_horse', 'animal_dog'], ValueError(
|
199 |
+
'model_type should be in smpl/smplh/smplx/mano')
|
200 |
+
if root_orient is None: root_orient = self.init_root_orient.expand(batch_size, -1)
|
201 |
+
if self.model_type in ['smplh', 'smpl']:
|
202 |
+
if pose_body is None: pose_body = self.init_pose_body.expand(batch_size, -1)
|
203 |
+
if pose_hand is None: pose_hand = self.init_pose_hand.expand(batch_size, -1)
|
204 |
+
elif self.model_type == 'smplx':
|
205 |
+
if pose_body is None: pose_body = self.init_pose_body.expand(batch_size, -1)
|
206 |
+
if pose_hand is None: pose_hand = self.init_pose_hand.expand(batch_size, -1)
|
207 |
+
if pose_jaw is None: pose_jaw = self.init_pose_jaw.expand(batch_size, -1)
|
208 |
+
if pose_eye is None: pose_eye = self.init_pose_eye.expand(batch_size, -1)
|
209 |
+
elif self.model_type in ['mano',]:
|
210 |
+
if pose_hand is None: pose_hand = self.init_pose_hand.expand(batch_size, -1)
|
211 |
+
elif self.model_type in ['animal_horse','animal_dog']:
|
212 |
+
if pose_body is None: pose_body = self.init_pose_body.expand(batch_size, -1)
|
213 |
+
|
214 |
+
if pose_hand is None and self.model_type not in ['animal_horse', 'animal_dog']: pose_hand = self.init_pose_hand.expand(batch_size, -1)
|
215 |
+
|
216 |
+
if trans is None: trans = self.init_trans.expand(batch_size, -1)
|
217 |
+
if v_template is None: v_template = self.init_v_template.expand(batch_size, -1,-1)
|
218 |
+
if betas is None: betas = self.init_betas.expand(batch_size, -1)
|
219 |
+
|
220 |
+
if self.model_type in ['smplh', 'smpl']:
|
221 |
+
full_pose = torch.cat([root_orient, pose_body, pose_hand], dim=-1)
|
222 |
+
elif self.model_type == 'smplx':
|
223 |
+
full_pose = torch.cat([root_orient, pose_body, pose_jaw, pose_eye, pose_hand], dim=-1) # orient:3, body:63, jaw:3, eyel:3, eyer:3, handl, handr
|
224 |
+
elif self.model_type in ['mano', ]:
|
225 |
+
full_pose = torch.cat([root_orient, pose_hand], dim=-1)
|
226 |
+
elif self.model_type in ['animal_horse', 'animal_dog']:
|
227 |
+
full_pose = torch.cat([root_orient, pose_body], dim=-1)
|
228 |
+
|
229 |
+
if self.use_dmpl:
|
230 |
+
if dmpls is None: dmpls = self.init_dmpls.expand(batch_size, -1)
|
231 |
+
shape_components = torch.cat([betas, dmpls], dim=-1)
|
232 |
+
shapedirs = torch.cat([self.shapedirs, self.dmpldirs], dim=-1)
|
233 |
+
elif self.model_type == 'smplx':
|
234 |
+
if expression is None: expression = self.init_expression.expand(batch_size, -1)
|
235 |
+
shape_components = torch.cat([betas, expression], dim=-1)
|
236 |
+
shapedirs = torch.cat([self.shapedirs, self.exprdirs], dim=-1)
|
237 |
+
else:
|
238 |
+
shape_components = betas
|
239 |
+
shapedirs = self.shapedirs
|
240 |
+
|
241 |
+
verts, Jtr = lbs(betas=shape_components, pose=full_pose, v_template=v_template,
|
242 |
+
shapedirs=shapedirs, posedirs=self.posedirs,
|
243 |
+
J_regressor=self.J_regressor, parents=self.kintree_table[0].long(),
|
244 |
+
lbs_weights=self.weights, joints=joints, v_shaped=v_shaped,
|
245 |
+
dtype=self.dtype)
|
246 |
+
|
247 |
+
Jtr = Jtr + trans.unsqueeze(dim=1)
|
248 |
+
verts = verts + trans.unsqueeze(dim=1)
|
249 |
+
|
250 |
+
res = {}
|
251 |
+
res['v'] = verts
|
252 |
+
res['f'] = self.f
|
253 |
+
res['Jtr'] = Jtr # Todo: ik can be made with vposer
|
254 |
+
# res['bStree_table'] = self.kintree_table
|
255 |
+
|
256 |
+
# if self.model_type == 'smpl':
|
257 |
+
# res['pose_body'] = pose_body
|
258 |
+
# elif self.model_type == 'smplh':
|
259 |
+
# res['pose_body'] = pose_body
|
260 |
+
# res['pose_hand'] = pose_hand
|
261 |
+
# elif self.model_type == 'smplx':
|
262 |
+
# res['pose_body'] = pose_body
|
263 |
+
# res['pose_hand'] = pose_hand
|
264 |
+
# res['pose_jaw'] = pose_jaw
|
265 |
+
# res['pose_eye'] = pose_eye
|
266 |
+
# elif self.model_type in ['mano', 'mano']:
|
267 |
+
# res['pose_hand'] = pose_hand
|
268 |
+
res['full_pose'] = full_pose
|
269 |
+
|
270 |
+
if not return_dict:
|
271 |
+
class result_meta(object):
|
272 |
+
pass
|
273 |
+
|
274 |
+
res_class = result_meta()
|
275 |
+
for k, v in res.items():
|
276 |
+
res_class.__setattr__(k, v)
|
277 |
+
res = res_class
|
278 |
+
|
279 |
+
return res
|
280 |
+
|
281 |
+
|
mogen/datasets/human_body_prior/body_model/lbs.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Vassilis Choutas <https://vchoutas.github.io/>
|
21 |
+
#
|
22 |
+
|
23 |
+
from __future__ import absolute_import
|
24 |
+
from __future__ import print_function
|
25 |
+
from __future__ import division
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
|
29 |
+
import torch
|
30 |
+
import torch.nn.functional as F
|
31 |
+
|
32 |
+
def to_tensor(array, dtype=torch.float32):
|
33 |
+
if 'torch.tensor' not in str(type(array)):
|
34 |
+
return torch.tensor(array, dtype=dtype)
|
35 |
+
|
36 |
+
|
37 |
+
class Struct(object):
|
38 |
+
def __init__(self, **kwargs):
|
39 |
+
for key, val in kwargs.items():
|
40 |
+
setattr(self, key, val)
|
41 |
+
|
42 |
+
|
43 |
+
def to_np(array, dtype=np.float32):
|
44 |
+
if 'scipy.sparse' in str(type(array)):
|
45 |
+
array = array.todense()
|
46 |
+
return np.array(array, dtype=dtype)
|
47 |
+
|
48 |
+
|
49 |
+
def rot_mat_to_euler(rot_mats):
|
50 |
+
# Calculates rotation matrix to euler angles
|
51 |
+
# Careful for extreme cases of eular angles like [0.0, pi, 0.0]
|
52 |
+
|
53 |
+
sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
|
54 |
+
rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
|
55 |
+
return torch.atan2(-rot_mats[:, 2, 0], sy)
|
56 |
+
|
57 |
+
|
58 |
+
def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx,
|
59 |
+
dynamic_lmk_b_coords,
|
60 |
+
neck_kin_chain, dtype=torch.float32):
|
61 |
+
''' Compute the faces, barycentric coordinates for the dynamic landmarks
|
62 |
+
|
63 |
+
|
64 |
+
To do so, we first compute the rotation of the neck around the y-axis
|
65 |
+
and then use a pre-computed look-up table to find the faces and the
|
66 |
+
barycentric coordinates that will be used.
|
67 |
+
|
68 |
+
Special thanks to Soubhik Sanyal ([email protected])
|
69 |
+
for providing the original TensorFlow implementation and for the LUT.
|
70 |
+
|
71 |
+
Parameters
|
72 |
+
----------
|
73 |
+
vertices: torch.tensor BxVx3, dtype = torch.float32
|
74 |
+
The tensor of input vertices
|
75 |
+
pose: torch.tensor Bx(Jx3), dtype = torch.float32
|
76 |
+
The current pose of the body model
|
77 |
+
dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long
|
78 |
+
The look-up table from neck rotation to faces
|
79 |
+
dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32
|
80 |
+
The look-up table from neck rotation to barycentric coordinates
|
81 |
+
neck_kin_chain: list
|
82 |
+
A python list that contains the indices of the joints that form the
|
83 |
+
kinematic chain of the neck.
|
84 |
+
dtype: torch.dtype, optional
|
85 |
+
|
86 |
+
Returns
|
87 |
+
-------
|
88 |
+
dyn_lmk_faces_idx: torch.tensor, dtype = torch.long
|
89 |
+
A tensor of size BxL that contains the indices of the faces that
|
90 |
+
will be used to compute the current dynamic landmarks.
|
91 |
+
dyn_lmk_b_coords: torch.tensor, dtype = torch.float32
|
92 |
+
A tensor of size BxL that contains the indices of the faces that
|
93 |
+
will be used to compute the current dynamic landmarks.
|
94 |
+
'''
|
95 |
+
|
96 |
+
batch_size = vertices.shape[0]
|
97 |
+
|
98 |
+
aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
|
99 |
+
neck_kin_chain)
|
100 |
+
rot_mats = batch_rodrigues(
|
101 |
+
aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
|
102 |
+
|
103 |
+
rel_rot_mat = torch.eye(3, device=vertices.device,
|
104 |
+
dtype=dtype).unsqueeze_(dim=0)
|
105 |
+
for idx in range(len(neck_kin_chain)):
|
106 |
+
rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
|
107 |
+
|
108 |
+
y_rot_angle = torch.round(
|
109 |
+
torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
|
110 |
+
max=39)).to(dtype=torch.long)
|
111 |
+
neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
|
112 |
+
mask = y_rot_angle.lt(-39).to(dtype=torch.long)
|
113 |
+
neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
|
114 |
+
y_rot_angle = (neg_mask * neg_vals +
|
115 |
+
(1 - neg_mask) * y_rot_angle)
|
116 |
+
|
117 |
+
dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
|
118 |
+
0, y_rot_angle)
|
119 |
+
dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
|
120 |
+
0, y_rot_angle)
|
121 |
+
|
122 |
+
return dyn_lmk_faces_idx, dyn_lmk_b_coords
|
123 |
+
|
124 |
+
|
125 |
+
def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
|
126 |
+
''' Calculates landmarks by barycentric interpolation
|
127 |
+
|
128 |
+
Parameters
|
129 |
+
----------
|
130 |
+
vertices: torch.tensor BxVx3, dtype = torch.float32
|
131 |
+
The tensor of input vertices
|
132 |
+
faces: torch.tensor Fx3, dtype = torch.long
|
133 |
+
The faces of the mesh
|
134 |
+
lmk_faces_idx: torch.tensor L, dtype = torch.long
|
135 |
+
The tensor with the indices of the faces used to calculate the
|
136 |
+
landmarks.
|
137 |
+
lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32
|
138 |
+
The tensor of barycentric coordinates that are used to interpolate
|
139 |
+
the landmarks
|
140 |
+
|
141 |
+
Returns
|
142 |
+
-------
|
143 |
+
landmarks: torch.tensor BxLx3, dtype = torch.float32
|
144 |
+
The coordinates of the landmarks for each mesh in the batch
|
145 |
+
'''
|
146 |
+
# Extract the indices of the vertices for each face
|
147 |
+
# BxLx3
|
148 |
+
batch_size, num_verts = vertices.shape[:2]
|
149 |
+
device = vertices.device
|
150 |
+
|
151 |
+
lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
|
152 |
+
batch_size, -1, 3)
|
153 |
+
|
154 |
+
lmk_faces += torch.arange(
|
155 |
+
batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
|
156 |
+
|
157 |
+
lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(
|
158 |
+
batch_size, -1, 3, 3)
|
159 |
+
|
160 |
+
landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
|
161 |
+
return landmarks
|
162 |
+
|
163 |
+
|
164 |
+
def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
|
165 |
+
lbs_weights, joints = None, pose2rot=True, v_shaped=None, dtype=torch.float32):
|
166 |
+
''' Performs Linear Blend Skinning with the given shape and pose parameters
|
167 |
+
|
168 |
+
Parameters
|
169 |
+
----------
|
170 |
+
betas : torch.tensor BxNB
|
171 |
+
The tensor of shape parameters
|
172 |
+
pose : torch.tensor Bx(J + 1) * 3
|
173 |
+
The pose parameters in axis-angle format
|
174 |
+
v_template torch.tensor BxVx3
|
175 |
+
The template mesh that will be deformed
|
176 |
+
shapedirs : torch.tensor 1xNB
|
177 |
+
The tensor of PCA shape displacements
|
178 |
+
posedirs : torch.tensor Px(V * 3)
|
179 |
+
The pose PCA coefficients
|
180 |
+
J_regressor : torch.tensor JxV
|
181 |
+
The regressor array that is used to calculate the joints from
|
182 |
+
the position of the vertices
|
183 |
+
parents: torch.tensor J
|
184 |
+
The array that describes the kinematic tree for the model
|
185 |
+
lbs_weights: torch.tensor N x V x (J + 1)
|
186 |
+
The linear blend skinning weights that represent how much the
|
187 |
+
rotation matrix of each part affects each vertex
|
188 |
+
pose2rot: bool, optional
|
189 |
+
Flag on whether to convert the input pose tensor to rotation
|
190 |
+
matrices. The default value is True. If False, then the pose tensor
|
191 |
+
should already contain rotation matrices and have a size of
|
192 |
+
Bx(J + 1)x9
|
193 |
+
dtype: torch.dtype, optional
|
194 |
+
|
195 |
+
Returns
|
196 |
+
-------
|
197 |
+
verts: torch.tensor BxVx3
|
198 |
+
The vertices of the mesh after applying the shape and pose
|
199 |
+
displacements.
|
200 |
+
joints: torch.tensor BxJx3
|
201 |
+
The joints of the model
|
202 |
+
'''
|
203 |
+
|
204 |
+
batch_size = max(betas.shape[0], pose.shape[0])
|
205 |
+
device = betas.device
|
206 |
+
|
207 |
+
# Add shape contribution
|
208 |
+
if v_shaped is None:
|
209 |
+
v_shaped = v_template + blend_shapes(betas, shapedirs)
|
210 |
+
|
211 |
+
# Get the joints
|
212 |
+
# NxJx3 array
|
213 |
+
if joints is not None:
|
214 |
+
J = joints
|
215 |
+
else:
|
216 |
+
J = vertices2joints(J_regressor, v_shaped)
|
217 |
+
|
218 |
+
# 3. Add pose blend shapes
|
219 |
+
# N x J x 3 x 3
|
220 |
+
ident = torch.eye(3, dtype=dtype, device=device)
|
221 |
+
if pose2rot:
|
222 |
+
rot_mats = batch_rodrigues(
|
223 |
+
pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
|
224 |
+
|
225 |
+
pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
|
226 |
+
# (N x P) x (P, V * 3) -> N x V x 3
|
227 |
+
pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3)
|
228 |
+
else:
|
229 |
+
pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
|
230 |
+
rot_mats = pose.view(batch_size, -1, 3, 3)
|
231 |
+
|
232 |
+
pose_offsets = torch.matmul(pose_feature.view(batch_size, -1),
|
233 |
+
posedirs).view(batch_size, -1, 3)
|
234 |
+
|
235 |
+
v_posed = pose_offsets + v_shaped
|
236 |
+
# 4. Get the global joint location
|
237 |
+
J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
|
238 |
+
|
239 |
+
# 5. Do skinning:
|
240 |
+
# W is N x V x (J + 1)
|
241 |
+
W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
|
242 |
+
# (N x V x (J + 1)) x (N x (J + 1) x 16)
|
243 |
+
num_joints = J_regressor.shape[0]
|
244 |
+
T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \
|
245 |
+
.view(batch_size, -1, 4, 4)
|
246 |
+
|
247 |
+
homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
|
248 |
+
dtype=dtype, device=device)
|
249 |
+
v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
|
250 |
+
v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
|
251 |
+
|
252 |
+
verts = v_homo[:, :, :3, 0]
|
253 |
+
|
254 |
+
return verts, J_transformed
|
255 |
+
|
256 |
+
|
257 |
+
def vertices2joints(J_regressor, vertices):
|
258 |
+
''' Calculates the 3D joint locations from the vertices
|
259 |
+
|
260 |
+
Parameters
|
261 |
+
----------
|
262 |
+
J_regressor : torch.tensor JxV
|
263 |
+
The regressor array that is used to calculate the joints from the
|
264 |
+
position of the vertices
|
265 |
+
vertices : torch.tensor BxVx3
|
266 |
+
The tensor of mesh vertices
|
267 |
+
|
268 |
+
Returns
|
269 |
+
-------
|
270 |
+
torch.tensor BxJx3
|
271 |
+
The location of the joints
|
272 |
+
'''
|
273 |
+
|
274 |
+
return torch.einsum('bik,ji->bjk', [vertices, J_regressor])
|
275 |
+
|
276 |
+
|
277 |
+
def blend_shapes(betas, shape_disps):
|
278 |
+
''' Calculates the per vertex displacement due to the blend shapes
|
279 |
+
|
280 |
+
|
281 |
+
Parameters
|
282 |
+
----------
|
283 |
+
betas : torch.tensor Bx(num_betas)
|
284 |
+
Blend shape coefficients
|
285 |
+
shape_disps: torch.tensor Vx3x(num_betas)
|
286 |
+
Blend shapes
|
287 |
+
|
288 |
+
Returns
|
289 |
+
-------
|
290 |
+
torch.tensor BxVx3
|
291 |
+
The per-vertex displacement due to shape deformation
|
292 |
+
'''
|
293 |
+
|
294 |
+
# Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
|
295 |
+
# i.e. Multiply each shape displacement by its corresponding beta and
|
296 |
+
# then sum them.
|
297 |
+
|
298 |
+
#print(betas.device,shape_disps.device)
|
299 |
+
blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])
|
300 |
+
return blend_shape
|
301 |
+
|
302 |
+
|
303 |
+
def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
|
304 |
+
''' Calculates the rotation matrices for a batch of rotation vectors
|
305 |
+
Parameters
|
306 |
+
----------
|
307 |
+
rot_vecs: torch.tensor Nx3
|
308 |
+
array of N axis-angle vectors
|
309 |
+
Returns
|
310 |
+
-------
|
311 |
+
R: torch.tensor Nx3x3
|
312 |
+
The rotation matrices for the given axis-angle parameters
|
313 |
+
'''
|
314 |
+
|
315 |
+
batch_size = rot_vecs.shape[0]
|
316 |
+
device = rot_vecs.device
|
317 |
+
|
318 |
+
angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
|
319 |
+
rot_dir = rot_vecs / angle
|
320 |
+
|
321 |
+
cos = torch.unsqueeze(torch.cos(angle), dim=1)
|
322 |
+
sin = torch.unsqueeze(torch.sin(angle), dim=1)
|
323 |
+
|
324 |
+
# Bx1 arrays
|
325 |
+
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
|
326 |
+
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
|
327 |
+
|
328 |
+
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
|
329 |
+
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
|
330 |
+
.view((batch_size, 3, 3))
|
331 |
+
|
332 |
+
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
|
333 |
+
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
|
334 |
+
return rot_mat
|
335 |
+
|
336 |
+
|
337 |
+
def transform_mat(R, t):
|
338 |
+
''' Creates a batch of transformation matrices
|
339 |
+
Args:
|
340 |
+
- R: Bx3x3 array of a batch of rotation matrices
|
341 |
+
- t: Bx3x1 array of a batch of translation vectors
|
342 |
+
Returns:
|
343 |
+
- T: Bx4x4 Transformation matrix
|
344 |
+
'''
|
345 |
+
# No padding left or right, only add an extra row
|
346 |
+
return torch.cat([F.pad(R, [0, 0, 0, 1]),
|
347 |
+
F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
|
348 |
+
|
349 |
+
|
350 |
+
def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
|
351 |
+
"""
|
352 |
+
Applies a batch of rigid transformations to the joints
|
353 |
+
|
354 |
+
Parameters
|
355 |
+
----------
|
356 |
+
rot_mats : torch.tensor BxNx3x3
|
357 |
+
Tensor of rotation matrices
|
358 |
+
joints : torch.tensor BxNx3
|
359 |
+
Locations of joints
|
360 |
+
parents : torch.tensor BxN
|
361 |
+
The kinematic tree of each object
|
362 |
+
dtype : torch.dtype, optional:
|
363 |
+
The data type of the created tensors, the default is torch.float32
|
364 |
+
|
365 |
+
Returns
|
366 |
+
-------
|
367 |
+
posed_joints : torch.tensor BxNx3
|
368 |
+
The locations of the joints after applying the pose rotations
|
369 |
+
rel_transforms : torch.tensor BxNx4x4
|
370 |
+
The relative (with respect to the root joint) rigid transformations
|
371 |
+
for all the joints
|
372 |
+
"""
|
373 |
+
|
374 |
+
joints = torch.unsqueeze(joints, dim=-1)
|
375 |
+
|
376 |
+
rel_joints = joints.clone()
|
377 |
+
rel_joints[:, 1:] -= joints[:, parents[1:]]
|
378 |
+
|
379 |
+
transforms_mat = transform_mat(
|
380 |
+
rot_mats.reshape(-1, 3, 3),
|
381 |
+
rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
|
382 |
+
|
383 |
+
transform_chain = [transforms_mat[:, 0]]
|
384 |
+
for i in range(1, parents.shape[0]):
|
385 |
+
# Subtract the joint location at the rest pose
|
386 |
+
# No need for rotation, since it's identity when at rest
|
387 |
+
curr_res = torch.matmul(transform_chain[parents[i]],
|
388 |
+
transforms_mat[:, i])
|
389 |
+
transform_chain.append(curr_res)
|
390 |
+
|
391 |
+
transforms = torch.stack(transform_chain, dim=1)
|
392 |
+
|
393 |
+
# The last column of the transformations contains the posed joints
|
394 |
+
posed_joints = transforms[:, :, :3, 3]
|
395 |
+
|
396 |
+
# The last column of the transformations contains the posed joints
|
397 |
+
posed_joints = transforms[:, :, :3, 3]
|
398 |
+
|
399 |
+
joints_homogen = F.pad(joints, [0, 0, 0, 1])
|
400 |
+
|
401 |
+
rel_transforms = transforms - F.pad(
|
402 |
+
torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
|
403 |
+
|
404 |
+
return posed_joints, rel_transforms
|
mogen/datasets/human_body_prior/body_model/parts_segm/readme
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
### Parts segmentation file obtained from https://github.com/vchoutas/torch-mesh-isect#examples and put here for convenience
|
mogen/datasets/human_body_prior/body_model/rigid_object_model.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2018.12.13
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
|
29 |
+
# from smplx.lbs import lbs
|
30 |
+
from human_body_prior.body_model.lbs import lbs
|
31 |
+
# import trimesh # dont use this package for loading meshes since it messes up the order of vertices
|
32 |
+
from psbody.mesh import Mesh
|
33 |
+
from human_body_prior.body_model.lbs import batch_rodrigues
|
34 |
+
|
35 |
+
class RigidObjectModel(nn.Module):
|
36 |
+
|
37 |
+
def __init__(self, plpath, batch_size=1, dtype=torch.float32):
|
38 |
+
super(RigidObjectModel, self).__init__()
|
39 |
+
|
40 |
+
trans = torch.tensor(np.zeros((batch_size, 3)), dtype=dtype, requires_grad=True)
|
41 |
+
self.register_parameter('trans', nn.Parameter(trans, requires_grad=True))
|
42 |
+
|
43 |
+
root_orient = torch.tensor(np.zeros((batch_size, 3)), dtype=dtype, requires_grad=True)
|
44 |
+
self.register_parameter('root_orient', nn.Parameter(root_orient, requires_grad=True))
|
45 |
+
|
46 |
+
mesh = Mesh(filename=plpath)
|
47 |
+
|
48 |
+
self.rigid_v = torch.from_numpy(np.repeat(mesh.v[np.newaxis], batch_size, axis=0)).type(dtype)
|
49 |
+
self.f = torch.from_numpy(mesh.f.astype(np.int32))
|
50 |
+
|
51 |
+
def forward(self, root_orient, trans):
|
52 |
+
if root_orient is None: root_orient = self.root_orient
|
53 |
+
if trans is None: trans = self.trans
|
54 |
+
verts = torch.bmm(self.rigid_v, batch_rodrigues(root_orient)) + trans.view(-1,1,3)
|
55 |
+
|
56 |
+
res = {}
|
57 |
+
res['v'] = verts
|
58 |
+
res['f'] = self.f
|
59 |
+
|
60 |
+
class result_meta(object): pass
|
61 |
+
|
62 |
+
res_class = result_meta()
|
63 |
+
for k, v in res.items():
|
64 |
+
res_class.__setattr__(k, v)
|
65 |
+
res = res_class
|
66 |
+
|
67 |
+
return res
|
mogen/datasets/human_body_prior/models/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2020.12.12
|
mogen/datasets/human_body_prior/models/ik_engine.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2021.02.12
|
23 |
+
|
24 |
+
from typing import List, Dict
|
25 |
+
|
26 |
+
from psbody.mesh import Mesh
|
27 |
+
from body_visualizer.tools.psbody_mesh_tools import rotateXYZ, points_to_cubes, points_to_spheres
|
28 |
+
|
29 |
+
|
30 |
+
from torch import nn
|
31 |
+
import torch
|
32 |
+
|
33 |
+
from human_body_prior.tools.model_loader import load_model
|
34 |
+
|
35 |
+
import numpy as np
|
36 |
+
|
37 |
+
from body_visualizer.tools.vis_tools import colors
|
38 |
+
from human_body_prior.tools.omni_tools import copy2cpu as c2c
|
39 |
+
from psbody.mesh import MeshViewers
|
40 |
+
|
41 |
+
from human_body_prior.tools.omni_tools import log2file
|
42 |
+
|
43 |
+
from human_body_prior.models.vposer_model import VPoser
|
44 |
+
from human_body_prior.tools.omni_tools import flatten_list
|
45 |
+
|
46 |
+
|
47 |
+
def visualize(points, bm_f, mvs, kpts_colors, verbosity=2, logger=None):
|
48 |
+
from human_body_prior.tools.omni_tools import log2file
|
49 |
+
|
50 |
+
if logger is None: logger = log2file()
|
51 |
+
|
52 |
+
def view(opt_objs, body_v, virtual_markers, opt_it):
|
53 |
+
if verbosity <= 0: return
|
54 |
+
opt_objs_cpu = {k: c2c(v) for k, v in opt_objs.items()}
|
55 |
+
|
56 |
+
total_loss = np.sum([np.sum(v) for k, v in opt_objs_cpu.items()])
|
57 |
+
message = 'it {} -- [total loss = {:.2e}] - {}'.format(opt_it, total_loss, ' | '.join(['%s = %2.2e' % (k, np.sum(v)) for k, v in opt_objs_cpu.items()]))
|
58 |
+
logger(message)
|
59 |
+
if verbosity>1:
|
60 |
+
bs = body_v.shape[0]
|
61 |
+
np.random.seed(100)
|
62 |
+
frame_ids = list(range(bs)) if bs <= len(mvs) else np.random.choice(bs , size=len(mvs), replace=False).tolist()
|
63 |
+
if bs > len(mvs): message += ' -- [frame_ids: {}]'.format(frame_ids)
|
64 |
+
for dispId, fId in enumerate(frame_ids): # check for the number of frames in mvs and show a randomly picked number of frames in body if there is more to show than row*cols available
|
65 |
+
new_body_v = rotateXYZ(body_v[fId], [-90,0,0])
|
66 |
+
|
67 |
+
orig_mrk_mesh = points_to_spheres(rotateXYZ(c2c(points[fId]), [-90,0,0]), radius=0.01, color=kpts_colors)
|
68 |
+
virtual_markers_mesh = points_to_cubes(rotateXYZ(virtual_markers[fId], [-90,0,0]), radius=0.01, color=kpts_colors)
|
69 |
+
new_body_mesh = Mesh(new_body_v, bm_f, vc=colors['grey'])
|
70 |
+
|
71 |
+
# linev = rotateXYZ(np.hstack((c2c(points[fId]), virtual_markers[fId])).reshape((-1, 3)), [-90,0,0])
|
72 |
+
# linee = np.arange(len(linev)).reshape((-1, 2))
|
73 |
+
# ll = Lines(v=linev, e=linee)
|
74 |
+
# ll.vc = (ll.v * 0. + 1) * np.array([0.00, 0.00, 1.00])
|
75 |
+
# mvs[dispId].set_dynamic_lines([ll])
|
76 |
+
|
77 |
+
# orig_mrk_mesh = points_to_spheres(data_pc, radius=0.01, vc=colors['blue'])
|
78 |
+
mvs[dispId].set_dynamic_meshes([orig_mrk_mesh, virtual_markers_mesh])
|
79 |
+
mvs[dispId].set_static_meshes([new_body_mesh])
|
80 |
+
|
81 |
+
mvs[0].set_titlebar(message)
|
82 |
+
# if out_dir is not None: mv.save_snapshot(os.path.join(out_dir, '%05d_it_%.5d.png' %(frame_id, opt_it)))
|
83 |
+
return view
|
84 |
+
|
85 |
+
|
86 |
+
class AdamInClosure():
|
87 |
+
def __init__(self, var_list, lr, max_iter=100, tolerance_change=1e-5):
|
88 |
+
self.optimizer = torch.optim.Adam(var_list, lr)
|
89 |
+
self.max_iter = max_iter
|
90 |
+
self.tolerance_change = tolerance_change
|
91 |
+
|
92 |
+
|
93 |
+
def step(self, closure):
|
94 |
+
prev_loss = None
|
95 |
+
for it in range(self.max_iter):
|
96 |
+
loss = closure()
|
97 |
+
self.optimizer.step()
|
98 |
+
if prev_loss is None:
|
99 |
+
prev_loss = loss
|
100 |
+
continue
|
101 |
+
if torch.isnan(loss):
|
102 |
+
# breakpoint()
|
103 |
+
break
|
104 |
+
if abs(loss - prev_loss) < self.tolerance_change:
|
105 |
+
print('abs(loss - prev_loss) < self.tolerance_change')
|
106 |
+
break
|
107 |
+
|
108 |
+
def zero_grad(self):
|
109 |
+
self.optimizer.zero_grad()
|
110 |
+
|
111 |
+
def ik_fit(optimizer, source_kpts_model, static_vars, vp_model, extra_params={}, on_step=None, gstep=0):
|
112 |
+
|
113 |
+
data_loss = extra_params.get('data_loss', torch.nn.SmoothL1Loss(reduction='mean'))
|
114 |
+
# data_loss =
|
115 |
+
# data_loss = torch.nn.L1Loss(reduction='mean')#change with SmoothL1
|
116 |
+
|
117 |
+
def fit(weights, free_vars):
|
118 |
+
|
119 |
+
fit.gstep += 1
|
120 |
+
optimizer.zero_grad()
|
121 |
+
|
122 |
+
free_vars['pose_body'] = vp_model.decode(free_vars['poZ_body'])['pose_body'].contiguous().view(-1, 63)
|
123 |
+
nonan_mask = torch.isnan(free_vars['poZ_body']).sum(-1) == 0
|
124 |
+
|
125 |
+
opt_objs = {}
|
126 |
+
|
127 |
+
res = source_kpts_model(free_vars)
|
128 |
+
|
129 |
+
opt_objs['data'] = data_loss(res['source_kpts'], static_vars['target_kpts'])
|
130 |
+
|
131 |
+
opt_objs['betas'] = torch.pow(free_vars['betas'][nonan_mask],2).sum()
|
132 |
+
opt_objs['poZ_body'] = torch.pow(free_vars['poZ_body'][nonan_mask],2).sum()
|
133 |
+
|
134 |
+
|
135 |
+
opt_objs = {k: opt_objs[k]*v for k, v in weights.items() if k in opt_objs.keys()}
|
136 |
+
loss_total = torch.sum(torch.stack(list(opt_objs.values())))
|
137 |
+
# breakpoint()
|
138 |
+
|
139 |
+
loss_total.backward()
|
140 |
+
|
141 |
+
if on_step is not None:
|
142 |
+
on_step(opt_objs, c2c(res['body'].v), c2c(res['source_kpts']), fit.gstep)
|
143 |
+
|
144 |
+
fit.free_vars = {k:v for k,v in free_vars.items()}# if k in IK_Engine.fields_to_optimize}
|
145 |
+
# fit.nonan_mask = nonan_mask
|
146 |
+
fit.final_loss = loss_total
|
147 |
+
|
148 |
+
return loss_total
|
149 |
+
|
150 |
+
fit.gstep = gstep
|
151 |
+
fit.final_loss = None
|
152 |
+
fit.free_vars = {}
|
153 |
+
# fit.nonan_mask = None
|
154 |
+
return fit
|
155 |
+
|
156 |
+
class IK_Engine(nn.Module):
|
157 |
+
|
158 |
+
|
159 |
+
def __init__(self,
|
160 |
+
vposer_expr_dir: str,
|
161 |
+
data_loss,
|
162 |
+
optimizer_args: dict={'type':'ADAM'},
|
163 |
+
stepwise_weights: List[Dict]=[{'data': 10., 'poZ_body': .01, 'betas': .5}],
|
164 |
+
display_rc: tuple = (2,1),
|
165 |
+
verbosity: int = 1,
|
166 |
+
logger=None,
|
167 |
+
):
|
168 |
+
'''
|
169 |
+
|
170 |
+
:param vposer_expr_dir: The vposer directory that holds the settings and model snapshot
|
171 |
+
:param data_loss: should be a pytorch callable (source, target) that returns the accumulated loss
|
172 |
+
:param optimizer_args: arguments for optimizers
|
173 |
+
:param stepwise_weights: list of dictionaries. each list element defines weights for one full step of optimization
|
174 |
+
if a weight value is left out, its respective object item will be removed as well. imagine optimizing without data term!
|
175 |
+
:param display_rc: number of row and columns in case verbosity > 1
|
176 |
+
:param verbosity: 0: silent, 1: text, 2: text/visual. running 2 over ssh would need extra work
|
177 |
+
:param logger: an instance of human_body_prior.tools.omni_tools.log2file
|
178 |
+
'''
|
179 |
+
|
180 |
+
|
181 |
+
super(IK_Engine, self).__init__()
|
182 |
+
|
183 |
+
assert isinstance(stepwise_weights, list), ValueError('stepwise_weights should be a list of dictionaries.')
|
184 |
+
assert np.all(['data' in l for l in stepwise_weights]), ValueError('The term data should be available in every weight of anealed optimization step: {}'.format(stepwise_weights))
|
185 |
+
|
186 |
+
self.data_loss = torch.nn.SmoothL1Loss(reduction='mean') if data_loss is None else data_loss
|
187 |
+
|
188 |
+
self.stepwise_weights = stepwise_weights
|
189 |
+
self.verbosity = verbosity
|
190 |
+
self.optimizer_args = optimizer_args
|
191 |
+
|
192 |
+
self.logger = log2file() if logger is None else logger
|
193 |
+
|
194 |
+
|
195 |
+
if verbosity>1:
|
196 |
+
mvs = MeshViewers(display_rc, keepalive=True)
|
197 |
+
self.mvs = flatten_list(mvs)
|
198 |
+
self.mvs[0].set_background_color(colors['white'])
|
199 |
+
else:
|
200 |
+
self.mvs=None
|
201 |
+
|
202 |
+
self.vp_model, _ = load_model(vposer_expr_dir,
|
203 |
+
model_code=VPoser,
|
204 |
+
remove_words_in_model_weights='vp_model.',
|
205 |
+
disable_grad=True)
|
206 |
+
|
207 |
+
|
208 |
+
def forward(self, source_kpts, target_kpts, initial_body_params={}):
|
209 |
+
'''
|
210 |
+
source_kpts is a function that given body parameters computes source key points that should match target key points
|
211 |
+
Try to reconstruct the bps signature by optimizing the body_poZ
|
212 |
+
'''
|
213 |
+
# if self.rt_ps.verbosity > 0: self.logger('Processing {} frames'.format(points.shape[0]))
|
214 |
+
|
215 |
+
bs = target_kpts.shape[0]
|
216 |
+
|
217 |
+
|
218 |
+
on_step = visualize(target_kpts,
|
219 |
+
kpts_colors=source_kpts.kpts_colors,
|
220 |
+
bm_f=source_kpts.bm_f,
|
221 |
+
mvs=self.mvs,
|
222 |
+
verbosity=self.verbosity,
|
223 |
+
logger=self.logger)
|
224 |
+
|
225 |
+
comp_device = target_kpts.device
|
226 |
+
# comp_device = self.vp_model.named_parameters().__next__()[1].device
|
227 |
+
if 'pose_body' not in initial_body_params:
|
228 |
+
initial_body_params['pose_body'] = torch.zeros([bs, 63], device=comp_device, dtype=torch.float, requires_grad=False)
|
229 |
+
if 'trans' not in initial_body_params:
|
230 |
+
initial_body_params['trans'] = torch.zeros([bs, 3], device=comp_device, dtype=torch.float, requires_grad=False)
|
231 |
+
if 'betas' not in initial_body_params:
|
232 |
+
initial_body_params['betas'] = torch.zeros([bs, 10], device=comp_device, dtype=torch.float, requires_grad=False)
|
233 |
+
if 'root_orient' not in initial_body_params:
|
234 |
+
initial_body_params['root_orient'] = torch.zeros([bs, 3], device=comp_device, dtype=torch.float, requires_grad=False)
|
235 |
+
|
236 |
+
initial_body_params['poZ_body'] = self.vp_model.encode(initial_body_params['pose_body']).mean
|
237 |
+
|
238 |
+
free_vars = {k: torch.nn.Parameter(v.detach(), requires_grad=True) for k,v in initial_body_params.items() if k in ['betas', 'trans', 'poZ_body', 'root_orient']}
|
239 |
+
static_vars = {
|
240 |
+
'target_kpts': target_kpts,
|
241 |
+
# 'trans': initial_body_params['trans'].detach(),
|
242 |
+
# 'betas': initial_body_params['betas'].detach(),
|
243 |
+
# 'poZ_body': initial_body_params['poZ_body'].detach()
|
244 |
+
}
|
245 |
+
|
246 |
+
if self.optimizer_args['type'].upper() == 'LBFGS':
|
247 |
+
optimizer = torch.optim.LBFGS(list(free_vars.values()),
|
248 |
+
lr=self.optimizer_args.get('lr', 1),
|
249 |
+
max_iter=self.optimizer_args.get('max_iter', 100),
|
250 |
+
tolerance_change=self.optimizer_args.get('tolerance_change', 1e-5),
|
251 |
+
max_eval=self.optimizer_args.get('max_eval', None),
|
252 |
+
history_size=self.optimizer_args.get('history_size', 100),
|
253 |
+
line_search_fn='strong_wolfe')
|
254 |
+
|
255 |
+
elif self.optimizer_args['type'].upper() == 'ADAM':
|
256 |
+
optimizer = AdamInClosure(list(free_vars.values()),
|
257 |
+
lr=self.optimizer_args.get('lr', 1e-3),
|
258 |
+
max_iter=self.optimizer_args.get('max_iter', 100),
|
259 |
+
tolerance_change=self.optimizer_args.get('tolerance_change', 1e-5),
|
260 |
+
)
|
261 |
+
else:
|
262 |
+
raise ValueError('optimizer_type not recognized.')
|
263 |
+
|
264 |
+
gstep = 0
|
265 |
+
closure = ik_fit(optimizer,
|
266 |
+
source_kpts_model=source_kpts,
|
267 |
+
static_vars=static_vars,
|
268 |
+
vp_model=self.vp_model,
|
269 |
+
extra_params={'data_loss': self.data_loss},
|
270 |
+
on_step=on_step,
|
271 |
+
gstep=gstep)
|
272 |
+
# try:
|
273 |
+
|
274 |
+
for wts in self.stepwise_weights:
|
275 |
+
optimizer.step(lambda: closure(wts, free_vars))
|
276 |
+
free_vars = closure.free_vars
|
277 |
+
# except:
|
278 |
+
#
|
279 |
+
# pass
|
280 |
+
|
281 |
+
# if closure.final_loss is None or torch.isnan(closure.final_loss) or torch.any(torch.isnan(free_vars['trans'])):
|
282 |
+
# if self.verbosity > 0:
|
283 |
+
# self.logger('NaN observed in the optimization results. you might want to restart the refinment procedure.')
|
284 |
+
# breakpoint()
|
285 |
+
# return None
|
286 |
+
|
287 |
+
return closure.free_vars#, closure.nonan_mask
|
mogen/datasets/human_body_prior/models/model_components.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2020.12.12
|
23 |
+
|
24 |
+
from torch import nn
|
25 |
+
|
26 |
+
class View(nn.Module):
|
27 |
+
def __init__(self, *args):
|
28 |
+
super(View, self).__init__()
|
29 |
+
self.shape = args
|
30 |
+
self._name = 'reshape'
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
return x.view(self.shape)
|
34 |
+
|
35 |
+
class BatchFlatten(nn.Module):
|
36 |
+
def __init__(self):
|
37 |
+
super(BatchFlatten, self).__init__()
|
38 |
+
self._name = 'batch_flatten'
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return x.view(x.shape[0], -1)
|
mogen/datasets/human_body_prior/models/vposer_model.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2020.12.12
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
from human_body_prior.models.model_components import BatchFlatten
|
27 |
+
from human_body_prior.tools.rotation_tools import matrot2aa
|
28 |
+
from torch import nn
|
29 |
+
from torch.nn import functional as F
|
30 |
+
|
31 |
+
|
32 |
+
class ContinousRotReprDecoder(nn.Module):
|
33 |
+
def __init__(self):
|
34 |
+
super(ContinousRotReprDecoder, self).__init__()
|
35 |
+
|
36 |
+
def forward(self, module_input):
|
37 |
+
reshaped_input = module_input.view(-1, 3, 2)
|
38 |
+
|
39 |
+
b1 = F.normalize(reshaped_input[:, :, 0], dim=1)
|
40 |
+
|
41 |
+
dot_prod = torch.sum(b1 * reshaped_input[:, :, 1], dim=1, keepdim=True)
|
42 |
+
b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=-1)
|
43 |
+
b3 = torch.cross(b1, b2, dim=1)
|
44 |
+
|
45 |
+
return torch.stack([b1, b2, b3], dim=-1)
|
46 |
+
|
47 |
+
|
48 |
+
class NormalDistDecoder(nn.Module):
|
49 |
+
def __init__(self, num_feat_in, latentD):
|
50 |
+
super(NormalDistDecoder, self).__init__()
|
51 |
+
|
52 |
+
self.mu = nn.Linear(num_feat_in, latentD)
|
53 |
+
self.logvar = nn.Linear(num_feat_in, latentD)
|
54 |
+
|
55 |
+
def forward(self, Xout):
|
56 |
+
return torch.distributions.normal.Normal(self.mu(Xout), F.softplus(self.logvar(Xout)))
|
57 |
+
|
58 |
+
|
59 |
+
class VPoser(nn.Module):
|
60 |
+
def __init__(self, model_ps):
|
61 |
+
super(VPoser, self).__init__()
|
62 |
+
|
63 |
+
num_neurons, self.latentD = model_ps.model_params.num_neurons, model_ps.model_params.latentD
|
64 |
+
|
65 |
+
self.num_joints = 21
|
66 |
+
n_features = self.num_joints * 3
|
67 |
+
|
68 |
+
self.encoder_net = nn.Sequential(
|
69 |
+
BatchFlatten(),
|
70 |
+
nn.BatchNorm1d(n_features),
|
71 |
+
nn.Linear(n_features, num_neurons),
|
72 |
+
nn.LeakyReLU(),
|
73 |
+
nn.BatchNorm1d(num_neurons),
|
74 |
+
nn.Dropout(0.1),
|
75 |
+
nn.Linear(num_neurons, num_neurons),
|
76 |
+
nn.Linear(num_neurons, num_neurons),
|
77 |
+
NormalDistDecoder(num_neurons, self.latentD)
|
78 |
+
)
|
79 |
+
|
80 |
+
self.decoder_net = nn.Sequential(
|
81 |
+
nn.Linear(self.latentD, num_neurons),
|
82 |
+
nn.LeakyReLU(),
|
83 |
+
nn.Dropout(0.1),
|
84 |
+
nn.Linear(num_neurons, num_neurons),
|
85 |
+
nn.LeakyReLU(),
|
86 |
+
nn.Linear(num_neurons, self.num_joints * 6),
|
87 |
+
ContinousRotReprDecoder(),
|
88 |
+
)
|
89 |
+
|
90 |
+
def encode(self, pose_body):
|
91 |
+
'''
|
92 |
+
:param Pin: Nx(numjoints*3)
|
93 |
+
:param rep_type: 'matrot'/'aa' for matrix rotations or axis-angle
|
94 |
+
:return:
|
95 |
+
'''
|
96 |
+
return self.encoder_net(pose_body)
|
97 |
+
|
98 |
+
def decode(self, Zin):
|
99 |
+
bs = Zin.shape[0]
|
100 |
+
|
101 |
+
prec = self.decoder_net(Zin)
|
102 |
+
|
103 |
+
return {
|
104 |
+
'pose_body': matrot2aa(prec.view(-1, 3, 3)).view(bs, -1, 3),
|
105 |
+
'pose_body_matrot': prec.view(bs, -1, 9)
|
106 |
+
}
|
107 |
+
|
108 |
+
|
109 |
+
def forward(self, pose_body):
|
110 |
+
'''
|
111 |
+
:param Pin: aa: Nx1xnum_jointsx3 / matrot: Nx1xnum_jointsx9
|
112 |
+
:param input_type: matrot / aa for matrix rotations or axis angles
|
113 |
+
:param output_type: matrot / aa
|
114 |
+
:return:
|
115 |
+
'''
|
116 |
+
|
117 |
+
q_z = self.encode(pose_body)
|
118 |
+
q_z_sample = q_z.rsample()
|
119 |
+
decode_results = self.decode(q_z_sample)
|
120 |
+
decode_results.update({'poZ_body_mean': q_z.mean, 'poZ_body_std': q_z.scale, 'q_z': q_z})
|
121 |
+
return decode_results
|
122 |
+
|
123 |
+
def sample_poses(self, num_poses, seed=None):
|
124 |
+
np.random.seed(seed)
|
125 |
+
|
126 |
+
some_weight = [a for a in self.parameters()][0]
|
127 |
+
dtype = some_weight.dtype
|
128 |
+
device = some_weight.device
|
129 |
+
self.eval()
|
130 |
+
with torch.no_grad():
|
131 |
+
Zgen = torch.tensor(np.random.normal(0., 1., size=(num_poses, self.latentD)), dtype=dtype, device=device)
|
132 |
+
|
133 |
+
return self.decode(Zgen)
|
mogen/datasets/human_body_prior/tools/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2020.12.12
|
mogen/datasets/human_body_prior/tools/angle_continuous_repres.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2020.12.12
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import torch
|
25 |
+
from torch import nn
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
|
29 |
+
# numpy implementation of yi zhou's method
|
30 |
+
def norm(v):
|
31 |
+
return v/np.linalg.norm(v)
|
32 |
+
|
33 |
+
def gs(M):
|
34 |
+
a1 = M[:,0]
|
35 |
+
a2 = M[:,1]
|
36 |
+
b1 = norm(a1)
|
37 |
+
b2 = norm((a2-np.dot(b1,a2)*b1))
|
38 |
+
b3 = np.cross(b1,b2)
|
39 |
+
return np.vstack([b1,b2,b3]).T
|
40 |
+
|
41 |
+
# input sz bszx3x2
|
42 |
+
def bgs(d6s):
|
43 |
+
|
44 |
+
bsz = d6s.shape[0]
|
45 |
+
b1 = F.normalize(d6s[:,:,0], p=2, dim=1)
|
46 |
+
a2 = d6s[:,:,1]
|
47 |
+
c = torch.bmm(b1.view(bsz,1,-1),a2.view(bsz,-1,1)).view(bsz,1)*b1
|
48 |
+
b2 = F.normalize(a2-c,p=2,dim=1)
|
49 |
+
b3=torch.cross(b1,b2,dim=1)
|
50 |
+
return torch.stack([b1,b2,b3],dim=1).permute(0,2,1)
|
51 |
+
|
52 |
+
|
53 |
+
class geodesic_loss_R(nn.Module):
|
54 |
+
def __init__(self, reduction='batchmean'):
|
55 |
+
super(geodesic_loss_R, self).__init__()
|
56 |
+
|
57 |
+
self.reduction = reduction
|
58 |
+
self.eps = 1e-6
|
59 |
+
|
60 |
+
# batch geodesic loss for rotation matrices
|
61 |
+
def bgdR(self,m1,m2):
|
62 |
+
batch = m1.shape[0]
|
63 |
+
m = torch.bmm(m1, m2.transpose(1, 2)) # batch*3*3
|
64 |
+
|
65 |
+
cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2
|
66 |
+
cos = torch.min(cos, m1.new(np.ones(batch)))
|
67 |
+
cos = torch.max(cos, m1.new(np.ones(batch)) * -1)
|
68 |
+
|
69 |
+
return torch.acos(cos)
|
70 |
+
|
71 |
+
def forward(self, ypred, ytrue):
|
72 |
+
theta = self.bgdR(ypred,ytrue)
|
73 |
+
if self.reduction == 'mean':
|
74 |
+
return torch.mean(theta)
|
75 |
+
if self.reduction == 'batchmean':
|
76 |
+
breakpoint()
|
77 |
+
return torch.mean(torch.sum(theta, dim=theta.shape[1:]))
|
78 |
+
|
79 |
+
else:
|
80 |
+
return theta
|
mogen/datasets/human_body_prior/tools/configurations.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2020.12.12
|
23 |
+
from dotmap import DotMap
|
24 |
+
import os
|
25 |
+
import yaml
|
26 |
+
|
27 |
+
def load_config(default_ps_fname=None, **kwargs):
|
28 |
+
if isinstance(default_ps_fname, str):
|
29 |
+
assert os.path.exists(default_ps_fname), FileNotFoundError(default_ps_fname)
|
30 |
+
assert default_ps_fname.lower().endswith('.yaml'), NotImplementedError('Only .yaml files are accepted.')
|
31 |
+
default_ps = yaml.safe_load(open(default_ps_fname, 'r'))
|
32 |
+
else:
|
33 |
+
default_ps = {}
|
34 |
+
|
35 |
+
default_ps.update(kwargs)
|
36 |
+
|
37 |
+
return DotMap(default_ps, _dynamic=False)
|
38 |
+
|
39 |
+
def dump_config(data, fname):
|
40 |
+
'''
|
41 |
+
dump current configuration to an ini file
|
42 |
+
:param fname:
|
43 |
+
:return:
|
44 |
+
'''
|
45 |
+
with open(fname, 'w') as file:
|
46 |
+
yaml.dump(data.toDict(), file)
|
47 |
+
return fname
|
mogen/datasets/human_body_prior/tools/model_loader.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by: Nima Ghorbani <https://www.linkedin.com/in/nghorbani/>
|
20 |
+
# 2018.01.02
|
21 |
+
|
22 |
+
import os, glob
|
23 |
+
import numpy as np
|
24 |
+
from human_body_prior.tools.configurations import load_config, dump_config
|
25 |
+
import os.path as osp
|
26 |
+
|
27 |
+
def exprdir2model(expr_dir):
|
28 |
+
|
29 |
+
if not os.path.exists(expr_dir): raise ValueError('Could not find the experiment directory: %s' % expr_dir)
|
30 |
+
|
31 |
+
model_snapshots_dir = osp.join(expr_dir, 'snapshots')
|
32 |
+
available_ckpts = sorted(glob.glob(osp.join(model_snapshots_dir, '*.ckpt')), key=osp.getmtime)
|
33 |
+
assert len(available_ckpts) > 0, ValueError('No checck points found at {}'.format(model_snapshots_dir))
|
34 |
+
trained_weigths_fname = available_ckpts[-1]
|
35 |
+
|
36 |
+
model_ps_fname = glob.glob(osp.join('/', '/'.join(trained_weigths_fname.split('/')[:-2]), '*.yaml'))
|
37 |
+
if len(model_ps_fname) == 0:
|
38 |
+
model_ps_fname = glob.glob(osp.join('/'.join(trained_weigths_fname.split('/')[:-2]), '*.yaml'))
|
39 |
+
|
40 |
+
model_ps_fname = model_ps_fname[0]
|
41 |
+
model_ps = load_config(default_ps_fname=model_ps_fname)
|
42 |
+
|
43 |
+
model_ps.logging.best_model_fname = trained_weigths_fname
|
44 |
+
|
45 |
+
return model_ps, trained_weigths_fname
|
46 |
+
|
47 |
+
|
48 |
+
def load_model(expr_dir, model_code=None, remove_words_in_model_weights=None, load_only_ps=False, disable_grad=True, custom_ps = None):
|
49 |
+
'''
|
50 |
+
|
51 |
+
:param expr_dir:
|
52 |
+
:param model_code: an imported module
|
53 |
+
from supercap.train.supercap_smpl import SuperCap, then pass SuperCap to this function
|
54 |
+
:param if True will load the model definition used for training, and not the one in current repository
|
55 |
+
:return:
|
56 |
+
'''
|
57 |
+
import importlib
|
58 |
+
import torch
|
59 |
+
|
60 |
+
model_ps, trained_weigths_fname = exprdir2model(expr_dir)
|
61 |
+
if load_only_ps: return model_ps
|
62 |
+
if custom_ps is not None: model_ps = custom_ps
|
63 |
+
assert model_code is not None, ValueError('mode_code should be provided')
|
64 |
+
model_instance = model_code(model_ps)
|
65 |
+
if disable_grad: # i had to do this. torch.no_grad() couldnt achieve what i was looking for
|
66 |
+
for param in model_instance.parameters():
|
67 |
+
param.requires_grad = False
|
68 |
+
state_dict = torch.load(trained_weigths_fname)['state_dict']
|
69 |
+
if remove_words_in_model_weights is not None:
|
70 |
+
words = '{}'.format(remove_words_in_model_weights)
|
71 |
+
state_dict = {k.replace(words, '') if k.startswith(words) else k: v for k, v in state_dict.items()}
|
72 |
+
|
73 |
+
## keys that were in the model trained file and not in the current model
|
74 |
+
instance_model_keys = list(model_instance.state_dict().keys())
|
75 |
+
trained_model_keys = list(state_dict.keys())
|
76 |
+
wts_in_model_not_in_file = set(instance_model_keys).difference(set(trained_model_keys))
|
77 |
+
## keys that are in the current model not in the training weights
|
78 |
+
wts_in_file_not_in_model = set(trained_model_keys).difference(set(instance_model_keys))
|
79 |
+
# assert len(wts_in_model_not_in_file) == 0, ValueError('Some model weights are not present in the pretrained file. {}'.format(wts_in_model_not_in_file))
|
80 |
+
|
81 |
+
state_dict = {k:v for k, v in state_dict.items() if k in instance_model_keys}
|
82 |
+
model_instance.load_state_dict(state_dict, strict=False) # Todo fix the issues so that we can set the strict to true. The body model uses unnecessary registered buffers
|
83 |
+
model_instance.eval()
|
84 |
+
|
85 |
+
return model_instance, model_ps
|
86 |
+
|
87 |
+
|
mogen/datasets/human_body_prior/tools/omni_tools.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2018.01.02
|
23 |
+
import numpy as np
|
24 |
+
import random
|
25 |
+
import torch
|
26 |
+
import os
|
27 |
+
import sys
|
28 |
+
import os.path as osp
|
29 |
+
|
30 |
+
def copy2cpu(tensor):
|
31 |
+
if isinstance(tensor, np.ndarray): return tensor
|
32 |
+
return tensor.detach().cpu().numpy()
|
33 |
+
|
34 |
+
def create_list_chunks(list_, group_size, overlap_size, cut_smaller_batches=True):
|
35 |
+
if cut_smaller_batches:
|
36 |
+
return [list_[i:i + group_size] for i in range(0, len(list_), group_size - overlap_size) if len(list_[i:i + group_size])==group_size]
|
37 |
+
else:
|
38 |
+
return [list_[i:i + group_size] for i in range(0, len(list_), group_size - overlap_size)]
|
39 |
+
|
40 |
+
|
41 |
+
def trainable_params_count(params):
|
42 |
+
return sum([p.numel() for p in params if p.requires_grad])
|
43 |
+
|
44 |
+
def flatten_list(l):
|
45 |
+
return [item for sublist in l for item in sublist]
|
46 |
+
|
47 |
+
def get_support_data_dir(current_fname=__file__):
|
48 |
+
support_data_dir = osp.abspath(current_fname)
|
49 |
+
support_data_dir_split = support_data_dir.split('/')
|
50 |
+
support_data_dir = '/'.join(support_data_dir_split[:support_data_dir_split.index('src')])
|
51 |
+
support_data_dir = osp.join(support_data_dir, 'support_data')
|
52 |
+
assert osp.exists(support_data_dir)
|
53 |
+
return support_data_dir
|
54 |
+
|
55 |
+
def make_deterministic(seed):
|
56 |
+
random.seed(seed)
|
57 |
+
torch.manual_seed(seed)
|
58 |
+
torch.cuda.manual_seed_all(seed)
|
59 |
+
np.random.seed(seed)
|
60 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
61 |
+
torch.backends.cudnn.deterministic = True
|
62 |
+
torch.backends.cudnn.benchmark = False
|
63 |
+
|
64 |
+
def id_generator(size=13):
|
65 |
+
import string
|
66 |
+
import random
|
67 |
+
chars = string.ascii_uppercase + string.digits
|
68 |
+
return ''.join(random.choice(chars) for _ in range(size))
|
69 |
+
|
70 |
+
def logger_sequencer(logger_list, prefix=None):
|
71 |
+
def post_text(text):
|
72 |
+
if prefix is not None: text = '{} -- '.format(prefix) + text
|
73 |
+
for logger_call in logger_list: logger_call(text)
|
74 |
+
return post_text
|
75 |
+
|
76 |
+
class log2file():
|
77 |
+
def __init__(self,logpath=None, prefix='', auto_newline = True, write2file_only=False):
|
78 |
+
if logpath is not None:
|
79 |
+
makepath(logpath, isfile=True)
|
80 |
+
self.fhandle = open(logpath,'a+')
|
81 |
+
else:
|
82 |
+
self.fhandle = None
|
83 |
+
|
84 |
+
self.prefix = prefix
|
85 |
+
self.auto_newline = auto_newline
|
86 |
+
self.write2file_only = write2file_only
|
87 |
+
|
88 |
+
def __call__(self, text):
|
89 |
+
if text is None: return
|
90 |
+
if self.prefix != '': text = '{} -- '.format(self.prefix) + text
|
91 |
+
# breakpoint()
|
92 |
+
if self.auto_newline:
|
93 |
+
if not text.endswith('\n'):
|
94 |
+
text = text + '\n'
|
95 |
+
if not self.write2file_only: sys.stderr.write(text)
|
96 |
+
if self.fhandle is not None:
|
97 |
+
self.fhandle.write(text)
|
98 |
+
self.fhandle.flush()
|
99 |
+
|
100 |
+
|
101 |
+
def makepath(*args, **kwargs):
|
102 |
+
'''
|
103 |
+
if the path does not exist make it
|
104 |
+
:param desired_path: can be path to a file or a folder name
|
105 |
+
:return:
|
106 |
+
'''
|
107 |
+
isfile = kwargs.get('isfile', False)
|
108 |
+
import os
|
109 |
+
desired_path = os.path.join(*args)
|
110 |
+
if isfile:
|
111 |
+
if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path))
|
112 |
+
else:
|
113 |
+
if not os.path.exists(desired_path): os.makedirs(desired_path)
|
114 |
+
return desired_path
|
115 |
+
|
116 |
+
def matrot2axisangle(matrots):
|
117 |
+
'''
|
118 |
+
:param matrots: N*T*num_joints*9
|
119 |
+
:return: N*T*num_joints*3
|
120 |
+
'''
|
121 |
+
import cv2
|
122 |
+
N = matrots.shape[0]
|
123 |
+
T = matrots.shape[1]
|
124 |
+
n_joints = matrots.shape[2]
|
125 |
+
out_axisangle = []
|
126 |
+
for tIdx in range(T):
|
127 |
+
T_axisangle = []
|
128 |
+
for mIdx in range(N):
|
129 |
+
cur_axisangle = []
|
130 |
+
for jIdx in range(n_joints):
|
131 |
+
cur_axisangle.append(cv2.Rodrigues(matrots[mIdx, tIdx, jIdx:jIdx + 1, :].reshape(3, 3))[0].T)
|
132 |
+
T_axisangle.append(np.vstack(cur_axisangle)[np.newaxis])
|
133 |
+
out_axisangle.append(np.vstack(T_axisangle).reshape([N,1, -1,3]))
|
134 |
+
return np.concatenate(out_axisangle, axis=1)
|
135 |
+
|
136 |
+
def axisangle2matrots(axisangle):
|
137 |
+
'''
|
138 |
+
:param matrots: N*1*num_joints*3
|
139 |
+
:return: N*num_joints*9
|
140 |
+
'''
|
141 |
+
import cv2
|
142 |
+
batch_size = axisangle.shape[0]
|
143 |
+
axisangle = axisangle.reshape([batch_size,1,-1,3])
|
144 |
+
out_matrot = []
|
145 |
+
for mIdx in range(axisangle.shape[0]):
|
146 |
+
cur_axisangle = []
|
147 |
+
for jIdx in range(axisangle.shape[2]):
|
148 |
+
a = cv2.Rodrigues(axisangle[mIdx, 0, jIdx:jIdx + 1, :].reshape(1, 3))[0].T
|
149 |
+
cur_axisangle.append(a)
|
150 |
+
|
151 |
+
out_matrot.append(np.array(cur_axisangle).reshape([batch_size,1,-1,9]))
|
152 |
+
return np.vstack(out_matrot)
|
153 |
+
|
154 |
+
|
155 |
+
def apply_mesh_tranfsormations_(meshes, transf):
|
156 |
+
'''
|
157 |
+
apply inplace translations to meshes
|
158 |
+
:param meshes: list of trimesh meshes
|
159 |
+
:param transf:
|
160 |
+
:return:
|
161 |
+
'''
|
162 |
+
for i in range(len(meshes)):
|
163 |
+
meshes[i] = meshes[i].apply_transform(transf)
|
mogen/datasets/human_body_prior/tools/rotation_tools.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2020.12.12
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
from torch.nn import functional as F
|
26 |
+
from human_body_prior.tools import tgm_conversion as tgm
|
27 |
+
import torch
|
28 |
+
|
29 |
+
def local2global_pose(local_pose, kintree):
|
30 |
+
bs = local_pose.shape[0]
|
31 |
+
|
32 |
+
local_pose = local_pose.view(bs, -1, 3, 3)
|
33 |
+
|
34 |
+
global_pose = local_pose.clone()
|
35 |
+
|
36 |
+
for jId in range(len(kintree)):
|
37 |
+
parent_id = kintree[jId]
|
38 |
+
if parent_id >= 0:
|
39 |
+
global_pose[:, jId] = torch.matmul(global_pose[:, parent_id], global_pose[:, jId])
|
40 |
+
|
41 |
+
return global_pose
|
42 |
+
|
43 |
+
def em2euler(em):
|
44 |
+
'''
|
45 |
+
|
46 |
+
:param em: rotation in expo-map (3,)
|
47 |
+
:return: rotation in euler angles (3,)
|
48 |
+
'''
|
49 |
+
from transforms3d.euler import axangle2euler
|
50 |
+
|
51 |
+
theta = np.sqrt((em ** 2).sum())
|
52 |
+
axis = em / theta
|
53 |
+
return np.array(axangle2euler(axis, theta))
|
54 |
+
|
55 |
+
|
56 |
+
def euler2em(ea):
|
57 |
+
'''
|
58 |
+
|
59 |
+
:param ea: rotation in euler angles (3,)
|
60 |
+
:return: rotation in expo-map (3,)
|
61 |
+
'''
|
62 |
+
from transforms3d.euler import euler2axangle
|
63 |
+
axis, theta = euler2axangle(*ea)
|
64 |
+
return np.array(axis*theta)
|
65 |
+
|
66 |
+
|
67 |
+
def remove_zrot(pose):
|
68 |
+
noZ = em2euler(pose[:3].copy())
|
69 |
+
noZ[2] = 0
|
70 |
+
pose[:3] = euler2em(noZ).copy()
|
71 |
+
return pose
|
72 |
+
|
73 |
+
def matrot2aa(pose_matrot):
|
74 |
+
'''
|
75 |
+
:param pose_matrot: Nx3x3
|
76 |
+
:return: Nx3
|
77 |
+
'''
|
78 |
+
bs = pose_matrot.size(0)
|
79 |
+
homogen_matrot = F.pad(pose_matrot, [0,1])
|
80 |
+
pose = tgm.rotation_matrix_to_angle_axis(homogen_matrot)
|
81 |
+
return pose
|
82 |
+
|
83 |
+
def aa2matrot(pose):
|
84 |
+
'''
|
85 |
+
:param Nx3
|
86 |
+
:return: pose_matrot: Nx3x3
|
87 |
+
'''
|
88 |
+
bs = pose.size(0)
|
89 |
+
num_joints = pose.size(1)//3
|
90 |
+
pose_body_matrot = tgm.angle_axis_to_rotation_matrix(pose)[:, :3, :3].contiguous()#.view(bs, num_joints*9)
|
91 |
+
return pose_body_matrot
|
92 |
+
|
93 |
+
def noisy_zrot(rot_in):
|
94 |
+
'''
|
95 |
+
|
96 |
+
:param rot_in: np.array Nx3 rotations in axis-angle representation
|
97 |
+
:return:
|
98 |
+
will add a degree from a full circle to the zrotations
|
99 |
+
'''
|
100 |
+
is_batched = False
|
101 |
+
if rot_in.ndim == 2: is_batched = True
|
102 |
+
if not is_batched:
|
103 |
+
rot_in = rot_in[np.newaxis]
|
104 |
+
|
105 |
+
rnd_zrot = np.random.uniform(-np.pi, np.pi)
|
106 |
+
rot_out = []
|
107 |
+
for bId in range(len(rot_in)):
|
108 |
+
pose_cpu = rot_in[bId]
|
109 |
+
pose_euler = em2euler(pose_cpu)
|
110 |
+
|
111 |
+
pose_euler[2] += rnd_zrot
|
112 |
+
|
113 |
+
pose_aa = euler2em(pose_euler)
|
114 |
+
rot_out.append(pose_aa.copy())
|
115 |
+
|
116 |
+
return np.array(rot_out)
|
117 |
+
|
118 |
+
def rotate_points_xyz(mesh_v, Rxyz):
|
119 |
+
'''
|
120 |
+
|
121 |
+
:param mesh_v: Nxnum_vx3
|
122 |
+
:param Rxyz: Nx3
|
123 |
+
:return:
|
124 |
+
'''
|
125 |
+
|
126 |
+
mesh_v_rotated = []
|
127 |
+
|
128 |
+
for fId in range(mesh_v.shape[0]):
|
129 |
+
angle = np.radians(Rxyz[fId, 0])
|
130 |
+
rx = np.array([
|
131 |
+
[1., 0., 0. ],
|
132 |
+
[0., np.cos(angle), -np.sin(angle)],
|
133 |
+
[0., np.sin(angle), np.cos(angle) ]
|
134 |
+
])
|
135 |
+
|
136 |
+
angle = np.radians(Rxyz[fId, 1])
|
137 |
+
ry = np.array([
|
138 |
+
[np.cos(angle), 0., np.sin(angle)],
|
139 |
+
[0., 1., 0. ],
|
140 |
+
[-np.sin(angle), 0., np.cos(angle)]
|
141 |
+
])
|
142 |
+
|
143 |
+
angle = np.radians(Rxyz[fId, 2])
|
144 |
+
rz = np.array([
|
145 |
+
[np.cos(angle), -np.sin(angle), 0. ],
|
146 |
+
[np.sin(angle), np.cos(angle), 0. ],
|
147 |
+
[0., 0., 1. ]
|
148 |
+
])
|
149 |
+
mesh_v_rotated.append(rz.dot(ry.dot(rx.dot(mesh_v[fId].T))).T)
|
150 |
+
|
151 |
+
return np.array(mesh_v_rotated)
|
mogen/datasets/human_body_prior/tools/tgm_conversion.py
ADDED
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This is a ripped code from an version of torchgeometry now called Kornia. Since Kornia has a
|
3 |
+
know bug: https://github.com/kornia/kornia/issues/317#issuecomment-751305910
|
4 |
+
in converting rotation representations we use this code until the original bug in Kornia is addressed
|
5 |
+
'''
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
# functional api
|
12 |
+
"pi",
|
13 |
+
"rad2deg",
|
14 |
+
"deg2rad",
|
15 |
+
"convert_points_from_homogeneous",
|
16 |
+
"convert_points_to_homogeneous",
|
17 |
+
"angle_axis_to_rotation_matrix",
|
18 |
+
"rotation_matrix_to_angle_axis",
|
19 |
+
"rotation_matrix_to_quaternion",
|
20 |
+
"quaternion_to_angle_axis",
|
21 |
+
"angle_axis_to_quaternion",
|
22 |
+
"rtvec_to_pose",
|
23 |
+
# layer api
|
24 |
+
"RadToDeg",
|
25 |
+
"DegToRad",
|
26 |
+
"ConvertPointsFromHomogeneous",
|
27 |
+
"ConvertPointsToHomogeneous",
|
28 |
+
]
|
29 |
+
|
30 |
+
|
31 |
+
"""Constant with number pi
|
32 |
+
"""
|
33 |
+
pi = torch.Tensor([3.14159265358979323846])
|
34 |
+
|
35 |
+
|
36 |
+
def rad2deg(tensor):
|
37 |
+
r"""Function that converts angles from radians to degrees.
|
38 |
+
|
39 |
+
See :class:`~torchgeometry.RadToDeg` for details.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
tensor (Tensor): Tensor of arbitrary shape.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Tensor: Tensor with same shape as input.
|
46 |
+
|
47 |
+
Example:
|
48 |
+
>>> input = tgm.pi * torch.rand(1, 3, 3)
|
49 |
+
>>> output = tgm.rad2deg(input)
|
50 |
+
"""
|
51 |
+
if not torch.is_tensor(tensor):
|
52 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}"
|
53 |
+
.format(type(tensor)))
|
54 |
+
|
55 |
+
return 180. * tensor / pi.to(tensor.device).type(tensor.dtype)
|
56 |
+
|
57 |
+
|
58 |
+
def deg2rad(tensor):
|
59 |
+
r"""Function that converts angles from degrees to radians.
|
60 |
+
|
61 |
+
See :class:`~torchgeometry.DegToRad` for details.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
tensor (Tensor): Tensor of arbitrary shape.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Tensor: Tensor with same shape as input.
|
68 |
+
|
69 |
+
Examples::
|
70 |
+
|
71 |
+
>>> input = 360. * torch.rand(1, 3, 3)
|
72 |
+
>>> output = tgm.deg2rad(input)
|
73 |
+
"""
|
74 |
+
if not torch.is_tensor(tensor):
|
75 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}"
|
76 |
+
.format(type(tensor)))
|
77 |
+
|
78 |
+
return tensor * pi.to(tensor.device).type(tensor.dtype) / 180.
|
79 |
+
|
80 |
+
|
81 |
+
def convert_points_from_homogeneous(points):
|
82 |
+
r"""Function that converts points from homogeneous to Euclidean space.
|
83 |
+
|
84 |
+
See :class:`~torchgeometry.ConvertPointsFromHomogeneous` for details.
|
85 |
+
|
86 |
+
Examples::
|
87 |
+
|
88 |
+
>>> input = torch.rand(2, 4, 3) # BxNx3
|
89 |
+
>>> output = tgm.convert_points_from_homogeneous(input) # BxNx2
|
90 |
+
"""
|
91 |
+
if not torch.is_tensor(points):
|
92 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
93 |
+
type(points)))
|
94 |
+
if len(points.shape) < 2:
|
95 |
+
raise ValueError("Input must be at least a 2D tensor. Got {}".format(
|
96 |
+
points.shape))
|
97 |
+
|
98 |
+
return points[..., :-1] / points[..., -1:]
|
99 |
+
|
100 |
+
|
101 |
+
def convert_points_to_homogeneous(points):
|
102 |
+
r"""Function that converts points from Euclidean to homogeneous space.
|
103 |
+
|
104 |
+
See :class:`~torchgeometry.ConvertPointsToHomogeneous` for details.
|
105 |
+
|
106 |
+
Examples::
|
107 |
+
|
108 |
+
>>> input = torch.rand(2, 4, 3) # BxNx3
|
109 |
+
>>> output = tgm.convert_points_to_homogeneous(input) # BxNx4
|
110 |
+
"""
|
111 |
+
if not torch.is_tensor(points):
|
112 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
113 |
+
type(points)))
|
114 |
+
if len(points.shape) < 2:
|
115 |
+
raise ValueError("Input must be at least a 2D tensor. Got {}".format(
|
116 |
+
points.shape))
|
117 |
+
|
118 |
+
return nn.functional.pad(points, (0, 1), "constant", 1.0)
|
119 |
+
|
120 |
+
|
121 |
+
def angle_axis_to_rotation_matrix(angle_axis):
|
122 |
+
"""Convert 3d vector of axis-angle rotation to 4x4 rotation matrix
|
123 |
+
|
124 |
+
Args:
|
125 |
+
angle_axis (Tensor): tensor of 3d vector of axis-angle rotations.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
Tensor: tensor of 4x4 rotation matrices.
|
129 |
+
|
130 |
+
Shape:
|
131 |
+
- Input: :math:`(N, 3)`
|
132 |
+
- Output: :math:`(N, 4, 4)`
|
133 |
+
|
134 |
+
Example:
|
135 |
+
>>> input = torch.rand(1, 3) # Nx3
|
136 |
+
>>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx4x4
|
137 |
+
"""
|
138 |
+
def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6):
|
139 |
+
# We want to be careful to only evaluate the square root if the
|
140 |
+
# norm of the angle_axis vector is greater than zero. Otherwise
|
141 |
+
# we get a division by zero.
|
142 |
+
k_one = 1.0
|
143 |
+
theta = torch.sqrt(theta2)
|
144 |
+
wxyz = angle_axis / (theta + eps)
|
145 |
+
wx, wy, wz = torch.chunk(wxyz, 3, dim=1)
|
146 |
+
cos_theta = torch.cos(theta)
|
147 |
+
sin_theta = torch.sin(theta)
|
148 |
+
|
149 |
+
r00 = cos_theta + wx * wx * (k_one - cos_theta)
|
150 |
+
r10 = wz * sin_theta + wx * wy * (k_one - cos_theta)
|
151 |
+
r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta)
|
152 |
+
r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta
|
153 |
+
r11 = cos_theta + wy * wy * (k_one - cos_theta)
|
154 |
+
r21 = wx * sin_theta + wy * wz * (k_one - cos_theta)
|
155 |
+
r02 = wy * sin_theta + wx * wz * (k_one - cos_theta)
|
156 |
+
r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta)
|
157 |
+
r22 = cos_theta + wz * wz * (k_one - cos_theta)
|
158 |
+
rotation_matrix = torch.cat(
|
159 |
+
[r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1)
|
160 |
+
return rotation_matrix.view(-1, 3, 3)
|
161 |
+
|
162 |
+
def _compute_rotation_matrix_taylor(angle_axis):
|
163 |
+
rx, ry, rz = torch.chunk(angle_axis, 3, dim=1)
|
164 |
+
k_one = torch.ones_like(rx)
|
165 |
+
rotation_matrix = torch.cat(
|
166 |
+
[k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1)
|
167 |
+
return rotation_matrix.view(-1, 3, 3)
|
168 |
+
|
169 |
+
# stolen from ceres/rotation.h
|
170 |
+
|
171 |
+
_angle_axis = torch.unsqueeze(angle_axis, dim=1)
|
172 |
+
theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2))
|
173 |
+
theta2 = torch.squeeze(theta2, dim=1)
|
174 |
+
|
175 |
+
# compute rotation matrices
|
176 |
+
rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2)
|
177 |
+
rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis)
|
178 |
+
|
179 |
+
# create mask to handle both cases
|
180 |
+
eps = 1e-6
|
181 |
+
mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device)
|
182 |
+
mask_pos = (mask).type_as(theta2)
|
183 |
+
mask_neg = (mask == False).type_as(theta2) # noqa
|
184 |
+
|
185 |
+
# create output pose matrix
|
186 |
+
batch_size = angle_axis.shape[0]
|
187 |
+
rotation_matrix = torch.eye(4).to(angle_axis.device).type_as(angle_axis)
|
188 |
+
rotation_matrix = rotation_matrix.view(1, 4, 4).repeat(batch_size, 1, 1)
|
189 |
+
# fill output matrix with masked values
|
190 |
+
rotation_matrix[..., :3, :3] = \
|
191 |
+
mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor
|
192 |
+
return rotation_matrix # Nx4x4
|
193 |
+
|
194 |
+
|
195 |
+
def rtvec_to_pose(rtvec):
|
196 |
+
"""
|
197 |
+
Convert axis-angle rotation and translation vector to 4x4 pose matrix
|
198 |
+
|
199 |
+
Args:
|
200 |
+
rtvec (Tensor): Rodrigues vector transformations
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
Tensor: transformation matrices
|
204 |
+
|
205 |
+
Shape:
|
206 |
+
- Input: :math:`(N, 6)`
|
207 |
+
- Output: :math:`(N, 4, 4)`
|
208 |
+
|
209 |
+
Example:
|
210 |
+
>>> input = torch.rand(3, 6) # Nx6
|
211 |
+
>>> output = tgm.rtvec_to_pose(input) # Nx4x4
|
212 |
+
"""
|
213 |
+
assert rtvec.shape[-1] == 6, 'rtvec=[rx, ry, rz, tx, ty, tz]'
|
214 |
+
pose = angle_axis_to_rotation_matrix(rtvec[..., :3])
|
215 |
+
pose[..., :3, 3] = rtvec[..., 3:]
|
216 |
+
return pose
|
217 |
+
|
218 |
+
|
219 |
+
def rotation_matrix_to_angle_axis(rotation_matrix):
|
220 |
+
"""Convert 3x4 rotation matrix to Rodrigues vector
|
221 |
+
|
222 |
+
Args:
|
223 |
+
rotation_matrix (Tensor): rotation matrix.
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
Tensor: Rodrigues vector transformation.
|
227 |
+
|
228 |
+
Shape:
|
229 |
+
- Input: :math:`(N, 3, 4)`
|
230 |
+
- Output: :math:`(N, 3)`
|
231 |
+
|
232 |
+
Example:
|
233 |
+
>>> input = torch.rand(2, 3, 4) # Nx4x4
|
234 |
+
>>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
|
235 |
+
"""
|
236 |
+
# todo add check that matrix is a valid rotation matrix
|
237 |
+
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
238 |
+
return quaternion_to_angle_axis(quaternion)
|
239 |
+
|
240 |
+
|
241 |
+
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
|
242 |
+
"""Convert 3x4 rotation matrix to 4d quaternion vector
|
243 |
+
|
244 |
+
This algorithm is based on algorithm described in
|
245 |
+
https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
|
246 |
+
|
247 |
+
Args:
|
248 |
+
rotation_matrix (Tensor): the rotation matrix to convert.
|
249 |
+
|
250 |
+
Return:
|
251 |
+
Tensor: the rotation in quaternion
|
252 |
+
|
253 |
+
Shape:
|
254 |
+
- Input: :math:`(N, 3, 4)`
|
255 |
+
- Output: :math:`(N, 4)`
|
256 |
+
|
257 |
+
Example:
|
258 |
+
>>> input = torch.rand(4, 3, 4) # Nx3x4
|
259 |
+
>>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
|
260 |
+
"""
|
261 |
+
if not torch.is_tensor(rotation_matrix):
|
262 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
263 |
+
type(rotation_matrix)))
|
264 |
+
|
265 |
+
if len(rotation_matrix.shape) > 3:
|
266 |
+
raise ValueError(
|
267 |
+
"Input size must be a three dimensional tensor. Got {}".format(
|
268 |
+
rotation_matrix.shape))
|
269 |
+
if not rotation_matrix.shape[-2:] == (3, 4):
|
270 |
+
raise ValueError(
|
271 |
+
"Input size must be a N x 3 x 4 tensor. Got {}".format(
|
272 |
+
rotation_matrix.shape))
|
273 |
+
|
274 |
+
rmat_t = torch.transpose(rotation_matrix, 1, 2)
|
275 |
+
|
276 |
+
mask_d2 = rmat_t[:, 2, 2] < eps
|
277 |
+
|
278 |
+
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
|
279 |
+
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
|
280 |
+
|
281 |
+
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
282 |
+
q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
283 |
+
t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
284 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
|
285 |
+
t0_rep = t0.repeat(4, 1).t()
|
286 |
+
|
287 |
+
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
288 |
+
q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
289 |
+
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
290 |
+
t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
|
291 |
+
t1_rep = t1.repeat(4, 1).t()
|
292 |
+
|
293 |
+
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
294 |
+
q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
|
295 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
296 |
+
rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
|
297 |
+
t2_rep = t2.repeat(4, 1).t()
|
298 |
+
|
299 |
+
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
300 |
+
q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
301 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
302 |
+
rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
|
303 |
+
t3_rep = t3.repeat(4, 1).t()
|
304 |
+
|
305 |
+
mask_c0 = mask_d2 * mask_d0_d1
|
306 |
+
mask_c1 = mask_d2 * torch.logical_not(mask_d0_d1)
|
307 |
+
mask_c2 = torch.logical_not(mask_d2) * mask_d0_nd1
|
308 |
+
mask_c3 = torch.logical_not(mask_d2) * torch.logical_not(mask_d0_nd1)
|
309 |
+
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
|
310 |
+
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
|
311 |
+
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
|
312 |
+
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
|
313 |
+
|
314 |
+
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
|
315 |
+
q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
|
316 |
+
t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
|
317 |
+
q *= 0.5
|
318 |
+
return q
|
319 |
+
|
320 |
+
|
321 |
+
def quaternion_to_angle_axis(quaternion) -> torch.Tensor:
|
322 |
+
"""Convert quaternion vector to angle axis of rotation.
|
323 |
+
|
324 |
+
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
|
325 |
+
|
326 |
+
Args:
|
327 |
+
quaternion (torch.Tensor): tensor with quaternions.
|
328 |
+
|
329 |
+
Return:
|
330 |
+
torch.Tensor: tensor with angle axis of rotation.
|
331 |
+
|
332 |
+
Shape:
|
333 |
+
- Input: :math:`(*, 4)` where `*` means, any number of dimensions
|
334 |
+
- Output: :math:`(*, 3)`
|
335 |
+
|
336 |
+
Example:
|
337 |
+
>>> quaternion = torch.rand(2, 4) # Nx4
|
338 |
+
>>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
|
339 |
+
"""
|
340 |
+
if not torch.is_tensor(quaternion):
|
341 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
342 |
+
type(quaternion)))
|
343 |
+
|
344 |
+
if not quaternion.shape[-1] == 4:
|
345 |
+
raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
|
346 |
+
.format(quaternion.shape))
|
347 |
+
# unpack input and compute conversion
|
348 |
+
q1 = quaternion[..., 1]
|
349 |
+
q2 = quaternion[..., 2]
|
350 |
+
q3 = quaternion[..., 3]
|
351 |
+
sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3
|
352 |
+
|
353 |
+
sin_theta = torch.sqrt(sin_squared_theta)
|
354 |
+
cos_theta = quaternion[..., 0]
|
355 |
+
two_theta = 2.0 * torch.where(
|
356 |
+
cos_theta < 0.0,
|
357 |
+
torch.atan2(-sin_theta, -cos_theta),
|
358 |
+
torch.atan2(sin_theta, cos_theta))
|
359 |
+
|
360 |
+
k_pos = two_theta / sin_theta
|
361 |
+
k_neg = 2.0 * torch.ones_like(sin_theta)
|
362 |
+
k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
|
363 |
+
|
364 |
+
angle_axis = torch.zeros_like(quaternion)[..., :3]
|
365 |
+
angle_axis[..., 0] += q1 * k
|
366 |
+
angle_axis[..., 1] += q2 * k
|
367 |
+
angle_axis[..., 2] += q3 * k
|
368 |
+
return angle_axis
|
369 |
+
|
370 |
+
# based on:
|
371 |
+
# https://github.com/facebookresearch/QuaterNet/blob/master/common/quaternion.py#L138
|
372 |
+
|
373 |
+
|
374 |
+
def angle_axis_to_quaternion(angle_axis) -> torch.Tensor:
|
375 |
+
"""Convert an angle axis to a quaternion.
|
376 |
+
|
377 |
+
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
|
378 |
+
|
379 |
+
Args:
|
380 |
+
angle_axis (torch.Tensor): tensor with angle axis.
|
381 |
+
|
382 |
+
Return:
|
383 |
+
torch.Tensor: tensor with quaternion.
|
384 |
+
|
385 |
+
Shape:
|
386 |
+
- Input: :math:`(*, 3)` where `*` means, any number of dimensions
|
387 |
+
- Output: :math:`(*, 4)`
|
388 |
+
|
389 |
+
Example:
|
390 |
+
>>> angle_axis = torch.rand(2, 4) # Nx4
|
391 |
+
>>> quaternion = tgm.angle_axis_to_quaternion(angle_axis) # Nx3
|
392 |
+
"""
|
393 |
+
if not torch.is_tensor(angle_axis):
|
394 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
395 |
+
type(angle_axis)))
|
396 |
+
|
397 |
+
if not angle_axis.shape[-1] == 3:
|
398 |
+
raise ValueError("Input must be a tensor of shape Nx3 or 3. Got {}"
|
399 |
+
.format(angle_axis.shape))
|
400 |
+
# unpack input and compute conversion
|
401 |
+
a0 = angle_axis[..., 0:1]
|
402 |
+
a1 = angle_axis[..., 1:2]
|
403 |
+
a2 = angle_axis[..., 2:3]
|
404 |
+
theta_squared = a0 * a0 + a1 * a1 + a2 * a2
|
405 |
+
|
406 |
+
theta = torch.sqrt(theta_squared)
|
407 |
+
half_theta = theta * 0.5
|
408 |
+
|
409 |
+
mask = theta_squared > 0.0
|
410 |
+
ones = torch.ones_like(half_theta)
|
411 |
+
|
412 |
+
k_neg = 0.5 * ones
|
413 |
+
k_pos = torch.sin(half_theta) / theta
|
414 |
+
k = torch.where(mask, k_pos, k_neg)
|
415 |
+
w = torch.where(mask, torch.cos(half_theta), ones)
|
416 |
+
|
417 |
+
quaternion = torch.zeros_like(angle_axis)
|
418 |
+
quaternion[..., 0:1] += a0 * k
|
419 |
+
quaternion[..., 1:2] += a1 * k
|
420 |
+
quaternion[..., 2:3] += a2 * k
|
421 |
+
return torch.cat([w, quaternion], dim=-1)
|
422 |
+
|
423 |
+
# TODO: add below funtionalities
|
424 |
+
# - pose_to_rtvec
|
425 |
+
|
426 |
+
|
427 |
+
# layer api
|
428 |
+
|
429 |
+
|
430 |
+
class RadToDeg(nn.Module):
|
431 |
+
r"""Creates an object that converts angles from radians to degrees.
|
432 |
+
|
433 |
+
Args:
|
434 |
+
tensor (Tensor): Tensor of arbitrary shape.
|
435 |
+
|
436 |
+
Returns:
|
437 |
+
Tensor: Tensor with same shape as input.
|
438 |
+
|
439 |
+
Examples::
|
440 |
+
|
441 |
+
>>> input = tgm.pi * torch.rand(1, 3, 3)
|
442 |
+
>>> output = tgm.RadToDeg()(input)
|
443 |
+
"""
|
444 |
+
|
445 |
+
def __init__(self):
|
446 |
+
super(RadToDeg, self).__init__()
|
447 |
+
|
448 |
+
def forward(self, input):
|
449 |
+
return rad2deg(input)
|
450 |
+
|
451 |
+
|
452 |
+
class DegToRad(nn.Module):
|
453 |
+
r"""Function that converts angles from degrees to radians.
|
454 |
+
|
455 |
+
Args:
|
456 |
+
tensor (Tensor): Tensor of arbitrary shape.
|
457 |
+
|
458 |
+
Returns:
|
459 |
+
Tensor: Tensor with same shape as input.
|
460 |
+
|
461 |
+
Examples::
|
462 |
+
|
463 |
+
>>> input = 360. * torch.rand(1, 3, 3)
|
464 |
+
>>> output = tgm.DegToRad()(input)
|
465 |
+
"""
|
466 |
+
|
467 |
+
def __init__(self):
|
468 |
+
super(DegToRad, self).__init__()
|
469 |
+
|
470 |
+
def forward(self, input):
|
471 |
+
return deg2rad(input)
|
472 |
+
|
473 |
+
|
474 |
+
class ConvertPointsFromHomogeneous(nn.Module):
|
475 |
+
r"""Creates a transformation that converts points from homogeneous to
|
476 |
+
Euclidean space.
|
477 |
+
|
478 |
+
Args:
|
479 |
+
points (Tensor): tensor of N-dimensional points.
|
480 |
+
|
481 |
+
Returns:
|
482 |
+
Tensor: tensor of N-1-dimensional points.
|
483 |
+
|
484 |
+
Shape:
|
485 |
+
- Input: :math:`(B, D, N)` or :math:`(D, N)`
|
486 |
+
- Output: :math:`(B, D, N + 1)` or :math:`(D, N + 1)`
|
487 |
+
|
488 |
+
Examples::
|
489 |
+
|
490 |
+
>>> input = torch.rand(2, 4, 3) # BxNx3
|
491 |
+
>>> transform = tgm.ConvertPointsFromHomogeneous()
|
492 |
+
>>> output = transform(input) # BxNx2
|
493 |
+
"""
|
494 |
+
|
495 |
+
def __init__(self):
|
496 |
+
super(ConvertPointsFromHomogeneous, self).__init__()
|
497 |
+
|
498 |
+
def forward(self, input):
|
499 |
+
return convert_points_from_homogeneous(input)
|
500 |
+
|
501 |
+
|
502 |
+
class ConvertPointsToHomogeneous(nn.Module):
|
503 |
+
r"""Creates a transformation to convert points from Euclidean to
|
504 |
+
homogeneous space.
|
505 |
+
|
506 |
+
Args:
|
507 |
+
points (Tensor): tensor of N-dimensional points.
|
508 |
+
|
509 |
+
Returns:
|
510 |
+
Tensor: tensor of N+1-dimensional points.
|
511 |
+
|
512 |
+
Shape:
|
513 |
+
- Input: :math:`(B, D, N)` or :math:`(D, N)`
|
514 |
+
- Output: :math:`(B, D, N + 1)` or :math:`(D, N + 1)`
|
515 |
+
|
516 |
+
Examples::
|
517 |
+
|
518 |
+
>>> input = torch.rand(2, 4, 3) # BxNx3
|
519 |
+
>>> transform = tgm.ConvertPointsToHomogeneous()
|
520 |
+
>>> output = transform(input) # BxNx4
|
521 |
+
"""
|
522 |
+
|
523 |
+
def __init__(self):
|
524 |
+
super(ConvertPointsToHomogeneous, self).__init__()
|
525 |
+
|
526 |
+
def forward(self, input):
|
527 |
+
return convert_points_to_homogeneous(input)
|
mogen/datasets/human_body_prior/train/README.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Train VPoser from Scratch
|
2 |
+
To train your own VPoser with new configuration duplicate the provided **V02_05** folder while setting a new experiment ID
|
3 |
+
and change the settings as you desire.
|
4 |
+
First you would need to download the
|
5 |
+
[AMASS](https://amass.is.tue.mpg.de/) dataset, then following the [data preparation tutorial](../data/README.md)
|
6 |
+
prepare the data for training.
|
7 |
+
Following is a code snippet for training that can be found in the [example training experiment](https://github.com/nghorbani/human_body_prior/blob/master/src/human_body_prior/train/V02_05/V02_05.py):
|
8 |
+
|
9 |
+
```python
|
10 |
+
import glob
|
11 |
+
import os.path as osp
|
12 |
+
|
13 |
+
from human_body_prior.tools.configurations import load_config
|
14 |
+
from human_body_prior.train.vposer_trainer import train_vposer_once
|
15 |
+
|
16 |
+
def main():
|
17 |
+
expr_id = 'V02_05'
|
18 |
+
|
19 |
+
default_ps_fname = glob.glob(osp.join(osp.dirname(__file__), '*.yaml'))[0]
|
20 |
+
|
21 |
+
vp_ps = load_config(default_ps_fname)
|
22 |
+
|
23 |
+
vp_ps.train_parms.batch_size = 128
|
24 |
+
|
25 |
+
vp_ps.general.expr_id = expr_id
|
26 |
+
|
27 |
+
total_jobs = []
|
28 |
+
total_jobs.append(vp_ps.toDict().copy())
|
29 |
+
|
30 |
+
print('#training_jobs to be done: {}'.format(len(total_jobs)))
|
31 |
+
if len(total_jobs) == 0:
|
32 |
+
print('No jobs to be done')
|
33 |
+
return
|
34 |
+
|
35 |
+
for job in total_jobs:
|
36 |
+
train_vposer_once(job)
|
37 |
+
```
|
38 |
+
The above code uses yaml configuration files to handle experiment settings.
|
39 |
+
It loads the default settings in *<expr_id>.yaml* and overloads it with your new args.
|
40 |
+
|
41 |
+
The training code, will dump a log file along with tensorboard readable events file.
|
mogen/datasets/human_body_prior/train/V02_05/V02_05.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2020.12.12
|
23 |
+
|
24 |
+
import glob
|
25 |
+
import os.path as osp
|
26 |
+
|
27 |
+
from human_body_prior.tools.configurations import load_config
|
28 |
+
from human_body_prior.train.vposer_trainer import train_vposer_once
|
29 |
+
|
30 |
+
def main():
|
31 |
+
expr_id = 'V02_05'
|
32 |
+
|
33 |
+
default_ps_fname = glob.glob(osp.join(osp.dirname(__file__), '*.yaml'))[0]
|
34 |
+
|
35 |
+
vp_ps = load_config(default_ps_fname)
|
36 |
+
|
37 |
+
vp_ps.train_parms.batch_size = 128
|
38 |
+
|
39 |
+
vp_ps.general.expr_id = expr_id
|
40 |
+
|
41 |
+
total_jobs = []
|
42 |
+
total_jobs.append(vp_ps.toDict().copy())
|
43 |
+
|
44 |
+
print('#training_jobs to be done: {}'.format(len(total_jobs)))
|
45 |
+
if len(total_jobs) == 0:
|
46 |
+
print('No jobs to be done')
|
47 |
+
return
|
48 |
+
|
49 |
+
for job in total_jobs:
|
50 |
+
train_vposer_once(job)
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
main()
|
mogen/datasets/human_body_prior/train/V02_05/V02_05.yaml
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
body_model:
|
3 |
+
gender: neutral
|
4 |
+
bm_fname: ../../../../support_data/dowloads/models/smplx/neutral/model.npz
|
5 |
+
|
6 |
+
general:
|
7 |
+
verbosity: 0
|
8 |
+
expr_id:
|
9 |
+
dataset_id: V02_03 #SMPLx neutral
|
10 |
+
rnd_seed: 100
|
11 |
+
work_basedir: ../../../../support_data/training/training_experiments
|
12 |
+
dataset_basedir: ../../../../support_data/training/data
|
13 |
+
|
14 |
+
logging:
|
15 |
+
expr_msg:
|
16 |
+
num_bodies_to_display: 25
|
17 |
+
work_dir:
|
18 |
+
dataset_dir:
|
19 |
+
render_during_training: False
|
20 |
+
best_model_fname:
|
21 |
+
|
22 |
+
train_parms:
|
23 |
+
batch_size:
|
24 |
+
num_epochs: 100
|
25 |
+
restore_optimizer: False
|
26 |
+
gen_optimizer:
|
27 |
+
type: Adam
|
28 |
+
args:
|
29 |
+
lr: 0.001
|
30 |
+
weight_decay: 0.00001
|
31 |
+
lr_scheduler:
|
32 |
+
type: ReduceLROnPlateau
|
33 |
+
args:
|
34 |
+
# metrics: val_loss
|
35 |
+
verbose: true
|
36 |
+
patience: 5
|
37 |
+
early_stopping:
|
38 |
+
monitor: val_loss
|
39 |
+
min_delta: 0.0
|
40 |
+
patience: 10
|
41 |
+
verbose: True
|
42 |
+
mode: min
|
43 |
+
keep_extra_loss_terms_until_epoch: 15
|
44 |
+
loss_weights:
|
45 |
+
loss_kl_wt: 0.005
|
46 |
+
loss_rec_wt: 4
|
47 |
+
loss_matrot_wt: 2
|
48 |
+
loss_jtr_wt: 2
|
49 |
+
|
50 |
+
|
51 |
+
data_parms:
|
52 |
+
num_workers: 5 # Used for dataloaders
|
53 |
+
amass_dir: support_data/dowloads/amass/smplx_neutral
|
54 |
+
num_timeseq_frames: 1
|
55 |
+
amass_splits:
|
56 |
+
vald:
|
57 |
+
# - HumanEva
|
58 |
+
# - MPI_HDM05
|
59 |
+
# - SFU
|
60 |
+
# - MPI_mosh
|
61 |
+
- BMLrub_vald
|
62 |
+
train:
|
63 |
+
- CMU
|
64 |
+
- BMLrub_train
|
65 |
+
# - MPI_Limits
|
66 |
+
# - TotalCapture
|
67 |
+
# - Eyes_Japan_Dataset
|
68 |
+
# - KIT
|
69 |
+
# - BMLrub
|
70 |
+
# - EKUT
|
71 |
+
# - TCD_handMocap
|
72 |
+
# - ACCAD
|
73 |
+
# - BMLmovi
|
74 |
+
test:
|
75 |
+
- BMLrub_test
|
76 |
+
# - Transitions_mocap
|
77 |
+
# - SSM_synced
|
78 |
+
# - DFaust_67
|
79 |
+
|
80 |
+
|
81 |
+
model_params:
|
82 |
+
num_neurons : 512
|
83 |
+
latentD : 32
|
84 |
+
|
mogen/datasets/human_body_prior/train/V02_05/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2020.12.12
|
mogen/datasets/human_body_prior/train/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2018.01.02
|
mogen/datasets/human_body_prior/train/vposer_trainer.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2020.12.12
|
23 |
+
|
24 |
+
# from pytorch_lightning import Trainer
|
25 |
+
|
26 |
+
import glob
|
27 |
+
import os
|
28 |
+
import os.path as osp
|
29 |
+
from datetime import datetime as dt
|
30 |
+
from pytorch_lightning.plugins import DDPPlugin
|
31 |
+
|
32 |
+
import numpy as np
|
33 |
+
import pytorch_lightning as pl
|
34 |
+
import torch
|
35 |
+
from human_body_prior.body_model.body_model import BodyModel
|
36 |
+
from human_body_prior.data.dataloader import VPoserDS
|
37 |
+
from human_body_prior.data.prepare_data import dataset_exists
|
38 |
+
from human_body_prior.data.prepare_data import prepare_vposer_datasets
|
39 |
+
from human_body_prior.models.vposer_model import VPoser
|
40 |
+
from human_body_prior.tools.angle_continuous_repres import geodesic_loss_R
|
41 |
+
from human_body_prior.tools.configurations import load_config, dump_config
|
42 |
+
from human_body_prior.tools.omni_tools import copy2cpu as c2c
|
43 |
+
from human_body_prior.tools.omni_tools import get_support_data_dir
|
44 |
+
from human_body_prior.tools.omni_tools import log2file
|
45 |
+
from human_body_prior.tools.omni_tools import make_deterministic
|
46 |
+
from human_body_prior.tools.omni_tools import makepath
|
47 |
+
from human_body_prior.tools.rotation_tools import aa2matrot
|
48 |
+
from human_body_prior.visualizations.training_visualization import vposer_trainer_renderer
|
49 |
+
from pytorch_lightning.callbacks import LearningRateMonitor
|
50 |
+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
51 |
+
|
52 |
+
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
53 |
+
from pytorch_lightning.core import LightningModule
|
54 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
55 |
+
from pytorch_lightning.utilities import rank_zero_only
|
56 |
+
from torch import optim as optim_module
|
57 |
+
from torch.optim import lr_scheduler as lr_sched_module
|
58 |
+
from torch.utils.data import DataLoader
|
59 |
+
|
60 |
+
|
61 |
+
class VPoserTrainer(LightningModule):
|
62 |
+
"""
|
63 |
+
|
64 |
+
It includes all data loading and train / val logic., and it is used for both training and testing models.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, _config):
|
68 |
+
super(VPoserTrainer, self).__init__()
|
69 |
+
|
70 |
+
_support_data_dir = get_support_data_dir()
|
71 |
+
|
72 |
+
vp_ps = load_config(**_config)
|
73 |
+
|
74 |
+
make_deterministic(vp_ps.general.rnd_seed)
|
75 |
+
|
76 |
+
self.expr_id = vp_ps.general.expr_id
|
77 |
+
self.dataset_id = vp_ps.general.dataset_id
|
78 |
+
|
79 |
+
self.work_dir = vp_ps.logging.work_dir = makepath(vp_ps.general.work_basedir, self.expr_id)
|
80 |
+
self.dataset_dir = vp_ps.logging.dataset_dir = osp.join(vp_ps.general.dataset_basedir, vp_ps.general.dataset_id)
|
81 |
+
|
82 |
+
self._log_prefix = '[{}]'.format(self.expr_id)
|
83 |
+
self.text_logger = log2file(prefix=self._log_prefix)
|
84 |
+
|
85 |
+
self.seq_len = vp_ps.data_parms.num_timeseq_frames
|
86 |
+
|
87 |
+
self.vp_model = VPoser(vp_ps)
|
88 |
+
|
89 |
+
with torch.no_grad():
|
90 |
+
|
91 |
+
self.bm_train = BodyModel(vp_ps.body_model.bm_fname)
|
92 |
+
|
93 |
+
if vp_ps.logging.render_during_training:
|
94 |
+
self.renderer = vposer_trainer_renderer(self.bm_train, vp_ps.logging.num_bodies_to_display)
|
95 |
+
else:
|
96 |
+
self.renderer = None
|
97 |
+
|
98 |
+
self.example_input_array = {'pose_body':torch.ones(vp_ps.train_parms.batch_size, 63),}
|
99 |
+
self.vp_ps = vp_ps
|
100 |
+
|
101 |
+
def forward(self, pose_body):
|
102 |
+
|
103 |
+
return self.vp_model(pose_body)
|
104 |
+
|
105 |
+
def _get_data(self, split_name):
|
106 |
+
|
107 |
+
assert split_name in ('train', 'vald', 'test')
|
108 |
+
|
109 |
+
split_name = split_name.replace('vald', 'vald')
|
110 |
+
|
111 |
+
assert dataset_exists(self.dataset_dir), FileNotFoundError('Dataset does not exist dataset_dir = {}'.format(self.dataset_dir))
|
112 |
+
dataset = VPoserDS(osp.join(self.dataset_dir, split_name), data_fields = ['pose_body'])
|
113 |
+
|
114 |
+
assert len(dataset) != 0, ValueError('Dataset has nothing in it!')
|
115 |
+
|
116 |
+
return DataLoader(dataset,
|
117 |
+
batch_size=self.vp_ps.train_parms.batch_size,
|
118 |
+
shuffle=True if split_name == 'train' else False,
|
119 |
+
num_workers=self.vp_ps.data_parms.num_workers,
|
120 |
+
pin_memory=True)
|
121 |
+
|
122 |
+
@rank_zero_only
|
123 |
+
def on_train_start(self):
|
124 |
+
if self.global_rank != 0: return
|
125 |
+
self.train_starttime = dt.now().replace(microsecond=0)
|
126 |
+
|
127 |
+
######## make a backup of vposer
|
128 |
+
git_repo_dir = os.path.abspath(__file__).split('/')
|
129 |
+
git_repo_dir = '/'.join(git_repo_dir[:git_repo_dir.index('human_body_prior') + 1])
|
130 |
+
starttime = dt.strftime(self.train_starttime, '%Y_%m_%d_%H_%M_%S')
|
131 |
+
archive_path = makepath(self.work_dir, 'code', 'vposer_{}.tar.gz'.format(starttime), isfile=True)
|
132 |
+
cmd = 'cd %s && git ls-files -z | xargs -0 tar -czf %s' % (git_repo_dir, archive_path)
|
133 |
+
os.system(cmd)
|
134 |
+
########
|
135 |
+
self.text_logger('Created a git archive backup at {}'.format(archive_path))
|
136 |
+
dump_config(self.vp_ps, osp.join(self.work_dir, '{}.yaml'.format(self.expr_id)))
|
137 |
+
|
138 |
+
def train_dataloader(self):
|
139 |
+
return self._get_data('train')
|
140 |
+
|
141 |
+
def val_dataloader(self):
|
142 |
+
return self._get_data('vald')
|
143 |
+
|
144 |
+
def configure_optimizers(self):
|
145 |
+
params_count = lambda params: sum(p.numel() for p in params if p.requires_grad)
|
146 |
+
|
147 |
+
gen_params = [a[1] for a in self.vp_model.named_parameters() if a[1].requires_grad]
|
148 |
+
gen_optimizer_class = getattr(optim_module, self.vp_ps.train_parms.gen_optimizer.type)
|
149 |
+
gen_optimizer = gen_optimizer_class(gen_params, **self.vp_ps.train_parms.gen_optimizer.args)
|
150 |
+
|
151 |
+
self.text_logger('Total Trainable Parameters Count in vp_model is %2.2f M.' % (params_count(gen_params) * 1e-6))
|
152 |
+
|
153 |
+
lr_sched_class = getattr(lr_sched_module, self.vp_ps.train_parms.lr_scheduler.type)
|
154 |
+
|
155 |
+
gen_lr_scheduler = lr_sched_class(gen_optimizer, **self.vp_ps.train_parms.lr_scheduler.args)
|
156 |
+
|
157 |
+
schedulers = [
|
158 |
+
{
|
159 |
+
'scheduler': gen_lr_scheduler,
|
160 |
+
'monitor': 'val_loss',
|
161 |
+
'interval': 'epoch',
|
162 |
+
'frequency': 1
|
163 |
+
},
|
164 |
+
]
|
165 |
+
return [gen_optimizer], schedulers
|
166 |
+
|
167 |
+
def _compute_loss(self, dorig, drec):
|
168 |
+
l1_loss = torch.nn.L1Loss(reduction='mean')
|
169 |
+
geodesic_loss = geodesic_loss_R(reduction='mean')
|
170 |
+
|
171 |
+
bs, latentD = drec['poZ_body_mean'].shape
|
172 |
+
device = drec['poZ_body_mean'].device
|
173 |
+
|
174 |
+
loss_kl_wt = self.vp_ps.train_parms.loss_weights.loss_kl_wt
|
175 |
+
loss_rec_wt = self.vp_ps.train_parms.loss_weights.loss_rec_wt
|
176 |
+
loss_matrot_wt = self.vp_ps.train_parms.loss_weights.loss_matrot_wt
|
177 |
+
loss_jtr_wt = self.vp_ps.train_parms.loss_weights.loss_jtr_wt
|
178 |
+
|
179 |
+
# q_z = torch.distributions.normal.Normal(drec['mean'], drec['std'])
|
180 |
+
q_z = drec['q_z']
|
181 |
+
# dorig['fullpose'] = torch.cat([dorig['root_orient'], dorig['pose_body']], dim=-1)
|
182 |
+
|
183 |
+
# Reconstruction loss - L1 on the output mesh
|
184 |
+
with torch.no_grad():
|
185 |
+
bm_orig = self.bm_train(pose_body=dorig['pose_body'])
|
186 |
+
|
187 |
+
bm_rec = self.bm_train(pose_body=drec['pose_body'].contiguous().view(bs, -1))
|
188 |
+
|
189 |
+
v2v = l1_loss(bm_rec.v, bm_orig.v)
|
190 |
+
|
191 |
+
# KL loss
|
192 |
+
p_z = torch.distributions.normal.Normal(
|
193 |
+
loc=torch.zeros((bs, latentD), device=device, requires_grad=False),
|
194 |
+
scale=torch.ones((bs, latentD), device=device, requires_grad=False))
|
195 |
+
weighted_loss_dict = {
|
196 |
+
'loss_kl':loss_kl_wt * torch.mean(torch.sum(torch.distributions.kl.kl_divergence(q_z, p_z), dim=[1])),
|
197 |
+
'loss_mesh_rec': loss_rec_wt * v2v
|
198 |
+
}
|
199 |
+
|
200 |
+
if (self.current_epoch < self.vp_ps.train_parms.keep_extra_loss_terms_until_epoch):
|
201 |
+
# breakpoint()
|
202 |
+
weighted_loss_dict['matrot'] = loss_matrot_wt * geodesic_loss(drec['pose_body_matrot'].view(-1,3,3), aa2matrot(dorig['pose_body'].view(-1, 3)))
|
203 |
+
weighted_loss_dict['jtr'] = loss_jtr_wt * l1_loss(bm_rec.Jtr, bm_orig.Jtr)
|
204 |
+
|
205 |
+
weighted_loss_dict['loss_total'] = torch.stack(list(weighted_loss_dict.values())).sum()
|
206 |
+
|
207 |
+
with torch.no_grad():
|
208 |
+
unweighted_loss_dict = {'v2v': torch.sqrt(torch.pow(bm_rec.v-bm_orig.v, 2).sum(-1)).mean()}
|
209 |
+
unweighted_loss_dict['loss_total'] = torch.cat(
|
210 |
+
list({k: v.view(-1) for k, v in unweighted_loss_dict.items()}.values()), dim=-1).sum().view(1)
|
211 |
+
|
212 |
+
return {'weighted_loss': weighted_loss_dict, 'unweighted_loss': unweighted_loss_dict}
|
213 |
+
|
214 |
+
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
215 |
+
|
216 |
+
drec = self(batch['pose_body'].view(-1, 63))
|
217 |
+
|
218 |
+
loss = self._compute_loss(batch, drec)
|
219 |
+
|
220 |
+
train_loss = loss['weighted_loss']['loss_total']
|
221 |
+
|
222 |
+
tensorboard_logs = {'train_loss': train_loss}
|
223 |
+
progress_bar = {k: c2c(v) for k, v in loss['weighted_loss'].items()}
|
224 |
+
return {'loss': train_loss, 'progress_bar':progress_bar, 'log': tensorboard_logs}
|
225 |
+
|
226 |
+
def validation_step(self, batch, batch_idx):
|
227 |
+
|
228 |
+
drec = self(batch['pose_body'].view(-1, 63))
|
229 |
+
|
230 |
+
loss = self._compute_loss(batch, drec)
|
231 |
+
val_loss = loss['unweighted_loss']['loss_total']
|
232 |
+
|
233 |
+
if self.renderer is not None and self.global_rank == 0 and batch_idx % 500==0 and np.random.rand()>0.5:
|
234 |
+
out_fname = makepath(self.work_dir, 'renders/vald_rec_E{:03d}_It{:04d}_val_loss_{:.2f}.png'.format(self.current_epoch, batch_idx, val_loss.item()), isfile=True)
|
235 |
+
self.renderer([batch, drec], out_fname = out_fname)
|
236 |
+
dgen = self.vp_model.sample_poses(self.vp_ps.logging.num_bodies_to_display)
|
237 |
+
out_fname = makepath(self.work_dir, 'renders/vald_gen_E{:03d}_I{:04d}.png'.format(self.current_epoch, batch_idx), isfile=True)
|
238 |
+
self.renderer([dgen], out_fname = out_fname)
|
239 |
+
|
240 |
+
|
241 |
+
progress_bar = {'v2v': val_loss}
|
242 |
+
return {'val_loss': c2c(val_loss), 'progress_bar': progress_bar, 'log': progress_bar}
|
243 |
+
|
244 |
+
def validation_epoch_end(self, outputs):
|
245 |
+
metrics = {'val_loss': np.nanmean(np.concatenate([v['val_loss'] for v in outputs])) }
|
246 |
+
|
247 |
+
if self.global_rank == 0:
|
248 |
+
|
249 |
+
self.text_logger('Epoch {}: {}'.format(self.current_epoch, ', '.join('{}:{:.2f}'.format(k, v) for k, v in metrics.items())))
|
250 |
+
self.text_logger('lr is {}'.format([pg['lr'] for opt in self.trainer.optimizers for pg in opt.param_groups]))
|
251 |
+
|
252 |
+
metrics = {k: torch.as_tensor(v) for k, v in metrics.items()}
|
253 |
+
|
254 |
+
return {'val_loss': metrics['val_loss'], 'log': metrics}
|
255 |
+
|
256 |
+
|
257 |
+
@rank_zero_only
|
258 |
+
def on_train_end(self):
|
259 |
+
|
260 |
+
self.train_endtime = dt.now().replace(microsecond=0)
|
261 |
+
endtime = dt.strftime(self.train_endtime, '%Y_%m_%d_%H_%M_%S')
|
262 |
+
elapsedtime = self.train_endtime - self.train_starttime
|
263 |
+
self.vp_ps.logging.best_model_fname = self.trainer.checkpoint_callback.best_model_path
|
264 |
+
|
265 |
+
self.text_logger('Epoch {} - Finished training at {} after {}'.format(self.current_epoch, endtime, elapsedtime))
|
266 |
+
self.text_logger('best_model_fname: {}'.format(self.vp_ps.logging.best_model_fname))
|
267 |
+
|
268 |
+
dump_config(self.vp_ps, osp.join(self.work_dir, '{}_{}.yaml'.format(self.expr_id, self.dataset_id)))
|
269 |
+
self.hparams = self.vp_ps.toDict()
|
270 |
+
|
271 |
+
@rank_zero_only
|
272 |
+
def prepare_data(self):
|
273 |
+
'''' Similar to standard AMASS dataset preparation pipeline:
|
274 |
+
Donwload npz file, corresponding to body data from https://amass.is.tue.mpg.de/ and place them under amass_dir
|
275 |
+
'''
|
276 |
+
self.text_logger = log2file(makepath(self.work_dir, '{}.log'.format(self.expr_id), isfile=True), prefix=self._log_prefix)
|
277 |
+
|
278 |
+
prepare_vposer_datasets(self.dataset_dir, self.vp_ps.data_parms.amass_splits, self.vp_ps.data_parms.amass_dir, logger=self.text_logger)
|
279 |
+
|
280 |
+
|
281 |
+
def create_expr_message(ps):
|
282 |
+
expr_msg = '[{}] batch_size = {}.'.format(ps.general.expr_id, ps.train_parms.batch_size)
|
283 |
+
|
284 |
+
return expr_msg
|
285 |
+
|
286 |
+
|
287 |
+
def train_vposer_once(_config):
|
288 |
+
|
289 |
+
resume_training_if_possible = True
|
290 |
+
|
291 |
+
model = VPoserTrainer(_config)
|
292 |
+
model.vp_ps.logging.expr_msg = create_expr_message(model.vp_ps)
|
293 |
+
# model.text_logger(model.vp_ps.logging.expr_msg.replace(". ", '.\n'))
|
294 |
+
dump_config(model.vp_ps, osp.join(model.work_dir, '{}.yaml'.format(model.expr_id)))
|
295 |
+
|
296 |
+
logger = TensorBoardLogger(model.work_dir, name='tensorboard')
|
297 |
+
lr_monitor = LearningRateMonitor()
|
298 |
+
|
299 |
+
snapshots_dir = osp.join(model.work_dir, 'snapshots')
|
300 |
+
checkpoint_callback = ModelCheckpoint(
|
301 |
+
dirpath=makepath(snapshots_dir, isfile=True),
|
302 |
+
filename="%s_{epoch:02d}_{val_loss:.2f}" % model.expr_id,
|
303 |
+
save_top_k=1,
|
304 |
+
verbose=True,
|
305 |
+
monitor='val_loss',
|
306 |
+
mode='min',
|
307 |
+
)
|
308 |
+
early_stop_callback = EarlyStopping(**model.vp_ps.train_parms.early_stopping)
|
309 |
+
|
310 |
+
resume_from_checkpoint = None
|
311 |
+
if resume_training_if_possible:
|
312 |
+
available_ckpts = sorted(glob.glob(osp.join(snapshots_dir, '*.ckpt')), key=os.path.getmtime)
|
313 |
+
if len(available_ckpts)>0:
|
314 |
+
resume_from_checkpoint = available_ckpts[-1]
|
315 |
+
model.text_logger('Resuming the training from {}'.format(resume_from_checkpoint))
|
316 |
+
|
317 |
+
trainer = pl.Trainer(gpus=1,
|
318 |
+
weights_summary='top',
|
319 |
+
distributed_backend = 'ddp',
|
320 |
+
# replace_sampler_ddp=False,
|
321 |
+
# accumulate_grad_batches=4,
|
322 |
+
# profiler=False,
|
323 |
+
# overfit_batches=0.05,
|
324 |
+
# fast_dev_run = True,
|
325 |
+
# limit_train_batches=0.02,
|
326 |
+
# limit_val_batches=0.02,
|
327 |
+
# num_sanity_val_steps=2,
|
328 |
+
plugins=[DDPPlugin(find_unused_parameters=False)],
|
329 |
+
|
330 |
+
callbacks=[lr_monitor, early_stop_callback, checkpoint_callback],
|
331 |
+
|
332 |
+
max_epochs=model.vp_ps.train_parms.num_epochs,
|
333 |
+
logger=logger,
|
334 |
+
resume_from_checkpoint=resume_from_checkpoint
|
335 |
+
)
|
336 |
+
|
337 |
+
trainer.fit(model)
|
mogen/datasets/human_body_prior/visualizations/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2020.12.12
|
mogen/datasets/human_body_prior/visualizations/training_visualization.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
|
4 |
+
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
|
5 |
+
# Max Planck Institute for Biological Cybernetics. All rights reserved.
|
6 |
+
#
|
7 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
|
8 |
+
# on this computer program. You can only use this computer program if you have closed a license agreement
|
9 |
+
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
|
10 |
+
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
|
11 |
+
# Contact: [email protected]
|
12 |
+
#
|
13 |
+
#
|
14 |
+
# If you use this code in a research publication please consider citing the following:
|
15 |
+
#
|
16 |
+
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# Code Developed by:
|
20 |
+
# Nima Ghorbani <https://nghorbani.github.io/>
|
21 |
+
#
|
22 |
+
# 2020.12.12
|
23 |
+
|
24 |
+
def pyrenderer(imw=2048, imh=2048):
|
25 |
+
|
26 |
+
from body_visualizer.mesh.mesh_viewer import MeshViewer
|
27 |
+
import cv2
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import trimesh
|
31 |
+
|
32 |
+
try:
|
33 |
+
mv = MeshViewer(width=imw, height=imh, use_offscreen=True)
|
34 |
+
except:
|
35 |
+
import os
|
36 |
+
os.environ['PYOPENGL_PLATFORM'] = 'egl'
|
37 |
+
os.environ['EGL_DEVICE_ID'] = os.environ['GPU_DEVICE_ORDINAL'].split(',')[0]
|
38 |
+
|
39 |
+
mv = MeshViewer(width=imw, height=imh, use_offscreen=True)
|
40 |
+
|
41 |
+
mv.set_cam_trans([0, -0.5, 2.])
|
42 |
+
|
43 |
+
def render_an_image(meshes):
|
44 |
+
n_all = len(meshes)
|
45 |
+
nc = int(np.sqrt(n_all))
|
46 |
+
|
47 |
+
out_image = np.zeros([1, 1, 1, mv.width, mv.height, 4])
|
48 |
+
|
49 |
+
scale_percent = 100./nc
|
50 |
+
width = int(mv.width * scale_percent / 100)
|
51 |
+
height = int(mv.height * scale_percent / 100)
|
52 |
+
dim = (width, height)
|
53 |
+
|
54 |
+
for rId in range(nc):
|
55 |
+
for cId in range(nc):
|
56 |
+
i = (nc*rId) + cId
|
57 |
+
if i>len(meshes): break
|
58 |
+
|
59 |
+
mesh = meshes[i]
|
60 |
+
|
61 |
+
# mesh.apply_transform(trimesh.transformations.rotation_matrix(np.radians(-90), (1, 0, 0)))
|
62 |
+
mesh.vertices -= np.median(np.array(mesh.vertices), axis=0)
|
63 |
+
mv.set_dynamic_meshes([mesh])
|
64 |
+
img = mv.render(render_wireframe=False, RGBA=True)
|
65 |
+
img_resized = cv2.resize(img, dim, interpolation=cv2.INTER_AREA)
|
66 |
+
|
67 |
+
out_image[0, 0, 0, (rId*width):((rId+1)*width), (cId*height):((cId+1)*height)] = cv2.cvtColor(img_resized, cv2.COLOR_BGRA2RGBA)
|
68 |
+
|
69 |
+
return out_image.astype(np.uint8)
|
70 |
+
|
71 |
+
return render_an_image
|
72 |
+
|
73 |
+
def vposer_trainer_renderer(bm, num_bodies_to_display=5):
|
74 |
+
import numpy as np
|
75 |
+
import trimesh
|
76 |
+
import torch
|
77 |
+
|
78 |
+
from body_visualizer.tools.vis_tools import imagearray2file, colors
|
79 |
+
from human_body_prior.tools.omni_tools import copy2cpu as c2c
|
80 |
+
from human_body_prior.tools.omni_tools import makepath
|
81 |
+
from trimesh import Trimesh as Mesh
|
82 |
+
from trimesh.util import concatenate as mesh_cat
|
83 |
+
|
84 |
+
renderer = pyrenderer(1024, 1024)
|
85 |
+
|
86 |
+
faces = c2c(bm.f)
|
87 |
+
|
88 |
+
def render_once(body_parms, body_colors=[colors['grey'], colors['brown-light']], out_fname=None):
|
89 |
+
'''
|
90 |
+
|
91 |
+
:param body_parms: list of dictionaries of body parameters.
|
92 |
+
:param body_colors: list of np arrays of color rgb values
|
93 |
+
:param movie_outpath: a mp4 path
|
94 |
+
:return:
|
95 |
+
'''
|
96 |
+
|
97 |
+
if out_fname is not None: makepath(out_fname, isfile=True)
|
98 |
+
assert len(body_parms) <= len(body_colors), ValueError('Not enough colors provided for #{} body_parms'.format(len(body_parms)))
|
99 |
+
|
100 |
+
bs = body_parms[0]['pose_body'].shape[0]
|
101 |
+
|
102 |
+
body_ids = np.random.choice(bs, num_bodies_to_display)
|
103 |
+
|
104 |
+
body_evals = [c2c(bm(root_orient=v['root_orient'].view(bs, -1) if 'root_orient' in v else torch.zeros(bs, 3).type_as(v['pose_body']),
|
105 |
+
pose_body=v['pose_body'].contiguous().view(bs, -1)).v) for v in body_parms]
|
106 |
+
num_verts = body_evals[0].shape[1]
|
107 |
+
|
108 |
+
render_meshes = []
|
109 |
+
for bId in body_ids:
|
110 |
+
concat_cur_meshes = None
|
111 |
+
for body, body_color in zip(body_evals, body_colors):
|
112 |
+
cur_body_mesh = Mesh(body[bId], faces, vertex_colors=np.ones([num_verts, 3]) * body_color)
|
113 |
+
concat_cur_meshes = cur_body_mesh if concat_cur_meshes is None else mesh_cat(concat_cur_meshes, cur_body_mesh)
|
114 |
+
render_meshes.append(concat_cur_meshes)
|
115 |
+
|
116 |
+
img = renderer(render_meshes)
|
117 |
+
|
118 |
+
if out_fname is not None: imagearray2file(img, out_fname, fps=10)
|
119 |
+
|
120 |
+
|
121 |
+
return
|
122 |
+
|
123 |
+
return render_once
|
mogen/datasets/motionverse_dataset.py
ADDED
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import pickle as pkl
|
4 |
+
from typing import Optional, Union, List
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import json
|
10 |
+
from torch.utils.data import ConcatDataset, Dataset, WeightedRandomSampler
|
11 |
+
from .builder import DATASETS
|
12 |
+
from .pipelines import Compose, RetargetSkeleton
|
13 |
+
import random
|
14 |
+
import pytorch3d.transforms as geometry
|
15 |
+
from scipy.ndimage import gaussian_filter
|
16 |
+
# from mogen.core.evaluation import build_evaluator
|
17 |
+
# from mogen.core.evaluation.utils import compute_similarity_transform, transform_pose_sequence
|
18 |
+
from mogen.models.builder import build_submodule
|
19 |
+
from .utils import copy_repr_data, extract_repr_data, move_repr_data, recover_from_ric
|
20 |
+
|
21 |
+
class SingleMotionVerseDataset(Dataset):
|
22 |
+
"""
|
23 |
+
A dataset class for handling single MotionVerse datasets.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
dataset_name (str): Name of the dataset and task to load.
|
27 |
+
data_prefix (str): Path to the directory containing the dataset.
|
28 |
+
ann_file (str): Path to the annotation file.
|
29 |
+
pipeline (list): A list of transformations to apply on the data.
|
30 |
+
mode (str): the mode of current work. Choices: ['pretrain', 'train', 'test'].
|
31 |
+
eval_cfg (dict): Configuration for evaluation metrics.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self,
|
35 |
+
dataset_path: Optional[str] = None,
|
36 |
+
task_name: Optional[str] = None,
|
37 |
+
data_prefix: Optional[str] = None,
|
38 |
+
ann_file: Optional[str] = None,
|
39 |
+
pipeline: Optional[List[dict]] = None,
|
40 |
+
|
41 |
+
# for text2motion and speech2gesture
|
42 |
+
tgt_min_motion_length: int = 20,
|
43 |
+
tgt_max_motion_length: int = 200,
|
44 |
+
|
45 |
+
# for video2motion
|
46 |
+
v2m_window_size: int = 20,
|
47 |
+
|
48 |
+
# for motion prediction
|
49 |
+
mp_input_length: int = 50,
|
50 |
+
mp_output_length: int = 25,
|
51 |
+
mp_stride_step: int = 5,
|
52 |
+
|
53 |
+
# for general test
|
54 |
+
test_rotation_type: str = 'h3d_rot',
|
55 |
+
target_framerate: float = 20,
|
56 |
+
eval_cfg: Optional[dict] = None,
|
57 |
+
test_mode: Optional[bool] = False):
|
58 |
+
data_prefix = os.path.join(data_prefix, 'datasets', dataset_path)
|
59 |
+
self.dataset_path = dataset_path
|
60 |
+
assert task_name in ['mocap', 't2m', 'v2m', 's2g', 'm2d']
|
61 |
+
self.task_name = task_name
|
62 |
+
self.dataset_name = dataset_path + '_' + task_name
|
63 |
+
|
64 |
+
# define subdirectories
|
65 |
+
self.meta_dir = os.path.join(data_prefix, 'metas')
|
66 |
+
self.motion_dir = os.path.join(data_prefix, 'motions')
|
67 |
+
self.eval_motion_dir = os.path.join(data_prefix, 'eval_motions')
|
68 |
+
self.text_dir = os.path.join(data_prefix, 'texts')
|
69 |
+
self.text_feat_dir = os.path.join(data_prefix, 'text_feats')
|
70 |
+
self.speech_dir = os.path.join(data_prefix, 'speeches')
|
71 |
+
self.speech_feat_dir = os.path.join(data_prefix, 'speech_feats')
|
72 |
+
self.music_dir = os.path.join(data_prefix, 'musics')
|
73 |
+
self.music_feat_dir = os.path.join(data_prefix, 'music_feats')
|
74 |
+
self.video_feat_dir = os.path.join(data_prefix, 'video_feats')
|
75 |
+
self.anno_file = os.path.join(data_prefix, 'splits', ann_file)
|
76 |
+
|
77 |
+
self.pipeline = Compose(pipeline)
|
78 |
+
|
79 |
+
self.tgt_min_motion_length = tgt_min_motion_length
|
80 |
+
self.tgt_max_motion_length = tgt_max_motion_length
|
81 |
+
|
82 |
+
self.v2m_window_size = v2m_window_size
|
83 |
+
|
84 |
+
self.mp_input_length = mp_input_length
|
85 |
+
self.mp_output_length = mp_output_length
|
86 |
+
self.mp_stride_step = mp_stride_step
|
87 |
+
|
88 |
+
self.target_framerate = target_framerate
|
89 |
+
self.test_rotation_type = test_rotation_type
|
90 |
+
self.test_mode = test_mode
|
91 |
+
self.load_annotations()
|
92 |
+
self.eval_cfg = copy.deepcopy(eval_cfg)
|
93 |
+
if self.test_mode:
|
94 |
+
self.prepare_evaluation()
|
95 |
+
|
96 |
+
def __len__(self) -> int:
|
97 |
+
"""Return the length of the current dataset."""
|
98 |
+
if self.test_mode:
|
99 |
+
return len(self.eval_indexes)
|
100 |
+
return len(self.name_list)
|
101 |
+
|
102 |
+
def __getitem__(self, idx: int) -> dict:
|
103 |
+
"""Prepare data for the given index."""
|
104 |
+
if self.test_mode:
|
105 |
+
idx = self.eval_indexes[idx]
|
106 |
+
return self.prepare_data(idx)
|
107 |
+
|
108 |
+
def load_annotations(self):
|
109 |
+
if self.task_name == 'mocap':
|
110 |
+
self.load_annotations_mocap()
|
111 |
+
elif self.task_name == 't2m':
|
112 |
+
self.load_annotations_t2m()
|
113 |
+
elif self.task_name == 'v2m':
|
114 |
+
self.load_annotations_v2m()
|
115 |
+
elif self.task_name == 's2g':
|
116 |
+
self.load_annotations_s2g()
|
117 |
+
elif self.task_name == 'm2d':
|
118 |
+
self.load_annotations_m2d()
|
119 |
+
else:
|
120 |
+
raise NotImplementedError()
|
121 |
+
|
122 |
+
def load_annotations_mocap(self):
|
123 |
+
if self.test_mode:
|
124 |
+
self.name_list = []
|
125 |
+
self.src_start_frame = []
|
126 |
+
self.src_end_frame = []
|
127 |
+
self.tgt_start_frame = []
|
128 |
+
self.tgt_end_frame = []
|
129 |
+
tgt_motion_length = self.mp_input_length + self.mp_output_length
|
130 |
+
for name in open(self.anno_file):
|
131 |
+
name = name.strip()
|
132 |
+
meta_path = os.path.join(self.meta_dir, name + ".json")
|
133 |
+
meta_data = json.load(open(meta_path))
|
134 |
+
num_frames = meta_data['num_frames']
|
135 |
+
downrate = int(meta_data['framerate'] / self.target_framerate + 0.1)
|
136 |
+
if num_frames < (self.mp_input_length + self.mp_output_length) * downrate:
|
137 |
+
continue
|
138 |
+
lim = num_frames // downrate - tgt_motion_length
|
139 |
+
for start_frame in range(0, lim, self.mp_stride_step):
|
140 |
+
self.name_list.append(name)
|
141 |
+
self.src_start_frame.append((start_frame + 1) * downrate)
|
142 |
+
self.src_end_frame.append((start_frame + tgt_motion_length + 1) * downrate)
|
143 |
+
self.tgt_start_frame.append(start_frame + self.mp_input_length)
|
144 |
+
self.tgt_end_frame.append(start_frame + tgt_motion_length)
|
145 |
+
else:
|
146 |
+
self.name_list = []
|
147 |
+
for name in open(self.anno_file):
|
148 |
+
name = name.strip()
|
149 |
+
self.name_list.append(name)
|
150 |
+
|
151 |
+
def load_annotations_t2m(self):
|
152 |
+
self.name_list = []
|
153 |
+
self.text_idx = []
|
154 |
+
for name in open(self.anno_file):
|
155 |
+
name = name.strip()
|
156 |
+
meta_path = os.path.join(self.meta_dir, name + ".json")
|
157 |
+
meta_data = json.load(open(meta_path))
|
158 |
+
downrate = int(meta_data['framerate'] / self.target_framerate + 0.1)
|
159 |
+
text_path = os.path.join(self.text_dir, name + ".json")
|
160 |
+
text_data = json.load(open(text_path))
|
161 |
+
for i, anno in enumerate(text_data):
|
162 |
+
start_frame = anno['start_frame'] // downrate
|
163 |
+
end_frame = min(anno['end_frame'], meta_data['num_frames']) // downrate
|
164 |
+
num_frame = end_frame - start_frame
|
165 |
+
if num_frame < self.tgt_min_motion_length or num_frame > self.tgt_max_motion_length:
|
166 |
+
continue
|
167 |
+
if len(anno['body_text']) > 0:
|
168 |
+
self.name_list.append(name)
|
169 |
+
self.text_idx.append(i)
|
170 |
+
|
171 |
+
def load_annotations_v2m(self):
|
172 |
+
if not self.test_mode:
|
173 |
+
self.name_list = []
|
174 |
+
for name in open(self.anno_file):
|
175 |
+
name = name.strip()
|
176 |
+
self.name_list.append(name)
|
177 |
+
else:
|
178 |
+
self.name_list = []
|
179 |
+
self.start_frame = []
|
180 |
+
self.end_frame = []
|
181 |
+
self.valid_start_frame = []
|
182 |
+
self.valid_end_frame = []
|
183 |
+
for name in open(self.anno_file):
|
184 |
+
name = name.strip()
|
185 |
+
meta_path = os.path.join(self.meta_dir, name + ".json")
|
186 |
+
meta_data = json.load(open(meta_path))
|
187 |
+
num_frames = meta_data['num_frames']
|
188 |
+
assert num_frames >= self.v2m_window_size
|
189 |
+
cur_idx = 0
|
190 |
+
while cur_idx < num_frames:
|
191 |
+
if cur_idx + self.v2m_window_size < num_frames:
|
192 |
+
self.name_list.append(name)
|
193 |
+
self.start_frame.append(cur_idx)
|
194 |
+
self.end_frame.append(cur_idx + self.v2m_window_size)
|
195 |
+
self.valid_start_frame.append(cur_idx)
|
196 |
+
self.valid_end_frame.append(cur_idx + self.v2m_window_size)
|
197 |
+
cur_idx += self.v2m_window_size
|
198 |
+
else:
|
199 |
+
self.name_list.append(name)
|
200 |
+
self.start_frame.append(num_frames - self.v2m_window_size)
|
201 |
+
self.end_frame.append(num_frames)
|
202 |
+
self.valid_start_frame.append(cur_idx)
|
203 |
+
self.valid_end_frame.append(num_frames)
|
204 |
+
break
|
205 |
+
|
206 |
+
def load_annotations_s2g(self):
|
207 |
+
self.name_list = []
|
208 |
+
self.speech_idx = []
|
209 |
+
for name in open(self.anno_file):
|
210 |
+
name = name.strip()
|
211 |
+
meta_path = os.path.join(self.meta_dir, name + ".json")
|
212 |
+
meta_data = json.load(open(meta_path))
|
213 |
+
downrate = int(meta_data['framerate'] / self.target_framerate + 0.1)
|
214 |
+
speech_path = os.path.join(self.speech_dir, name + ".json")
|
215 |
+
speech_data = json.load(open(speech_path))
|
216 |
+
for i, anno in enumerate(speech_data):
|
217 |
+
start_frame = anno['start_frame'] // downrate
|
218 |
+
end_frame = min(anno['end_frame'], meta_data['num_frames']) // downrate
|
219 |
+
num_frame = end_frame - start_frame
|
220 |
+
if num_frame < self.tgt_min_motion_length or num_frame > self.tgt_max_motion_length:
|
221 |
+
continue
|
222 |
+
self.name_list.append(name)
|
223 |
+
self.speech_idx.append(i)
|
224 |
+
|
225 |
+
def load_annotations_m2d(self):
|
226 |
+
self.name_list = []
|
227 |
+
self.music_idx = []
|
228 |
+
for name in open(self.anno_file):
|
229 |
+
name = name.strip()
|
230 |
+
meta_path = os.path.join(self.meta_dir, name + ".json")
|
231 |
+
meta_data = json.load(open(meta_path))
|
232 |
+
downrate = int(meta_data['framerate'] / self.target_framerate + 0.1)
|
233 |
+
music_path = os.path.join(self.music_dir, name + ".json")
|
234 |
+
music_data = json.load(open(music_path))
|
235 |
+
for i, anno in enumerate(music_data):
|
236 |
+
start_frame = anno['start_frame'] // downrate
|
237 |
+
end_frame = min(anno['end_frame'], meta_data['num_frames']) // downrate
|
238 |
+
num_frame = end_frame - start_frame
|
239 |
+
if num_frame < self.tgt_min_motion_length or num_frame > self.tgt_max_motion_length:
|
240 |
+
continue
|
241 |
+
self.name_list.append(name)
|
242 |
+
self.music_idx.append(i)
|
243 |
+
|
244 |
+
def prepare_data_base(self, idx: int) -> dict:
|
245 |
+
results = {}
|
246 |
+
name = self.name_list[idx]
|
247 |
+
results['motion_path'] = os.path.join(self.motion_dir, name + ".npz")
|
248 |
+
meta_path = os.path.join(self.meta_dir, name + ".json")
|
249 |
+
meta_data = json.load(open(meta_path))
|
250 |
+
meta_data['dataset_name'] = self.dataset_name
|
251 |
+
results['meta_data'] = meta_data
|
252 |
+
results['meta_data']['sample_idx'] = idx
|
253 |
+
results.update({
|
254 |
+
'text_word_feat': np.zeros((77, 1024)).astype(np.float32),
|
255 |
+
'text_seq_feat': np.zeros((1024)).astype(np.float32),
|
256 |
+
'text_cond': 0,
|
257 |
+
'music_word_feat': np.zeros((229, 768)).astype(np.float32),
|
258 |
+
'music_seq_feat': np.zeros((1024)).astype(np.float32),
|
259 |
+
'music_cond': 0,
|
260 |
+
'speech_word_feat': np.zeros((229, 768)).astype(np.float32),
|
261 |
+
'speech_seq_feat': np.zeros((1024)).astype(np.float32),
|
262 |
+
'speech_cond': 0,
|
263 |
+
'video_seq_feat': np.zeros((1024)).astype(np.float32),
|
264 |
+
'video_cond': 0,
|
265 |
+
})
|
266 |
+
return results
|
267 |
+
|
268 |
+
def prepare_data(self, idx: int) -> dict:
|
269 |
+
if self.task_name == 'mocap':
|
270 |
+
results = self.prepare_data_mocap(idx)
|
271 |
+
elif self.task_name == 't2m':
|
272 |
+
results = self.prepare_data_t2m(idx)
|
273 |
+
elif self.task_name == 'v2m':
|
274 |
+
results = self.prepare_data_v2m(idx)
|
275 |
+
elif self.task_name == 's2g':
|
276 |
+
results = self.prepare_data_s2g(idx)
|
277 |
+
elif self.task_name == 'm2d':
|
278 |
+
results = self.prepare_data_m2d(idx)
|
279 |
+
else:
|
280 |
+
raise NotImplementedError()
|
281 |
+
results = self.pipeline(results)
|
282 |
+
return results
|
283 |
+
|
284 |
+
def prepare_data_mocap(self, idx: int) -> dict:
|
285 |
+
results = self.prepare_data_base(idx)
|
286 |
+
if self.test_mode:
|
287 |
+
results['meta_data']['start_frame'] = self.src_start_frame[idx]
|
288 |
+
results['meta_data']['end_frame'] = self.src_end_frame[idx]
|
289 |
+
results['context_mask'] = np.concatenate(
|
290 |
+
(np.ones((self.mp_input_length - 1)), np.zeros((self.mp_output_length))),
|
291 |
+
axis=-1
|
292 |
+
)
|
293 |
+
return results
|
294 |
+
|
295 |
+
def prepare_data_t2m(self, idx: int) -> dict:
|
296 |
+
results = self.prepare_data_base(idx)
|
297 |
+
name = self.name_list[idx]
|
298 |
+
text_idx = self.text_idx[idx]
|
299 |
+
text_path = os.path.join(self.text_dir, name + ".json")
|
300 |
+
text_data = json.load(open(text_path))[text_idx]
|
301 |
+
text_feat_path = os.path.join(self.text_feat_dir, name + ".pkl")
|
302 |
+
text_feat_data = pkl.load(open(text_feat_path, "rb"))
|
303 |
+
text_list = text_data['body_text']
|
304 |
+
tid = np.random.randint(len(text_list))
|
305 |
+
text = text_list[tid]
|
306 |
+
text_word_feat = text_feat_data['text_word_feats'][text_idx][tid]
|
307 |
+
text_seq_feat = text_feat_data['text_seq_feats'][text_idx][tid]
|
308 |
+
assert text_word_feat.shape[0] == 77
|
309 |
+
assert text_word_feat.shape[1] == 1024
|
310 |
+
assert text_seq_feat.shape[0] == 1024
|
311 |
+
|
312 |
+
if self.test_mode:
|
313 |
+
motion_path = os.path.join(self.eval_motion_dir, name + ".npy")
|
314 |
+
motion_data = np.load(motion_path)
|
315 |
+
assert not np.isnan(motion_data).any()
|
316 |
+
downrate = int(results['meta_data']['framerate'] / self.target_framerate + 0.1)
|
317 |
+
start_frame = text_data['start_frame'] // downrate
|
318 |
+
end_frame = text_data['end_frame'] // downrate
|
319 |
+
motion_data = motion_data[start_frame: end_frame]
|
320 |
+
results['meta_data']['framerate'] = self.target_framerate
|
321 |
+
results['meta_data']['rotation_type'] = self.test_rotation_type
|
322 |
+
assert motion_data.shape[0] > 0
|
323 |
+
if 'body_tokens' in text_data:
|
324 |
+
token = text_data['body_tokens'][tid]
|
325 |
+
else:
|
326 |
+
token = ""
|
327 |
+
text_cond = 1
|
328 |
+
results.update({
|
329 |
+
'motion': motion_data,
|
330 |
+
'text_word_feat': text_word_feat,
|
331 |
+
'text_seq_feat': text_seq_feat,
|
332 |
+
'text_cond': text_cond,
|
333 |
+
'text': text,
|
334 |
+
'token': token
|
335 |
+
})
|
336 |
+
else:
|
337 |
+
results['meta_data']['start_frame'] = text_data['start_frame']
|
338 |
+
results['meta_data']['end_frame'] = text_data['end_frame']
|
339 |
+
text_cond = 1
|
340 |
+
results.update({
|
341 |
+
'text_word_feat': text_word_feat,
|
342 |
+
'text_seq_feat': text_seq_feat,
|
343 |
+
'text_cond': text_cond
|
344 |
+
})
|
345 |
+
return results
|
346 |
+
|
347 |
+
def prepare_data_v2m(self, idx: int) -> dict:
|
348 |
+
results = self.prepare_data_base(idx)
|
349 |
+
name = self.name_list[idx]
|
350 |
+
video_feat_path = os.path.join(self.video_feat_dir, name + ".pkl")
|
351 |
+
video_feat_data = pkl.load(open(video_feat_path, "rb"))
|
352 |
+
video_word_feat = video_feat_data['video_word_feats']
|
353 |
+
video_seq_feat = video_feat_data['video_seq_feats']
|
354 |
+
assert video_word_feat.shape[0] == results['meta_data']['num_frames']
|
355 |
+
assert video_word_feat.shape[1] == 1024
|
356 |
+
assert video_seq_feat.shape[0] == 1024
|
357 |
+
video_cond = 1
|
358 |
+
if self.test_mode:
|
359 |
+
results['meta_data']['start_frame'] = self.start_frame[idx]
|
360 |
+
results['meta_data']['end_frame'] = self.end_frame[idx]
|
361 |
+
motion_path = os.path.join(self.eval_motion_dir, name + ".npy")
|
362 |
+
motion_data = np.load(motion_path)
|
363 |
+
assert not np.isnan(motion_data).any()
|
364 |
+
|
365 |
+
start_frame = self.start_frame[idx]
|
366 |
+
end_frame = self.end_frame[idx]
|
367 |
+
motion_data = motion_data[start_frame: end_frame]
|
368 |
+
video_word_feat = video_word_feat[start_frame: end_frame]
|
369 |
+
results['meta_data']['framerate'] = self.target_framerate
|
370 |
+
results['meta_data']['rotation_type'] = self.test_rotation_type
|
371 |
+
assert motion_data.shape[0] > 0
|
372 |
+
results.update({
|
373 |
+
'motion': motion_data,
|
374 |
+
'video_word_feat': video_word_feat,
|
375 |
+
'video_seq_feat': video_seq_feat,
|
376 |
+
'video_cond': video_cond
|
377 |
+
})
|
378 |
+
else:
|
379 |
+
results.update({
|
380 |
+
'video_word_feat': video_word_feat,
|
381 |
+
'video_seq_feat': video_seq_feat,
|
382 |
+
'video_cond': video_cond
|
383 |
+
})
|
384 |
+
return results
|
385 |
+
|
386 |
+
def prepare_data_s2g(self, idx: int) -> dict:
|
387 |
+
results = self.prepare_data_base(idx)
|
388 |
+
name = self.name_list[idx]
|
389 |
+
speech_idx = self.speech_idx[idx]
|
390 |
+
speech_path = os.path.join(self.speech_dir, name + ".json")
|
391 |
+
speech_data = json.load(open(speech_path))[speech_idx]
|
392 |
+
speech_feat_path = os.path.join(self.speech_feat_dir, name + ".pkl")
|
393 |
+
speech_feat_data = pkl.load(open(speech_feat_path, "rb"))
|
394 |
+
try:
|
395 |
+
speech_word_feat = speech_feat_data['audio_word_feats'][speech_idx]
|
396 |
+
speech_seq_feat = speech_feat_data['audio_seq_feats'][speech_idx]
|
397 |
+
except:
|
398 |
+
speech_word_feat = speech_feat_data['speech_word_feats'][speech_idx]
|
399 |
+
speech_seq_feat = speech_feat_data['speech_seq_feats'][speech_idx]
|
400 |
+
del speech_feat_data
|
401 |
+
assert speech_word_feat.shape[0] == 229
|
402 |
+
assert speech_word_feat.shape[1] == 768
|
403 |
+
assert speech_seq_feat.shape[0] == 1024
|
404 |
+
|
405 |
+
results['meta_data']['start_frame'] = speech_data['start_frame']
|
406 |
+
results['meta_data']['end_frame'] = speech_data['end_frame']
|
407 |
+
speech_cond = 1
|
408 |
+
results.update({
|
409 |
+
'speech_word_feat': speech_word_feat,
|
410 |
+
'speech_seq_feat': speech_seq_feat,
|
411 |
+
'speech_cond': speech_cond
|
412 |
+
})
|
413 |
+
if self.test_mode:
|
414 |
+
results['meta_data']['framerate'] = self.target_framerate
|
415 |
+
results['meta_data']['rotation_type'] = self.test_rotation_type
|
416 |
+
eval_data_path = os.path.join(self.eval_motion_dir, name + ".npz")
|
417 |
+
eval_data = np.load(eval_data_path)
|
418 |
+
motion_data = eval_data["bvh_rot_beat141"]
|
419 |
+
sem_data = eval_data["sem"]
|
420 |
+
wav_data = eval_data["wave16k"]
|
421 |
+
assert not np.isnan(motion_data).any()
|
422 |
+
|
423 |
+
start_frame = results['meta_data']['start_frame']
|
424 |
+
end_frame = results['meta_data']['end_frame']
|
425 |
+
wav_start_frame = start_frame / results['meta_data']['framerate'] * 16000
|
426 |
+
wav_end_frame = end_frame / results['meta_data']['framerate'] * 16000
|
427 |
+
motion_data = motion_data[start_frame: end_frame]
|
428 |
+
sem_data = sem_data[start_frame: end_frame]
|
429 |
+
wav_data = wav_data[wav_start_frame: wav_end_frame]
|
430 |
+
assert motion_data.shape[0] > 0
|
431 |
+
results.update({
|
432 |
+
'motion': motion_data,
|
433 |
+
'sem_score': sem_data,
|
434 |
+
'wav_feat': wav_data
|
435 |
+
})
|
436 |
+
return results
|
437 |
+
|
438 |
+
def prepare_data_m2d(self, idx: int) -> dict:
|
439 |
+
results = self.prepare_data_base(idx)
|
440 |
+
name = self.name_list[idx]
|
441 |
+
music_idx = self.music_idx[idx]
|
442 |
+
music_path = os.path.join(self.music_dir, name + ".json")
|
443 |
+
music_data = json.load(open(music_path))[music_idx]
|
444 |
+
music_feat_path = os.path.join(self.music_feat_dir, name + ".pkl")
|
445 |
+
music_feat_data = pkl.load(open(music_feat_path, "rb"))
|
446 |
+
music_word_feat = music_feat_data['audio_word_feats'][music_idx]
|
447 |
+
music_seq_feat = music_feat_data['audio_seq_feats'][music_idx]
|
448 |
+
assert music_word_feat.shape[0] == 229
|
449 |
+
assert music_word_feat.shape[1] == 768
|
450 |
+
assert music_seq_feat.shape[0] == 1024
|
451 |
+
|
452 |
+
results['meta_data']['start_frame'] = music_data['start_frame']
|
453 |
+
results['meta_data']['end_frame'] = music_data['end_frame']
|
454 |
+
music_cond = 1
|
455 |
+
results.update({
|
456 |
+
'music_word_feat': music_word_feat,
|
457 |
+
'music_seq_feat': music_seq_feat,
|
458 |
+
'music_cond': music_cond
|
459 |
+
})
|
460 |
+
return results
|
461 |
+
|
462 |
+
def prepare_evaluation(self):
|
463 |
+
"""
|
464 |
+
Prepare the dataset for evaluation by initializing evaluators and creating evaluation indexes.
|
465 |
+
"""
|
466 |
+
self.evaluators = []
|
467 |
+
self.eval_indexes = []
|
468 |
+
self.evaluator_model = build_submodule(self.eval_cfg.get('evaluator_model', None))
|
469 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
470 |
+
if self.evaluator_model is not None:
|
471 |
+
self.evaluator_model = self.evaluator_model.to(device)
|
472 |
+
self.evaluator_model.eval()
|
473 |
+
self.eval_cfg['evaluator_model'] = self.evaluator_model
|
474 |
+
|
475 |
+
for _ in range(self.eval_cfg['replication_times']):
|
476 |
+
eval_indexes = np.arange(len(self.name_list))
|
477 |
+
if self.eval_cfg.get('shuffle_indexes', False):
|
478 |
+
np.random.shuffle(eval_indexes)
|
479 |
+
self.eval_indexes.append(eval_indexes)
|
480 |
+
|
481 |
+
for metric in self.eval_cfg['metrics']:
|
482 |
+
evaluator, self.eval_indexes = build_evaluator(
|
483 |
+
metric, self.eval_cfg, len(self.name_list), self.eval_indexes)
|
484 |
+
self.evaluators.append(evaluator)
|
485 |
+
|
486 |
+
self.eval_indexes = np.concatenate(self.eval_indexes)
|
487 |
+
|
488 |
+
def process_outputs(self, results):
|
489 |
+
return results
|
490 |
+
|
491 |
+
def evaluate(self, results: List[dict], work_dir: str, logger=None) -> dict:
|
492 |
+
"""
|
493 |
+
Evaluate the model performance based on the results.
|
494 |
+
|
495 |
+
Args:
|
496 |
+
results (list): A list of result dictionaries.
|
497 |
+
work_dir (str): Directory where evaluation logs will be stored.
|
498 |
+
logger: Logger object to record evaluation results (optional).
|
499 |
+
|
500 |
+
Returns:
|
501 |
+
dict: Dictionary containing evaluation metrics.
|
502 |
+
"""
|
503 |
+
metrics = {}
|
504 |
+
results = self.process_outputs(results)
|
505 |
+
for evaluator in self.evaluators:
|
506 |
+
metrics.update(evaluator.evaluate(results))
|
507 |
+
if logger is not None:
|
508 |
+
logger.info(metrics)
|
509 |
+
eval_output = os.path.join(work_dir, 'eval_results.log')
|
510 |
+
with open(eval_output, 'w') as f:
|
511 |
+
for k, v in metrics.items():
|
512 |
+
f.write(k + ': ' + str(v) + '\n')
|
513 |
+
return metrics
|
514 |
+
|
515 |
+
|
516 |
+
def create_single_dataset(cfg: dict):
|
517 |
+
dataset_path = cfg['dataset_path']
|
518 |
+
if dataset_path == 'amass':
|
519 |
+
return MotionVerseAMASS(**cfg)
|
520 |
+
elif dataset_path == 'humanml3d':
|
521 |
+
return MotionVerseH3D(**cfg)
|
522 |
+
elif dataset_path == 'kitml':
|
523 |
+
return MotionVerseKIT(**cfg)
|
524 |
+
elif dataset_path == 'babel':
|
525 |
+
return MotionVerseBABEL(**cfg)
|
526 |
+
elif dataset_path == 'motionx':
|
527 |
+
return MotionVerseMotionX(**cfg)
|
528 |
+
elif dataset_path == 'humanact12':
|
529 |
+
return MotionVerseACT12(**cfg)
|
530 |
+
elif dataset_path == 'uestc':
|
531 |
+
return MotionVerseUESTC(**cfg)
|
532 |
+
elif dataset_path == 'ntu':
|
533 |
+
return MotionVerseNTU(**cfg)
|
534 |
+
elif dataset_path == 'h36m':
|
535 |
+
return MotionVerseH36M(**cfg)
|
536 |
+
elif dataset_path == 'mpi':
|
537 |
+
return MotionVerseMPI(**cfg)
|
538 |
+
elif dataset_path == 'pw3d':
|
539 |
+
return MotionVersePW3D(**cfg)
|
540 |
+
elif dataset_path == 'aist':
|
541 |
+
return MotionVerseAIST(**cfg)
|
542 |
+
elif dataset_path == 'beat':
|
543 |
+
return MotionVerseBEAT(**cfg)
|
544 |
+
elif dataset_path == 'tedg':
|
545 |
+
return MotionVerseTEDG(**cfg)
|
546 |
+
elif dataset_path == 'tedex':
|
547 |
+
return MotionVerseTEDEx(**cfg)
|
548 |
+
elif dataset_path == 's2g3d':
|
549 |
+
return MotionVerseS2G3D(**cfg)
|
550 |
+
else:
|
551 |
+
raise NotImplementedError()
|
552 |
+
|
553 |
+
|
554 |
+
@DATASETS.register_module()
|
555 |
+
class MotionVerse(Dataset):
|
556 |
+
"""
|
557 |
+
A dataset class that handles multiple MotionBench datasets.
|
558 |
+
|
559 |
+
Args:
|
560 |
+
dataset_cfgs (list[str]): List of dataset configurations.
|
561 |
+
partitions (list[float]): List of partition weights corresponding to the datasets.
|
562 |
+
num_data (Optional[int]): Number of data samples to load. Defaults to None.
|
563 |
+
data_prefix (str): Path to the directory containing the dataset.
|
564 |
+
"""
|
565 |
+
|
566 |
+
def __init__(self,
|
567 |
+
dataset_cfgs: List[dict],
|
568 |
+
partitions: List[int],
|
569 |
+
num_data: Optional[int] = None,
|
570 |
+
data_prefix: Optional[str] = None):
|
571 |
+
"""Load data from multiple datasets."""
|
572 |
+
assert min(partitions) >= 0
|
573 |
+
assert len(dataset_cfgs) == len(partitions)
|
574 |
+
datasets = []
|
575 |
+
new_partitions = []
|
576 |
+
for idx, cfg in enumerate(dataset_cfgs):
|
577 |
+
if partitions[idx] == 0:
|
578 |
+
continue
|
579 |
+
new_partitions.append(partitions[idx])
|
580 |
+
cfg.update({
|
581 |
+
'data_prefix': data_prefix
|
582 |
+
})
|
583 |
+
datasets.append(create_single_dataset(cfg))
|
584 |
+
self.dataset = ConcatDataset(datasets)
|
585 |
+
if num_data is not None:
|
586 |
+
self.length = num_data
|
587 |
+
else:
|
588 |
+
self.length = max(len(ds) for ds in datasets)
|
589 |
+
partitions = new_partitions
|
590 |
+
weights = [np.ones(len(ds)) * p / len(ds) for (p, ds) in zip(partitions, datasets)]
|
591 |
+
weights = np.concatenate(weights, axis=0)
|
592 |
+
self.weights = weights
|
593 |
+
self.task_proj = {
|
594 |
+
'mocap': 0,
|
595 |
+
't2m': 1,
|
596 |
+
'v2m': 2,
|
597 |
+
's2g': 3,
|
598 |
+
'm2d': 4
|
599 |
+
}
|
600 |
+
self.task_idx_list = []
|
601 |
+
for ds in datasets:
|
602 |
+
self.task_idx_list += [self.task_proj[ds.task_name]] * len(ds)
|
603 |
+
|
604 |
+
def __len__(self) -> int:
|
605 |
+
"""Get the size of the dataset."""
|
606 |
+
return self.length
|
607 |
+
|
608 |
+
def __getitem__(self, idx: int) -> dict:
|
609 |
+
"""Given an index, sample data from multiple datasets with the specified proportion."""
|
610 |
+
return self.dataset[idx]
|
611 |
+
|
612 |
+
def get_task_idx(self, idx: int) -> int:
|
613 |
+
return self.task_idx_list[idx]
|
614 |
+
|
615 |
+
|
616 |
+
@DATASETS.register_module()
|
617 |
+
class MotionVerseEval(Dataset):
|
618 |
+
|
619 |
+
def __init__(self,
|
620 |
+
eval_cfgs: dict,
|
621 |
+
testset: str,
|
622 |
+
test_mode: bool = True):
|
623 |
+
"""Load data from multiple datasets."""
|
624 |
+
assert testset in eval_cfgs
|
625 |
+
dataset_path, task_name = testset.split('_')
|
626 |
+
dataset_cfg = eval_cfgs[testset]
|
627 |
+
dataset_cfg['dataset_path'] = dataset_path
|
628 |
+
dataset_cfg['task_name'] = task_name
|
629 |
+
dataset_cfg['test_mode'] = test_mode
|
630 |
+
self.dataset = create_single_dataset(dataset_cfg)
|
631 |
+
|
632 |
+
def __len__(self) -> int:
|
633 |
+
return len(self.dataset)
|
634 |
+
|
635 |
+
def __getitem__(self, idx: int) -> dict:
|
636 |
+
return self.dataset[idx]
|
637 |
+
|
638 |
+
def load_annotation(self):
|
639 |
+
self.dataset.load_annotation()
|
640 |
+
|
641 |
+
def prepare_data(self, idx: int) -> dict:
|
642 |
+
return self.dataset.prepare_data(idx)
|
643 |
+
|
644 |
+
def prepare_evaluation(self):
|
645 |
+
self.dataset.prepare_evaluation()
|
646 |
+
|
647 |
+
def process_outputs(self, results):
|
648 |
+
return self.dataset.process_outputs(results)
|
649 |
+
|
650 |
+
def evaluate(self, results: List[dict], work_dir: str, logger=None) -> dict:
|
651 |
+
return self.dataset.evaluate(results=results, work_dir=work_dir, logger=logger)
|
652 |
+
|
653 |
+
|
654 |
+
@DATASETS.register_module()
|
655 |
+
class MotionVerseAMASS(SingleMotionVerseDataset):
|
656 |
+
|
657 |
+
def __init__(self, **kwargs):
|
658 |
+
if 'dataset_path' not in kwargs:
|
659 |
+
kwargs['dataset_path'] = 'amass'
|
660 |
+
task_name = kwargs['task_name']
|
661 |
+
assert task_name in ['mocap']
|
662 |
+
super().__init__(**kwargs)
|
663 |
+
|
664 |
+
|
665 |
+
@DATASETS.register_module()
|
666 |
+
class MotionVerseH3D(SingleMotionVerseDataset):
|
667 |
+
|
668 |
+
def __init__(self, **kwargs):
|
669 |
+
if 'dataset_path' not in kwargs:
|
670 |
+
kwargs['dataset_path'] = 'humanml3d'
|
671 |
+
task_name = kwargs['task_name']
|
672 |
+
assert task_name in ['mocap', 't2m']
|
673 |
+
super().__init__(**kwargs)
|
674 |
+
|
675 |
+
|
676 |
+
@DATASETS.register_module()
|
677 |
+
class MotionVerseKIT(SingleMotionVerseDataset):
|
678 |
+
|
679 |
+
def __init__(self, **kwargs):
|
680 |
+
if 'dataset_path' not in kwargs:
|
681 |
+
kwargs['dataset_path'] = 'kitml'
|
682 |
+
task_name = kwargs['task_name']
|
683 |
+
assert task_name in ['mocap', 't2m']
|
684 |
+
super().__init__(**kwargs)
|
685 |
+
|
686 |
+
|
687 |
+
@DATASETS.register_module()
|
688 |
+
class MotionVerseBABEL(SingleMotionVerseDataset):
|
689 |
+
|
690 |
+
def __init__(self, **kwargs):
|
691 |
+
if 'dataset_path' not in kwargs:
|
692 |
+
kwargs['dataset_path'] = 'babel'
|
693 |
+
task_name = kwargs['task_name']
|
694 |
+
assert task_name in ['mocap', 't2m']
|
695 |
+
super().__init__(**kwargs)
|
696 |
+
|
697 |
+
|
698 |
+
@DATASETS.register_module()
|
699 |
+
class MotionVerseMotionX(SingleMotionVerseDataset):
|
700 |
+
|
701 |
+
def __init__(self, **kwargs):
|
702 |
+
if 'dataset_path' not in kwargs:
|
703 |
+
kwargs['dataset_path'] = 'motionx'
|
704 |
+
task_name = kwargs['task_name']
|
705 |
+
assert task_name in ['mocap', 't2m']
|
706 |
+
super().__init__(**kwargs)
|
707 |
+
|
708 |
+
|
709 |
+
@DATASETS.register_module()
|
710 |
+
class MotionVerseACT12(SingleMotionVerseDataset):
|
711 |
+
|
712 |
+
def __init__(self, **kwargs):
|
713 |
+
if 'dataset_path' not in kwargs:
|
714 |
+
kwargs['dataset_path'] = 'humanact12'
|
715 |
+
task_name = kwargs['task_name']
|
716 |
+
assert task_name in ['mocap', 't2m']
|
717 |
+
super().__init__(**kwargs)
|
718 |
+
|
719 |
+
|
720 |
+
@DATASETS.register_module()
|
721 |
+
class MotionVerseUESTC(SingleMotionVerseDataset):
|
722 |
+
|
723 |
+
def __init__(self, **kwargs):
|
724 |
+
if 'dataset_path' not in kwargs:
|
725 |
+
kwargs['dataset_path'] = 'uestc'
|
726 |
+
task_name = kwargs['task_name']
|
727 |
+
assert task_name in ['mocap', 't2m']
|
728 |
+
super().__init__(**kwargs)
|
729 |
+
|
730 |
+
|
731 |
+
@DATASETS.register_module()
|
732 |
+
class MotionVerseNTU(SingleMotionVerseDataset):
|
733 |
+
|
734 |
+
def __init__(self, **kwargs):
|
735 |
+
if 'dataset_path' not in kwargs:
|
736 |
+
kwargs['dataset_path'] = 'ntu'
|
737 |
+
task_name = kwargs['task_name']
|
738 |
+
assert task_name in ['mocap', 't2m']
|
739 |
+
super().__init__(**kwargs)
|
740 |
+
|
741 |
+
|
742 |
+
@DATASETS.register_module()
|
743 |
+
class MotionVerseH36M(SingleMotionVerseDataset):
|
744 |
+
|
745 |
+
def __init__(self, **kwargs):
|
746 |
+
if 'dataset_path' not in kwargs:
|
747 |
+
kwargs['dataset_path'] = 'h36m'
|
748 |
+
task_name = kwargs['task_name']
|
749 |
+
assert task_name in ['mocap', 'v2m']
|
750 |
+
super().__init__(**kwargs)
|
751 |
+
|
752 |
+
|
753 |
+
@DATASETS.register_module()
|
754 |
+
class MotionVerseMPI(SingleMotionVerseDataset):
|
755 |
+
|
756 |
+
def __init__(self, **kwargs):
|
757 |
+
if 'dataset_path' not in kwargs:
|
758 |
+
kwargs['dataset_path'] = 'mpi'
|
759 |
+
task_name = kwargs['task_name']
|
760 |
+
assert task_name in ['mocap', 'v2m']
|
761 |
+
super().__init__(**kwargs)
|
762 |
+
|
763 |
+
|
764 |
+
@DATASETS.register_module()
|
765 |
+
class MotionVersePW3D(SingleMotionVerseDataset):
|
766 |
+
|
767 |
+
def __init__(self, **kwargs):
|
768 |
+
if 'dataset_path' not in kwargs:
|
769 |
+
kwargs['dataset_path'] = '3dpw'
|
770 |
+
task_name = kwargs['task_name']
|
771 |
+
assert task_name in ['mocap', 'v2m']
|
772 |
+
super().__init__(**kwargs)
|
773 |
+
|
774 |
+
|
775 |
+
@DATASETS.register_module()
|
776 |
+
class MotionVerseAIST(SingleMotionVerseDataset):
|
777 |
+
|
778 |
+
def __init__(self, **kwargs):
|
779 |
+
if 'dataset_path' not in kwargs:
|
780 |
+
kwargs['dataset_path'] = 'aist'
|
781 |
+
task_name = kwargs['task_name']
|
782 |
+
assert task_name in ['mocap', 'm2d']
|
783 |
+
super().__init__(**kwargs)
|
784 |
+
|
785 |
+
|
786 |
+
@DATASETS.register_module()
|
787 |
+
class MotionVerseBEAT(SingleMotionVerseDataset):
|
788 |
+
|
789 |
+
def __init__(self, **kwargs):
|
790 |
+
if 'dataset_path' not in kwargs:
|
791 |
+
kwargs['dataset_path'] = 'beat'
|
792 |
+
task_name = kwargs['task_name']
|
793 |
+
assert task_name in ['mocap', 's2g']
|
794 |
+
super().__init__(**kwargs)
|
795 |
+
|
796 |
+
|
797 |
+
@DATASETS.register_module()
|
798 |
+
class MotionVerseTEDG(SingleMotionVerseDataset):
|
799 |
+
|
800 |
+
def __init__(self, **kwargs):
|
801 |
+
if 'dataset_path' not in kwargs:
|
802 |
+
kwargs['dataset_path'] = 'tedg'
|
803 |
+
task_name = kwargs['task_name']
|
804 |
+
assert task_name in ['mocap', 's2g']
|
805 |
+
super().__init__(**kwargs)
|
806 |
+
|
807 |
+
|
808 |
+
@DATASETS.register_module()
|
809 |
+
class MotionVerseTEDEx(SingleMotionVerseDataset):
|
810 |
+
|
811 |
+
def __init__(self, **kwargs):
|
812 |
+
if 'dataset_path' not in kwargs:
|
813 |
+
kwargs['dataset_path'] = 'tedex'
|
814 |
+
task_name = kwargs['task_name']
|
815 |
+
assert task_name in ['mocap', 's2g']
|
816 |
+
super().__init__(**kwargs)
|
817 |
+
|
818 |
+
|
819 |
+
@DATASETS.register_module()
|
820 |
+
class MotionVerseS2G3D(SingleMotionVerseDataset):
|
821 |
+
|
822 |
+
def __init__(self, **kwargs):
|
823 |
+
if 'dataset_path' not in kwargs:
|
824 |
+
kwargs['dataset_path'] = 's2g3d'
|
825 |
+
task_name = kwargs['task_name']
|
826 |
+
assert task_name in ['mocap', 's2g']
|
827 |
+
super().__init__(**kwargs)
|
828 |
+
|
mogen/datasets/paramUtil.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The IDEA Authors (Shunlin Lu and Ling-Hao Chen). All rights reserved.
|
3 |
+
#
|
4 |
+
# For all the datasets, be sure to read and follow their license agreements,
|
5 |
+
# and cite them accordingly.
|
6 |
+
# If the unifier is used in your research, please consider to cite as:
|
7 |
+
#
|
8 |
+
# @article{humantomato,
|
9 |
+
# title={HumanTOMATO: Text-aligned Whole-body Motion Generation},
|
10 |
+
# author={Lu, Shunlin and Chen, Ling-Hao and Zeng, Ailing and Lin, Jing and Zhang, Ruimao and Zhang, Lei and Shum, Heung-Yeung},
|
11 |
+
# journal={arxiv:2310.12978},
|
12 |
+
# year={2023}
|
13 |
+
# }
|
14 |
+
#
|
15 |
+
# @InProceedings{Guo_2022_CVPR,
|
16 |
+
# author = {Guo, Chuan and Zou, Shihao and Zuo, Xinxin and Wang, Sen and Ji, Wei and Li, Xingyu and Cheng, Li},
|
17 |
+
# title = {Generating Diverse and Natural 3D Human Motions From Text},
|
18 |
+
# booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
19 |
+
# month = {June},
|
20 |
+
# year = {2022},
|
21 |
+
# pages = {5152-5161}
|
22 |
+
# }
|
23 |
+
#
|
24 |
+
# Licensed under the IDEA License, Version 2.0 (the "License");
|
25 |
+
# you may not use this file except in compliance with the License.
|
26 |
+
# You may obtain a copy of the License at
|
27 |
+
#
|
28 |
+
# https://github.com/IDEA-Research/HumanTOMATO/blob/main/LICENSE
|
29 |
+
#
|
30 |
+
# Unless required by applicable law or agreed to in writing, software
|
31 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
32 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
33 |
+
# See the License for the specific language governing permissions and
|
34 |
+
# limitations under the License. We provide a license to use the code,
|
35 |
+
# please read the specific details carefully.
|
36 |
+
#
|
37 |
+
# ------------------------------------------------------------------------------------------------
|
38 |
+
# Copyright (c) Chuan Guo.
|
39 |
+
# ------------------------------------------------------------------------------------------------
|
40 |
+
# Portions of this code were adapted from the following open-source project:
|
41 |
+
# https://github.com/EricGuo5513/HumanML3D
|
42 |
+
# ------------------------------------------------------------------------------------------------
|
43 |
+
|
44 |
+
|
45 |
+
import numpy as np
|
46 |
+
|
47 |
+
# Define a kinematic tree for the skeletal struture
|
48 |
+
kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
|
49 |
+
|
50 |
+
kit_raw_offsets = np.array(
|
51 |
+
[
|
52 |
+
[0, 0, 0],
|
53 |
+
[0, 1, 0],
|
54 |
+
[0, 1, 0],
|
55 |
+
[0, 1, 0],
|
56 |
+
[0, 1, 0],
|
57 |
+
[1, 0, 0],
|
58 |
+
[0, -1, 0],
|
59 |
+
[0, -1, 0],
|
60 |
+
[-1, 0, 0],
|
61 |
+
[0, -1, 0],
|
62 |
+
[0, -1, 0],
|
63 |
+
[1, 0, 0],
|
64 |
+
[0, -1, 0],
|
65 |
+
[0, -1, 0],
|
66 |
+
[0, 0, 1],
|
67 |
+
[0, 0, 1],
|
68 |
+
[-1, 0, 0],
|
69 |
+
[0, -1, 0],
|
70 |
+
[0, -1, 0],
|
71 |
+
[0, 0, 1],
|
72 |
+
[0, 0, 1]
|
73 |
+
]
|
74 |
+
)
|
75 |
+
|
76 |
+
t2m_raw_body_offsets = np.array([[0,0,0],
|
77 |
+
[1,0,0],
|
78 |
+
[-1,0,0],
|
79 |
+
[0,1,0],
|
80 |
+
[0,-1,0],
|
81 |
+
[0,-1,0],
|
82 |
+
[0,1,0],
|
83 |
+
[0,-1,0],
|
84 |
+
[0,-1,0],
|
85 |
+
[0,1,0],
|
86 |
+
[0,0,1],
|
87 |
+
[0,0,1],
|
88 |
+
[0,1,0],
|
89 |
+
[1,0,0],
|
90 |
+
[-1,0,0],
|
91 |
+
[0,0,1],
|
92 |
+
[0,-1,0],
|
93 |
+
[0,-1,0],
|
94 |
+
[0,-1,0],
|
95 |
+
[0,-1,0],
|
96 |
+
[0,-1,0],
|
97 |
+
[0,-1,0]])
|
98 |
+
|
99 |
+
t2m_raw_hand_offsets = np.array([[1, 0, 0], # left_index1
|
100 |
+
[1, 0, 0], # left_index2
|
101 |
+
[1, 0, 0], # left_index3
|
102 |
+
[1, 0, 0], # left_middle1
|
103 |
+
[1, 0, 0], # left_middle2
|
104 |
+
[1, 0, 0], # left_middle3
|
105 |
+
[1, 0, 0], # left_pinky1
|
106 |
+
[1, 0, 0], # left_pinky2
|
107 |
+
[1, 0, 0], # left_pinky3
|
108 |
+
[1, 0, 0], # left_ring1
|
109 |
+
[1, 0, 0], # left_ring2
|
110 |
+
[1, 0, 0], # left_ring3
|
111 |
+
[1, 0, 0], # left_thumb1
|
112 |
+
[1, 0, 0], # left_thumb2
|
113 |
+
[1, 0, 0], # left_thumb3
|
114 |
+
[-1, 0, 0], # right_index1
|
115 |
+
[-1, 0, 0], # right_index2
|
116 |
+
[-1, 0, 0], # right_index3
|
117 |
+
[-1, 0, 0], # right_middle1
|
118 |
+
[-1, 0, 0], # right_middle2
|
119 |
+
[-1, 0, 0], # right_middle3
|
120 |
+
[-1, 0, 0], # right_pinky1
|
121 |
+
[-1, 0, 0], # right_pinky2
|
122 |
+
[-1, 0, 0], # right_pinky3
|
123 |
+
[-1, 0, 0], # right_ring1
|
124 |
+
[-1, 0, 0], # right_ring2
|
125 |
+
[-1, 0, 0], # right_ring3
|
126 |
+
[-1, 0, 0], # right_thumb1
|
127 |
+
[-1, 0, 0], # right_thumb2
|
128 |
+
[-1, 0, 0],]) # right_thumb3
|
129 |
+
|
130 |
+
t2m_raw_offsets = np.concatenate(
|
131 |
+
(t2m_raw_body_offsets, t2m_raw_hand_offsets), axis=0)
|
132 |
+
t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
|
133 |
+
t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
|
134 |
+
t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
|
135 |
+
|
136 |
+
t2m_body_hand_kinematic_chain = t2m_kinematic_chain + t2m_left_hand_chain + t2m_right_hand_chain
|
137 |
+
|
138 |
+
kit_tgt_skel_id = '03950'
|
139 |
+
|
140 |
+
t2m_tgt_skel_id = '000021'
|
mogen/datasets/pipelines/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .compose import Compose
|
2 |
+
from .formatting import (
|
3 |
+
Collect,
|
4 |
+
ToTensor,
|
5 |
+
Transpose,
|
6 |
+
WrapFieldsToLists,
|
7 |
+
to_tensor
|
8 |
+
)
|
9 |
+
from .siamese_motion import ProcessSiameseMotion, SwapSiameseMotion
|
10 |
+
from .transforms import Crop, Normalize, RandomCrop
|
11 |
+
from .motionverse import (
|
12 |
+
LoadMotion,
|
13 |
+
RetargetSkeleton,
|
14 |
+
MotionDownsample,
|
15 |
+
PutOnFloor,
|
16 |
+
MoveToOrigin,
|
17 |
+
RotateToZ,
|
18 |
+
KeypointsToTomato,
|
19 |
+
RandomCropKeypoints,
|
20 |
+
MaskedCrop,
|
21 |
+
MaskedRandomCrop
|
22 |
+
)
|
23 |
+
|
24 |
+
__all__ = [
|
25 |
+
'Compose', 'to_tensor', 'Transpose', 'Collect', 'WrapFieldsToLists',
|
26 |
+
'ToTensor', 'Crop', 'RandomCrop', 'Normalize', 'SwapSiameseMotion',
|
27 |
+
'ProcessSiameseMotion', 'LoadMotion', 'RetargetSkeleton', 'MotionDownsample',
|
28 |
+
'PutOnFloor', 'MoveToOrigin', 'RotateToZ', 'KeypointsToTomato', 'RandomCropKeypoints',
|
29 |
+
'MaskedCrop', 'MaskedRandomCrop'
|
30 |
+
]
|
mogen/datasets/pipelines/compose.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Sequence
|
2 |
+
|
3 |
+
from mmcv.utils import build_from_cfg
|
4 |
+
|
5 |
+
from ..builder import PIPELINES
|
6 |
+
|
7 |
+
|
8 |
+
@PIPELINES.register_module()
|
9 |
+
class Compose(object):
|
10 |
+
"""Compose a data pipeline with a sequence of transforms.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
transforms (list[dict | callable]):
|
14 |
+
Either config dicts of transforms or transform objects.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, transforms):
|
18 |
+
assert isinstance(transforms, Sequence)
|
19 |
+
self.transforms = []
|
20 |
+
for transform in transforms:
|
21 |
+
if isinstance(transform, dict):
|
22 |
+
transform = build_from_cfg(transform, PIPELINES)
|
23 |
+
self.transforms.append(transform)
|
24 |
+
elif callable(transform):
|
25 |
+
self.transforms.append(transform)
|
26 |
+
else:
|
27 |
+
raise TypeError('transform must be callable or a dict, but got'
|
28 |
+
f' {type(transform)}')
|
29 |
+
|
30 |
+
def __call__(self, data):
|
31 |
+
for t in self.transforms:
|
32 |
+
data = t(data)
|
33 |
+
if data is None:
|
34 |
+
return None
|
35 |
+
return data
|
36 |
+
|
37 |
+
def __repr__(self):
|
38 |
+
format_string = self.__class__.__name__ + '('
|
39 |
+
for t in self.transforms:
|
40 |
+
format_string += f'\n {t}'
|
41 |
+
format_string += '\n)'
|
42 |
+
return format_string
|
mogen/datasets/pipelines/formatting.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Sequence
|
2 |
+
|
3 |
+
import mmcv
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from mmcv.parallel import DataContainer as DC
|
7 |
+
|
8 |
+
from ..builder import PIPELINES
|
9 |
+
|
10 |
+
|
11 |
+
def to_tensor(data):
|
12 |
+
"""Convert objects of various python types to :obj:`torch.Tensor`.
|
13 |
+
|
14 |
+
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
15 |
+
:class:`Sequence`, :class:`int` and :class:`float`.
|
16 |
+
"""
|
17 |
+
if isinstance(data, torch.Tensor):
|
18 |
+
return data
|
19 |
+
elif isinstance(data, np.ndarray):
|
20 |
+
return torch.from_numpy(data)
|
21 |
+
elif isinstance(data, Sequence) and not mmcv.is_str(data):
|
22 |
+
return torch.tensor(data)
|
23 |
+
elif isinstance(data, int):
|
24 |
+
return torch.LongTensor([data])
|
25 |
+
elif isinstance(data, float):
|
26 |
+
return torch.FloatTensor([data])
|
27 |
+
else:
|
28 |
+
raise TypeError(
|
29 |
+
f'Type {type(data)} cannot be converted to tensor.'
|
30 |
+
'Supported types are: `numpy.ndarray`, `torch.Tensor`, '
|
31 |
+
'`Sequence`, `int` and `float`')
|
32 |
+
|
33 |
+
|
34 |
+
@PIPELINES.register_module()
|
35 |
+
class ToTensor(object):
|
36 |
+
|
37 |
+
def __init__(self, keys):
|
38 |
+
self.keys = keys
|
39 |
+
|
40 |
+
def __call__(self, results):
|
41 |
+
for key in self.keys:
|
42 |
+
results[key] = to_tensor(results[key])
|
43 |
+
return results
|
44 |
+
|
45 |
+
def __repr__(self):
|
46 |
+
return self.__class__.__name__ + f'(keys={self.keys})'
|
47 |
+
|
48 |
+
|
49 |
+
@PIPELINES.register_module()
|
50 |
+
class Transpose(object):
|
51 |
+
|
52 |
+
def __init__(self, keys, order):
|
53 |
+
self.keys = keys
|
54 |
+
self.order = order
|
55 |
+
|
56 |
+
def __call__(self, results):
|
57 |
+
for key in self.keys:
|
58 |
+
results[key] = results[key].transpose(self.order)
|
59 |
+
return results
|
60 |
+
|
61 |
+
def __repr__(self):
|
62 |
+
return self.__class__.__name__ + \
|
63 |
+
f'(keys={self.keys}, order={self.order})'
|
64 |
+
|
65 |
+
|
66 |
+
@PIPELINES.register_module()
|
67 |
+
class Collect(object):
|
68 |
+
"""Collect data from the loader relevant to the specific task.
|
69 |
+
|
70 |
+
This is usually the last stage of the data loader pipeline.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
keys (Sequence[str]): Keys of results to be collected in ``data``.
|
74 |
+
meta_keys (Sequence[str], optional): Meta keys to be converted to
|
75 |
+
``mmcv.DataContainer`` and collected in ``data[motion_metas]``.
|
76 |
+
Default: ``('filename', 'ori_filename',
|
77 |
+
'ori_shape', 'motion_shape', 'motion_mask')``
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
dict: The result dict contains the following keys
|
81 |
+
- keys in``self.keys``
|
82 |
+
- ``motion_metas`` if available
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self,
|
86 |
+
keys,
|
87 |
+
meta_keys=('filename', 'ori_filename', 'ori_shape',
|
88 |
+
'motion_shape', 'motion_mask')):
|
89 |
+
self.keys = keys
|
90 |
+
self.meta_keys = meta_keys
|
91 |
+
|
92 |
+
def __call__(self, results):
|
93 |
+
data = {}
|
94 |
+
motion_meta = {}
|
95 |
+
for key in self.meta_keys:
|
96 |
+
if key in results:
|
97 |
+
motion_meta[key] = results[key]
|
98 |
+
data['motion_metas'] = DC(motion_meta, cpu_only=True)
|
99 |
+
for key in self.keys:
|
100 |
+
data[key] = results[key]
|
101 |
+
return data
|
102 |
+
|
103 |
+
def __repr__(self):
|
104 |
+
return self.__class__.__name__ + \
|
105 |
+
f'(keys={self.keys}, meta_keys={self.meta_keys})'
|
106 |
+
|
107 |
+
|
108 |
+
@PIPELINES.register_module()
|
109 |
+
class WrapFieldsToLists(object):
|
110 |
+
"""Wrap fields of the data dictionary into lists for evaluation.
|
111 |
+
|
112 |
+
This class can be used as a last step of a test or validation
|
113 |
+
pipeline for single image evaluation or inference.
|
114 |
+
|
115 |
+
Example:
|
116 |
+
>>> test_pipeline = [
|
117 |
+
>>> dict(type='LoadImageFromFile'),
|
118 |
+
>>> dict(type='Normalize',
|
119 |
+
mean=[123.675, 116.28, 103.53],
|
120 |
+
std=[58.395, 57.12, 57.375],
|
121 |
+
to_rgb=True),
|
122 |
+
>>> dict(type='ImageToTensor', keys=['img']),
|
123 |
+
>>> dict(type='Collect', keys=['img']),
|
124 |
+
>>> dict(type='WrapIntoLists')
|
125 |
+
>>> ]
|
126 |
+
"""
|
127 |
+
|
128 |
+
def __call__(self, results):
|
129 |
+
# Wrap dict fields into lists
|
130 |
+
for key, val in results.items():
|
131 |
+
results[key] = [val]
|
132 |
+
return results
|
133 |
+
|
134 |
+
def __repr__(self):
|
135 |
+
return f'{self.__class__.__name__}()'
|