tuandunghcmut
commited on
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +8 -0
- LICENSE +21 -0
- assets/framework.png +0 -0
- assets/teaser.png +0 -0
- core/__init__.py +0 -0
- core/__pycache__/__init__.cpython-312.pyc +0 -0
- core/__pycache__/comm_.cpython-312.pyc +0 -0
- core/__pycache__/config.cpython-312.pyc +0 -0
- core/__pycache__/distributed_utils.cpython-312.pyc +0 -0
- core/__pycache__/make_param_group.cpython-312.pyc +0 -0
- core/__pycache__/memory.cpython-312.pyc +0 -0
- core/__pycache__/utils.cpython-312.pyc +0 -0
- core/clipping.py +92 -0
- core/comm_.py +307 -0
- core/config.py +600 -0
- core/data/__init__.py +0 -0
- core/data/__pycache__/__init__.cpython-312.pyc +0 -0
- core/data/datasets/__init__.py +15 -0
- core/data/datasets/__pycache__/__init__.cpython-312.pyc +0 -0
- core/data/datasets/images/__pycache__/image_caption_dataset.cpython-312.pyc +0 -0
- core/data/datasets/images/__pycache__/multi_posedataset.cpython-312.pyc +0 -0
- core/data/datasets/images/__pycache__/parsing_dataset.cpython-312.pyc +0 -0
- core/data/datasets/images/__pycache__/pedattr_dataset.cpython-312.pyc +0 -0
- core/data/datasets/images/__pycache__/peddet_dataset_v2.cpython-312.pyc +0 -0
- core/data/datasets/images/__pycache__/pos_dataset_dev.cpython-312.pyc +0 -0
- core/data/datasets/images/__pycache__/seg_dataset_dev.cpython-312.pyc +0 -0
- core/data/datasets/images/__pycache__/smpl_dataset_v2.cpython-312.pyc +0 -0
- core/data/datasets/images/image_caption_dataset.py +261 -0
- core/data/datasets/images/multi_posedataset.py +413 -0
- core/data/datasets/images/parsing_dataset.py +1084 -0
- core/data/datasets/images/pedattr_dataset.py +665 -0
- core/data/datasets/images/peddet_dataset_v2.py +578 -0
- core/data/datasets/images/pos_dataset_dev.py +713 -0
- core/data/datasets/images/resources/CHval.odgt +3 -0
- core/data/datasets/images/resources/COCO_val2017_detections_AP_H_56_person.json +3 -0
- core/data/datasets/images/resources/mpii_gt_val.mat +3 -0
- core/data/datasets/images/resources/test_caltech_heavy_1xnew.odgt +0 -0
- core/data/datasets/images/seg_data_tools/__init__.py +0 -0
- core/data/datasets/images/seg_data_tools/collate.py +143 -0
- core/data/datasets/images/seg_data_tools/cv2_aug_transforms.py +889 -0
- core/data/datasets/images/seg_data_tools/transforms.py +106 -0
- core/data/datasets/images/seg_dataset_dev.py +293 -0
- core/data/datasets/images/smpl_data_tools/__pycache__/_smpl.cpython-312.pyc +0 -0
- core/data/datasets/images/smpl_data_tools/__pycache__/config_smpl.cpython-312.pyc +0 -0
- core/data/datasets/images/smpl_data_tools/__pycache__/image_ops.cpython-312.pyc +0 -0
- core/data/datasets/images/smpl_data_tools/__pycache__/tsv_file.cpython-312.pyc +0 -0
- core/data/datasets/images/smpl_data_tools/_smpl.py +333 -0
- core/data/datasets/images/smpl_data_tools/config_smpl.py +53 -0
- core/data/datasets/images/smpl_data_tools/image_ops.py +230 -0
- core/data/datasets/images/smpl_data_tools/smpl_modeling/data/J_regressor_extra.npy +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
core/data/datasets/images/resources/CHval.odgt filter=lfs diff=lfs merge=lfs -text
|
37 |
+
core/data/datasets/images/resources/COCO_val2017_detections_AP_H_56_person.json filter=lfs diff=lfs merge=lfs -text
|
38 |
+
core/data/datasets/images/resources/mpii_gt_val.mat filter=lfs diff=lfs merge=lfs -text
|
39 |
+
core/solvers/utils/pycocoevalcap/meteor/meteor-1.5.jar filter=lfs diff=lfs merge=lfs -text
|
40 |
+
core/solvers/utils/pycocoevalcap/spice/lib/Meteor-1.5.jar filter=lfs diff=lfs merge=lfs -text
|
41 |
+
core/solvers/utils/pycocoevalcap/spice/lib/guava-19.0.jar filter=lfs diff=lfs merge=lfs -text
|
42 |
+
core/solvers/utils/pycocoevalcap/spice/spice-1.0.jar filter=lfs diff=lfs merge=lfs -text
|
43 |
+
core/solvers/utils/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Shanghai AI Laboratory
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
assets/framework.png
ADDED
assets/teaser.png
ADDED
core/__init__.py
ADDED
File without changes
|
core/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (183 Bytes). View file
|
|
core/__pycache__/comm_.cpython-312.pyc
ADDED
Binary file (11.9 kB). View file
|
|
core/__pycache__/config.cpython-312.pyc
ADDED
Binary file (26 kB). View file
|
|
core/__pycache__/distributed_utils.cpython-312.pyc
ADDED
Binary file (91.8 kB). View file
|
|
core/__pycache__/make_param_group.cpython-312.pyc
ADDED
Binary file (4.17 kB). View file
|
|
core/__pycache__/memory.cpython-312.pyc
ADDED
Binary file (3.77 kB). View file
|
|
core/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (57.9 kB). View file
|
|
core/clipping.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from easydict import EasyDict as edict
|
7 |
+
|
8 |
+
import warnings
|
9 |
+
from torch._six import inf
|
10 |
+
from core.utils import sync_print
|
11 |
+
|
12 |
+
# return if any inf/nan
|
13 |
+
# div norm by loss_scale, for 'real' norm
|
14 |
+
# if auto_clipper provided, compute max_norm using auto_clipper
|
15 |
+
# else, using give max_norm
|
16 |
+
def clip_grad_norm_(parameters, max_norm=1000000, norm_type=2, auto_clipper=None, loss_scale=1.0):
|
17 |
+
if isinstance(parameters, torch.Tensor):
|
18 |
+
parameters = [parameters]
|
19 |
+
parameters = list(filter(lambda p: p[1].grad is not None, parameters))
|
20 |
+
|
21 |
+
if len(parameters) == 0: return None
|
22 |
+
|
23 |
+
max_norm = float(max_norm)
|
24 |
+
norm_type = float(norm_type)
|
25 |
+
if norm_type == inf:
|
26 |
+
total_norm = max(p.grad.data.abs().max() for p in parameters)
|
27 |
+
else:
|
28 |
+
total_norm = 0
|
29 |
+
for name,p in parameters:
|
30 |
+
param_norm = p.grad.data.norm(norm_type)
|
31 |
+
total_norm += param_norm.item() ** norm_type
|
32 |
+
|
33 |
+
total_norm = total_norm ** (1. / norm_type)
|
34 |
+
|
35 |
+
# check inf/nan
|
36 |
+
overflow_num = torch.zeros(1)
|
37 |
+
if np.isinf(total_norm) or np.isnan(total_norm):
|
38 |
+
overflow_num[0] = 1
|
39 |
+
torch.distributed.all_reduce.allreduce(overflow_num)
|
40 |
+
|
41 |
+
if overflow_num > 0:
|
42 |
+
for name,p in parameters:
|
43 |
+
p.grad.data.fill_(float('nan'))
|
44 |
+
sync_print('total_norm is inf({})/nan({}), skip clipping!!!'.format(np.isinf(total_norm), np.isnan(total_norm)))
|
45 |
+
return total_norm
|
46 |
+
|
47 |
+
# rescale the total_norm by loss_scale
|
48 |
+
total_norm /= loss_scale
|
49 |
+
|
50 |
+
# update auto_clipper, compute max_norm
|
51 |
+
if auto_clipper is not None:
|
52 |
+
max_norm = auto_clipper.update(total_norm)
|
53 |
+
|
54 |
+
# do clipping
|
55 |
+
clip_coef = max_norm / (total_norm + 1e-6)
|
56 |
+
if clip_coef < 1:
|
57 |
+
# sync_print('clip_coef: {}'.format(clip_coef))
|
58 |
+
for _, p in parameters:
|
59 |
+
p.grad.data.mul_(clip_coef)
|
60 |
+
|
61 |
+
return total_norm
|
62 |
+
|
63 |
+
class ClipMeter(object):
|
64 |
+
def __init__(self, mom=None, thresh=None, min_max=False, mean=False, init=False):
|
65 |
+
self.thresh = thresh
|
66 |
+
self.mom = mom
|
67 |
+
self.min_max = min_max
|
68 |
+
self.mean = mean
|
69 |
+
self.val = 1.0
|
70 |
+
self.init = init
|
71 |
+
|
72 |
+
def get_mean(self):
|
73 |
+
return self.val
|
74 |
+
|
75 |
+
def get_clip_val(self):
|
76 |
+
if self.mean:
|
77 |
+
return self.get_mean()
|
78 |
+
else:
|
79 |
+
return self.get_mean() * (1+self.thresh)
|
80 |
+
|
81 |
+
def update(self, x):
|
82 |
+
if self.init:
|
83 |
+
self.val = x
|
84 |
+
self.init = False
|
85 |
+
mean = self.get_mean()
|
86 |
+
if self.min_max:
|
87 |
+
x = max(min(x, mean*(1+self.thresh)), mean*(1-self.thresh))
|
88 |
+
else:
|
89 |
+
x = min(x, mean*(1+self.thresh))
|
90 |
+
|
91 |
+
self.val = self.mom * self.val + (1-self.mom)*x
|
92 |
+
return self.get_clip_val()
|
core/comm_.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
"""
|
3 |
+
This file contains primitives for multi-gpu communication.
|
4 |
+
This is useful when doing distributed training.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import functools
|
8 |
+
import logging
|
9 |
+
import numpy as np
|
10 |
+
import pickle
|
11 |
+
import torch
|
12 |
+
import torch.distributed as dist
|
13 |
+
|
14 |
+
_LOCAL_PROCESS_GROUP = None
|
15 |
+
|
16 |
+
_CAPTION_GEN_MODE = False
|
17 |
+
|
18 |
+
temp_dir = TEMP_DIR = './data/temp'
|
19 |
+
IDS = 'IDS'
|
20 |
+
image_features = 'image_features'
|
21 |
+
text_features = 'text_features'
|
22 |
+
|
23 |
+
old_checkpoint = True
|
24 |
+
|
25 |
+
"""
|
26 |
+
A torch process group which only includes processes that on the same machine as the current process.
|
27 |
+
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
|
28 |
+
"""
|
29 |
+
|
30 |
+
|
31 |
+
def is_dist_avail_and_initialized():
|
32 |
+
if not dist.is_available():
|
33 |
+
return False
|
34 |
+
if not dist.is_initialized():
|
35 |
+
return False
|
36 |
+
return True
|
37 |
+
|
38 |
+
|
39 |
+
def get_world_size() -> int:
|
40 |
+
if not dist.is_available():
|
41 |
+
return 1
|
42 |
+
if not dist.is_initialized():
|
43 |
+
return 1
|
44 |
+
return dist.get_world_size()
|
45 |
+
|
46 |
+
|
47 |
+
def get_rank() -> int:
|
48 |
+
if not dist.is_available():
|
49 |
+
return 0
|
50 |
+
if not dist.is_initialized():
|
51 |
+
return 0
|
52 |
+
return dist.get_rank()
|
53 |
+
|
54 |
+
|
55 |
+
def get_local_rank() -> int:
|
56 |
+
"""
|
57 |
+
Returns:
|
58 |
+
The rank of the current process within the local (per-machine) process group.
|
59 |
+
"""
|
60 |
+
if not dist.is_available():
|
61 |
+
return 0
|
62 |
+
if not dist.is_initialized():
|
63 |
+
return 0
|
64 |
+
# assert _LOCAL_PROCESS_GROUP is not None
|
65 |
+
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|
66 |
+
|
67 |
+
|
68 |
+
def get_local_size() -> int:
|
69 |
+
"""
|
70 |
+
Returns:
|
71 |
+
The size of the per-machine process group,
|
72 |
+
i.e. the number of processes per machine.
|
73 |
+
"""
|
74 |
+
if not dist.is_available():
|
75 |
+
return 1
|
76 |
+
if not dist.is_initialized():
|
77 |
+
return 1
|
78 |
+
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
|
79 |
+
|
80 |
+
|
81 |
+
def is_main_process() -> bool:
|
82 |
+
return get_rank() == 0
|
83 |
+
|
84 |
+
|
85 |
+
def synchronize():
|
86 |
+
"""
|
87 |
+
Helper function to synchronize (barrier) among all processes when
|
88 |
+
using distributed training
|
89 |
+
"""
|
90 |
+
if not dist.is_available():
|
91 |
+
return
|
92 |
+
if not dist.is_initialized():
|
93 |
+
return
|
94 |
+
world_size = dist.get_world_size()
|
95 |
+
if world_size == 1:
|
96 |
+
return
|
97 |
+
dist.barrier()
|
98 |
+
|
99 |
+
|
100 |
+
@functools.lru_cache()
|
101 |
+
def _get_global_gloo_group():
|
102 |
+
"""
|
103 |
+
Return a process group based on gloo backend, containing all the ranks
|
104 |
+
The result is cached.
|
105 |
+
"""
|
106 |
+
if dist.get_backend() == "nccl":
|
107 |
+
return dist.new_group(backend="gloo")
|
108 |
+
else:
|
109 |
+
return dist.group.WORLD
|
110 |
+
|
111 |
+
|
112 |
+
def _serialize_to_tensor(data, group):
|
113 |
+
backend = dist.get_backend(group)
|
114 |
+
assert backend in ["gloo", "nccl"]
|
115 |
+
device = torch.device("cpu" if backend == "gloo" else "cuda")
|
116 |
+
|
117 |
+
buffer = pickle.dumps(data)
|
118 |
+
if len(buffer) > 1024 ** 3:
|
119 |
+
logger = logging.getLogger(__name__)
|
120 |
+
logger.warning(
|
121 |
+
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
|
122 |
+
get_rank(), len(buffer) / (1024 ** 3), device
|
123 |
+
)
|
124 |
+
)
|
125 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
126 |
+
tensor = torch.ByteTensor(storage).to(device=device)
|
127 |
+
return tensor
|
128 |
+
|
129 |
+
|
130 |
+
def _pad_to_largest_tensor(tensor, group):
|
131 |
+
"""
|
132 |
+
Returns:
|
133 |
+
list[int]: size of the tensor, on each rank
|
134 |
+
Tensor: padded tensor that has the max size
|
135 |
+
"""
|
136 |
+
world_size = dist.get_world_size(group=group)
|
137 |
+
assert (
|
138 |
+
world_size >= 1
|
139 |
+
), "comm.gather/all_gather must be called from ranks within the given group!"
|
140 |
+
local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
|
141 |
+
size_list = [
|
142 |
+
torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
|
143 |
+
]
|
144 |
+
dist.all_gather(size_list, local_size, group=group)
|
145 |
+
size_list = [int(size.item()) for size in size_list]
|
146 |
+
|
147 |
+
max_size = max(size_list)
|
148 |
+
|
149 |
+
# we pad the tensor because torch all_gather does not support
|
150 |
+
# gathering tensors of different shapes
|
151 |
+
if local_size != max_size:
|
152 |
+
padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
|
153 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
154 |
+
return size_list, tensor
|
155 |
+
|
156 |
+
|
157 |
+
def all_gather(data, group=None):
|
158 |
+
"""
|
159 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
160 |
+
Args:
|
161 |
+
data: any picklable object
|
162 |
+
group: a torch process group. By default, will use a group which
|
163 |
+
contains all ranks on gloo backend.
|
164 |
+
Returns:
|
165 |
+
list[data]: list of data gathered from each rank
|
166 |
+
"""
|
167 |
+
if get_world_size() == 1:
|
168 |
+
return [data]
|
169 |
+
if group is None:
|
170 |
+
group = _get_global_gloo_group()
|
171 |
+
if dist.get_world_size(group) == 1:
|
172 |
+
return [data]
|
173 |
+
|
174 |
+
tensor = _serialize_to_tensor(data, group)
|
175 |
+
|
176 |
+
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
177 |
+
max_size = max(size_list)
|
178 |
+
|
179 |
+
# receiving Tensor from all ranks
|
180 |
+
tensor_list = [
|
181 |
+
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
|
182 |
+
]
|
183 |
+
dist.all_gather(tensor_list, tensor, group=group)
|
184 |
+
|
185 |
+
data_list = []
|
186 |
+
for size, tensor in zip(size_list, tensor_list):
|
187 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
188 |
+
data_list.append(pickle.loads(buffer))
|
189 |
+
|
190 |
+
return data_list
|
191 |
+
|
192 |
+
|
193 |
+
def gather(data, dst=0, group=None):
|
194 |
+
"""
|
195 |
+
Run gather on arbitrary picklable data (not necessarily tensors).
|
196 |
+
Args:
|
197 |
+
data: any picklable object
|
198 |
+
dst (int): destination rank
|
199 |
+
group: a torch process group. By default, will use a group which
|
200 |
+
contains all ranks on gloo backend.
|
201 |
+
Returns:
|
202 |
+
list[data]: on dst, a list of data gathered from each rank. Otherwise,
|
203 |
+
an empty list.
|
204 |
+
"""
|
205 |
+
if get_world_size() == 1:
|
206 |
+
return [data]
|
207 |
+
if group is None:
|
208 |
+
group = _get_global_gloo_group()
|
209 |
+
if dist.get_world_size(group=group) == 1:
|
210 |
+
return [data]
|
211 |
+
rank = dist.get_rank(group=group)
|
212 |
+
|
213 |
+
tensor = _serialize_to_tensor(data, group)
|
214 |
+
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
215 |
+
|
216 |
+
# receiving Tensor from all ranks
|
217 |
+
if rank == dst:
|
218 |
+
max_size = max(size_list)
|
219 |
+
tensor_list = [
|
220 |
+
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
|
221 |
+
]
|
222 |
+
dist.gather(tensor, tensor_list, dst=dst, group=group)
|
223 |
+
|
224 |
+
data_list = []
|
225 |
+
for size, tensor in zip(size_list, tensor_list):
|
226 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
227 |
+
data_list.append(pickle.loads(buffer))
|
228 |
+
return data_list
|
229 |
+
else:
|
230 |
+
dist.gather(tensor, [], dst=dst, group=group)
|
231 |
+
return []
|
232 |
+
|
233 |
+
|
234 |
+
def broadcast_object(data, src=0, group=None):
|
235 |
+
"""
|
236 |
+
Run gather on arbitrary picklable data (not necessarily tensors).
|
237 |
+
Args:
|
238 |
+
data: any picklable object
|
239 |
+
dst (int): destination rank
|
240 |
+
group: a torch process group. By default, will use a group which
|
241 |
+
contains all ranks on gloo backend.
|
242 |
+
Returns:
|
243 |
+
list[data]: on dst, a list of data gathered from each rank. Otherwise,
|
244 |
+
an empty list.
|
245 |
+
"""
|
246 |
+
# if get_world_size() == 1:
|
247 |
+
# return data
|
248 |
+
# if group is None:
|
249 |
+
# group = _get_global_gloo_group()
|
250 |
+
# if dist.get_world_size(group=group) == 1:
|
251 |
+
# return data
|
252 |
+
|
253 |
+
if not isinstance(data, list):
|
254 |
+
data_list = [data]
|
255 |
+
dist.broadcast_object_list(data_list, src=src, group=group)
|
256 |
+
return data_list[0]
|
257 |
+
else:
|
258 |
+
dist.broadcast_object_list(data, src=src, group=group)
|
259 |
+
return data
|
260 |
+
return data
|
261 |
+
|
262 |
+
|
263 |
+
def shared_random_seed():
|
264 |
+
"""
|
265 |
+
Returns:
|
266 |
+
int: a random number that is the same across all workers.
|
267 |
+
If workers need a shared RNG, they can use this shared seed to
|
268 |
+
create one.
|
269 |
+
All workers must call this function, otherwise it will deadlock.
|
270 |
+
"""
|
271 |
+
ints = np.random.randint(2 ** 31)
|
272 |
+
all_ints = all_gather(ints)
|
273 |
+
return all_ints[0]
|
274 |
+
|
275 |
+
|
276 |
+
def reduce_dict(input_dict, average=True):
|
277 |
+
"""
|
278 |
+
Reduce the values in the dictionary from all processes so that process with rank
|
279 |
+
0 has the reduced results.
|
280 |
+
Args:
|
281 |
+
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
|
282 |
+
average (bool): whether to do average or sum
|
283 |
+
Returns:
|
284 |
+
a dict with the same keys as input_dict, after reduction.
|
285 |
+
"""
|
286 |
+
world_size = get_world_size()
|
287 |
+
if world_size < 2:
|
288 |
+
return input_dict
|
289 |
+
with torch.no_grad():
|
290 |
+
names = []
|
291 |
+
values = []
|
292 |
+
# sort the keys so that they are consistent across processes
|
293 |
+
for k in sorted(input_dict.keys()):
|
294 |
+
names.append(k)
|
295 |
+
values.append(input_dict[k])
|
296 |
+
values = torch.stack(values, dim=0)
|
297 |
+
dist.reduce(values, dst=0)
|
298 |
+
if dist.get_rank() == 0 and average:
|
299 |
+
# only main process gets accumulated, so only divide by
|
300 |
+
# world_size in this case
|
301 |
+
values /= world_size
|
302 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
303 |
+
return reduced_dict
|
304 |
+
|
305 |
+
|
306 |
+
def unwrap_model(model):
|
307 |
+
return model.module if hasattr(model, 'module') else model
|
core/config.py
ADDED
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
import logging
|
3 |
+
import numpy as np
|
4 |
+
from easydict import EasyDict as edict
|
5 |
+
import copy
|
6 |
+
import re
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
from .utils import printlog
|
10 |
+
from torch.distributed.distributed_c10d import _get_global_rank
|
11 |
+
|
12 |
+
|
13 |
+
task_specific_param = ['backbone', 'neck', 'decoder', 'dataset', 'sampler', 'lr_scheduler', 'optimizer',
|
14 |
+
'extra', 'evaluation', 'model_entry_type', 'load_ignore', 'ckpt_task_id',
|
15 |
+
'patch_neck','patch_adapter', 'patch_proj', 'label_neck', 'label_adapter', 'label_proj',]
|
16 |
+
|
17 |
+
loader = yaml.SafeLoader
|
18 |
+
loader.add_implicit_resolver(
|
19 |
+
u'tag:yaml.org,2002:float',
|
20 |
+
re.compile(u'''^(?:
|
21 |
+
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
22 |
+
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
23 |
+
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
24 |
+
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
25 |
+
|[-+]?\\.(?:inf|Inf|INF)
|
26 |
+
|\\.(?:nan|NaN|NAN))$''', re.X),
|
27 |
+
list(u'-+0123456789.'))
|
28 |
+
|
29 |
+
def flat(nums):
|
30 |
+
res = []
|
31 |
+
for i in nums:
|
32 |
+
if isinstance(i, list):
|
33 |
+
res.extend(flat(i))
|
34 |
+
else:
|
35 |
+
res.append(i)
|
36 |
+
return res
|
37 |
+
|
38 |
+
def specific_group_split_modality_groups(group_spec, share_backbone_group_ids,
|
39 |
+
share_decoder_group_ids, share_rgb_group_ids,
|
40 |
+
share_video_group_ids, share_dense_labeling_group_ids,
|
41 |
+
share_sparse_labeling_group_ids, share_text_group_ids, share_modality_group_ids=None):
|
42 |
+
## sanity check
|
43 |
+
assert type(group_spec) is list
|
44 |
+
assert all(map(lambda x: type(x) is int, group_spec))
|
45 |
+
|
46 |
+
num_groups = len(group_spec)
|
47 |
+
splits = np.sum(group_spec)
|
48 |
+
|
49 |
+
if dist.is_initialized():
|
50 |
+
world_size = dist.get_world_size()
|
51 |
+
rank = dist.get_rank()
|
52 |
+
else:
|
53 |
+
world_size = 1
|
54 |
+
rank = 0
|
55 |
+
|
56 |
+
assert world_size % splits == 0, f"{world_size} % {splits}"
|
57 |
+
unit = int(world_size / splits)
|
58 |
+
|
59 |
+
## split
|
60 |
+
group_sizes = [x*unit for x in group_spec] # [8,8,8] / [32, 16]
|
61 |
+
groups = []
|
62 |
+
roots = []
|
63 |
+
last = 0
|
64 |
+
task_info = edict()
|
65 |
+
all_ranks = []
|
66 |
+
|
67 |
+
for i,gs in enumerate(group_sizes):
|
68 |
+
ranks = list(map(int, np.arange(last, last+gs))) #[0...8], [9...15], ...
|
69 |
+
groups.append(dist.new_group(ranks=ranks))
|
70 |
+
roots.append(last) # 0, 8, 16
|
71 |
+
all_ranks.append(ranks)
|
72 |
+
if rank in ranks: # if current gpu rank in traversed rank task group
|
73 |
+
printlog(f">> task_info.group[{i}] ranks {ranks}")
|
74 |
+
task_info.group = groups[-1] # subordinate to what group
|
75 |
+
task_info.task_size = gs # 8
|
76 |
+
task_info.task_id = i
|
77 |
+
task_info.task_rank = rank - last
|
78 |
+
task_info.task_root_rank = last
|
79 |
+
last += gs
|
80 |
+
task_info.root_group = dist.new_group(ranks=roots)
|
81 |
+
printlog(f">> task_info.root_group ranks {roots}")
|
82 |
+
task_info.task_sizes = group_sizes
|
83 |
+
task_info.task_root_ranks = roots
|
84 |
+
task_info.task_num = num_groups
|
85 |
+
|
86 |
+
## share_backbone_group spec
|
87 |
+
if share_backbone_group_ids is not None: # *[0,0,0]*(default) | [0,1,0]task ids
|
88 |
+
# group size must equal within a share_group
|
89 |
+
backboneshareid2idx = {}
|
90 |
+
for idx, this_id in enumerate(share_backbone_group_ids):
|
91 |
+
if this_id not in backboneshareid2idx:
|
92 |
+
backboneshareid2idx[this_id] = list()
|
93 |
+
backboneshareid2idx[this_id].append(idx) # {0: [0,1,2]}| {0: [0,2], 1: [1]}
|
94 |
+
|
95 |
+
## create backbone share group
|
96 |
+
for idxs in backboneshareid2idx.values(): # idxs = [0, 1, 2]
|
97 |
+
this_group_ranks = flat([all_ranks[i] for i in idxs])
|
98 |
+
this_share_group = dist.new_group(ranks=this_group_ranks)
|
99 |
+
this_group_size = len(this_group_ranks)
|
100 |
+
if rank in this_group_ranks:
|
101 |
+
task_info.backbone_share_group = this_share_group
|
102 |
+
printlog(f">> task_info.backbone_share_group[{idxs}] ranks {this_group_ranks}")
|
103 |
+
task_info.backbone_group_size = len(backboneshareid2idx)
|
104 |
+
task_info.backbone_task_size = len(backboneshareid2idx) * this_group_size
|
105 |
+
task_info.backbone_task_rank = np.sum(rank < np.array(this_group_ranks))
|
106 |
+
|
107 |
+
## share_decoder_group spec
|
108 |
+
if share_decoder_group_ids is not None:
|
109 |
+
# group size must equal within a share_group
|
110 |
+
decodershareid2idx = {}
|
111 |
+
for idx, this_id in enumerate(share_decoder_group_ids):
|
112 |
+
if this_id not in decodershareid2idx:
|
113 |
+
decodershareid2idx[this_id] = list()
|
114 |
+
decodershareid2idx[this_id].append(idx)
|
115 |
+
|
116 |
+
## create decoder share group
|
117 |
+
for idxs in decodershareid2idx.values():
|
118 |
+
this_group_ranks = flat([all_ranks[i] for i in idxs])
|
119 |
+
this_share_group = dist.new_group(ranks=this_group_ranks)
|
120 |
+
this_group_size = len(this_group_ranks)
|
121 |
+
if rank in this_group_ranks:
|
122 |
+
task_info.decoder_share_group = this_share_group
|
123 |
+
printlog(f">> task_info.decoder_share_group[{idxs}] ranks {this_group_ranks}")
|
124 |
+
task_info.decoder_group_size = len(decodershareid2idx)
|
125 |
+
task_info.decoder_task_size = len(decodershareid2idx) * this_group_size
|
126 |
+
task_info.decoder_task_rank = np.sum(rank < np.array(this_group_ranks))
|
127 |
+
|
128 |
+
|
129 |
+
# Now, only for sparse labeling to deal with the modality sharing problem,
|
130 |
+
# which is not a good solution, but it works.
|
131 |
+
# parameters that have grads in [0,1,2] are in modality share group,
|
132 |
+
# parameters that do not have grads in [3,4] should be set in the task-specific group.
|
133 |
+
if share_modality_group_ids is not None:
|
134 |
+
# group size must equal within a share_group
|
135 |
+
modalityshareid2idx = {}
|
136 |
+
for idx, this_id in enumerate(share_modality_group_ids):
|
137 |
+
# -1 denotes that this modality does not appear in the current task
|
138 |
+
# if this_id == -1:
|
139 |
+
# continue
|
140 |
+
if this_id not in modalityshareid2idx:
|
141 |
+
modalityshareid2idx[this_id] = list()
|
142 |
+
modalityshareid2idx[this_id].append(idx)
|
143 |
+
|
144 |
+
## create modality share group
|
145 |
+
for idxs in modalityshareid2idx.values(): # 0: [1,2] 1: [3]
|
146 |
+
this_group_ranks = flat([all_ranks[i] for i in idxs]) # 1 2
|
147 |
+
this_share_group = dist.new_group(ranks=this_group_ranks)
|
148 |
+
this_group_size = len(this_group_ranks) # 2
|
149 |
+
if rank in this_group_ranks:
|
150 |
+
task_info.modality_share_group = this_share_group
|
151 |
+
printlog(f">> task_info.modality_share_group[{idxs}] ranks {this_group_ranks}")
|
152 |
+
task_info.modality_group_size = len(modalityshareid2idx)
|
153 |
+
|
154 |
+
if share_rgb_group_ids is not None:
|
155 |
+
# group size must equal within a share_group
|
156 |
+
rgbshareid2idx = {}
|
157 |
+
for idx, this_id in enumerate(share_rgb_group_ids):
|
158 |
+
# -1 denotes that this modality does not appear in the current task
|
159 |
+
# if this_id == -1:
|
160 |
+
# continue
|
161 |
+
if this_id not in rgbshareid2idx:
|
162 |
+
rgbshareid2idx[this_id] = list()
|
163 |
+
rgbshareid2idx[this_id].append(idx)
|
164 |
+
|
165 |
+
## create rgb share group
|
166 |
+
for idxs in rgbshareid2idx.values(): # 0: [1,2] 1: [3]
|
167 |
+
this_group_ranks = flat([all_ranks[i] for i in idxs]) # 1 2
|
168 |
+
this_share_group = dist.new_group(ranks=this_group_ranks)
|
169 |
+
this_group_size = len(this_group_ranks) # 2
|
170 |
+
if rank in this_group_ranks:
|
171 |
+
task_info.rgb_share_group = this_share_group
|
172 |
+
printlog(f">> task_info.rgb_share_group[{idxs}] ranks {this_group_ranks}")
|
173 |
+
task_info.rgb_group_size = len(rgbshareid2idx)
|
174 |
+
# task_info.rgb_task_size = len(rgbshareid2idx) * this_group_size
|
175 |
+
# task_info.rgb_task_rank = np.sum(rank < np.array(this_group_ranks))
|
176 |
+
# all_group_ranks = flat(rgbshareid2idx.values())
|
177 |
+
# if not len(rgbshareid2idx.values()) or dist.get_rank() not in all_group_ranks:
|
178 |
+
# task_info.rgb_share_group = None
|
179 |
+
|
180 |
+
if share_dense_labeling_group_ids is not None:
|
181 |
+
# group size must equal within a share_group
|
182 |
+
dense_labelingshareid2idx = {}
|
183 |
+
for idx, this_id in enumerate(share_dense_labeling_group_ids):
|
184 |
+
# -1 denotes that this modality does not appear in the current task
|
185 |
+
# if this_id == -1:
|
186 |
+
# continue
|
187 |
+
if this_id not in dense_labelingshareid2idx:
|
188 |
+
dense_labelingshareid2idx[this_id] = list()
|
189 |
+
dense_labelingshareid2idx[this_id].append(idx)
|
190 |
+
|
191 |
+
## create dense share group
|
192 |
+
for idxs in dense_labelingshareid2idx.values(): # 0: [1,2] 1: [3]
|
193 |
+
this_group_ranks = flat([all_ranks[i] for i in idxs]) # 1 2
|
194 |
+
this_share_group = dist.new_group(ranks=this_group_ranks)
|
195 |
+
this_group_size = len(this_group_ranks) # 2
|
196 |
+
if rank in this_group_ranks:
|
197 |
+
task_info.dense_labeling_share_group = this_share_group
|
198 |
+
printlog(f">> task_info.dense_labeling_share_group[{idxs}] ranks {this_group_ranks}")
|
199 |
+
task_info.dense_labeling_group_size = len(dense_labelingshareid2idx)
|
200 |
+
|
201 |
+
|
202 |
+
if share_sparse_labeling_group_ids is not None:
|
203 |
+
# group size must equal within a share_group
|
204 |
+
sparse_labelingshareid2idx = {}
|
205 |
+
for idx, this_id in enumerate(share_sparse_labeling_group_ids):
|
206 |
+
# -1 denotes that this modality does not appear in the current task
|
207 |
+
# if this_id == -1:
|
208 |
+
# continue
|
209 |
+
if this_id not in sparse_labelingshareid2idx:
|
210 |
+
sparse_labelingshareid2idx[this_id] = list()
|
211 |
+
sparse_labelingshareid2idx[this_id].append(idx)
|
212 |
+
|
213 |
+
## create sparse share group
|
214 |
+
for idxs in sparse_labelingshareid2idx.values(): # 0: [1,2] 1: [3]
|
215 |
+
this_group_ranks = flat([all_ranks[i] for i in idxs]) # 1 2
|
216 |
+
this_share_group = dist.new_group(ranks=this_group_ranks)
|
217 |
+
this_group_size = len(this_group_ranks) # 2
|
218 |
+
if rank in this_group_ranks:
|
219 |
+
task_info.sparse_labeling_share_group = this_share_group
|
220 |
+
printlog(f">> task_info.sparse_labeling_share_group[{idxs}] ranks {this_group_ranks}")
|
221 |
+
task_info.sparse_labeling_group_size = len(sparse_labelingshareid2idx)
|
222 |
+
|
223 |
+
|
224 |
+
if share_text_group_ids is not None:
|
225 |
+
# group size must equal within a share_group
|
226 |
+
textshareid2idx = {}
|
227 |
+
for idx, this_id in enumerate(share_text_group_ids):
|
228 |
+
# -1 denotes that this modality does not appear in the current task
|
229 |
+
if this_id not in textshareid2idx:
|
230 |
+
textshareid2idx[this_id] = list()
|
231 |
+
textshareid2idx[this_id].append(idx)
|
232 |
+
|
233 |
+
## create text share group
|
234 |
+
for idxs in textshareid2idx.values(): # 0: [1,2] 1: [3]
|
235 |
+
this_group_ranks = flat([all_ranks[i] for i in idxs]) # 1 2
|
236 |
+
this_share_group = dist.new_group(ranks=this_group_ranks)
|
237 |
+
this_group_size = len(this_group_ranks) # 2
|
238 |
+
if rank in this_group_ranks:
|
239 |
+
task_info.text_share_group = this_share_group
|
240 |
+
printlog(f">> task_info.text_share_group[{idxs}] ranks {this_group_ranks}")
|
241 |
+
task_info.text_group_size = len(textshareid2idx)
|
242 |
+
|
243 |
+
|
244 |
+
if share_video_group_ids is not None:
|
245 |
+
# group size must equal within a share_group
|
246 |
+
videoshareid2idx = {}
|
247 |
+
for idx, this_id in enumerate(share_video_group_ids):
|
248 |
+
# -1 denotes that this modality does not appear in the current task
|
249 |
+
# if this_id == -1:
|
250 |
+
# continue
|
251 |
+
if this_id not in videoshareid2idx:
|
252 |
+
videoshareid2idx[this_id] = list()
|
253 |
+
videoshareid2idx[this_id].append(idx)
|
254 |
+
|
255 |
+
## create video share group
|
256 |
+
for idxs in videoshareid2idx.values(): # 0: [1,2] 1: [3]
|
257 |
+
this_group_ranks = flat([all_ranks[i] for i in idxs]) # 1 2
|
258 |
+
this_share_group = dist.new_group(ranks=this_group_ranks)
|
259 |
+
this_group_size = len(this_group_ranks) # 2
|
260 |
+
if rank in this_group_ranks:
|
261 |
+
task_info.video_share_group = this_share_group
|
262 |
+
printlog(f">> task_info.video_share_group[{idxs}] ranks {this_group_ranks}")
|
263 |
+
task_info.video_group_size = len(videoshareid2idx)
|
264 |
+
|
265 |
+
return task_info
|
266 |
+
|
267 |
+
def specific_group_split(group_spec, share_backbone_group_ids, \
|
268 |
+
share_neck_group_ids, share_decoder_group_ids, share_adapter_group_ids):
|
269 |
+
## sanity check
|
270 |
+
assert type(group_spec) is list
|
271 |
+
assert all(map(lambda x: type(x) is int, group_spec))
|
272 |
+
|
273 |
+
num_groups = len(group_spec)
|
274 |
+
splits = np.sum(group_spec)
|
275 |
+
|
276 |
+
world_size = dist.get_world_size()
|
277 |
+
rank = dist.get_rank()
|
278 |
+
|
279 |
+
assert world_size % splits == 0, f"{world_size} % {splits}"
|
280 |
+
unit = int(world_size / splits)
|
281 |
+
|
282 |
+
## split
|
283 |
+
group_sizes = [x*unit for x in group_spec] # [8,8,8] / [32, 16]
|
284 |
+
groups = []
|
285 |
+
roots = []
|
286 |
+
last = 0
|
287 |
+
task_info = edict()
|
288 |
+
all_ranks = []
|
289 |
+
# import pdb;
|
290 |
+
# pdb.set_trace()
|
291 |
+
for i,gs in enumerate(group_sizes):
|
292 |
+
ranks = list(map(int, np.arange(last, last+gs))) #[0...8], [9...15], ...
|
293 |
+
groups.append(dist.new_group(ranks=ranks))
|
294 |
+
roots.append(last) # 0, 8, 16
|
295 |
+
all_ranks.append(ranks)
|
296 |
+
if rank in ranks: # if current gpu rank in traversed rank task group
|
297 |
+
printlog(f">> task_info.group[{i}] ranks {ranks}")
|
298 |
+
task_info.group = groups[-1] # subordinate to what group
|
299 |
+
task_info.task_size = gs # 8
|
300 |
+
task_info.task_id = i
|
301 |
+
task_info.task_rank = rank - last
|
302 |
+
task_info.task_root_rank = last
|
303 |
+
last += gs
|
304 |
+
task_info.root_group = dist.new_group(ranks=roots)
|
305 |
+
printlog(f">> task_info.root_group ranks {roots}")
|
306 |
+
task_info.task_sizes = group_sizes
|
307 |
+
task_info.task_root_ranks = roots
|
308 |
+
task_info.task_num = num_groups
|
309 |
+
# pdb.set_trace()
|
310 |
+
## share_backbone_group spec
|
311 |
+
if share_backbone_group_ids is not None: # *[0,0,0]*(default) | [0,1,0]task ids
|
312 |
+
# group size must equal within a share_group
|
313 |
+
backboneshareid2idx = {}
|
314 |
+
for idx, this_id in enumerate(share_backbone_group_ids):
|
315 |
+
if this_id not in backboneshareid2idx:
|
316 |
+
backboneshareid2idx[this_id] = list()
|
317 |
+
backboneshareid2idx[this_id].append(idx) # {0: [0,1,2]}| {0: [0,2], 1: [1]}
|
318 |
+
|
319 |
+
## create backbone share group
|
320 |
+
for idxs in backboneshareid2idx.values(): # idxs = [0, 1, 2]
|
321 |
+
this_group_ranks = flat([all_ranks[i] for i in idxs])
|
322 |
+
this_share_group = dist.new_group(ranks=this_group_ranks)
|
323 |
+
this_group_size = len(this_group_ranks)
|
324 |
+
if rank in this_group_ranks:
|
325 |
+
task_info.backbone_share_group = this_share_group
|
326 |
+
printlog(f">> task_info.backbone_share_group[{idxs}] ranks {this_group_ranks}")
|
327 |
+
task_info.backbone_group_size = len(backboneshareid2idx)
|
328 |
+
task_info.backbone_task_size = len(backboneshareid2idx) * this_group_size
|
329 |
+
task_info.backbone_task_rank = np.sum(rank < np.array(this_group_ranks))
|
330 |
+
## share_adapter_group spec
|
331 |
+
if share_adapter_group_ids is not None: # *[0,0,0]*(default) | [0,1,0]task ids
|
332 |
+
# group size must equal within a share_group
|
333 |
+
adaptershareid2idx = {}
|
334 |
+
for idx, this_id in enumerate(share_adapter_group_ids):
|
335 |
+
if this_id not in adaptershareid2idx:
|
336 |
+
adaptershareid2idx[this_id] = list()
|
337 |
+
adaptershareid2idx[this_id].append(idx) # {0: [0,1,2]}| {0: [0,2], 1: [1]}
|
338 |
+
|
339 |
+
## create adapter share group
|
340 |
+
for idxs in adaptershareid2idx.values(): # idxs = [0, 1, 2]
|
341 |
+
this_group_ranks = flat([all_ranks[i] for i in idxs])
|
342 |
+
this_share_group = dist.new_group(ranks=this_group_ranks)
|
343 |
+
this_group_size = len(this_group_ranks)
|
344 |
+
if rank in this_group_ranks:
|
345 |
+
task_info.adapter_share_group = this_share_group
|
346 |
+
printlog(f">> task_info.adapter_share_group[{idxs}] ranks {this_group_ranks}")
|
347 |
+
task_info.adapter_group_size = len(adaptershareid2idx)
|
348 |
+
task_info.adapter_task_size = len(adaptershareid2idx) * this_group_size
|
349 |
+
task_info.adapter_task_rank = np.sum(rank < np.array(this_group_ranks))
|
350 |
+
|
351 |
+
# pdb.set_trace()
|
352 |
+
## share_neck_group spec
|
353 |
+
if share_neck_group_ids is not None:
|
354 |
+
# group size must equal within a share_group
|
355 |
+
neckshareid2idx = {}
|
356 |
+
for idx, this_id in enumerate(share_neck_group_ids):
|
357 |
+
if this_id not in neckshareid2idx:
|
358 |
+
neckshareid2idx[this_id] = list()
|
359 |
+
neckshareid2idx[this_id].append(idx)
|
360 |
+
|
361 |
+
## create neck share group
|
362 |
+
for idxs in neckshareid2idx.values():
|
363 |
+
this_group_ranks = flat([all_ranks[i] for i in idxs])
|
364 |
+
this_share_group = dist.new_group(ranks=this_group_ranks)
|
365 |
+
this_group_size = len(this_group_ranks)
|
366 |
+
if rank in this_group_ranks:
|
367 |
+
task_info.neck_share_group = this_share_group
|
368 |
+
printlog(f">> task_info.neck_share_group[{idxs}] ranks {this_group_ranks}")
|
369 |
+
task_info.neck_group_size = len(neckshareid2idx)
|
370 |
+
task_info.neck_task_size = len(neckshareid2idx) * this_group_size
|
371 |
+
task_info.neck_task_rank = np.sum(rank < np.array(this_group_ranks))
|
372 |
+
|
373 |
+
## share_decoder_group spec
|
374 |
+
if share_decoder_group_ids is not None:
|
375 |
+
# group size must equal within a share_group
|
376 |
+
decodershareid2idx = {}
|
377 |
+
for idx, this_id in enumerate(share_decoder_group_ids):
|
378 |
+
if this_id not in decodershareid2idx:
|
379 |
+
decodershareid2idx[this_id] = list()
|
380 |
+
decodershareid2idx[this_id].append(idx)
|
381 |
+
|
382 |
+
## create decoder share group
|
383 |
+
for idxs in decodershareid2idx.values():
|
384 |
+
this_group_ranks = flat([all_ranks[i] for i in idxs])
|
385 |
+
this_share_group = dist.new_group(ranks=this_group_ranks)
|
386 |
+
this_group_size = len(this_group_ranks)
|
387 |
+
if rank in this_group_ranks:
|
388 |
+
task_info.decoder_share_group = this_share_group
|
389 |
+
printlog(f">> task_info.decoder_share_group[{idxs}] ranks {this_group_ranks}")
|
390 |
+
task_info.decoder_group_size = len(decodershareid2idx)
|
391 |
+
task_info.decoder_task_size = len(decodershareid2idx) * this_group_size
|
392 |
+
task_info.decoder_task_rank = np.sum(rank < np.array(this_group_ranks))
|
393 |
+
return task_info
|
394 |
+
|
395 |
+
class Config(object):
|
396 |
+
|
397 |
+
def __init__(self, config_file, noginfo=False, spec_ginfo_index=None):
|
398 |
+
|
399 |
+
with open(config_file) as f:
|
400 |
+
config = yaml.load(f, Loader=loader)
|
401 |
+
# print('config',config)
|
402 |
+
self.config_path = config_file
|
403 |
+
|
404 |
+
world_size = dist.get_world_size()
|
405 |
+
rank = dist.get_rank()
|
406 |
+
|
407 |
+
if noginfo:
|
408 |
+
ginfo = None
|
409 |
+
else: # cherrypick from tasks
|
410 |
+
tasks = config['tasks']
|
411 |
+
num_tasks = len(tasks)
|
412 |
+
if spec_ginfo_index is not None:
|
413 |
+
assert spec_ginfo_index < len(tasks), \
|
414 |
+
'spec_ginfo_index={} is larger than num_tasks={}'.format(spec_ginfo_index, len(tasks))
|
415 |
+
tmp_config = copy.deepcopy(config)
|
416 |
+
config['tasks'] = dict()
|
417 |
+
config['tasks'][0] = tmp_config['tasks'][spec_ginfo_index]
|
418 |
+
config['tasks'][0]['gres_ratio'] = 1
|
419 |
+
tasks = config['tasks']
|
420 |
+
num_tasks = len(tasks)
|
421 |
+
|
422 |
+
# parse task_common and assign to each task
|
423 |
+
task_common = config.get('task_common', None)
|
424 |
+
if task_common is not None:
|
425 |
+
for i in range(num_tasks):
|
426 |
+
for k,v in task_common.items():
|
427 |
+
if not k in tasks[i]:
|
428 |
+
printlog('setting {} to {} for task {}'.format(k, v, i))
|
429 |
+
tasks[i][k] = v
|
430 |
+
|
431 |
+
group_spec = [tasks[i].get('gres_ratio',1) for i in range(num_tasks)]
|
432 |
+
|
433 |
+
## share group spec
|
434 |
+
if config['common'].get('share_backbone_group', False):
|
435 |
+
share_backbone_group_ids = config['common']['share_backbone_group'][:num_tasks]
|
436 |
+
else:
|
437 |
+
share_backbone_group_ids = [0 for i in range(num_tasks)] # hardcoded prior
|
438 |
+
if config['common'].get('share_adapter_group', False):
|
439 |
+
if len(config['common']['share_adapter_group']) == 1:
|
440 |
+
adapter_list = []
|
441 |
+
share_adapter_group_ids = config['common']['share_adapter_group'][:num_tasks]
|
442 |
+
else:
|
443 |
+
share_adapter_group_ids = [0 for i in range(num_tasks)] # hardcoded prior
|
444 |
+
|
445 |
+
if config['common'].get('share_neck_group', False):
|
446 |
+
share_neck_group_ids = config['common']['share_neck_group'][:num_tasks]
|
447 |
+
else:
|
448 |
+
share_neck_group_ids = [0 for i in range(num_tasks)] # hardcoded prior
|
449 |
+
|
450 |
+
if config['common'].get('share_decoder_group', False):
|
451 |
+
share_decoder_group_ids = config['common']['share_decoder_group'][:num_tasks]
|
452 |
+
else:
|
453 |
+
share_decoder_group_ids = [i for i in range(num_tasks)] # hardcoded prior
|
454 |
+
ginfo = specific_group_split(group_spec, share_backbone_group_ids, share_neck_group_ids,
|
455 |
+
share_decoder_group_ids, share_adapter_group_ids)
|
456 |
+
loss_weight_sum = float(np.sum(np.array([task['loss_weight'] for task in tasks.values()])))
|
457 |
+
ginfo.task_name = tasks[ginfo.task_id]['name']
|
458 |
+
ginfo.task_names = [tasks[i]['name'] for i in range(ginfo.task_num)]
|
459 |
+
ginfo.task_weight = float(tasks[ginfo.task_id]['loss_weight']) / loss_weight_sum
|
460 |
+
ginfo.task_type = tasks[ginfo.task_id].get('type', 'normal')
|
461 |
+
ginfo.task_types = [tasks[i].get('type', 'normal') for i in range(ginfo.task_num)]
|
462 |
+
ginfo.task_random_seed = tasks[ginfo.task_id].get('random_seed', 0)
|
463 |
+
|
464 |
+
for p in task_specific_param:
|
465 |
+
if p in config['tasks'][ginfo.task_id]:
|
466 |
+
config['common'][p] = config['tasks'][ginfo.task_id][p]
|
467 |
+
printlog('{} of task{} has been overided to {}'.format(p, ginfo.task_id, config['common'][p]))
|
468 |
+
|
469 |
+
logger = logging.getLogger('global_logger')
|
470 |
+
|
471 |
+
self.world_size = world_size
|
472 |
+
self.rank = rank
|
473 |
+
self.ginfo = ginfo
|
474 |
+
self.config = config
|
475 |
+
self.config_file = config_file
|
476 |
+
|
477 |
+
class Config_Hulk(object):
|
478 |
+
|
479 |
+
def __init__(self, config_file, noginfo=False, spec_ginfo_index=None):
|
480 |
+
|
481 |
+
with open(config_file) as f:
|
482 |
+
config = yaml.load(f, Loader=loader)
|
483 |
+
# print('config',config)
|
484 |
+
self.config_path = config_file
|
485 |
+
|
486 |
+
|
487 |
+
if dist.is_initialized():
|
488 |
+
world_size = dist.get_world_size()
|
489 |
+
rank = dist.get_rank()
|
490 |
+
else:
|
491 |
+
world_size = 1
|
492 |
+
rank = 0
|
493 |
+
|
494 |
+
if noginfo:
|
495 |
+
ginfo = None
|
496 |
+
else: # cherrypick from tasks
|
497 |
+
tasks = config['tasks']
|
498 |
+
num_tasks = len(tasks)
|
499 |
+
if spec_ginfo_index is not None:
|
500 |
+
assert spec_ginfo_index < len(tasks), \
|
501 |
+
'spec_ginfo_index={} is larger than num_tasks={}'.format(spec_ginfo_index, len(tasks))
|
502 |
+
tmp_config = copy.deepcopy(config)
|
503 |
+
config['tasks'] = dict()
|
504 |
+
config['tasks'][0] = tmp_config['tasks'][spec_ginfo_index]
|
505 |
+
config['tasks'][0]['gres_ratio'] = 1
|
506 |
+
tasks = config['tasks']
|
507 |
+
num_tasks = len(tasks)
|
508 |
+
|
509 |
+
# parse task_common and assign to each task
|
510 |
+
task_common = config.get('task_common', None)
|
511 |
+
if task_common is not None:
|
512 |
+
for i in range(num_tasks):
|
513 |
+
for k,v in task_common.items():
|
514 |
+
if not k in tasks[i]:
|
515 |
+
printlog('setting {} to {} for task {}'.format(k, v, i))
|
516 |
+
tasks[i][k] = v
|
517 |
+
|
518 |
+
group_spec = [tasks[i].get('gres_ratio',1) for i in range(num_tasks)]
|
519 |
+
|
520 |
+
## share group spec
|
521 |
+
if config['common'].get('share_backbone_group', False):
|
522 |
+
share_backbone_group_ids = config['common']['share_backbone_group'][:num_tasks]
|
523 |
+
else:
|
524 |
+
share_backbone_group_ids = [0 for i in range(num_tasks)] # hardcoded prior
|
525 |
+
|
526 |
+
|
527 |
+
if config['common'].get('share_decoder_group', False):
|
528 |
+
share_decoder_group_ids = config['common']['share_decoder_group'][:num_tasks]
|
529 |
+
else:
|
530 |
+
share_decoder_group_ids = [i for i in range(num_tasks)] # hardcoded prior
|
531 |
+
|
532 |
+
# use modality groups to control the communication of neck, adapter, and output proj
|
533 |
+
|
534 |
+
if config['common'].get('share_rgb_group', False):
|
535 |
+
share_rgb_group_ids = config['common']['share_rgb_group'][:num_tasks]
|
536 |
+
else:
|
537 |
+
share_rgb_group_ids = [i for i in range(num_tasks)] # hardcoded prior
|
538 |
+
|
539 |
+
if config['common'].get('share_dense_labeling_group', False):
|
540 |
+
share_dense_labeling_group_ids = config['common']['share_dense_labeling_group'][:num_tasks]
|
541 |
+
else:
|
542 |
+
share_dense_labeling_group_ids = [i for i in range(num_tasks)]
|
543 |
+
|
544 |
+
if config['common'].get('share_sparse_labeling_group', False):
|
545 |
+
share_sparse_labeling_group_ids = config['common']['share_sparse_labeling_group'][:num_tasks]
|
546 |
+
else:
|
547 |
+
share_sparse_labeling_group_ids = [i for i in range(num_tasks)]
|
548 |
+
|
549 |
+
if config['common'].get('share_text_group', False):
|
550 |
+
share_text_group_ids = config['common']['share_text_group'][:num_tasks]
|
551 |
+
else:
|
552 |
+
share_text_group_ids = [i for i in range(num_tasks)]
|
553 |
+
|
554 |
+
if config['common'].get('share_video_group', False):
|
555 |
+
share_video_group_ids = config['common']['share_video_group'][:num_tasks]
|
556 |
+
else:
|
557 |
+
share_video_group_ids = [i for i in range(num_tasks)]
|
558 |
+
|
559 |
+
if config['common'].get('share_modality_group', False):
|
560 |
+
share_modality_group_ids = config['common']['share_modality_group'][:num_tasks]
|
561 |
+
else:
|
562 |
+
share_modality_group_ids = [i for i in range(num_tasks)]
|
563 |
+
|
564 |
+
# ginfo = specific_group_split_modality_groups(group_spec, share_backbone_group_ids,
|
565 |
+
# share_decoder_group_ids, share_rgb_group_ids,
|
566 |
+
# share_video_group_ids, share_dense_labeling_group_ids,
|
567 |
+
# share_sparse_labeling_group_ids, share_text_group_ids,
|
568 |
+
# share_modality_group_ids)
|
569 |
+
import easydict
|
570 |
+
ginfo = easydict.EasyDict()
|
571 |
+
ginfo.task_id = 5
|
572 |
+
ginfo.task_num = 5
|
573 |
+
ginfo.backbone_share_group = None
|
574 |
+
ginfo.task_rank = 0
|
575 |
+
|
576 |
+
loss_weight_sum = float(np.sum(np.array([task['loss_weight'] for task in tasks.values()])))
|
577 |
+
ginfo.task_name = tasks[ginfo.task_id]['name']
|
578 |
+
ginfo.task_names = [tasks[i]['name'] for i in range(ginfo.task_num)]
|
579 |
+
# ginfo.task_weight = float(tasks[ginfo.task_id]['loss_weight']) / loss_weight_sum
|
580 |
+
ginfo.task_weight = float(tasks[ginfo.task_id]['loss_weight'])
|
581 |
+
ginfo.task_type = tasks[ginfo.task_id].get('type', 'normal')
|
582 |
+
ginfo.task_types = [tasks[i].get('type', 'normal') for i in range(ginfo.task_num)]
|
583 |
+
ginfo.task_random_seed = tasks[ginfo.task_id].get('random_seed', 0)
|
584 |
+
|
585 |
+
for p in task_specific_param:
|
586 |
+
if p in config['tasks'][ginfo.task_id]:
|
587 |
+
config['common'][p] = config['tasks'][ginfo.task_id][p]
|
588 |
+
printlog('{} of task{} has been overided to {}'.format(p, ginfo.task_id, config['common'][p]))
|
589 |
+
|
590 |
+
logger = logging.getLogger('global_logger')
|
591 |
+
|
592 |
+
self.world_size = world_size
|
593 |
+
self.rank = rank
|
594 |
+
self.ginfo = ginfo
|
595 |
+
self.config = config
|
596 |
+
self.config_file = config_file
|
597 |
+
|
598 |
+
# def __repr__(self) -> str:
|
599 |
+
# return str(self.config)
|
600 |
+
|
core/data/__init__.py
ADDED
File without changes
|
core/data/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (188 Bytes). View file
|
|
core/data/datasets/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .images.pedattr_dataset import MultiAttrDataset
|
2 |
+
from .images.pos_dataset_dev import COCOPosDatasetDev, MPIIPosDatasetDev
|
3 |
+
from .images.parsing_dataset import (Human3M6ParsingDataset, LIPParsingDataset, CIHPParsingDataset, ATRParsingDataset,
|
4 |
+
DeepFashionParsingDataset, VIPParsingDataset, ModaNetParsingDataset,
|
5 |
+
PaperDollParsingDataset)
|
6 |
+
from .images.multi_posedataset import MultiPoseDatasetDev
|
7 |
+
from .images.peddet_dataset_v2 import PedestrainDetectionDataset_v2, PedestrainDetectionDataset_v2demo
|
8 |
+
from .images.image_caption_dataset import CocoCaption, CocoCaptiondemo
|
9 |
+
from .sequences.skeleton_action_dataset import mmSkeletonDataset
|
10 |
+
from .images.smpl_dataset_v2 import MeshTSVYamlDataset
|
11 |
+
from core.utils import printlog
|
12 |
+
|
13 |
+
def dataset_entry(config):
|
14 |
+
printlog('config[kwargs]',config['kwargs'])
|
15 |
+
return globals()[config['type']](**config['kwargs'])
|
core/data/datasets/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (1.41 kB). View file
|
|
core/data/datasets/images/__pycache__/image_caption_dataset.cpython-312.pyc
ADDED
Binary file (13.9 kB). View file
|
|
core/data/datasets/images/__pycache__/multi_posedataset.cpython-312.pyc
ADDED
Binary file (18.9 kB). View file
|
|
core/data/datasets/images/__pycache__/parsing_dataset.cpython-312.pyc
ADDED
Binary file (40.6 kB). View file
|
|
core/data/datasets/images/__pycache__/pedattr_dataset.cpython-312.pyc
ADDED
Binary file (34.7 kB). View file
|
|
core/data/datasets/images/__pycache__/peddet_dataset_v2.cpython-312.pyc
ADDED
Binary file (28.9 kB). View file
|
|
core/data/datasets/images/__pycache__/pos_dataset_dev.cpython-312.pyc
ADDED
Binary file (33.8 kB). View file
|
|
core/data/datasets/images/__pycache__/seg_dataset_dev.cpython-312.pyc
ADDED
Binary file (16.2 kB). View file
|
|
core/data/datasets/images/__pycache__/smpl_dataset_v2.cpython-312.pyc
ADDED
Binary file (18 kB). View file
|
|
core/data/datasets/images/image_caption_dataset.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import os.path as osp
|
10 |
+
from PIL import Image
|
11 |
+
from collections import defaultdict
|
12 |
+
from transformers import BertTokenizer
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
import torch.distributed as dist
|
15 |
+
from torchvision import transforms
|
16 |
+
from torchvision.transforms import PILToTensor, ToTensor
|
17 |
+
from core.data.transforms.caption_transforms import RandomAugment
|
18 |
+
|
19 |
+
def pre_caption(caption, max_words=30):
|
20 |
+
caption = re.sub(
|
21 |
+
r"([.!\"()*#:;~])",
|
22 |
+
' ',
|
23 |
+
caption.lower(),
|
24 |
+
)
|
25 |
+
caption = re.sub(
|
26 |
+
r"\s{2,}",
|
27 |
+
' ',
|
28 |
+
caption,
|
29 |
+
)
|
30 |
+
caption = caption.rstrip('\n')
|
31 |
+
caption = caption.strip(' ')
|
32 |
+
|
33 |
+
#truncate caption
|
34 |
+
caption_words = caption.split(' ')
|
35 |
+
if len(caption_words) > max_words:
|
36 |
+
caption = ' '.join(caption_words[ :max_words])
|
37 |
+
|
38 |
+
return caption
|
39 |
+
|
40 |
+
def data_transforms(split_type='train', img_size=384, min_scale=0.5):
|
41 |
+
if split_type == 'train':
|
42 |
+
data_transforms = transforms.Compose([
|
43 |
+
transforms.RandomResizedCrop(img_size, scale=(min_scale, 1.0), interpolation=Image.BICUBIC),
|
44 |
+
transforms.RandomHorizontalFlip(),
|
45 |
+
RandomAugment(2, 5, isPIL=True,
|
46 |
+
augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
|
47 |
+
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
|
48 |
+
PILToTensor()
|
49 |
+
# ToTensor()
|
50 |
+
])
|
51 |
+
else:
|
52 |
+
data_transforms = transforms.Compose([
|
53 |
+
transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC),
|
54 |
+
PILToTensor(),
|
55 |
+
# ToTensor()
|
56 |
+
])
|
57 |
+
return data_transforms
|
58 |
+
|
59 |
+
|
60 |
+
class CocoCaption(Dataset):
|
61 |
+
"""
|
62 |
+
Implementation of the dataloader for coco_caption.
|
63 |
+
Mainly used in the model training and evaluation.
|
64 |
+
Params:
|
65 |
+
ginfo: group information for Multitask learning.
|
66 |
+
coco_root: root path of coco2014 dataset.
|
67 |
+
anno_root: annotation path of coco captions.
|
68 |
+
bert_dir: path of bert-base-uncased for loading tokenizer.
|
69 |
+
max_words: max length of input captions.
|
70 |
+
img_size: image size.
|
71 |
+
prompt: given prompt to add before captions.
|
72 |
+
"""
|
73 |
+
def __init__(self, ginfo, max_words=30, img_size=384, beam_size=1, prompt='', split_type='train',
|
74 |
+
cuhk_peds=False, cuhk_peds_root=None, cuhk_peds_anno_root=None, cuhk_peds_gt_root=None,
|
75 |
+
joint_train=False, synth_peds_root=None, joint_train_anno_root=None, coco_train=False,
|
76 |
+
coco_root=None, anno_root=None, bert_dir='', mals_root=None, luperson_root=None):
|
77 |
+
self.task_name = ginfo.task_name
|
78 |
+
self.rank = dist.get_rank()
|
79 |
+
self.prompt = prompt
|
80 |
+
|
81 |
+
# plus one for bos token
|
82 |
+
self.max_words = max_words + 1
|
83 |
+
self.img_size = img_size
|
84 |
+
self.split_type = split_type
|
85 |
+
self.beam_size = beam_size
|
86 |
+
|
87 |
+
self.transforms = data_transforms(split_type, img_size)
|
88 |
+
self.tokenizer = BertTokenizer.from_pretrained(bert_dir, do_lower=True)
|
89 |
+
self.cuhk_peds = cuhk_peds
|
90 |
+
self.joint_train = joint_train
|
91 |
+
self.coco_train = coco_train
|
92 |
+
if joint_train:
|
93 |
+
self.annotation = json.load(open(joint_train_anno_root, 'r'))
|
94 |
+
self.cuhk_peds_root = cuhk_peds_root
|
95 |
+
self.synth_peds_root = synth_peds_root
|
96 |
+
self.mals_root = mals_root
|
97 |
+
self.luperson_root = luperson_root
|
98 |
+
self.coco_gt_file = cuhk_peds_gt_root
|
99 |
+
elif cuhk_peds:
|
100 |
+
self.annotation = json.load(open(cuhk_peds_anno_root, 'r'))
|
101 |
+
self.coco_gt_file = cuhk_peds_gt_root
|
102 |
+
self.coco_root = cuhk_peds_root
|
103 |
+
elif coco_train:
|
104 |
+
self.coco_root = coco_root
|
105 |
+
self.coco_gt_file = osp.join(anno_root, 'coco_gt', 'coco_karpathy_' + split_type + '_gt.json')
|
106 |
+
self.annotation = json.load(open(osp.join(anno_root, 'coco_karpathy_' + split_type + '.json'), 'r'))
|
107 |
+
|
108 |
+
|
109 |
+
def __len__(self):
|
110 |
+
return len(self.annotation)
|
111 |
+
|
112 |
+
def __getitem__(self, index):
|
113 |
+
sample = self.annotation[index]
|
114 |
+
if self.joint_train and self.split_type == 'train':
|
115 |
+
if sample['split'] == 'cuhk_peds':
|
116 |
+
image_path = osp.join(self.cuhk_peds_root, sample['image'])
|
117 |
+
elif sample['split'] == 'mals':
|
118 |
+
image_path = osp.join(self.mals_root, sample['image'])
|
119 |
+
elif sample['split'] == 'luperson':
|
120 |
+
image_path = osp.join(self.luperson_root, sample['image'])
|
121 |
+
else:
|
122 |
+
image_path = osp.join(self.synth_peds_root, sample['image'])
|
123 |
+
else:
|
124 |
+
image_path = osp.join(self.coco_root, sample['image'])
|
125 |
+
image = Image.open(image_path).convert('RGB')
|
126 |
+
image = self.transforms(image)
|
127 |
+
if self.split_type != 'train':
|
128 |
+
caption_id = np.zeros(self.max_words - 1, dtype=np.int32)
|
129 |
+
token_type_id = np.zeros(self.max_words - 1, dtype=np.int32)
|
130 |
+
caption_pad_mask = np.zeros(self.max_words - 1, dtype=np.int32)
|
131 |
+
if self.cuhk_peds:
|
132 |
+
img_id = sample['image'].split('.')[0]
|
133 |
+
else:
|
134 |
+
img_id = sample['image'].split('/')[-1].strip('.jpg').split('_')[-1]
|
135 |
+
coco_gt_file = self.coco_gt_file
|
136 |
+
beam_size = self.beam_size
|
137 |
+
return {'image': image, 'input_id': caption_id, 'image_id': int(img_id) if not self.cuhk_peds else img_id,
|
138 |
+
'coco_gt_file': coco_gt_file, 'beam_size': beam_size,
|
139 |
+
'token_type_id': token_type_id, 'padding_mask': caption_pad_mask}
|
140 |
+
caption = self.prompt + pre_caption(sample['caption'], self.max_words)
|
141 |
+
caption_encode = self.tokenizer.encode_plus(caption, max_length=self.max_words, pad_to_max_length=True,
|
142 |
+
return_attention_mask=True, return_token_type_ids=True,
|
143 |
+
truncation=True)
|
144 |
+
caption_id, caption_pad_mask, token_type_id = caption_encode['input_ids'], caption_encode['attention_mask'], caption_encode['token_type_ids']
|
145 |
+
caption_id = np.array(caption_id)
|
146 |
+
token_type_id = np.array(token_type_id)
|
147 |
+
caption_pad_mask = np.array(caption_pad_mask)
|
148 |
+
# caption_pad_mask = (1 - np.array(caption_pad_mask)).astype(bool)
|
149 |
+
caption = [caption]
|
150 |
+
output = {'image': image, 'input_id': caption_id, 'token_type_id': token_type_id, 'padding_mask': caption_pad_mask, 'label': caption_id}
|
151 |
+
return output
|
152 |
+
|
153 |
+
def __repr__(self):
|
154 |
+
return self.__class__.__name__ + \
|
155 |
+
f'rank: {self.rank} task: {self.task_name} mode:{"training" if self.split_type == "train" else "inference"} ' \
|
156 |
+
f'dataset_len:{len(self.annotation)} augmentation: {self.transforms}'
|
157 |
+
|
158 |
+
|
159 |
+
class CocoCaptiondemo(Dataset):
|
160 |
+
"""
|
161 |
+
Implementation of the dataloader for coco_caption.
|
162 |
+
Mainly used in the model training and evaluation.
|
163 |
+
Params:
|
164 |
+
ginfo: group information for Multitask learning.
|
165 |
+
coco_root: root path of coco2014 dataset.
|
166 |
+
anno_root: annotation path of coco captions.
|
167 |
+
bert_dir: path of bert-base-uncased for loading tokenizer.
|
168 |
+
max_words: max length of input captions.
|
169 |
+
img_size: image size.
|
170 |
+
prompt: given prompt to add before captions.
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self, ginfo, max_words=30, img_size=384, beam_size=1, prompt='', split_type='train', demo_dir='/mnt/cache/tangshixiang/wyz_proj/demo_video_unihcpv2/folder0',
|
174 |
+
cuhk_peds=False, cuhk_peds_root=None, cuhk_peds_anno_root=None, cuhk_peds_gt_root=None,
|
175 |
+
joint_train=False, synth_peds_root=None, joint_train_anno_root=None, coco_train=False,
|
176 |
+
coco_root=None, anno_root=None, bert_dir='', mals_root=None, luperson_root=None):
|
177 |
+
self.task_name = ginfo.task_name
|
178 |
+
self.rank = dist.get_rank()
|
179 |
+
self.prompt = prompt
|
180 |
+
|
181 |
+
# plus one for bos token
|
182 |
+
self.max_words = max_words + 1
|
183 |
+
self.img_size = img_size
|
184 |
+
self.split_type = split_type
|
185 |
+
self.beam_size = beam_size
|
186 |
+
|
187 |
+
self.transforms = data_transforms(split_type, img_size)
|
188 |
+
self.tokenizer = BertTokenizer.from_pretrained(bert_dir, do_lower=True)
|
189 |
+
self.cuhk_peds = cuhk_peds
|
190 |
+
self.joint_train = joint_train
|
191 |
+
self.coco_train = coco_train
|
192 |
+
if joint_train:
|
193 |
+
self.annotation = json.load(open(joint_train_anno_root, 'r'))
|
194 |
+
self.cuhk_peds_root = cuhk_peds_root
|
195 |
+
self.synth_peds_root = synth_peds_root
|
196 |
+
self.mals_root = mals_root
|
197 |
+
self.luperson_root = luperson_root
|
198 |
+
self.coco_gt_file = cuhk_peds_gt_root
|
199 |
+
elif cuhk_peds:
|
200 |
+
self.annotation = json.load(open(cuhk_peds_anno_root, 'r'))
|
201 |
+
self.coco_gt_file = cuhk_peds_gt_root
|
202 |
+
self.coco_root = cuhk_peds_root
|
203 |
+
elif coco_train:
|
204 |
+
self.coco_root = coco_root
|
205 |
+
self.coco_gt_file = osp.join(anno_root, 'coco_gt', 'coco_karpathy_' + split_type + '_gt.json')
|
206 |
+
self.annotation = json.load(open(osp.join(anno_root, 'coco_karpathy_' + split_type + '.json'), 'r'))
|
207 |
+
self.demo_dir = demo_dir
|
208 |
+
|
209 |
+
|
210 |
+
def __len__(self):
|
211 |
+
return len(os.listdir(self.demo_dir))
|
212 |
+
|
213 |
+
def __getitem__(self, index):
|
214 |
+
# import pdb; pdb.set_trace()
|
215 |
+
sample = self.annotation[index]
|
216 |
+
if self.joint_train and self.split_type == 'train':
|
217 |
+
if sample['split'] == 'cuhk_peds':
|
218 |
+
image_path = osp.join(self.cuhk_peds_root, sample['image'])
|
219 |
+
elif sample['split'] == 'mals':
|
220 |
+
image_path = osp.join(self.mals_root, sample['image'])
|
221 |
+
elif sample['split'] == 'luperson':
|
222 |
+
image_path = osp.join(self.luperson_root, sample['image'])
|
223 |
+
else:
|
224 |
+
image_path = osp.join(self.synth_peds_root, sample['image'])
|
225 |
+
else:
|
226 |
+
image_path = osp.join(self.coco_root, sample['image'])
|
227 |
+
filename = os.path.join(self.demo_dir, f'frame_{index}.jpg')
|
228 |
+
image = Image.open(filename).convert('RGB')
|
229 |
+
image = self.transforms(image)
|
230 |
+
if self.split_type != 'train':
|
231 |
+
caption_id = np.zeros(self.max_words - 1, dtype=np.int32)
|
232 |
+
token_type_id = np.zeros(self.max_words - 1, dtype=np.int32)
|
233 |
+
caption_pad_mask = np.zeros(self.max_words - 1, dtype=np.int32)
|
234 |
+
if self.cuhk_peds:
|
235 |
+
img_id = sample['image'].split('.')[0]
|
236 |
+
else:
|
237 |
+
img_id = sample['image'].split('/')[-1].strip('.jpg').split('_')[-1]
|
238 |
+
coco_gt_file = self.coco_gt_file
|
239 |
+
beam_size = self.beam_size
|
240 |
+
return {'image': image, 'input_id': caption_id, 'image_id': filename,
|
241 |
+
'coco_gt_file': coco_gt_file, 'beam_size': beam_size,
|
242 |
+
'token_type_id': token_type_id, 'padding_mask': caption_pad_mask}
|
243 |
+
caption = self.prompt + pre_caption(sample['caption'], self.max_words)
|
244 |
+
caption_encode = self.tokenizer.encode_plus(caption, max_length=self.max_words, pad_to_max_length=True,
|
245 |
+
return_attention_mask=True, return_token_type_ids=True,
|
246 |
+
truncation=True)
|
247 |
+
caption_id, caption_pad_mask, token_type_id = caption_encode['input_ids'], caption_encode['attention_mask'], \
|
248 |
+
caption_encode['token_type_ids']
|
249 |
+
caption_id = np.array(caption_id)
|
250 |
+
token_type_id = np.array(token_type_id)
|
251 |
+
caption_pad_mask = np.array(caption_pad_mask)
|
252 |
+
# caption_pad_mask = (1 - np.array(caption_pad_mask)).astype(bool)
|
253 |
+
caption = [caption]
|
254 |
+
output = {'image': image, 'input_id': filename, 'token_type_id': token_type_id,
|
255 |
+
'padding_mask': caption_pad_mask, 'label': caption_id}
|
256 |
+
return output
|
257 |
+
|
258 |
+
def __repr__(self):
|
259 |
+
return self.__class__.__name__ + \
|
260 |
+
f'rank: {self.rank} task: {self.task_name} mode:{"training" if self.split_type == "train" else "inference"} ' \
|
261 |
+
f'dataset_len:{len(self.annotation)} augmentation: {self.transforms}'
|
core/data/datasets/images/multi_posedataset.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from abc import ABCMeta, abstractmethod
|
3 |
+
import numpy as np
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import os
|
8 |
+
import cv2
|
9 |
+
import random
|
10 |
+
import time
|
11 |
+
import os.path as osp
|
12 |
+
import os
|
13 |
+
import warnings
|
14 |
+
from collections import OrderedDict, defaultdict
|
15 |
+
from core.data.transforms.pose_transforms import *
|
16 |
+
import json #_tricks as json
|
17 |
+
import numpy as np
|
18 |
+
from xtcocotools.coco import COCO
|
19 |
+
from xtcocotools.cocoeval import COCOeval
|
20 |
+
import torch.distributed as dist
|
21 |
+
|
22 |
+
|
23 |
+
from core.utils import sync_print
|
24 |
+
|
25 |
+
|
26 |
+
class PetrelCOCO(COCO):
|
27 |
+
def __init__(self, annotation_file=None, test_index=None, ann_data=None):
|
28 |
+
"""
|
29 |
+
Constructor of Microsoft COCO helper class for reading and visualizing annotations.
|
30 |
+
:param annotation_file (str): location of annotation file
|
31 |
+
:param image_folder (str): location to the folder that hosts images.
|
32 |
+
:return:
|
33 |
+
"""
|
34 |
+
self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
|
35 |
+
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
|
36 |
+
self.anno_file = [annotation_file]
|
37 |
+
self.test_index = test_index
|
38 |
+
if annotation_file is not None:
|
39 |
+
print('loading annotations into memory...')
|
40 |
+
tic = time.time()
|
41 |
+
# https://github.com/cocodataset/cocoapi/pull/453/
|
42 |
+
if ann_data == None:
|
43 |
+
with open(annotation_file, 'r') as f:
|
44 |
+
dataset = json.load(f)
|
45 |
+
else:
|
46 |
+
dataset = ann_data
|
47 |
+
assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
|
48 |
+
print('Done (t={:0.2f}s)'.format(time.time()- tic))
|
49 |
+
self.dataset = dataset
|
50 |
+
self.createIndex()
|
51 |
+
if 'annotations' in self.dataset:
|
52 |
+
for i in range(len(self.dataset['annotations'])):
|
53 |
+
if self.test_index is not None:
|
54 |
+
keypoints = np.array(self.dataset['annotations'][i]['keypoints']).reshape([-1, 3])
|
55 |
+
keypoints = keypoints[self.test_index, :]
|
56 |
+
self.dataset['annotations'][i]['keypoints'] = keypoints.reshape([-1]).tolist()
|
57 |
+
if 'iscrowd' not in self.dataset['annotations'][i]:
|
58 |
+
self.dataset['annotations'][i]['iscrowd'] = False
|
59 |
+
|
60 |
+
|
61 |
+
class MultiPoseDatasetDev(Dataset):
|
62 |
+
def __init__(self,
|
63 |
+
ginfo,
|
64 |
+
ann_file,
|
65 |
+
img_prefix,
|
66 |
+
data_cfg,
|
67 |
+
test_mode=False,
|
68 |
+
use_udp=False,
|
69 |
+
dataset_name='coco',
|
70 |
+
use_ceph=False,
|
71 |
+
simp_aug=False,
|
72 |
+
**kwargs):
|
73 |
+
|
74 |
+
assert dataset_name in ['coco', 'aic', 'posetrack', 'halpe', 'JRDB2022', 'h36m', 'mhp', 'penn_action', '3DPW', '3DHP', 'AIST'], "invalid dataset name input"
|
75 |
+
self.dataset_name = dataset_name
|
76 |
+
self.image_info = {}
|
77 |
+
self.ann_info = {}
|
78 |
+
self.initialized = False
|
79 |
+
|
80 |
+
self.use_ceph = True
|
81 |
+
self.annotations_path = ann_file
|
82 |
+
self.img_prefix = img_prefix
|
83 |
+
self.test_mode = test_mode
|
84 |
+
print('data_cfg0',data_cfg)
|
85 |
+
# data_cfg=demjson.decode(data_cfg)
|
86 |
+
# print('data_cfg',data_cfg)
|
87 |
+
self.ann_info['image_size'] = np.array(data_cfg['image_size'])
|
88 |
+
self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
|
89 |
+
self.ann_info['num_joints'] = data_cfg['num_joints']
|
90 |
+
|
91 |
+
self.ann_info['inference_channel'] = data_cfg['inference_channel']
|
92 |
+
self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
|
93 |
+
self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
|
94 |
+
|
95 |
+
self.db = []
|
96 |
+
self.task_name = ginfo.task_name
|
97 |
+
|
98 |
+
if test_mode:
|
99 |
+
pipeline = [
|
100 |
+
LoadImageFromFile(use_ceph=use_ceph),
|
101 |
+
TopDownAffine(use_udp=use_udp),
|
102 |
+
ToUNTensor(),
|
103 |
+
Collect(keys=['image'],
|
104 |
+
meta_keys=['image_file', 'center', 'scale', 'rotation', 'bbox_score', 'flip_pairs'])
|
105 |
+
]
|
106 |
+
else:
|
107 |
+
if self.dataset_name in ['coco', 'aic'] or simp_aug:
|
108 |
+
pipeline = [
|
109 |
+
LoadImageFromFile(use_ceph=use_ceph),
|
110 |
+
TopDownRandomFlip(flip_prob=0.5),
|
111 |
+
TopDownHalfBodyTransform(num_joints_half_body=8, prob_half_body=0.3),
|
112 |
+
TopDownGetRandomScaleRotation(rot_factor=40, scale_factor=0.5),
|
113 |
+
TopDownAffine(use_udp=use_udp),
|
114 |
+
ToUNTensor(),
|
115 |
+
TopDownGenerateTarget(sigma=2, encoding='UDP' if use_udp else 'MSRA'),
|
116 |
+
Collect(keys=['image', 'label', 'target_weight'],
|
117 |
+
meta_keys=['image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale','rotation',
|
118 |
+
'bbox_score', 'flip_pairs'])
|
119 |
+
]
|
120 |
+
elif self.dataset_name in ['posetrack', 'halpe', 'penn_action', '3DPW', 'mhp']:
|
121 |
+
pipeline = [
|
122 |
+
LoadImageFromFile(),
|
123 |
+
TopDownGetBboxCenterScale(padding=1.25),
|
124 |
+
TopDownRandomShiftBboxCenter(shift_factor=0.16, prob=0.3),
|
125 |
+
TopDownRandomFlip(flip_prob=0.5),
|
126 |
+
TopDownHalfBodyTransform(num_joints_half_body=8, prob_half_body=0.3),
|
127 |
+
TopDownGetRandomScaleRotation(rot_factor=40, scale_factor=0.5),
|
128 |
+
TopDownAffine(use_udp=use_udp),
|
129 |
+
ToUNTensor(),
|
130 |
+
TopDownGenerateTarget(sigma=2, encoding='UDP' if use_udp else 'MSRA'),
|
131 |
+
Collect(keys=['image', 'label', 'target_weight'],
|
132 |
+
meta_keys=['image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale','rotation',
|
133 |
+
'bbox_score', 'flip_pairs'])
|
134 |
+
]
|
135 |
+
else:
|
136 |
+
pipeline = [
|
137 |
+
LoadImageFromFile(),
|
138 |
+
# TopDownGetBboxCenterScale(padding=1.25),
|
139 |
+
TopDownRandomShiftBboxCenter(shift_factor=0.16, prob=0.3),
|
140 |
+
TopDownRandomFlip(flip_prob=0.5),
|
141 |
+
TopDownHalfBodyTransform(num_joints_half_body=8, prob_half_body=0.3),
|
142 |
+
TopDownGetRandomScaleRotation(rot_factor=40, scale_factor=0.5),
|
143 |
+
TopDownAffine(use_udp=use_udp),
|
144 |
+
ToUNTensor(),
|
145 |
+
TopDownGenerateTarget(sigma=2, encoding='UDP' if use_udp else 'MSRA'),
|
146 |
+
Collect(keys=['image', 'label', 'target_weight'],
|
147 |
+
meta_keys=['image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation',
|
148 |
+
'bbox_score', 'flip_pairs'])
|
149 |
+
]
|
150 |
+
|
151 |
+
|
152 |
+
self.pipeline = ComposeX(pipeline)
|
153 |
+
# dict_keys(['image_file', 'center', 'scale', 'bbox', 'rotation', 'joints_3d', 'joints_3d_visible', 'dataset',
|
154 |
+
# 'bbox_score', 'bbox_id', 'ann_info', 'image', 'flipped',
|
155 |
+
# 'label', ****'target' as in mmlab****
|
156 |
+
# 'target_weight'])
|
157 |
+
|
158 |
+
|
159 |
+
self.use_gt_bbox = data_cfg['use_gt_bbox']
|
160 |
+
self.bbox_file = data_cfg['bbox_file'] if data_cfg['bbox_file'].startswith('/mnt') else (Path(__file__).parent / 'resources' / data_cfg['bbox_file']).resolve()
|
161 |
+
self.det_bbox_thr = data_cfg.get('det_bbox_thr', 0.0)
|
162 |
+
if 'image_thr' in data_cfg:
|
163 |
+
warnings.warn(
|
164 |
+
'image_thr is deprecated, '
|
165 |
+
'please use det_bbox_thr instead', DeprecationWarning)
|
166 |
+
self.det_bbox_thr = data_cfg['image_thr']
|
167 |
+
|
168 |
+
|
169 |
+
self.ann_info['flip_pairs'] = data_cfg['flip_pairs']
|
170 |
+
|
171 |
+
self.ann_info['upper_body_ids'] = data_cfg['upper_body_ids']
|
172 |
+
self.ann_info['lower_body_ids'] = data_cfg['lower_body_ids']
|
173 |
+
|
174 |
+
self.ann_info['use_different_joint_weights'] = False
|
175 |
+
self.ann_info['joint_weights'] = np.array(
|
176 |
+
data_cfg['joint_weights'],
|
177 |
+
dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
|
178 |
+
|
179 |
+
# 'https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/'
|
180 |
+
# 'pycocotools/cocoeval.py#L523'
|
181 |
+
|
182 |
+
self.coco = PetrelCOCO(ann_file)
|
183 |
+
|
184 |
+
cats = [
|
185 |
+
cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
|
186 |
+
]
|
187 |
+
self.classes = ['__background__'] + cats
|
188 |
+
self.num_classes = len(self.classes)
|
189 |
+
self._class_to_ind = dict(zip(self.classes, range(self.num_classes)))
|
190 |
+
self._class_to_coco_ind = dict(zip(cats, self.coco.getCatIds()))
|
191 |
+
self._coco_ind_to_class_ind = dict(
|
192 |
+
(self._class_to_coco_ind[cls], self._class_to_ind[cls])
|
193 |
+
for cls in self.classes[1:])
|
194 |
+
self.img_ids = self.coco.getImgIds()
|
195 |
+
self.num_images = len(self.img_ids)
|
196 |
+
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
|
197 |
+
|
198 |
+
|
199 |
+
self.db = self._get_db()
|
200 |
+
print(f'=> dataset: {self.dataset_name} num_images: {self.num_images}')
|
201 |
+
print(f'=> dataset: {self.dataset_name} load {len(self.db)} samples')
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
def _get_mapping_id_name(imgs):
|
205 |
+
"""
|
206 |
+
Args:
|
207 |
+
imgs (dict): dict of image info.
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
tuple: Image name & id mapping dicts.
|
211 |
+
|
212 |
+
- id2name (dict): Mapping image id to name.
|
213 |
+
- name2id (dict): Mapping image name to id.
|
214 |
+
"""
|
215 |
+
id2name = {}
|
216 |
+
name2id = {}
|
217 |
+
for image_id, image in imgs.items():
|
218 |
+
file_name = image['file_name']
|
219 |
+
id2name[image_id] = file_name
|
220 |
+
name2id[file_name] = image_id
|
221 |
+
|
222 |
+
return id2name, name2id
|
223 |
+
|
224 |
+
def _get_db(self):
|
225 |
+
"""Load dataset."""
|
226 |
+
if (not self.test_mode) or self.use_gt_bbox:
|
227 |
+
# use ground truth bbox
|
228 |
+
gt_db = self._load_coco_keypoint_annotations()
|
229 |
+
else:
|
230 |
+
# use bbox from detection
|
231 |
+
gt_db = self._load_coco_person_detection_results()
|
232 |
+
return gt_db
|
233 |
+
|
234 |
+
def _load_coco_keypoint_annotations(self):
|
235 |
+
"""Ground truth bbox and keypoints."""
|
236 |
+
gt_db = []
|
237 |
+
for img_id in self.img_ids:
|
238 |
+
gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
|
239 |
+
return gt_db
|
240 |
+
|
241 |
+
def _load_coco_keypoint_annotation_kernel(self, img_id):
|
242 |
+
"""load annotation from COCOAPI.
|
243 |
+
|
244 |
+
Note:
|
245 |
+
bbox:[x1, y1, w, h]
|
246 |
+
Args:
|
247 |
+
img_id: coco image id
|
248 |
+
Returns:
|
249 |
+
dict: db entry
|
250 |
+
"""
|
251 |
+
img_ann = self.coco.loadImgs(img_id)[0]
|
252 |
+
width = img_ann['width']
|
253 |
+
height = img_ann['height']
|
254 |
+
num_joints = self.ann_info['num_joints']
|
255 |
+
|
256 |
+
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
|
257 |
+
objs = self.coco.loadAnns(ann_ids)
|
258 |
+
|
259 |
+
# sanitize bboxes
|
260 |
+
valid_objs = []
|
261 |
+
for obj in objs:
|
262 |
+
if 'bbox' not in obj:
|
263 |
+
continue
|
264 |
+
x, y, w, h = obj['bbox']
|
265 |
+
x1 = max(0, x)
|
266 |
+
y1 = max(0, y)
|
267 |
+
x2 = min(width - 1, x1 + max(0, w - 1))
|
268 |
+
y2 = min(height - 1, y1 + max(0, h - 1))
|
269 |
+
if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
|
270 |
+
obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
|
271 |
+
valid_objs.append(obj)
|
272 |
+
objs = valid_objs
|
273 |
+
|
274 |
+
bbox_id = 0
|
275 |
+
rec = []
|
276 |
+
for obj in objs:
|
277 |
+
if 'keypoints' not in obj:
|
278 |
+
continue
|
279 |
+
if max(obj['keypoints']) == 0:
|
280 |
+
continue
|
281 |
+
if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
|
282 |
+
continue
|
283 |
+
joints_3d = np.zeros((num_joints, 3), dtype=np.float32)
|
284 |
+
joints_3d_visible = np.zeros((num_joints, 3), dtype=np.float32)
|
285 |
+
|
286 |
+
keypoints = np.array(obj['keypoints']).reshape(-1, 3)
|
287 |
+
|
288 |
+
if self.dataset_name == 'posetrack':
|
289 |
+
keypoints = np.delete(keypoints, [3, 4], axis=0) # keypoint idx == 3 and 4 not annot
|
290 |
+
elif self.dataset_name == 'halpe':
|
291 |
+
keypoints = keypoints[:17,:] # halpe has only 17 valid kp
|
292 |
+
|
293 |
+
joints_3d[:, :2] = keypoints[:, :2]
|
294 |
+
joints_3d_visible[:, :2] = np.minimum(1, keypoints[:, 2:3])
|
295 |
+
|
296 |
+
center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
|
297 |
+
|
298 |
+
image_file = os.path.join(self.img_prefix, self.id2name[img_id])
|
299 |
+
rec.append({
|
300 |
+
'image_file': image_file,
|
301 |
+
'center': center,
|
302 |
+
'scale': scale,
|
303 |
+
'bbox': obj['clean_bbox'][:4],
|
304 |
+
'rotation': 0,
|
305 |
+
'joints_3d': joints_3d,
|
306 |
+
'joints_3d_visible': joints_3d_visible,
|
307 |
+
'dataset': self.dataset_name,
|
308 |
+
'bbox_score': 1,
|
309 |
+
'bbox_id': bbox_id
|
310 |
+
})
|
311 |
+
bbox_id = bbox_id + 1
|
312 |
+
|
313 |
+
return rec
|
314 |
+
|
315 |
+
def _xywh2cs(self, x, y, w, h):
|
316 |
+
"""This encodes bbox(x,y,w,w) into (center, scale)
|
317 |
+
|
318 |
+
Args:
|
319 |
+
x, y, w, h
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
tuple: A tuple containing center and scale.
|
323 |
+
|
324 |
+
- center (np.ndarray[float32](2,)): center of the bbox (x, y).
|
325 |
+
- scale (np.ndarray[float32](2,)): scale of the bbox w & h.
|
326 |
+
"""
|
327 |
+
aspect_ratio = self.ann_info['image_size'][0] / self.ann_info[
|
328 |
+
'image_size'][1]
|
329 |
+
center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
|
330 |
+
|
331 |
+
if (not self.test_mode) and np.random.rand() < 0.3:
|
332 |
+
center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
|
333 |
+
|
334 |
+
if w > aspect_ratio * h:
|
335 |
+
h = w * 1.0 / aspect_ratio
|
336 |
+
elif w < aspect_ratio * h:
|
337 |
+
w = h * aspect_ratio
|
338 |
+
|
339 |
+
# pixel std is 200.0
|
340 |
+
scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
|
341 |
+
# padding to include proper amount of context
|
342 |
+
scale = scale * 1.25
|
343 |
+
|
344 |
+
return center, scale
|
345 |
+
|
346 |
+
def _load_coco_person_detection_results(self):
|
347 |
+
"""Load coco person detection results."""
|
348 |
+
num_joints = self.ann_info['num_joints']
|
349 |
+
all_boxes = None
|
350 |
+
with open(self.bbox_file, 'r') as f:
|
351 |
+
all_boxes = json.load(f)
|
352 |
+
|
353 |
+
if not all_boxes:
|
354 |
+
raise ValueError('=> Load %s fail!' % self.bbox_file)
|
355 |
+
|
356 |
+
print(f'=> Total boxes: {len(all_boxes)}')
|
357 |
+
|
358 |
+
kpt_db = []
|
359 |
+
bbox_id = 0
|
360 |
+
for det_res in all_boxes:
|
361 |
+
if det_res['category_id'] != 1:
|
362 |
+
continue
|
363 |
+
|
364 |
+
image_file = os.path.join(self.img_prefix,
|
365 |
+
self.id2name[det_res['image_id']])
|
366 |
+
box = det_res['bbox']
|
367 |
+
score = det_res['score']
|
368 |
+
|
369 |
+
if score < self.det_bbox_thr:
|
370 |
+
continue
|
371 |
+
|
372 |
+
center, scale = self._xywh2cs(*box[:4])
|
373 |
+
joints_3d = np.zeros((num_joints, 3), dtype=np.float32)
|
374 |
+
joints_3d_visible = np.ones((num_joints, 3), dtype=np.float32)
|
375 |
+
kpt_db.append({
|
376 |
+
'image_file': image_file,
|
377 |
+
'center': center,
|
378 |
+
'scale': scale,
|
379 |
+
'rotation': 0,
|
380 |
+
'bbox': box[:4],
|
381 |
+
'bbox_score': score,
|
382 |
+
'dataset': self.dataset_name,
|
383 |
+
'joints_3d': joints_3d,
|
384 |
+
'joints_3d_visible': joints_3d_visible,
|
385 |
+
'bbox_id': bbox_id
|
386 |
+
})
|
387 |
+
bbox_id = bbox_id + 1
|
388 |
+
print(f'=> Total boxes after filter '
|
389 |
+
f'low score@{self.det_bbox_thr}: {bbox_id}')
|
390 |
+
return kpt_db
|
391 |
+
|
392 |
+
def __len__(self):
|
393 |
+
"""Get the size of the dataset."""
|
394 |
+
return len(self.db)
|
395 |
+
|
396 |
+
def __getitem__(self, idx):
|
397 |
+
"""Get the sample given index."""
|
398 |
+
results = copy.deepcopy(self.db[idx])
|
399 |
+
results['ann_info'] = self.ann_info
|
400 |
+
out = self.pipeline(results)
|
401 |
+
|
402 |
+
C = self.ann_info['num_joints']
|
403 |
+
|
404 |
+
if 'label' in out:
|
405 |
+
out['dense_labeling'] = np.resize(out['label'],
|
406 |
+
(C, self.ann_info['image_size'][0], self.ann_info['image_size'][1]))
|
407 |
+
else:
|
408 |
+
out['dense_labeling'] = np.zeros((C, self.ann_info['image_size'][0], self.ann_info['image_size'][1]))
|
409 |
+
|
410 |
+
# del out['ann_info']
|
411 |
+
return out # dict_keys(['image_file', 'center', 'scale', 'bbox', 'rotation', 'joints_3d', 'joints_3d_visible',
|
412 |
+
# 'dataset', 'bbox_score', 'bbox_id', 'ann_info', 'image', 'flipped', 'label',
|
413 |
+
# 'target_weight'])
|
core/data/datasets/images/parsing_dataset.py
ADDED
@@ -0,0 +1,1084 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
import io
|
6 |
+
import numpy as np
|
7 |
+
import itertools
|
8 |
+
from typing import Any, Dict, List, Tuple, Union
|
9 |
+
from torch.utils import data
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from PIL import Image
|
12 |
+
from core.utils import cv2_loader, pil_loader
|
13 |
+
from core.data.datasets.images.seg_dataset_dev import Instances, BitMasks
|
14 |
+
import random
|
15 |
+
import torch.distributed as dist
|
16 |
+
|
17 |
+
import core.data.transforms.parsing_transforms as T
|
18 |
+
from core.data.transforms.pose_transforms import DataContainer
|
19 |
+
|
20 |
+
try:
|
21 |
+
from petrel_client.client import Client as Client
|
22 |
+
|
23 |
+
s3client = Client(boto=True,
|
24 |
+
enable_multi_cluster=True,
|
25 |
+
enable_mc=True)
|
26 |
+
except:
|
27 |
+
print("ceph can not be used")
|
28 |
+
|
29 |
+
palette_dict = {
|
30 |
+
'human3m6_parsing': np.array(
|
31 |
+
[[0, 0, 0], [128, 0, 0], [255, 0, 0], [0, 85, 0], [170, 0, 51], [255, 85, 0], [0, 0, 85], [0, 119, 221],
|
32 |
+
[85, 85, 0], [0, 85, 85],
|
33 |
+
[85, 51, 0], [52, 86, 128], [0, 128, 0], [0, 0, 255], [0, 255, 0],
|
34 |
+
[51, 170, 221], [0, 255, 255], [255, 170, 85], [85, 255, 170], [170, 85, 52],
|
35 |
+
[170, 255, 85], [255, 255, 0], [255, 170, 0], [255, 0, 170], [170, 0, 255]]),
|
36 |
+
'LIP_parsing': np.array([[0, 0, 0], [128, 0, 0], [255, 0, 0], [0, 85, 0], [170, 0, 51],
|
37 |
+
[255, 85, 0], [0, 0, 85], [0, 119, 221], [85, 85,
|
38 |
+
0], [0, 85, 85],
|
39 |
+
[85, 51, 0], [52, 86, 128], [0, 128, 0], [0, 0, 255],
|
40 |
+
[51, 170, 221], [0, 255, 255], [85, 255, 170], [170, 255, 85],
|
41 |
+
[255, 255, 0], [255, 170, 0]]),
|
42 |
+
'CIHP_parsing': np.array([[0, 0, 0], [128, 0, 0], [255, 0, 0], [0, 85, 0], [170, 0, 51],
|
43 |
+
[255, 85, 0], [0, 0, 85], [0, 119, 221], [85, 85,
|
44 |
+
0], [0, 85, 85],
|
45 |
+
[85, 51, 0], [52, 86, 128], [0, 128, 0], [0, 0, 255],
|
46 |
+
[51, 170, 221], [0, 255, 255], [85, 255, 170], [170, 255, 85],
|
47 |
+
[255, 255, 0], [255, 170, 0]]),
|
48 |
+
'ATR_parsing': np.array([[0, 0, 0], [128, 0, 0], [255, 0, 0], [0, 85, 0], [170, 0, 51],
|
49 |
+
[255, 85, 0], [0, 0, 85], [0, 119, 221], [85, 85,
|
50 |
+
0], [0, 85, 85],
|
51 |
+
[85, 51, 0], [52, 86, 128], [0, 128, 0], [0, 0, 255],
|
52 |
+
[51, 170, 221], [0, 255, 255], [85, 255, 170], [170, 255, 85]]),
|
53 |
+
|
54 |
+
}
|
55 |
+
|
56 |
+
def get_unk_mask_indices(image,num_labels,known_labels,epoch=1,testing=False,):
|
57 |
+
if testing:
|
58 |
+
# for consistency across epochs and experiments, seed using hashed image array
|
59 |
+
random.seed(hashlib.sha1(np.array(image)).hexdigest())
|
60 |
+
unk_mask_indices = random.sample(range(num_labels), (num_labels-int(known_labels)))
|
61 |
+
else:
|
62 |
+
# sample random number of known labels during training
|
63 |
+
if known_labels>0:
|
64 |
+
random.seed()
|
65 |
+
num_known = random.randint(0,int(num_labels*0.75))
|
66 |
+
else:
|
67 |
+
num_known = 0
|
68 |
+
|
69 |
+
unk_mask_indices = random.sample(range(num_labels), (num_labels-num_known))
|
70 |
+
|
71 |
+
return unk_mask_indices
|
72 |
+
|
73 |
+
class Human3M6ParsingDataset(data.Dataset):
|
74 |
+
task_name = 'human3m6_parsing'
|
75 |
+
left_right_pairs = np.array([[1, 6],
|
76 |
+
[2, 7],
|
77 |
+
[3, 8],
|
78 |
+
[17, 25],
|
79 |
+
[18, 26],
|
80 |
+
[19, 27],
|
81 |
+
[33, 38],
|
82 |
+
[34, 39],
|
83 |
+
[49, 56],
|
84 |
+
[50, 58]])
|
85 |
+
|
86 |
+
label_mapper = np.arange(60)
|
87 |
+
|
88 |
+
evaluate_size = (1000, 1000)
|
89 |
+
|
90 |
+
def __init__(self,
|
91 |
+
ginfo,
|
92 |
+
data_path,
|
93 |
+
dataset='train',
|
94 |
+
data_use_ratio=1,
|
95 |
+
is_train=True,
|
96 |
+
cfg=None,
|
97 |
+
**kwargs):
|
98 |
+
"""human3.6m dataset for human parsing
|
99 |
+
Args:
|
100 |
+
root_dir ([str]): where dataset
|
101 |
+
dataset: train / val
|
102 |
+
cfg: yaml format config
|
103 |
+
|
104 |
+
# 0 : background
|
105 |
+
# 1 : right hip
|
106 |
+
# 2 : right knee
|
107 |
+
# 3 : right foot
|
108 |
+
# 6 : left hip
|
109 |
+
# 7 : left knee
|
110 |
+
# 8 : left foot
|
111 |
+
# 17 : left shoulder
|
112 |
+
# 18 : left elbow
|
113 |
+
# 19 : left hand
|
114 |
+
# 25 : right shoulder
|
115 |
+
# 26 : right elbow
|
116 |
+
# 27 : right hand
|
117 |
+
# 32 : crotch
|
118 |
+
# 33 : right thigh
|
119 |
+
# 34 : right calf
|
120 |
+
# 38 : left thigh
|
121 |
+
# 39 : left calf
|
122 |
+
# 43 : lower spine
|
123 |
+
# 44 : upper spine
|
124 |
+
# 46 : head
|
125 |
+
# 49 : left arm
|
126 |
+
# 50 : left forearm
|
127 |
+
# 56 : right arm
|
128 |
+
# 58 : right forearm
|
129 |
+
|
130 |
+
"""
|
131 |
+
# self.task_name = 'human3m6_parsing'
|
132 |
+
self.cfg = cfg
|
133 |
+
self.dataset = dataset
|
134 |
+
self.is_train = is_train
|
135 |
+
self.data_use_ratio = data_use_ratio
|
136 |
+
self.pseudo_labels = self.cfg.get('Pseudo_labels', False)
|
137 |
+
self.stride_level = self.cfg.get('stride_level', 1)
|
138 |
+
# self.palette = palette_dict[self.task_name]
|
139 |
+
self.pseudo_labels_palette = palette_dict[self.cfg.get('Pseudo_labels_palette','human3m6_parsing')]
|
140 |
+
self.ignore2endclass = self.cfg.get('ignore2endclass', False)
|
141 |
+
|
142 |
+
self.img_list, self.label_list, self.name_list = self._list_dirs(data_path)
|
143 |
+
|
144 |
+
index = np.arange(0, len(self.img_list))
|
145 |
+
random.shuffle(index)
|
146 |
+
self.img_list = np.array(self.img_list)
|
147 |
+
self.label_list = np.array(self.label_list)
|
148 |
+
self.name_list = np.array(self.name_list)
|
149 |
+
|
150 |
+
self.img_list = self.img_list[index].tolist()
|
151 |
+
self.label_list = self.label_list[index].tolist()
|
152 |
+
self.name_list = self.name_list[index].tolist()
|
153 |
+
|
154 |
+
self.images = self.img_list
|
155 |
+
self.labels = self.label_list
|
156 |
+
self.ignore_label = cfg.ignore_value
|
157 |
+
self.num = len(self.images)
|
158 |
+
self.num_classes = len(self.cfg.label_list) # - 1
|
159 |
+
assert self.num_classes == self.cfg.num_classes, f"num of class mismatch, len(label_list)={self.num_classes}, num_classes:{self.cfg.num_classes}"
|
160 |
+
|
161 |
+
self.rank = dist.get_rank()
|
162 |
+
self.world_size = dist.get_world_size()
|
163 |
+
|
164 |
+
self.original_label = np.array(self.cfg.label_list)
|
165 |
+
|
166 |
+
for i, l in enumerate(self.original_label):
|
167 |
+
self.label_mapper[l] = i
|
168 |
+
self.mapped_left_right_pairs = self.label_mapper[self.left_right_pairs] if self.left_right_pairs is not None else None
|
169 |
+
|
170 |
+
if self.is_train:
|
171 |
+
augs = T.compose([T.hflip(cfg.get("is_flip", False), self.mapped_left_right_pairs),
|
172 |
+
T.resize_image(cfg.crop_size),
|
173 |
+
T.multi_scale(cfg.get("is_multi_scale", False), scale_factor=cfg.get("scale_factor", 11),
|
174 |
+
center_crop_test=cfg.get("center_crop_test", False),
|
175 |
+
base_size=cfg.base_size,
|
176 |
+
crop_size=cfg.crop_size,
|
177 |
+
ignore_label=cfg.get("ignore_value", 255)),
|
178 |
+
T.rotate(cfg.get("is_rotate", False), degree=cfg.get("degree", 30),
|
179 |
+
p=cfg.get("possibility", 0.6), pad_val=cfg.get("pad_val", 0),
|
180 |
+
seg_pad_val=cfg.get("ignore_value", 255)),
|
181 |
+
T.PhotoMetricDistortion(cfg.get('is_photometricdistortion', False),
|
182 |
+
brightness_delta=cfg.get('brightness', 32),
|
183 |
+
contrast_range=cfg.get('contrast_range', [0.5, 1.5]),
|
184 |
+
saturation_range=cfg.get("saturation_range", [0.5, 1.5]),
|
185 |
+
hue_delta=cfg.get('hue', 18)
|
186 |
+
),
|
187 |
+
T.transpose()])
|
188 |
+
else:
|
189 |
+
augs = T.compose([T.resize_image_eval(cfg.eval_crop_size),
|
190 |
+
T.transpose()])
|
191 |
+
self.augs = augs
|
192 |
+
|
193 |
+
self.initialized = False
|
194 |
+
self.use_ceph = True
|
195 |
+
|
196 |
+
def __len__(self):
|
197 |
+
return len(self.img_list)
|
198 |
+
|
199 |
+
def _ignore_to_endclass(self, label):
|
200 |
+
label[label==self.ignore_label] = self.num_classes
|
201 |
+
return label
|
202 |
+
|
203 |
+
def _read_one(self, index=None):
|
204 |
+
if index == None:
|
205 |
+
index = np.random.randint(self.num)
|
206 |
+
|
207 |
+
filename = self.img_list[index]
|
208 |
+
try:
|
209 |
+
img = Image.open(filename).convert('RGB')
|
210 |
+
img = np.array(img)[:,:,::-1]
|
211 |
+
|
212 |
+
except:
|
213 |
+
outputName = "failed_to_read_in_train.txt"
|
214 |
+
with open(outputName, "a") as g:
|
215 |
+
g.write("%s\n" % (filename))
|
216 |
+
print('Read image[{}] failed ({})'.format(index, filename))
|
217 |
+
## if fail then recursive call _read_one without idx
|
218 |
+
return self._read_one()
|
219 |
+
|
220 |
+
gt_label = self.label_list[index]
|
221 |
+
|
222 |
+
try:
|
223 |
+
label = np.array(Image.open(gt_label))
|
224 |
+
except:
|
225 |
+
outputName = "failed_to_read_in_train_labels.txt"
|
226 |
+
with open(outputName, "a") as g:
|
227 |
+
g.write("%s\n" % (filename))
|
228 |
+
print('Read image[{}] failed ({})'.format(index, gt_label))
|
229 |
+
## if fail then recursive call _read_one without idx
|
230 |
+
return self._read_one()
|
231 |
+
|
232 |
+
return img, label
|
233 |
+
|
234 |
+
def __getitem__(self, index):
|
235 |
+
|
236 |
+
dataset_dict = {}
|
237 |
+
dataset_dict["filename"] = self.name_list[index]
|
238 |
+
|
239 |
+
image, parsing_seg_gt = self._read_one(index)
|
240 |
+
|
241 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
242 |
+
|
243 |
+
self._record_image_size(dataset_dict, image)
|
244 |
+
|
245 |
+
if self.pseudo_labels:
|
246 |
+
image, parsing_seg_gt = self.augs(image, parsing_seg_gt)
|
247 |
+
image = torch.as_tensor(np.ascontiguousarray(image))
|
248 |
+
parsing_seg_gt = torch.as_tensor(np.ascontiguousarray(parsing_seg_gt))
|
249 |
+
dataset_dict["image"] = image
|
250 |
+
dataset_dict['PL_gt'] = parsing_seg_gt
|
251 |
+
return dataset_dict
|
252 |
+
|
253 |
+
parsing_seg_gt = self._encode_label(parsing_seg_gt) # - 1 no need to filter background in human parsing
|
254 |
+
|
255 |
+
size = parsing_seg_gt.size
|
256 |
+
|
257 |
+
if not self.is_train:
|
258 |
+
if len(self.evaluate_size) == 2:
|
259 |
+
dataset_dict["gt"] = np.copy(
|
260 |
+
cv2.resize(parsing_seg_gt, self.evaluate_size, interpolation=cv2.INTER_LINEAR_EXACT).astype(np.int_))
|
261 |
+
else:
|
262 |
+
# use DataContainer type to avoid being batched as tensors
|
263 |
+
dataset_dict["gt"] = DataContainer(np.copy(parsing_seg_gt.astype(np.int_)))
|
264 |
+
|
265 |
+
parsing_seg_gt = parsing_seg_gt.astype("double")
|
266 |
+
assert len(parsing_seg_gt), "parsing needs gt to train"
|
267 |
+
image, parsing_seg_gt = self.augs(image, parsing_seg_gt)
|
268 |
+
if self.stride_level>1:
|
269 |
+
temp_h, temp_w = parsing_seg_gt.shape
|
270 |
+
parsing_seg_gt = np.asarray(Image.fromarray(parsing_seg_gt).convert('P').resize((int(temp_h/self.stride_level), int(temp_w/self.stride_level))))
|
271 |
+
if self.ignore2endclass:
|
272 |
+
parsing_seg_gt = self._ignore_to_endclass(parsing_seg_gt)
|
273 |
+
image = torch.as_tensor(np.ascontiguousarray(image))
|
274 |
+
parsing_seg_gt = torch.as_tensor(parsing_seg_gt.astype("long"))
|
275 |
+
|
276 |
+
image_shape = (image.shape[-2], image.shape[-1]) # h, w
|
277 |
+
|
278 |
+
dataset_dict["image"] = image
|
279 |
+
if not self.is_train:
|
280 |
+
if self.cfg.get('label_mask', False):
|
281 |
+
m = torch.ones(self.num_classes,dtype=torch.int64)*-1 # mask all labels
|
282 |
+
dataset_dict['mask'] = m
|
283 |
+
return dataset_dict
|
284 |
+
|
285 |
+
dataset_dict["label"] = parsing_seg_gt.long() # not used in test
|
286 |
+
|
287 |
+
# Prepare per-category binary masks
|
288 |
+
parsing_seg_gt = parsing_seg_gt.numpy()
|
289 |
+
instances = Instances(image_shape)
|
290 |
+
classes = np.unique(parsing_seg_gt)
|
291 |
+
# remove ignored region
|
292 |
+
if self.cfg.get('add_zero_mask',False):
|
293 |
+
classes = np.array(list(range(self.num_classes)))
|
294 |
+
classes = classes[classes != self.ignore_label]
|
295 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
296 |
+
|
297 |
+
if self.cfg.get('label_mask', False):
|
298 |
+
m = np.zeros(self.num_classes)
|
299 |
+
m[classes] = 1
|
300 |
+
mask = torch.tensor(m, dtype=torch.int64).clone()
|
301 |
+
unk_mask_indices = get_unk_mask_indices(image, self.num_classes, known_labels=100,) # set known_labels>1 to use label masking training
|
302 |
+
mask.scatter_(0, torch.Tensor(unk_mask_indices).long(), -1)
|
303 |
+
dataset_dict['mask'] = mask
|
304 |
+
|
305 |
+
masks = []
|
306 |
+
for class_id in classes:
|
307 |
+
masks.append(parsing_seg_gt == class_id)
|
308 |
+
|
309 |
+
if len(masks) == 0:
|
310 |
+
# Some image does not have annotation (all ignored)
|
311 |
+
instances.gt_masks = torch.zeros((0, parsing_seg_gt.shape[-2], parsing_seg_gt.shape[-1]))
|
312 |
+
else:
|
313 |
+
masks = BitMasks(
|
314 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
315 |
+
)
|
316 |
+
instances.gt_masks = masks.tensor
|
317 |
+
|
318 |
+
dataset_dict["instances"] = instances # not used in test
|
319 |
+
|
320 |
+
return dataset_dict # {'image': img_mask, 'label': target_mask, 'instances': xxx, 'filename': img_name}
|
321 |
+
|
322 |
+
@staticmethod
|
323 |
+
def _record_image_size(dataset_dict, image):
|
324 |
+
"""
|
325 |
+
Raise an error if the image does not match the size specified in the dict.
|
326 |
+
"""
|
327 |
+
# To ensure bbox always remap to original image size
|
328 |
+
if "width" not in dataset_dict:
|
329 |
+
dataset_dict["width"] = image.shape[1]
|
330 |
+
if "height" not in dataset_dict:
|
331 |
+
dataset_dict["height"] = image.shape[0]
|
332 |
+
|
333 |
+
def _list_dirs(self, data_path):
|
334 |
+
img_list = list()
|
335 |
+
label_list = list()
|
336 |
+
name_list = list()
|
337 |
+
|
338 |
+
if self.dataset == 'train':
|
339 |
+
train_type = 'train'
|
340 |
+
elif self.dataset == 'val':
|
341 |
+
train_type = 'eval'
|
342 |
+
list_txt = osp.join(data_path, f'flist_2hz_{train_type}.txt')
|
343 |
+
|
344 |
+
# with open(list_txt, 'r') as f:
|
345 |
+
# data = f.readlines()
|
346 |
+
# data = [d.strip() for d in data]
|
347 |
+
list_lines = s3client.Get(list_txt)
|
348 |
+
if not list_lines:
|
349 |
+
print('File not exist', list_txt)
|
350 |
+
import pdb;
|
351 |
+
pdb.set_trace()
|
352 |
+
raise IOError('File not exist', list_file)
|
353 |
+
list_lines = list_lines.decode('ascii')
|
354 |
+
data = list_lines.split('\n')
|
355 |
+
data = [d for d in data if len(d)]
|
356 |
+
|
357 |
+
|
358 |
+
if self.data_use_ratio != 1:
|
359 |
+
data = random.sample(data, int(len(data) * self.data_use_ratio))
|
360 |
+
|
361 |
+
for d in data:
|
362 |
+
img_path = osp.join(data_path, d)
|
363 |
+
image_name = '/'.join(d.split('/')[2:])
|
364 |
+
label_path = img_path.replace('rgb', 'seg', 1)
|
365 |
+
|
366 |
+
img_list.append(img_path)
|
367 |
+
label_list.append(label_path)
|
368 |
+
name_list.append(image_name)
|
369 |
+
|
370 |
+
return img_list, label_list, name_list
|
371 |
+
|
372 |
+
def _encode_label(self, labelmap):
|
373 |
+
shape = labelmap.shape
|
374 |
+
encoded_labelmap = np.zeros(shape=(shape[0], shape[1]), dtype=np.uint8)
|
375 |
+
for i, class_id in enumerate(self.cfg.label_list):
|
376 |
+
encoded_labelmap[labelmap == class_id] = i
|
377 |
+
|
378 |
+
return encoded_labelmap
|
379 |
+
|
380 |
+
def __repr__(self):
|
381 |
+
return self.__class__.__name__ + \
|
382 |
+
f'rank: {self.rank} task: {self.task_name} mode:{"training" if self.is_train else "inference"} ' \
|
383 |
+
f'dataset_len:{len(self.img_list)} id_num:{self.cfg["num_classes"]} augmentation: {self.augs}'
|
384 |
+
|
385 |
+
class LIPParsingDataset(Human3M6ParsingDataset):
|
386 |
+
"""
|
387 |
+
0:'background',
|
388 |
+
1:'hat',
|
389 |
+
2:'hair',
|
390 |
+
3:'glove',
|
391 |
+
4:'sunglasses',
|
392 |
+
5:'upperclothes',
|
393 |
+
6:'dress',
|
394 |
+
7:'coat',
|
395 |
+
8:'socks',
|
396 |
+
9:'pants',
|
397 |
+
10:'jumpsuits',
|
398 |
+
11:'scarf',
|
399 |
+
12:'skirt',
|
400 |
+
13:'face',
|
401 |
+
14:'leftArm',
|
402 |
+
15:'rightArm',
|
403 |
+
16:'leftLeg',
|
404 |
+
17:'rightLeg',
|
405 |
+
18:'leftShoe',
|
406 |
+
19:'rightShoe'
|
407 |
+
"""
|
408 |
+
task_name = 'LIP_parsing'
|
409 |
+
|
410 |
+
left_right_pairs = np.array(
|
411 |
+
[[14, 15], [16, 17], [18, 19]]
|
412 |
+
)
|
413 |
+
|
414 |
+
label_mapper = np.arange(60)
|
415 |
+
|
416 |
+
evaluate_size = ()
|
417 |
+
|
418 |
+
def __init__(self,
|
419 |
+
ginfo,
|
420 |
+
data_path,
|
421 |
+
dataset='train',
|
422 |
+
data_use_ratio=1,
|
423 |
+
is_train=True,
|
424 |
+
cfg=None,
|
425 |
+
**kwargs):
|
426 |
+
super(LIPParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,
|
427 |
+
data_use_ratio=data_use_ratio,
|
428 |
+
dataset=dataset, is_train=is_train,
|
429 |
+
cfg=cfg, **kwargs)
|
430 |
+
|
431 |
+
def _list_dirs(self, data_path):
|
432 |
+
img_list = list()
|
433 |
+
label_list = list()
|
434 |
+
name_list = list()
|
435 |
+
|
436 |
+
if self.dataset == 'train':
|
437 |
+
train_type = 'train'
|
438 |
+
elif self.dataset == 'val':
|
439 |
+
train_type = 'val'
|
440 |
+
"""
|
441 |
+
- LIP
|
442 |
+
-data
|
443 |
+
-train_id.txt
|
444 |
+
-train_images
|
445 |
+
-1000_1234574.jpg
|
446 |
+
-val_images
|
447 |
+
-val_id.txt
|
448 |
+
-Trainval_parsing_annotations
|
449 |
+
-train_segmentations
|
450 |
+
-1000_1234574.png
|
451 |
+
"""
|
452 |
+
list_txt = osp.join(data_path, 'data', f'{train_type}_id.txt')
|
453 |
+
|
454 |
+
with open(list_txt, 'r') as f:
|
455 |
+
data = f.readlines()
|
456 |
+
data = [d.strip() for d in data]
|
457 |
+
|
458 |
+
|
459 |
+
if self.data_use_ratio != 1:
|
460 |
+
data = random.sample(data, int(len(data) * self.data_use_ratio))
|
461 |
+
|
462 |
+
postfix_img = '.jpg'
|
463 |
+
postfix_ann = '.png'
|
464 |
+
for d in data:
|
465 |
+
img_path = osp.join(data_path, f'data/{train_type}_images', d + postfix_img)
|
466 |
+
image_name = d
|
467 |
+
label_path = osp.join(data_path, f'TrainVal_parsing_annotations/{train_type}_segmentations',
|
468 |
+
d + postfix_ann)
|
469 |
+
|
470 |
+
img_list.append(img_path)
|
471 |
+
label_list.append(label_path)
|
472 |
+
name_list.append(image_name)
|
473 |
+
|
474 |
+
return img_list, label_list, name_list
|
475 |
+
|
476 |
+
class CIHPParsingDataset(Human3M6ParsingDataset):
|
477 |
+
"""
|
478 |
+
0:'background',
|
479 |
+
1:'hat',
|
480 |
+
2:'hair',
|
481 |
+
3:'glove',
|
482 |
+
4:'sunglasses',
|
483 |
+
5:'upperclothes',
|
484 |
+
6:'dress',
|
485 |
+
7:'coat',
|
486 |
+
8:'socks',
|
487 |
+
9:'pants',
|
488 |
+
10:'torsoSkin',
|
489 |
+
11:'scarf',
|
490 |
+
12:'skirt',
|
491 |
+
13:'face',
|
492 |
+
14:'leftArm',
|
493 |
+
15:'rightArm',
|
494 |
+
16:'leftLeg',
|
495 |
+
17:'rightLeg',
|
496 |
+
18:'leftShoe',
|
497 |
+
19:'rightShoe'
|
498 |
+
"""
|
499 |
+
task_name = 'CIHP_parsing'
|
500 |
+
|
501 |
+
left_right_pairs = np.array(
|
502 |
+
[[14, 15], [16, 17], [18, 19]]
|
503 |
+
)
|
504 |
+
|
505 |
+
label_mapper = np.arange(60)
|
506 |
+
|
507 |
+
evaluate_size = ()
|
508 |
+
|
509 |
+
def __init__(self,
|
510 |
+
ginfo,
|
511 |
+
data_path,
|
512 |
+
dataset='train',
|
513 |
+
data_use_ratio=1,
|
514 |
+
is_train=True,
|
515 |
+
cfg=None,
|
516 |
+
**kwargs):
|
517 |
+
super(CIHPParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,data_use_ratio=data_use_ratio,
|
518 |
+
dataset=dataset, is_train=is_train,
|
519 |
+
cfg=cfg, **kwargs)
|
520 |
+
|
521 |
+
def _list_dirs(self, data_path):
|
522 |
+
img_list = list()
|
523 |
+
label_list = list()
|
524 |
+
name_list = list()
|
525 |
+
|
526 |
+
if self.dataset == 'train':
|
527 |
+
train_type = 'train'
|
528 |
+
elif self.dataset == 'val':
|
529 |
+
train_type = 'val'
|
530 |
+
"""
|
531 |
+
- CHIP
|
532 |
+
-instance-level_human_parsing
|
533 |
+
-Training
|
534 |
+
-Images
|
535 |
+
-0008522.jpg
|
536 |
+
-Category_ids
|
537 |
+
-0008522.png
|
538 |
+
-train_id.txt
|
539 |
+
-Validation
|
540 |
+
-val_id.txt
|
541 |
+
"""
|
542 |
+
Infix = 'Training' if train_type == 'train' else 'Validation'
|
543 |
+
list_txt = osp.join(data_path, 'instance-level_human_parsing', Infix, f'{train_type}_id.txt')
|
544 |
+
|
545 |
+
with open(list_txt, 'r') as f:
|
546 |
+
data = f.readlines()
|
547 |
+
data = [d.strip() for d in data]
|
548 |
+
|
549 |
+
if self.data_use_ratio != 1:
|
550 |
+
data = random.sample(data, int(len(data) * self.data_use_ratio))
|
551 |
+
|
552 |
+
postfix_img = '.jpg'
|
553 |
+
postfix_ann = '.png'
|
554 |
+
for d in data:
|
555 |
+
img_path = osp.join(data_path, 'instance-level_human_parsing', Infix, f'Images', d + postfix_img)
|
556 |
+
image_name = d
|
557 |
+
label_path = osp.join(data_path, 'instance-level_human_parsing', Infix, 'Category_ids', d + postfix_ann)
|
558 |
+
|
559 |
+
img_list.append(img_path)
|
560 |
+
label_list.append(label_path)
|
561 |
+
name_list.append(image_name)
|
562 |
+
|
563 |
+
return img_list, label_list, name_list
|
564 |
+
|
565 |
+
|
566 |
+
class ATRParsingDataset(Human3M6ParsingDataset):
|
567 |
+
"""
|
568 |
+
0:'background', #
|
569 |
+
1:'hat', #
|
570 |
+
2:'hair',#
|
571 |
+
3:'sunglasses',#
|
572 |
+
4:'upperclothes',#
|
573 |
+
5:'skirt',
|
574 |
+
6:'pants',#
|
575 |
+
7:'dress',#
|
576 |
+
8:'belt',
|
577 |
+
9:'leftshoe',#
|
578 |
+
10:'rightshoe',#
|
579 |
+
11:'face',#
|
580 |
+
12:'leftleg',#
|
581 |
+
13:'rightleg',#
|
582 |
+
14:'leftarm',#
|
583 |
+
15:'rightarm',#
|
584 |
+
16:'bag',#
|
585 |
+
17:'scarf',#
|
586 |
+
"""
|
587 |
+
task_name = 'ATR_parsing'
|
588 |
+
|
589 |
+
left_right_pairs = np.array(
|
590 |
+
[[9,10], [12,13], [14,15]]
|
591 |
+
)
|
592 |
+
|
593 |
+
label_mapper = np.arange(60)
|
594 |
+
|
595 |
+
evaluate_size = ()
|
596 |
+
|
597 |
+
def __init__(self,
|
598 |
+
ginfo,
|
599 |
+
data_path,
|
600 |
+
dataset='train',
|
601 |
+
data_use_ratio=1,
|
602 |
+
is_train=True,
|
603 |
+
cfg=None,
|
604 |
+
**kwargs):
|
605 |
+
super(ATRParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,
|
606 |
+
data_use_ratio=data_use_ratio,
|
607 |
+
dataset=dataset, is_train=is_train,
|
608 |
+
cfg=cfg, **kwargs)
|
609 |
+
|
610 |
+
def _list_dirs(self, data_path):
|
611 |
+
img_list = list()
|
612 |
+
label_list = list()
|
613 |
+
name_list = list()
|
614 |
+
|
615 |
+
if self.dataset == 'train':
|
616 |
+
train_type = 'train'
|
617 |
+
elif self.dataset == 'val':
|
618 |
+
train_type = 'val'
|
619 |
+
"""
|
620 |
+
- ATR
|
621 |
+
-humanparsing
|
622 |
+
-JPEGImages
|
623 |
+
-SegmentationClassAug
|
624 |
+
-train_id.txt
|
625 |
+
-val_id.txt
|
626 |
+
"""
|
627 |
+
list_txt = osp.join(data_path, f'{train_type}_id.txt')
|
628 |
+
with open(list_txt, 'r') as f:
|
629 |
+
data = f.readlines()
|
630 |
+
data = [d.strip() for d in data]
|
631 |
+
|
632 |
+
if self.data_use_ratio != 1:
|
633 |
+
data = random.sample(data, int(len(data) * self.data_use_ratio))
|
634 |
+
|
635 |
+
postfix_img = '.jpg'
|
636 |
+
postfix_ann = '.png'
|
637 |
+
for d in data:
|
638 |
+
img_path = osp.join(data_path, f'humanparsing/JPEGImages', d + postfix_img)
|
639 |
+
image_name = d
|
640 |
+
label_path = osp.join(data_path, f'humanparsing/SegmentationClassAug',
|
641 |
+
d + postfix_ann)
|
642 |
+
|
643 |
+
img_list.append(img_path)
|
644 |
+
label_list.append(label_path)
|
645 |
+
name_list.append(image_name)
|
646 |
+
|
647 |
+
return img_list, label_list, name_list
|
648 |
+
|
649 |
+
|
650 |
+
class DeepFashionParsingDataset(Human3M6ParsingDataset):
|
651 |
+
"""
|
652 |
+
0:'background', #
|
653 |
+
1:'hat', #
|
654 |
+
2:'hair',#
|
655 |
+
3:'sunglasses',#
|
656 |
+
4:'upperclothes',#
|
657 |
+
5:'skirt',
|
658 |
+
6:'pants',#
|
659 |
+
7:'dress',#
|
660 |
+
8:'belt',
|
661 |
+
9:'leftshoe',#
|
662 |
+
10:'rightshoe',#
|
663 |
+
11:'face',#
|
664 |
+
12:'leftleg',#
|
665 |
+
13:'rightleg',#
|
666 |
+
14:'leftarm',#
|
667 |
+
15:'rightarm',#
|
668 |
+
16:'bag',#
|
669 |
+
17:'scarf',#
|
670 |
+
"""
|
671 |
+
task_name = 'DeepFashion_parsing'
|
672 |
+
label_mapper = np.arange(60)
|
673 |
+
left_right_pairs = None
|
674 |
+
evaluate_size = ()
|
675 |
+
|
676 |
+
def __init__(self,
|
677 |
+
ginfo,
|
678 |
+
data_path,
|
679 |
+
dataset='train',
|
680 |
+
data_use_ratio=1,
|
681 |
+
is_train=True,
|
682 |
+
cfg=None,
|
683 |
+
**kwargs):
|
684 |
+
super(DeepFashionParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,
|
685 |
+
data_use_ratio=data_use_ratio,
|
686 |
+
dataset=dataset, is_train=is_train,
|
687 |
+
cfg=cfg, **kwargs)
|
688 |
+
|
689 |
+
def _list_dirs(self, data_path):
|
690 |
+
img_list = list()
|
691 |
+
label_list = list()
|
692 |
+
name_list = list()
|
693 |
+
|
694 |
+
if self.dataset == 'train':
|
695 |
+
train_type = 'train'
|
696 |
+
elif self.dataset == 'val':
|
697 |
+
train_type = 'val'
|
698 |
+
"""
|
699 |
+
- DeepFashion
|
700 |
+
-humanparsing
|
701 |
+
-JPEGImages
|
702 |
+
-SegmentationClassAug
|
703 |
+
-train_id.txt
|
704 |
+
-val_id.txt
|
705 |
+
"""
|
706 |
+
list_txt = osp.join(data_path, f'{train_type}_id.txt')
|
707 |
+
|
708 |
+
with open(list_txt, 'r') as f:
|
709 |
+
data = f.readlines()
|
710 |
+
data = [d.strip() for d in data]
|
711 |
+
|
712 |
+
|
713 |
+
if self.data_use_ratio != 1:
|
714 |
+
data = random.sample(data, int(len(data) * self.data_use_ratio))
|
715 |
+
|
716 |
+
postfix_img = '.jpg'
|
717 |
+
postfix_ann = '.png'
|
718 |
+
|
719 |
+
if train_type == 'train':
|
720 |
+
for d in data:
|
721 |
+
img_path = osp.join(data_path, f'train/image', d + postfix_img)
|
722 |
+
image_name = d
|
723 |
+
label_path = osp.join(data_path, f'train/seg',
|
724 |
+
d + postfix_ann)
|
725 |
+
|
726 |
+
img_list.append(img_path)
|
727 |
+
label_list.append(label_path)
|
728 |
+
name_list.append(image_name)
|
729 |
+
|
730 |
+
return img_list, label_list, name_list
|
731 |
+
else:
|
732 |
+
raise ValueError("not implement")
|
733 |
+
|
734 |
+
|
735 |
+
class VIPParsingDataset(Human3M6ParsingDataset):
|
736 |
+
task_name = 'VIP_parsing'
|
737 |
+
left_right_pairs = np.array(
|
738 |
+
[[14, 15], [16, 17], [18, 19]]
|
739 |
+
)
|
740 |
+
|
741 |
+
label_mapper = np.arange(60)
|
742 |
+
|
743 |
+
evaluate_size = ()
|
744 |
+
|
745 |
+
def __init__(self,
|
746 |
+
ginfo,
|
747 |
+
data_path,
|
748 |
+
dataset='train',
|
749 |
+
data_use_ratio=1,
|
750 |
+
is_train=True,
|
751 |
+
cfg=None,
|
752 |
+
**kwargs):
|
753 |
+
super(VIPParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,data_use_ratio=data_use_ratio,
|
754 |
+
dataset=dataset, is_train=is_train,
|
755 |
+
cfg=cfg, **kwargs)
|
756 |
+
|
757 |
+
def _list_dirs(self, data_path):
|
758 |
+
img_list = list()
|
759 |
+
label_list = list()
|
760 |
+
name_list = list()
|
761 |
+
|
762 |
+
if self.dataset == 'train':
|
763 |
+
train_type = 'train'
|
764 |
+
elif self.dataset == 'val':
|
765 |
+
train_type = 'val'
|
766 |
+
|
767 |
+
list_txt = osp.join(data_path, f'{train_type}_id.txt')
|
768 |
+
|
769 |
+
with open(list_txt, 'r') as f:
|
770 |
+
data = f.readlines()
|
771 |
+
data = [d.strip() for d in data]
|
772 |
+
|
773 |
+
|
774 |
+
postfix_img = '.jpg'
|
775 |
+
postfix_ann = '.png'
|
776 |
+
|
777 |
+
if self.data_use_ratio != 1:
|
778 |
+
data = random.sample(data, int(len(data) * self.data_use_ratio))
|
779 |
+
|
780 |
+
if train_type == 'train':
|
781 |
+
for d in data:
|
782 |
+
img_path = osp.join(data_path, f'Images', d + postfix_img)
|
783 |
+
image_name = d
|
784 |
+
label_path = osp.join(data_path, f'Annotations/Category_ids',
|
785 |
+
d + postfix_ann)
|
786 |
+
|
787 |
+
img_list.append(img_path)
|
788 |
+
label_list.append(label_path)
|
789 |
+
name_list.append(image_name)
|
790 |
+
|
791 |
+
return img_list, label_list, name_list
|
792 |
+
else:
|
793 |
+
raise ValueError("not implement")
|
794 |
+
|
795 |
+
class PaperDollParsingDataset(Human3M6ParsingDataset):
|
796 |
+
"""
|
797 |
+
0:'background',
|
798 |
+
1:'hat',
|
799 |
+
2:'hair',
|
800 |
+
3:'glove',
|
801 |
+
4:'sunglasses',
|
802 |
+
5:'upperclothes',
|
803 |
+
6:'dress',
|
804 |
+
7:'coat',
|
805 |
+
8:'socks',
|
806 |
+
9:'pants',
|
807 |
+
10:'torsoSkin',
|
808 |
+
11:'scarf',
|
809 |
+
12:'skirt',
|
810 |
+
13:'face',
|
811 |
+
14:'leftArm',
|
812 |
+
15:'rightArm',
|
813 |
+
16:'leftLeg',
|
814 |
+
17:'rightLeg',
|
815 |
+
18:'leftShoe',
|
816 |
+
19:'rightShoe'
|
817 |
+
"""
|
818 |
+
task_name = 'PaperDoll_parsing'
|
819 |
+
|
820 |
+
left_right_pairs = np.array(
|
821 |
+
[[14, 15], [16, 17], [18, 19]]
|
822 |
+
)
|
823 |
+
|
824 |
+
label_mapper = np.arange(60)
|
825 |
+
|
826 |
+
evaluate_size = ()
|
827 |
+
|
828 |
+
def __init__(self,
|
829 |
+
ginfo,
|
830 |
+
data_path,
|
831 |
+
data_use_ratio=1,
|
832 |
+
dataset='train',
|
833 |
+
is_train=True,
|
834 |
+
cfg=None,
|
835 |
+
**kwargs):
|
836 |
+
super(PaperDollParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path, data_use_ratio=data_use_ratio,
|
837 |
+
dataset=dataset, is_train=is_train,
|
838 |
+
cfg=cfg, **kwargs)
|
839 |
+
|
840 |
+
def _list_dirs(self, data_path):
|
841 |
+
img_list = list()
|
842 |
+
label_list = list()
|
843 |
+
name_list = list()
|
844 |
+
|
845 |
+
if self.dataset == 'train':
|
846 |
+
train_type = 'train'
|
847 |
+
elif self.dataset == 'val':
|
848 |
+
train_type = 'val'
|
849 |
+
"""
|
850 |
+
- PaperDoll_folder
|
851 |
+
- TrainVal_parsing_annotations/
|
852 |
+
- 0000000.png
|
853 |
+
- images
|
854 |
+
- 0000000.jpg
|
855 |
+
"""
|
856 |
+
list_txt = osp.join(data_path, f'{train_type}_id.txt')
|
857 |
+
|
858 |
+
with open(list_txt, 'r') as f:
|
859 |
+
data = f.readlines()
|
860 |
+
data = [d.strip() for d in data]
|
861 |
+
|
862 |
+
|
863 |
+
postfix_img = '.jpg'
|
864 |
+
postfix_ann = '.png'
|
865 |
+
|
866 |
+
if self.data_use_ratio != 1:
|
867 |
+
data = random.sample(data, int(len(data) * self.data_use_ratio))
|
868 |
+
|
869 |
+
for d in data:
|
870 |
+
img_path = osp.join(data_path, 'images', d + postfix_img)
|
871 |
+
image_name = d
|
872 |
+
label_path = osp.join(data_path, 'TrainVal_parsing_annotations/', d + postfix_ann)
|
873 |
+
|
874 |
+
img_list.append(img_path)
|
875 |
+
label_list.append(label_path)
|
876 |
+
name_list.append(image_name)
|
877 |
+
|
878 |
+
return img_list, label_list, name_list
|
879 |
+
|
880 |
+
|
881 |
+
class FashionPediaParsingDataset(Human3M6ParsingDataset):
|
882 |
+
task_name = 'FashionPedia_parsing'
|
883 |
+
|
884 |
+
label_mapper = np.arange(60)
|
885 |
+
|
886 |
+
evaluate_size = ()
|
887 |
+
|
888 |
+
def __init__(self,
|
889 |
+
ginfo,
|
890 |
+
data_path,
|
891 |
+
data_use_ratio=1,
|
892 |
+
dataset='train',
|
893 |
+
is_train=True,
|
894 |
+
cfg=None,
|
895 |
+
**kwargs):
|
896 |
+
super(FashionPediaParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,
|
897 |
+
data_use_ratio=data_use_ratio,
|
898 |
+
dataset=dataset, is_train=is_train,
|
899 |
+
cfg=cfg, **kwargs)
|
900 |
+
|
901 |
+
def _list_dirs(self, data_path):
|
902 |
+
img_list = list()
|
903 |
+
label_list = list()
|
904 |
+
name_list = list()
|
905 |
+
|
906 |
+
if self.dataset == 'train':
|
907 |
+
train_type = 'train'
|
908 |
+
elif self.dataset == 'val':
|
909 |
+
train_type = 'val'
|
910 |
+
|
911 |
+
list_txt = osp.join(data_path, f'{train_type}_id.txt')
|
912 |
+
|
913 |
+
with open(list_txt, 'r') as f:
|
914 |
+
data = f.readlines()
|
915 |
+
data = [d.strip() for d in data]
|
916 |
+
|
917 |
+
|
918 |
+
postfix_img = '.jpg'
|
919 |
+
postfix_ann = '.png'
|
920 |
+
|
921 |
+
if self.data_use_ratio != 1:
|
922 |
+
data = random.sample(data, int(len(data) * self.data_use_ratio))
|
923 |
+
|
924 |
+
if train_type == 'train':
|
925 |
+
for d in data:
|
926 |
+
img_path = osp.join(data_path, 'train/', d + postfix_img)
|
927 |
+
image_name = d
|
928 |
+
label_path = osp.join(data_path, 'train_annotation/', d + postfix_ann)
|
929 |
+
|
930 |
+
img_list.append(img_path)
|
931 |
+
label_list.append(label_path)
|
932 |
+
name_list.append(image_name)
|
933 |
+
|
934 |
+
return img_list, label_list, name_list
|
935 |
+
|
936 |
+
elif train_type == 'val':
|
937 |
+
for d in data:
|
938 |
+
img_path = osp.join(data_path, 'test/', d + postfix_img)
|
939 |
+
image_name = d
|
940 |
+
label_path = osp.join(data_path, 'test_annotation/', d + postfix_ann)
|
941 |
+
|
942 |
+
img_list.append(img_path)
|
943 |
+
label_list.append(label_path)
|
944 |
+
name_list.append(image_name)
|
945 |
+
|
946 |
+
return img_list, label_list, name_list
|
947 |
+
else:
|
948 |
+
raise
|
949 |
+
|
950 |
+
|
951 |
+
class ModaNetParsingDataset(Human3M6ParsingDataset):
|
952 |
+
"""
|
953 |
+
modanet_par = {
|
954 |
+
0: 'Background',
|
955 |
+
1: 'Bag',
|
956 |
+
2: 'Belt',
|
957 |
+
3: 'Boots',
|
958 |
+
4: 'Footwear',
|
959 |
+
5: 'Outer',
|
960 |
+
6: 'Dress',
|
961 |
+
7: 'Sunglasses',
|
962 |
+
8: 'Pants',
|
963 |
+
9: 'Top',
|
964 |
+
10: 'Shorts',
|
965 |
+
11: 'Skirt',
|
966 |
+
12: 'Headwear',
|
967 |
+
13: 'Scarf & Tie'
|
968 |
+
}
|
969 |
+
"""
|
970 |
+
|
971 |
+
task_name = 'ModaNet_parsing'
|
972 |
+
label_mapper = np.arange(60)
|
973 |
+
left_right_pairs = None
|
974 |
+
evaluate_size = ()
|
975 |
+
|
976 |
+
def __init__(self,
|
977 |
+
ginfo,
|
978 |
+
data_path,
|
979 |
+
dataset='train',
|
980 |
+
data_use_ratio=1,
|
981 |
+
is_train=True,
|
982 |
+
cfg=None,
|
983 |
+
**kwargs):
|
984 |
+
super(ModaNetParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,data_use_ratio=data_use_ratio,
|
985 |
+
dataset=dataset, is_train=is_train,
|
986 |
+
cfg=cfg, **kwargs)
|
987 |
+
|
988 |
+
def _list_dirs(self, data_path):
|
989 |
+
img_list = list()
|
990 |
+
label_list = list()
|
991 |
+
name_list = list()
|
992 |
+
# image_dir = osp.join(data_path, 'protocal_1', 'rgb')
|
993 |
+
# label_dir = osp.join(data_path, 'protocal_1', 'seg')
|
994 |
+
|
995 |
+
if self.dataset == 'train':
|
996 |
+
train_type = 'train'
|
997 |
+
elif self.dataset == 'val':
|
998 |
+
train_type = 'val'
|
999 |
+
|
1000 |
+
list_txt = osp.join(data_path, f'{train_type}_id.txt')
|
1001 |
+
|
1002 |
+
with open(list_txt, 'r') as f:
|
1003 |
+
data = f.readlines()
|
1004 |
+
data = [d.strip() for d in data]
|
1005 |
+
|
1006 |
+
|
1007 |
+
if self.data_use_ratio != 1:
|
1008 |
+
data = random.sample(data, int(len(data) * self.data_use_ratio))
|
1009 |
+
|
1010 |
+
postfix_img = '.jpg'
|
1011 |
+
postfix_ann = '.png'
|
1012 |
+
|
1013 |
+
if train_type == 'train':
|
1014 |
+
for d in data:
|
1015 |
+
img_path = osp.join(data_path, f'images', d + postfix_img)
|
1016 |
+
image_name = d
|
1017 |
+
label_path = osp.join(data_path, f'seg',
|
1018 |
+
d + postfix_ann)
|
1019 |
+
|
1020 |
+
img_list.append(img_path)
|
1021 |
+
label_list.append(label_path)
|
1022 |
+
name_list.append(image_name)
|
1023 |
+
|
1024 |
+
return img_list, label_list, name_list
|
1025 |
+
else:
|
1026 |
+
raise ValueError("not implement")
|
1027 |
+
|
1028 |
+
|
1029 |
+
class MHPParsingDataset(Human3M6ParsingDataset):
|
1030 |
+
task_name = 'MHP_parsing'
|
1031 |
+
label_mapper = np.arange(60)
|
1032 |
+
left_right_pairs = None
|
1033 |
+
evaluate_size = ()
|
1034 |
+
|
1035 |
+
def __init__(self,
|
1036 |
+
ginfo,
|
1037 |
+
data_path,
|
1038 |
+
dataset='train',
|
1039 |
+
data_use_ratio=1,
|
1040 |
+
is_train=True,
|
1041 |
+
cfg=None,
|
1042 |
+
**kwargs):
|
1043 |
+
super(ModaNetParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,data_use_ratio=data_use_ratio,
|
1044 |
+
dataset=dataset, is_train=is_train,
|
1045 |
+
cfg=cfg, **kwargs)
|
1046 |
+
|
1047 |
+
def _list_dirs(self, data_path):
|
1048 |
+
img_list = list()
|
1049 |
+
label_list = list()
|
1050 |
+
name_list = list()
|
1051 |
+
|
1052 |
+
if self.dataset == 'train':
|
1053 |
+
train_type = 'train'
|
1054 |
+
elif self.dataset == 'val':
|
1055 |
+
train_type = 'val'
|
1056 |
+
|
1057 |
+
list_txt = osp.join(data_path, f'{train_type}_id.txt')
|
1058 |
+
|
1059 |
+
with open(list_txt, 'r') as f:
|
1060 |
+
data = f.readlines()
|
1061 |
+
data = [d.strip() for d in data]
|
1062 |
+
|
1063 |
+
|
1064 |
+
if self.data_use_ratio != 1:
|
1065 |
+
data = random.sample(data, int(len(data) * self.data_use_ratio))
|
1066 |
+
|
1067 |
+
postfix_img = '.jpg'
|
1068 |
+
postfix_ann = '.png'
|
1069 |
+
|
1070 |
+
if train_type == 'train':
|
1071 |
+
for d in data:
|
1072 |
+
img_path = osp.join(data_path, f'images/', d + postfix_img)
|
1073 |
+
image_name = d
|
1074 |
+
label_path = osp.join(data_path, f'processed_label/',
|
1075 |
+
d + postfix_ann)
|
1076 |
+
|
1077 |
+
img_list.append(img_path)
|
1078 |
+
label_list.append(label_path)
|
1079 |
+
name_list.append(image_name)
|
1080 |
+
|
1081 |
+
return img_list, label_list, name_list
|
1082 |
+
else:
|
1083 |
+
raise ValueError("not implement")
|
1084 |
+
|
core/data/datasets/images/pedattr_dataset.py
ADDED
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import pickle
|
4 |
+
import random
|
5 |
+
from easydict import EasyDict as edict
|
6 |
+
import numpy as np
|
7 |
+
import torch.utils.data as data
|
8 |
+
from PIL import Image
|
9 |
+
from core.data.transforms.pedattr_transforms import PedAttrAugmentation, PedAttrTestAugmentation, PedAttrRandomAugmentation
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
|
13 |
+
__all__ = ['AttrDataset', 'MultiAttrDataset']
|
14 |
+
|
15 |
+
def merge_pedattr_datasets(data_path_list, root_path_list, dataset_name_list, train,
|
16 |
+
data_use_ratio, text_label_return, select_data, ignore_other_attrs=True):
|
17 |
+
total_img_id = []
|
18 |
+
total_attr_num = 0
|
19 |
+
total_img_num = 0
|
20 |
+
total_attr_begin = []
|
21 |
+
total_attr_end = []
|
22 |
+
total_img_begin = []
|
23 |
+
total_img_end = []
|
24 |
+
total_text_dict = {}
|
25 |
+
attr_begin = []
|
26 |
+
attr_end = []
|
27 |
+
|
28 |
+
for data_path, root_path, dataset_name in zip(data_path_list, root_path_list, dataset_name_list):
|
29 |
+
assert dataset_name in ['peta', 'PA_100k', 'rap', 'rap2', 'uavhuman', 'HARDHC',
|
30 |
+
'ClothingAttribute', 'parse27k', 'duke', 'market','lup_0_200w', 'lup_0_600w', 'lup_600_1200w'], \
|
31 |
+
'dataset name {} is not exist'.format(dataset_name)
|
32 |
+
|
33 |
+
|
34 |
+
with open(data_path, 'rb') as f:
|
35 |
+
dataset_info = pickle.load(f)
|
36 |
+
dataset_info = edict(dataset_info)
|
37 |
+
img_id = dataset_info.image_name
|
38 |
+
attr_label = dataset_info.label
|
39 |
+
|
40 |
+
if train:
|
41 |
+
split = 'trainval'
|
42 |
+
else:
|
43 |
+
split = 'test'
|
44 |
+
|
45 |
+
attr_id = dataset_info.attr_name
|
46 |
+
attr_num = len(attr_id)
|
47 |
+
|
48 |
+
total_attr_begin.append(total_attr_num)
|
49 |
+
total_attr_num = total_attr_num + attr_num
|
50 |
+
total_attr_end.append(total_attr_num)
|
51 |
+
|
52 |
+
if select_data is None or (select_data!= None and dataset_name == select_data):
|
53 |
+
assert split in dataset_info.partition.keys(), f'split {split} is not exist'
|
54 |
+
img_idx = dataset_info.partition[split]
|
55 |
+
|
56 |
+
if isinstance(img_idx, list):
|
57 |
+
img_idx = img_idx[0] # default partition 0
|
58 |
+
|
59 |
+
if data_use_ratio != 1:
|
60 |
+
img_idx = random.sample(list(img_idx), int(len(img_idx) * data_use_ratio))
|
61 |
+
|
62 |
+
img_num = len(img_idx)
|
63 |
+
img_idx = np.array(img_idx)
|
64 |
+
|
65 |
+
img_id = [os.path.join(root_path, img_id[i]) for i in img_idx]
|
66 |
+
label = attr_label[img_idx]
|
67 |
+
|
68 |
+
|
69 |
+
total_img_begin.append(total_img_num)
|
70 |
+
total_img_num = total_img_num + len(img_id)
|
71 |
+
total_img_end.append(total_img_num)
|
72 |
+
else:
|
73 |
+
# when testing on a single dataset, split may not exist in other datasets. therefore, we need to set a fake
|
74 |
+
# split to make the code run. and the number of images in this fake split is 0.
|
75 |
+
# TODO: find a better way to solve this problem. e.g., use a time for-loop to load the select dataset
|
76 |
+
img_id = []
|
77 |
+
label = []
|
78 |
+
img_num = 0
|
79 |
+
total_img_begin.append(total_img_num)
|
80 |
+
total_img_num = total_img_num + len(img_id)
|
81 |
+
total_img_end.append(total_img_num)
|
82 |
+
|
83 |
+
infilling_class = -1 if ignore_other_attrs else 0
|
84 |
+
total_label = np.full((total_img_num, total_attr_num), infilling_class, dtype=np.int32)
|
85 |
+
select_attr_begin = 0
|
86 |
+
select_attr_end = total_attr_num
|
87 |
+
|
88 |
+
for index, (data_path, root_path, dataset_name) in enumerate(zip(data_path_list, root_path_list, dataset_name_list)):
|
89 |
+
|
90 |
+
assert dataset_name in ['peta', 'PA_100k', 'rap', 'rap2', 'uavhuman', 'HARDHC',
|
91 |
+
'ClothingAttribute', 'parse27k', 'duke', 'market','lup_0_200w', 'lup_0_600w', 'lup_600_1200w'], \
|
92 |
+
'dataset name {} is not exist'.format(dataset_name)
|
93 |
+
with open(data_path, 'rb') as f:
|
94 |
+
dataset_info = pickle.load(f)
|
95 |
+
dataset_info = edict(dataset_info)
|
96 |
+
|
97 |
+
img_id = dataset_info.image_name
|
98 |
+
attr_label = dataset_info.label
|
99 |
+
|
100 |
+
if train:
|
101 |
+
split = 'trainval'
|
102 |
+
else:
|
103 |
+
split = 'test'
|
104 |
+
|
105 |
+
if not train and dataset_name != select_data:
|
106 |
+
continue
|
107 |
+
|
108 |
+
assert split in dataset_info.partition.keys(), f'split {split} is not exist'
|
109 |
+
|
110 |
+
attr_id = dataset_info.attr_name
|
111 |
+
attr_num = len(attr_id)
|
112 |
+
|
113 |
+
img_idx = dataset_info.partition[split]
|
114 |
+
|
115 |
+
if isinstance(img_idx, list):
|
116 |
+
img_idx = img_idx[0] # default partition 0
|
117 |
+
|
118 |
+
if data_use_ratio != 1:
|
119 |
+
img_idx = random.sample(list(img_idx), int(len(img_idx) * data_use_ratio))
|
120 |
+
|
121 |
+
img_num = len(img_idx)
|
122 |
+
img_idx = np.array(img_idx)
|
123 |
+
|
124 |
+
img_id = [os.path.join(root_path, img_id[i]) for i in img_idx]
|
125 |
+
label = attr_label[img_idx]
|
126 |
+
# import pdb;pdb.set_trace()
|
127 |
+
if text_label_return:
|
128 |
+
for idx in range(attr_num):
|
129 |
+
total_text_dict[total_attr_begin[index] + idx] = eval(f"{dataset_name}_attr_name")[idx]
|
130 |
+
|
131 |
+
if not train:
|
132 |
+
if dataset_name == select_data:
|
133 |
+
total_label[total_img_begin[index]: total_img_end[index], total_attr_begin[index]: total_attr_end[index]] = label
|
134 |
+
total_img_id.extend(img_id)
|
135 |
+
attr_begin.extend([total_attr_begin[index] for i in img_idx])
|
136 |
+
attr_end.extend([total_attr_end[index] for i in img_idx])
|
137 |
+
else:
|
138 |
+
total_label[total_img_begin[index]: total_img_end[index], total_attr_begin[index]: total_attr_end[index]] = label
|
139 |
+
total_img_id.extend(img_id)
|
140 |
+
attr_begin.extend([total_attr_begin[index] for i in img_idx])
|
141 |
+
attr_end.extend([total_attr_end[index] for i in img_idx])
|
142 |
+
|
143 |
+
# import pdb;pdb.set_trace()
|
144 |
+
return total_img_id, total_label, total_text_dict, attr_begin, attr_end
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
class MultiAttrDataset(data.Dataset):
|
149 |
+
|
150 |
+
def __init__(self, ginfo, augmentation, task_spec, train=True, data_use_ratio=1, text_label_return=False,
|
151 |
+
select_data=None, ignore_other_attrs=True,
|
152 |
+
**kwargs):
|
153 |
+
data_path = task_spec.data_path
|
154 |
+
root_path = task_spec.root_path
|
155 |
+
dataset_name = task_spec.dataset
|
156 |
+
# import pdb; pdb.set_trace()
|
157 |
+
self.rank = dist.get_rank()
|
158 |
+
self.train = train
|
159 |
+
self.img_id, self.label, self.text_dict, self.attr_begin, self.attr_end = \
|
160 |
+
merge_pedattr_datasets(data_path, root_path, dataset_name, train,
|
161 |
+
data_use_ratio, text_label_return, select_data, ignore_other_attrs)
|
162 |
+
height = augmentation.height
|
163 |
+
width = augmentation.width
|
164 |
+
self.img_num = len(self.img_id)
|
165 |
+
|
166 |
+
if train:
|
167 |
+
self.transform = PedAttrAugmentation(height, width)
|
168 |
+
if augmentation.get('use_random_aug', False):
|
169 |
+
self.transform = PedAttrRandomAugmentation(height, width, \
|
170 |
+
augmentation.use_random_aug.m, augmentation.use_random_aug.n)
|
171 |
+
else:
|
172 |
+
self.transform = PedAttrTestAugmentation(height, width)
|
173 |
+
|
174 |
+
|
175 |
+
self.task_name = ginfo.task_name
|
176 |
+
|
177 |
+
def __getitem__(self, index):
|
178 |
+
return self.read_one(index)
|
179 |
+
|
180 |
+
def __len__(self):
|
181 |
+
return len(self.img_id)
|
182 |
+
|
183 |
+
def read_one(self, idx=None):
|
184 |
+
if idx == None:
|
185 |
+
idx = np.random.randint(self.img_num)
|
186 |
+
|
187 |
+
imgname, gt_label = self.img_id[idx], self.label[idx]
|
188 |
+
imgpath = imgname
|
189 |
+
|
190 |
+
try:
|
191 |
+
img = Image.open(imgpath).convert('RGB')
|
192 |
+
if self.transform is not None:
|
193 |
+
img = self.transform(img)
|
194 |
+
|
195 |
+
gt_label = gt_label.astype(np.float32)
|
196 |
+
|
197 |
+
output = {}
|
198 |
+
output = {'image': img, 'label': gt_label, 'filename': imgname, 'attr_begin': self.attr_begin[idx], 'attr_end': self.attr_end[idx]}
|
199 |
+
|
200 |
+
return output
|
201 |
+
except:
|
202 |
+
print('{} load failed'.format(imgpath))
|
203 |
+
return self.read_one()
|
204 |
+
|
205 |
+
def __repr__(self):
|
206 |
+
return self.__class__.__name__ + \
|
207 |
+
f'rank: {self.rank} task: {self.task_name} mode:{"training" if self.train else "inference"} ' \
|
208 |
+
f'dataset_len:{len(self.img_id)} augmentation: {self.transform}'
|
209 |
+
|
210 |
+
rap2_attr_name = {
|
211 |
+
0: {0:'without a bald head',1:'with a bald head'},
|
212 |
+
1: {0:'with short hair',1:'with long hair'},
|
213 |
+
2: {0:'with non-black hair',1:'with black hair'},
|
214 |
+
3: {0:'without a hat',1:'with a hat'},
|
215 |
+
4: {0:'without glasses',1:'with glasses'},
|
216 |
+
5: {0:'without a shirt',1:'with a shirt'},
|
217 |
+
6: {0:'without a sweater',1:'with a sweater'},
|
218 |
+
7: {0:'without a vest',1:'with a vest'},
|
219 |
+
8: {0:'without a t-shirt',1:'with a t-shirt'},
|
220 |
+
9: {0:'without cotton',1:'with cotton'},
|
221 |
+
10: {0:'without a jacket',1:'with a jacket'},
|
222 |
+
11: {0:'without formal wear',1:'with formal wear'},
|
223 |
+
12: {0:'without tight clothes',1:'with tight clothes'},
|
224 |
+
13: {0:'without short sleeves',1:'with short sleeves'},
|
225 |
+
14: {0:'without other upper-body clothing',1:'with other upper-body clothing'},
|
226 |
+
15: {0:'without long trousers',1:'with long trousers'},
|
227 |
+
16: {0:'without a skirt',1:'with a skirt'},
|
228 |
+
17: {0:'without a short skirt',1:'with a short skirt'},
|
229 |
+
18: {0:'without a dress',1:'with a dress'},
|
230 |
+
19: {0:'without jeans',1:'with jeans'},
|
231 |
+
20: {0:'without tight trousers',1:'with tight trousers'},
|
232 |
+
21: {0:'without leather shoes',1:'with leather shoes'},
|
233 |
+
22: {0:'without sport shoes',1:'with sport shoes'},
|
234 |
+
23: {0:'without boots',1:'with boots'},
|
235 |
+
24: {0:'without cloth shoes',1:'with cloth shoes'},
|
236 |
+
25: {0:'without casual shoes',1:'with casual shoes'},
|
237 |
+
26: {0:'without other shoes',1:'with other shoes'},
|
238 |
+
27: {0:'without a backpack',1:'with a backpack'},
|
239 |
+
28: {0:'without a shoulder bag',1:'with a shoulder bag'},
|
240 |
+
29: {0:'without a handbag',1:'with a handbag'},
|
241 |
+
30: {0:'without a box',1:'with a box'},
|
242 |
+
31: {0:'without a plastic bag',1:'with a plastic bag'},
|
243 |
+
32: {0:'without a paper bag',1:'with a paper bag'},
|
244 |
+
33: {0:'without a hand trunk',1:'with a hand trunk'},
|
245 |
+
34: {0:'without other attachments',1:'with other attachments'},
|
246 |
+
35: {0:'age greater than 16',1:'age less than or equal to 16'},
|
247 |
+
36: {0:'age less than 17 or greater than 30',1:'age between 17 and 30'},
|
248 |
+
37: {0:'age less than 31 or greater than 45',1:'age between 31 and 45'},
|
249 |
+
38: {0:'age less than 46 or greater than 60',1:'age between 46 and 60'},
|
250 |
+
39: {0:'male',1:'female', 2:'gender unknown'},
|
251 |
+
40: {0:'without excess body fat',1:'with excess body fat'},
|
252 |
+
41: {0:'without normal body shape',1:'with normal body shape'},
|
253 |
+
42: {0:'without thin body shape',1:'with thin body shape'},
|
254 |
+
43: {0:'not a customer',1:'is a customer'},
|
255 |
+
44: {0:'not an employee',1:'is an employee'},
|
256 |
+
45: {0:'not calling',1:'calling'},
|
257 |
+
46: {0:'not talking',1:'talking'},
|
258 |
+
47: {0:'not gathering',1:'gathering'},
|
259 |
+
48: {0:'not holding anything',1:'holding something'},
|
260 |
+
49: {0:'not pushing anything',1:'pushing something'},
|
261 |
+
50: {0:'not pulling anything',1:'pulling something'},
|
262 |
+
51: {0:'not carrying anything in arms',1:'carrying something in arms'},
|
263 |
+
52: {0:'not carrying anything in hands',1:'carrying something in hands'},
|
264 |
+
53: {0:'no other actions',1:'performing other actions'}
|
265 |
+
}
|
266 |
+
|
267 |
+
PA_100k_attr_name = {
|
268 |
+
0: {0:'without a hat',1:'with a hat'},
|
269 |
+
1: {0:'without glasses',1:'with glasses'},
|
270 |
+
2: {0:'without short sleeves',1:'with short sleeves'},
|
271 |
+
3: {0:'without long sleeves',1:'with long sleeves'},
|
272 |
+
4: {0:'without stripe upper-clothes',1:'with stripe upper-clothes'},
|
273 |
+
5: {0:'without logo upper-clothes',1:'with logo upper-clothes'},
|
274 |
+
6: {0:'without plaid upper-clothes',1:'with plaid upper-clothes'},
|
275 |
+
7: {0:'without splice upper-clothes',1:'with splice upper-clothes'},
|
276 |
+
8: {0:'without stripe lower-clothes',1:'with stripe lower-clothes'},
|
277 |
+
9: {0:'without pattern lower-clothes',1:'with pattern lower-clothes'},
|
278 |
+
10: {0:'without long coat',1:'with long coat'},
|
279 |
+
11: {0:'without long trousers',1:'with long trousers'},
|
280 |
+
12: {0:'without short trousers',1:'with short trousers'},
|
281 |
+
13: {0:'without skirt or dress',1:'with skirt or dress'},
|
282 |
+
14: {0:'without boots',1:'with boots'},
|
283 |
+
15: {0:'without a handbag',1:'with a handbag'},
|
284 |
+
16: {0:'without a shoulder bag',1:'with a shoulder bag'},
|
285 |
+
17: {0:'without a backpack',1:'with a backpack'},
|
286 |
+
18: {0:'not hold objects in front',1:'hold objects in front'},
|
287 |
+
19: {0:'age less than or equal to 60',1:'age greater than 60'},
|
288 |
+
20: {0:'age less than 18 or greater than 60',1:'age between 18 and 60'},
|
289 |
+
21: {0:'age greater than or equal to 18',1:'age less than 18'},
|
290 |
+
22: {0:'male',1:'female', 2:'gender unknown'},
|
291 |
+
23: {0:'not in the front position',1:'in the front position'},
|
292 |
+
24: {0:'not in the side position',1:'in the side position'},
|
293 |
+
25: {0:'not in the back position',1:'in the back position'},
|
294 |
+
}
|
295 |
+
|
296 |
+
HARDHC_attr_name = {
|
297 |
+
0: {0:'female', 1:'male', -1:'gender unknown'},
|
298 |
+
1: {0:'with short hair',1:'with long hair'},
|
299 |
+
2: {0:'without sunglass',1:'with sunglass'},
|
300 |
+
3: {0:'without a hat',1:'with a hat'},
|
301 |
+
4: {0:'without T-skirt',1:'with T-skirt'},
|
302 |
+
5: {0:'without long sleeves',1:'with long sleeves'},
|
303 |
+
6: {0:'without formal clothes',1:'with formal clothes'},
|
304 |
+
7: {0:'without short trousers',1:'with short trousers'},
|
305 |
+
8: {0:'without jeans',1:'with jeans'},
|
306 |
+
9: {0:'without long pants',1:'with long pants'},
|
307 |
+
10: {0:'without skirt',1:'with skirt'},
|
308 |
+
11: {0:'without face mask',1:'with face mask'},
|
309 |
+
12: {0:'without logo clothes',1:'with logo clothes'},
|
310 |
+
13: {0:'without stripe clothes',1:'with stripe clothes'},
|
311 |
+
}
|
312 |
+
|
313 |
+
parse27k_attr_name = {
|
314 |
+
0: {0:'without a bald head',1:'with a bald head'},
|
315 |
+
1: {0:'with short hair',1:'with long hair'},
|
316 |
+
2: {0:'with non-black hair',1:'with black hair'},
|
317 |
+
3: {0:'without a hat',1:'with a hat'},
|
318 |
+
4: {0:'without glasses',1:'with glasses'},
|
319 |
+
5: {0:'without a shirt',1:'with a shirt'},
|
320 |
+
6: {0:'without a sweater',1:'with a sweater'},
|
321 |
+
7: {0:'without a vest',1:'with a vest'},
|
322 |
+
8: {0:'without a t-shirt',1:'with a t-shirt'},
|
323 |
+
9: {0:'without cotton',1:'with cotton'},
|
324 |
+
10: {0:'without a jacket',1:'with a jacket'},
|
325 |
+
11: {0:'without formal wear',1:'with formal wear'},
|
326 |
+
12: {0:'without tight clothes',1:'with tight clothes'},
|
327 |
+
13: {0:'without short sleeves',1:'with short sleeves'},
|
328 |
+
14: {0:'without other upper-body clothing',1:'with other upper-body clothing'},
|
329 |
+
15: {0:'without long trousers',1:'with long trousers'},
|
330 |
+
16: {0:'without a skirt',1:'with a skirt'},
|
331 |
+
17: {0:'without a short skirt',1:'with a short skirt'},
|
332 |
+
18: {0:'without a dress',1:'with a dress'},
|
333 |
+
19: {0:'without jeans',1:'with jeans'},
|
334 |
+
20: {0:'without tight trousers',1:'with tight trousers'},
|
335 |
+
21: {0:'without leather shoes',1:'with leather shoes'},
|
336 |
+
22: {0:'without sport shoes',1:'with sport shoes'},
|
337 |
+
23: {0:'without boots',1:'with boots'},
|
338 |
+
24: {0:'without cloth shoes',1:'with cloth shoes'},
|
339 |
+
25: {0:'without casual shoes',1:'with casual shoes'},
|
340 |
+
26: {0:'without other shoes',1:'with other shoes'},
|
341 |
+
27: {0:'without a backpack',1:'with a backpack'},
|
342 |
+
28: {0:'without a shoulder bag',1:'with a shoulder bag'},
|
343 |
+
29: {0:'without a handbag',1:'with a handbag'},
|
344 |
+
30: {0:'without a box',1:'with a box'},
|
345 |
+
31: {0:'without a plastic bag',1:'with a plastic bag'},
|
346 |
+
32: {0:'without a paper bag',1:'with a paper bag'},
|
347 |
+
33: {0:'without a hand trunk',1:'with a hand trunk'},
|
348 |
+
34: {0:'without other attachments',1:'with other attachments'},
|
349 |
+
35: {0:'age greater than 16',1:'age less than or equal to 16'},
|
350 |
+
36: {0:'age less than 17 or greater than 30',1:'age between 17 and 30'},
|
351 |
+
37: {0:'age less than 31 or greater than 45',1:'age between 31 and 45'},
|
352 |
+
38: {0:'age less than 46 or greater than 60',1:'age between 46 and 60'},
|
353 |
+
39: {0:'male',1:'female', 2:'gender unknown'},
|
354 |
+
40: {0:'without excess body fat',1:'with excess body fat'},
|
355 |
+
41: {0:'without normal body shape',1:'with normal body shape'},
|
356 |
+
42: {0:'without thin body shape',1:'with thin body shape'},
|
357 |
+
43: {0:'not a customer',1:'is a customer'}
|
358 |
+
}
|
359 |
+
|
360 |
+
uavhuman_attr_name = {
|
361 |
+
0: {0:'female',1:'male'},
|
362 |
+
1: {0:'without red backpack',1:'with red backpack'},
|
363 |
+
2: {0:'without black backpack',1:'with black backpack'},
|
364 |
+
3: {0:'without green backpack',1:'with green backpack'},
|
365 |
+
4: {0:'without yellow backpack',1:'with yellow backpack'},
|
366 |
+
5: {0:'without other backpack',1:'with other backpack'},
|
367 |
+
6: {0:'without red hat',1:'with red hat'},
|
368 |
+
7: {0:'without black hat',1:'with black hat'},
|
369 |
+
8: {0:'without yellow hat',1:'with yellow hat'},
|
370 |
+
9: {0:'without white hat',1:'with white hat'},
|
371 |
+
10: {0:'without other hat',1:'with other hat'},
|
372 |
+
11: {0:'without red upper-clothes',1:'with red upper-clothes'},
|
373 |
+
12: {0:'without black upper-clothes',1:'with black upper-clothes'},
|
374 |
+
13: {0:'without blue upper-clothes',1:'with blue upper-clothes'},
|
375 |
+
14: {0:'without green upper-clothes',1:'with green upper-clothes'},
|
376 |
+
15: {0:'without multicolor upper-clothes',1:'with multicolor upper-clothes'},
|
377 |
+
16: {0:'without grey upper-clothes',1:'with grey upper-clothes'},
|
378 |
+
17: {0:'without white upper-clothes',1:'with white upper-clothes'},
|
379 |
+
18: {0:'without yellow upper-clothes',1:'with yellow upper-clothes'},
|
380 |
+
19: {0:'without dark brown upper-clothes',1:'with dark brown upper-clothes'},
|
381 |
+
20: {0:'without purple upper-clothes',1:'with purple upper-clothes'},
|
382 |
+
21: {0:'without pink upper-clothes',1:'with pink upper-clothes'},
|
383 |
+
22: {0:'without other upper-clothes',1:'with other upper-clothes'},
|
384 |
+
23: {0:'without long upper-clothes style',1:'with long upper-clothes style'},
|
385 |
+
24: {0:'without short upper-clothes style',1:'with short upper-clothes style'},
|
386 |
+
25: {0:'without skirt upper-clothes style',1:'with skirt upper-clothes style'},
|
387 |
+
26: {0:'without other upper-clothes style',1:'with other upper-clothes style'},
|
388 |
+
27: {0:'without red lower clothes',1:'with red lower clothes'},
|
389 |
+
28: {0:'without black lower clothes',1:'with black lower clothes'},
|
390 |
+
29: {0:'without blue lower clothes',1:'with blue lower clothes'},
|
391 |
+
30: {0:'without green lower clothes',1:'with green lower clothes'},
|
392 |
+
31: {0:'without multicolor lower clothes',1:'with multicolor lower clothes'},
|
393 |
+
32: {0:'without grey lower clothes',1:'with grey lower clothes'},
|
394 |
+
33: {0:'without white lower-clothes',1:'with white lower-clothes'},
|
395 |
+
34: {0:'without yellow lower-clothes',1:'with yellow lower-clothes'},
|
396 |
+
35: {0:'without dark brown lower-clothes',1:'with dark brown lower-clothes'},
|
397 |
+
36: {0:'without purple lower-clothes',1:'with purple lower-clothes'},
|
398 |
+
37: {0:'without pink lower-clothes',1:'with pink lower-clothes'},
|
399 |
+
38: {0:'without other lower-clothes',1:'with other lower-clothes'},
|
400 |
+
39: {0:'without long lower-clothes style',1:'with long lower-clothes style'},
|
401 |
+
40: {0:'without short lower-clothes style',1:'with short lower-clothes style'},
|
402 |
+
41: {0:'without skirt lower-clothes style',1:'with skirt lower-clothes style'},
|
403 |
+
42: {0:'without other lower-clothes style',1:'with other lower-clothes style'}
|
404 |
+
}
|
405 |
+
|
406 |
+
market_attr_name = {
|
407 |
+
0: {0:'without a backpack',1:'with a backpack'},
|
408 |
+
1: {0:'without a bag',1:'with a bag'},
|
409 |
+
2: {0:'without a handbag',1:'with a handbag'},
|
410 |
+
3: {0:'without black lower-clothes',1:'with black lower-clothes'},
|
411 |
+
4: {0:'without blue lower-clothes',1:'with blue lower-clothes'},
|
412 |
+
5: {0:'without brown lower-clothes',1:'with brown lower-clothes'},
|
413 |
+
6: {0:'without gray lower-clothes',1:'with gray lower-clothes'},
|
414 |
+
7: {0:'without green lower-clothes',1:'with green lower-clothes'},
|
415 |
+
8: {0:'without pink lower-clothes',1:'with pink lower-clothes'},
|
416 |
+
9: {0:'without purple lower-clothes',1:'with purple lower-clothes'},
|
417 |
+
10: {0:'without white lower-clothes',1:'with white lower-clothes'},
|
418 |
+
11: {0:'without yellow lower-clothes',1:'with yellow lower-clothes'},
|
419 |
+
12: {0:'without black upper-clothes',1:'with black upper-clothes'},
|
420 |
+
13: {0:'without blue upper-clothes',1:'with blue upper-clothes'},
|
421 |
+
14: {0:'without green upper-clothes',1:'with green upper-clothes'},
|
422 |
+
15: {0:'without gray upper-clothes',1:'with gray upper-clothes'},
|
423 |
+
16: {0:'without purple upper-clothes',1:'with purple upper-clothes'},
|
424 |
+
17: {0:'without red upper-clothes',1:'with red upper-clothes'},
|
425 |
+
18: {0:'without white upper-clothes',1:'with white upper-clothes'},
|
426 |
+
19: {0:'without yellow upper-clothes',1:'with yellow upper-clothes'},
|
427 |
+
20: {0:'with dress',1:'with pants'},
|
428 |
+
21: {0:'with long lower body clothing',1:'with short lower body clothing'},
|
429 |
+
22: {0:'with long sleeve upper body clothing',1:'with short upper body clothing'},
|
430 |
+
23: {0:'with short hair',1:'with long hair'},
|
431 |
+
24: {0:'without a hat',1:'with a hat'},
|
432 |
+
25: {0:'male',1:'female'},
|
433 |
+
26: {0:'not a young person',1:'a young person'},
|
434 |
+
27: {0:'not a teenager',1:'a teenager'},
|
435 |
+
28: {0:'not an adult',1:'an adult'},
|
436 |
+
29: {0:'not an old person',1:'an old person'}
|
437 |
+
}
|
438 |
+
# peta attr name still have some bugs
|
439 |
+
peta_attr_name = {
|
440 |
+
0: {0:'without hat accessory',1:'with hat accessory'},
|
441 |
+
1: {0:'without muffler accessory',1:'with muffler accessory'},
|
442 |
+
2: {0:'with accessory',1:'with nothing accessory'},
|
443 |
+
3: {0:'without sunglasses accessory',1:'with sunglasses accessory'},
|
444 |
+
4: {0:'with short hair',1:'with long hair'},
|
445 |
+
5: {0:'without casual upper body wear',1:'with casual upper body wear'},
|
446 |
+
6: {0:'without formal upper body wear',1:'with formal upper body wear'},
|
447 |
+
7: {0:'without jacket upper body wear',1:'with jacket upper body wear'},
|
448 |
+
8: {0:'without logo upper body wear',1:'with logo upper body wear'},
|
449 |
+
9: {0:'without plaid upper body wear',1:'with plaid upper body wear'},
|
450 |
+
10: {0:'without short sleeve upper body wear',1:'with short sleeve upper body wear'},
|
451 |
+
11: {0:'without thin stripes upper body wear',1:'with thin stripes upper body wear'},
|
452 |
+
12: {0:'without t-shirt upper body wear',1:'with t-shirt upper body wear'},
|
453 |
+
13: {0:'without other upper body wear',1:'with other upper body wear'},
|
454 |
+
14: {0:'without vneck upper body wear',1:'with vneck upper body wear'},
|
455 |
+
15: {0:'without casual lower body wear',1:'with casual lower body wear'},
|
456 |
+
16: {0:'without formal lower body wear',1:'with formal lower body wear'},
|
457 |
+
17: {0:'without jeans lower body wear',1:'with jeans lower body wear'},
|
458 |
+
18: {0:'without shorts lower body wear',1:'with shorts lower body wear'},
|
459 |
+
19: {0:'without shortskirt lower body wear',1:'with shortskirt lower body wear'},
|
460 |
+
20: {0:'without trousers lower body wear',1:'with trousers lower body wear'},
|
461 |
+
21: {0: 'without leather shoes', 1: 'with leather shoes'},
|
462 |
+
22: {0: 'without sandals', 1: 'with sandals'},
|
463 |
+
23: {0: 'without shoes', 1: 'with shoes'},
|
464 |
+
24: {0: 'without sneaker', 1: 'with sneaker'},
|
465 |
+
25: {0: 'without carrying backpack', 1: 'carrying backpack'},
|
466 |
+
26: {0: 'with carrying other things', 1: 'carrying other things'},
|
467 |
+
27: {0: 'without carrying messengerbag', 1: 'carrying messengerbag'},
|
468 |
+
28: {0: 'carrying something', 1: 'carrying nothing'},
|
469 |
+
29: {0: 'without carrying plasticbags', 1: 'carrying plasticbags'},
|
470 |
+
30: {0:'age greater than or equal to 30',1:'age less than 30'},
|
471 |
+
31: {0:'age less than 31 or greater than 45',1:'age between 31 and 45'},
|
472 |
+
32: {0:'age less than 46 or greater than 60',1:'age between 46 and 60'},
|
473 |
+
33: {0:'age less than or equal to 60',1:'age larger than 60'},
|
474 |
+
34: {0:'female',1:'male'}
|
475 |
+
}
|
476 |
+
|
477 |
+
duke_attr_name = {
|
478 |
+
0: {0:'without a backpack',1:'with a backpack'},
|
479 |
+
1: {0:'without a bag',1:'with a bag'},
|
480 |
+
2: {0:'without boots',1:'with boots'},
|
481 |
+
3: {0:'without black lower-clothes',1:'with black lower-clothes'},
|
482 |
+
4: {0:'without blue lower-clothes',1:'with blue lower-clothes'},
|
483 |
+
5: {0:'without brown lower-clothes',1:'with brown lower-clothes'},
|
484 |
+
6: {0:'without gray lower-clothes',1:'with gray lower-clothes'},
|
485 |
+
7: {0:'without green lower-clothes',1:'with green lower-clothes'},
|
486 |
+
8: {0:'without red lower-clothes',1:'with red lower-clothes'},
|
487 |
+
9: {0:'without white lower-clothes',1:'with white lower-clothes'},
|
488 |
+
10: {0:'male',1:'female'},
|
489 |
+
11: {0:'without a handbag',1:'with a handbag'},
|
490 |
+
12: {0:'without a hat',1:'with a hat'},
|
491 |
+
13: {0:'with dark shoes',1:'with light shoes'},
|
492 |
+
14: {0:'short top clothing',1:'long top clothing'},
|
493 |
+
15: {0:'without black upper-clothes',1:'with black upper-clothes'},
|
494 |
+
16: {0:'without blue upper-clothes',1:'with blue upper-clothes'},
|
495 |
+
17: {0:'without brown upper-clothes',1:'with brown upper-clothes'},
|
496 |
+
18: {0:'without gray upper-clothes',1:'with gray upper-clothes'},
|
497 |
+
19: {0:'without green upper-clothes',1:'with green upper-clothes'},
|
498 |
+
20: {0:'without purple upper-clothes',1:'with purple upper-clothes'},
|
499 |
+
21: {0:'without red upper-clothes',1:'with red upper-clothes'},
|
500 |
+
22: {0:'without white upper-clothes',1:'with white upper-clothes'},
|
501 |
+
}
|
502 |
+
|
503 |
+
ClothingAttribute_attr_name = ['pattern_spot', 'cyan', 'brown', 'v_shape_neckline', 'round_neckline', 'other_neckline', 'no_sleevelength', 'short_sleevelength', 'long_sleevelength', 'pattern_graphics', 'gender', 'black', 'many_colors', 'white', 'pattern_floral', 'collar', 'blue', 'necktie', 'pattern_stripe', 'pattern_solid', 'gray', 'shirt_category', 'sweater_category', 't_shirt_category', 'outerwear_category', 'suit_category', 'tank_top_category', 'dress_category', 'placket', 'pattern_plaid', 'purple', 'scarf', 'green', 'yellow', 'skin_exposure', 'red']
|
504 |
+
|
505 |
+
lup_0_200w_attr_name = {
|
506 |
+
0: {0: 'male', 1: 'female', -1:'gender unknown'},
|
507 |
+
1: {0: 'age greater than 6', 1: "age less than or equal to 6", -1: 'age unknown'},
|
508 |
+
2: {0: 'age less than 7 or greater than 18', 1: "age between 7 and 18", -1: 'age unknown'},
|
509 |
+
3: {0: 'age less than 19 or greater than 65', 1: "age between 19 and 65", -1: 'age unknown'},
|
510 |
+
4: {0: 'age less than 66', 1: "age greater than or equal to 66", -1: 'age unknown'},
|
511 |
+
5: {0: 'with short sleeve coat', 1: 'with long sleeves', -1: 'coat length unknown'},
|
512 |
+
6: {0: 'with shorts trousers', 1: 'with long trousers'},
|
513 |
+
7: {0: 'without a skirt', 1:'with a skirt'},
|
514 |
+
8: {0: 'without a pure pattern coat', 1: 'with a pure upper-clothes'},
|
515 |
+
9: {0: 'without a stripe pattern coat', 1: 'with a stripe upper-clothes'},
|
516 |
+
10: {0: 'without a design pattern coat', 1: 'with a design upper-clothes'},
|
517 |
+
11: {0: 'without a joint pattern coat', 1: 'with a joint upper-clothes'},
|
518 |
+
12: {0: 'without a lattic pattern coat', 1: 'with a lattic upper-clothes'},
|
519 |
+
13: {0: 'without a black color trousers', 1: 'with black lower-clothes'},
|
520 |
+
14: {0: 'without a white color trousers', 1: 'with white lower-clothes'},
|
521 |
+
15: {0: 'without a gray color trousers', 1: 'with a gray color trousers'},
|
522 |
+
16: {0: 'without a red color trousers', 1: 'with a red color trousers'},
|
523 |
+
17: {0: 'without a yellow color trousers', 1: 'with a yellow color trousers'},
|
524 |
+
18: {0: 'without a blue color trousers', 1: 'with a blue color trousers'},
|
525 |
+
19: {0: 'without a green color trousers', 1: 'with a green color trousers'},
|
526 |
+
20: {0: 'without a purple color trousers', 1: 'with a purple color trousers'},
|
527 |
+
21: {0: 'without a pure pattern trousers', 1: 'with a pure lower-clothes'},
|
528 |
+
22: {0: 'without a stripe pattern trousers', 1: 'with a stripe lower-clothes'},
|
529 |
+
23: {0: 'without a design pattern trousers', 1: 'with a design lower-clothes'},
|
530 |
+
24: {0: 'without a joint pattern trousers', 1: 'with a joint lower-clothes'},
|
531 |
+
25: {0: 'without a lattic pattern trousers', 1: 'with a lattic lower-clothes'},
|
532 |
+
26: {0: 'without a hat', 1: 'with a hat', -1: 'hat unknown'},
|
533 |
+
27: {0: 'without a jacket', 1: 'with a jacket'},
|
534 |
+
28: {0: 'without a sweater', 1: 'with a sweater'},
|
535 |
+
29: {0: 'without a long coat', 1: 'with a long coat'},
|
536 |
+
30: {0: 'without a shirt', 1: 'with a shirt'},
|
537 |
+
31: {0: 'without a dress', 1: 'with a dress'},
|
538 |
+
32: {0: 'without a business suit', 1: 'with a business suit'},
|
539 |
+
33: {0: 'without a black color coat', 1: 'with a black color coat', -1:'unknown coat color'},
|
540 |
+
34: {0: 'without a white color coat', 1: 'with a white color coat', -1:'unknown coat color'},
|
541 |
+
35: {0: 'without a gray color coat', 1: 'with a gray color coat', -1:'unknown coat color'},
|
542 |
+
36: {0: 'without a red color coat', 1: 'with a red color coat', -1:'unknown coat color'},
|
543 |
+
37: {0: 'without a yellow color coat', 1: 'with a yellow color coat', -1:'unknown coat color'},
|
544 |
+
38: {0: 'without a blue color coat', 1: 'with a blue color coat', -1:'unknown coat color'},
|
545 |
+
39: {0: 'without a green color coat', 1: 'with a green color coat', -1:'unknown coat color'},
|
546 |
+
40: {0: 'without a purple color coat', 1: 'with a purple color coat', -1:'unknown coat color'},
|
547 |
+
41: {0: 'with short hair', 1: 'with long hair', -1: 'unknown hair style'},
|
548 |
+
42: {0: 'without leather shoes', 1: 'with leather shoes'},
|
549 |
+
43: {0: 'without boots', 1: 'with boots'},
|
550 |
+
44: {0: 'without walking shoes', 1: 'with walking shoes'},
|
551 |
+
45: {0: 'without sandal', 1: 'with sandal'},
|
552 |
+
46: {0: 'without a bag', 1: 'without a bag', -1: 'unknown bag style'},
|
553 |
+
47: {0: 'without glasses', 1: 'with glasses'},
|
554 |
+
48: {0: 'not stand', 1: 'stand', -1: 'unknown pose'},
|
555 |
+
49: {0: 'not sit', 1: 'sit', -1: 'unknown pose'},
|
556 |
+
50: {0: 'not lie', 1: 'lie', -1: 'unknown pose'},
|
557 |
+
51: {0: 'not stoop', 1: 'stoop', -1: 'unknown pose'}}
|
558 |
+
|
559 |
+
lup_0_600w_attr_name = {
|
560 |
+
0: {0: 'male', 1: 'female', -1:'gender unknown'},
|
561 |
+
1: {0: 'age greater than 6', 1: "age less than or equal to 6", -1: 'age unknown'},
|
562 |
+
2: {0: 'age less than 7 or greater than 18', 1: "age between 7 and 18", -1: 'age unknown'},
|
563 |
+
3: {0: 'age less than 19 or greater than 65', 1: "age between 19 and 65", -1: 'age unknown'},
|
564 |
+
4: {0: 'age less than 66', 1: "age greater than or equal to 66", -1: 'age unknown'},
|
565 |
+
5: {0: 'with short sleeve coat', 1: 'with long sleeves', -1: 'coat length unknown'},
|
566 |
+
6: {0: 'with shorts trousers', 1: 'with long trousers'},
|
567 |
+
7: {0: 'without a skirt', 1:'with a skirt'},
|
568 |
+
8: {0: 'without a pure pattern coat', 1: 'with a pure upper-clothes'},
|
569 |
+
9: {0: 'without a stripe pattern coat', 1: 'with a stripe upper-clothes'},
|
570 |
+
10: {0: 'without a design pattern coat', 1: 'with a design upper-clothes'},
|
571 |
+
11: {0: 'without a joint pattern coat', 1: 'with a joint upper-clothes'},
|
572 |
+
12: {0: 'without a lattic pattern coat', 1: 'with a lattic upper-clothes'},
|
573 |
+
13: {0: 'without a black color trousers', 1: 'with black lower-clothes'},
|
574 |
+
14: {0: 'without a white color trousers', 1: 'with white lower-clothes'},
|
575 |
+
15: {0: 'without a gray color trousers', 1: 'with a gray color trousers'},
|
576 |
+
16: {0: 'without a red color trousers', 1: 'with a red color trousers'},
|
577 |
+
17: {0: 'without a yellow color trousers', 1: 'with a yellow color trousers'},
|
578 |
+
18: {0: 'without a blue color trousers', 1: 'with a blue color trousers'},
|
579 |
+
19: {0: 'without a green color trousers', 1: 'with a green color trousers'},
|
580 |
+
20: {0: 'without a purple color trousers', 1: 'with a purple color trousers'},
|
581 |
+
21: {0: 'without a pure pattern trousers', 1: 'with a pure lower-clothes'},
|
582 |
+
22: {0: 'without a stripe pattern trousers', 1: 'with a stripe lower-clothes'},
|
583 |
+
23: {0: 'without a design pattern trousers', 1: 'with a design lower-clothes'},
|
584 |
+
24: {0: 'without a joint pattern trousers', 1: 'with a joint lower-clothes'},
|
585 |
+
25: {0: 'without a lattic pattern trousers', 1: 'with a lattic lower-clothes'},
|
586 |
+
26: {0: 'without a hat', 1: 'with a hat', -1: 'hat unknown'},
|
587 |
+
27: {0: 'without a jacket', 1: 'with a jacket'},
|
588 |
+
28: {0: 'without a sweater', 1: 'with a sweater'},
|
589 |
+
29: {0: 'without a long coat', 1: 'with a long coat'},
|
590 |
+
30: {0: 'without a shirt', 1: 'with a shirt'},
|
591 |
+
31: {0: 'without a dress', 1: 'with a dress'},
|
592 |
+
32: {0: 'without a business suit', 1: 'with a business suit'},
|
593 |
+
33: {0: 'without a black color coat', 1: 'with a black color coat', -1:'unknown coat color'},
|
594 |
+
34: {0: 'without a white color coat', 1: 'with a white color coat', -1:'unknown coat color'},
|
595 |
+
35: {0: 'without a gray color coat', 1: 'with a gray color coat', -1:'unknown coat color'},
|
596 |
+
36: {0: 'without a red color coat', 1: 'with a red color coat', -1:'unknown coat color'},
|
597 |
+
37: {0: 'without a yellow color coat', 1: 'with a yellow color coat', -1:'unknown coat color'},
|
598 |
+
38: {0: 'without a blue color coat', 1: 'with a blue color coat', -1:'unknown coat color'},
|
599 |
+
39: {0: 'without a green color coat', 1: 'with a green color coat', -1:'unknown coat color'},
|
600 |
+
40: {0: 'without a purple color coat', 1: 'with a purple color coat', -1:'unknown coat color'},
|
601 |
+
41: {0: 'with short hair', 1: 'with long hair', -1: 'unknown hair style'},
|
602 |
+
42: {0: 'without leather shoes', 1: 'with leather shoes'},
|
603 |
+
43: {0: 'without boots', 1: 'with boots'},
|
604 |
+
44: {0: 'without walking shoes', 1: 'with walking shoes'},
|
605 |
+
45: {0: 'without sandal', 1: 'with sandal'},
|
606 |
+
46: {0: 'without a bag', 1: 'without a bag', -1: 'unknown bag style'},
|
607 |
+
47: {0: 'without glasses', 1: 'with glasses'},
|
608 |
+
48: {0: 'not stand', 1: 'stand', -1: 'unknown pose'},
|
609 |
+
49: {0: 'not sit', 1: 'sit', -1: 'unknown pose'},
|
610 |
+
50: {0: 'not lie', 1: 'lie', -1: 'unknown pose'},
|
611 |
+
51: {0: 'not stoop', 1: 'stoop', -1: 'unknown pose'}}
|
612 |
+
|
613 |
+
lup_600_1200w_attr_name = {
|
614 |
+
0: {0: 'male', 1: 'female', -1:'gender unknown'},
|
615 |
+
1: {0: 'age greater than 6', 1: "age less than or equal to 6", -1: 'age unknown'},
|
616 |
+
2: {0: 'age less than 7 or greater than 18', 1: "age between 7 and 18", -1: 'age unknown'},
|
617 |
+
3: {0: 'age less than 19 or greater than 65', 1: "age between 19 and 65", -1: 'age unknown'},
|
618 |
+
4: {0: 'age less than 66', 1: "age greater than or equal to 66", -1: 'age unknown'},
|
619 |
+
5: {0: 'with short sleeve coat', 1: 'with long sleeves', -1: 'coat length unknown'},
|
620 |
+
6: {0: 'with shorts trousers', 1: 'with long trousers'},
|
621 |
+
7: {0: 'without a skirt', 1:'with a skirt'},
|
622 |
+
8: {0: 'without a pure pattern coat', 1: 'with a pure upper-clothes'},
|
623 |
+
9: {0: 'without a stripe pattern coat', 1: 'with a stripe upper-clothes'},
|
624 |
+
10: {0: 'without a design pattern coat', 1: 'with a design upper-clothes'},
|
625 |
+
11: {0: 'without a joint pattern coat', 1: 'with a joint upper-clothes'},
|
626 |
+
12: {0: 'without a lattic pattern coat', 1: 'with a lattic upper-clothes'},
|
627 |
+
13: {0: 'without a black color trousers', 1: 'with black lower-clothes'},
|
628 |
+
14: {0: 'without a white color trousers', 1: 'with white lower-clothes'},
|
629 |
+
15: {0: 'without a gray color trousers', 1: 'with a gray color trousers'},
|
630 |
+
16: {0: 'without a red color trousers', 1: 'with a red color trousers'},
|
631 |
+
17: {0: 'without a yellow color trousers', 1: 'with a yellow color trousers'},
|
632 |
+
18: {0: 'without a blue color trousers', 1: 'with a blue color trousers'},
|
633 |
+
19: {0: 'without a green color trousers', 1: 'with a green color trousers'},
|
634 |
+
20: {0: 'without a purple color trousers', 1: 'with a purple color trousers'},
|
635 |
+
21: {0: 'without a pure pattern trousers', 1: 'with a pure lower-clothes'},
|
636 |
+
22: {0: 'without a stripe pattern trousers', 1: 'with a stripe lower-clothes'},
|
637 |
+
23: {0: 'without a design pattern trousers', 1: 'with a design lower-clothes'},
|
638 |
+
24: {0: 'without a joint pattern trousers', 1: 'with a joint lower-clothes'},
|
639 |
+
25: {0: 'without a lattic pattern trousers', 1: 'with a lattic lower-clothes'},
|
640 |
+
26: {0: 'without a hat', 1: 'with a hat', -1: 'hat unknown'},
|
641 |
+
27: {0: 'without a jacket', 1: 'with a jacket'},
|
642 |
+
28: {0: 'without a sweater', 1: 'with a sweater'},
|
643 |
+
29: {0: 'without a long coat', 1: 'with a long coat'},
|
644 |
+
30: {0: 'without a shirt', 1: 'with a shirt'},
|
645 |
+
31: {0: 'without a dress', 1: 'with a dress'},
|
646 |
+
32: {0: 'without a business suit', 1: 'with a business suit'},
|
647 |
+
33: {0: 'without a black color coat', 1: 'with a black color coat', -1:'unknown coat color'},
|
648 |
+
34: {0: 'without a white color coat', 1: 'with a white color coat', -1:'unknown coat color'},
|
649 |
+
35: {0: 'without a gray color coat', 1: 'with a gray color coat', -1:'unknown coat color'},
|
650 |
+
36: {0: 'without a red color coat', 1: 'with a red color coat', -1:'unknown coat color'},
|
651 |
+
37: {0: 'without a yellow color coat', 1: 'with a yellow color coat', -1:'unknown coat color'},
|
652 |
+
38: {0: 'without a blue color coat', 1: 'with a blue color coat', -1:'unknown coat color'},
|
653 |
+
39: {0: 'without a green color coat', 1: 'with a green color coat', -1:'unknown coat color'},
|
654 |
+
40: {0: 'without a purple color coat', 1: 'with a purple color coat', -1:'unknown coat color'},
|
655 |
+
41: {0: 'with short hair', 1: 'with long hair', -1: 'unknown hair style'},
|
656 |
+
42: {0: 'without leather shoes', 1: 'with leather shoes'},
|
657 |
+
43: {0: 'without boots', 1: 'with boots'},
|
658 |
+
44: {0: 'without walking shoes', 1: 'with walking shoes'},
|
659 |
+
45: {0: 'without sandal', 1: 'with sandal'},
|
660 |
+
46: {0: 'without a bag', 1: 'without a bag', -1: 'unknown bag style'},
|
661 |
+
47: {0: 'without glasses', 1: 'with glasses'},
|
662 |
+
48: {0: 'not stand', 1: 'stand', -1: 'unknown pose'},
|
663 |
+
49: {0: 'not sit', 1: 'sit', -1: 'unknown pose'},
|
664 |
+
50: {0: 'not lie', 1: 'lie', -1: 'unknown pose'},
|
665 |
+
51: {0: 'not stoop', 1: 'stoop', -1: 'unknown pose'}}
|
core/data/datasets/images/peddet_dataset_v2.py
ADDED
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
from PIL import ImageFile
|
3 |
+
ImageFile.LOAD_TRUNCATED_IMAGES=True
|
4 |
+
|
5 |
+
import os
|
6 |
+
import os.path
|
7 |
+
|
8 |
+
import random
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import copy
|
12 |
+
|
13 |
+
import time
|
14 |
+
from core.data.transforms.peddet_transforms import PedestrainDetectionAugmentation
|
15 |
+
|
16 |
+
from core.data.datasets.images.seg_dataset_dev import Instances
|
17 |
+
from typing import *
|
18 |
+
import torch.distributed as dist
|
19 |
+
from PIL import Image
|
20 |
+
import json
|
21 |
+
from pycocotools.coco import COCO
|
22 |
+
|
23 |
+
from collections import defaultdict
|
24 |
+
|
25 |
+
__all__ = ['PedestrainDetectionDataset_v2']
|
26 |
+
|
27 |
+
class PetrelCOCO(COCO):
|
28 |
+
def __init__(self, annotation_file=None, annotation=None):
|
29 |
+
"""
|
30 |
+
Constructor of Microsoft COCO helper class for reading and visualizing annotations.
|
31 |
+
:param annotation_file (str): location of annotation file
|
32 |
+
:param annotation (?): partially processed annotation file
|
33 |
+
:return:
|
34 |
+
"""
|
35 |
+
# load dataset
|
36 |
+
self.dataset, self.anns, self.cats, self.imgs = dict(), dict(), dict(), dict()
|
37 |
+
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
|
38 |
+
assert annotation_file is None or annotation is None
|
39 |
+
if annotation_file is not None:
|
40 |
+
print('loading annotations into memory...')
|
41 |
+
tic = time.time()
|
42 |
+
with open(annotation_file, 'r') as f:
|
43 |
+
dataset = json.load(f)
|
44 |
+
assert type(dataset) == dict, 'annotation file format {} not supported'.format(type(dataset))
|
45 |
+
print('Done (t={:0.2f}s)'.format(time.time() - tic))
|
46 |
+
self.dataset = dataset
|
47 |
+
self.createIndex()
|
48 |
+
|
49 |
+
if annotation is not None:
|
50 |
+
print('adding annotations into memory...')
|
51 |
+
tic = time.time()
|
52 |
+
dataset = annotation
|
53 |
+
self.dataset = dataset
|
54 |
+
self.createIndex()
|
55 |
+
|
56 |
+
def convert_coco_poly_to_mask(segmentations, height, width):
|
57 |
+
masks = []
|
58 |
+
for polygons in segmentations:
|
59 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
60 |
+
mask = coco_mask.decode(rles)
|
61 |
+
if len(mask.shape) < 3:
|
62 |
+
mask = mask[..., None]
|
63 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
64 |
+
mask = mask.any(dim=2)
|
65 |
+
masks.append(mask)
|
66 |
+
if masks:
|
67 |
+
masks = torch.stack(masks, dim=0)
|
68 |
+
else:
|
69 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
70 |
+
return masks
|
71 |
+
|
72 |
+
class ConvertCocoPolysToMask(object):
|
73 |
+
def __init__(self, return_masks=False):
|
74 |
+
self.return_masks = return_masks
|
75 |
+
|
76 |
+
def __call__(self, image, target):
|
77 |
+
w, h = image.size
|
78 |
+
|
79 |
+
image_id = target["image_id"]
|
80 |
+
image_id = torch.tensor([image_id])
|
81 |
+
|
82 |
+
anno = target["annotations"]
|
83 |
+
|
84 |
+
boxes = [obj["bbox"] for obj in anno]
|
85 |
+
# guard against no boxes via resizing
|
86 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
87 |
+
boxes[:, 2:] += boxes[:, :2]
|
88 |
+
boxes[:, 0::2].clamp_(min=0, max=w)
|
89 |
+
boxes[:, 1::2].clamp_(min=0, max=h)
|
90 |
+
|
91 |
+
classes = [obj["category_id"] for obj in anno]
|
92 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
93 |
+
|
94 |
+
if self.return_masks:
|
95 |
+
segmentations = [obj["segmentation"] for obj in anno]
|
96 |
+
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
97 |
+
|
98 |
+
keypoints = None
|
99 |
+
if anno and "keypoints" in anno[0]:
|
100 |
+
keypoints = [obj["keypoints"] for obj in anno]
|
101 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
|
102 |
+
num_keypoints = keypoints.shape[0]
|
103 |
+
if num_keypoints:
|
104 |
+
keypoints = keypoints.view(num_keypoints, -1, 3)
|
105 |
+
|
106 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
107 |
+
|
108 |
+
# for conversion to coco api
|
109 |
+
area = torch.tensor([obj["area"] for obj in anno])
|
110 |
+
iscrowd = torch.BoolTensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
|
111 |
+
iscrowd |= classes != 0
|
112 |
+
|
113 |
+
target = {}
|
114 |
+
target["boxes"] = boxes[keep]
|
115 |
+
target["labels"] = classes[keep]
|
116 |
+
if self.return_masks:
|
117 |
+
target["masks"] = masks[keep]
|
118 |
+
target["image_id"] = image_id
|
119 |
+
if keypoints is not None:
|
120 |
+
target["keypoints"] = keypoints[keep]
|
121 |
+
|
122 |
+
target["area"] = area[keep]
|
123 |
+
target["iscrowd"] = iscrowd[keep]
|
124 |
+
|
125 |
+
target["orig_size"] = torch.as_tensor([int(h), int(w)])
|
126 |
+
target["size"] = torch.as_tensor([int(h), int(w)])
|
127 |
+
|
128 |
+
return image, target
|
129 |
+
|
130 |
+
|
131 |
+
class CocoDetection(data.Dataset):
|
132 |
+
"""`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
root (string): Root directory where images are downloaded to.
|
136 |
+
annFile (string): Path to json annotation file.
|
137 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
138 |
+
and returns a transformed version. E.g, ``transforms.ToTensor``
|
139 |
+
target_transform (callable, optional): A function/transform that takes in the
|
140 |
+
target and transforms it.
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self, ann, phase, transform=None, target_transform=None):
|
144 |
+
self.coco = PetrelCOCO(annotation=ann)
|
145 |
+
|
146 |
+
self.ids = list(self.coco.imgs.keys())
|
147 |
+
assert phase in ['train', 'val']
|
148 |
+
self.transform = transform
|
149 |
+
self.phase = phase
|
150 |
+
self.target_transform = target_transform
|
151 |
+
|
152 |
+
self.rank = dist.get_rank()
|
153 |
+
self.world_size = dist.get_world_size()
|
154 |
+
|
155 |
+
self.initialized = True
|
156 |
+
|
157 |
+
def _init_memcached(self):
|
158 |
+
if not self.initialized:
|
159 |
+
## only use mc default
|
160 |
+
print("==> will load files from local machine")
|
161 |
+
server_list_config_file = "/mnt/lustre/share/memcached_client/server_list.conf"
|
162 |
+
client_config_file = "/mnt/lustre/share/memcached_client/client.conf"
|
163 |
+
self.memcached_mclient = mc.MemcachedClient.GetInstance(server_list_config_file, client_config_file)
|
164 |
+
## mc-support-ceph
|
165 |
+
print('mc-support-ceph')
|
166 |
+
self.ceph_mclient = s3client
|
167 |
+
|
168 |
+
self.initialized = True
|
169 |
+
|
170 |
+
def _read_one(self, index=None):
|
171 |
+
"""
|
172 |
+
Args:
|
173 |
+
index (int): Index
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
|
177 |
+
"""
|
178 |
+
if index is None:
|
179 |
+
index = np.random.randint(len(self.ids))
|
180 |
+
|
181 |
+
coco = self.coco
|
182 |
+
img_id = self.ids[index]
|
183 |
+
|
184 |
+
ann_ids = coco.getAnnIds(imgIds=img_id)
|
185 |
+
target = copy.deepcopy(coco.loadAnns(ann_ids))
|
186 |
+
|
187 |
+
for one_target in target:
|
188 |
+
if 'segmentation' in one_target: del one_target['segmentation']
|
189 |
+
if 'keypoints' in one_target: del one_target['keypoints']
|
190 |
+
|
191 |
+
path = coco.loadImgs(img_id)[0]['file_name']
|
192 |
+
img_root = coco.loadImgs(img_id)[0]['img_root']
|
193 |
+
imgname = os.path.splitext(path)[0]
|
194 |
+
|
195 |
+
if self.phase == 'val':
|
196 |
+
if 'CrowdHuman' in img_root:
|
197 |
+
path = path.replace('.png', '.jpg')
|
198 |
+
## for code in lab, we use jpg
|
199 |
+
if 'CrowdHuman' in img_root:
|
200 |
+
path = path.replace('.png', '.jpg')
|
201 |
+
filename = os.path.join(img_root, path)
|
202 |
+
try:
|
203 |
+
img = Image.open(filename).convert('RGB')
|
204 |
+
if img is None:
|
205 |
+
raise Exception("None Image")
|
206 |
+
except:
|
207 |
+
outputName = "failed_to_read_in_train.txt"
|
208 |
+
with open(outputName,"a") as g:
|
209 |
+
g.write("%s\n"%(filename))
|
210 |
+
print('Read image[{}] failed ({})'.format(index, filename))
|
211 |
+
## if fail then recursive call _read_one without idx
|
212 |
+
return self._read_one()
|
213 |
+
else:
|
214 |
+
output = dict()
|
215 |
+
##set random_seed with img idx
|
216 |
+
random.seed(index+self.rank)
|
217 |
+
np.random.seed(index+self.rank)
|
218 |
+
|
219 |
+
if self.transform is not None:
|
220 |
+
img = self.transform(img)
|
221 |
+
|
222 |
+
if self.target_transform is not None:
|
223 |
+
target = self.target_transform(target)
|
224 |
+
|
225 |
+
return img, target, imgname
|
226 |
+
|
227 |
+
def __getitem__(self, index):
|
228 |
+
"""
|
229 |
+
Args:
|
230 |
+
index (int): Index
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
|
234 |
+
"""
|
235 |
+
self._init_memcached()
|
236 |
+
img, target, imgname = self._read_one(index)
|
237 |
+
|
238 |
+
return img, target, imgname
|
239 |
+
|
240 |
+
def __len__(self):
|
241 |
+
return len(self.ids)
|
242 |
+
|
243 |
+
def __repr__(self):
|
244 |
+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
245 |
+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
|
246 |
+
fmt_str += ' Root Location: {}\n'.format(self.root)
|
247 |
+
tmp = ' Transforms (if any): '
|
248 |
+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
249 |
+
tmp = ' Target Transforms (if any): '
|
250 |
+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
251 |
+
return fmt_str
|
252 |
+
|
253 |
+
|
254 |
+
def coco_merge(
|
255 |
+
img_root_list: List[str], input_list: List[str],
|
256 |
+
indent: Optional[int] = None,
|
257 |
+
) -> str:
|
258 |
+
"""Merge COCO annotation files.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
input_extend: Path to input file to be extended.
|
262 |
+
input_add: Path to input file to be added.
|
263 |
+
output_file : Path to output file with merged annotations.
|
264 |
+
indent: Argument passed to `json.dump`. See https://docs.python.org/3/library/json.html#json.dump.
|
265 |
+
"""
|
266 |
+
data_list = []
|
267 |
+
|
268 |
+
for input in input_list:
|
269 |
+
with open(input, 'r') as f:
|
270 |
+
data_extend = json.load(f)
|
271 |
+
|
272 |
+
data_list.append(data_extend)
|
273 |
+
|
274 |
+
output= {'categories': data_list[0]['categories']}
|
275 |
+
|
276 |
+
output["images"], output["annotations"] = [], []
|
277 |
+
|
278 |
+
for i, (data, img_root) in enumerate(zip(data_list, img_root_list)):
|
279 |
+
print(
|
280 |
+
"Input {}: {} images, {} annotations".format(
|
281 |
+
i + 1, len(data["images"]), len(data["annotations"])
|
282 |
+
)
|
283 |
+
)
|
284 |
+
|
285 |
+
cat_id_map = {}
|
286 |
+
for new_cat in data["categories"]:
|
287 |
+
new_id = None
|
288 |
+
for output_cat in output["categories"]:
|
289 |
+
if new_cat["name"] == output_cat["name"]:
|
290 |
+
new_id = output_cat["id"]
|
291 |
+
break
|
292 |
+
|
293 |
+
if new_id is not None:
|
294 |
+
cat_id_map[new_cat["id"]] = new_id
|
295 |
+
else:
|
296 |
+
new_cat_id = max(c["id"] for c in output["categories"]) + 1
|
297 |
+
cat_id_map[new_cat["id"]] = new_cat_id
|
298 |
+
new_cat["id"] = new_cat_id
|
299 |
+
output["categories"].append(new_cat)
|
300 |
+
|
301 |
+
img_id_map = {}
|
302 |
+
for image in data["images"]:
|
303 |
+
n_imgs = len(output["images"])
|
304 |
+
img_id_map[image["id"]] = n_imgs
|
305 |
+
image["id"] = n_imgs
|
306 |
+
image["img_root"] = img_root
|
307 |
+
|
308 |
+
output["images"].append(image)
|
309 |
+
|
310 |
+
for annotation in data["annotations"]:
|
311 |
+
n_anns = len(output["annotations"])
|
312 |
+
annotation["id"] = n_anns
|
313 |
+
annotation["image_id"] = img_id_map[annotation["image_id"]]
|
314 |
+
annotation["category_id"] = cat_id_map[annotation["category_id"]]
|
315 |
+
|
316 |
+
output["annotations"].append(annotation)
|
317 |
+
|
318 |
+
print(
|
319 |
+
"Result: {} images, {} annotations".format(
|
320 |
+
len(output["images"]), len(output["annotations"])
|
321 |
+
)
|
322 |
+
)
|
323 |
+
return output
|
324 |
+
|
325 |
+
|
326 |
+
class PedestrainDetectionDataset_v2(CocoDetection):
|
327 |
+
def __init__(self, ginfo, augmentation, task_spec, train=True, vit=False,
|
328 |
+
num_append_fake_boxes=0,
|
329 |
+
# append to 900 for a fixed length gt input in the sparse labeling (label) branch
|
330 |
+
return_box_xyxy=False,
|
331 |
+
append_z=True,
|
332 |
+
test_trainset=False,
|
333 |
+
**kwargs):
|
334 |
+
img_folder = task_spec['img_folder'] if isinstance(task_spec['img_folder'], list) else [task_spec['img_folder']]
|
335 |
+
ann_file = task_spec['ann_file'] if isinstance(task_spec['ann_file'], list) else [task_spec['ann_file']]
|
336 |
+
self.root = img_folder
|
337 |
+
|
338 |
+
ann = coco_merge(img_folder, ann_file)
|
339 |
+
|
340 |
+
return_masks = task_spec['return_masks']
|
341 |
+
phase = 'train' if train else 'val'
|
342 |
+
|
343 |
+
super(PedestrainDetectionDataset_v2, self).__init__(ann=ann, phase=phase)
|
344 |
+
|
345 |
+
self.return_box_xyxy = return_box_xyxy
|
346 |
+
transforms = PedestrainDetectionAugmentation(phase=phase if not test_trainset else 'val', vit=vit, return_box_xyxy=self.return_box_xyxy,
|
347 |
+
max_size=augmentation.get('max_size',1333),)
|
348 |
+
|
349 |
+
name2wh = {}
|
350 |
+
for img_id in self.ids:
|
351 |
+
img_name = self.coco.loadImgs(img_id)[0]['file_name'].split('.')[0]
|
352 |
+
height = self.coco.loadImgs(img_id)[0]['height']
|
353 |
+
width = self.coco.loadImgs(img_id)[0]['width']
|
354 |
+
name2wh[img_name]={'width':width, 'height': height}
|
355 |
+
|
356 |
+
self.flag = np.zeros(len(self.ids), dtype=np.uint8)
|
357 |
+
for i, img_id in enumerate(self.ids):
|
358 |
+
img_info = self.coco.loadImgs(img_id)[0]['file_name'].split('.')[0]
|
359 |
+
if name2wh[img_info]['width'] / name2wh[img_info]['height'] > 1:
|
360 |
+
self.flag[i] = 1
|
361 |
+
|
362 |
+
self._transforms = transforms
|
363 |
+
self.phase = phase
|
364 |
+
self.prepare = ConvertCocoPolysToMask(return_masks)
|
365 |
+
self.task_name = ginfo.task_name
|
366 |
+
|
367 |
+
self.num_append_fake_boxes = num_append_fake_boxes
|
368 |
+
self.append_z = append_z
|
369 |
+
|
370 |
+
def _filter_ignores(self, target):
|
371 |
+
target = list(filter(lambda rb: rb['category_id'] > -1, target))
|
372 |
+
|
373 |
+
return target
|
374 |
+
|
375 |
+
def _minus_target_label(self, target, value):
|
376 |
+
|
377 |
+
results = []
|
378 |
+
for t in target:
|
379 |
+
t['category_id'] -= value
|
380 |
+
results.append(t)
|
381 |
+
return results
|
382 |
+
|
383 |
+
def __getitem__(self, idx):
|
384 |
+
dataset_dict = {}
|
385 |
+
img, target, imgname = super(PedestrainDetectionDataset_v2, self).__getitem__(idx)
|
386 |
+
target = self._minus_target_label(target, 1)
|
387 |
+
total = len(target)
|
388 |
+
image_id = self.ids[idx]
|
389 |
+
|
390 |
+
target = {'image_id': image_id, 'annotations': target}
|
391 |
+
img, target = self.prepare(img, target)
|
392 |
+
image_shape = (img.size[-1], img.size[-2]) # h, w
|
393 |
+
self._record_image_size(dataset_dict, img)
|
394 |
+
|
395 |
+
if self._transforms is not None:
|
396 |
+
img, target = self._transforms(img, target)
|
397 |
+
|
398 |
+
if self.num_append_fake_boxes > 0:
|
399 |
+
# not take iscrowded boxes into consideration
|
400 |
+
len_target = target['labels'].shape[0]
|
401 |
+
len_append = self.num_append_fake_boxes - len_target
|
402 |
+
target['boxes'] = torch.cat([target['boxes'], torch.zeros([len_append, 4])], dim=0)
|
403 |
+
# the appended label is set to 1(background), as ped det only has one class 0 for pedestrian
|
404 |
+
append_label = 1
|
405 |
+
target['labels'] = torch.cat([target['labels'], torch.ones([len_append]).long()*append_label], dim=0)
|
406 |
+
target['iscrowd'] = torch.cat([target['iscrowd'], torch.ones([len_append]).bool()], dim=0)
|
407 |
+
target['area'] = torch.cat([target['area'], torch.zeros([len_append])], dim=0)
|
408 |
+
|
409 |
+
dataset_dict['orig_size'] = target['orig_size']
|
410 |
+
dataset_dict['size'] = target['size']
|
411 |
+
del target['image_id']
|
412 |
+
del target['orig_size']
|
413 |
+
del target['size']
|
414 |
+
|
415 |
+
instances = Instances(image_shape, **target)
|
416 |
+
|
417 |
+
# sparse_labeling should have a shape of [xyz, T(temperal)=2, V=num_append_fake_boxes, M(num_peopoe)=1]
|
418 |
+
# T=2, as we consider x1y1, x2y2 as two points. Info in two points will be integrated in conv to
|
419 |
+
# have a token representing a box.
|
420 |
+
# import pdb;
|
421 |
+
# pdb.set_trace()
|
422 |
+
sparse_labeling = target['boxes'].reshape(target['boxes'].shape[0], 2, 2).contiguous()
|
423 |
+
if self.append_z:
|
424 |
+
append_z = torch.zeros([target['boxes'].shape[0], 2, 1])
|
425 |
+
sparse_labeling = torch.cat([sparse_labeling, append_z], dim=2) # num_append_fake_boxes, T, xyz
|
426 |
+
sparse_labeling = sparse_labeling.unsqueeze(-1).permute(2, 1, 0, 3).contiguous()
|
427 |
+
|
428 |
+
dataset_dict['sparse_labeling'] = sparse_labeling
|
429 |
+
dataset_dict["image"] = img
|
430 |
+
dataset_dict["image_id"] = image_id
|
431 |
+
dataset_dict["label"] = -1
|
432 |
+
dataset_dict["instances"] = instances
|
433 |
+
dataset_dict["filename"] = imgname
|
434 |
+
|
435 |
+
return dataset_dict
|
436 |
+
|
437 |
+
@staticmethod
|
438 |
+
def _record_image_size(dataset_dict, image):
|
439 |
+
"""
|
440 |
+
Raise an error if the image does not match the size specified in the dict.
|
441 |
+
"""
|
442 |
+
# To ensure bbox always remap to original image size # when in PIL, reversed.
|
443 |
+
if "width" not in dataset_dict:
|
444 |
+
dataset_dict["width"] = image.size[1]
|
445 |
+
if "height" not in dataset_dict:
|
446 |
+
dataset_dict["height"] = image.size[0]
|
447 |
+
|
448 |
+
|
449 |
+
class PedestrainDetectionDataset_v2demo(CocoDetection):
|
450 |
+
def __init__(self, ginfo, augmentation, task_spec, train=True, vit=False,
|
451 |
+
num_append_fake_boxes=0,
|
452 |
+
# append to 900 for a fixed length gt input in the sparse labeling (label) branch
|
453 |
+
return_box_xyxy=False,
|
454 |
+
append_z=True,
|
455 |
+
test_trainset=False,
|
456 |
+
demo_dir='/mnt/cache/tangshixiang/wyz_proj/demo_video_unihcpv2/folder0',
|
457 |
+
**kwargs):
|
458 |
+
img_folder = task_spec['img_folder'] if isinstance(task_spec['img_folder'], list) else [task_spec['img_folder']]
|
459 |
+
ann_file = task_spec['ann_file'] if isinstance(task_spec['ann_file'], list) else [task_spec['ann_file']]
|
460 |
+
self.root = img_folder
|
461 |
+
|
462 |
+
ann = coco_merge(img_folder, ann_file)
|
463 |
+
|
464 |
+
return_masks = task_spec['return_masks']
|
465 |
+
phase = 'train' if train else 'val'
|
466 |
+
|
467 |
+
super(PedestrainDetectionDataset_v2demo, self).__init__(ann=ann, phase=phase)
|
468 |
+
|
469 |
+
self.return_box_xyxy = return_box_xyxy
|
470 |
+
transforms = PedestrainDetectionAugmentation(phase=phase if not test_trainset else 'val', vit=vit, return_box_xyxy=self.return_box_xyxy,
|
471 |
+
max_size=augmentation.get('max_size',1333),)
|
472 |
+
|
473 |
+
name2wh = {}
|
474 |
+
for img_id in self.ids:
|
475 |
+
img_name = self.coco.loadImgs(img_id)[0]['file_name'].split('.')[0]
|
476 |
+
height = self.coco.loadImgs(img_id)[0]['height']
|
477 |
+
width = self.coco.loadImgs(img_id)[0]['width']
|
478 |
+
name2wh[img_name]={'width':width, 'height': height}
|
479 |
+
|
480 |
+
self.flag = np.zeros(len(self.ids), dtype=np.uint8)
|
481 |
+
for i, img_id in enumerate(self.ids):
|
482 |
+
img_info = self.coco.loadImgs(img_id)[0]['file_name'].split('.')[0]
|
483 |
+
if name2wh[img_info]['width'] / name2wh[img_info]['height'] > 1:
|
484 |
+
self.flag[i] = 1
|
485 |
+
|
486 |
+
self._transforms = transforms
|
487 |
+
self.phase = phase
|
488 |
+
self.prepare = ConvertCocoPolysToMask(return_masks)
|
489 |
+
self.task_name = ginfo.task_name
|
490 |
+
|
491 |
+
self.num_append_fake_boxes = num_append_fake_boxes
|
492 |
+
self.append_z = append_z
|
493 |
+
self.demo_dir = demo_dir
|
494 |
+
self.listdir = os.listdir(self.demo_dir)
|
495 |
+
|
496 |
+
def _filter_ignores(self, target):
|
497 |
+
target = list(filter(lambda rb: rb['category_id'] > -1, target))
|
498 |
+
|
499 |
+
return target
|
500 |
+
|
501 |
+
def _minus_target_label(self, target, value):
|
502 |
+
|
503 |
+
results = []
|
504 |
+
for t in target:
|
505 |
+
t['category_id'] -= value
|
506 |
+
results.append(t)
|
507 |
+
return results
|
508 |
+
|
509 |
+
def __len__(self):
|
510 |
+
return len(os.listdir(self.demo_dir))
|
511 |
+
|
512 |
+
def __getitem__(self, idx):
|
513 |
+
dataset_dict = {}
|
514 |
+
img, target, imgname = super(PedestrainDetectionDataset_v2demo, self).__getitem__(0)
|
515 |
+
demo_dir = self.demo_dir
|
516 |
+
filename = os.path.join(demo_dir, self.listdir[idx])
|
517 |
+
img = Image.open(filename).convert('RGB')
|
518 |
+
target = self._minus_target_label(target, 1)
|
519 |
+
total = len(target)
|
520 |
+
image_id = self.ids[0]
|
521 |
+
|
522 |
+
target = {'image_id': image_id, 'annotations': target}
|
523 |
+
img, target = self.prepare(img, target)
|
524 |
+
image_shape = (img.size[-1], img.size[-2]) # h, w
|
525 |
+
self._record_image_size(dataset_dict, img)
|
526 |
+
|
527 |
+
if self._transforms is not None:
|
528 |
+
img, target = self._transforms(img, target)
|
529 |
+
|
530 |
+
if self.num_append_fake_boxes > 0:
|
531 |
+
# not take iscrowded boxes into consideration
|
532 |
+
len_target = target['labels'].shape[0]
|
533 |
+
len_append = self.num_append_fake_boxes - len_target
|
534 |
+
target['boxes'] = torch.cat([target['boxes'], torch.zeros([len_append, 4])], dim=0)
|
535 |
+
# the appended label is set to 1(background), as ped det only has one class 0 for pedestrian
|
536 |
+
append_label = 1
|
537 |
+
target['labels'] = torch.cat([target['labels'], torch.ones([len_append]).long()*append_label], dim=0)
|
538 |
+
target['iscrowd'] = torch.cat([target['iscrowd'], torch.ones([len_append]).bool()], dim=0)
|
539 |
+
target['area'] = torch.cat([target['area'], torch.zeros([len_append])], dim=0)
|
540 |
+
|
541 |
+
dataset_dict['orig_size'] = target['orig_size']
|
542 |
+
dataset_dict['size'] = target['size']
|
543 |
+
del target['image_id']
|
544 |
+
del target['orig_size']
|
545 |
+
del target['size']
|
546 |
+
|
547 |
+
instances = Instances(image_shape, **target)
|
548 |
+
|
549 |
+
# sparse_labeling should have a shape of [xyz, T(temperal)=2, V=num_append_fake_boxes, M(num_peopoe)=1]
|
550 |
+
# T=2, as we consider x1y1, x2y2 as two points. Info in two points will be integrated in conv to
|
551 |
+
# have a token representing a box.
|
552 |
+
# import pdb;
|
553 |
+
# pdb.set_trace()
|
554 |
+
sparse_labeling = target['boxes'].reshape(target['boxes'].shape[0], 2, 2).contiguous()
|
555 |
+
if self.append_z:
|
556 |
+
append_z = torch.zeros([target['boxes'].shape[0], 2, 1])
|
557 |
+
sparse_labeling = torch.cat([sparse_labeling, append_z], dim=2) # num_append_fake_boxes, T, xyz
|
558 |
+
sparse_labeling = sparse_labeling.unsqueeze(-1).permute(2, 1, 0, 3).contiguous()
|
559 |
+
|
560 |
+
dataset_dict['sparse_labeling'] = sparse_labeling
|
561 |
+
dataset_dict["image"] = img
|
562 |
+
dataset_dict["image_id"] = image_id
|
563 |
+
dataset_dict["label"] = -1
|
564 |
+
dataset_dict["instances"] = instances
|
565 |
+
dataset_dict["filename"] = filename
|
566 |
+
|
567 |
+
return dataset_dict
|
568 |
+
|
569 |
+
@staticmethod
|
570 |
+
def _record_image_size(dataset_dict, image):
|
571 |
+
"""
|
572 |
+
Raise an error if the image does not match the size specified in the dict.
|
573 |
+
"""
|
574 |
+
# To ensure bbox always remap to original image size # when in PIL, reversed.
|
575 |
+
if "width" not in dataset_dict:
|
576 |
+
dataset_dict["width"] = image.size[1]
|
577 |
+
if "height" not in dataset_dict:
|
578 |
+
dataset_dict["height"] = image.size[0]
|
core/data/datasets/images/pos_dataset_dev.py
ADDED
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from abc import ABCMeta, abstractmethod
|
7 |
+
import numpy as np
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
import os
|
10 |
+
import cv2
|
11 |
+
import time
|
12 |
+
import random
|
13 |
+
import os.path as osp
|
14 |
+
import os
|
15 |
+
import torch
|
16 |
+
import warnings
|
17 |
+
from collections import OrderedDict, defaultdict
|
18 |
+
from core.data.transforms.pose_transforms import *
|
19 |
+
import json #_tricks as json
|
20 |
+
import numpy as np
|
21 |
+
from xtcocotools.coco import COCO
|
22 |
+
from xtcocotools.cocoeval import COCOeval
|
23 |
+
import torch.distributed as dist
|
24 |
+
|
25 |
+
|
26 |
+
from core.utils import sync_print
|
27 |
+
|
28 |
+
|
29 |
+
class PetrelCOCO(COCO):
|
30 |
+
def __init__(self, annotation_file=None, test_index=None, ann_data=None):
|
31 |
+
"""
|
32 |
+
Constructor of Microsoft COCO helper class for reading and visualizing annotations.
|
33 |
+
:param annotation_file (str): location of annotation file
|
34 |
+
:param image_folder (str): location to the folder that hosts images.
|
35 |
+
:return:
|
36 |
+
"""
|
37 |
+
self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
|
38 |
+
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
|
39 |
+
self.anno_file = [annotation_file]
|
40 |
+
self.test_index = test_index
|
41 |
+
if annotation_file is not None:
|
42 |
+
print('loading annotations into memory...')
|
43 |
+
tic = time.time()
|
44 |
+
# https://github.com/cocodataset/cocoapi/pull/453/
|
45 |
+
if ann_data == None:
|
46 |
+
with open(annotation_file, 'r') as f:
|
47 |
+
dataset = json.load(f)
|
48 |
+
else:
|
49 |
+
dataset = ann_data
|
50 |
+
assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
|
51 |
+
print('Done (t={:0.2f}s)'.format(time.time()- tic))
|
52 |
+
self.dataset = dataset
|
53 |
+
self.createIndex()
|
54 |
+
if 'annotations' in self.dataset:
|
55 |
+
for i in range(len(self.dataset['annotations'])):
|
56 |
+
if self.test_index is not None:
|
57 |
+
keypoints = np.array(self.dataset['annotations'][i]['keypoints']).reshape([-1, 3])
|
58 |
+
keypoints = keypoints[self.test_index, :]
|
59 |
+
self.dataset['annotations'][i]['keypoints'] = keypoints.reshape([-1]).tolist()
|
60 |
+
if 'iscrowd' not in self.dataset['annotations'][i]:
|
61 |
+
self.dataset['annotations'][i]['iscrowd'] = False
|
62 |
+
|
63 |
+
|
64 |
+
class COCOPosDatasetDev(Dataset):
|
65 |
+
"""CocoDataset dataset for top-down pose estimation.
|
66 |
+
|
67 |
+
"Microsoft COCO: Common Objects in Context", ECCV'2014.
|
68 |
+
More details can be found in the `paper
|
69 |
+
<https://arxiv.org/abs/1405.0312>`__ .
|
70 |
+
|
71 |
+
The dataset loads raw features and apply specified transforms
|
72 |
+
to return a dict containing the image tensors and other information.
|
73 |
+
|
74 |
+
COCO keypoint indexes::
|
75 |
+
|
76 |
+
0: 'nose',
|
77 |
+
1: 'left_eye',
|
78 |
+
2: 'right_eye',
|
79 |
+
3: 'left_ear',
|
80 |
+
4: 'right_ear',
|
81 |
+
5: 'left_shoulder',
|
82 |
+
6: 'right_shoulder',
|
83 |
+
7: 'left_elbow',
|
84 |
+
8: 'right_elbow',
|
85 |
+
9: 'left_wrist',
|
86 |
+
10: 'right_wrist',
|
87 |
+
11: 'left_hip',
|
88 |
+
12: 'right_hip',
|
89 |
+
13: 'left_knee',
|
90 |
+
14: 'right_knee',
|
91 |
+
15: 'left_ankle',
|
92 |
+
16: 'right_ankle'
|
93 |
+
|
94 |
+
Args:
|
95 |
+
ann_file (str): Path to the annotation file.
|
96 |
+
img_prefix (str): Path to a directory where images are held.
|
97 |
+
Default: None.
|
98 |
+
data_cfg (dict): config
|
99 |
+
pipeline (list[dict | callable]): A sequence of data transforms.
|
100 |
+
dataset_info (DatasetInfo): A class containing all dataset info.
|
101 |
+
test_mode (bool): Store True when building test or
|
102 |
+
validation dataset. Default: False.
|
103 |
+
"""
|
104 |
+
def __init__(self,
|
105 |
+
ginfo,
|
106 |
+
ann_file,
|
107 |
+
img_prefix,
|
108 |
+
data_cfg,
|
109 |
+
test_mode=False,
|
110 |
+
use_udp=False,
|
111 |
+
use_ceph=False,
|
112 |
+
data_use_ratio=1,
|
113 |
+
**kwargs):
|
114 |
+
self.image_info = {}
|
115 |
+
self.ann_info = {}
|
116 |
+
self.initialized = False
|
117 |
+
|
118 |
+
self.use_ceph = True
|
119 |
+
self.annotations_path = ann_file
|
120 |
+
self.img_prefix = img_prefix
|
121 |
+
self.test_mode = test_mode
|
122 |
+
print('data_cfg0', data_cfg)
|
123 |
+
# data_cfg=demjson.decode(data_cfg)
|
124 |
+
# print('data_cfg',data_cfg)
|
125 |
+
self.ann_info['image_size'] = np.array(data_cfg['image_size'])
|
126 |
+
self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
|
127 |
+
self.ann_info['num_joints'] = data_cfg['num_joints']
|
128 |
+
|
129 |
+
self.ann_info['inference_channel'] = data_cfg['inference_channel']
|
130 |
+
self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
|
131 |
+
self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
|
132 |
+
|
133 |
+
self.db = []
|
134 |
+
self.task_name = ginfo.task_name
|
135 |
+
|
136 |
+
if test_mode:
|
137 |
+
pipeline = [
|
138 |
+
LoadImageFromFile(use_ceph=use_ceph),
|
139 |
+
TopDownAffine(use_udp=use_udp),
|
140 |
+
ToUNTensor(),
|
141 |
+
Collect(keys=['image'],
|
142 |
+
meta_keys=['image_file', 'center', 'bbox', 'scale', 'rotation', 'bbox_score', 'flip_pairs'])
|
143 |
+
]
|
144 |
+
else:
|
145 |
+
pipeline = [
|
146 |
+
LoadImageFromFile(use_ceph=use_ceph),
|
147 |
+
TopDownRandomFlip(flip_prob=0.5),
|
148 |
+
TopDownHalfBodyTransform(num_joints_half_body=8,prob_half_body=0.3),
|
149 |
+
TopDownGetRandomScaleRotation(rot_factor=40, scale_factor=0.5),
|
150 |
+
TopDownAffine(use_udp=use_udp),
|
151 |
+
ToUNTensor(),
|
152 |
+
TopDownGenerateTarget(sigma=2, encoding='UDP' if use_udp else 'MSRA'),
|
153 |
+
Collect(keys=['image', 'label', 'target_weight'],
|
154 |
+
meta_keys=['image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale','rotation',
|
155 |
+
'bbox_score', 'flip_pairs'])
|
156 |
+
]
|
157 |
+
|
158 |
+
self.pipeline = ComposeX(pipeline)
|
159 |
+
self.use_gt_bbox = data_cfg['use_gt_bbox']
|
160 |
+
self.bbox_file = data_cfg['bbox_file'] if data_cfg['bbox_file'].startswith('/mnt') else (Path(__file__).parent / 'resources' / data_cfg['bbox_file']).resolve()
|
161 |
+
self.det_bbox_thr = data_cfg.get('det_bbox_thr', 0.0)
|
162 |
+
if 'image_thr' in data_cfg:
|
163 |
+
warnings.warn(
|
164 |
+
'image_thr is deprecated, '
|
165 |
+
'please use det_bbox_thr instead', DeprecationWarning)
|
166 |
+
self.det_bbox_thr = data_cfg['image_thr']
|
167 |
+
|
168 |
+
|
169 |
+
self.ann_info['flip_pairs'] = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
|
170 |
+
[11, 12], [13, 14], [15, 16]]
|
171 |
+
|
172 |
+
self.ann_info['upper_body_ids'] = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
|
173 |
+
self.ann_info['lower_body_ids'] = (11, 12, 13, 14, 15, 16)
|
174 |
+
|
175 |
+
self.ann_info['use_different_joint_weights'] = False
|
176 |
+
self.ann_info['joint_weights'] = np.array(
|
177 |
+
[
|
178 |
+
1., 1., 1., 1., 1., 1., 1., 1.2, 1.2, 1.5, 1.5, 1., 1., 1.2,
|
179 |
+
1.2, 1.5, 1.5
|
180 |
+
],
|
181 |
+
dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
|
182 |
+
|
183 |
+
# 'https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/'
|
184 |
+
# 'pycocotools/cocoeval.py#L523'
|
185 |
+
|
186 |
+
self.coco = PetrelCOCO(ann_file)
|
187 |
+
|
188 |
+
cats = [
|
189 |
+
cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
|
190 |
+
]
|
191 |
+
self.classes = ['__background__'] + cats
|
192 |
+
self.num_classes = len(self.classes)
|
193 |
+
self._class_to_ind = dict(zip(self.classes, range(self.num_classes)))
|
194 |
+
self._class_to_coco_ind = dict(zip(cats, self.coco.getCatIds()))
|
195 |
+
self._coco_ind_to_class_ind = dict(
|
196 |
+
(self._class_to_coco_ind[cls], self._class_to_ind[cls])
|
197 |
+
for cls in self.classes[1:])
|
198 |
+
self.img_ids = self.coco.getImgIds()
|
199 |
+
self.num_images = len(self.img_ids)
|
200 |
+
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
|
201 |
+
self.dataset_name = 'coco'
|
202 |
+
|
203 |
+
self.db = self._get_db()
|
204 |
+
if data_use_ratio != 1:
|
205 |
+
self.db = random.sample(self.db, int(len(self.db) * data_use_ratio))
|
206 |
+
|
207 |
+
print(f'=> COCOPosDatasetDev num_images: {self.num_images}')
|
208 |
+
print(f'=> COCOPosDatasetDev load {len(self.db)} samples')
|
209 |
+
|
210 |
+
@staticmethod
|
211 |
+
def _get_mapping_id_name(imgs):
|
212 |
+
"""
|
213 |
+
Args:
|
214 |
+
imgs (dict): dict of image info.
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
tuple: Image name & id mapping dicts.
|
218 |
+
|
219 |
+
- id2name (dict): Mapping image id to name.
|
220 |
+
- name2id (dict): Mapping image name to id.
|
221 |
+
"""
|
222 |
+
id2name = {}
|
223 |
+
name2id = {}
|
224 |
+
for image_id, image in imgs.items():
|
225 |
+
file_name = image['file_name']
|
226 |
+
id2name[image_id] = file_name
|
227 |
+
name2id[file_name] = image_id
|
228 |
+
|
229 |
+
return id2name, name2id
|
230 |
+
|
231 |
+
def _get_db(self):
|
232 |
+
"""Load dataset."""
|
233 |
+
if (not self.test_mode) or self.use_gt_bbox:
|
234 |
+
# use ground truth bbox
|
235 |
+
gt_db = self._load_coco_keypoint_annotations()
|
236 |
+
else:
|
237 |
+
# use bbox from detection
|
238 |
+
gt_db = self._load_coco_person_detection_results()
|
239 |
+
return gt_db
|
240 |
+
|
241 |
+
def _load_coco_keypoint_annotations(self):
|
242 |
+
"""Ground truth bbox and keypoints."""
|
243 |
+
gt_db = []
|
244 |
+
for img_id in self.img_ids:
|
245 |
+
gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
|
246 |
+
return gt_db
|
247 |
+
|
248 |
+
def _load_coco_keypoint_annotation_kernel(self, img_id):
|
249 |
+
"""load annotation from COCOAPI.
|
250 |
+
|
251 |
+
Note:
|
252 |
+
bbox:[x1, y1, w, h]
|
253 |
+
Args:
|
254 |
+
img_id: coco image id
|
255 |
+
Returns:
|
256 |
+
dict: db entry
|
257 |
+
"""
|
258 |
+
img_ann = self.coco.loadImgs(img_id)[0]
|
259 |
+
width = img_ann['width']
|
260 |
+
height = img_ann['height']
|
261 |
+
num_joints = self.ann_info['num_joints']
|
262 |
+
|
263 |
+
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
|
264 |
+
objs = self.coco.loadAnns(ann_ids)
|
265 |
+
|
266 |
+
# sanitize bboxes
|
267 |
+
valid_objs = []
|
268 |
+
for obj in objs:
|
269 |
+
if 'bbox' not in obj:
|
270 |
+
continue
|
271 |
+
x, y, w, h = obj['bbox']
|
272 |
+
x1 = max(0, x)
|
273 |
+
y1 = max(0, y)
|
274 |
+
x2 = min(width - 1, x1 + max(0, w - 1))
|
275 |
+
y2 = min(height - 1, y1 + max(0, h - 1))
|
276 |
+
if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
|
277 |
+
obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
|
278 |
+
valid_objs.append(obj)
|
279 |
+
objs = valid_objs
|
280 |
+
|
281 |
+
bbox_id = 0
|
282 |
+
rec = []
|
283 |
+
for obj in objs:
|
284 |
+
if 'keypoints' not in obj:
|
285 |
+
continue
|
286 |
+
if max(obj['keypoints']) == 0:
|
287 |
+
continue
|
288 |
+
if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
|
289 |
+
continue
|
290 |
+
joints_3d = np.zeros((num_joints, 3), dtype=np.float32)
|
291 |
+
joints_3d_visible = np.zeros((num_joints, 3), dtype=np.float32)
|
292 |
+
|
293 |
+
keypoints = np.array(obj['keypoints']).reshape(-1, 3)
|
294 |
+
joints_3d[:, :2] = keypoints[:, :2]
|
295 |
+
joints_3d_visible[:, :2] = np.minimum(1, keypoints[:, 2:3])
|
296 |
+
|
297 |
+
center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
|
298 |
+
|
299 |
+
image_file = os.path.join(self.img_prefix, self.id2name[img_id])
|
300 |
+
rec.append({
|
301 |
+
'image_file': image_file,
|
302 |
+
'center': center,
|
303 |
+
'scale': scale,
|
304 |
+
'bbox': obj['clean_bbox'][:4],
|
305 |
+
'rotation': 0,
|
306 |
+
'joints_3d': joints_3d,
|
307 |
+
'joints_3d_visible': joints_3d_visible,
|
308 |
+
'dataset': self.dataset_name,
|
309 |
+
'bbox_score': 1,
|
310 |
+
'bbox_id': bbox_id
|
311 |
+
})
|
312 |
+
bbox_id = bbox_id + 1
|
313 |
+
|
314 |
+
return rec
|
315 |
+
|
316 |
+
def _xywh2cs(self, x, y, w, h):
|
317 |
+
"""This encodes bbox(x,y,w,w) into (center, scale)
|
318 |
+
|
319 |
+
Args:
|
320 |
+
x, y, w, h
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
tuple: A tuple containing center and scale.
|
324 |
+
|
325 |
+
- center (np.ndarray[float32](2,)): center of the bbox (x, y).
|
326 |
+
- scale (np.ndarray[float32](2,)): scale of the bbox w & h.
|
327 |
+
"""
|
328 |
+
aspect_ratio = self.ann_info['image_size'][0] / self.ann_info[
|
329 |
+
'image_size'][1]
|
330 |
+
center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
|
331 |
+
|
332 |
+
if (not self.test_mode) and np.random.rand() < 0.3:
|
333 |
+
center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
|
334 |
+
|
335 |
+
if w > aspect_ratio * h:
|
336 |
+
h = w * 1.0 / aspect_ratio
|
337 |
+
elif w < aspect_ratio * h:
|
338 |
+
w = h * aspect_ratio
|
339 |
+
|
340 |
+
# pixel std is 200.0
|
341 |
+
scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
|
342 |
+
# padding to include proper amount of context
|
343 |
+
scale = scale * 1.25
|
344 |
+
|
345 |
+
return center, scale
|
346 |
+
|
347 |
+
def _load_coco_person_detection_results(self):
|
348 |
+
"""Load coco person detection results."""
|
349 |
+
num_joints = self.ann_info['num_joints']
|
350 |
+
|
351 |
+
with open(self.bbox_file, 'r') as f:
|
352 |
+
all_boxes = json.load(f)
|
353 |
+
|
354 |
+
if not all_boxes:
|
355 |
+
raise ValueError('=> Load %s fail!' % self.bbox_file)
|
356 |
+
|
357 |
+
print(f'=> Total boxes: {len(all_boxes)}')
|
358 |
+
|
359 |
+
kpt_db = []
|
360 |
+
bbox_id = 0
|
361 |
+
for det_res in all_boxes:
|
362 |
+
if det_res['category_id'] != 1:
|
363 |
+
continue
|
364 |
+
|
365 |
+
image_file = os.path.join(self.img_prefix,
|
366 |
+
self.id2name[det_res['image_id']])
|
367 |
+
box = det_res['bbox']
|
368 |
+
score = det_res['score']
|
369 |
+
|
370 |
+
if score < self.det_bbox_thr:
|
371 |
+
continue
|
372 |
+
|
373 |
+
center, scale = self._xywh2cs(*box[:4])
|
374 |
+
joints_3d = np.zeros((num_joints, 3), dtype=np.float32)
|
375 |
+
joints_3d_visible = np.ones((num_joints, 3), dtype=np.float32)
|
376 |
+
kpt_db.append({
|
377 |
+
'image_file': image_file,
|
378 |
+
'center': center,
|
379 |
+
'scale': scale,
|
380 |
+
'rotation': 0,
|
381 |
+
'bbox': box[:4],
|
382 |
+
'bbox_score': score,
|
383 |
+
'dataset': self.dataset_name,
|
384 |
+
'joints_3d': joints_3d,
|
385 |
+
'joints_3d_visible': joints_3d_visible,
|
386 |
+
'bbox_id': bbox_id
|
387 |
+
})
|
388 |
+
bbox_id = bbox_id + 1
|
389 |
+
print(f'=> Total boxes after filter '
|
390 |
+
f'low score@{self.det_bbox_thr}: {bbox_id}')
|
391 |
+
return kpt_db
|
392 |
+
|
393 |
+
def __len__(self):
|
394 |
+
"""Get the size of the dataset."""
|
395 |
+
return len(self.db)
|
396 |
+
|
397 |
+
def __getitem__(self, idx):
|
398 |
+
"""Get the sample given index."""
|
399 |
+
results = copy.deepcopy(self.db[idx])
|
400 |
+
results['ann_info'] = self.ann_info
|
401 |
+
out = self.pipeline(results)
|
402 |
+
C = self.num_classes - 1 # delete the background class
|
403 |
+
if 'label' in out:
|
404 |
+
out['dense_labeling'] = np.resize(out['label'], (C, self.ann_info['image_size'][0], self.ann_info['image_size'][1]))
|
405 |
+
else:
|
406 |
+
out['dense_labeling'] = np.zeros((C, self.ann_info['image_size'][0], self.ann_info['image_size'][1]))
|
407 |
+
# del out['ann_info']
|
408 |
+
return out # dict_keys(['image_file', 'center', 'scale', 'bbox', 'rotation', 'joints_3d', 'joints_3d_visible',
|
409 |
+
# 'dataset', 'bbox_score', 'bbox_id', 'ann_info', 'image', 'flipped', 'label',
|
410 |
+
# 'target_weight'])
|
411 |
+
|
412 |
+
|
413 |
+
class MPIIPosDatasetDev(Dataset):
|
414 |
+
def __init__(self,
|
415 |
+
ginfo,
|
416 |
+
ann_file,
|
417 |
+
img_prefix,
|
418 |
+
data_cfg,
|
419 |
+
test_mode=False,
|
420 |
+
use_udp=False,
|
421 |
+
data_use_ratio=1,
|
422 |
+
**kwargs):
|
423 |
+
|
424 |
+
self.image_info = {}
|
425 |
+
self.ann_info = {}
|
426 |
+
|
427 |
+
self.ann_file = ann_file
|
428 |
+
self.img_prefix = img_prefix
|
429 |
+
|
430 |
+
self.test_mode = test_mode
|
431 |
+
|
432 |
+
self.ann_info['image_size'] = np.array(data_cfg['image_size'])
|
433 |
+
self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
|
434 |
+
self.ann_info['num_joints'] = data_cfg['num_joints']
|
435 |
+
|
436 |
+
self.ann_info['inference_channel'] = data_cfg['inference_channel']
|
437 |
+
self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
|
438 |
+
self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
|
439 |
+
|
440 |
+
self.ann_info['use_different_joint_weights'] = data_cfg.get(
|
441 |
+
'use_different_joint_weights', False)
|
442 |
+
|
443 |
+
assert self.ann_info['num_joints'] == 16
|
444 |
+
self.ann_info['flip_pairs'] = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]]
|
445 |
+
self.ann_info['flip_index'] = [5, 4, 3, 2, 1, 0, 6, 7, 8, 9, 15, 14, 13, 12, 11, 10]
|
446 |
+
self.ann_info['upper_body_ids'] = [7, 8, 9, 10, 11, 12, 13, 14, 15]
|
447 |
+
self.ann_info['lower_body_ids'] = [0, 1, 2, 3, 4, 5, 6]
|
448 |
+
self.ann_info['joint_weights'] = np.array([
|
449 |
+
1.5, 1.2, 1., 1., 1.2, 1.5, 1., 1., 1., 1., 1.5, 1.2, 1., 1., 1.2, 1.5
|
450 |
+
])
|
451 |
+
self.ann_info['skeleton'] = [[0, 1], [1, 2], [2, 6], [6, 3], [3, 4], [4, 5], [6, 7],
|
452 |
+
[7, 8], [8, 9], [8, 12], [12, 11], [11, 10], [8, 13], [13, 14], [14, 15]]
|
453 |
+
self.sigmas = np.array(
|
454 |
+
[0.089, 0.083, 0.107, 0.107, 0.083, 0.089, 0.026, 0.026, 0.026, 0.026, 0.062, 0.072, 0.179, 0.179, 0.072,
|
455 |
+
0.062])
|
456 |
+
self.dataset_name = 'mpii'
|
457 |
+
|
458 |
+
self.db = self._get_db()
|
459 |
+
if data_use_ratio != 1:
|
460 |
+
self.db = random.sample(self.db, int(len(self.db) * data_use_ratio))
|
461 |
+
|
462 |
+
self.image_set = set(x['image_file'] for x in self.db)
|
463 |
+
self.num_images = len(self.image_set)
|
464 |
+
|
465 |
+
print(f'=> num_images: {self.num_images}')
|
466 |
+
print(f'=> load {len(self.db)} samples')
|
467 |
+
|
468 |
+
if test_mode:
|
469 |
+
pipeline = [
|
470 |
+
LoadImageFromFile(),
|
471 |
+
TopDownAffine(use_udp=use_udp),
|
472 |
+
ToUNTensor(),
|
473 |
+
Collect(keys=['image'],
|
474 |
+
meta_keys=['image_file', 'center', 'scale', 'rotation', 'flip_pairs'])
|
475 |
+
]
|
476 |
+
else:
|
477 |
+
pipeline = [
|
478 |
+
LoadImageFromFile(),
|
479 |
+
TopDownRandomFlip(flip_prob=0.5),
|
480 |
+
TopDownGetRandomScaleRotation(rot_factor=40, scale_factor=0.5),
|
481 |
+
TopDownAffine(use_udp=use_udp),
|
482 |
+
ToUNTensor(),
|
483 |
+
TopDownGenerateTarget(sigma=2, encoding='UDP' if use_udp else 'MSRA'),
|
484 |
+
Collect(keys=['image', 'label', 'target_weight'],
|
485 |
+
meta_keys=['image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation',
|
486 |
+
'flip_pairs'])
|
487 |
+
]
|
488 |
+
self.pipeline = ComposeX(pipeline)
|
489 |
+
|
490 |
+
self.task_name = ginfo.task_name
|
491 |
+
|
492 |
+
@staticmethod
|
493 |
+
def _get_mapping_id_name(imgs):
|
494 |
+
"""
|
495 |
+
Args:
|
496 |
+
imgs (dict): dict of image info.
|
497 |
+
Returns:
|
498 |
+
tuple: Image name & id mapping dicts.
|
499 |
+
- id2name (dict): Mapping image id to name.
|
500 |
+
- name2id (dict): Mapping image name to id.
|
501 |
+
"""
|
502 |
+
id2name = {}
|
503 |
+
name2id = {}
|
504 |
+
for image_id, image in imgs.items():
|
505 |
+
file_name = image['file_name']
|
506 |
+
id2name[image_id] = file_name
|
507 |
+
name2id[file_name] = image_id
|
508 |
+
|
509 |
+
return id2name, name2id
|
510 |
+
|
511 |
+
def _xywh2cs(self, x, y, w, h, padding=1.25):
|
512 |
+
"""This encodes bbox(x,y,w,h) into (center, scale)
|
513 |
+
Args:
|
514 |
+
x, y, w, h (float): left, top, width and height
|
515 |
+
padding (float): bounding box padding factor
|
516 |
+
Returns:
|
517 |
+
center (np.ndarray[float32](2,)): center of the bbox (x, y).
|
518 |
+
scale (np.ndarray[float32](2,)): scale of the bbox w & h.
|
519 |
+
"""
|
520 |
+
aspect_ratio = self.ann_info['image_size'][0] / self.ann_info[
|
521 |
+
'image_size'][1]
|
522 |
+
center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
|
523 |
+
|
524 |
+
if (not self.test_mode) and np.random.rand() < 0.3:
|
525 |
+
center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
|
526 |
+
|
527 |
+
if w > aspect_ratio * h:
|
528 |
+
h = w * 1.0 / aspect_ratio
|
529 |
+
elif w < aspect_ratio * h:
|
530 |
+
w = h * aspect_ratio
|
531 |
+
|
532 |
+
# pixel std is 200.0
|
533 |
+
scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
|
534 |
+
# padding to include proper amount of context
|
535 |
+
scale = scale * padding
|
536 |
+
|
537 |
+
return center, scale
|
538 |
+
|
539 |
+
def _get_normalize_factor(self, gts, *args, **kwargs):
|
540 |
+
"""Get the normalize factor. generally inter-ocular distance measured
|
541 |
+
as the Euclidean distance between the outer corners of the eyes is
|
542 |
+
used. This function should be overrode, to measure NME.
|
543 |
+
Args:
|
544 |
+
gts (np.ndarray[N, K, 2]): Groundtruth keypoint location.
|
545 |
+
Returns:
|
546 |
+
np.ndarray[N, 2]: normalized factor
|
547 |
+
"""
|
548 |
+
return np.ones([gts.shape[0], 2], dtype=np.float32)
|
549 |
+
|
550 |
+
def _get_db(self):
|
551 |
+
"""Load dataset."""
|
552 |
+
# create train/val split
|
553 |
+
with open(self.ann_file, 'r') as f:
|
554 |
+
anno = json.load(f)
|
555 |
+
|
556 |
+
gt_db = []
|
557 |
+
bbox_id = 0
|
558 |
+
for a in anno:
|
559 |
+
image_name = a['image']
|
560 |
+
|
561 |
+
center = np.array(a['center'], dtype=np.float32)
|
562 |
+
scale = np.array([a['scale'], a['scale']], dtype=np.float32)
|
563 |
+
|
564 |
+
# Adjust center/scale slightly to avoid cropping limbs
|
565 |
+
if center[0] != -1:
|
566 |
+
center[1] = center[1] + 15 * scale[1]
|
567 |
+
# padding to include proper amount of context
|
568 |
+
scale = scale * 1.25
|
569 |
+
|
570 |
+
# MPII uses matlab format, index is 1-based,
|
571 |
+
# we should first convert to 0-based index
|
572 |
+
center = center - 1
|
573 |
+
|
574 |
+
joints_3d = np.zeros((self.ann_info['num_joints'], 3),
|
575 |
+
dtype=np.float32)
|
576 |
+
joints_3d_visible = np.zeros((self.ann_info['num_joints'], 3),
|
577 |
+
dtype=np.float32)
|
578 |
+
if not self.test_mode:
|
579 |
+
joints = np.array(a['joints'])
|
580 |
+
joints_vis = np.array(a['joints_vis'])
|
581 |
+
assert len(joints) == self.ann_info['num_joints'], \
|
582 |
+
f'joint num diff: {len(joints)}' + \
|
583 |
+
f' vs {self.ann_info["num_joints"]}'
|
584 |
+
|
585 |
+
joints_3d[:, 0:2] = joints[:, 0:2] - 1
|
586 |
+
joints_3d_visible[:, :2] = joints_vis[:, None]
|
587 |
+
image_file = osp.join(self.img_prefix, image_name)
|
588 |
+
gt_db.append({
|
589 |
+
'image_file': image_file,
|
590 |
+
'bbox_id': bbox_id,
|
591 |
+
'center': center,
|
592 |
+
'scale': scale,
|
593 |
+
'rotation': 0,
|
594 |
+
'joints_3d': joints_3d,
|
595 |
+
'joints_3d_visible': joints_3d_visible,
|
596 |
+
'dataset': self.dataset_name,
|
597 |
+
'bbox_score': 1
|
598 |
+
})
|
599 |
+
bbox_id = bbox_id + 1
|
600 |
+
gt_db = sorted(gt_db, key=lambda x: x['bbox_id'])
|
601 |
+
|
602 |
+
return gt_db
|
603 |
+
|
604 |
+
@staticmethod
|
605 |
+
def _write_keypoint_results(keypoints, res_file):
|
606 |
+
"""Write results into a json file."""
|
607 |
+
|
608 |
+
with open(res_file, 'w') as f:
|
609 |
+
json.dump(keypoints, f, sort_keys=True, indent=4)
|
610 |
+
|
611 |
+
def _report_metric(self,
|
612 |
+
res_file,
|
613 |
+
metrics,
|
614 |
+
pck_thr=0.2,
|
615 |
+
pckh_thr=0.7,
|
616 |
+
auc_nor=30):
|
617 |
+
"""Keypoint evaluation.
|
618 |
+
Args:
|
619 |
+
res_file (str): Json file stored prediction results.
|
620 |
+
metrics (str | list[str]): Metric to be performed.
|
621 |
+
Options: 'PCK', 'PCKh', 'AUC', 'EPE', 'NME'.
|
622 |
+
pck_thr (float): PCK threshold, default as 0.2.
|
623 |
+
pckh_thr (float): PCKh threshold, default as 0.7.
|
624 |
+
auc_nor (float): AUC normalization factor, default as 30 pixel.
|
625 |
+
Returns:
|
626 |
+
List: Evaluation results for evaluation metric.
|
627 |
+
"""
|
628 |
+
info_str = []
|
629 |
+
|
630 |
+
with open(res_file, 'r') as fin:
|
631 |
+
preds = json.load(fin)
|
632 |
+
assert len(preds) == len(self.db)
|
633 |
+
|
634 |
+
outputs = []
|
635 |
+
gts = []
|
636 |
+
masks = []
|
637 |
+
box_sizes = []
|
638 |
+
threshold_bbox = []
|
639 |
+
threshold_head_box = []
|
640 |
+
|
641 |
+
for pred, item in zip(preds, self.db):
|
642 |
+
outputs.append(np.array(pred['keypoints'])[:, :-1])
|
643 |
+
gts.append(np.array(item['joints_3d'])[:, :-1])
|
644 |
+
masks.append((np.array(item['joints_3d_visible'])[:, 0]) > 0)
|
645 |
+
if 'PCK' in metrics:
|
646 |
+
bbox = np.array(item['bbox'])
|
647 |
+
bbox_thr = np.max(bbox[2:])
|
648 |
+
threshold_bbox.append(np.array([bbox_thr, bbox_thr]))
|
649 |
+
if 'PCKh' in metrics:
|
650 |
+
head_box_thr = item['head_size']
|
651 |
+
threshold_head_box.append(
|
652 |
+
np.array([head_box_thr, head_box_thr]))
|
653 |
+
box_sizes.append(item.get('box_size', 1))
|
654 |
+
|
655 |
+
outputs = np.array(outputs)
|
656 |
+
gts = np.array(gts)
|
657 |
+
masks = np.array(masks)
|
658 |
+
threshold_bbox = np.array(threshold_bbox)
|
659 |
+
threshold_head_box = np.array(threshold_head_box)
|
660 |
+
box_sizes = np.array(box_sizes).reshape([-1, 1])
|
661 |
+
|
662 |
+
if 'PCK' in metrics:
|
663 |
+
_, pck, _ = keypoint_pck_accuracy(outputs, gts, masks, pck_thr,
|
664 |
+
threshold_bbox)
|
665 |
+
info_str.append(('PCK', pck))
|
666 |
+
|
667 |
+
if 'PCKh' in metrics:
|
668 |
+
_, pckh, _ = keypoint_pck_accuracy(outputs, gts, masks, pckh_thr,
|
669 |
+
threshold_head_box)
|
670 |
+
info_str.append(('PCKh', pckh))
|
671 |
+
|
672 |
+
if 'AUC' in metrics:
|
673 |
+
info_str.append(('AUC', keypoint_auc(outputs, gts, masks,
|
674 |
+
auc_nor)))
|
675 |
+
|
676 |
+
if 'EPE' in metrics:
|
677 |
+
info_str.append(('EPE', keypoint_epe(outputs, gts, masks)))
|
678 |
+
|
679 |
+
if 'NME' in metrics:
|
680 |
+
normalize_factor = self._get_normalize_factor(
|
681 |
+
gts=gts, box_sizes=box_sizes)
|
682 |
+
info_str.append(
|
683 |
+
('NME', keypoint_nme(outputs, gts, masks, normalize_factor)))
|
684 |
+
|
685 |
+
return info_str
|
686 |
+
|
687 |
+
def __len__(self):
|
688 |
+
"""Get the size of the dataset."""
|
689 |
+
return len(self.db)
|
690 |
+
|
691 |
+
def __getitem__(self, idx):
|
692 |
+
"""Get the sample given index."""
|
693 |
+
results = copy.deepcopy(self.db[idx])
|
694 |
+
results['ann_info'] = self.ann_info
|
695 |
+
out = self.pipeline(results)
|
696 |
+
C = self.ann_info['num_joints']
|
697 |
+
if 'label' in out:
|
698 |
+
out['dense_labeling'] = np.resize(out['label'],
|
699 |
+
(C, self.ann_info['image_size'][1], self.ann_info['image_size'][0]))
|
700 |
+
else:
|
701 |
+
out['dense_labeling'] = np.zeros((C, self.ann_info['image_size'][1], self.ann_info['image_size'][0]))
|
702 |
+
# import pdb;pdb.set_trace()
|
703 |
+
return out
|
704 |
+
|
705 |
+
def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
|
706 |
+
"""sort kpts and remove the repeated ones."""
|
707 |
+
kpts = sorted(kpts, key=lambda x: x[key])
|
708 |
+
num = len(kpts)
|
709 |
+
for i in range(num - 1, 0, -1):
|
710 |
+
if kpts[i][key] == kpts[i - 1][key]:
|
711 |
+
del kpts[i]
|
712 |
+
|
713 |
+
return kpts
|
core/data/datasets/images/resources/CHval.odgt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e8f9c1b0cb455d6b0fb53b73d1dd92fbb3b3b02bfd897661ceddc246e4991b5e
|
3 |
+
size 19994003
|
core/data/datasets/images/resources/COCO_val2017_detections_AP_H_56_person.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:53ba0ad8d0fd461c5a000cd90797fa8c39cd8c38cd125125c0412626ff592d59
|
3 |
+
size 16383781
|
core/data/datasets/images/resources/mpii_gt_val.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ab6874f858046c74acdd1b9dacb8746a2ddddc331952487a9774e3ee0c2b075
|
3 |
+
size 1257356
|
core/data/datasets/images/resources/test_caltech_heavy_1xnew.odgt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
core/data/datasets/images/seg_data_tools/__init__.py
ADDED
File without changes
|
core/data/datasets/images/seg_data_tools/collate.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.utils.data.dataloader import default_collate
|
6 |
+
|
7 |
+
from lib.extensions.parallel.data_container import DataContainer
|
8 |
+
|
9 |
+
|
10 |
+
def stack(batch, data_key=None, return_dc=False):
|
11 |
+
if isinstance(batch[0][data_key], DataContainer):
|
12 |
+
if batch[0][data_key].stack:
|
13 |
+
assert isinstance(batch[0][data_key].data, torch.Tensor)
|
14 |
+
samples = [sample[data_key].data for sample in batch]
|
15 |
+
return default_collate(samples)
|
16 |
+
|
17 |
+
elif not return_dc:
|
18 |
+
return [sample[data_key].data for sample in batch]
|
19 |
+
|
20 |
+
else:
|
21 |
+
return DataContainer([sample[data_key].data for sample in batch])
|
22 |
+
|
23 |
+
else:
|
24 |
+
return default_collate([sample[data_key] for sample in batch])
|
25 |
+
|
26 |
+
|
27 |
+
def collate(batch, trans_dict):
|
28 |
+
data_keys = batch[0].keys()
|
29 |
+
|
30 |
+
target_width, target_height = trans_dict['input_size']
|
31 |
+
target_widths, target_heights = [target_width] * len(batch), [target_height] * len(batch)
|
32 |
+
|
33 |
+
|
34 |
+
for i in range(len(batch)):
|
35 |
+
target_width, target_height = target_widths[i], target_heights[i]
|
36 |
+
|
37 |
+
if 'meta' in data_keys:
|
38 |
+
batch[i]['meta'].data['input_size'] = [target_width, target_height]
|
39 |
+
|
40 |
+
channels, height, width = batch[i]['img'].size()
|
41 |
+
if height == target_height and width == target_width:
|
42 |
+
continue
|
43 |
+
|
44 |
+
scaled_size = [width, height]
|
45 |
+
|
46 |
+
if trans_dict['align_method'] in ['only_scale', 'scale_and_pad']:
|
47 |
+
w_scale_ratio = target_width / width
|
48 |
+
h_scale_ratio = target_height / height
|
49 |
+
if trans_dict['align_method'] == 'scale_and_pad':
|
50 |
+
w_scale_ratio = min(w_scale_ratio, h_scale_ratio)
|
51 |
+
h_scale_ratio = w_scale_ratio
|
52 |
+
|
53 |
+
scaled_size = (int(round(width * w_scale_ratio)), int(round(height * h_scale_ratio)))
|
54 |
+
if 'meta' in data_keys and 'border_size' in batch[i]['meta'].data:
|
55 |
+
batch[i]['meta'].data['border_size'] = scaled_size
|
56 |
+
|
57 |
+
scaled_size_hw = (scaled_size[1], scaled_size[0])
|
58 |
+
batch[i]['img'] = DataContainer(F.interpolate(batch[i]['img'].data.unsqueeze(0),
|
59 |
+
scaled_size_hw, mode='bilinear', align_corners=True).squeeze(0), stack=True)
|
60 |
+
if 'labelmap' in data_keys:
|
61 |
+
labelmap = batch[i]['labelmap'].data.unsqueeze(0).unsqueeze(0).float()
|
62 |
+
labelmap = F.interpolate(labelmap, scaled_size_hw, mode='nearest').long().squeeze(0).squeeze(0)
|
63 |
+
batch[i]['labelmap'] = DataContainer(labelmap, stack=True)
|
64 |
+
|
65 |
+
if 'maskmap' in data_keys:
|
66 |
+
maskmap = batch[i]['maskmap'].data.unsqueeze(0).unsqueeze(0).float()
|
67 |
+
maskmap = F.interpolate(maskmap, scaled_size_hw, mode='nearest').long().squeeze(0).squeeze(0)
|
68 |
+
batch[i]['maskmap'].data = DataContainer(maskmap, stack=True)
|
69 |
+
|
70 |
+
pad_width = target_width - scaled_size[0]
|
71 |
+
pad_height = target_height - scaled_size[1]
|
72 |
+
assert pad_height >= 0 and pad_width >= 0
|
73 |
+
if pad_width > 0 or pad_height > 0:
|
74 |
+
assert trans_dict['align_method'] in ['only_pad', 'scale_and_pad']
|
75 |
+
left_pad = 0
|
76 |
+
up_pad = 0
|
77 |
+
if 'pad_mode' not in trans_dict or trans_dict['pad_mode'] == 'random':
|
78 |
+
left_pad = random.randint(0, pad_width) # pad_left
|
79 |
+
up_pad = random.randint(0, pad_height) # pad_up
|
80 |
+
|
81 |
+
elif trans_dict['pad_mode'] == 'pad_left_up':
|
82 |
+
left_pad = pad_width
|
83 |
+
up_pad = pad_height
|
84 |
+
|
85 |
+
elif trans_dict['pad_mode'] == 'pad_right_down':
|
86 |
+
left_pad = 0
|
87 |
+
up_pad = 0
|
88 |
+
|
89 |
+
elif trans_dict['pad_mode'] == 'pad_center':
|
90 |
+
left_pad = pad_width // 2
|
91 |
+
up_pad = pad_height // 2
|
92 |
+
|
93 |
+
elif trans_dict['pad_mode'] == 'pad_border':
|
94 |
+
if random.randint(0, 1) == 0:
|
95 |
+
left_pad = pad_width
|
96 |
+
up_pad = pad_height
|
97 |
+
else:
|
98 |
+
left_pad = 0
|
99 |
+
up_pad = 0
|
100 |
+
else:
|
101 |
+
raise ValueError("mode not define")
|
102 |
+
exit(1)
|
103 |
+
|
104 |
+
pad = (left_pad, pad_width-left_pad, up_pad, pad_height-up_pad)
|
105 |
+
|
106 |
+
batch[i]['img'] = DataContainer(F.pad(batch[i]['img'].data, pad=pad, value=0), stack=batch[i]['img'].stack)
|
107 |
+
|
108 |
+
if 'labelmap' in data_keys:
|
109 |
+
batch[i]['labelmap'] = DataContainer(F.pad(batch[i]['labelmap'].data, pad=pad, value=-1), stack=batch[i]['labelmap'].stack)
|
110 |
+
|
111 |
+
if 'maskmap' in data_keys:
|
112 |
+
batch[i]['maskmap'] = DataContainer(F.pad(batch[i]['maskmap'].data, pad=pad, value=0), stack=batch[i]['maskmap'].stack)
|
113 |
+
|
114 |
+
if 'distance_map' in data_keys:
|
115 |
+
batch[i]['distance_map'] = DataContainer(F.pad(batch[i]['distance_map'].data, pad=pad, value=255), stack=batch[i]['distance_map'].stack)
|
116 |
+
|
117 |
+
if 'angle_map' in data_keys:
|
118 |
+
batch[i]['angle_map'] = DataContainer(F.pad(batch[i]['angle_map'].data, pad=pad, value=0), stack=batch[i]['angle_map'].stack)
|
119 |
+
|
120 |
+
if 'mask_label_map' in data_keys:
|
121 |
+
batch[i]['mask_label_map'] = DataContainer(F.pad(batch[i]['mask_label_map'].data, pad=pad, value=-1), stack=batch[i]['mask_label_map'].stack)
|
122 |
+
|
123 |
+
if 'direction_label_map' in data_keys:
|
124 |
+
batch[i]['direction_label_map'] = DataContainer(F.pad(batch[i]['direction_label_map'].data, pad=pad, value=-1), stack=batch[i]['direction_label_map'].stack)
|
125 |
+
|
126 |
+
if 'multi_label_direction_map' in data_keys:
|
127 |
+
batch[i]['multi_label_direction_map'] = DataContainer(F.pad(batch[i]['multi_label_direction_map'].data, pad=pad, value=-1), stack=batch[i]['multi_label_direction_map'].stack)
|
128 |
+
|
129 |
+
if 'energy_label_map' in data_keys:
|
130 |
+
batch[i]['energy_label_map'] = DataContainer(F.pad(batch[i]['energy_label_map'].data, pad=pad, value=-1), stack=batch[i]['energy_label_map'].stack)
|
131 |
+
|
132 |
+
if 'offsetmap_h' in data_keys:
|
133 |
+
batch[i]['offsetmap_h'] = DataContainer(F.pad(batch[i]['offsetmap_h'].data, pad=pad, value=0), stack=batch[i]['offsetmap_h'].stack)
|
134 |
+
|
135 |
+
if 'offsetmap_w' in data_keys:
|
136 |
+
batch[i]['offsetmap_w'] = DataContainer(F.pad(batch[i]['offsetmap_w'].data, pad=pad, value=0), stack=batch[i]['offsetmap_w'].stack)
|
137 |
+
|
138 |
+
return dict({key: stack(batch, data_key=key) for key in data_keys})
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
|
core/data/datasets/images/seg_data_tools/cv2_aug_transforms.py
ADDED
@@ -0,0 +1,889 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
class _BaseTransform(object):
|
9 |
+
|
10 |
+
DATA_ITEMS = (
|
11 |
+
'labelmap', 'maskmap',
|
12 |
+
'distance_map', 'angle_map', 'multi_label_direction_map',
|
13 |
+
'boundary_map', 'offsetmap',
|
14 |
+
# 'offsetmap_h', 'offsetmap_w',
|
15 |
+
'region_indexmap'
|
16 |
+
)
|
17 |
+
|
18 |
+
def __call__(self, img, **kwargs):
|
19 |
+
|
20 |
+
data_dict = collections.defaultdict(lambda: None)
|
21 |
+
data_dict.update(kwargs)
|
22 |
+
|
23 |
+
return img, data_dict
|
24 |
+
|
25 |
+
def _process(self, img, data_dict, skip_condition, *args, **kwargs):
|
26 |
+
assert isinstance(img, np.ndarray), \
|
27 |
+
"img should be numpy array, got {}.".format(type(img))
|
28 |
+
if not skip_condition:
|
29 |
+
img = self._process_img(img, *args, **kwargs)
|
30 |
+
|
31 |
+
ret_dict = collections.defaultdict(lambda: None)
|
32 |
+
for name in self.DATA_ITEMS:
|
33 |
+
func_name = '_process_' + name
|
34 |
+
x = data_dict[name]
|
35 |
+
|
36 |
+
assert isinstance(x, np.ndarray) or x is None, \
|
37 |
+
"{} should be numpy array or None, got {}.".format(
|
38 |
+
name, type(x))
|
39 |
+
|
40 |
+
if hasattr(self, func_name) and x is not None and not skip_condition:
|
41 |
+
ret_dict[name] = getattr(self, func_name)(x, *args, **kwargs)
|
42 |
+
else:
|
43 |
+
ret_dict[name] = x
|
44 |
+
|
45 |
+
return img, ret_dict
|
46 |
+
|
47 |
+
|
48 |
+
class Padding(_BaseTransform):
|
49 |
+
""" Padding the Image to proper size.
|
50 |
+
Args:
|
51 |
+
stride: the stride of the network.
|
52 |
+
pad_value: the value that pad to the image border.
|
53 |
+
img: Image object as input.
|
54 |
+
Returns::
|
55 |
+
img: Image object.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, pad=None, pad_ratio=0.5, mean=(104, 117, 123), allow_outside_center=True):
|
59 |
+
self.pad = pad
|
60 |
+
self.ratio = pad_ratio
|
61 |
+
self.mean = mean
|
62 |
+
self.allow_outside_center = allow_outside_center
|
63 |
+
|
64 |
+
def _pad(self, x, pad_value, height, width, target_size, offset_left, offset_up):
|
65 |
+
expand_x = np.zeros((
|
66 |
+
max(height, target_size[1]) + abs(offset_up),
|
67 |
+
max(width, target_size[0]) + abs(offset_left),
|
68 |
+
*x.shape[2:]
|
69 |
+
), dtype=x.dtype)
|
70 |
+
expand_x[:, :] = pad_value
|
71 |
+
expand_x[
|
72 |
+
abs(min(offset_up, 0)):abs(min(offset_up, 0)) + height,
|
73 |
+
abs(min(offset_left, 0)):abs(min(offset_left, 0)) + width] = x
|
74 |
+
x = expand_x[
|
75 |
+
max(offset_up, 0):max(offset_up, 0) + target_size[1],
|
76 |
+
max(offset_left, 0):max(offset_left, 0) + target_size[0]
|
77 |
+
]
|
78 |
+
return x
|
79 |
+
|
80 |
+
def _process_img(self, img, *args):
|
81 |
+
return self._pad(img, self.mean, *args)
|
82 |
+
|
83 |
+
def _process_labelmap(self, x, *args):
|
84 |
+
return self._pad(x, 255, *args)
|
85 |
+
|
86 |
+
def _process_region_indexmap(self, x, *args):
|
87 |
+
return self._pad(x, 0, *args)
|
88 |
+
|
89 |
+
def _process_maskmap(self, x, *args):
|
90 |
+
return self._pad(x, 1, *args)
|
91 |
+
|
92 |
+
def _process_distance_map(self, x, *args):
|
93 |
+
return self._pad(x, 255, *args)
|
94 |
+
|
95 |
+
def _process_angle_map(self, x, *args):
|
96 |
+
return self._pad(x, 0, *args)
|
97 |
+
|
98 |
+
def _process_boundary_map(self, x, *args):
|
99 |
+
return self._pad(x, 0, *args)
|
100 |
+
|
101 |
+
def _process_multi_label_direction_map(self, x, *args):
|
102 |
+
return self._pad(x, 0, *args)
|
103 |
+
|
104 |
+
# def _process_offsetmap_h(self, x, *args):
|
105 |
+
# return self._pad(x, 0, *args)
|
106 |
+
|
107 |
+
# def _process_offsetmap_w(self, x, *args):
|
108 |
+
# return self._pad(x, 0, *args)
|
109 |
+
|
110 |
+
def _process_offsetmap(self, x, *args):
|
111 |
+
return self._pad(x, 0, *args)
|
112 |
+
|
113 |
+
def __call__(self, img, **kwargs):
|
114 |
+
img, data_dict = super().__call__(img, **kwargs)
|
115 |
+
|
116 |
+
height, width, channels = img.shape
|
117 |
+
left_pad, up_pad, right_pad, down_pad = self.pad
|
118 |
+
|
119 |
+
target_size = [width + left_pad +
|
120 |
+
right_pad, height + up_pad + down_pad]
|
121 |
+
offset_left = -left_pad
|
122 |
+
offset_up = -up_pad
|
123 |
+
|
124 |
+
return self._process(
|
125 |
+
img, data_dict,
|
126 |
+
random.random() > self.ratio,
|
127 |
+
height, width, target_size, offset_left, offset_up
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
class RandomHFlip(_BaseTransform):
|
132 |
+
def __init__(self, swap_pair=None, flip_ratio=0.5):
|
133 |
+
self.swap_pair = swap_pair
|
134 |
+
self.ratio = flip_ratio
|
135 |
+
|
136 |
+
def _process_img(self, img):
|
137 |
+
return cv2.flip(img, 1)
|
138 |
+
|
139 |
+
def _process_labelmap(self, labelmap):
|
140 |
+
labelmap = cv2.flip(labelmap, 1)
|
141 |
+
# to handle datasets with left/right annatations
|
142 |
+
if self.swap_pair is not None:
|
143 |
+
assert isinstance(self.swap_pair, (tuple, list))
|
144 |
+
temp = labelmap.copy()
|
145 |
+
for pair in self.swap_pair:
|
146 |
+
assert isinstance(pair, (tuple, list)) and len(pair) == 2
|
147 |
+
labelmap[temp == pair[0]] = pair[1]
|
148 |
+
labelmap[temp == pair[1]] = pair[0]
|
149 |
+
|
150 |
+
return labelmap
|
151 |
+
|
152 |
+
def _process_region_indexmap(self, labelmap):
|
153 |
+
return cv2.flip(labelmap, 1)
|
154 |
+
|
155 |
+
def _process_maskmap(self, x):
|
156 |
+
return cv2.flip(x, 1)
|
157 |
+
|
158 |
+
def _process_distance_map(self, x):
|
159 |
+
return cv2.flip(x, 1)
|
160 |
+
|
161 |
+
def _process_angle_map(self, angle_map):
|
162 |
+
ret_angle_map = angle_map.copy()
|
163 |
+
mask = (angle_map > 0) & (angle_map < 180)
|
164 |
+
ret_angle_map[mask] = 180 - angle_map[mask]
|
165 |
+
mask = (angle_map < 0) & (angle_map > -180)
|
166 |
+
ret_angle_map[mask] = - (180 + angle_map[mask])
|
167 |
+
ret_angle_map = cv2.flip(ret_angle_map, 1)
|
168 |
+
return ret_angle_map
|
169 |
+
|
170 |
+
def _process_boundary_map(self, x):
|
171 |
+
return cv2.flip(x, 1)
|
172 |
+
|
173 |
+
def _process_multi_label_direction_map(self, multi_label_direction_map):
|
174 |
+
perm = [4, 3, 2, 1, 0, 7, 6, 5]
|
175 |
+
multi_label_direction_map = cv2.flip(multi_label_direction_map, 1)
|
176 |
+
multi_label_direction_map = multi_label_direction_map[..., perm]
|
177 |
+
return multi_label_direction_map
|
178 |
+
|
179 |
+
# def _process_offsetmap_h(self, x):
|
180 |
+
# return cv2.flip(x, 1)
|
181 |
+
|
182 |
+
# def _process_offsetmap_w(self, x):
|
183 |
+
# return -cv2.flip(x, 1)
|
184 |
+
|
185 |
+
def _process_offsetmap_w(self, x):
|
186 |
+
x = cv2.flip(x, 1)
|
187 |
+
x[..., 1] = -x[..., 1]
|
188 |
+
return x
|
189 |
+
|
190 |
+
def __call__(self, img, **kwargs):
|
191 |
+
img, data_dict = super().__call__(img, **kwargs)
|
192 |
+
|
193 |
+
return self._process(
|
194 |
+
img, data_dict,
|
195 |
+
random.random() > self.ratio
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
class RandomSaturation(_BaseTransform):
|
200 |
+
def __init__(self, lower=0.5, upper=1.5, saturation_ratio=0.5):
|
201 |
+
self.lower = lower
|
202 |
+
self.upper = upper
|
203 |
+
self.ratio = saturation_ratio
|
204 |
+
assert self.upper >= self.lower, "saturation upper must be >= lower."
|
205 |
+
assert self.lower >= 0, "saturation lower must be non-negative."
|
206 |
+
|
207 |
+
def _process_img(self, img):
|
208 |
+
img = img.astype(np.float32)
|
209 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
210 |
+
img[:, :, 1] *= random.uniform(self.lower, self.upper)
|
211 |
+
img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
|
212 |
+
img = np.clip(img, 0, 255).astype(np.uint8)
|
213 |
+
return img
|
214 |
+
|
215 |
+
def __call__(self, img, **kwargs):
|
216 |
+
img, data_dict = super().__call__(img, **kwargs)
|
217 |
+
|
218 |
+
return self._process(
|
219 |
+
img, data_dict,
|
220 |
+
random.random() > self.ratio
|
221 |
+
)
|
222 |
+
|
223 |
+
|
224 |
+
class RandomHue(_BaseTransform):
|
225 |
+
def __init__(self, delta=18, hue_ratio=0.5):
|
226 |
+
assert 0 <= delta <= 360
|
227 |
+
self.delta = delta
|
228 |
+
self.ratio = hue_ratio
|
229 |
+
|
230 |
+
def _process_img(self, img):
|
231 |
+
img = img.astype(np.float32)
|
232 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
233 |
+
img[:, :, 0] += random.uniform(-self.delta, self.delta)
|
234 |
+
img[:, :, 0][img[:, :, 0] > 360] -= 360
|
235 |
+
img[:, :, 0][img[:, :, 0] < 0] += 360
|
236 |
+
img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
|
237 |
+
img = np.clip(img, 0, 255).astype(np.uint8)
|
238 |
+
return img
|
239 |
+
|
240 |
+
def __call__(self, img, **kwargs):
|
241 |
+
img, data_dict = super().__call__(img, **kwargs)
|
242 |
+
|
243 |
+
return self._process(
|
244 |
+
img, data_dict,
|
245 |
+
random.random() > self.ratio
|
246 |
+
)
|
247 |
+
|
248 |
+
|
249 |
+
class RandomPerm(_BaseTransform):
|
250 |
+
def __init__(self, perm_ratio=0.5):
|
251 |
+
self.ratio = perm_ratio
|
252 |
+
self.perms = ((0, 1, 2), (0, 2, 1),
|
253 |
+
(1, 0, 2), (1, 2, 0),
|
254 |
+
(2, 0, 1), (2, 1, 0))
|
255 |
+
|
256 |
+
def _process_img(self, img):
|
257 |
+
swap = self.perms[random.randint(0, len(self.perms) - 1)]
|
258 |
+
img = img[:, :, swap].astype(np.uint8)
|
259 |
+
return img
|
260 |
+
|
261 |
+
def __call__(self, img, **kwargs):
|
262 |
+
img, data_dict = super().__call__(img, **kwargs)
|
263 |
+
|
264 |
+
return self._process(
|
265 |
+
img, data_dict,
|
266 |
+
random.random() > self.ratio
|
267 |
+
)
|
268 |
+
|
269 |
+
|
270 |
+
class RandomContrast(_BaseTransform):
|
271 |
+
def __init__(self, lower=0.5, upper=1.5, contrast_ratio=0.5):
|
272 |
+
self.lower = lower
|
273 |
+
self.upper = upper
|
274 |
+
self.ratio = contrast_ratio
|
275 |
+
assert self.upper >= self.lower, "contrast upper must be >= lower."
|
276 |
+
assert self.lower >= 0, "contrast lower must be non-negative."
|
277 |
+
|
278 |
+
def _process_img(self, img):
|
279 |
+
img = img.astype(np.float32)
|
280 |
+
img *= random.uniform(self.lower, self.upper)
|
281 |
+
img = np.clip(img, 0, 255).astype(np.uint8)
|
282 |
+
return img
|
283 |
+
|
284 |
+
def __call__(self, img, **kwargs):
|
285 |
+
img, data_dict = super().__call__(img, **kwargs)
|
286 |
+
|
287 |
+
return self._process(
|
288 |
+
img, data_dict,
|
289 |
+
random.random() > self.ratio
|
290 |
+
)
|
291 |
+
|
292 |
+
|
293 |
+
class RandomBrightness(_BaseTransform):
|
294 |
+
def __init__(self, shift_value=30, brightness_ratio=0.5):
|
295 |
+
self.shift_value = shift_value
|
296 |
+
self.ratio = brightness_ratio
|
297 |
+
|
298 |
+
def _process_img(self, img):
|
299 |
+
img = img.astype(np.float32)
|
300 |
+
shift = random.randint(-self.shift_value, self.shift_value)
|
301 |
+
img[:, :, :] += shift
|
302 |
+
img = np.around(img)
|
303 |
+
img = np.clip(img, 0, 255).astype(np.uint8)
|
304 |
+
return img
|
305 |
+
|
306 |
+
def __call__(self, img, **kwargs):
|
307 |
+
img, data_dict = super().__call__(img, **kwargs)
|
308 |
+
|
309 |
+
return self._process(
|
310 |
+
img, data_dict,
|
311 |
+
random.random() > self.ratio
|
312 |
+
)
|
313 |
+
|
314 |
+
|
315 |
+
class RandomResize(_BaseTransform):
|
316 |
+
"""Resize the given numpy.ndarray to random size and aspect ratio.
|
317 |
+
|
318 |
+
Args:
|
319 |
+
scale_min: the min scale to resize.
|
320 |
+
scale_max: the max scale to resize.
|
321 |
+
"""
|
322 |
+
|
323 |
+
def __init__(self, scale_range=(0.75, 1.25), aspect_range=(0.9, 1.1), target_size=None,
|
324 |
+
resize_bound=None, method='random', max_side_bound=None, scale_list=None, resize_ratio=0.5):
|
325 |
+
self.scale_range = scale_range
|
326 |
+
self.aspect_range = aspect_range
|
327 |
+
self.resize_bound = resize_bound
|
328 |
+
self.max_side_bound = max_side_bound
|
329 |
+
self.scale_list = scale_list
|
330 |
+
self.method = method
|
331 |
+
self.ratio = resize_ratio
|
332 |
+
|
333 |
+
if target_size is not None:
|
334 |
+
if isinstance(target_size, int):
|
335 |
+
self.input_size = (target_size, target_size)
|
336 |
+
elif isinstance(target_size, (list, tuple)) and len(target_size) == 2:
|
337 |
+
self.input_size = target_size
|
338 |
+
else:
|
339 |
+
raise TypeError(
|
340 |
+
'Got inappropriate size arg: {}'.format(target_size))
|
341 |
+
else:
|
342 |
+
self.input_size = None
|
343 |
+
|
344 |
+
def get_scale(self, img_size):
|
345 |
+
if self.method == 'random':
|
346 |
+
scale_ratio = random.uniform(
|
347 |
+
self.scale_range[0], self.scale_range[1])
|
348 |
+
return scale_ratio
|
349 |
+
|
350 |
+
elif self.method == 'bound':
|
351 |
+
scale1 = self.resize_bound[0] / min(img_size)
|
352 |
+
scale2 = self.resize_bound[1] / max(img_size)
|
353 |
+
scale = min(scale1, scale2)
|
354 |
+
return scale
|
355 |
+
|
356 |
+
else:
|
357 |
+
raise ValueError("invalid method")
|
358 |
+
exit(1)
|
359 |
+
|
360 |
+
def _process_img(self, img, converted_size, *args):
|
361 |
+
return cv2.resize(img, converted_size, interpolation=cv2.INTER_CUBIC).astype(np.uint8)
|
362 |
+
|
363 |
+
def _process_labelmap(self, x, converted_size, *args):
|
364 |
+
return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
|
365 |
+
|
366 |
+
def _process_region_indexmap(self, x, converted_size, *args):
|
367 |
+
return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
|
368 |
+
|
369 |
+
def _process_maskmap(self, x, converted_size, *args):
|
370 |
+
return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
|
371 |
+
|
372 |
+
def _process_distance_map(self, x, converted_size, *args):
|
373 |
+
return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
|
374 |
+
|
375 |
+
def _process_angle_map(self, x, converted_size, *args):
|
376 |
+
return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
|
377 |
+
|
378 |
+
def _process_boundary_map(self, x, converted_size, *args):
|
379 |
+
return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
|
380 |
+
|
381 |
+
def _process_multi_label_direction_map(self, x, converted_size, *args):
|
382 |
+
return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
|
383 |
+
|
384 |
+
# def _process_offsetmap_h(self, x, converted_size, h_scale_ratio, w_scale_ratio):
|
385 |
+
# return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST) * h_scale_ratio
|
386 |
+
|
387 |
+
# def _process_offsetmap_w(self, x, converted_size, h_scale_ratio, w_scale_ratio):
|
388 |
+
# return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST) * w_scale_ratio
|
389 |
+
|
390 |
+
def _process_offsetmap(self, x, converted_size, h_scale_ratio, w_scale_ratio):
|
391 |
+
return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
|
392 |
+
|
393 |
+
def __call__(self, img, **kwargs):
|
394 |
+
"""
|
395 |
+
Args:
|
396 |
+
img (Image): Image to be resized.
|
397 |
+
maskmap (Image): Mask to be resized.
|
398 |
+
kpt (list): keypoints to be resized.
|
399 |
+
center: (list): center points to be resized.
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
Image: Randomly resize image.
|
403 |
+
Image: Randomly resize maskmap.
|
404 |
+
list: Randomly resize keypoints.
|
405 |
+
list: Randomly resize center points.
|
406 |
+
"""
|
407 |
+
img, data_dict = super().__call__(img, **kwargs)
|
408 |
+
|
409 |
+
height, width, _ = img.shape
|
410 |
+
if self.scale_list is None:
|
411 |
+
scale_ratio = self.get_scale([width, height])
|
412 |
+
else:
|
413 |
+
scale_ratio = self.scale_list[random.randint(
|
414 |
+
0, len(self.scale_list)-1)]
|
415 |
+
|
416 |
+
aspect_ratio = random.uniform(*self.aspect_range)
|
417 |
+
w_scale_ratio = math.sqrt(aspect_ratio) * scale_ratio
|
418 |
+
h_scale_ratio = math.sqrt(1.0 / aspect_ratio) * scale_ratio
|
419 |
+
if self.max_side_bound is not None and max(height*h_scale_ratio, width*w_scale_ratio) > self.max_side_bound:
|
420 |
+
d_ratio = self.max_side_bound / max(height * h_scale_ratio, width * w_scale_ratio)
|
421 |
+
w_scale_ratio *= d_ratio
|
422 |
+
h_scale_ratio *= d_ratio
|
423 |
+
|
424 |
+
converted_size = (int(width * w_scale_ratio),
|
425 |
+
int(height * h_scale_ratio))
|
426 |
+
return self._process(
|
427 |
+
img, data_dict,
|
428 |
+
random.random() > self.ratio,
|
429 |
+
converted_size, h_scale_ratio, w_scale_ratio
|
430 |
+
)
|
431 |
+
|
432 |
+
|
433 |
+
class RandomRotate(_BaseTransform):
|
434 |
+
"""Rotate the input numpy.ndarray and points to the given degree.
|
435 |
+
|
436 |
+
Args:
|
437 |
+
degree (number): Desired rotate degree.
|
438 |
+
"""
|
439 |
+
|
440 |
+
def __init__(self, max_degree, rotate_ratio=0.5, mean=(104, 117, 123)):
|
441 |
+
assert isinstance(max_degree, int)
|
442 |
+
self.max_degree = max_degree
|
443 |
+
self.ratio = rotate_ratio
|
444 |
+
self.mean = mean
|
445 |
+
|
446 |
+
def _warp(self, x, border_value, rotate_mat, new_width, new_height):
|
447 |
+
return cv2.warpAffine(x, rotate_mat, (new_width, new_height), borderValue=border_value)
|
448 |
+
|
449 |
+
def _process_img(self, x, *args):
|
450 |
+
return self._warp(x, self.mean, *args).astype(np.uint8)
|
451 |
+
|
452 |
+
def _process_labelmap(self, x, *args):
|
453 |
+
return self._warp(x, (255, 255, 255), *args).astype(np.uint8)
|
454 |
+
|
455 |
+
def _process_maskmap(self, x, *args):
|
456 |
+
return self._warp(x, (1, 1, 1), *args).astype(np.uint8)
|
457 |
+
|
458 |
+
def __call__(self, img, **kwargs):
|
459 |
+
"""
|
460 |
+
Args:
|
461 |
+
img (Image): Image to be rotated.
|
462 |
+
maskmap (Image): Mask to be rotated.
|
463 |
+
kpt (list): Keypoints to be rotated.
|
464 |
+
center (list): Center points to be rotated.
|
465 |
+
|
466 |
+
Returns:
|
467 |
+
Image: Rotated image.
|
468 |
+
list: Rotated key points.
|
469 |
+
"""
|
470 |
+
img, data_dict = super().__call__(img, **kwargs)
|
471 |
+
|
472 |
+
rotate_degree = random.uniform(-self.max_degree, self.max_degree)
|
473 |
+
height, width, _ = img.shape
|
474 |
+
img_center = (width / 2.0, height / 2.0)
|
475 |
+
rotate_mat = cv2.getRotationMatrix2D(img_center, rotate_degree, 1.0)
|
476 |
+
cos_val = np.abs(rotate_mat[0, 0])
|
477 |
+
sin_val = np.abs(rotate_mat[0, 1])
|
478 |
+
new_width = int(height * sin_val + width * cos_val)
|
479 |
+
new_height = int(height * cos_val + width * sin_val)
|
480 |
+
rotate_mat[0, 2] += (new_width / 2.) - img_center[0]
|
481 |
+
rotate_mat[1, 2] += (new_height / 2.) - img_center[1]
|
482 |
+
|
483 |
+
return self._process(
|
484 |
+
img, data_dict,
|
485 |
+
random.random() > self.ratio,
|
486 |
+
rotate_mat, new_width, new_height
|
487 |
+
)
|
488 |
+
|
489 |
+
|
490 |
+
class RandomCrop(_BaseTransform):
|
491 |
+
"""Crop the given numpy.ndarray and at a random location.
|
492 |
+
|
493 |
+
Args:
|
494 |
+
size (int or tuple): Desired output size of the crop.(w, h)
|
495 |
+
"""
|
496 |
+
|
497 |
+
def __init__(self, crop_size, crop_ratio=0.5, method='random', grid=None, allow_outside_center=True):
|
498 |
+
self.ratio = crop_ratio
|
499 |
+
self.method = method
|
500 |
+
self.grid = grid
|
501 |
+
self.allow_outside_center = allow_outside_center
|
502 |
+
|
503 |
+
if isinstance(crop_size, float):
|
504 |
+
self.size = (crop_size, crop_size)
|
505 |
+
elif isinstance(crop_size, collections.Iterable) and len(crop_size) == 2:
|
506 |
+
self.size = crop_size
|
507 |
+
else:
|
508 |
+
raise TypeError('Got inappropriate size arg: {}'.format(crop_size))
|
509 |
+
|
510 |
+
def get_lefttop(self, crop_size, img_size):
|
511 |
+
if self.method == 'center':
|
512 |
+
return [(img_size[0] - crop_size[0]) // 2, (img_size[1] - crop_size[1]) // 2]
|
513 |
+
|
514 |
+
elif self.method == 'random':
|
515 |
+
x = random.randint(0, img_size[0] - crop_size[0])
|
516 |
+
y = random.randint(0, img_size[1] - crop_size[1])
|
517 |
+
return [x, y]
|
518 |
+
|
519 |
+
elif self.method == 'grid':
|
520 |
+
grid_x = random.randint(0, self.grid[0] - 1)
|
521 |
+
grid_y = random.randint(0, self.grid[1] - 1)
|
522 |
+
x = grid_x * ((img_size[0] - crop_size[0]) // (self.grid[0] - 1))
|
523 |
+
y = grid_y * ((img_size[1] - crop_size[1]) // (self.grid[1] - 1))
|
524 |
+
return [x, y]
|
525 |
+
|
526 |
+
else:
|
527 |
+
raise ValueError('Crop method invalid')
|
528 |
+
exit(1)
|
529 |
+
|
530 |
+
def _crop(self, x, offset_up, offset_left, target_size):
|
531 |
+
return x[offset_up:offset_up + target_size[1], offset_left:offset_left + target_size[0]]
|
532 |
+
|
533 |
+
def _process_img(self, img, *args):
|
534 |
+
return self._crop(img, *args)
|
535 |
+
|
536 |
+
def _process_labelmap(self, x, *args):
|
537 |
+
return self._crop(x, *args)
|
538 |
+
|
539 |
+
def _process_region_indexmap(self, x, *args):
|
540 |
+
return self._crop(x, *args)
|
541 |
+
|
542 |
+
def _process_maskmap(self, x, *args):
|
543 |
+
return self._crop(x, *args)
|
544 |
+
|
545 |
+
def _process_distance_map(self, x, *args):
|
546 |
+
return self._crop(x, *args)
|
547 |
+
|
548 |
+
def _process_angle_map(self, x, *args):
|
549 |
+
return self._crop(x, *args)
|
550 |
+
|
551 |
+
def _process_boundary_map(self, x, *args):
|
552 |
+
return self._crop(x, *args)
|
553 |
+
|
554 |
+
def _process_multi_label_direction_map(self, x, *args):
|
555 |
+
return self._crop(x, *args)
|
556 |
+
|
557 |
+
# def _process_offsetmap_h(self, x, *args):
|
558 |
+
# return self._crop(x, *args)
|
559 |
+
|
560 |
+
# def _process_offsetmap_w(self, x, *args):
|
561 |
+
# return self._crop(x, *args)
|
562 |
+
|
563 |
+
def _process_offsetmap(self, x, *args):
|
564 |
+
return self._crop(x, *args)
|
565 |
+
|
566 |
+
def __call__(self, img, **kwargs):
|
567 |
+
"""
|
568 |
+
Args:
|
569 |
+
img (Image): Image to be cropped.
|
570 |
+
maskmap (Image): Mask to be cropped.
|
571 |
+
|
572 |
+
Returns:
|
573 |
+
Image: Cropped image.
|
574 |
+
Image: Cropped maskmap.
|
575 |
+
list: Cropped keypoints.
|
576 |
+
list: Cropped center points.
|
577 |
+
"""
|
578 |
+
img, data_dict = super().__call__(img, **kwargs)
|
579 |
+
|
580 |
+
height, width, _ = img.shape
|
581 |
+
target_size = [min(self.size[0], width), min(self.size[1], height)]
|
582 |
+
|
583 |
+
offset_left, offset_up = self.get_lefttop(target_size, [width, height])
|
584 |
+
return self._process(
|
585 |
+
img, data_dict,
|
586 |
+
random.random() > self.ratio,
|
587 |
+
offset_up, offset_left, target_size
|
588 |
+
)
|
589 |
+
|
590 |
+
|
591 |
+
class Resize(RandomResize):
|
592 |
+
"""Resize the given numpy.ndarray to random size and aspect ratio.
|
593 |
+
Args:
|
594 |
+
scale_min: the min scale to resize.
|
595 |
+
scale_max: the max scale to resize.
|
596 |
+
"""
|
597 |
+
|
598 |
+
def __init__(self, target_size=None, min_side_length=None, max_side_length=None, max_side_bound=None):
|
599 |
+
self.target_size = target_size
|
600 |
+
self.min_side_length = min_side_length
|
601 |
+
self.max_side_length = max_side_length
|
602 |
+
self.max_side_bound = max_side_bound
|
603 |
+
|
604 |
+
def __call__(self, img, **kwargs):
|
605 |
+
img, data_dict = super(RandomResize, self).__call__(img, **kwargs)
|
606 |
+
|
607 |
+
height, width, _ = img.shape
|
608 |
+
if self.target_size is not None:
|
609 |
+
target_size = self.target_size
|
610 |
+
w_scale_ratio = self.target_size[0] / width
|
611 |
+
h_scale_ratio = self.target_size[1] / height
|
612 |
+
|
613 |
+
elif self.min_side_length is not None:
|
614 |
+
scale_ratio = self.min_side_length / min(width, height)
|
615 |
+
w_scale_ratio, h_scale_ratio = scale_ratio, scale_ratio
|
616 |
+
target_size = [int(round(width * w_scale_ratio)),
|
617 |
+
int(round(height * h_scale_ratio))]
|
618 |
+
|
619 |
+
else:
|
620 |
+
scale_ratio = self.max_side_length / max(width, height)
|
621 |
+
w_scale_ratio, h_scale_ratio = scale_ratio, scale_ratio
|
622 |
+
target_size = [int(round(width * w_scale_ratio)),
|
623 |
+
int(round(height * h_scale_ratio))]
|
624 |
+
|
625 |
+
if self.max_side_bound is not None and max(target_size) > self.max_side_bound:
|
626 |
+
d_ratio = self.max_side_bound / max(target_size)
|
627 |
+
w_scale_ratio = d_ratio * w_scale_ratio
|
628 |
+
h_scale_ratio = d_ratio * h_scale_ratio
|
629 |
+
target_size = [int(round(width * w_scale_ratio)),
|
630 |
+
int(round(height * h_scale_ratio))]
|
631 |
+
|
632 |
+
target_size = tuple(target_size)
|
633 |
+
return self._process(
|
634 |
+
img, data_dict,
|
635 |
+
False,
|
636 |
+
target_size, h_scale_ratio, w_scale_ratio
|
637 |
+
)
|
638 |
+
|
639 |
+
|
640 |
+
class CV2AugCompose(object):
|
641 |
+
"""Composes several transforms together.
|
642 |
+
|
643 |
+
Args:
|
644 |
+
transforms (list of ``Transform`` objects): list of transforms to compose.
|
645 |
+
|
646 |
+
Example:
|
647 |
+
>>> CV2AugCompose([
|
648 |
+
>>> RandomCrop(),
|
649 |
+
>>> ])
|
650 |
+
"""
|
651 |
+
|
652 |
+
def __init__(self, configer, split='train'):
|
653 |
+
self.configer = configer
|
654 |
+
self.split = split
|
655 |
+
|
656 |
+
if self.split == 'train':
|
657 |
+
shuffle_train_trans = []
|
658 |
+
if self.configer.exists('train_trans', 'shuffle_trans_seq'):
|
659 |
+
if isinstance(self.configer.get('train_trans', 'shuffle_trans_seq')[0], list):
|
660 |
+
train_trans_seq_list = self.configer.get(
|
661 |
+
'train_trans', 'shuffle_trans_seq')
|
662 |
+
for train_trans_seq in train_trans_seq_list:
|
663 |
+
shuffle_train_trans += train_trans_seq
|
664 |
+
|
665 |
+
else:
|
666 |
+
shuffle_train_trans = self.configer.get(
|
667 |
+
'train_trans', 'shuffle_trans_seq')
|
668 |
+
trans_seq = self.configer.get(
|
669 |
+
'train_trans', 'trans_seq') + shuffle_train_trans
|
670 |
+
trans_key = 'train_trans'
|
671 |
+
else:
|
672 |
+
trans_seq = self.configer.get('val_trans', 'trans_seq')
|
673 |
+
trans_key = 'val_trans'
|
674 |
+
|
675 |
+
self.transforms = dict()
|
676 |
+
self.trans_config = self.configer.get(trans_key)
|
677 |
+
for trans_name in trans_seq:
|
678 |
+
specs = TRANSFORM_SPEC[trans_name]
|
679 |
+
config = self.configer.get(trans_key, trans_name)
|
680 |
+
for spec in specs:
|
681 |
+
if 'when' not in spec:
|
682 |
+
break
|
683 |
+
choose_this = True
|
684 |
+
for cond_key, cond_value in spec['when'].items():
|
685 |
+
choose_this = choose_this and (
|
686 |
+
config[cond_key] == cond_value)
|
687 |
+
if choose_this:
|
688 |
+
break
|
689 |
+
else:
|
690 |
+
raise RuntimeError("Not support!")
|
691 |
+
|
692 |
+
kwargs = {}
|
693 |
+
for arg_name, arg_path in spec["args"].items():
|
694 |
+
if isinstance(arg_path, str):
|
695 |
+
arg_value = config.get(arg_path, None)
|
696 |
+
elif isinstance(arg_path, list):
|
697 |
+
arg_value = self.configer.get(*arg_path)
|
698 |
+
kwargs[arg_name] = arg_value
|
699 |
+
|
700 |
+
klass = TRANSFORM_MAPPING[trans_name]
|
701 |
+
self.transforms[trans_name] = klass(**kwargs)
|
702 |
+
|
703 |
+
def __call__(self, img, **data_dict):
|
704 |
+
|
705 |
+
orig_key_list = list(data_dict)
|
706 |
+
|
707 |
+
if self.configer.get('data', 'input_mode') == 'RGB':
|
708 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
709 |
+
|
710 |
+
if self.split == 'train':
|
711 |
+
shuffle_trans_seq = []
|
712 |
+
if self.configer.exists('train_trans', 'shuffle_trans_seq'):
|
713 |
+
if isinstance(self.configer.get('train_trans', 'shuffle_trans_seq')[0], list):
|
714 |
+
shuffle_trans_seq_list = self.configer.get('train_trans', 'shuffle_trans_seq')
|
715 |
+
shuffle_trans_seq = shuffle_trans_seq_list[random.randint(0, len(shuffle_trans_seq_list))]
|
716 |
+
else:
|
717 |
+
shuffle_trans_seq = self.configer.get('train_trans', 'shuffle_trans_seq')
|
718 |
+
random.shuffle(shuffle_trans_seq)
|
719 |
+
trans_seq = shuffle_trans_seq + self.configer.get('train_trans', 'trans_seq')
|
720 |
+
else:
|
721 |
+
trans_seq = self.configer.get('val_trans', 'trans_seq')
|
722 |
+
|
723 |
+
for trans_key in trans_seq:
|
724 |
+
img, data_dict = self.transforms[trans_key](img, **data_dict)
|
725 |
+
|
726 |
+
if self.configer.get('data', 'input_mode') == 'RGB':
|
727 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
728 |
+
|
729 |
+
return (img, *[data_dict[key] for key in orig_key_list])
|
730 |
+
|
731 |
+
def __repr__(self):
|
732 |
+
import pprint
|
733 |
+
return 'CV2AugCompose({})'.format(pprint.pformat(self.trans_config))
|
734 |
+
|
735 |
+
|
736 |
+
TRANSFORM_MAPPING = {
|
737 |
+
"random_saturation": RandomSaturation,
|
738 |
+
"random_hue": RandomHue,
|
739 |
+
"random_perm": RandomPerm,
|
740 |
+
"random_contrast": RandomContrast,
|
741 |
+
"padding": Padding,
|
742 |
+
"random_brightness": RandomBrightness,
|
743 |
+
"random_hflip": RandomHFlip,
|
744 |
+
"random_resize": RandomResize,
|
745 |
+
"random_crop": RandomCrop,
|
746 |
+
"random_rotate": RandomRotate,
|
747 |
+
"resize": Resize,
|
748 |
+
}
|
749 |
+
|
750 |
+
TRANSFORM_SPEC = {
|
751 |
+
"random_style": [{
|
752 |
+
"args": {
|
753 |
+
"style_ratio": "ratio"
|
754 |
+
}
|
755 |
+
}],
|
756 |
+
"random_saturation": [{
|
757 |
+
"args": {
|
758 |
+
"lower": "lower",
|
759 |
+
"upper": "upper",
|
760 |
+
"saturation_ratio": "ratio"
|
761 |
+
}
|
762 |
+
}],
|
763 |
+
"random_hue": [{
|
764 |
+
"args": {
|
765 |
+
"delta": "delta",
|
766 |
+
"hue_ratio": "ratio"
|
767 |
+
}
|
768 |
+
}],
|
769 |
+
"ramdom_perm": [{
|
770 |
+
"args": {
|
771 |
+
"perm_ratio": "ratio"
|
772 |
+
}
|
773 |
+
}],
|
774 |
+
"random_contrast": [{
|
775 |
+
"args": {
|
776 |
+
"lower": "lower",
|
777 |
+
"upper": "upper",
|
778 |
+
"contrast_ratio": "ratio"
|
779 |
+
}
|
780 |
+
}],
|
781 |
+
"padding": [{
|
782 |
+
"args": {
|
783 |
+
"pad": "pad",
|
784 |
+
"pad_ratio": "ratio",
|
785 |
+
"mean": ["normalize", "mean_value"],
|
786 |
+
"allow_outside_center": "allow_outside_center"
|
787 |
+
}
|
788 |
+
}],
|
789 |
+
"random_brightness": [{
|
790 |
+
"args": {
|
791 |
+
"shift_value": "shift_value",
|
792 |
+
"brightness_ratio": "ratio"
|
793 |
+
}
|
794 |
+
}],
|
795 |
+
"random_hflip": [{
|
796 |
+
"args": {
|
797 |
+
"swap_pair": "swap_pair",
|
798 |
+
"flip_ratio": "ratio"
|
799 |
+
}
|
800 |
+
}],
|
801 |
+
"random_resize": [
|
802 |
+
{
|
803 |
+
"args": {
|
804 |
+
"method": "method",
|
805 |
+
"scale_range": "scale_range",
|
806 |
+
"aspect_range": "aspect_range",
|
807 |
+
"max_side_bound": "max_side_bound",
|
808 |
+
"resize_ratio": "ratio"
|
809 |
+
},
|
810 |
+
"when": {
|
811 |
+
"method": "random"
|
812 |
+
}
|
813 |
+
},
|
814 |
+
{
|
815 |
+
"args": {
|
816 |
+
"method": "method",
|
817 |
+
"scale_range": "scale_range",
|
818 |
+
"aspect_range": "aspect_range",
|
819 |
+
"target_size": "target_size",
|
820 |
+
"resize_ratio": "ratio"
|
821 |
+
},
|
822 |
+
"when": {
|
823 |
+
"method": "focus"
|
824 |
+
}
|
825 |
+
},
|
826 |
+
{
|
827 |
+
"args": {
|
828 |
+
"method": "method",
|
829 |
+
"aspect_range": "aspect_range",
|
830 |
+
"resize_bound": "resize_bound",
|
831 |
+
"resize_ratio": "ratio"
|
832 |
+
},
|
833 |
+
"when": {
|
834 |
+
"method": "bound"
|
835 |
+
}
|
836 |
+
},
|
837 |
+
],
|
838 |
+
"random_crop": [
|
839 |
+
{
|
840 |
+
"args": {
|
841 |
+
"crop_size": "crop_size",
|
842 |
+
"method": "method",
|
843 |
+
"crop_ratio": "ratio",
|
844 |
+
"allow_outside_center": "allow_outside_center"
|
845 |
+
},
|
846 |
+
"when": {
|
847 |
+
"method": "random"
|
848 |
+
}
|
849 |
+
},
|
850 |
+
{
|
851 |
+
"args": {
|
852 |
+
"crop_size": "crop_size",
|
853 |
+
"method": "method",
|
854 |
+
"crop_ratio": "ratio",
|
855 |
+
"allow_outside_center": "allow_outside_center"
|
856 |
+
},
|
857 |
+
"when": {
|
858 |
+
"method": "center"
|
859 |
+
}
|
860 |
+
},
|
861 |
+
{
|
862 |
+
"args": {
|
863 |
+
"crop_size": "crop_size",
|
864 |
+
"method": "method",
|
865 |
+
"crop_ratio": "ratio",
|
866 |
+
"grid": "grid",
|
867 |
+
"allow_outside_center": "allow_outside_center"
|
868 |
+
},
|
869 |
+
"when": {
|
870 |
+
"method": "grid"
|
871 |
+
}
|
872 |
+
},
|
873 |
+
],
|
874 |
+
"random_rotate": [{
|
875 |
+
"args": {
|
876 |
+
"max_degree": "rotate_degree",
|
877 |
+
"rotate_ratio": "ratio",
|
878 |
+
"mean": ["normalize", "mean_value"]
|
879 |
+
}
|
880 |
+
}],
|
881 |
+
"resize": [{
|
882 |
+
"args": {
|
883 |
+
"target_size": "target_size",
|
884 |
+
"min_side_length": "min_side_length",
|
885 |
+
"max_side_bound": "max_side_bound",
|
886 |
+
"max_side_length": "max_side_length"
|
887 |
+
}
|
888 |
+
}],
|
889 |
+
}
|
core/data/datasets/images/seg_data_tools/transforms.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
class Normalize(object):
|
7 |
+
"""Normalize a ``torch.tensor``
|
8 |
+
|
9 |
+
Args:
|
10 |
+
inputs (torch.tensor): tensor to be normalized.
|
11 |
+
mean: (list): the mean of RGB
|
12 |
+
std: (list): the std of RGB
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
Tensor: Normalized tensor.
|
16 |
+
"""
|
17 |
+
def __init__(self, div_value, mean, std):
|
18 |
+
self.div_value = div_value
|
19 |
+
self.mean = mean
|
20 |
+
self.std =std
|
21 |
+
|
22 |
+
def __call__(self, inputs):
|
23 |
+
inputs = inputs.div(self.div_value)
|
24 |
+
for t, m, s in zip(inputs, self.mean, self.std):
|
25 |
+
t.sub_(m).div_(s)
|
26 |
+
|
27 |
+
return inputs
|
28 |
+
|
29 |
+
|
30 |
+
class DeNormalize(object):
|
31 |
+
"""DeNormalize a ``torch.tensor``
|
32 |
+
|
33 |
+
Args:
|
34 |
+
inputs (torch.tensor): tensor to be normalized.
|
35 |
+
mean: (list): the mean of RGB
|
36 |
+
std: (list): the std of RGB
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
Tensor: Normalized tensor.
|
40 |
+
"""
|
41 |
+
def __init__(self, div_value, mean, std):
|
42 |
+
self.div_value = div_value
|
43 |
+
self.mean = mean
|
44 |
+
self.std =std
|
45 |
+
|
46 |
+
def __call__(self, inputs):
|
47 |
+
result = inputs.clone()
|
48 |
+
for i in range(result.size(0)):
|
49 |
+
result[i, :, :] = result[i, :, :] * self.std[i] + self.mean[i]
|
50 |
+
|
51 |
+
return result.mul_(self.div_value)
|
52 |
+
|
53 |
+
|
54 |
+
class ToTensor(object):
|
55 |
+
"""Convert a ``numpy.ndarray or Image`` to tensor.
|
56 |
+
|
57 |
+
See ``ToTensor`` for more details.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
inputs (numpy.ndarray or Image): Image to be converted to tensor.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
Tensor: Converted image.
|
64 |
+
"""
|
65 |
+
def __call__(self, inputs):
|
66 |
+
if isinstance(inputs, Image.Image):
|
67 |
+
channels = len(inputs.mode)
|
68 |
+
inputs = np.array(inputs)
|
69 |
+
inputs = inputs.reshape(inputs.shape[0], inputs.shape[1], channels)
|
70 |
+
inputs = torch.from_numpy(inputs.transpose(2, 0, 1))
|
71 |
+
else:
|
72 |
+
inputs = torch.from_numpy(inputs.transpose(2, 0, 1))
|
73 |
+
|
74 |
+
return inputs.float()
|
75 |
+
|
76 |
+
|
77 |
+
class ToLabel(object):
|
78 |
+
def __call__(self, inputs):
|
79 |
+
return torch.from_numpy(np.array(inputs)).long()
|
80 |
+
|
81 |
+
|
82 |
+
class ReLabel(object):
|
83 |
+
"""
|
84 |
+
255 indicate the background, relabel 255 to some value.
|
85 |
+
"""
|
86 |
+
def __init__(self, olabel, nlabel):
|
87 |
+
self.olabel = olabel
|
88 |
+
self.nlabel = nlabel
|
89 |
+
|
90 |
+
def __call__(self, inputs):
|
91 |
+
assert isinstance(inputs, torch.LongTensor), 'tensor needs to be LongTensor'
|
92 |
+
|
93 |
+
inputs[inputs == self.olabel] = self.nlabel
|
94 |
+
return inputs
|
95 |
+
|
96 |
+
|
97 |
+
class Compose(object):
|
98 |
+
|
99 |
+
def __init__(self, transforms):
|
100 |
+
self.transforms = transforms
|
101 |
+
|
102 |
+
def __call__(self, inputs):
|
103 |
+
for t in self.transforms:
|
104 |
+
inputs = t(inputs)
|
105 |
+
|
106 |
+
return inputs
|
core/data/datasets/images/seg_dataset_dev.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import itertools
|
8 |
+
from typing import Any, Dict, List, Tuple, Union
|
9 |
+
from torch.utils import data
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from PIL import Image
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
import core.data.transforms.seg_aug_dev as T
|
15 |
+
from core.data.transforms.seg_transforms_dev import AugInput, apply_transform_gens
|
16 |
+
|
17 |
+
|
18 |
+
class Instances:
|
19 |
+
"""
|
20 |
+
This class represents a list of instances in an image.
|
21 |
+
It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields".
|
22 |
+
All fields must have the same ``__len__`` which is the number of instances.
|
23 |
+
|
24 |
+
All other (non-field) attributes of this class are considered private:
|
25 |
+
they must start with '_' and are not modifiable by a user.
|
26 |
+
|
27 |
+
Some basic usage:
|
28 |
+
|
29 |
+
1. Set/get/check a field:
|
30 |
+
|
31 |
+
.. code-block:: python
|
32 |
+
|
33 |
+
instances.gt_boxes = Boxes(...)
|
34 |
+
print(instances.pred_masks) # a tensor of shape (N, H, W)
|
35 |
+
print('gt_masks' in instances)
|
36 |
+
|
37 |
+
2. ``len(instances)`` returns the number of instances
|
38 |
+
3. Indexing: ``instances[indices]`` will apply the indexing on all the fields
|
39 |
+
and returns a new :class:`Instances`.
|
40 |
+
Typically, ``indices`` is a integer vector of indices,
|
41 |
+
or a binary mask of length ``num_instances``
|
42 |
+
|
43 |
+
.. code-block:: python
|
44 |
+
|
45 |
+
category_3_detections = instances[instances.pred_classes == 3]
|
46 |
+
confident_detections = instances[instances.scores > 0.9]
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, image_size: Tuple[int, int], **kwargs: Any):
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
image_size (height, width): the spatial size of the image.
|
53 |
+
kwargs: fields to add to this `Instances`.
|
54 |
+
"""
|
55 |
+
self._image_size = image_size
|
56 |
+
self._fields: Dict[str, Any] = {}
|
57 |
+
for k, v in kwargs.items():
|
58 |
+
self.set(k, v)
|
59 |
+
|
60 |
+
@property
|
61 |
+
def image_size(self) -> Tuple[int, int]:
|
62 |
+
"""
|
63 |
+
Returns:
|
64 |
+
tuple: height, width
|
65 |
+
"""
|
66 |
+
return self._image_size
|
67 |
+
|
68 |
+
def __setattr__(self, name: str, val: Any) -> None:
|
69 |
+
if name.startswith("_"):
|
70 |
+
super().__setattr__(name, val)
|
71 |
+
else:
|
72 |
+
self.set(name, val)
|
73 |
+
|
74 |
+
def __getattr__(self, name: str) -> Any:
|
75 |
+
if name == "_fields" or name not in self._fields:
|
76 |
+
raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
|
77 |
+
return self._fields[name]
|
78 |
+
|
79 |
+
def set(self, name: str, value: Any) -> None:
|
80 |
+
"""
|
81 |
+
Set the field named `name` to `value`.
|
82 |
+
The length of `value` must be the number of instances,
|
83 |
+
and must agree with other existing fields in this object.
|
84 |
+
"""
|
85 |
+
data_len = len(value)
|
86 |
+
if len(self._fields):
|
87 |
+
assert (
|
88 |
+
len(self) == data_len
|
89 |
+
), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
|
90 |
+
self._fields[name] = value
|
91 |
+
|
92 |
+
def has(self, name: str) -> bool:
|
93 |
+
"""
|
94 |
+
Returns:
|
95 |
+
bool: whether the field called `name` exists.
|
96 |
+
"""
|
97 |
+
return name in self._fields
|
98 |
+
|
99 |
+
def remove(self, name: str) -> None:
|
100 |
+
"""
|
101 |
+
Remove the field called `name`.
|
102 |
+
"""
|
103 |
+
del self._fields[name]
|
104 |
+
|
105 |
+
def get(self, name: str) -> Any:
|
106 |
+
"""
|
107 |
+
Returns the field called `name`.
|
108 |
+
"""
|
109 |
+
return self._fields[name]
|
110 |
+
|
111 |
+
def get_fields(self) -> Dict[str, Any]:
|
112 |
+
"""
|
113 |
+
Returns:
|
114 |
+
dict: a dict which maps names (str) to data of the fields
|
115 |
+
|
116 |
+
Modifying the returned dict will modify this instance.
|
117 |
+
"""
|
118 |
+
return self._fields
|
119 |
+
|
120 |
+
# Tensor-like methods
|
121 |
+
def cuda(self, *args: Any, **kwargs: Any) -> "Instances":
|
122 |
+
"""
|
123 |
+
Returns:
|
124 |
+
Instances: all fields are called with a `cuda`, if the field has this method.
|
125 |
+
"""
|
126 |
+
ret = Instances(self._image_size)
|
127 |
+
for k, v in self._fields.items():
|
128 |
+
if hasattr(v, "cuda"):
|
129 |
+
v = v.cuda(*args, **kwargs)
|
130 |
+
ret.set(k, v)
|
131 |
+
return ret
|
132 |
+
|
133 |
+
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Instances":
|
134 |
+
"""
|
135 |
+
Args:
|
136 |
+
item: an index-like object and will be used to index all the fields.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
If `item` is a string, return the data in the corresponding field.
|
140 |
+
Otherwise, returns an `Instances` where all fields are indexed by `item`.
|
141 |
+
"""
|
142 |
+
if type(item) == int:
|
143 |
+
if item >= len(self) or item < -len(self):
|
144 |
+
raise IndexError("Instances index out of range!")
|
145 |
+
else:
|
146 |
+
item = slice(item, None, len(self))
|
147 |
+
|
148 |
+
ret = Instances(self._image_size)
|
149 |
+
for k, v in self._fields.items():
|
150 |
+
ret.set(k, v[item])
|
151 |
+
return ret
|
152 |
+
|
153 |
+
def __len__(self) -> int:
|
154 |
+
for v in self._fields.values():
|
155 |
+
# use __len__ because len() has to be int and is not friendly to tracing
|
156 |
+
return v.__len__()
|
157 |
+
raise NotImplementedError("Empty Instances does not support __len__!")
|
158 |
+
|
159 |
+
def __iter__(self):
|
160 |
+
raise NotImplementedError("`Instances` object is not iterable!")
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
def cat(instance_lists: List["Instances"]) -> "Instances":
|
164 |
+
"""
|
165 |
+
Args:
|
166 |
+
instance_lists (list[Instances])
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
Instances
|
170 |
+
"""
|
171 |
+
assert all(isinstance(i, Instances) for i in instance_lists)
|
172 |
+
assert len(instance_lists) > 0
|
173 |
+
if len(instance_lists) == 1:
|
174 |
+
return instance_lists[0]
|
175 |
+
|
176 |
+
image_size = instance_lists[0].image_size
|
177 |
+
if not isinstance(image_size, torch.Tensor): # could be a tensor in tracing
|
178 |
+
for i in instance_lists[1:]:
|
179 |
+
assert i.image_size == image_size
|
180 |
+
ret = Instances(image_size)
|
181 |
+
for k in instance_lists[0]._fields.keys():
|
182 |
+
values = [i.get(k) for i in instance_lists]
|
183 |
+
v0 = values[0]
|
184 |
+
if isinstance(v0, torch.Tensor):
|
185 |
+
values = torch.cat(values, dim=0)
|
186 |
+
elif isinstance(v0, list):
|
187 |
+
values = list(itertools.chain(*values))
|
188 |
+
elif hasattr(type(v0), "cat"):
|
189 |
+
values = type(v0).cat(values)
|
190 |
+
else:
|
191 |
+
raise ValueError("Unsupported type {} for concatenation".format(type(v0)))
|
192 |
+
ret.set(k, values)
|
193 |
+
return ret
|
194 |
+
|
195 |
+
def __str__(self) -> str:
|
196 |
+
s = self.__class__.__name__ + "("
|
197 |
+
s += "num_instances={}, ".format(len(self))
|
198 |
+
s += "image_height={}, ".format(self._image_size[0])
|
199 |
+
s += "image_width={}, ".format(self._image_size[1])
|
200 |
+
s += "fields=[{}])".format(", ".join((f"{k}: {v}" for k, v in self._fields.items())))
|
201 |
+
return s
|
202 |
+
|
203 |
+
__repr__ = __str__
|
204 |
+
|
205 |
+
|
206 |
+
class BitMasks:
|
207 |
+
"""
|
208 |
+
This class stores the segmentation masks for all objects in one image, in
|
209 |
+
the form of bitmaps.
|
210 |
+
|
211 |
+
Attributes:
|
212 |
+
tensor: bool Tensor of N,H,W, representing N instances in the image.
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self, tensor: Union[torch.Tensor, np.ndarray]):
|
216 |
+
"""
|
217 |
+
Args:
|
218 |
+
tensor: bool Tensor of N,H,W, representing N instances in the image.
|
219 |
+
"""
|
220 |
+
device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
|
221 |
+
tensor = torch.as_tensor(tensor, dtype=torch.bool, device=device)
|
222 |
+
assert tensor.dim() == 3, tensor.size()
|
223 |
+
self.image_size = tensor.shape[1:]
|
224 |
+
self.tensor = tensor
|
225 |
+
|
226 |
+
def to(self, *args: Any, **kwargs: Any) -> "BitMasks":
|
227 |
+
return BitMasks(self.tensor.to(*args, **kwargs))
|
228 |
+
|
229 |
+
@property
|
230 |
+
def device(self) -> torch.device:
|
231 |
+
return self.tensor.device
|
232 |
+
|
233 |
+
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "BitMasks":
|
234 |
+
"""
|
235 |
+
Returns:
|
236 |
+
BitMasks: Create a new :class:`BitMasks` by indexing.
|
237 |
+
|
238 |
+
The following usage are allowed:
|
239 |
+
|
240 |
+
1. `new_masks = masks[3]`: return a `BitMasks` which contains only one mask.
|
241 |
+
2. `new_masks = masks[2:10]`: return a slice of masks.
|
242 |
+
3. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
|
243 |
+
with `length = len(masks)`. Nonzero elements in the vector will be selected.
|
244 |
+
|
245 |
+
Note that the returned object might share storage with this object,
|
246 |
+
subject to Pytorch's indexing semantics.
|
247 |
+
"""
|
248 |
+
if isinstance(item, int):
|
249 |
+
return BitMasks(self.tensor[item].unsqueeze(0))
|
250 |
+
m = self.tensor[item]
|
251 |
+
assert m.dim() == 3, "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
|
252 |
+
item, m.shape
|
253 |
+
)
|
254 |
+
return BitMasks(m)
|
255 |
+
|
256 |
+
def __iter__(self) -> torch.Tensor:
|
257 |
+
yield from self.tensor
|
258 |
+
|
259 |
+
def __repr__(self) -> str:
|
260 |
+
s = self.__class__.__name__ + "("
|
261 |
+
s += "num_instances={})".format(len(self.tensor))
|
262 |
+
return s
|
263 |
+
|
264 |
+
def __len__(self) -> int:
|
265 |
+
return self.tensor.shape[0]
|
266 |
+
|
267 |
+
def nonempty(self) -> torch.Tensor:
|
268 |
+
"""
|
269 |
+
Find masks that are non-empty.
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
Tensor: a BoolTensor which represents
|
273 |
+
whether each mask is empty (False) or non-empty (True).
|
274 |
+
"""
|
275 |
+
return self.tensor.flatten(1).any(dim=1)
|
276 |
+
|
277 |
+
@staticmethod
|
278 |
+
def cat(bitmasks_list: List["BitMasks"]) -> "BitMasks":
|
279 |
+
"""
|
280 |
+
Concatenates a list of BitMasks into a single BitMasks
|
281 |
+
|
282 |
+
Arguments:
|
283 |
+
bitmasks_list (list[BitMasks])
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
BitMasks: the concatenated BitMasks
|
287 |
+
"""
|
288 |
+
assert isinstance(bitmasks_list, (list, tuple))
|
289 |
+
assert len(bitmasks_list) > 0
|
290 |
+
assert all(isinstance(bitmask, BitMasks) for bitmask in bitmasks_list)
|
291 |
+
|
292 |
+
cat_bitmasks = type(bitmasks_list[0])(torch.cat([bm.tensor for bm in bitmasks_list], dim=0))
|
293 |
+
return cat_bitmasks
|
core/data/datasets/images/smpl_data_tools/__pycache__/_smpl.cpython-312.pyc
ADDED
Binary file (22.5 kB). View file
|
|
core/data/datasets/images/smpl_data_tools/__pycache__/config_smpl.cpython-312.pyc
ADDED
Binary file (2.35 kB). View file
|
|
core/data/datasets/images/smpl_data_tools/__pycache__/image_ops.cpython-312.pyc
ADDED
Binary file (12.1 kB). View file
|
|
core/data/datasets/images/smpl_data_tools/__pycache__/tsv_file.cpython-312.pyc
ADDED
Binary file (11.4 kB). View file
|
|
core/data/datasets/images/smpl_data_tools/_smpl.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains the definition of the SMPL model
|
3 |
+
|
4 |
+
It is adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/)
|
5 |
+
"""
|
6 |
+
from __future__ import division
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import numpy as np
|
11 |
+
import scipy.sparse
|
12 |
+
import pickle
|
13 |
+
|
14 |
+
from . import config_smpl as cfg
|
15 |
+
|
16 |
+
def rodrigues(theta):
|
17 |
+
"""Convert axis-angle representation to rotation matrix.
|
18 |
+
Args:
|
19 |
+
theta: size = [B, 3]
|
20 |
+
Returns:
|
21 |
+
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
22 |
+
"""
|
23 |
+
l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
|
24 |
+
angle = torch.unsqueeze(l1norm, -1)
|
25 |
+
normalized = torch.div(theta, angle)
|
26 |
+
angle = angle * 0.5
|
27 |
+
v_cos = torch.cos(angle)
|
28 |
+
v_sin = torch.sin(angle)
|
29 |
+
quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
|
30 |
+
return quat2mat(quat)
|
31 |
+
|
32 |
+
def quat2mat(quat):
|
33 |
+
"""Convert quaternion coefficients to rotation matrix.
|
34 |
+
Args:
|
35 |
+
quat: size = [B, 4] 4 <===>(w, x, y, z)
|
36 |
+
Returns:
|
37 |
+
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
38 |
+
"""
|
39 |
+
norm_quat = quat
|
40 |
+
norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
|
41 |
+
w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
|
42 |
+
|
43 |
+
B = quat.size(0)
|
44 |
+
|
45 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
46 |
+
wx, wy, wz = w*x, w*y, w*z
|
47 |
+
xy, xz, yz = x*y, x*z, y*z
|
48 |
+
|
49 |
+
rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
|
50 |
+
2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
|
51 |
+
2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
|
52 |
+
return rotMat
|
53 |
+
|
54 |
+
def orthographic_projection(X, camera):
|
55 |
+
"""Perform orthographic projection of 3D points X using the camera parameters
|
56 |
+
Args:
|
57 |
+
X: size = [B, N, 3]
|
58 |
+
camera: size = [B, 3]
|
59 |
+
Returns:
|
60 |
+
Projected 2D points -- size = [B, N, 2]
|
61 |
+
"""
|
62 |
+
camera = camera.view(-1, 1, 3)
|
63 |
+
X_trans = X[:, :, :2] + camera[:, :, 1:]
|
64 |
+
shape = X_trans.shape
|
65 |
+
X_2d = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape)
|
66 |
+
return X_2d
|
67 |
+
|
68 |
+
class SMPL(nn.Module):
|
69 |
+
|
70 |
+
def __init__(self, gender='neutral'):
|
71 |
+
super(SMPL, self).__init__()
|
72 |
+
|
73 |
+
if gender=='m':
|
74 |
+
model_file=cfg.SMPL_Male
|
75 |
+
elif gender=='f':
|
76 |
+
model_file=cfg.SMPL_Female
|
77 |
+
else:
|
78 |
+
model_file=cfg.SMPL_FILE
|
79 |
+
|
80 |
+
smpl_model = pickle.load(open(model_file, 'rb'), encoding='latin1')
|
81 |
+
J_regressor = smpl_model['J_regressor'].tocoo()
|
82 |
+
row = J_regressor.row
|
83 |
+
col = J_regressor.col
|
84 |
+
data = J_regressor.data
|
85 |
+
i = torch.LongTensor([row, col])
|
86 |
+
v = torch.FloatTensor(data)
|
87 |
+
J_regressor_shape = [24, 6890]
|
88 |
+
self.register_buffer('J_regressor', torch.sparse.FloatTensor(i, v, J_regressor_shape).to_dense()) # 24*6890
|
89 |
+
self.register_buffer('weights', torch.FloatTensor(smpl_model['weights'])) # 6890 * 24
|
90 |
+
self.register_buffer('posedirs', torch.FloatTensor(smpl_model['posedirs'])) # 6890*3*207
|
91 |
+
self.register_buffer('v_template', torch.FloatTensor(smpl_model['v_template'])) # 6890*3
|
92 |
+
self.register_buffer('shapedirs', torch.FloatTensor(np.array(smpl_model['shapedirs']))) # # 6890*3*10
|
93 |
+
self.register_buffer('faces', torch.from_numpy(smpl_model['f'].astype(np.int64))) # 13776 * 3
|
94 |
+
self.register_buffer('kintree_table', torch.from_numpy(smpl_model['kintree_table'].astype(np.int64))) # 2*24
|
95 |
+
id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])}
|
96 |
+
self.register_buffer('parent', torch.LongTensor([id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])]))
|
97 |
+
|
98 |
+
self.pose_shape = [24, 3]
|
99 |
+
self.beta_shape = [10]
|
100 |
+
self.translation_shape = [3]
|
101 |
+
|
102 |
+
self.pose = torch.zeros(self.pose_shape)
|
103 |
+
self.beta = torch.zeros(self.beta_shape)
|
104 |
+
self.translation = torch.zeros(self.translation_shape)
|
105 |
+
|
106 |
+
self.verts = None
|
107 |
+
self.J = None
|
108 |
+
self.R = None
|
109 |
+
|
110 |
+
J_regressor_extra = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_TRAIN_EXTRA)).float() # 14*6890
|
111 |
+
self.register_buffer('J_regressor_extra', J_regressor_extra)
|
112 |
+
self.joints_idx = cfg.JOINTS_IDX
|
113 |
+
|
114 |
+
J_regressor_h36m_correct = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M_correct)).float() # 17*6890
|
115 |
+
self.register_buffer('J_regressor_h36m_correct', J_regressor_h36m_correct)
|
116 |
+
|
117 |
+
|
118 |
+
def forward(self, pose, beta):
|
119 |
+
device = pose.device
|
120 |
+
batch_size = pose.shape[0]
|
121 |
+
v_template = self.v_template[None, :]
|
122 |
+
shapedirs = self.shapedirs.view(-1,10)[None, :].expand(batch_size, -1, -1)
|
123 |
+
beta = beta[:, :, None]
|
124 |
+
# print(f'pose device {pose.device} beta device {beta.device} smpl parameter device {shapedirs.device}')
|
125 |
+
v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template
|
126 |
+
# batched sparse matmul not supported in pytorch
|
127 |
+
J = []
|
128 |
+
for i in range(batch_size):
|
129 |
+
J.append(torch.matmul(self.J_regressor, v_shaped[i]))
|
130 |
+
J = torch.stack(J, dim=0)
|
131 |
+
# input it rotmat: (bs,24,3,3)
|
132 |
+
if pose.ndimension() == 4:
|
133 |
+
R = pose
|
134 |
+
# input it rotmat: (bs,72)
|
135 |
+
elif pose.ndimension() == 2:
|
136 |
+
pose_cube = pose.view(-1, 3) # (batch_size * 24, 1, 3)
|
137 |
+
R = rodrigues(pose_cube).view(batch_size, 24, 3, 3)
|
138 |
+
R = R.view(batch_size, 24, 3, 3)
|
139 |
+
I_cube = torch.eye(3)[None, None, :].to(device)
|
140 |
+
|
141 |
+
lrotmin = (R[:,1:,:] - I_cube).view(batch_size, -1)
|
142 |
+
posedirs = self.posedirs.view(-1,207)[None, :].expand(batch_size, -1, -1)
|
143 |
+
v_posed = v_shaped + torch.matmul(posedirs, lrotmin[:, :, None]).view(-1, 6890, 3)
|
144 |
+
J_ = J.clone()
|
145 |
+
J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :]
|
146 |
+
G_ = torch.cat([R, J_[:, :, :, None]], dim=-1)
|
147 |
+
pad_row = torch.FloatTensor([0,0,0,1]).to(device).view(1,1,1,4).expand(batch_size, 24, -1, -1)
|
148 |
+
G_ = torch.cat([G_, pad_row], dim=2)
|
149 |
+
G = [G_[:, 0].clone()]
|
150 |
+
for i in range(1, 24):
|
151 |
+
G.append(torch.matmul(G[self.parent[i-1]], G_[:, i, :, :]))
|
152 |
+
G = torch.stack(G, dim=1)
|
153 |
+
|
154 |
+
rest = torch.cat([J, torch.zeros(batch_size, 24, 1).to(device)], dim=2).view(batch_size, 24, 4, 1)
|
155 |
+
zeros = torch.zeros(batch_size, 24, 4, 3).to(device)
|
156 |
+
rest = torch.cat([zeros, rest], dim=-1)
|
157 |
+
rest = torch.matmul(G, rest)
|
158 |
+
G = G - rest
|
159 |
+
T = torch.matmul(self.weights, G.permute(1,0,2,3).contiguous().view(24,-1)).view(6890, batch_size, 4, 4).transpose(0,1)
|
160 |
+
rest_shape_h = torch.cat([v_posed, torch.ones_like(v_posed)[:, :, [0]]], dim=-1)
|
161 |
+
v = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0]
|
162 |
+
return v
|
163 |
+
|
164 |
+
def get_joints(self, vertices):
|
165 |
+
"""
|
166 |
+
This method is used to get the joint locations from the SMPL mesh
|
167 |
+
Input:
|
168 |
+
vertices: size = (B, 6890, 3)
|
169 |
+
Output:
|
170 |
+
3D joints: size = (B, 38, 3)
|
171 |
+
"""
|
172 |
+
joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor])
|
173 |
+
joints_extra = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_extra])
|
174 |
+
joints = torch.cat((joints, joints_extra), dim=1)
|
175 |
+
joints = joints[:, cfg.JOINTS_IDX]
|
176 |
+
return joints
|
177 |
+
|
178 |
+
def get_h36m_joints(self, vertices):
|
179 |
+
"""
|
180 |
+
This method is used to get the joint locations from the SMPL mesh
|
181 |
+
Input:
|
182 |
+
vertices: size = (B, 6890, 3)
|
183 |
+
Output:
|
184 |
+
3D joints: size = (B, 17, 3)
|
185 |
+
"""
|
186 |
+
joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct])
|
187 |
+
return joints
|
188 |
+
|
189 |
+
class SparseMM(torch.autograd.Function):
|
190 |
+
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
|
191 |
+
The builtin matrix multiplication operation does not support backpropagation in some cases.
|
192 |
+
"""
|
193 |
+
@staticmethod
|
194 |
+
def forward(ctx, sparse, dense):
|
195 |
+
ctx.req_grad = dense.requires_grad
|
196 |
+
ctx.save_for_backward(sparse)
|
197 |
+
return torch.matmul(sparse, dense)
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
def backward(ctx, grad_output):
|
201 |
+
grad_input = None
|
202 |
+
sparse, = ctx.saved_tensors
|
203 |
+
if ctx.req_grad:
|
204 |
+
grad_input = torch.matmul(sparse.t(), grad_output)
|
205 |
+
return None, grad_input
|
206 |
+
|
207 |
+
def spmm(sparse, dense):
|
208 |
+
return SparseMM.apply(sparse, dense)
|
209 |
+
|
210 |
+
|
211 |
+
def scipy_to_pytorch(A, U, D):
|
212 |
+
"""Convert scipy sparse matrices to pytorch sparse matrix."""
|
213 |
+
ptU = []
|
214 |
+
ptD = []
|
215 |
+
|
216 |
+
for i in range(len(U)):
|
217 |
+
u = scipy.sparse.coo_matrix(U[i])
|
218 |
+
i = torch.LongTensor(np.array([u.row, u.col]))
|
219 |
+
v = torch.FloatTensor(u.data)
|
220 |
+
# return index value and shape instead of a sparse tensor to avoid bug in multi-worker
|
221 |
+
ptU.append([i, v, u.shape])
|
222 |
+
for i in range(len(D)):
|
223 |
+
d = scipy.sparse.coo_matrix(D[i])
|
224 |
+
i = torch.LongTensor(np.array([d.row, d.col]))
|
225 |
+
v = torch.FloatTensor(d.data)
|
226 |
+
# return index value and shape instead of a sparse tensor to avoid bug in multi-worker
|
227 |
+
ptD.append([i, v, d.shape])
|
228 |
+
|
229 |
+
return ptU, ptD
|
230 |
+
|
231 |
+
|
232 |
+
def adjmat_sparse(adjmat, nsize=1):
|
233 |
+
"""Create row-normalized sparse graph adjacency matrix."""
|
234 |
+
adjmat = scipy.sparse.csr_matrix(adjmat)
|
235 |
+
if nsize > 1:
|
236 |
+
orig_adjmat = adjmat.copy()
|
237 |
+
for _ in range(1, nsize):
|
238 |
+
adjmat = adjmat * orig_adjmat
|
239 |
+
adjmat.data = np.ones_like(adjmat.data)
|
240 |
+
for i in range(adjmat.shape[0]):
|
241 |
+
adjmat[i,i] = 1
|
242 |
+
num_neighbors = np.array(1 / adjmat.sum(axis=-1))
|
243 |
+
adjmat = adjmat.multiply(num_neighbors)
|
244 |
+
adjmat = scipy.sparse.coo_matrix(adjmat)
|
245 |
+
row = adjmat.row
|
246 |
+
col = adjmat.col
|
247 |
+
data = adjmat.data
|
248 |
+
i = torch.LongTensor(np.array([row, col]))
|
249 |
+
v = torch.from_numpy(data).float()
|
250 |
+
# adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape)
|
251 |
+
# return index value and shape instead of a sparse tensor to avoid bug in multi-worker
|
252 |
+
|
253 |
+
|
254 |
+
return [i, v, adjmat.shape]
|
255 |
+
|
256 |
+
def get_graph_params(filename, nsize=1):
|
257 |
+
"""Load and process graph adjacency matrix and upsampling/downsampling matrices."""
|
258 |
+
data = np.load(filename, encoding='latin1', allow_pickle=True)
|
259 |
+
A = data['A']
|
260 |
+
U = data['U']
|
261 |
+
D = data['D']
|
262 |
+
U, D = scipy_to_pytorch(A, U, D)
|
263 |
+
A = [adjmat_sparse(a, nsize=nsize) for a in A]
|
264 |
+
return A, U, D
|
265 |
+
|
266 |
+
|
267 |
+
class Mesh(object):
|
268 |
+
"""Mesh object that is used for handling certain graph operations."""
|
269 |
+
def __init__(self, filename=cfg.SMPL_sampling_matrix,
|
270 |
+
num_downsampling=1, nsize=1, device=torch.device('cuda')):
|
271 |
+
super(Mesh, self).__init__()
|
272 |
+
|
273 |
+
self.device = device
|
274 |
+
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
|
275 |
+
# self._A = [a.to(device) for a in self._A]
|
276 |
+
# self._U = [u.to(device) for u in self._U]
|
277 |
+
# self._D = [d.to(device) for d in self._D]
|
278 |
+
|
279 |
+
self.num_downsampling = num_downsampling
|
280 |
+
|
281 |
+
# load template vertices from SMPL and normalize them
|
282 |
+
smpl = SMPL()
|
283 |
+
ref_vertices = smpl.v_template
|
284 |
+
center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None]
|
285 |
+
ref_vertices -= center
|
286 |
+
ref_vertices /= ref_vertices.abs().max().item()
|
287 |
+
|
288 |
+
self._ref_vertices = ref_vertices.to(device)
|
289 |
+
self.faces = smpl.faces.int().to(device)
|
290 |
+
self.sparse = False
|
291 |
+
|
292 |
+
@property
|
293 |
+
def ref_vertices(self):
|
294 |
+
"""Return the template vertices at the specified subsampling level."""
|
295 |
+
_D = [torch.sparse.FloatTensor(item[0], item[1], item[2]).to(self.device) for item in self._D]
|
296 |
+
ref_vertices = self._ref_vertices
|
297 |
+
for i in range(self.num_downsampling):
|
298 |
+
ref_vertices = torch.spmm(_D[i], ref_vertices)
|
299 |
+
return ref_vertices
|
300 |
+
|
301 |
+
def downsample(self, x, n1=0, n2=None):
|
302 |
+
_D = [torch.sparse.FloatTensor(item[0], item[1], item[2]).to(self.device) for item in self._D]
|
303 |
+
"""Downsample mesh."""
|
304 |
+
if n2 is None:
|
305 |
+
n2 = self.num_downsampling
|
306 |
+
if x.ndimension() < 3:
|
307 |
+
for i in range(n1, n2):
|
308 |
+
x = spmm(_D[i], x)
|
309 |
+
elif x.ndimension() == 3:
|
310 |
+
out = []
|
311 |
+
for i in range(x.shape[0]):
|
312 |
+
y = x[i]
|
313 |
+
for j in range(n1, n2):
|
314 |
+
y = spmm(_D[j], y)
|
315 |
+
out.append(y)
|
316 |
+
x = torch.stack(out, dim=0)
|
317 |
+
return x
|
318 |
+
|
319 |
+
def upsample(self, x, n1=1, n2=0):
|
320 |
+
_U = [torch.sparse.FloatTensor(item[0], item[1], item[2]).to(self.device) for item in self._U]
|
321 |
+
"""Upsample mesh."""
|
322 |
+
if x.ndimension() < 3:
|
323 |
+
for i in reversed(range(n2, n1)):
|
324 |
+
x = spmm(_U[i], x)
|
325 |
+
elif x.ndimension() == 3:
|
326 |
+
out = []
|
327 |
+
for i in range(x.shape[0]):
|
328 |
+
y = x[i]
|
329 |
+
for j in reversed(range(n2, n1)):
|
330 |
+
y = spmm(_U[j], y)
|
331 |
+
out.append(y)
|
332 |
+
x = torch.stack(out, dim=0)
|
333 |
+
return x
|
core/data/datasets/images/smpl_data_tools/config_smpl.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains definitions of useful data stuctures and the paths
|
3 |
+
for the datasets and data files necessary to run the code.
|
4 |
+
|
5 |
+
Adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) and Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
|
6 |
+
|
7 |
+
"""
|
8 |
+
|
9 |
+
from os.path import join,split,abspath
|
10 |
+
# from os import getcwd
|
11 |
+
import sys
|
12 |
+
|
13 |
+
dirname, filename = split(abspath(sys.argv[0]))
|
14 |
+
|
15 |
+
folder_path = join(dirname,'core/data/datasets/images/smpl_data_tools/smpl_modeling/')
|
16 |
+
# print("current path {} ".format(folder_path))
|
17 |
+
JOINT_REGRESSOR_TRAIN_EXTRA = folder_path + 'data/J_regressor_extra.npy'
|
18 |
+
JOINT_REGRESSOR_H36M_correct = folder_path + 'data/J_regressor_h36m_correct.npy'
|
19 |
+
SMPL_FILE = folder_path + 'data/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl'
|
20 |
+
SMPL_Male = folder_path + 'data/basicModel_m_lbs_10_207_0_v1.0.0.pkl'
|
21 |
+
SMPL_Female = folder_path + 'data/basicModel_f_lbs_10_207_0_v1.0.0.pkl'
|
22 |
+
SMPL_sampling_matrix = folder_path + 'data/mesh_downsampling.npz'
|
23 |
+
MANO_FILE = folder_path + 'data/MANO_RIGHT.pkl'
|
24 |
+
MANO_sampling_matrix = folder_path + 'data/mano_downsampling.npz'
|
25 |
+
|
26 |
+
JOINTS_IDX = [8, 5, 29, 30, 4, 7, 21, 19, 17, 16, 18, 20, 31, 32, 33, 34, 35, 36, 37, 24, 26, 25, 28, 27]
|
27 |
+
|
28 |
+
|
29 |
+
"""
|
30 |
+
We follow the body joint definition, loss functions, and evaluation metrics from
|
31 |
+
open source project GraphCMR (https://github.com/nkolot/GraphCMR/)
|
32 |
+
|
33 |
+
Each dataset uses different sets of joints.
|
34 |
+
We use a superset of 24 joints such that we include all joints from every dataset.
|
35 |
+
If a dataset doesn't provide annotations for a specific joint, we simply ignore it.
|
36 |
+
The joints used here are:
|
37 |
+
"""
|
38 |
+
J24_NAME = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder',
|
39 |
+
'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear')
|
40 |
+
H36M_J17_NAME = ( 'Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head',
|
41 |
+
'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist')
|
42 |
+
J24_TO_J14 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
|
43 |
+
H36M_J17_TO_J14 = [3, 2, 1, 4, 5, 6, 16, 15, 14, 11, 12, 13, 8, 10]
|
44 |
+
|
45 |
+
"""
|
46 |
+
We follow the hand joint definition and mesh topology from
|
47 |
+
open source project Manopth (https://github.com/hassony2/manopth)
|
48 |
+
|
49 |
+
The hand joints used here are:
|
50 |
+
"""
|
51 |
+
J_NAME = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1',
|
52 |
+
'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
|
53 |
+
ROOT_INDEX = 0
|
core/data/datasets/images/smpl_data_tools/image_ops.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ----------------------------------------------------------------------------------------------
|
2 |
+
# METRO (https://github.com/microsoft/MeshTransformer)
|
3 |
+
# Copyright (c) Microsoft Corporation. All Rights Reserved [see https://github.com/microsoft/MeshTransformer/blob/main/LICENSE for details]
|
4 |
+
# Licensed under the MIT license.
|
5 |
+
# ----------------------------------------------------------------------------------------------
|
6 |
+
"""
|
7 |
+
Image processing tools
|
8 |
+
Modified from open source projects:
|
9 |
+
(https://github.com/nkolot/GraphCMR/)
|
10 |
+
(https://github.com/open-mmlab/mmdetection)
|
11 |
+
"""
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import base64
|
15 |
+
import cv2
|
16 |
+
import torch
|
17 |
+
import scipy.misc
|
18 |
+
|
19 |
+
def img_from_base64(imagestring):
|
20 |
+
try:
|
21 |
+
jpgbytestring = base64.b64decode(imagestring)
|
22 |
+
nparr = np.frombuffer(jpgbytestring, np.uint8)
|
23 |
+
r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
24 |
+
return r
|
25 |
+
except ValueError:
|
26 |
+
return None
|
27 |
+
|
28 |
+
def myimrotate(img, angle, center=None, scale=1.0, border_value=0, auto_bound=False):
|
29 |
+
if center is not None and auto_bound:
|
30 |
+
raise ValueError('`auto_bound` conflicts with `center`')
|
31 |
+
try:
|
32 |
+
h, w = img.shape[:2]
|
33 |
+
except:
|
34 |
+
h, w = img.size[:2]
|
35 |
+
if center is None:
|
36 |
+
center = ((w - 1) * 0.5, (h - 1) * 0.5)
|
37 |
+
assert isinstance(center, tuple)
|
38 |
+
|
39 |
+
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
40 |
+
if auto_bound:
|
41 |
+
cos = np.abs(matrix[0, 0])
|
42 |
+
sin = np.abs(matrix[0, 1])
|
43 |
+
new_w = h * sin + w * cos
|
44 |
+
new_h = h * cos + w * sin
|
45 |
+
matrix[0, 2] += (new_w - w) * 0.5
|
46 |
+
matrix[1, 2] += (new_h - h) * 0.5
|
47 |
+
w = int(np.round(new_w))
|
48 |
+
h = int(np.round(new_h))
|
49 |
+
rotated = cv2.warpAffine(img, matrix, (w, h), borderValue=border_value)
|
50 |
+
return rotated
|
51 |
+
|
52 |
+
def myimresize(img, size, return_scale=False, interpolation='bilinear'):
|
53 |
+
|
54 |
+
try:
|
55 |
+
h, w = img.shape[:2]
|
56 |
+
except:
|
57 |
+
h, w = img.size[:2]
|
58 |
+
resized_img = cv2.resize(
|
59 |
+
img, (size[0],size[1]), interpolation=cv2.INTER_LINEAR)
|
60 |
+
if not return_scale:
|
61 |
+
return resized_img
|
62 |
+
else:
|
63 |
+
w_scale = size[0] / w
|
64 |
+
h_scale = size[1] / h
|
65 |
+
return resized_img, w_scale, h_scale
|
66 |
+
|
67 |
+
|
68 |
+
def get_transform(center, scale, res, rot=0):
|
69 |
+
"""Generate transformation matrix."""
|
70 |
+
h = 200 * scale
|
71 |
+
t = np.zeros((3, 3))
|
72 |
+
t[0, 0] = float(res[1]) / h
|
73 |
+
t[1, 1] = float(res[0]) / h
|
74 |
+
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
|
75 |
+
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
|
76 |
+
t[2, 2] = 1
|
77 |
+
if not rot == 0:
|
78 |
+
rot = -rot # To match direction of rotation from cropping
|
79 |
+
rot_mat = np.zeros((3,3))
|
80 |
+
rot_rad = rot * np.pi / 180
|
81 |
+
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
|
82 |
+
rot_mat[0,:2] = [cs, -sn]
|
83 |
+
rot_mat[1,:2] = [sn, cs]
|
84 |
+
rot_mat[2,2] = 1
|
85 |
+
# Need to rotate around center
|
86 |
+
t_mat = np.eye(3)
|
87 |
+
t_mat[0,2] = -res[1]/2
|
88 |
+
t_mat[1,2] = -res[0]/2
|
89 |
+
t_inv = t_mat.copy()
|
90 |
+
t_inv[:2,2] *= -1
|
91 |
+
t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
|
92 |
+
return t
|
93 |
+
|
94 |
+
def transform(pt, center, scale, res, invert=0, rot=0):
|
95 |
+
"""Transform pixel location to different reference."""
|
96 |
+
t = get_transform(center, scale, res, rot=rot)
|
97 |
+
if invert:
|
98 |
+
# t = np.linalg.inv(t)
|
99 |
+
t_torch = torch.from_numpy(t)
|
100 |
+
t_torch = torch.inverse(t_torch)
|
101 |
+
t = t_torch.numpy()
|
102 |
+
new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T
|
103 |
+
new_pt = np.dot(t, new_pt)
|
104 |
+
return new_pt[:2].astype(int)+1
|
105 |
+
|
106 |
+
def crop(img, center, scale, res, rot=0):
|
107 |
+
"""Crop image according to the supplied bounding box."""
|
108 |
+
# Upper left point
|
109 |
+
ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
|
110 |
+
# Bottom right point
|
111 |
+
br = np.array(transform([res[0]+1,
|
112 |
+
res[1]+1], center, scale, res, invert=1))-1
|
113 |
+
# Padding so that when rotated proper amount of context is included
|
114 |
+
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
|
115 |
+
if not rot == 0:
|
116 |
+
ul -= pad
|
117 |
+
br += pad
|
118 |
+
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
119 |
+
try:
|
120 |
+
image_shape = img.shape
|
121 |
+
except:
|
122 |
+
image_shape = img.size
|
123 |
+
if len(image_shape) > 2:
|
124 |
+
new_shape += [image_shape[2]]
|
125 |
+
new_img = np.zeros(new_shape)
|
126 |
+
|
127 |
+
# Range to fill new array
|
128 |
+
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
|
129 |
+
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
|
130 |
+
# Range to sample from original image
|
131 |
+
old_x = max(0, ul[0]), min(len(img[0]), br[0])
|
132 |
+
old_y = max(0, ul[1]), min(len(img), br[1])
|
133 |
+
|
134 |
+
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
|
135 |
+
old_x[0]:old_x[1]]
|
136 |
+
if not rot == 0:
|
137 |
+
# Remove padding
|
138 |
+
# new_img = scipy.misc.imrotate(new_img, rot)
|
139 |
+
new_img = myimrotate(new_img, rot)
|
140 |
+
new_img = new_img[pad:-pad, pad:-pad]
|
141 |
+
|
142 |
+
# new_img = scipy.misc.imresize(new_img, res)
|
143 |
+
new_img = myimresize(new_img, [res[0], res[1]])
|
144 |
+
return new_img
|
145 |
+
|
146 |
+
def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
|
147 |
+
"""'Undo' the image cropping/resizing.
|
148 |
+
This function is used when evaluating mask/part segmentation.
|
149 |
+
"""
|
150 |
+
try:
|
151 |
+
res = img.shape[:2]
|
152 |
+
except:
|
153 |
+
res = img.size[:2]
|
154 |
+
# Upper left point
|
155 |
+
ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
|
156 |
+
# Bottom right point
|
157 |
+
br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1
|
158 |
+
# size of cropped image
|
159 |
+
crop_shape = [br[1] - ul[1], br[0] - ul[0]]
|
160 |
+
|
161 |
+
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
162 |
+
|
163 |
+
try:
|
164 |
+
image_shape = img.shape
|
165 |
+
except:
|
166 |
+
image_shape = img.size
|
167 |
+
|
168 |
+
if len(image_shape) > 2:
|
169 |
+
new_shape += [image_shape[2]]
|
170 |
+
new_img = np.zeros(orig_shape, dtype=np.uint8)
|
171 |
+
# Range to fill new array
|
172 |
+
new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
|
173 |
+
new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
|
174 |
+
# Range to sample from original image
|
175 |
+
old_x = max(0, ul[0]), min(orig_shape[1], br[0])
|
176 |
+
old_y = max(0, ul[1]), min(orig_shape[0], br[1])
|
177 |
+
# img = scipy.misc.imresize(img, crop_shape, interp='nearest')
|
178 |
+
img = myimresize(img, [crop_shape[0],crop_shape[1]])
|
179 |
+
new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
|
180 |
+
return new_img
|
181 |
+
|
182 |
+
def rot_aa(aa, rot):
|
183 |
+
"""Rotate axis angle parameters."""
|
184 |
+
# pose parameters
|
185 |
+
R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
186 |
+
[np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
|
187 |
+
[0, 0, 1]])
|
188 |
+
# find the rotation of the body in camera frame
|
189 |
+
per_rdg, _ = cv2.Rodrigues(aa)
|
190 |
+
# apply the global rotation to the global orientation
|
191 |
+
resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
|
192 |
+
aa = (resrot.T)[0]
|
193 |
+
return aa
|
194 |
+
|
195 |
+
def flip_img(img):
|
196 |
+
"""Flip rgb images or masks.
|
197 |
+
channels come last, e.g. (256,256,3).
|
198 |
+
"""
|
199 |
+
img = np.fliplr(img)
|
200 |
+
return img
|
201 |
+
|
202 |
+
def flip_kp(kp):
|
203 |
+
"""Flip keypoints."""
|
204 |
+
flipped_parts = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22]
|
205 |
+
kp = kp[flipped_parts]
|
206 |
+
kp[:,0] = - kp[:,0]
|
207 |
+
return kp
|
208 |
+
|
209 |
+
def flip_pose(pose):
|
210 |
+
"""Flip pose.
|
211 |
+
The flipping is based on SMPL parameters.
|
212 |
+
"""
|
213 |
+
flippedParts = [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13,
|
214 |
+
14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33,
|
215 |
+
34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41,
|
216 |
+
45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55,
|
217 |
+
56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68]
|
218 |
+
pose = pose[flippedParts]
|
219 |
+
# we also negate the second and the third dimension of the axis-angle
|
220 |
+
pose[1::3] = -pose[1::3]
|
221 |
+
pose[2::3] = -pose[2::3]
|
222 |
+
return pose
|
223 |
+
|
224 |
+
def flip_aa(aa):
|
225 |
+
"""Flip axis-angle representation.
|
226 |
+
We negate the second and the third dimension of the axis-angle.
|
227 |
+
"""
|
228 |
+
aa[1] = -aa[1]
|
229 |
+
aa[2] = -aa[2]
|
230 |
+
return aa
|
core/data/datasets/images/smpl_data_tools/smpl_modeling/data/J_regressor_extra.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:40dfaa71fcc7eed6966a6ed046311b7e8ea0eb9a5172b298e3df6fc4b6ec0eb0
|
3 |
+
size 771808
|