tuandunghcmut commited on
Commit
345ee20
·
verified ·
1 Parent(s): 1c3e162

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. LICENSE +21 -0
  3. assets/framework.png +0 -0
  4. assets/teaser.png +0 -0
  5. core/__init__.py +0 -0
  6. core/__pycache__/__init__.cpython-312.pyc +0 -0
  7. core/__pycache__/comm_.cpython-312.pyc +0 -0
  8. core/__pycache__/config.cpython-312.pyc +0 -0
  9. core/__pycache__/distributed_utils.cpython-312.pyc +0 -0
  10. core/__pycache__/make_param_group.cpython-312.pyc +0 -0
  11. core/__pycache__/memory.cpython-312.pyc +0 -0
  12. core/__pycache__/utils.cpython-312.pyc +0 -0
  13. core/clipping.py +92 -0
  14. core/comm_.py +307 -0
  15. core/config.py +600 -0
  16. core/data/__init__.py +0 -0
  17. core/data/__pycache__/__init__.cpython-312.pyc +0 -0
  18. core/data/datasets/__init__.py +15 -0
  19. core/data/datasets/__pycache__/__init__.cpython-312.pyc +0 -0
  20. core/data/datasets/images/__pycache__/image_caption_dataset.cpython-312.pyc +0 -0
  21. core/data/datasets/images/__pycache__/multi_posedataset.cpython-312.pyc +0 -0
  22. core/data/datasets/images/__pycache__/parsing_dataset.cpython-312.pyc +0 -0
  23. core/data/datasets/images/__pycache__/pedattr_dataset.cpython-312.pyc +0 -0
  24. core/data/datasets/images/__pycache__/peddet_dataset_v2.cpython-312.pyc +0 -0
  25. core/data/datasets/images/__pycache__/pos_dataset_dev.cpython-312.pyc +0 -0
  26. core/data/datasets/images/__pycache__/seg_dataset_dev.cpython-312.pyc +0 -0
  27. core/data/datasets/images/__pycache__/smpl_dataset_v2.cpython-312.pyc +0 -0
  28. core/data/datasets/images/image_caption_dataset.py +261 -0
  29. core/data/datasets/images/multi_posedataset.py +413 -0
  30. core/data/datasets/images/parsing_dataset.py +1084 -0
  31. core/data/datasets/images/pedattr_dataset.py +665 -0
  32. core/data/datasets/images/peddet_dataset_v2.py +578 -0
  33. core/data/datasets/images/pos_dataset_dev.py +713 -0
  34. core/data/datasets/images/resources/CHval.odgt +3 -0
  35. core/data/datasets/images/resources/COCO_val2017_detections_AP_H_56_person.json +3 -0
  36. core/data/datasets/images/resources/mpii_gt_val.mat +3 -0
  37. core/data/datasets/images/resources/test_caltech_heavy_1xnew.odgt +0 -0
  38. core/data/datasets/images/seg_data_tools/__init__.py +0 -0
  39. core/data/datasets/images/seg_data_tools/collate.py +143 -0
  40. core/data/datasets/images/seg_data_tools/cv2_aug_transforms.py +889 -0
  41. core/data/datasets/images/seg_data_tools/transforms.py +106 -0
  42. core/data/datasets/images/seg_dataset_dev.py +293 -0
  43. core/data/datasets/images/smpl_data_tools/__pycache__/_smpl.cpython-312.pyc +0 -0
  44. core/data/datasets/images/smpl_data_tools/__pycache__/config_smpl.cpython-312.pyc +0 -0
  45. core/data/datasets/images/smpl_data_tools/__pycache__/image_ops.cpython-312.pyc +0 -0
  46. core/data/datasets/images/smpl_data_tools/__pycache__/tsv_file.cpython-312.pyc +0 -0
  47. core/data/datasets/images/smpl_data_tools/_smpl.py +333 -0
  48. core/data/datasets/images/smpl_data_tools/config_smpl.py +53 -0
  49. core/data/datasets/images/smpl_data_tools/image_ops.py +230 -0
  50. core/data/datasets/images/smpl_data_tools/smpl_modeling/data/J_regressor_extra.npy +3 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ core/data/datasets/images/resources/CHval.odgt filter=lfs diff=lfs merge=lfs -text
37
+ core/data/datasets/images/resources/COCO_val2017_detections_AP_H_56_person.json filter=lfs diff=lfs merge=lfs -text
38
+ core/data/datasets/images/resources/mpii_gt_val.mat filter=lfs diff=lfs merge=lfs -text
39
+ core/solvers/utils/pycocoevalcap/meteor/meteor-1.5.jar filter=lfs diff=lfs merge=lfs -text
40
+ core/solvers/utils/pycocoevalcap/spice/lib/Meteor-1.5.jar filter=lfs diff=lfs merge=lfs -text
41
+ core/solvers/utils/pycocoevalcap/spice/lib/guava-19.0.jar filter=lfs diff=lfs merge=lfs -text
42
+ core/solvers/utils/pycocoevalcap/spice/spice-1.0.jar filter=lfs diff=lfs merge=lfs -text
43
+ core/solvers/utils/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Shanghai AI Laboratory
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
assets/framework.png ADDED
assets/teaser.png ADDED
core/__init__.py ADDED
File without changes
core/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (183 Bytes). View file
 
core/__pycache__/comm_.cpython-312.pyc ADDED
Binary file (11.9 kB). View file
 
core/__pycache__/config.cpython-312.pyc ADDED
Binary file (26 kB). View file
 
core/__pycache__/distributed_utils.cpython-312.pyc ADDED
Binary file (91.8 kB). View file
 
core/__pycache__/make_param_group.cpython-312.pyc ADDED
Binary file (4.17 kB). View file
 
core/__pycache__/memory.cpython-312.pyc ADDED
Binary file (3.77 kB). View file
 
core/__pycache__/utils.cpython-312.pyc ADDED
Binary file (57.9 kB). View file
 
core/clipping.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from easydict import EasyDict as edict
7
+
8
+ import warnings
9
+ from torch._six import inf
10
+ from core.utils import sync_print
11
+
12
+ # return if any inf/nan
13
+ # div norm by loss_scale, for 'real' norm
14
+ # if auto_clipper provided, compute max_norm using auto_clipper
15
+ # else, using give max_norm
16
+ def clip_grad_norm_(parameters, max_norm=1000000, norm_type=2, auto_clipper=None, loss_scale=1.0):
17
+ if isinstance(parameters, torch.Tensor):
18
+ parameters = [parameters]
19
+ parameters = list(filter(lambda p: p[1].grad is not None, parameters))
20
+
21
+ if len(parameters) == 0: return None
22
+
23
+ max_norm = float(max_norm)
24
+ norm_type = float(norm_type)
25
+ if norm_type == inf:
26
+ total_norm = max(p.grad.data.abs().max() for p in parameters)
27
+ else:
28
+ total_norm = 0
29
+ for name,p in parameters:
30
+ param_norm = p.grad.data.norm(norm_type)
31
+ total_norm += param_norm.item() ** norm_type
32
+
33
+ total_norm = total_norm ** (1. / norm_type)
34
+
35
+ # check inf/nan
36
+ overflow_num = torch.zeros(1)
37
+ if np.isinf(total_norm) or np.isnan(total_norm):
38
+ overflow_num[0] = 1
39
+ torch.distributed.all_reduce.allreduce(overflow_num)
40
+
41
+ if overflow_num > 0:
42
+ for name,p in parameters:
43
+ p.grad.data.fill_(float('nan'))
44
+ sync_print('total_norm is inf({})/nan({}), skip clipping!!!'.format(np.isinf(total_norm), np.isnan(total_norm)))
45
+ return total_norm
46
+
47
+ # rescale the total_norm by loss_scale
48
+ total_norm /= loss_scale
49
+
50
+ # update auto_clipper, compute max_norm
51
+ if auto_clipper is not None:
52
+ max_norm = auto_clipper.update(total_norm)
53
+
54
+ # do clipping
55
+ clip_coef = max_norm / (total_norm + 1e-6)
56
+ if clip_coef < 1:
57
+ # sync_print('clip_coef: {}'.format(clip_coef))
58
+ for _, p in parameters:
59
+ p.grad.data.mul_(clip_coef)
60
+
61
+ return total_norm
62
+
63
+ class ClipMeter(object):
64
+ def __init__(self, mom=None, thresh=None, min_max=False, mean=False, init=False):
65
+ self.thresh = thresh
66
+ self.mom = mom
67
+ self.min_max = min_max
68
+ self.mean = mean
69
+ self.val = 1.0
70
+ self.init = init
71
+
72
+ def get_mean(self):
73
+ return self.val
74
+
75
+ def get_clip_val(self):
76
+ if self.mean:
77
+ return self.get_mean()
78
+ else:
79
+ return self.get_mean() * (1+self.thresh)
80
+
81
+ def update(self, x):
82
+ if self.init:
83
+ self.val = x
84
+ self.init = False
85
+ mean = self.get_mean()
86
+ if self.min_max:
87
+ x = max(min(x, mean*(1+self.thresh)), mean*(1-self.thresh))
88
+ else:
89
+ x = min(x, mean*(1+self.thresh))
90
+
91
+ self.val = self.mom * self.val + (1-self.mom)*x
92
+ return self.get_clip_val()
core/comm_.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ """
3
+ This file contains primitives for multi-gpu communication.
4
+ This is useful when doing distributed training.
5
+ """
6
+
7
+ import functools
8
+ import logging
9
+ import numpy as np
10
+ import pickle
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ _LOCAL_PROCESS_GROUP = None
15
+
16
+ _CAPTION_GEN_MODE = False
17
+
18
+ temp_dir = TEMP_DIR = './data/temp'
19
+ IDS = 'IDS'
20
+ image_features = 'image_features'
21
+ text_features = 'text_features'
22
+
23
+ old_checkpoint = True
24
+
25
+ """
26
+ A torch process group which only includes processes that on the same machine as the current process.
27
+ This variable is set when processes are spawned by `launch()` in "engine/launch.py".
28
+ """
29
+
30
+
31
+ def is_dist_avail_and_initialized():
32
+ if not dist.is_available():
33
+ return False
34
+ if not dist.is_initialized():
35
+ return False
36
+ return True
37
+
38
+
39
+ def get_world_size() -> int:
40
+ if not dist.is_available():
41
+ return 1
42
+ if not dist.is_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank() -> int:
48
+ if not dist.is_available():
49
+ return 0
50
+ if not dist.is_initialized():
51
+ return 0
52
+ return dist.get_rank()
53
+
54
+
55
+ def get_local_rank() -> int:
56
+ """
57
+ Returns:
58
+ The rank of the current process within the local (per-machine) process group.
59
+ """
60
+ if not dist.is_available():
61
+ return 0
62
+ if not dist.is_initialized():
63
+ return 0
64
+ # assert _LOCAL_PROCESS_GROUP is not None
65
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
66
+
67
+
68
+ def get_local_size() -> int:
69
+ """
70
+ Returns:
71
+ The size of the per-machine process group,
72
+ i.e. the number of processes per machine.
73
+ """
74
+ if not dist.is_available():
75
+ return 1
76
+ if not dist.is_initialized():
77
+ return 1
78
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
79
+
80
+
81
+ def is_main_process() -> bool:
82
+ return get_rank() == 0
83
+
84
+
85
+ def synchronize():
86
+ """
87
+ Helper function to synchronize (barrier) among all processes when
88
+ using distributed training
89
+ """
90
+ if not dist.is_available():
91
+ return
92
+ if not dist.is_initialized():
93
+ return
94
+ world_size = dist.get_world_size()
95
+ if world_size == 1:
96
+ return
97
+ dist.barrier()
98
+
99
+
100
+ @functools.lru_cache()
101
+ def _get_global_gloo_group():
102
+ """
103
+ Return a process group based on gloo backend, containing all the ranks
104
+ The result is cached.
105
+ """
106
+ if dist.get_backend() == "nccl":
107
+ return dist.new_group(backend="gloo")
108
+ else:
109
+ return dist.group.WORLD
110
+
111
+
112
+ def _serialize_to_tensor(data, group):
113
+ backend = dist.get_backend(group)
114
+ assert backend in ["gloo", "nccl"]
115
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
116
+
117
+ buffer = pickle.dumps(data)
118
+ if len(buffer) > 1024 ** 3:
119
+ logger = logging.getLogger(__name__)
120
+ logger.warning(
121
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
122
+ get_rank(), len(buffer) / (1024 ** 3), device
123
+ )
124
+ )
125
+ storage = torch.ByteStorage.from_buffer(buffer)
126
+ tensor = torch.ByteTensor(storage).to(device=device)
127
+ return tensor
128
+
129
+
130
+ def _pad_to_largest_tensor(tensor, group):
131
+ """
132
+ Returns:
133
+ list[int]: size of the tensor, on each rank
134
+ Tensor: padded tensor that has the max size
135
+ """
136
+ world_size = dist.get_world_size(group=group)
137
+ assert (
138
+ world_size >= 1
139
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
140
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
141
+ size_list = [
142
+ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
143
+ ]
144
+ dist.all_gather(size_list, local_size, group=group)
145
+ size_list = [int(size.item()) for size in size_list]
146
+
147
+ max_size = max(size_list)
148
+
149
+ # we pad the tensor because torch all_gather does not support
150
+ # gathering tensors of different shapes
151
+ if local_size != max_size:
152
+ padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
153
+ tensor = torch.cat((tensor, padding), dim=0)
154
+ return size_list, tensor
155
+
156
+
157
+ def all_gather(data, group=None):
158
+ """
159
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
160
+ Args:
161
+ data: any picklable object
162
+ group: a torch process group. By default, will use a group which
163
+ contains all ranks on gloo backend.
164
+ Returns:
165
+ list[data]: list of data gathered from each rank
166
+ """
167
+ if get_world_size() == 1:
168
+ return [data]
169
+ if group is None:
170
+ group = _get_global_gloo_group()
171
+ if dist.get_world_size(group) == 1:
172
+ return [data]
173
+
174
+ tensor = _serialize_to_tensor(data, group)
175
+
176
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
177
+ max_size = max(size_list)
178
+
179
+ # receiving Tensor from all ranks
180
+ tensor_list = [
181
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
182
+ ]
183
+ dist.all_gather(tensor_list, tensor, group=group)
184
+
185
+ data_list = []
186
+ for size, tensor in zip(size_list, tensor_list):
187
+ buffer = tensor.cpu().numpy().tobytes()[:size]
188
+ data_list.append(pickle.loads(buffer))
189
+
190
+ return data_list
191
+
192
+
193
+ def gather(data, dst=0, group=None):
194
+ """
195
+ Run gather on arbitrary picklable data (not necessarily tensors).
196
+ Args:
197
+ data: any picklable object
198
+ dst (int): destination rank
199
+ group: a torch process group. By default, will use a group which
200
+ contains all ranks on gloo backend.
201
+ Returns:
202
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
203
+ an empty list.
204
+ """
205
+ if get_world_size() == 1:
206
+ return [data]
207
+ if group is None:
208
+ group = _get_global_gloo_group()
209
+ if dist.get_world_size(group=group) == 1:
210
+ return [data]
211
+ rank = dist.get_rank(group=group)
212
+
213
+ tensor = _serialize_to_tensor(data, group)
214
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
215
+
216
+ # receiving Tensor from all ranks
217
+ if rank == dst:
218
+ max_size = max(size_list)
219
+ tensor_list = [
220
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
221
+ ]
222
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
223
+
224
+ data_list = []
225
+ for size, tensor in zip(size_list, tensor_list):
226
+ buffer = tensor.cpu().numpy().tobytes()[:size]
227
+ data_list.append(pickle.loads(buffer))
228
+ return data_list
229
+ else:
230
+ dist.gather(tensor, [], dst=dst, group=group)
231
+ return []
232
+
233
+
234
+ def broadcast_object(data, src=0, group=None):
235
+ """
236
+ Run gather on arbitrary picklable data (not necessarily tensors).
237
+ Args:
238
+ data: any picklable object
239
+ dst (int): destination rank
240
+ group: a torch process group. By default, will use a group which
241
+ contains all ranks on gloo backend.
242
+ Returns:
243
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
244
+ an empty list.
245
+ """
246
+ # if get_world_size() == 1:
247
+ # return data
248
+ # if group is None:
249
+ # group = _get_global_gloo_group()
250
+ # if dist.get_world_size(group=group) == 1:
251
+ # return data
252
+
253
+ if not isinstance(data, list):
254
+ data_list = [data]
255
+ dist.broadcast_object_list(data_list, src=src, group=group)
256
+ return data_list[0]
257
+ else:
258
+ dist.broadcast_object_list(data, src=src, group=group)
259
+ return data
260
+ return data
261
+
262
+
263
+ def shared_random_seed():
264
+ """
265
+ Returns:
266
+ int: a random number that is the same across all workers.
267
+ If workers need a shared RNG, they can use this shared seed to
268
+ create one.
269
+ All workers must call this function, otherwise it will deadlock.
270
+ """
271
+ ints = np.random.randint(2 ** 31)
272
+ all_ints = all_gather(ints)
273
+ return all_ints[0]
274
+
275
+
276
+ def reduce_dict(input_dict, average=True):
277
+ """
278
+ Reduce the values in the dictionary from all processes so that process with rank
279
+ 0 has the reduced results.
280
+ Args:
281
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
282
+ average (bool): whether to do average or sum
283
+ Returns:
284
+ a dict with the same keys as input_dict, after reduction.
285
+ """
286
+ world_size = get_world_size()
287
+ if world_size < 2:
288
+ return input_dict
289
+ with torch.no_grad():
290
+ names = []
291
+ values = []
292
+ # sort the keys so that they are consistent across processes
293
+ for k in sorted(input_dict.keys()):
294
+ names.append(k)
295
+ values.append(input_dict[k])
296
+ values = torch.stack(values, dim=0)
297
+ dist.reduce(values, dst=0)
298
+ if dist.get_rank() == 0 and average:
299
+ # only main process gets accumulated, so only divide by
300
+ # world_size in this case
301
+ values /= world_size
302
+ reduced_dict = {k: v for k, v in zip(names, values)}
303
+ return reduced_dict
304
+
305
+
306
+ def unwrap_model(model):
307
+ return model.module if hasattr(model, 'module') else model
core/config.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import logging
3
+ import numpy as np
4
+ from easydict import EasyDict as edict
5
+ import copy
6
+ import re
7
+ import torch.distributed as dist
8
+
9
+ from .utils import printlog
10
+ from torch.distributed.distributed_c10d import _get_global_rank
11
+
12
+
13
+ task_specific_param = ['backbone', 'neck', 'decoder', 'dataset', 'sampler', 'lr_scheduler', 'optimizer',
14
+ 'extra', 'evaluation', 'model_entry_type', 'load_ignore', 'ckpt_task_id',
15
+ 'patch_neck','patch_adapter', 'patch_proj', 'label_neck', 'label_adapter', 'label_proj',]
16
+
17
+ loader = yaml.SafeLoader
18
+ loader.add_implicit_resolver(
19
+ u'tag:yaml.org,2002:float',
20
+ re.compile(u'''^(?:
21
+ [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
22
+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
23
+ |\\.[0-9_]+(?:[eE][-+][0-9]+)?
24
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
25
+ |[-+]?\\.(?:inf|Inf|INF)
26
+ |\\.(?:nan|NaN|NAN))$''', re.X),
27
+ list(u'-+0123456789.'))
28
+
29
+ def flat(nums):
30
+ res = []
31
+ for i in nums:
32
+ if isinstance(i, list):
33
+ res.extend(flat(i))
34
+ else:
35
+ res.append(i)
36
+ return res
37
+
38
+ def specific_group_split_modality_groups(group_spec, share_backbone_group_ids,
39
+ share_decoder_group_ids, share_rgb_group_ids,
40
+ share_video_group_ids, share_dense_labeling_group_ids,
41
+ share_sparse_labeling_group_ids, share_text_group_ids, share_modality_group_ids=None):
42
+ ## sanity check
43
+ assert type(group_spec) is list
44
+ assert all(map(lambda x: type(x) is int, group_spec))
45
+
46
+ num_groups = len(group_spec)
47
+ splits = np.sum(group_spec)
48
+
49
+ if dist.is_initialized():
50
+ world_size = dist.get_world_size()
51
+ rank = dist.get_rank()
52
+ else:
53
+ world_size = 1
54
+ rank = 0
55
+
56
+ assert world_size % splits == 0, f"{world_size} % {splits}"
57
+ unit = int(world_size / splits)
58
+
59
+ ## split
60
+ group_sizes = [x*unit for x in group_spec] # [8,8,8] / [32, 16]
61
+ groups = []
62
+ roots = []
63
+ last = 0
64
+ task_info = edict()
65
+ all_ranks = []
66
+
67
+ for i,gs in enumerate(group_sizes):
68
+ ranks = list(map(int, np.arange(last, last+gs))) #[0...8], [9...15], ...
69
+ groups.append(dist.new_group(ranks=ranks))
70
+ roots.append(last) # 0, 8, 16
71
+ all_ranks.append(ranks)
72
+ if rank in ranks: # if current gpu rank in traversed rank task group
73
+ printlog(f">> task_info.group[{i}] ranks {ranks}")
74
+ task_info.group = groups[-1] # subordinate to what group
75
+ task_info.task_size = gs # 8
76
+ task_info.task_id = i
77
+ task_info.task_rank = rank - last
78
+ task_info.task_root_rank = last
79
+ last += gs
80
+ task_info.root_group = dist.new_group(ranks=roots)
81
+ printlog(f">> task_info.root_group ranks {roots}")
82
+ task_info.task_sizes = group_sizes
83
+ task_info.task_root_ranks = roots
84
+ task_info.task_num = num_groups
85
+
86
+ ## share_backbone_group spec
87
+ if share_backbone_group_ids is not None: # *[0,0,0]*(default) | [0,1,0]task ids
88
+ # group size must equal within a share_group
89
+ backboneshareid2idx = {}
90
+ for idx, this_id in enumerate(share_backbone_group_ids):
91
+ if this_id not in backboneshareid2idx:
92
+ backboneshareid2idx[this_id] = list()
93
+ backboneshareid2idx[this_id].append(idx) # {0: [0,1,2]}| {0: [0,2], 1: [1]}
94
+
95
+ ## create backbone share group
96
+ for idxs in backboneshareid2idx.values(): # idxs = [0, 1, 2]
97
+ this_group_ranks = flat([all_ranks[i] for i in idxs])
98
+ this_share_group = dist.new_group(ranks=this_group_ranks)
99
+ this_group_size = len(this_group_ranks)
100
+ if rank in this_group_ranks:
101
+ task_info.backbone_share_group = this_share_group
102
+ printlog(f">> task_info.backbone_share_group[{idxs}] ranks {this_group_ranks}")
103
+ task_info.backbone_group_size = len(backboneshareid2idx)
104
+ task_info.backbone_task_size = len(backboneshareid2idx) * this_group_size
105
+ task_info.backbone_task_rank = np.sum(rank < np.array(this_group_ranks))
106
+
107
+ ## share_decoder_group spec
108
+ if share_decoder_group_ids is not None:
109
+ # group size must equal within a share_group
110
+ decodershareid2idx = {}
111
+ for idx, this_id in enumerate(share_decoder_group_ids):
112
+ if this_id not in decodershareid2idx:
113
+ decodershareid2idx[this_id] = list()
114
+ decodershareid2idx[this_id].append(idx)
115
+
116
+ ## create decoder share group
117
+ for idxs in decodershareid2idx.values():
118
+ this_group_ranks = flat([all_ranks[i] for i in idxs])
119
+ this_share_group = dist.new_group(ranks=this_group_ranks)
120
+ this_group_size = len(this_group_ranks)
121
+ if rank in this_group_ranks:
122
+ task_info.decoder_share_group = this_share_group
123
+ printlog(f">> task_info.decoder_share_group[{idxs}] ranks {this_group_ranks}")
124
+ task_info.decoder_group_size = len(decodershareid2idx)
125
+ task_info.decoder_task_size = len(decodershareid2idx) * this_group_size
126
+ task_info.decoder_task_rank = np.sum(rank < np.array(this_group_ranks))
127
+
128
+
129
+ # Now, only for sparse labeling to deal with the modality sharing problem,
130
+ # which is not a good solution, but it works.
131
+ # parameters that have grads in [0,1,2] are in modality share group,
132
+ # parameters that do not have grads in [3,4] should be set in the task-specific group.
133
+ if share_modality_group_ids is not None:
134
+ # group size must equal within a share_group
135
+ modalityshareid2idx = {}
136
+ for idx, this_id in enumerate(share_modality_group_ids):
137
+ # -1 denotes that this modality does not appear in the current task
138
+ # if this_id == -1:
139
+ # continue
140
+ if this_id not in modalityshareid2idx:
141
+ modalityshareid2idx[this_id] = list()
142
+ modalityshareid2idx[this_id].append(idx)
143
+
144
+ ## create modality share group
145
+ for idxs in modalityshareid2idx.values(): # 0: [1,2] 1: [3]
146
+ this_group_ranks = flat([all_ranks[i] for i in idxs]) # 1 2
147
+ this_share_group = dist.new_group(ranks=this_group_ranks)
148
+ this_group_size = len(this_group_ranks) # 2
149
+ if rank in this_group_ranks:
150
+ task_info.modality_share_group = this_share_group
151
+ printlog(f">> task_info.modality_share_group[{idxs}] ranks {this_group_ranks}")
152
+ task_info.modality_group_size = len(modalityshareid2idx)
153
+
154
+ if share_rgb_group_ids is not None:
155
+ # group size must equal within a share_group
156
+ rgbshareid2idx = {}
157
+ for idx, this_id in enumerate(share_rgb_group_ids):
158
+ # -1 denotes that this modality does not appear in the current task
159
+ # if this_id == -1:
160
+ # continue
161
+ if this_id not in rgbshareid2idx:
162
+ rgbshareid2idx[this_id] = list()
163
+ rgbshareid2idx[this_id].append(idx)
164
+
165
+ ## create rgb share group
166
+ for idxs in rgbshareid2idx.values(): # 0: [1,2] 1: [3]
167
+ this_group_ranks = flat([all_ranks[i] for i in idxs]) # 1 2
168
+ this_share_group = dist.new_group(ranks=this_group_ranks)
169
+ this_group_size = len(this_group_ranks) # 2
170
+ if rank in this_group_ranks:
171
+ task_info.rgb_share_group = this_share_group
172
+ printlog(f">> task_info.rgb_share_group[{idxs}] ranks {this_group_ranks}")
173
+ task_info.rgb_group_size = len(rgbshareid2idx)
174
+ # task_info.rgb_task_size = len(rgbshareid2idx) * this_group_size
175
+ # task_info.rgb_task_rank = np.sum(rank < np.array(this_group_ranks))
176
+ # all_group_ranks = flat(rgbshareid2idx.values())
177
+ # if not len(rgbshareid2idx.values()) or dist.get_rank() not in all_group_ranks:
178
+ # task_info.rgb_share_group = None
179
+
180
+ if share_dense_labeling_group_ids is not None:
181
+ # group size must equal within a share_group
182
+ dense_labelingshareid2idx = {}
183
+ for idx, this_id in enumerate(share_dense_labeling_group_ids):
184
+ # -1 denotes that this modality does not appear in the current task
185
+ # if this_id == -1:
186
+ # continue
187
+ if this_id not in dense_labelingshareid2idx:
188
+ dense_labelingshareid2idx[this_id] = list()
189
+ dense_labelingshareid2idx[this_id].append(idx)
190
+
191
+ ## create dense share group
192
+ for idxs in dense_labelingshareid2idx.values(): # 0: [1,2] 1: [3]
193
+ this_group_ranks = flat([all_ranks[i] for i in idxs]) # 1 2
194
+ this_share_group = dist.new_group(ranks=this_group_ranks)
195
+ this_group_size = len(this_group_ranks) # 2
196
+ if rank in this_group_ranks:
197
+ task_info.dense_labeling_share_group = this_share_group
198
+ printlog(f">> task_info.dense_labeling_share_group[{idxs}] ranks {this_group_ranks}")
199
+ task_info.dense_labeling_group_size = len(dense_labelingshareid2idx)
200
+
201
+
202
+ if share_sparse_labeling_group_ids is not None:
203
+ # group size must equal within a share_group
204
+ sparse_labelingshareid2idx = {}
205
+ for idx, this_id in enumerate(share_sparse_labeling_group_ids):
206
+ # -1 denotes that this modality does not appear in the current task
207
+ # if this_id == -1:
208
+ # continue
209
+ if this_id not in sparse_labelingshareid2idx:
210
+ sparse_labelingshareid2idx[this_id] = list()
211
+ sparse_labelingshareid2idx[this_id].append(idx)
212
+
213
+ ## create sparse share group
214
+ for idxs in sparse_labelingshareid2idx.values(): # 0: [1,2] 1: [3]
215
+ this_group_ranks = flat([all_ranks[i] for i in idxs]) # 1 2
216
+ this_share_group = dist.new_group(ranks=this_group_ranks)
217
+ this_group_size = len(this_group_ranks) # 2
218
+ if rank in this_group_ranks:
219
+ task_info.sparse_labeling_share_group = this_share_group
220
+ printlog(f">> task_info.sparse_labeling_share_group[{idxs}] ranks {this_group_ranks}")
221
+ task_info.sparse_labeling_group_size = len(sparse_labelingshareid2idx)
222
+
223
+
224
+ if share_text_group_ids is not None:
225
+ # group size must equal within a share_group
226
+ textshareid2idx = {}
227
+ for idx, this_id in enumerate(share_text_group_ids):
228
+ # -1 denotes that this modality does not appear in the current task
229
+ if this_id not in textshareid2idx:
230
+ textshareid2idx[this_id] = list()
231
+ textshareid2idx[this_id].append(idx)
232
+
233
+ ## create text share group
234
+ for idxs in textshareid2idx.values(): # 0: [1,2] 1: [3]
235
+ this_group_ranks = flat([all_ranks[i] for i in idxs]) # 1 2
236
+ this_share_group = dist.new_group(ranks=this_group_ranks)
237
+ this_group_size = len(this_group_ranks) # 2
238
+ if rank in this_group_ranks:
239
+ task_info.text_share_group = this_share_group
240
+ printlog(f">> task_info.text_share_group[{idxs}] ranks {this_group_ranks}")
241
+ task_info.text_group_size = len(textshareid2idx)
242
+
243
+
244
+ if share_video_group_ids is not None:
245
+ # group size must equal within a share_group
246
+ videoshareid2idx = {}
247
+ for idx, this_id in enumerate(share_video_group_ids):
248
+ # -1 denotes that this modality does not appear in the current task
249
+ # if this_id == -1:
250
+ # continue
251
+ if this_id not in videoshareid2idx:
252
+ videoshareid2idx[this_id] = list()
253
+ videoshareid2idx[this_id].append(idx)
254
+
255
+ ## create video share group
256
+ for idxs in videoshareid2idx.values(): # 0: [1,2] 1: [3]
257
+ this_group_ranks = flat([all_ranks[i] for i in idxs]) # 1 2
258
+ this_share_group = dist.new_group(ranks=this_group_ranks)
259
+ this_group_size = len(this_group_ranks) # 2
260
+ if rank in this_group_ranks:
261
+ task_info.video_share_group = this_share_group
262
+ printlog(f">> task_info.video_share_group[{idxs}] ranks {this_group_ranks}")
263
+ task_info.video_group_size = len(videoshareid2idx)
264
+
265
+ return task_info
266
+
267
+ def specific_group_split(group_spec, share_backbone_group_ids, \
268
+ share_neck_group_ids, share_decoder_group_ids, share_adapter_group_ids):
269
+ ## sanity check
270
+ assert type(group_spec) is list
271
+ assert all(map(lambda x: type(x) is int, group_spec))
272
+
273
+ num_groups = len(group_spec)
274
+ splits = np.sum(group_spec)
275
+
276
+ world_size = dist.get_world_size()
277
+ rank = dist.get_rank()
278
+
279
+ assert world_size % splits == 0, f"{world_size} % {splits}"
280
+ unit = int(world_size / splits)
281
+
282
+ ## split
283
+ group_sizes = [x*unit for x in group_spec] # [8,8,8] / [32, 16]
284
+ groups = []
285
+ roots = []
286
+ last = 0
287
+ task_info = edict()
288
+ all_ranks = []
289
+ # import pdb;
290
+ # pdb.set_trace()
291
+ for i,gs in enumerate(group_sizes):
292
+ ranks = list(map(int, np.arange(last, last+gs))) #[0...8], [9...15], ...
293
+ groups.append(dist.new_group(ranks=ranks))
294
+ roots.append(last) # 0, 8, 16
295
+ all_ranks.append(ranks)
296
+ if rank in ranks: # if current gpu rank in traversed rank task group
297
+ printlog(f">> task_info.group[{i}] ranks {ranks}")
298
+ task_info.group = groups[-1] # subordinate to what group
299
+ task_info.task_size = gs # 8
300
+ task_info.task_id = i
301
+ task_info.task_rank = rank - last
302
+ task_info.task_root_rank = last
303
+ last += gs
304
+ task_info.root_group = dist.new_group(ranks=roots)
305
+ printlog(f">> task_info.root_group ranks {roots}")
306
+ task_info.task_sizes = group_sizes
307
+ task_info.task_root_ranks = roots
308
+ task_info.task_num = num_groups
309
+ # pdb.set_trace()
310
+ ## share_backbone_group spec
311
+ if share_backbone_group_ids is not None: # *[0,0,0]*(default) | [0,1,0]task ids
312
+ # group size must equal within a share_group
313
+ backboneshareid2idx = {}
314
+ for idx, this_id in enumerate(share_backbone_group_ids):
315
+ if this_id not in backboneshareid2idx:
316
+ backboneshareid2idx[this_id] = list()
317
+ backboneshareid2idx[this_id].append(idx) # {0: [0,1,2]}| {0: [0,2], 1: [1]}
318
+
319
+ ## create backbone share group
320
+ for idxs in backboneshareid2idx.values(): # idxs = [0, 1, 2]
321
+ this_group_ranks = flat([all_ranks[i] for i in idxs])
322
+ this_share_group = dist.new_group(ranks=this_group_ranks)
323
+ this_group_size = len(this_group_ranks)
324
+ if rank in this_group_ranks:
325
+ task_info.backbone_share_group = this_share_group
326
+ printlog(f">> task_info.backbone_share_group[{idxs}] ranks {this_group_ranks}")
327
+ task_info.backbone_group_size = len(backboneshareid2idx)
328
+ task_info.backbone_task_size = len(backboneshareid2idx) * this_group_size
329
+ task_info.backbone_task_rank = np.sum(rank < np.array(this_group_ranks))
330
+ ## share_adapter_group spec
331
+ if share_adapter_group_ids is not None: # *[0,0,0]*(default) | [0,1,0]task ids
332
+ # group size must equal within a share_group
333
+ adaptershareid2idx = {}
334
+ for idx, this_id in enumerate(share_adapter_group_ids):
335
+ if this_id not in adaptershareid2idx:
336
+ adaptershareid2idx[this_id] = list()
337
+ adaptershareid2idx[this_id].append(idx) # {0: [0,1,2]}| {0: [0,2], 1: [1]}
338
+
339
+ ## create adapter share group
340
+ for idxs in adaptershareid2idx.values(): # idxs = [0, 1, 2]
341
+ this_group_ranks = flat([all_ranks[i] for i in idxs])
342
+ this_share_group = dist.new_group(ranks=this_group_ranks)
343
+ this_group_size = len(this_group_ranks)
344
+ if rank in this_group_ranks:
345
+ task_info.adapter_share_group = this_share_group
346
+ printlog(f">> task_info.adapter_share_group[{idxs}] ranks {this_group_ranks}")
347
+ task_info.adapter_group_size = len(adaptershareid2idx)
348
+ task_info.adapter_task_size = len(adaptershareid2idx) * this_group_size
349
+ task_info.adapter_task_rank = np.sum(rank < np.array(this_group_ranks))
350
+
351
+ # pdb.set_trace()
352
+ ## share_neck_group spec
353
+ if share_neck_group_ids is not None:
354
+ # group size must equal within a share_group
355
+ neckshareid2idx = {}
356
+ for idx, this_id in enumerate(share_neck_group_ids):
357
+ if this_id not in neckshareid2idx:
358
+ neckshareid2idx[this_id] = list()
359
+ neckshareid2idx[this_id].append(idx)
360
+
361
+ ## create neck share group
362
+ for idxs in neckshareid2idx.values():
363
+ this_group_ranks = flat([all_ranks[i] for i in idxs])
364
+ this_share_group = dist.new_group(ranks=this_group_ranks)
365
+ this_group_size = len(this_group_ranks)
366
+ if rank in this_group_ranks:
367
+ task_info.neck_share_group = this_share_group
368
+ printlog(f">> task_info.neck_share_group[{idxs}] ranks {this_group_ranks}")
369
+ task_info.neck_group_size = len(neckshareid2idx)
370
+ task_info.neck_task_size = len(neckshareid2idx) * this_group_size
371
+ task_info.neck_task_rank = np.sum(rank < np.array(this_group_ranks))
372
+
373
+ ## share_decoder_group spec
374
+ if share_decoder_group_ids is not None:
375
+ # group size must equal within a share_group
376
+ decodershareid2idx = {}
377
+ for idx, this_id in enumerate(share_decoder_group_ids):
378
+ if this_id not in decodershareid2idx:
379
+ decodershareid2idx[this_id] = list()
380
+ decodershareid2idx[this_id].append(idx)
381
+
382
+ ## create decoder share group
383
+ for idxs in decodershareid2idx.values():
384
+ this_group_ranks = flat([all_ranks[i] for i in idxs])
385
+ this_share_group = dist.new_group(ranks=this_group_ranks)
386
+ this_group_size = len(this_group_ranks)
387
+ if rank in this_group_ranks:
388
+ task_info.decoder_share_group = this_share_group
389
+ printlog(f">> task_info.decoder_share_group[{idxs}] ranks {this_group_ranks}")
390
+ task_info.decoder_group_size = len(decodershareid2idx)
391
+ task_info.decoder_task_size = len(decodershareid2idx) * this_group_size
392
+ task_info.decoder_task_rank = np.sum(rank < np.array(this_group_ranks))
393
+ return task_info
394
+
395
+ class Config(object):
396
+
397
+ def __init__(self, config_file, noginfo=False, spec_ginfo_index=None):
398
+
399
+ with open(config_file) as f:
400
+ config = yaml.load(f, Loader=loader)
401
+ # print('config',config)
402
+ self.config_path = config_file
403
+
404
+ world_size = dist.get_world_size()
405
+ rank = dist.get_rank()
406
+
407
+ if noginfo:
408
+ ginfo = None
409
+ else: # cherrypick from tasks
410
+ tasks = config['tasks']
411
+ num_tasks = len(tasks)
412
+ if spec_ginfo_index is not None:
413
+ assert spec_ginfo_index < len(tasks), \
414
+ 'spec_ginfo_index={} is larger than num_tasks={}'.format(spec_ginfo_index, len(tasks))
415
+ tmp_config = copy.deepcopy(config)
416
+ config['tasks'] = dict()
417
+ config['tasks'][0] = tmp_config['tasks'][spec_ginfo_index]
418
+ config['tasks'][0]['gres_ratio'] = 1
419
+ tasks = config['tasks']
420
+ num_tasks = len(tasks)
421
+
422
+ # parse task_common and assign to each task
423
+ task_common = config.get('task_common', None)
424
+ if task_common is not None:
425
+ for i in range(num_tasks):
426
+ for k,v in task_common.items():
427
+ if not k in tasks[i]:
428
+ printlog('setting {} to {} for task {}'.format(k, v, i))
429
+ tasks[i][k] = v
430
+
431
+ group_spec = [tasks[i].get('gres_ratio',1) for i in range(num_tasks)]
432
+
433
+ ## share group spec
434
+ if config['common'].get('share_backbone_group', False):
435
+ share_backbone_group_ids = config['common']['share_backbone_group'][:num_tasks]
436
+ else:
437
+ share_backbone_group_ids = [0 for i in range(num_tasks)] # hardcoded prior
438
+ if config['common'].get('share_adapter_group', False):
439
+ if len(config['common']['share_adapter_group']) == 1:
440
+ adapter_list = []
441
+ share_adapter_group_ids = config['common']['share_adapter_group'][:num_tasks]
442
+ else:
443
+ share_adapter_group_ids = [0 for i in range(num_tasks)] # hardcoded prior
444
+
445
+ if config['common'].get('share_neck_group', False):
446
+ share_neck_group_ids = config['common']['share_neck_group'][:num_tasks]
447
+ else:
448
+ share_neck_group_ids = [0 for i in range(num_tasks)] # hardcoded prior
449
+
450
+ if config['common'].get('share_decoder_group', False):
451
+ share_decoder_group_ids = config['common']['share_decoder_group'][:num_tasks]
452
+ else:
453
+ share_decoder_group_ids = [i for i in range(num_tasks)] # hardcoded prior
454
+ ginfo = specific_group_split(group_spec, share_backbone_group_ids, share_neck_group_ids,
455
+ share_decoder_group_ids, share_adapter_group_ids)
456
+ loss_weight_sum = float(np.sum(np.array([task['loss_weight'] for task in tasks.values()])))
457
+ ginfo.task_name = tasks[ginfo.task_id]['name']
458
+ ginfo.task_names = [tasks[i]['name'] for i in range(ginfo.task_num)]
459
+ ginfo.task_weight = float(tasks[ginfo.task_id]['loss_weight']) / loss_weight_sum
460
+ ginfo.task_type = tasks[ginfo.task_id].get('type', 'normal')
461
+ ginfo.task_types = [tasks[i].get('type', 'normal') for i in range(ginfo.task_num)]
462
+ ginfo.task_random_seed = tasks[ginfo.task_id].get('random_seed', 0)
463
+
464
+ for p in task_specific_param:
465
+ if p in config['tasks'][ginfo.task_id]:
466
+ config['common'][p] = config['tasks'][ginfo.task_id][p]
467
+ printlog('{} of task{} has been overided to {}'.format(p, ginfo.task_id, config['common'][p]))
468
+
469
+ logger = logging.getLogger('global_logger')
470
+
471
+ self.world_size = world_size
472
+ self.rank = rank
473
+ self.ginfo = ginfo
474
+ self.config = config
475
+ self.config_file = config_file
476
+
477
+ class Config_Hulk(object):
478
+
479
+ def __init__(self, config_file, noginfo=False, spec_ginfo_index=None):
480
+
481
+ with open(config_file) as f:
482
+ config = yaml.load(f, Loader=loader)
483
+ # print('config',config)
484
+ self.config_path = config_file
485
+
486
+
487
+ if dist.is_initialized():
488
+ world_size = dist.get_world_size()
489
+ rank = dist.get_rank()
490
+ else:
491
+ world_size = 1
492
+ rank = 0
493
+
494
+ if noginfo:
495
+ ginfo = None
496
+ else: # cherrypick from tasks
497
+ tasks = config['tasks']
498
+ num_tasks = len(tasks)
499
+ if spec_ginfo_index is not None:
500
+ assert spec_ginfo_index < len(tasks), \
501
+ 'spec_ginfo_index={} is larger than num_tasks={}'.format(spec_ginfo_index, len(tasks))
502
+ tmp_config = copy.deepcopy(config)
503
+ config['tasks'] = dict()
504
+ config['tasks'][0] = tmp_config['tasks'][spec_ginfo_index]
505
+ config['tasks'][0]['gres_ratio'] = 1
506
+ tasks = config['tasks']
507
+ num_tasks = len(tasks)
508
+
509
+ # parse task_common and assign to each task
510
+ task_common = config.get('task_common', None)
511
+ if task_common is not None:
512
+ for i in range(num_tasks):
513
+ for k,v in task_common.items():
514
+ if not k in tasks[i]:
515
+ printlog('setting {} to {} for task {}'.format(k, v, i))
516
+ tasks[i][k] = v
517
+
518
+ group_spec = [tasks[i].get('gres_ratio',1) for i in range(num_tasks)]
519
+
520
+ ## share group spec
521
+ if config['common'].get('share_backbone_group', False):
522
+ share_backbone_group_ids = config['common']['share_backbone_group'][:num_tasks]
523
+ else:
524
+ share_backbone_group_ids = [0 for i in range(num_tasks)] # hardcoded prior
525
+
526
+
527
+ if config['common'].get('share_decoder_group', False):
528
+ share_decoder_group_ids = config['common']['share_decoder_group'][:num_tasks]
529
+ else:
530
+ share_decoder_group_ids = [i for i in range(num_tasks)] # hardcoded prior
531
+
532
+ # use modality groups to control the communication of neck, adapter, and output proj
533
+
534
+ if config['common'].get('share_rgb_group', False):
535
+ share_rgb_group_ids = config['common']['share_rgb_group'][:num_tasks]
536
+ else:
537
+ share_rgb_group_ids = [i for i in range(num_tasks)] # hardcoded prior
538
+
539
+ if config['common'].get('share_dense_labeling_group', False):
540
+ share_dense_labeling_group_ids = config['common']['share_dense_labeling_group'][:num_tasks]
541
+ else:
542
+ share_dense_labeling_group_ids = [i for i in range(num_tasks)]
543
+
544
+ if config['common'].get('share_sparse_labeling_group', False):
545
+ share_sparse_labeling_group_ids = config['common']['share_sparse_labeling_group'][:num_tasks]
546
+ else:
547
+ share_sparse_labeling_group_ids = [i for i in range(num_tasks)]
548
+
549
+ if config['common'].get('share_text_group', False):
550
+ share_text_group_ids = config['common']['share_text_group'][:num_tasks]
551
+ else:
552
+ share_text_group_ids = [i for i in range(num_tasks)]
553
+
554
+ if config['common'].get('share_video_group', False):
555
+ share_video_group_ids = config['common']['share_video_group'][:num_tasks]
556
+ else:
557
+ share_video_group_ids = [i for i in range(num_tasks)]
558
+
559
+ if config['common'].get('share_modality_group', False):
560
+ share_modality_group_ids = config['common']['share_modality_group'][:num_tasks]
561
+ else:
562
+ share_modality_group_ids = [i for i in range(num_tasks)]
563
+
564
+ # ginfo = specific_group_split_modality_groups(group_spec, share_backbone_group_ids,
565
+ # share_decoder_group_ids, share_rgb_group_ids,
566
+ # share_video_group_ids, share_dense_labeling_group_ids,
567
+ # share_sparse_labeling_group_ids, share_text_group_ids,
568
+ # share_modality_group_ids)
569
+ import easydict
570
+ ginfo = easydict.EasyDict()
571
+ ginfo.task_id = 5
572
+ ginfo.task_num = 5
573
+ ginfo.backbone_share_group = None
574
+ ginfo.task_rank = 0
575
+
576
+ loss_weight_sum = float(np.sum(np.array([task['loss_weight'] for task in tasks.values()])))
577
+ ginfo.task_name = tasks[ginfo.task_id]['name']
578
+ ginfo.task_names = [tasks[i]['name'] for i in range(ginfo.task_num)]
579
+ # ginfo.task_weight = float(tasks[ginfo.task_id]['loss_weight']) / loss_weight_sum
580
+ ginfo.task_weight = float(tasks[ginfo.task_id]['loss_weight'])
581
+ ginfo.task_type = tasks[ginfo.task_id].get('type', 'normal')
582
+ ginfo.task_types = [tasks[i].get('type', 'normal') for i in range(ginfo.task_num)]
583
+ ginfo.task_random_seed = tasks[ginfo.task_id].get('random_seed', 0)
584
+
585
+ for p in task_specific_param:
586
+ if p in config['tasks'][ginfo.task_id]:
587
+ config['common'][p] = config['tasks'][ginfo.task_id][p]
588
+ printlog('{} of task{} has been overided to {}'.format(p, ginfo.task_id, config['common'][p]))
589
+
590
+ logger = logging.getLogger('global_logger')
591
+
592
+ self.world_size = world_size
593
+ self.rank = rank
594
+ self.ginfo = ginfo
595
+ self.config = config
596
+ self.config_file = config_file
597
+
598
+ # def __repr__(self) -> str:
599
+ # return str(self.config)
600
+
core/data/__init__.py ADDED
File without changes
core/data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (188 Bytes). View file
 
core/data/datasets/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .images.pedattr_dataset import MultiAttrDataset
2
+ from .images.pos_dataset_dev import COCOPosDatasetDev, MPIIPosDatasetDev
3
+ from .images.parsing_dataset import (Human3M6ParsingDataset, LIPParsingDataset, CIHPParsingDataset, ATRParsingDataset,
4
+ DeepFashionParsingDataset, VIPParsingDataset, ModaNetParsingDataset,
5
+ PaperDollParsingDataset)
6
+ from .images.multi_posedataset import MultiPoseDatasetDev
7
+ from .images.peddet_dataset_v2 import PedestrainDetectionDataset_v2, PedestrainDetectionDataset_v2demo
8
+ from .images.image_caption_dataset import CocoCaption, CocoCaptiondemo
9
+ from .sequences.skeleton_action_dataset import mmSkeletonDataset
10
+ from .images.smpl_dataset_v2 import MeshTSVYamlDataset
11
+ from core.utils import printlog
12
+
13
+ def dataset_entry(config):
14
+ printlog('config[kwargs]',config['kwargs'])
15
+ return globals()[config['type']](**config['kwargs'])
core/data/datasets/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.41 kB). View file
 
core/data/datasets/images/__pycache__/image_caption_dataset.cpython-312.pyc ADDED
Binary file (13.9 kB). View file
 
core/data/datasets/images/__pycache__/multi_posedataset.cpython-312.pyc ADDED
Binary file (18.9 kB). View file
 
core/data/datasets/images/__pycache__/parsing_dataset.cpython-312.pyc ADDED
Binary file (40.6 kB). View file
 
core/data/datasets/images/__pycache__/pedattr_dataset.cpython-312.pyc ADDED
Binary file (34.7 kB). View file
 
core/data/datasets/images/__pycache__/peddet_dataset_v2.cpython-312.pyc ADDED
Binary file (28.9 kB). View file
 
core/data/datasets/images/__pycache__/pos_dataset_dev.cpython-312.pyc ADDED
Binary file (33.8 kB). View file
 
core/data/datasets/images/__pycache__/seg_dataset_dev.cpython-312.pyc ADDED
Binary file (16.2 kB). View file
 
core/data/datasets/images/__pycache__/smpl_dataset_v2.cpython-312.pyc ADDED
Binary file (18 kB). View file
 
core/data/datasets/images/image_caption_dataset.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import random
5
+ import torch
6
+ import torchvision
7
+ import numpy as np
8
+ import pandas as pd
9
+ import os.path as osp
10
+ from PIL import Image
11
+ from collections import defaultdict
12
+ from transformers import BertTokenizer
13
+ from torch.utils.data import Dataset
14
+ import torch.distributed as dist
15
+ from torchvision import transforms
16
+ from torchvision.transforms import PILToTensor, ToTensor
17
+ from core.data.transforms.caption_transforms import RandomAugment
18
+
19
+ def pre_caption(caption, max_words=30):
20
+ caption = re.sub(
21
+ r"([.!\"()*#:;~])",
22
+ ' ',
23
+ caption.lower(),
24
+ )
25
+ caption = re.sub(
26
+ r"\s{2,}",
27
+ ' ',
28
+ caption,
29
+ )
30
+ caption = caption.rstrip('\n')
31
+ caption = caption.strip(' ')
32
+
33
+ #truncate caption
34
+ caption_words = caption.split(' ')
35
+ if len(caption_words) > max_words:
36
+ caption = ' '.join(caption_words[ :max_words])
37
+
38
+ return caption
39
+
40
+ def data_transforms(split_type='train', img_size=384, min_scale=0.5):
41
+ if split_type == 'train':
42
+ data_transforms = transforms.Compose([
43
+ transforms.RandomResizedCrop(img_size, scale=(min_scale, 1.0), interpolation=Image.BICUBIC),
44
+ transforms.RandomHorizontalFlip(),
45
+ RandomAugment(2, 5, isPIL=True,
46
+ augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
47
+ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
48
+ PILToTensor()
49
+ # ToTensor()
50
+ ])
51
+ else:
52
+ data_transforms = transforms.Compose([
53
+ transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC),
54
+ PILToTensor(),
55
+ # ToTensor()
56
+ ])
57
+ return data_transforms
58
+
59
+
60
+ class CocoCaption(Dataset):
61
+ """
62
+ Implementation of the dataloader for coco_caption.
63
+ Mainly used in the model training and evaluation.
64
+ Params:
65
+ ginfo: group information for Multitask learning.
66
+ coco_root: root path of coco2014 dataset.
67
+ anno_root: annotation path of coco captions.
68
+ bert_dir: path of bert-base-uncased for loading tokenizer.
69
+ max_words: max length of input captions.
70
+ img_size: image size.
71
+ prompt: given prompt to add before captions.
72
+ """
73
+ def __init__(self, ginfo, max_words=30, img_size=384, beam_size=1, prompt='', split_type='train',
74
+ cuhk_peds=False, cuhk_peds_root=None, cuhk_peds_anno_root=None, cuhk_peds_gt_root=None,
75
+ joint_train=False, synth_peds_root=None, joint_train_anno_root=None, coco_train=False,
76
+ coco_root=None, anno_root=None, bert_dir='', mals_root=None, luperson_root=None):
77
+ self.task_name = ginfo.task_name
78
+ self.rank = dist.get_rank()
79
+ self.prompt = prompt
80
+
81
+ # plus one for bos token
82
+ self.max_words = max_words + 1
83
+ self.img_size = img_size
84
+ self.split_type = split_type
85
+ self.beam_size = beam_size
86
+
87
+ self.transforms = data_transforms(split_type, img_size)
88
+ self.tokenizer = BertTokenizer.from_pretrained(bert_dir, do_lower=True)
89
+ self.cuhk_peds = cuhk_peds
90
+ self.joint_train = joint_train
91
+ self.coco_train = coco_train
92
+ if joint_train:
93
+ self.annotation = json.load(open(joint_train_anno_root, 'r'))
94
+ self.cuhk_peds_root = cuhk_peds_root
95
+ self.synth_peds_root = synth_peds_root
96
+ self.mals_root = mals_root
97
+ self.luperson_root = luperson_root
98
+ self.coco_gt_file = cuhk_peds_gt_root
99
+ elif cuhk_peds:
100
+ self.annotation = json.load(open(cuhk_peds_anno_root, 'r'))
101
+ self.coco_gt_file = cuhk_peds_gt_root
102
+ self.coco_root = cuhk_peds_root
103
+ elif coco_train:
104
+ self.coco_root = coco_root
105
+ self.coco_gt_file = osp.join(anno_root, 'coco_gt', 'coco_karpathy_' + split_type + '_gt.json')
106
+ self.annotation = json.load(open(osp.join(anno_root, 'coco_karpathy_' + split_type + '.json'), 'r'))
107
+
108
+
109
+ def __len__(self):
110
+ return len(self.annotation)
111
+
112
+ def __getitem__(self, index):
113
+ sample = self.annotation[index]
114
+ if self.joint_train and self.split_type == 'train':
115
+ if sample['split'] == 'cuhk_peds':
116
+ image_path = osp.join(self.cuhk_peds_root, sample['image'])
117
+ elif sample['split'] == 'mals':
118
+ image_path = osp.join(self.mals_root, sample['image'])
119
+ elif sample['split'] == 'luperson':
120
+ image_path = osp.join(self.luperson_root, sample['image'])
121
+ else:
122
+ image_path = osp.join(self.synth_peds_root, sample['image'])
123
+ else:
124
+ image_path = osp.join(self.coco_root, sample['image'])
125
+ image = Image.open(image_path).convert('RGB')
126
+ image = self.transforms(image)
127
+ if self.split_type != 'train':
128
+ caption_id = np.zeros(self.max_words - 1, dtype=np.int32)
129
+ token_type_id = np.zeros(self.max_words - 1, dtype=np.int32)
130
+ caption_pad_mask = np.zeros(self.max_words - 1, dtype=np.int32)
131
+ if self.cuhk_peds:
132
+ img_id = sample['image'].split('.')[0]
133
+ else:
134
+ img_id = sample['image'].split('/')[-1].strip('.jpg').split('_')[-1]
135
+ coco_gt_file = self.coco_gt_file
136
+ beam_size = self.beam_size
137
+ return {'image': image, 'input_id': caption_id, 'image_id': int(img_id) if not self.cuhk_peds else img_id,
138
+ 'coco_gt_file': coco_gt_file, 'beam_size': beam_size,
139
+ 'token_type_id': token_type_id, 'padding_mask': caption_pad_mask}
140
+ caption = self.prompt + pre_caption(sample['caption'], self.max_words)
141
+ caption_encode = self.tokenizer.encode_plus(caption, max_length=self.max_words, pad_to_max_length=True,
142
+ return_attention_mask=True, return_token_type_ids=True,
143
+ truncation=True)
144
+ caption_id, caption_pad_mask, token_type_id = caption_encode['input_ids'], caption_encode['attention_mask'], caption_encode['token_type_ids']
145
+ caption_id = np.array(caption_id)
146
+ token_type_id = np.array(token_type_id)
147
+ caption_pad_mask = np.array(caption_pad_mask)
148
+ # caption_pad_mask = (1 - np.array(caption_pad_mask)).astype(bool)
149
+ caption = [caption]
150
+ output = {'image': image, 'input_id': caption_id, 'token_type_id': token_type_id, 'padding_mask': caption_pad_mask, 'label': caption_id}
151
+ return output
152
+
153
+ def __repr__(self):
154
+ return self.__class__.__name__ + \
155
+ f'rank: {self.rank} task: {self.task_name} mode:{"training" if self.split_type == "train" else "inference"} ' \
156
+ f'dataset_len:{len(self.annotation)} augmentation: {self.transforms}'
157
+
158
+
159
+ class CocoCaptiondemo(Dataset):
160
+ """
161
+ Implementation of the dataloader for coco_caption.
162
+ Mainly used in the model training and evaluation.
163
+ Params:
164
+ ginfo: group information for Multitask learning.
165
+ coco_root: root path of coco2014 dataset.
166
+ anno_root: annotation path of coco captions.
167
+ bert_dir: path of bert-base-uncased for loading tokenizer.
168
+ max_words: max length of input captions.
169
+ img_size: image size.
170
+ prompt: given prompt to add before captions.
171
+ """
172
+
173
+ def __init__(self, ginfo, max_words=30, img_size=384, beam_size=1, prompt='', split_type='train', demo_dir='/mnt/cache/tangshixiang/wyz_proj/demo_video_unihcpv2/folder0',
174
+ cuhk_peds=False, cuhk_peds_root=None, cuhk_peds_anno_root=None, cuhk_peds_gt_root=None,
175
+ joint_train=False, synth_peds_root=None, joint_train_anno_root=None, coco_train=False,
176
+ coco_root=None, anno_root=None, bert_dir='', mals_root=None, luperson_root=None):
177
+ self.task_name = ginfo.task_name
178
+ self.rank = dist.get_rank()
179
+ self.prompt = prompt
180
+
181
+ # plus one for bos token
182
+ self.max_words = max_words + 1
183
+ self.img_size = img_size
184
+ self.split_type = split_type
185
+ self.beam_size = beam_size
186
+
187
+ self.transforms = data_transforms(split_type, img_size)
188
+ self.tokenizer = BertTokenizer.from_pretrained(bert_dir, do_lower=True)
189
+ self.cuhk_peds = cuhk_peds
190
+ self.joint_train = joint_train
191
+ self.coco_train = coco_train
192
+ if joint_train:
193
+ self.annotation = json.load(open(joint_train_anno_root, 'r'))
194
+ self.cuhk_peds_root = cuhk_peds_root
195
+ self.synth_peds_root = synth_peds_root
196
+ self.mals_root = mals_root
197
+ self.luperson_root = luperson_root
198
+ self.coco_gt_file = cuhk_peds_gt_root
199
+ elif cuhk_peds:
200
+ self.annotation = json.load(open(cuhk_peds_anno_root, 'r'))
201
+ self.coco_gt_file = cuhk_peds_gt_root
202
+ self.coco_root = cuhk_peds_root
203
+ elif coco_train:
204
+ self.coco_root = coco_root
205
+ self.coco_gt_file = osp.join(anno_root, 'coco_gt', 'coco_karpathy_' + split_type + '_gt.json')
206
+ self.annotation = json.load(open(osp.join(anno_root, 'coco_karpathy_' + split_type + '.json'), 'r'))
207
+ self.demo_dir = demo_dir
208
+
209
+
210
+ def __len__(self):
211
+ return len(os.listdir(self.demo_dir))
212
+
213
+ def __getitem__(self, index):
214
+ # import pdb; pdb.set_trace()
215
+ sample = self.annotation[index]
216
+ if self.joint_train and self.split_type == 'train':
217
+ if sample['split'] == 'cuhk_peds':
218
+ image_path = osp.join(self.cuhk_peds_root, sample['image'])
219
+ elif sample['split'] == 'mals':
220
+ image_path = osp.join(self.mals_root, sample['image'])
221
+ elif sample['split'] == 'luperson':
222
+ image_path = osp.join(self.luperson_root, sample['image'])
223
+ else:
224
+ image_path = osp.join(self.synth_peds_root, sample['image'])
225
+ else:
226
+ image_path = osp.join(self.coco_root, sample['image'])
227
+ filename = os.path.join(self.demo_dir, f'frame_{index}.jpg')
228
+ image = Image.open(filename).convert('RGB')
229
+ image = self.transforms(image)
230
+ if self.split_type != 'train':
231
+ caption_id = np.zeros(self.max_words - 1, dtype=np.int32)
232
+ token_type_id = np.zeros(self.max_words - 1, dtype=np.int32)
233
+ caption_pad_mask = np.zeros(self.max_words - 1, dtype=np.int32)
234
+ if self.cuhk_peds:
235
+ img_id = sample['image'].split('.')[0]
236
+ else:
237
+ img_id = sample['image'].split('/')[-1].strip('.jpg').split('_')[-1]
238
+ coco_gt_file = self.coco_gt_file
239
+ beam_size = self.beam_size
240
+ return {'image': image, 'input_id': caption_id, 'image_id': filename,
241
+ 'coco_gt_file': coco_gt_file, 'beam_size': beam_size,
242
+ 'token_type_id': token_type_id, 'padding_mask': caption_pad_mask}
243
+ caption = self.prompt + pre_caption(sample['caption'], self.max_words)
244
+ caption_encode = self.tokenizer.encode_plus(caption, max_length=self.max_words, pad_to_max_length=True,
245
+ return_attention_mask=True, return_token_type_ids=True,
246
+ truncation=True)
247
+ caption_id, caption_pad_mask, token_type_id = caption_encode['input_ids'], caption_encode['attention_mask'], \
248
+ caption_encode['token_type_ids']
249
+ caption_id = np.array(caption_id)
250
+ token_type_id = np.array(token_type_id)
251
+ caption_pad_mask = np.array(caption_pad_mask)
252
+ # caption_pad_mask = (1 - np.array(caption_pad_mask)).astype(bool)
253
+ caption = [caption]
254
+ output = {'image': image, 'input_id': filename, 'token_type_id': token_type_id,
255
+ 'padding_mask': caption_pad_mask, 'label': caption_id}
256
+ return output
257
+
258
+ def __repr__(self):
259
+ return self.__class__.__name__ + \
260
+ f'rank: {self.rank} task: {self.task_name} mode:{"training" if self.split_type == "train" else "inference"} ' \
261
+ f'dataset_len:{len(self.annotation)} augmentation: {self.transforms}'
core/data/datasets/images/multi_posedataset.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from abc import ABCMeta, abstractmethod
3
+ import numpy as np
4
+ from torch.utils.data import Dataset
5
+ from pathlib import Path
6
+
7
+ import os
8
+ import cv2
9
+ import random
10
+ import time
11
+ import os.path as osp
12
+ import os
13
+ import warnings
14
+ from collections import OrderedDict, defaultdict
15
+ from core.data.transforms.pose_transforms import *
16
+ import json #_tricks as json
17
+ import numpy as np
18
+ from xtcocotools.coco import COCO
19
+ from xtcocotools.cocoeval import COCOeval
20
+ import torch.distributed as dist
21
+
22
+
23
+ from core.utils import sync_print
24
+
25
+
26
+ class PetrelCOCO(COCO):
27
+ def __init__(self, annotation_file=None, test_index=None, ann_data=None):
28
+ """
29
+ Constructor of Microsoft COCO helper class for reading and visualizing annotations.
30
+ :param annotation_file (str): location of annotation file
31
+ :param image_folder (str): location to the folder that hosts images.
32
+ :return:
33
+ """
34
+ self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
35
+ self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
36
+ self.anno_file = [annotation_file]
37
+ self.test_index = test_index
38
+ if annotation_file is not None:
39
+ print('loading annotations into memory...')
40
+ tic = time.time()
41
+ # https://github.com/cocodataset/cocoapi/pull/453/
42
+ if ann_data == None:
43
+ with open(annotation_file, 'r') as f:
44
+ dataset = json.load(f)
45
+ else:
46
+ dataset = ann_data
47
+ assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
48
+ print('Done (t={:0.2f}s)'.format(time.time()- tic))
49
+ self.dataset = dataset
50
+ self.createIndex()
51
+ if 'annotations' in self.dataset:
52
+ for i in range(len(self.dataset['annotations'])):
53
+ if self.test_index is not None:
54
+ keypoints = np.array(self.dataset['annotations'][i]['keypoints']).reshape([-1, 3])
55
+ keypoints = keypoints[self.test_index, :]
56
+ self.dataset['annotations'][i]['keypoints'] = keypoints.reshape([-1]).tolist()
57
+ if 'iscrowd' not in self.dataset['annotations'][i]:
58
+ self.dataset['annotations'][i]['iscrowd'] = False
59
+
60
+
61
+ class MultiPoseDatasetDev(Dataset):
62
+ def __init__(self,
63
+ ginfo,
64
+ ann_file,
65
+ img_prefix,
66
+ data_cfg,
67
+ test_mode=False,
68
+ use_udp=False,
69
+ dataset_name='coco',
70
+ use_ceph=False,
71
+ simp_aug=False,
72
+ **kwargs):
73
+
74
+ assert dataset_name in ['coco', 'aic', 'posetrack', 'halpe', 'JRDB2022', 'h36m', 'mhp', 'penn_action', '3DPW', '3DHP', 'AIST'], "invalid dataset name input"
75
+ self.dataset_name = dataset_name
76
+ self.image_info = {}
77
+ self.ann_info = {}
78
+ self.initialized = False
79
+
80
+ self.use_ceph = True
81
+ self.annotations_path = ann_file
82
+ self.img_prefix = img_prefix
83
+ self.test_mode = test_mode
84
+ print('data_cfg0',data_cfg)
85
+ # data_cfg=demjson.decode(data_cfg)
86
+ # print('data_cfg',data_cfg)
87
+ self.ann_info['image_size'] = np.array(data_cfg['image_size'])
88
+ self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
89
+ self.ann_info['num_joints'] = data_cfg['num_joints']
90
+
91
+ self.ann_info['inference_channel'] = data_cfg['inference_channel']
92
+ self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
93
+ self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
94
+
95
+ self.db = []
96
+ self.task_name = ginfo.task_name
97
+
98
+ if test_mode:
99
+ pipeline = [
100
+ LoadImageFromFile(use_ceph=use_ceph),
101
+ TopDownAffine(use_udp=use_udp),
102
+ ToUNTensor(),
103
+ Collect(keys=['image'],
104
+ meta_keys=['image_file', 'center', 'scale', 'rotation', 'bbox_score', 'flip_pairs'])
105
+ ]
106
+ else:
107
+ if self.dataset_name in ['coco', 'aic'] or simp_aug:
108
+ pipeline = [
109
+ LoadImageFromFile(use_ceph=use_ceph),
110
+ TopDownRandomFlip(flip_prob=0.5),
111
+ TopDownHalfBodyTransform(num_joints_half_body=8, prob_half_body=0.3),
112
+ TopDownGetRandomScaleRotation(rot_factor=40, scale_factor=0.5),
113
+ TopDownAffine(use_udp=use_udp),
114
+ ToUNTensor(),
115
+ TopDownGenerateTarget(sigma=2, encoding='UDP' if use_udp else 'MSRA'),
116
+ Collect(keys=['image', 'label', 'target_weight'],
117
+ meta_keys=['image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale','rotation',
118
+ 'bbox_score', 'flip_pairs'])
119
+ ]
120
+ elif self.dataset_name in ['posetrack', 'halpe', 'penn_action', '3DPW', 'mhp']:
121
+ pipeline = [
122
+ LoadImageFromFile(),
123
+ TopDownGetBboxCenterScale(padding=1.25),
124
+ TopDownRandomShiftBboxCenter(shift_factor=0.16, prob=0.3),
125
+ TopDownRandomFlip(flip_prob=0.5),
126
+ TopDownHalfBodyTransform(num_joints_half_body=8, prob_half_body=0.3),
127
+ TopDownGetRandomScaleRotation(rot_factor=40, scale_factor=0.5),
128
+ TopDownAffine(use_udp=use_udp),
129
+ ToUNTensor(),
130
+ TopDownGenerateTarget(sigma=2, encoding='UDP' if use_udp else 'MSRA'),
131
+ Collect(keys=['image', 'label', 'target_weight'],
132
+ meta_keys=['image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale','rotation',
133
+ 'bbox_score', 'flip_pairs'])
134
+ ]
135
+ else:
136
+ pipeline = [
137
+ LoadImageFromFile(),
138
+ # TopDownGetBboxCenterScale(padding=1.25),
139
+ TopDownRandomShiftBboxCenter(shift_factor=0.16, prob=0.3),
140
+ TopDownRandomFlip(flip_prob=0.5),
141
+ TopDownHalfBodyTransform(num_joints_half_body=8, prob_half_body=0.3),
142
+ TopDownGetRandomScaleRotation(rot_factor=40, scale_factor=0.5),
143
+ TopDownAffine(use_udp=use_udp),
144
+ ToUNTensor(),
145
+ TopDownGenerateTarget(sigma=2, encoding='UDP' if use_udp else 'MSRA'),
146
+ Collect(keys=['image', 'label', 'target_weight'],
147
+ meta_keys=['image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation',
148
+ 'bbox_score', 'flip_pairs'])
149
+ ]
150
+
151
+
152
+ self.pipeline = ComposeX(pipeline)
153
+ # dict_keys(['image_file', 'center', 'scale', 'bbox', 'rotation', 'joints_3d', 'joints_3d_visible', 'dataset',
154
+ # 'bbox_score', 'bbox_id', 'ann_info', 'image', 'flipped',
155
+ # 'label', ****'target' as in mmlab****
156
+ # 'target_weight'])
157
+
158
+
159
+ self.use_gt_bbox = data_cfg['use_gt_bbox']
160
+ self.bbox_file = data_cfg['bbox_file'] if data_cfg['bbox_file'].startswith('/mnt') else (Path(__file__).parent / 'resources' / data_cfg['bbox_file']).resolve()
161
+ self.det_bbox_thr = data_cfg.get('det_bbox_thr', 0.0)
162
+ if 'image_thr' in data_cfg:
163
+ warnings.warn(
164
+ 'image_thr is deprecated, '
165
+ 'please use det_bbox_thr instead', DeprecationWarning)
166
+ self.det_bbox_thr = data_cfg['image_thr']
167
+
168
+
169
+ self.ann_info['flip_pairs'] = data_cfg['flip_pairs']
170
+
171
+ self.ann_info['upper_body_ids'] = data_cfg['upper_body_ids']
172
+ self.ann_info['lower_body_ids'] = data_cfg['lower_body_ids']
173
+
174
+ self.ann_info['use_different_joint_weights'] = False
175
+ self.ann_info['joint_weights'] = np.array(
176
+ data_cfg['joint_weights'],
177
+ dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
178
+
179
+ # 'https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/'
180
+ # 'pycocotools/cocoeval.py#L523'
181
+
182
+ self.coco = PetrelCOCO(ann_file)
183
+
184
+ cats = [
185
+ cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
186
+ ]
187
+ self.classes = ['__background__'] + cats
188
+ self.num_classes = len(self.classes)
189
+ self._class_to_ind = dict(zip(self.classes, range(self.num_classes)))
190
+ self._class_to_coco_ind = dict(zip(cats, self.coco.getCatIds()))
191
+ self._coco_ind_to_class_ind = dict(
192
+ (self._class_to_coco_ind[cls], self._class_to_ind[cls])
193
+ for cls in self.classes[1:])
194
+ self.img_ids = self.coco.getImgIds()
195
+ self.num_images = len(self.img_ids)
196
+ self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
197
+
198
+
199
+ self.db = self._get_db()
200
+ print(f'=> dataset: {self.dataset_name} num_images: {self.num_images}')
201
+ print(f'=> dataset: {self.dataset_name} load {len(self.db)} samples')
202
+
203
+ @staticmethod
204
+ def _get_mapping_id_name(imgs):
205
+ """
206
+ Args:
207
+ imgs (dict): dict of image info.
208
+
209
+ Returns:
210
+ tuple: Image name & id mapping dicts.
211
+
212
+ - id2name (dict): Mapping image id to name.
213
+ - name2id (dict): Mapping image name to id.
214
+ """
215
+ id2name = {}
216
+ name2id = {}
217
+ for image_id, image in imgs.items():
218
+ file_name = image['file_name']
219
+ id2name[image_id] = file_name
220
+ name2id[file_name] = image_id
221
+
222
+ return id2name, name2id
223
+
224
+ def _get_db(self):
225
+ """Load dataset."""
226
+ if (not self.test_mode) or self.use_gt_bbox:
227
+ # use ground truth bbox
228
+ gt_db = self._load_coco_keypoint_annotations()
229
+ else:
230
+ # use bbox from detection
231
+ gt_db = self._load_coco_person_detection_results()
232
+ return gt_db
233
+
234
+ def _load_coco_keypoint_annotations(self):
235
+ """Ground truth bbox and keypoints."""
236
+ gt_db = []
237
+ for img_id in self.img_ids:
238
+ gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
239
+ return gt_db
240
+
241
+ def _load_coco_keypoint_annotation_kernel(self, img_id):
242
+ """load annotation from COCOAPI.
243
+
244
+ Note:
245
+ bbox:[x1, y1, w, h]
246
+ Args:
247
+ img_id: coco image id
248
+ Returns:
249
+ dict: db entry
250
+ """
251
+ img_ann = self.coco.loadImgs(img_id)[0]
252
+ width = img_ann['width']
253
+ height = img_ann['height']
254
+ num_joints = self.ann_info['num_joints']
255
+
256
+ ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
257
+ objs = self.coco.loadAnns(ann_ids)
258
+
259
+ # sanitize bboxes
260
+ valid_objs = []
261
+ for obj in objs:
262
+ if 'bbox' not in obj:
263
+ continue
264
+ x, y, w, h = obj['bbox']
265
+ x1 = max(0, x)
266
+ y1 = max(0, y)
267
+ x2 = min(width - 1, x1 + max(0, w - 1))
268
+ y2 = min(height - 1, y1 + max(0, h - 1))
269
+ if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
270
+ obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
271
+ valid_objs.append(obj)
272
+ objs = valid_objs
273
+
274
+ bbox_id = 0
275
+ rec = []
276
+ for obj in objs:
277
+ if 'keypoints' not in obj:
278
+ continue
279
+ if max(obj['keypoints']) == 0:
280
+ continue
281
+ if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
282
+ continue
283
+ joints_3d = np.zeros((num_joints, 3), dtype=np.float32)
284
+ joints_3d_visible = np.zeros((num_joints, 3), dtype=np.float32)
285
+
286
+ keypoints = np.array(obj['keypoints']).reshape(-1, 3)
287
+
288
+ if self.dataset_name == 'posetrack':
289
+ keypoints = np.delete(keypoints, [3, 4], axis=0) # keypoint idx == 3 and 4 not annot
290
+ elif self.dataset_name == 'halpe':
291
+ keypoints = keypoints[:17,:] # halpe has only 17 valid kp
292
+
293
+ joints_3d[:, :2] = keypoints[:, :2]
294
+ joints_3d_visible[:, :2] = np.minimum(1, keypoints[:, 2:3])
295
+
296
+ center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
297
+
298
+ image_file = os.path.join(self.img_prefix, self.id2name[img_id])
299
+ rec.append({
300
+ 'image_file': image_file,
301
+ 'center': center,
302
+ 'scale': scale,
303
+ 'bbox': obj['clean_bbox'][:4],
304
+ 'rotation': 0,
305
+ 'joints_3d': joints_3d,
306
+ 'joints_3d_visible': joints_3d_visible,
307
+ 'dataset': self.dataset_name,
308
+ 'bbox_score': 1,
309
+ 'bbox_id': bbox_id
310
+ })
311
+ bbox_id = bbox_id + 1
312
+
313
+ return rec
314
+
315
+ def _xywh2cs(self, x, y, w, h):
316
+ """This encodes bbox(x,y,w,w) into (center, scale)
317
+
318
+ Args:
319
+ x, y, w, h
320
+
321
+ Returns:
322
+ tuple: A tuple containing center and scale.
323
+
324
+ - center (np.ndarray[float32](2,)): center of the bbox (x, y).
325
+ - scale (np.ndarray[float32](2,)): scale of the bbox w & h.
326
+ """
327
+ aspect_ratio = self.ann_info['image_size'][0] / self.ann_info[
328
+ 'image_size'][1]
329
+ center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
330
+
331
+ if (not self.test_mode) and np.random.rand() < 0.3:
332
+ center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
333
+
334
+ if w > aspect_ratio * h:
335
+ h = w * 1.0 / aspect_ratio
336
+ elif w < aspect_ratio * h:
337
+ w = h * aspect_ratio
338
+
339
+ # pixel std is 200.0
340
+ scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
341
+ # padding to include proper amount of context
342
+ scale = scale * 1.25
343
+
344
+ return center, scale
345
+
346
+ def _load_coco_person_detection_results(self):
347
+ """Load coco person detection results."""
348
+ num_joints = self.ann_info['num_joints']
349
+ all_boxes = None
350
+ with open(self.bbox_file, 'r') as f:
351
+ all_boxes = json.load(f)
352
+
353
+ if not all_boxes:
354
+ raise ValueError('=> Load %s fail!' % self.bbox_file)
355
+
356
+ print(f'=> Total boxes: {len(all_boxes)}')
357
+
358
+ kpt_db = []
359
+ bbox_id = 0
360
+ for det_res in all_boxes:
361
+ if det_res['category_id'] != 1:
362
+ continue
363
+
364
+ image_file = os.path.join(self.img_prefix,
365
+ self.id2name[det_res['image_id']])
366
+ box = det_res['bbox']
367
+ score = det_res['score']
368
+
369
+ if score < self.det_bbox_thr:
370
+ continue
371
+
372
+ center, scale = self._xywh2cs(*box[:4])
373
+ joints_3d = np.zeros((num_joints, 3), dtype=np.float32)
374
+ joints_3d_visible = np.ones((num_joints, 3), dtype=np.float32)
375
+ kpt_db.append({
376
+ 'image_file': image_file,
377
+ 'center': center,
378
+ 'scale': scale,
379
+ 'rotation': 0,
380
+ 'bbox': box[:4],
381
+ 'bbox_score': score,
382
+ 'dataset': self.dataset_name,
383
+ 'joints_3d': joints_3d,
384
+ 'joints_3d_visible': joints_3d_visible,
385
+ 'bbox_id': bbox_id
386
+ })
387
+ bbox_id = bbox_id + 1
388
+ print(f'=> Total boxes after filter '
389
+ f'low score@{self.det_bbox_thr}: {bbox_id}')
390
+ return kpt_db
391
+
392
+ def __len__(self):
393
+ """Get the size of the dataset."""
394
+ return len(self.db)
395
+
396
+ def __getitem__(self, idx):
397
+ """Get the sample given index."""
398
+ results = copy.deepcopy(self.db[idx])
399
+ results['ann_info'] = self.ann_info
400
+ out = self.pipeline(results)
401
+
402
+ C = self.ann_info['num_joints']
403
+
404
+ if 'label' in out:
405
+ out['dense_labeling'] = np.resize(out['label'],
406
+ (C, self.ann_info['image_size'][0], self.ann_info['image_size'][1]))
407
+ else:
408
+ out['dense_labeling'] = np.zeros((C, self.ann_info['image_size'][0], self.ann_info['image_size'][1]))
409
+
410
+ # del out['ann_info']
411
+ return out # dict_keys(['image_file', 'center', 'scale', 'bbox', 'rotation', 'joints_3d', 'joints_3d_visible',
412
+ # 'dataset', 'bbox_score', 'bbox_id', 'ann_info', 'image', 'flipped', 'label',
413
+ # 'target_weight'])
core/data/datasets/images/parsing_dataset.py ADDED
@@ -0,0 +1,1084 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import cv2
4
+ import torch
5
+ import io
6
+ import numpy as np
7
+ import itertools
8
+ from typing import Any, Dict, List, Tuple, Union
9
+ from torch.utils import data
10
+ from torch.nn import functional as F
11
+ from PIL import Image
12
+ from core.utils import cv2_loader, pil_loader
13
+ from core.data.datasets.images.seg_dataset_dev import Instances, BitMasks
14
+ import random
15
+ import torch.distributed as dist
16
+
17
+ import core.data.transforms.parsing_transforms as T
18
+ from core.data.transforms.pose_transforms import DataContainer
19
+
20
+ try:
21
+ from petrel_client.client import Client as Client
22
+
23
+ s3client = Client(boto=True,
24
+ enable_multi_cluster=True,
25
+ enable_mc=True)
26
+ except:
27
+ print("ceph can not be used")
28
+
29
+ palette_dict = {
30
+ 'human3m6_parsing': np.array(
31
+ [[0, 0, 0], [128, 0, 0], [255, 0, 0], [0, 85, 0], [170, 0, 51], [255, 85, 0], [0, 0, 85], [0, 119, 221],
32
+ [85, 85, 0], [0, 85, 85],
33
+ [85, 51, 0], [52, 86, 128], [0, 128, 0], [0, 0, 255], [0, 255, 0],
34
+ [51, 170, 221], [0, 255, 255], [255, 170, 85], [85, 255, 170], [170, 85, 52],
35
+ [170, 255, 85], [255, 255, 0], [255, 170, 0], [255, 0, 170], [170, 0, 255]]),
36
+ 'LIP_parsing': np.array([[0, 0, 0], [128, 0, 0], [255, 0, 0], [0, 85, 0], [170, 0, 51],
37
+ [255, 85, 0], [0, 0, 85], [0, 119, 221], [85, 85,
38
+ 0], [0, 85, 85],
39
+ [85, 51, 0], [52, 86, 128], [0, 128, 0], [0, 0, 255],
40
+ [51, 170, 221], [0, 255, 255], [85, 255, 170], [170, 255, 85],
41
+ [255, 255, 0], [255, 170, 0]]),
42
+ 'CIHP_parsing': np.array([[0, 0, 0], [128, 0, 0], [255, 0, 0], [0, 85, 0], [170, 0, 51],
43
+ [255, 85, 0], [0, 0, 85], [0, 119, 221], [85, 85,
44
+ 0], [0, 85, 85],
45
+ [85, 51, 0], [52, 86, 128], [0, 128, 0], [0, 0, 255],
46
+ [51, 170, 221], [0, 255, 255], [85, 255, 170], [170, 255, 85],
47
+ [255, 255, 0], [255, 170, 0]]),
48
+ 'ATR_parsing': np.array([[0, 0, 0], [128, 0, 0], [255, 0, 0], [0, 85, 0], [170, 0, 51],
49
+ [255, 85, 0], [0, 0, 85], [0, 119, 221], [85, 85,
50
+ 0], [0, 85, 85],
51
+ [85, 51, 0], [52, 86, 128], [0, 128, 0], [0, 0, 255],
52
+ [51, 170, 221], [0, 255, 255], [85, 255, 170], [170, 255, 85]]),
53
+
54
+ }
55
+
56
+ def get_unk_mask_indices(image,num_labels,known_labels,epoch=1,testing=False,):
57
+ if testing:
58
+ # for consistency across epochs and experiments, seed using hashed image array
59
+ random.seed(hashlib.sha1(np.array(image)).hexdigest())
60
+ unk_mask_indices = random.sample(range(num_labels), (num_labels-int(known_labels)))
61
+ else:
62
+ # sample random number of known labels during training
63
+ if known_labels>0:
64
+ random.seed()
65
+ num_known = random.randint(0,int(num_labels*0.75))
66
+ else:
67
+ num_known = 0
68
+
69
+ unk_mask_indices = random.sample(range(num_labels), (num_labels-num_known))
70
+
71
+ return unk_mask_indices
72
+
73
+ class Human3M6ParsingDataset(data.Dataset):
74
+ task_name = 'human3m6_parsing'
75
+ left_right_pairs = np.array([[1, 6],
76
+ [2, 7],
77
+ [3, 8],
78
+ [17, 25],
79
+ [18, 26],
80
+ [19, 27],
81
+ [33, 38],
82
+ [34, 39],
83
+ [49, 56],
84
+ [50, 58]])
85
+
86
+ label_mapper = np.arange(60)
87
+
88
+ evaluate_size = (1000, 1000)
89
+
90
+ def __init__(self,
91
+ ginfo,
92
+ data_path,
93
+ dataset='train',
94
+ data_use_ratio=1,
95
+ is_train=True,
96
+ cfg=None,
97
+ **kwargs):
98
+ """human3.6m dataset for human parsing
99
+ Args:
100
+ root_dir ([str]): where dataset
101
+ dataset: train / val
102
+ cfg: yaml format config
103
+
104
+ # 0 : background
105
+ # 1 : right hip
106
+ # 2 : right knee
107
+ # 3 : right foot
108
+ # 6 : left hip
109
+ # 7 : left knee
110
+ # 8 : left foot
111
+ # 17 : left shoulder
112
+ # 18 : left elbow
113
+ # 19 : left hand
114
+ # 25 : right shoulder
115
+ # 26 : right elbow
116
+ # 27 : right hand
117
+ # 32 : crotch
118
+ # 33 : right thigh
119
+ # 34 : right calf
120
+ # 38 : left thigh
121
+ # 39 : left calf
122
+ # 43 : lower spine
123
+ # 44 : upper spine
124
+ # 46 : head
125
+ # 49 : left arm
126
+ # 50 : left forearm
127
+ # 56 : right arm
128
+ # 58 : right forearm
129
+
130
+ """
131
+ # self.task_name = 'human3m6_parsing'
132
+ self.cfg = cfg
133
+ self.dataset = dataset
134
+ self.is_train = is_train
135
+ self.data_use_ratio = data_use_ratio
136
+ self.pseudo_labels = self.cfg.get('Pseudo_labels', False)
137
+ self.stride_level = self.cfg.get('stride_level', 1)
138
+ # self.palette = palette_dict[self.task_name]
139
+ self.pseudo_labels_palette = palette_dict[self.cfg.get('Pseudo_labels_palette','human3m6_parsing')]
140
+ self.ignore2endclass = self.cfg.get('ignore2endclass', False)
141
+
142
+ self.img_list, self.label_list, self.name_list = self._list_dirs(data_path)
143
+
144
+ index = np.arange(0, len(self.img_list))
145
+ random.shuffle(index)
146
+ self.img_list = np.array(self.img_list)
147
+ self.label_list = np.array(self.label_list)
148
+ self.name_list = np.array(self.name_list)
149
+
150
+ self.img_list = self.img_list[index].tolist()
151
+ self.label_list = self.label_list[index].tolist()
152
+ self.name_list = self.name_list[index].tolist()
153
+
154
+ self.images = self.img_list
155
+ self.labels = self.label_list
156
+ self.ignore_label = cfg.ignore_value
157
+ self.num = len(self.images)
158
+ self.num_classes = len(self.cfg.label_list) # - 1
159
+ assert self.num_classes == self.cfg.num_classes, f"num of class mismatch, len(label_list)={self.num_classes}, num_classes:{self.cfg.num_classes}"
160
+
161
+ self.rank = dist.get_rank()
162
+ self.world_size = dist.get_world_size()
163
+
164
+ self.original_label = np.array(self.cfg.label_list)
165
+
166
+ for i, l in enumerate(self.original_label):
167
+ self.label_mapper[l] = i
168
+ self.mapped_left_right_pairs = self.label_mapper[self.left_right_pairs] if self.left_right_pairs is not None else None
169
+
170
+ if self.is_train:
171
+ augs = T.compose([T.hflip(cfg.get("is_flip", False), self.mapped_left_right_pairs),
172
+ T.resize_image(cfg.crop_size),
173
+ T.multi_scale(cfg.get("is_multi_scale", False), scale_factor=cfg.get("scale_factor", 11),
174
+ center_crop_test=cfg.get("center_crop_test", False),
175
+ base_size=cfg.base_size,
176
+ crop_size=cfg.crop_size,
177
+ ignore_label=cfg.get("ignore_value", 255)),
178
+ T.rotate(cfg.get("is_rotate", False), degree=cfg.get("degree", 30),
179
+ p=cfg.get("possibility", 0.6), pad_val=cfg.get("pad_val", 0),
180
+ seg_pad_val=cfg.get("ignore_value", 255)),
181
+ T.PhotoMetricDistortion(cfg.get('is_photometricdistortion', False),
182
+ brightness_delta=cfg.get('brightness', 32),
183
+ contrast_range=cfg.get('contrast_range', [0.5, 1.5]),
184
+ saturation_range=cfg.get("saturation_range", [0.5, 1.5]),
185
+ hue_delta=cfg.get('hue', 18)
186
+ ),
187
+ T.transpose()])
188
+ else:
189
+ augs = T.compose([T.resize_image_eval(cfg.eval_crop_size),
190
+ T.transpose()])
191
+ self.augs = augs
192
+
193
+ self.initialized = False
194
+ self.use_ceph = True
195
+
196
+ def __len__(self):
197
+ return len(self.img_list)
198
+
199
+ def _ignore_to_endclass(self, label):
200
+ label[label==self.ignore_label] = self.num_classes
201
+ return label
202
+
203
+ def _read_one(self, index=None):
204
+ if index == None:
205
+ index = np.random.randint(self.num)
206
+
207
+ filename = self.img_list[index]
208
+ try:
209
+ img = Image.open(filename).convert('RGB')
210
+ img = np.array(img)[:,:,::-1]
211
+
212
+ except:
213
+ outputName = "failed_to_read_in_train.txt"
214
+ with open(outputName, "a") as g:
215
+ g.write("%s\n" % (filename))
216
+ print('Read image[{}] failed ({})'.format(index, filename))
217
+ ## if fail then recursive call _read_one without idx
218
+ return self._read_one()
219
+
220
+ gt_label = self.label_list[index]
221
+
222
+ try:
223
+ label = np.array(Image.open(gt_label))
224
+ except:
225
+ outputName = "failed_to_read_in_train_labels.txt"
226
+ with open(outputName, "a") as g:
227
+ g.write("%s\n" % (filename))
228
+ print('Read image[{}] failed ({})'.format(index, gt_label))
229
+ ## if fail then recursive call _read_one without idx
230
+ return self._read_one()
231
+
232
+ return img, label
233
+
234
+ def __getitem__(self, index):
235
+
236
+ dataset_dict = {}
237
+ dataset_dict["filename"] = self.name_list[index]
238
+
239
+ image, parsing_seg_gt = self._read_one(index)
240
+
241
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
242
+
243
+ self._record_image_size(dataset_dict, image)
244
+
245
+ if self.pseudo_labels:
246
+ image, parsing_seg_gt = self.augs(image, parsing_seg_gt)
247
+ image = torch.as_tensor(np.ascontiguousarray(image))
248
+ parsing_seg_gt = torch.as_tensor(np.ascontiguousarray(parsing_seg_gt))
249
+ dataset_dict["image"] = image
250
+ dataset_dict['PL_gt'] = parsing_seg_gt
251
+ return dataset_dict
252
+
253
+ parsing_seg_gt = self._encode_label(parsing_seg_gt) # - 1 no need to filter background in human parsing
254
+
255
+ size = parsing_seg_gt.size
256
+
257
+ if not self.is_train:
258
+ if len(self.evaluate_size) == 2:
259
+ dataset_dict["gt"] = np.copy(
260
+ cv2.resize(parsing_seg_gt, self.evaluate_size, interpolation=cv2.INTER_LINEAR_EXACT).astype(np.int_))
261
+ else:
262
+ # use DataContainer type to avoid being batched as tensors
263
+ dataset_dict["gt"] = DataContainer(np.copy(parsing_seg_gt.astype(np.int_)))
264
+
265
+ parsing_seg_gt = parsing_seg_gt.astype("double")
266
+ assert len(parsing_seg_gt), "parsing needs gt to train"
267
+ image, parsing_seg_gt = self.augs(image, parsing_seg_gt)
268
+ if self.stride_level>1:
269
+ temp_h, temp_w = parsing_seg_gt.shape
270
+ parsing_seg_gt = np.asarray(Image.fromarray(parsing_seg_gt).convert('P').resize((int(temp_h/self.stride_level), int(temp_w/self.stride_level))))
271
+ if self.ignore2endclass:
272
+ parsing_seg_gt = self._ignore_to_endclass(parsing_seg_gt)
273
+ image = torch.as_tensor(np.ascontiguousarray(image))
274
+ parsing_seg_gt = torch.as_tensor(parsing_seg_gt.astype("long"))
275
+
276
+ image_shape = (image.shape[-2], image.shape[-1]) # h, w
277
+
278
+ dataset_dict["image"] = image
279
+ if not self.is_train:
280
+ if self.cfg.get('label_mask', False):
281
+ m = torch.ones(self.num_classes,dtype=torch.int64)*-1 # mask all labels
282
+ dataset_dict['mask'] = m
283
+ return dataset_dict
284
+
285
+ dataset_dict["label"] = parsing_seg_gt.long() # not used in test
286
+
287
+ # Prepare per-category binary masks
288
+ parsing_seg_gt = parsing_seg_gt.numpy()
289
+ instances = Instances(image_shape)
290
+ classes = np.unique(parsing_seg_gt)
291
+ # remove ignored region
292
+ if self.cfg.get('add_zero_mask',False):
293
+ classes = np.array(list(range(self.num_classes)))
294
+ classes = classes[classes != self.ignore_label]
295
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
296
+
297
+ if self.cfg.get('label_mask', False):
298
+ m = np.zeros(self.num_classes)
299
+ m[classes] = 1
300
+ mask = torch.tensor(m, dtype=torch.int64).clone()
301
+ unk_mask_indices = get_unk_mask_indices(image, self.num_classes, known_labels=100,) # set known_labels>1 to use label masking training
302
+ mask.scatter_(0, torch.Tensor(unk_mask_indices).long(), -1)
303
+ dataset_dict['mask'] = mask
304
+
305
+ masks = []
306
+ for class_id in classes:
307
+ masks.append(parsing_seg_gt == class_id)
308
+
309
+ if len(masks) == 0:
310
+ # Some image does not have annotation (all ignored)
311
+ instances.gt_masks = torch.zeros((0, parsing_seg_gt.shape[-2], parsing_seg_gt.shape[-1]))
312
+ else:
313
+ masks = BitMasks(
314
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
315
+ )
316
+ instances.gt_masks = masks.tensor
317
+
318
+ dataset_dict["instances"] = instances # not used in test
319
+
320
+ return dataset_dict # {'image': img_mask, 'label': target_mask, 'instances': xxx, 'filename': img_name}
321
+
322
+ @staticmethod
323
+ def _record_image_size(dataset_dict, image):
324
+ """
325
+ Raise an error if the image does not match the size specified in the dict.
326
+ """
327
+ # To ensure bbox always remap to original image size
328
+ if "width" not in dataset_dict:
329
+ dataset_dict["width"] = image.shape[1]
330
+ if "height" not in dataset_dict:
331
+ dataset_dict["height"] = image.shape[0]
332
+
333
+ def _list_dirs(self, data_path):
334
+ img_list = list()
335
+ label_list = list()
336
+ name_list = list()
337
+
338
+ if self.dataset == 'train':
339
+ train_type = 'train'
340
+ elif self.dataset == 'val':
341
+ train_type = 'eval'
342
+ list_txt = osp.join(data_path, f'flist_2hz_{train_type}.txt')
343
+
344
+ # with open(list_txt, 'r') as f:
345
+ # data = f.readlines()
346
+ # data = [d.strip() for d in data]
347
+ list_lines = s3client.Get(list_txt)
348
+ if not list_lines:
349
+ print('File not exist', list_txt)
350
+ import pdb;
351
+ pdb.set_trace()
352
+ raise IOError('File not exist', list_file)
353
+ list_lines = list_lines.decode('ascii')
354
+ data = list_lines.split('\n')
355
+ data = [d for d in data if len(d)]
356
+
357
+
358
+ if self.data_use_ratio != 1:
359
+ data = random.sample(data, int(len(data) * self.data_use_ratio))
360
+
361
+ for d in data:
362
+ img_path = osp.join(data_path, d)
363
+ image_name = '/'.join(d.split('/')[2:])
364
+ label_path = img_path.replace('rgb', 'seg', 1)
365
+
366
+ img_list.append(img_path)
367
+ label_list.append(label_path)
368
+ name_list.append(image_name)
369
+
370
+ return img_list, label_list, name_list
371
+
372
+ def _encode_label(self, labelmap):
373
+ shape = labelmap.shape
374
+ encoded_labelmap = np.zeros(shape=(shape[0], shape[1]), dtype=np.uint8)
375
+ for i, class_id in enumerate(self.cfg.label_list):
376
+ encoded_labelmap[labelmap == class_id] = i
377
+
378
+ return encoded_labelmap
379
+
380
+ def __repr__(self):
381
+ return self.__class__.__name__ + \
382
+ f'rank: {self.rank} task: {self.task_name} mode:{"training" if self.is_train else "inference"} ' \
383
+ f'dataset_len:{len(self.img_list)} id_num:{self.cfg["num_classes"]} augmentation: {self.augs}'
384
+
385
+ class LIPParsingDataset(Human3M6ParsingDataset):
386
+ """
387
+ 0:'background',
388
+ 1:'hat',
389
+ 2:'hair',
390
+ 3:'glove',
391
+ 4:'sunglasses',
392
+ 5:'upperclothes',
393
+ 6:'dress',
394
+ 7:'coat',
395
+ 8:'socks',
396
+ 9:'pants',
397
+ 10:'jumpsuits',
398
+ 11:'scarf',
399
+ 12:'skirt',
400
+ 13:'face',
401
+ 14:'leftArm',
402
+ 15:'rightArm',
403
+ 16:'leftLeg',
404
+ 17:'rightLeg',
405
+ 18:'leftShoe',
406
+ 19:'rightShoe'
407
+ """
408
+ task_name = 'LIP_parsing'
409
+
410
+ left_right_pairs = np.array(
411
+ [[14, 15], [16, 17], [18, 19]]
412
+ )
413
+
414
+ label_mapper = np.arange(60)
415
+
416
+ evaluate_size = ()
417
+
418
+ def __init__(self,
419
+ ginfo,
420
+ data_path,
421
+ dataset='train',
422
+ data_use_ratio=1,
423
+ is_train=True,
424
+ cfg=None,
425
+ **kwargs):
426
+ super(LIPParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,
427
+ data_use_ratio=data_use_ratio,
428
+ dataset=dataset, is_train=is_train,
429
+ cfg=cfg, **kwargs)
430
+
431
+ def _list_dirs(self, data_path):
432
+ img_list = list()
433
+ label_list = list()
434
+ name_list = list()
435
+
436
+ if self.dataset == 'train':
437
+ train_type = 'train'
438
+ elif self.dataset == 'val':
439
+ train_type = 'val'
440
+ """
441
+ - LIP
442
+ -data
443
+ -train_id.txt
444
+ -train_images
445
+ -1000_1234574.jpg
446
+ -val_images
447
+ -val_id.txt
448
+ -Trainval_parsing_annotations
449
+ -train_segmentations
450
+ -1000_1234574.png
451
+ """
452
+ list_txt = osp.join(data_path, 'data', f'{train_type}_id.txt')
453
+
454
+ with open(list_txt, 'r') as f:
455
+ data = f.readlines()
456
+ data = [d.strip() for d in data]
457
+
458
+
459
+ if self.data_use_ratio != 1:
460
+ data = random.sample(data, int(len(data) * self.data_use_ratio))
461
+
462
+ postfix_img = '.jpg'
463
+ postfix_ann = '.png'
464
+ for d in data:
465
+ img_path = osp.join(data_path, f'data/{train_type}_images', d + postfix_img)
466
+ image_name = d
467
+ label_path = osp.join(data_path, f'TrainVal_parsing_annotations/{train_type}_segmentations',
468
+ d + postfix_ann)
469
+
470
+ img_list.append(img_path)
471
+ label_list.append(label_path)
472
+ name_list.append(image_name)
473
+
474
+ return img_list, label_list, name_list
475
+
476
+ class CIHPParsingDataset(Human3M6ParsingDataset):
477
+ """
478
+ 0:'background',
479
+ 1:'hat',
480
+ 2:'hair',
481
+ 3:'glove',
482
+ 4:'sunglasses',
483
+ 5:'upperclothes',
484
+ 6:'dress',
485
+ 7:'coat',
486
+ 8:'socks',
487
+ 9:'pants',
488
+ 10:'torsoSkin',
489
+ 11:'scarf',
490
+ 12:'skirt',
491
+ 13:'face',
492
+ 14:'leftArm',
493
+ 15:'rightArm',
494
+ 16:'leftLeg',
495
+ 17:'rightLeg',
496
+ 18:'leftShoe',
497
+ 19:'rightShoe'
498
+ """
499
+ task_name = 'CIHP_parsing'
500
+
501
+ left_right_pairs = np.array(
502
+ [[14, 15], [16, 17], [18, 19]]
503
+ )
504
+
505
+ label_mapper = np.arange(60)
506
+
507
+ evaluate_size = ()
508
+
509
+ def __init__(self,
510
+ ginfo,
511
+ data_path,
512
+ dataset='train',
513
+ data_use_ratio=1,
514
+ is_train=True,
515
+ cfg=None,
516
+ **kwargs):
517
+ super(CIHPParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,data_use_ratio=data_use_ratio,
518
+ dataset=dataset, is_train=is_train,
519
+ cfg=cfg, **kwargs)
520
+
521
+ def _list_dirs(self, data_path):
522
+ img_list = list()
523
+ label_list = list()
524
+ name_list = list()
525
+
526
+ if self.dataset == 'train':
527
+ train_type = 'train'
528
+ elif self.dataset == 'val':
529
+ train_type = 'val'
530
+ """
531
+ - CHIP
532
+ -instance-level_human_parsing
533
+ -Training
534
+ -Images
535
+ -0008522.jpg
536
+ -Category_ids
537
+ -0008522.png
538
+ -train_id.txt
539
+ -Validation
540
+ -val_id.txt
541
+ """
542
+ Infix = 'Training' if train_type == 'train' else 'Validation'
543
+ list_txt = osp.join(data_path, 'instance-level_human_parsing', Infix, f'{train_type}_id.txt')
544
+
545
+ with open(list_txt, 'r') as f:
546
+ data = f.readlines()
547
+ data = [d.strip() for d in data]
548
+
549
+ if self.data_use_ratio != 1:
550
+ data = random.sample(data, int(len(data) * self.data_use_ratio))
551
+
552
+ postfix_img = '.jpg'
553
+ postfix_ann = '.png'
554
+ for d in data:
555
+ img_path = osp.join(data_path, 'instance-level_human_parsing', Infix, f'Images', d + postfix_img)
556
+ image_name = d
557
+ label_path = osp.join(data_path, 'instance-level_human_parsing', Infix, 'Category_ids', d + postfix_ann)
558
+
559
+ img_list.append(img_path)
560
+ label_list.append(label_path)
561
+ name_list.append(image_name)
562
+
563
+ return img_list, label_list, name_list
564
+
565
+
566
+ class ATRParsingDataset(Human3M6ParsingDataset):
567
+ """
568
+ 0:'background', #
569
+ 1:'hat', #
570
+ 2:'hair',#
571
+ 3:'sunglasses',#
572
+ 4:'upperclothes',#
573
+ 5:'skirt',
574
+ 6:'pants',#
575
+ 7:'dress',#
576
+ 8:'belt',
577
+ 9:'leftshoe',#
578
+ 10:'rightshoe',#
579
+ 11:'face',#
580
+ 12:'leftleg',#
581
+ 13:'rightleg',#
582
+ 14:'leftarm',#
583
+ 15:'rightarm',#
584
+ 16:'bag',#
585
+ 17:'scarf',#
586
+ """
587
+ task_name = 'ATR_parsing'
588
+
589
+ left_right_pairs = np.array(
590
+ [[9,10], [12,13], [14,15]]
591
+ )
592
+
593
+ label_mapper = np.arange(60)
594
+
595
+ evaluate_size = ()
596
+
597
+ def __init__(self,
598
+ ginfo,
599
+ data_path,
600
+ dataset='train',
601
+ data_use_ratio=1,
602
+ is_train=True,
603
+ cfg=None,
604
+ **kwargs):
605
+ super(ATRParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,
606
+ data_use_ratio=data_use_ratio,
607
+ dataset=dataset, is_train=is_train,
608
+ cfg=cfg, **kwargs)
609
+
610
+ def _list_dirs(self, data_path):
611
+ img_list = list()
612
+ label_list = list()
613
+ name_list = list()
614
+
615
+ if self.dataset == 'train':
616
+ train_type = 'train'
617
+ elif self.dataset == 'val':
618
+ train_type = 'val'
619
+ """
620
+ - ATR
621
+ -humanparsing
622
+ -JPEGImages
623
+ -SegmentationClassAug
624
+ -train_id.txt
625
+ -val_id.txt
626
+ """
627
+ list_txt = osp.join(data_path, f'{train_type}_id.txt')
628
+ with open(list_txt, 'r') as f:
629
+ data = f.readlines()
630
+ data = [d.strip() for d in data]
631
+
632
+ if self.data_use_ratio != 1:
633
+ data = random.sample(data, int(len(data) * self.data_use_ratio))
634
+
635
+ postfix_img = '.jpg'
636
+ postfix_ann = '.png'
637
+ for d in data:
638
+ img_path = osp.join(data_path, f'humanparsing/JPEGImages', d + postfix_img)
639
+ image_name = d
640
+ label_path = osp.join(data_path, f'humanparsing/SegmentationClassAug',
641
+ d + postfix_ann)
642
+
643
+ img_list.append(img_path)
644
+ label_list.append(label_path)
645
+ name_list.append(image_name)
646
+
647
+ return img_list, label_list, name_list
648
+
649
+
650
+ class DeepFashionParsingDataset(Human3M6ParsingDataset):
651
+ """
652
+ 0:'background', #
653
+ 1:'hat', #
654
+ 2:'hair',#
655
+ 3:'sunglasses',#
656
+ 4:'upperclothes',#
657
+ 5:'skirt',
658
+ 6:'pants',#
659
+ 7:'dress',#
660
+ 8:'belt',
661
+ 9:'leftshoe',#
662
+ 10:'rightshoe',#
663
+ 11:'face',#
664
+ 12:'leftleg',#
665
+ 13:'rightleg',#
666
+ 14:'leftarm',#
667
+ 15:'rightarm',#
668
+ 16:'bag',#
669
+ 17:'scarf',#
670
+ """
671
+ task_name = 'DeepFashion_parsing'
672
+ label_mapper = np.arange(60)
673
+ left_right_pairs = None
674
+ evaluate_size = ()
675
+
676
+ def __init__(self,
677
+ ginfo,
678
+ data_path,
679
+ dataset='train',
680
+ data_use_ratio=1,
681
+ is_train=True,
682
+ cfg=None,
683
+ **kwargs):
684
+ super(DeepFashionParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,
685
+ data_use_ratio=data_use_ratio,
686
+ dataset=dataset, is_train=is_train,
687
+ cfg=cfg, **kwargs)
688
+
689
+ def _list_dirs(self, data_path):
690
+ img_list = list()
691
+ label_list = list()
692
+ name_list = list()
693
+
694
+ if self.dataset == 'train':
695
+ train_type = 'train'
696
+ elif self.dataset == 'val':
697
+ train_type = 'val'
698
+ """
699
+ - DeepFashion
700
+ -humanparsing
701
+ -JPEGImages
702
+ -SegmentationClassAug
703
+ -train_id.txt
704
+ -val_id.txt
705
+ """
706
+ list_txt = osp.join(data_path, f'{train_type}_id.txt')
707
+
708
+ with open(list_txt, 'r') as f:
709
+ data = f.readlines()
710
+ data = [d.strip() for d in data]
711
+
712
+
713
+ if self.data_use_ratio != 1:
714
+ data = random.sample(data, int(len(data) * self.data_use_ratio))
715
+
716
+ postfix_img = '.jpg'
717
+ postfix_ann = '.png'
718
+
719
+ if train_type == 'train':
720
+ for d in data:
721
+ img_path = osp.join(data_path, f'train/image', d + postfix_img)
722
+ image_name = d
723
+ label_path = osp.join(data_path, f'train/seg',
724
+ d + postfix_ann)
725
+
726
+ img_list.append(img_path)
727
+ label_list.append(label_path)
728
+ name_list.append(image_name)
729
+
730
+ return img_list, label_list, name_list
731
+ else:
732
+ raise ValueError("not implement")
733
+
734
+
735
+ class VIPParsingDataset(Human3M6ParsingDataset):
736
+ task_name = 'VIP_parsing'
737
+ left_right_pairs = np.array(
738
+ [[14, 15], [16, 17], [18, 19]]
739
+ )
740
+
741
+ label_mapper = np.arange(60)
742
+
743
+ evaluate_size = ()
744
+
745
+ def __init__(self,
746
+ ginfo,
747
+ data_path,
748
+ dataset='train',
749
+ data_use_ratio=1,
750
+ is_train=True,
751
+ cfg=None,
752
+ **kwargs):
753
+ super(VIPParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,data_use_ratio=data_use_ratio,
754
+ dataset=dataset, is_train=is_train,
755
+ cfg=cfg, **kwargs)
756
+
757
+ def _list_dirs(self, data_path):
758
+ img_list = list()
759
+ label_list = list()
760
+ name_list = list()
761
+
762
+ if self.dataset == 'train':
763
+ train_type = 'train'
764
+ elif self.dataset == 'val':
765
+ train_type = 'val'
766
+
767
+ list_txt = osp.join(data_path, f'{train_type}_id.txt')
768
+
769
+ with open(list_txt, 'r') as f:
770
+ data = f.readlines()
771
+ data = [d.strip() for d in data]
772
+
773
+
774
+ postfix_img = '.jpg'
775
+ postfix_ann = '.png'
776
+
777
+ if self.data_use_ratio != 1:
778
+ data = random.sample(data, int(len(data) * self.data_use_ratio))
779
+
780
+ if train_type == 'train':
781
+ for d in data:
782
+ img_path = osp.join(data_path, f'Images', d + postfix_img)
783
+ image_name = d
784
+ label_path = osp.join(data_path, f'Annotations/Category_ids',
785
+ d + postfix_ann)
786
+
787
+ img_list.append(img_path)
788
+ label_list.append(label_path)
789
+ name_list.append(image_name)
790
+
791
+ return img_list, label_list, name_list
792
+ else:
793
+ raise ValueError("not implement")
794
+
795
+ class PaperDollParsingDataset(Human3M6ParsingDataset):
796
+ """
797
+ 0:'background',
798
+ 1:'hat',
799
+ 2:'hair',
800
+ 3:'glove',
801
+ 4:'sunglasses',
802
+ 5:'upperclothes',
803
+ 6:'dress',
804
+ 7:'coat',
805
+ 8:'socks',
806
+ 9:'pants',
807
+ 10:'torsoSkin',
808
+ 11:'scarf',
809
+ 12:'skirt',
810
+ 13:'face',
811
+ 14:'leftArm',
812
+ 15:'rightArm',
813
+ 16:'leftLeg',
814
+ 17:'rightLeg',
815
+ 18:'leftShoe',
816
+ 19:'rightShoe'
817
+ """
818
+ task_name = 'PaperDoll_parsing'
819
+
820
+ left_right_pairs = np.array(
821
+ [[14, 15], [16, 17], [18, 19]]
822
+ )
823
+
824
+ label_mapper = np.arange(60)
825
+
826
+ evaluate_size = ()
827
+
828
+ def __init__(self,
829
+ ginfo,
830
+ data_path,
831
+ data_use_ratio=1,
832
+ dataset='train',
833
+ is_train=True,
834
+ cfg=None,
835
+ **kwargs):
836
+ super(PaperDollParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path, data_use_ratio=data_use_ratio,
837
+ dataset=dataset, is_train=is_train,
838
+ cfg=cfg, **kwargs)
839
+
840
+ def _list_dirs(self, data_path):
841
+ img_list = list()
842
+ label_list = list()
843
+ name_list = list()
844
+
845
+ if self.dataset == 'train':
846
+ train_type = 'train'
847
+ elif self.dataset == 'val':
848
+ train_type = 'val'
849
+ """
850
+ - PaperDoll_folder
851
+ - TrainVal_parsing_annotations/
852
+ - 0000000.png
853
+ - images
854
+ - 0000000.jpg
855
+ """
856
+ list_txt = osp.join(data_path, f'{train_type}_id.txt')
857
+
858
+ with open(list_txt, 'r') as f:
859
+ data = f.readlines()
860
+ data = [d.strip() for d in data]
861
+
862
+
863
+ postfix_img = '.jpg'
864
+ postfix_ann = '.png'
865
+
866
+ if self.data_use_ratio != 1:
867
+ data = random.sample(data, int(len(data) * self.data_use_ratio))
868
+
869
+ for d in data:
870
+ img_path = osp.join(data_path, 'images', d + postfix_img)
871
+ image_name = d
872
+ label_path = osp.join(data_path, 'TrainVal_parsing_annotations/', d + postfix_ann)
873
+
874
+ img_list.append(img_path)
875
+ label_list.append(label_path)
876
+ name_list.append(image_name)
877
+
878
+ return img_list, label_list, name_list
879
+
880
+
881
+ class FashionPediaParsingDataset(Human3M6ParsingDataset):
882
+ task_name = 'FashionPedia_parsing'
883
+
884
+ label_mapper = np.arange(60)
885
+
886
+ evaluate_size = ()
887
+
888
+ def __init__(self,
889
+ ginfo,
890
+ data_path,
891
+ data_use_ratio=1,
892
+ dataset='train',
893
+ is_train=True,
894
+ cfg=None,
895
+ **kwargs):
896
+ super(FashionPediaParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,
897
+ data_use_ratio=data_use_ratio,
898
+ dataset=dataset, is_train=is_train,
899
+ cfg=cfg, **kwargs)
900
+
901
+ def _list_dirs(self, data_path):
902
+ img_list = list()
903
+ label_list = list()
904
+ name_list = list()
905
+
906
+ if self.dataset == 'train':
907
+ train_type = 'train'
908
+ elif self.dataset == 'val':
909
+ train_type = 'val'
910
+
911
+ list_txt = osp.join(data_path, f'{train_type}_id.txt')
912
+
913
+ with open(list_txt, 'r') as f:
914
+ data = f.readlines()
915
+ data = [d.strip() for d in data]
916
+
917
+
918
+ postfix_img = '.jpg'
919
+ postfix_ann = '.png'
920
+
921
+ if self.data_use_ratio != 1:
922
+ data = random.sample(data, int(len(data) * self.data_use_ratio))
923
+
924
+ if train_type == 'train':
925
+ for d in data:
926
+ img_path = osp.join(data_path, 'train/', d + postfix_img)
927
+ image_name = d
928
+ label_path = osp.join(data_path, 'train_annotation/', d + postfix_ann)
929
+
930
+ img_list.append(img_path)
931
+ label_list.append(label_path)
932
+ name_list.append(image_name)
933
+
934
+ return img_list, label_list, name_list
935
+
936
+ elif train_type == 'val':
937
+ for d in data:
938
+ img_path = osp.join(data_path, 'test/', d + postfix_img)
939
+ image_name = d
940
+ label_path = osp.join(data_path, 'test_annotation/', d + postfix_ann)
941
+
942
+ img_list.append(img_path)
943
+ label_list.append(label_path)
944
+ name_list.append(image_name)
945
+
946
+ return img_list, label_list, name_list
947
+ else:
948
+ raise
949
+
950
+
951
+ class ModaNetParsingDataset(Human3M6ParsingDataset):
952
+ """
953
+ modanet_par = {
954
+ 0: 'Background',
955
+ 1: 'Bag',
956
+ 2: 'Belt',
957
+ 3: 'Boots',
958
+ 4: 'Footwear',
959
+ 5: 'Outer',
960
+ 6: 'Dress',
961
+ 7: 'Sunglasses',
962
+ 8: 'Pants',
963
+ 9: 'Top',
964
+ 10: 'Shorts',
965
+ 11: 'Skirt',
966
+ 12: 'Headwear',
967
+ 13: 'Scarf & Tie'
968
+ }
969
+ """
970
+
971
+ task_name = 'ModaNet_parsing'
972
+ label_mapper = np.arange(60)
973
+ left_right_pairs = None
974
+ evaluate_size = ()
975
+
976
+ def __init__(self,
977
+ ginfo,
978
+ data_path,
979
+ dataset='train',
980
+ data_use_ratio=1,
981
+ is_train=True,
982
+ cfg=None,
983
+ **kwargs):
984
+ super(ModaNetParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,data_use_ratio=data_use_ratio,
985
+ dataset=dataset, is_train=is_train,
986
+ cfg=cfg, **kwargs)
987
+
988
+ def _list_dirs(self, data_path):
989
+ img_list = list()
990
+ label_list = list()
991
+ name_list = list()
992
+ # image_dir = osp.join(data_path, 'protocal_1', 'rgb')
993
+ # label_dir = osp.join(data_path, 'protocal_1', 'seg')
994
+
995
+ if self.dataset == 'train':
996
+ train_type = 'train'
997
+ elif self.dataset == 'val':
998
+ train_type = 'val'
999
+
1000
+ list_txt = osp.join(data_path, f'{train_type}_id.txt')
1001
+
1002
+ with open(list_txt, 'r') as f:
1003
+ data = f.readlines()
1004
+ data = [d.strip() for d in data]
1005
+
1006
+
1007
+ if self.data_use_ratio != 1:
1008
+ data = random.sample(data, int(len(data) * self.data_use_ratio))
1009
+
1010
+ postfix_img = '.jpg'
1011
+ postfix_ann = '.png'
1012
+
1013
+ if train_type == 'train':
1014
+ for d in data:
1015
+ img_path = osp.join(data_path, f'images', d + postfix_img)
1016
+ image_name = d
1017
+ label_path = osp.join(data_path, f'seg',
1018
+ d + postfix_ann)
1019
+
1020
+ img_list.append(img_path)
1021
+ label_list.append(label_path)
1022
+ name_list.append(image_name)
1023
+
1024
+ return img_list, label_list, name_list
1025
+ else:
1026
+ raise ValueError("not implement")
1027
+
1028
+
1029
+ class MHPParsingDataset(Human3M6ParsingDataset):
1030
+ task_name = 'MHP_parsing'
1031
+ label_mapper = np.arange(60)
1032
+ left_right_pairs = None
1033
+ evaluate_size = ()
1034
+
1035
+ def __init__(self,
1036
+ ginfo,
1037
+ data_path,
1038
+ dataset='train',
1039
+ data_use_ratio=1,
1040
+ is_train=True,
1041
+ cfg=None,
1042
+ **kwargs):
1043
+ super(ModaNetParsingDataset, self).__init__(ginfo=ginfo, data_path=data_path,data_use_ratio=data_use_ratio,
1044
+ dataset=dataset, is_train=is_train,
1045
+ cfg=cfg, **kwargs)
1046
+
1047
+ def _list_dirs(self, data_path):
1048
+ img_list = list()
1049
+ label_list = list()
1050
+ name_list = list()
1051
+
1052
+ if self.dataset == 'train':
1053
+ train_type = 'train'
1054
+ elif self.dataset == 'val':
1055
+ train_type = 'val'
1056
+
1057
+ list_txt = osp.join(data_path, f'{train_type}_id.txt')
1058
+
1059
+ with open(list_txt, 'r') as f:
1060
+ data = f.readlines()
1061
+ data = [d.strip() for d in data]
1062
+
1063
+
1064
+ if self.data_use_ratio != 1:
1065
+ data = random.sample(data, int(len(data) * self.data_use_ratio))
1066
+
1067
+ postfix_img = '.jpg'
1068
+ postfix_ann = '.png'
1069
+
1070
+ if train_type == 'train':
1071
+ for d in data:
1072
+ img_path = osp.join(data_path, f'images/', d + postfix_img)
1073
+ image_name = d
1074
+ label_path = osp.join(data_path, f'processed_label/',
1075
+ d + postfix_ann)
1076
+
1077
+ img_list.append(img_path)
1078
+ label_list.append(label_path)
1079
+ name_list.append(image_name)
1080
+
1081
+ return img_list, label_list, name_list
1082
+ else:
1083
+ raise ValueError("not implement")
1084
+
core/data/datasets/images/pedattr_dataset.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import pickle
4
+ import random
5
+ from easydict import EasyDict as edict
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ from core.data.transforms.pedattr_transforms import PedAttrAugmentation, PedAttrTestAugmentation, PedAttrRandomAugmentation
10
+ import torch.distributed as dist
11
+
12
+
13
+ __all__ = ['AttrDataset', 'MultiAttrDataset']
14
+
15
+ def merge_pedattr_datasets(data_path_list, root_path_list, dataset_name_list, train,
16
+ data_use_ratio, text_label_return, select_data, ignore_other_attrs=True):
17
+ total_img_id = []
18
+ total_attr_num = 0
19
+ total_img_num = 0
20
+ total_attr_begin = []
21
+ total_attr_end = []
22
+ total_img_begin = []
23
+ total_img_end = []
24
+ total_text_dict = {}
25
+ attr_begin = []
26
+ attr_end = []
27
+
28
+ for data_path, root_path, dataset_name in zip(data_path_list, root_path_list, dataset_name_list):
29
+ assert dataset_name in ['peta', 'PA_100k', 'rap', 'rap2', 'uavhuman', 'HARDHC',
30
+ 'ClothingAttribute', 'parse27k', 'duke', 'market','lup_0_200w', 'lup_0_600w', 'lup_600_1200w'], \
31
+ 'dataset name {} is not exist'.format(dataset_name)
32
+
33
+
34
+ with open(data_path, 'rb') as f:
35
+ dataset_info = pickle.load(f)
36
+ dataset_info = edict(dataset_info)
37
+ img_id = dataset_info.image_name
38
+ attr_label = dataset_info.label
39
+
40
+ if train:
41
+ split = 'trainval'
42
+ else:
43
+ split = 'test'
44
+
45
+ attr_id = dataset_info.attr_name
46
+ attr_num = len(attr_id)
47
+
48
+ total_attr_begin.append(total_attr_num)
49
+ total_attr_num = total_attr_num + attr_num
50
+ total_attr_end.append(total_attr_num)
51
+
52
+ if select_data is None or (select_data!= None and dataset_name == select_data):
53
+ assert split in dataset_info.partition.keys(), f'split {split} is not exist'
54
+ img_idx = dataset_info.partition[split]
55
+
56
+ if isinstance(img_idx, list):
57
+ img_idx = img_idx[0] # default partition 0
58
+
59
+ if data_use_ratio != 1:
60
+ img_idx = random.sample(list(img_idx), int(len(img_idx) * data_use_ratio))
61
+
62
+ img_num = len(img_idx)
63
+ img_idx = np.array(img_idx)
64
+
65
+ img_id = [os.path.join(root_path, img_id[i]) for i in img_idx]
66
+ label = attr_label[img_idx]
67
+
68
+
69
+ total_img_begin.append(total_img_num)
70
+ total_img_num = total_img_num + len(img_id)
71
+ total_img_end.append(total_img_num)
72
+ else:
73
+ # when testing on a single dataset, split may not exist in other datasets. therefore, we need to set a fake
74
+ # split to make the code run. and the number of images in this fake split is 0.
75
+ # TODO: find a better way to solve this problem. e.g., use a time for-loop to load the select dataset
76
+ img_id = []
77
+ label = []
78
+ img_num = 0
79
+ total_img_begin.append(total_img_num)
80
+ total_img_num = total_img_num + len(img_id)
81
+ total_img_end.append(total_img_num)
82
+
83
+ infilling_class = -1 if ignore_other_attrs else 0
84
+ total_label = np.full((total_img_num, total_attr_num), infilling_class, dtype=np.int32)
85
+ select_attr_begin = 0
86
+ select_attr_end = total_attr_num
87
+
88
+ for index, (data_path, root_path, dataset_name) in enumerate(zip(data_path_list, root_path_list, dataset_name_list)):
89
+
90
+ assert dataset_name in ['peta', 'PA_100k', 'rap', 'rap2', 'uavhuman', 'HARDHC',
91
+ 'ClothingAttribute', 'parse27k', 'duke', 'market','lup_0_200w', 'lup_0_600w', 'lup_600_1200w'], \
92
+ 'dataset name {} is not exist'.format(dataset_name)
93
+ with open(data_path, 'rb') as f:
94
+ dataset_info = pickle.load(f)
95
+ dataset_info = edict(dataset_info)
96
+
97
+ img_id = dataset_info.image_name
98
+ attr_label = dataset_info.label
99
+
100
+ if train:
101
+ split = 'trainval'
102
+ else:
103
+ split = 'test'
104
+
105
+ if not train and dataset_name != select_data:
106
+ continue
107
+
108
+ assert split in dataset_info.partition.keys(), f'split {split} is not exist'
109
+
110
+ attr_id = dataset_info.attr_name
111
+ attr_num = len(attr_id)
112
+
113
+ img_idx = dataset_info.partition[split]
114
+
115
+ if isinstance(img_idx, list):
116
+ img_idx = img_idx[0] # default partition 0
117
+
118
+ if data_use_ratio != 1:
119
+ img_idx = random.sample(list(img_idx), int(len(img_idx) * data_use_ratio))
120
+
121
+ img_num = len(img_idx)
122
+ img_idx = np.array(img_idx)
123
+
124
+ img_id = [os.path.join(root_path, img_id[i]) for i in img_idx]
125
+ label = attr_label[img_idx]
126
+ # import pdb;pdb.set_trace()
127
+ if text_label_return:
128
+ for idx in range(attr_num):
129
+ total_text_dict[total_attr_begin[index] + idx] = eval(f"{dataset_name}_attr_name")[idx]
130
+
131
+ if not train:
132
+ if dataset_name == select_data:
133
+ total_label[total_img_begin[index]: total_img_end[index], total_attr_begin[index]: total_attr_end[index]] = label
134
+ total_img_id.extend(img_id)
135
+ attr_begin.extend([total_attr_begin[index] for i in img_idx])
136
+ attr_end.extend([total_attr_end[index] for i in img_idx])
137
+ else:
138
+ total_label[total_img_begin[index]: total_img_end[index], total_attr_begin[index]: total_attr_end[index]] = label
139
+ total_img_id.extend(img_id)
140
+ attr_begin.extend([total_attr_begin[index] for i in img_idx])
141
+ attr_end.extend([total_attr_end[index] for i in img_idx])
142
+
143
+ # import pdb;pdb.set_trace()
144
+ return total_img_id, total_label, total_text_dict, attr_begin, attr_end
145
+
146
+
147
+
148
+ class MultiAttrDataset(data.Dataset):
149
+
150
+ def __init__(self, ginfo, augmentation, task_spec, train=True, data_use_ratio=1, text_label_return=False,
151
+ select_data=None, ignore_other_attrs=True,
152
+ **kwargs):
153
+ data_path = task_spec.data_path
154
+ root_path = task_spec.root_path
155
+ dataset_name = task_spec.dataset
156
+ # import pdb; pdb.set_trace()
157
+ self.rank = dist.get_rank()
158
+ self.train = train
159
+ self.img_id, self.label, self.text_dict, self.attr_begin, self.attr_end = \
160
+ merge_pedattr_datasets(data_path, root_path, dataset_name, train,
161
+ data_use_ratio, text_label_return, select_data, ignore_other_attrs)
162
+ height = augmentation.height
163
+ width = augmentation.width
164
+ self.img_num = len(self.img_id)
165
+
166
+ if train:
167
+ self.transform = PedAttrAugmentation(height, width)
168
+ if augmentation.get('use_random_aug', False):
169
+ self.transform = PedAttrRandomAugmentation(height, width, \
170
+ augmentation.use_random_aug.m, augmentation.use_random_aug.n)
171
+ else:
172
+ self.transform = PedAttrTestAugmentation(height, width)
173
+
174
+
175
+ self.task_name = ginfo.task_name
176
+
177
+ def __getitem__(self, index):
178
+ return self.read_one(index)
179
+
180
+ def __len__(self):
181
+ return len(self.img_id)
182
+
183
+ def read_one(self, idx=None):
184
+ if idx == None:
185
+ idx = np.random.randint(self.img_num)
186
+
187
+ imgname, gt_label = self.img_id[idx], self.label[idx]
188
+ imgpath = imgname
189
+
190
+ try:
191
+ img = Image.open(imgpath).convert('RGB')
192
+ if self.transform is not None:
193
+ img = self.transform(img)
194
+
195
+ gt_label = gt_label.astype(np.float32)
196
+
197
+ output = {}
198
+ output = {'image': img, 'label': gt_label, 'filename': imgname, 'attr_begin': self.attr_begin[idx], 'attr_end': self.attr_end[idx]}
199
+
200
+ return output
201
+ except:
202
+ print('{} load failed'.format(imgpath))
203
+ return self.read_one()
204
+
205
+ def __repr__(self):
206
+ return self.__class__.__name__ + \
207
+ f'rank: {self.rank} task: {self.task_name} mode:{"training" if self.train else "inference"} ' \
208
+ f'dataset_len:{len(self.img_id)} augmentation: {self.transform}'
209
+
210
+ rap2_attr_name = {
211
+ 0: {0:'without a bald head',1:'with a bald head'},
212
+ 1: {0:'with short hair',1:'with long hair'},
213
+ 2: {0:'with non-black hair',1:'with black hair'},
214
+ 3: {0:'without a hat',1:'with a hat'},
215
+ 4: {0:'without glasses',1:'with glasses'},
216
+ 5: {0:'without a shirt',1:'with a shirt'},
217
+ 6: {0:'without a sweater',1:'with a sweater'},
218
+ 7: {0:'without a vest',1:'with a vest'},
219
+ 8: {0:'without a t-shirt',1:'with a t-shirt'},
220
+ 9: {0:'without cotton',1:'with cotton'},
221
+ 10: {0:'without a jacket',1:'with a jacket'},
222
+ 11: {0:'without formal wear',1:'with formal wear'},
223
+ 12: {0:'without tight clothes',1:'with tight clothes'},
224
+ 13: {0:'without short sleeves',1:'with short sleeves'},
225
+ 14: {0:'without other upper-body clothing',1:'with other upper-body clothing'},
226
+ 15: {0:'without long trousers',1:'with long trousers'},
227
+ 16: {0:'without a skirt',1:'with a skirt'},
228
+ 17: {0:'without a short skirt',1:'with a short skirt'},
229
+ 18: {0:'without a dress',1:'with a dress'},
230
+ 19: {0:'without jeans',1:'with jeans'},
231
+ 20: {0:'without tight trousers',1:'with tight trousers'},
232
+ 21: {0:'without leather shoes',1:'with leather shoes'},
233
+ 22: {0:'without sport shoes',1:'with sport shoes'},
234
+ 23: {0:'without boots',1:'with boots'},
235
+ 24: {0:'without cloth shoes',1:'with cloth shoes'},
236
+ 25: {0:'without casual shoes',1:'with casual shoes'},
237
+ 26: {0:'without other shoes',1:'with other shoes'},
238
+ 27: {0:'without a backpack',1:'with a backpack'},
239
+ 28: {0:'without a shoulder bag',1:'with a shoulder bag'},
240
+ 29: {0:'without a handbag',1:'with a handbag'},
241
+ 30: {0:'without a box',1:'with a box'},
242
+ 31: {0:'without a plastic bag',1:'with a plastic bag'},
243
+ 32: {0:'without a paper bag',1:'with a paper bag'},
244
+ 33: {0:'without a hand trunk',1:'with a hand trunk'},
245
+ 34: {0:'without other attachments',1:'with other attachments'},
246
+ 35: {0:'age greater than 16',1:'age less than or equal to 16'},
247
+ 36: {0:'age less than 17 or greater than 30',1:'age between 17 and 30'},
248
+ 37: {0:'age less than 31 or greater than 45',1:'age between 31 and 45'},
249
+ 38: {0:'age less than 46 or greater than 60',1:'age between 46 and 60'},
250
+ 39: {0:'male',1:'female', 2:'gender unknown'},
251
+ 40: {0:'without excess body fat',1:'with excess body fat'},
252
+ 41: {0:'without normal body shape',1:'with normal body shape'},
253
+ 42: {0:'without thin body shape',1:'with thin body shape'},
254
+ 43: {0:'not a customer',1:'is a customer'},
255
+ 44: {0:'not an employee',1:'is an employee'},
256
+ 45: {0:'not calling',1:'calling'},
257
+ 46: {0:'not talking',1:'talking'},
258
+ 47: {0:'not gathering',1:'gathering'},
259
+ 48: {0:'not holding anything',1:'holding something'},
260
+ 49: {0:'not pushing anything',1:'pushing something'},
261
+ 50: {0:'not pulling anything',1:'pulling something'},
262
+ 51: {0:'not carrying anything in arms',1:'carrying something in arms'},
263
+ 52: {0:'not carrying anything in hands',1:'carrying something in hands'},
264
+ 53: {0:'no other actions',1:'performing other actions'}
265
+ }
266
+
267
+ PA_100k_attr_name = {
268
+ 0: {0:'without a hat',1:'with a hat'},
269
+ 1: {0:'without glasses',1:'with glasses'},
270
+ 2: {0:'without short sleeves',1:'with short sleeves'},
271
+ 3: {0:'without long sleeves',1:'with long sleeves'},
272
+ 4: {0:'without stripe upper-clothes',1:'with stripe upper-clothes'},
273
+ 5: {0:'without logo upper-clothes',1:'with logo upper-clothes'},
274
+ 6: {0:'without plaid upper-clothes',1:'with plaid upper-clothes'},
275
+ 7: {0:'without splice upper-clothes',1:'with splice upper-clothes'},
276
+ 8: {0:'without stripe lower-clothes',1:'with stripe lower-clothes'},
277
+ 9: {0:'without pattern lower-clothes',1:'with pattern lower-clothes'},
278
+ 10: {0:'without long coat',1:'with long coat'},
279
+ 11: {0:'without long trousers',1:'with long trousers'},
280
+ 12: {0:'without short trousers',1:'with short trousers'},
281
+ 13: {0:'without skirt or dress',1:'with skirt or dress'},
282
+ 14: {0:'without boots',1:'with boots'},
283
+ 15: {0:'without a handbag',1:'with a handbag'},
284
+ 16: {0:'without a shoulder bag',1:'with a shoulder bag'},
285
+ 17: {0:'without a backpack',1:'with a backpack'},
286
+ 18: {0:'not hold objects in front',1:'hold objects in front'},
287
+ 19: {0:'age less than or equal to 60',1:'age greater than 60'},
288
+ 20: {0:'age less than 18 or greater than 60',1:'age between 18 and 60'},
289
+ 21: {0:'age greater than or equal to 18',1:'age less than 18'},
290
+ 22: {0:'male',1:'female', 2:'gender unknown'},
291
+ 23: {0:'not in the front position',1:'in the front position'},
292
+ 24: {0:'not in the side position',1:'in the side position'},
293
+ 25: {0:'not in the back position',1:'in the back position'},
294
+ }
295
+
296
+ HARDHC_attr_name = {
297
+ 0: {0:'female', 1:'male', -1:'gender unknown'},
298
+ 1: {0:'with short hair',1:'with long hair'},
299
+ 2: {0:'without sunglass',1:'with sunglass'},
300
+ 3: {0:'without a hat',1:'with a hat'},
301
+ 4: {0:'without T-skirt',1:'with T-skirt'},
302
+ 5: {0:'without long sleeves',1:'with long sleeves'},
303
+ 6: {0:'without formal clothes',1:'with formal clothes'},
304
+ 7: {0:'without short trousers',1:'with short trousers'},
305
+ 8: {0:'without jeans',1:'with jeans'},
306
+ 9: {0:'without long pants',1:'with long pants'},
307
+ 10: {0:'without skirt',1:'with skirt'},
308
+ 11: {0:'without face mask',1:'with face mask'},
309
+ 12: {0:'without logo clothes',1:'with logo clothes'},
310
+ 13: {0:'without stripe clothes',1:'with stripe clothes'},
311
+ }
312
+
313
+ parse27k_attr_name = {
314
+ 0: {0:'without a bald head',1:'with a bald head'},
315
+ 1: {0:'with short hair',1:'with long hair'},
316
+ 2: {0:'with non-black hair',1:'with black hair'},
317
+ 3: {0:'without a hat',1:'with a hat'},
318
+ 4: {0:'without glasses',1:'with glasses'},
319
+ 5: {0:'without a shirt',1:'with a shirt'},
320
+ 6: {0:'without a sweater',1:'with a sweater'},
321
+ 7: {0:'without a vest',1:'with a vest'},
322
+ 8: {0:'without a t-shirt',1:'with a t-shirt'},
323
+ 9: {0:'without cotton',1:'with cotton'},
324
+ 10: {0:'without a jacket',1:'with a jacket'},
325
+ 11: {0:'without formal wear',1:'with formal wear'},
326
+ 12: {0:'without tight clothes',1:'with tight clothes'},
327
+ 13: {0:'without short sleeves',1:'with short sleeves'},
328
+ 14: {0:'without other upper-body clothing',1:'with other upper-body clothing'},
329
+ 15: {0:'without long trousers',1:'with long trousers'},
330
+ 16: {0:'without a skirt',1:'with a skirt'},
331
+ 17: {0:'without a short skirt',1:'with a short skirt'},
332
+ 18: {0:'without a dress',1:'with a dress'},
333
+ 19: {0:'without jeans',1:'with jeans'},
334
+ 20: {0:'without tight trousers',1:'with tight trousers'},
335
+ 21: {0:'without leather shoes',1:'with leather shoes'},
336
+ 22: {0:'without sport shoes',1:'with sport shoes'},
337
+ 23: {0:'without boots',1:'with boots'},
338
+ 24: {0:'without cloth shoes',1:'with cloth shoes'},
339
+ 25: {0:'without casual shoes',1:'with casual shoes'},
340
+ 26: {0:'without other shoes',1:'with other shoes'},
341
+ 27: {0:'without a backpack',1:'with a backpack'},
342
+ 28: {0:'without a shoulder bag',1:'with a shoulder bag'},
343
+ 29: {0:'without a handbag',1:'with a handbag'},
344
+ 30: {0:'without a box',1:'with a box'},
345
+ 31: {0:'without a plastic bag',1:'with a plastic bag'},
346
+ 32: {0:'without a paper bag',1:'with a paper bag'},
347
+ 33: {0:'without a hand trunk',1:'with a hand trunk'},
348
+ 34: {0:'without other attachments',1:'with other attachments'},
349
+ 35: {0:'age greater than 16',1:'age less than or equal to 16'},
350
+ 36: {0:'age less than 17 or greater than 30',1:'age between 17 and 30'},
351
+ 37: {0:'age less than 31 or greater than 45',1:'age between 31 and 45'},
352
+ 38: {0:'age less than 46 or greater than 60',1:'age between 46 and 60'},
353
+ 39: {0:'male',1:'female', 2:'gender unknown'},
354
+ 40: {0:'without excess body fat',1:'with excess body fat'},
355
+ 41: {0:'without normal body shape',1:'with normal body shape'},
356
+ 42: {0:'without thin body shape',1:'with thin body shape'},
357
+ 43: {0:'not a customer',1:'is a customer'}
358
+ }
359
+
360
+ uavhuman_attr_name = {
361
+ 0: {0:'female',1:'male'},
362
+ 1: {0:'without red backpack',1:'with red backpack'},
363
+ 2: {0:'without black backpack',1:'with black backpack'},
364
+ 3: {0:'without green backpack',1:'with green backpack'},
365
+ 4: {0:'without yellow backpack',1:'with yellow backpack'},
366
+ 5: {0:'without other backpack',1:'with other backpack'},
367
+ 6: {0:'without red hat',1:'with red hat'},
368
+ 7: {0:'without black hat',1:'with black hat'},
369
+ 8: {0:'without yellow hat',1:'with yellow hat'},
370
+ 9: {0:'without white hat',1:'with white hat'},
371
+ 10: {0:'without other hat',1:'with other hat'},
372
+ 11: {0:'without red upper-clothes',1:'with red upper-clothes'},
373
+ 12: {0:'without black upper-clothes',1:'with black upper-clothes'},
374
+ 13: {0:'without blue upper-clothes',1:'with blue upper-clothes'},
375
+ 14: {0:'without green upper-clothes',1:'with green upper-clothes'},
376
+ 15: {0:'without multicolor upper-clothes',1:'with multicolor upper-clothes'},
377
+ 16: {0:'without grey upper-clothes',1:'with grey upper-clothes'},
378
+ 17: {0:'without white upper-clothes',1:'with white upper-clothes'},
379
+ 18: {0:'without yellow upper-clothes',1:'with yellow upper-clothes'},
380
+ 19: {0:'without dark brown upper-clothes',1:'with dark brown upper-clothes'},
381
+ 20: {0:'without purple upper-clothes',1:'with purple upper-clothes'},
382
+ 21: {0:'without pink upper-clothes',1:'with pink upper-clothes'},
383
+ 22: {0:'without other upper-clothes',1:'with other upper-clothes'},
384
+ 23: {0:'without long upper-clothes style',1:'with long upper-clothes style'},
385
+ 24: {0:'without short upper-clothes style',1:'with short upper-clothes style'},
386
+ 25: {0:'without skirt upper-clothes style',1:'with skirt upper-clothes style'},
387
+ 26: {0:'without other upper-clothes style',1:'with other upper-clothes style'},
388
+ 27: {0:'without red lower clothes',1:'with red lower clothes'},
389
+ 28: {0:'without black lower clothes',1:'with black lower clothes'},
390
+ 29: {0:'without blue lower clothes',1:'with blue lower clothes'},
391
+ 30: {0:'without green lower clothes',1:'with green lower clothes'},
392
+ 31: {0:'without multicolor lower clothes',1:'with multicolor lower clothes'},
393
+ 32: {0:'without grey lower clothes',1:'with grey lower clothes'},
394
+ 33: {0:'without white lower-clothes',1:'with white lower-clothes'},
395
+ 34: {0:'without yellow lower-clothes',1:'with yellow lower-clothes'},
396
+ 35: {0:'without dark brown lower-clothes',1:'with dark brown lower-clothes'},
397
+ 36: {0:'without purple lower-clothes',1:'with purple lower-clothes'},
398
+ 37: {0:'without pink lower-clothes',1:'with pink lower-clothes'},
399
+ 38: {0:'without other lower-clothes',1:'with other lower-clothes'},
400
+ 39: {0:'without long lower-clothes style',1:'with long lower-clothes style'},
401
+ 40: {0:'without short lower-clothes style',1:'with short lower-clothes style'},
402
+ 41: {0:'without skirt lower-clothes style',1:'with skirt lower-clothes style'},
403
+ 42: {0:'without other lower-clothes style',1:'with other lower-clothes style'}
404
+ }
405
+
406
+ market_attr_name = {
407
+ 0: {0:'without a backpack',1:'with a backpack'},
408
+ 1: {0:'without a bag',1:'with a bag'},
409
+ 2: {0:'without a handbag',1:'with a handbag'},
410
+ 3: {0:'without black lower-clothes',1:'with black lower-clothes'},
411
+ 4: {0:'without blue lower-clothes',1:'with blue lower-clothes'},
412
+ 5: {0:'without brown lower-clothes',1:'with brown lower-clothes'},
413
+ 6: {0:'without gray lower-clothes',1:'with gray lower-clothes'},
414
+ 7: {0:'without green lower-clothes',1:'with green lower-clothes'},
415
+ 8: {0:'without pink lower-clothes',1:'with pink lower-clothes'},
416
+ 9: {0:'without purple lower-clothes',1:'with purple lower-clothes'},
417
+ 10: {0:'without white lower-clothes',1:'with white lower-clothes'},
418
+ 11: {0:'without yellow lower-clothes',1:'with yellow lower-clothes'},
419
+ 12: {0:'without black upper-clothes',1:'with black upper-clothes'},
420
+ 13: {0:'without blue upper-clothes',1:'with blue upper-clothes'},
421
+ 14: {0:'without green upper-clothes',1:'with green upper-clothes'},
422
+ 15: {0:'without gray upper-clothes',1:'with gray upper-clothes'},
423
+ 16: {0:'without purple upper-clothes',1:'with purple upper-clothes'},
424
+ 17: {0:'without red upper-clothes',1:'with red upper-clothes'},
425
+ 18: {0:'without white upper-clothes',1:'with white upper-clothes'},
426
+ 19: {0:'without yellow upper-clothes',1:'with yellow upper-clothes'},
427
+ 20: {0:'with dress',1:'with pants'},
428
+ 21: {0:'with long lower body clothing',1:'with short lower body clothing'},
429
+ 22: {0:'with long sleeve upper body clothing',1:'with short upper body clothing'},
430
+ 23: {0:'with short hair',1:'with long hair'},
431
+ 24: {0:'without a hat',1:'with a hat'},
432
+ 25: {0:'male',1:'female'},
433
+ 26: {0:'not a young person',1:'a young person'},
434
+ 27: {0:'not a teenager',1:'a teenager'},
435
+ 28: {0:'not an adult',1:'an adult'},
436
+ 29: {0:'not an old person',1:'an old person'}
437
+ }
438
+ # peta attr name still have some bugs
439
+ peta_attr_name = {
440
+ 0: {0:'without hat accessory',1:'with hat accessory'},
441
+ 1: {0:'without muffler accessory',1:'with muffler accessory'},
442
+ 2: {0:'with accessory',1:'with nothing accessory'},
443
+ 3: {0:'without sunglasses accessory',1:'with sunglasses accessory'},
444
+ 4: {0:'with short hair',1:'with long hair'},
445
+ 5: {0:'without casual upper body wear',1:'with casual upper body wear'},
446
+ 6: {0:'without formal upper body wear',1:'with formal upper body wear'},
447
+ 7: {0:'without jacket upper body wear',1:'with jacket upper body wear'},
448
+ 8: {0:'without logo upper body wear',1:'with logo upper body wear'},
449
+ 9: {0:'without plaid upper body wear',1:'with plaid upper body wear'},
450
+ 10: {0:'without short sleeve upper body wear',1:'with short sleeve upper body wear'},
451
+ 11: {0:'without thin stripes upper body wear',1:'with thin stripes upper body wear'},
452
+ 12: {0:'without t-shirt upper body wear',1:'with t-shirt upper body wear'},
453
+ 13: {0:'without other upper body wear',1:'with other upper body wear'},
454
+ 14: {0:'without vneck upper body wear',1:'with vneck upper body wear'},
455
+ 15: {0:'without casual lower body wear',1:'with casual lower body wear'},
456
+ 16: {0:'without formal lower body wear',1:'with formal lower body wear'},
457
+ 17: {0:'without jeans lower body wear',1:'with jeans lower body wear'},
458
+ 18: {0:'without shorts lower body wear',1:'with shorts lower body wear'},
459
+ 19: {0:'without shortskirt lower body wear',1:'with shortskirt lower body wear'},
460
+ 20: {0:'without trousers lower body wear',1:'with trousers lower body wear'},
461
+ 21: {0: 'without leather shoes', 1: 'with leather shoes'},
462
+ 22: {0: 'without sandals', 1: 'with sandals'},
463
+ 23: {0: 'without shoes', 1: 'with shoes'},
464
+ 24: {0: 'without sneaker', 1: 'with sneaker'},
465
+ 25: {0: 'without carrying backpack', 1: 'carrying backpack'},
466
+ 26: {0: 'with carrying other things', 1: 'carrying other things'},
467
+ 27: {0: 'without carrying messengerbag', 1: 'carrying messengerbag'},
468
+ 28: {0: 'carrying something', 1: 'carrying nothing'},
469
+ 29: {0: 'without carrying plasticbags', 1: 'carrying plasticbags'},
470
+ 30: {0:'age greater than or equal to 30',1:'age less than 30'},
471
+ 31: {0:'age less than 31 or greater than 45',1:'age between 31 and 45'},
472
+ 32: {0:'age less than 46 or greater than 60',1:'age between 46 and 60'},
473
+ 33: {0:'age less than or equal to 60',1:'age larger than 60'},
474
+ 34: {0:'female',1:'male'}
475
+ }
476
+
477
+ duke_attr_name = {
478
+ 0: {0:'without a backpack',1:'with a backpack'},
479
+ 1: {0:'without a bag',1:'with a bag'},
480
+ 2: {0:'without boots',1:'with boots'},
481
+ 3: {0:'without black lower-clothes',1:'with black lower-clothes'},
482
+ 4: {0:'without blue lower-clothes',1:'with blue lower-clothes'},
483
+ 5: {0:'without brown lower-clothes',1:'with brown lower-clothes'},
484
+ 6: {0:'without gray lower-clothes',1:'with gray lower-clothes'},
485
+ 7: {0:'without green lower-clothes',1:'with green lower-clothes'},
486
+ 8: {0:'without red lower-clothes',1:'with red lower-clothes'},
487
+ 9: {0:'without white lower-clothes',1:'with white lower-clothes'},
488
+ 10: {0:'male',1:'female'},
489
+ 11: {0:'without a handbag',1:'with a handbag'},
490
+ 12: {0:'without a hat',1:'with a hat'},
491
+ 13: {0:'with dark shoes',1:'with light shoes'},
492
+ 14: {0:'short top clothing',1:'long top clothing'},
493
+ 15: {0:'without black upper-clothes',1:'with black upper-clothes'},
494
+ 16: {0:'without blue upper-clothes',1:'with blue upper-clothes'},
495
+ 17: {0:'without brown upper-clothes',1:'with brown upper-clothes'},
496
+ 18: {0:'without gray upper-clothes',1:'with gray upper-clothes'},
497
+ 19: {0:'without green upper-clothes',1:'with green upper-clothes'},
498
+ 20: {0:'without purple upper-clothes',1:'with purple upper-clothes'},
499
+ 21: {0:'without red upper-clothes',1:'with red upper-clothes'},
500
+ 22: {0:'without white upper-clothes',1:'with white upper-clothes'},
501
+ }
502
+
503
+ ClothingAttribute_attr_name = ['pattern_spot', 'cyan', 'brown', 'v_shape_neckline', 'round_neckline', 'other_neckline', 'no_sleevelength', 'short_sleevelength', 'long_sleevelength', 'pattern_graphics', 'gender', 'black', 'many_colors', 'white', 'pattern_floral', 'collar', 'blue', 'necktie', 'pattern_stripe', 'pattern_solid', 'gray', 'shirt_category', 'sweater_category', 't_shirt_category', 'outerwear_category', 'suit_category', 'tank_top_category', 'dress_category', 'placket', 'pattern_plaid', 'purple', 'scarf', 'green', 'yellow', 'skin_exposure', 'red']
504
+
505
+ lup_0_200w_attr_name = {
506
+ 0: {0: 'male', 1: 'female', -1:'gender unknown'},
507
+ 1: {0: 'age greater than 6', 1: "age less than or equal to 6", -1: 'age unknown'},
508
+ 2: {0: 'age less than 7 or greater than 18', 1: "age between 7 and 18", -1: 'age unknown'},
509
+ 3: {0: 'age less than 19 or greater than 65', 1: "age between 19 and 65", -1: 'age unknown'},
510
+ 4: {0: 'age less than 66', 1: "age greater than or equal to 66", -1: 'age unknown'},
511
+ 5: {0: 'with short sleeve coat', 1: 'with long sleeves', -1: 'coat length unknown'},
512
+ 6: {0: 'with shorts trousers', 1: 'with long trousers'},
513
+ 7: {0: 'without a skirt', 1:'with a skirt'},
514
+ 8: {0: 'without a pure pattern coat', 1: 'with a pure upper-clothes'},
515
+ 9: {0: 'without a stripe pattern coat', 1: 'with a stripe upper-clothes'},
516
+ 10: {0: 'without a design pattern coat', 1: 'with a design upper-clothes'},
517
+ 11: {0: 'without a joint pattern coat', 1: 'with a joint upper-clothes'},
518
+ 12: {0: 'without a lattic pattern coat', 1: 'with a lattic upper-clothes'},
519
+ 13: {0: 'without a black color trousers', 1: 'with black lower-clothes'},
520
+ 14: {0: 'without a white color trousers', 1: 'with white lower-clothes'},
521
+ 15: {0: 'without a gray color trousers', 1: 'with a gray color trousers'},
522
+ 16: {0: 'without a red color trousers', 1: 'with a red color trousers'},
523
+ 17: {0: 'without a yellow color trousers', 1: 'with a yellow color trousers'},
524
+ 18: {0: 'without a blue color trousers', 1: 'with a blue color trousers'},
525
+ 19: {0: 'without a green color trousers', 1: 'with a green color trousers'},
526
+ 20: {0: 'without a purple color trousers', 1: 'with a purple color trousers'},
527
+ 21: {0: 'without a pure pattern trousers', 1: 'with a pure lower-clothes'},
528
+ 22: {0: 'without a stripe pattern trousers', 1: 'with a stripe lower-clothes'},
529
+ 23: {0: 'without a design pattern trousers', 1: 'with a design lower-clothes'},
530
+ 24: {0: 'without a joint pattern trousers', 1: 'with a joint lower-clothes'},
531
+ 25: {0: 'without a lattic pattern trousers', 1: 'with a lattic lower-clothes'},
532
+ 26: {0: 'without a hat', 1: 'with a hat', -1: 'hat unknown'},
533
+ 27: {0: 'without a jacket', 1: 'with a jacket'},
534
+ 28: {0: 'without a sweater', 1: 'with a sweater'},
535
+ 29: {0: 'without a long coat', 1: 'with a long coat'},
536
+ 30: {0: 'without a shirt', 1: 'with a shirt'},
537
+ 31: {0: 'without a dress', 1: 'with a dress'},
538
+ 32: {0: 'without a business suit', 1: 'with a business suit'},
539
+ 33: {0: 'without a black color coat', 1: 'with a black color coat', -1:'unknown coat color'},
540
+ 34: {0: 'without a white color coat', 1: 'with a white color coat', -1:'unknown coat color'},
541
+ 35: {0: 'without a gray color coat', 1: 'with a gray color coat', -1:'unknown coat color'},
542
+ 36: {0: 'without a red color coat', 1: 'with a red color coat', -1:'unknown coat color'},
543
+ 37: {0: 'without a yellow color coat', 1: 'with a yellow color coat', -1:'unknown coat color'},
544
+ 38: {0: 'without a blue color coat', 1: 'with a blue color coat', -1:'unknown coat color'},
545
+ 39: {0: 'without a green color coat', 1: 'with a green color coat', -1:'unknown coat color'},
546
+ 40: {0: 'without a purple color coat', 1: 'with a purple color coat', -1:'unknown coat color'},
547
+ 41: {0: 'with short hair', 1: 'with long hair', -1: 'unknown hair style'},
548
+ 42: {0: 'without leather shoes', 1: 'with leather shoes'},
549
+ 43: {0: 'without boots', 1: 'with boots'},
550
+ 44: {0: 'without walking shoes', 1: 'with walking shoes'},
551
+ 45: {0: 'without sandal', 1: 'with sandal'},
552
+ 46: {0: 'without a bag', 1: 'without a bag', -1: 'unknown bag style'},
553
+ 47: {0: 'without glasses', 1: 'with glasses'},
554
+ 48: {0: 'not stand', 1: 'stand', -1: 'unknown pose'},
555
+ 49: {0: 'not sit', 1: 'sit', -1: 'unknown pose'},
556
+ 50: {0: 'not lie', 1: 'lie', -1: 'unknown pose'},
557
+ 51: {0: 'not stoop', 1: 'stoop', -1: 'unknown pose'}}
558
+
559
+ lup_0_600w_attr_name = {
560
+ 0: {0: 'male', 1: 'female', -1:'gender unknown'},
561
+ 1: {0: 'age greater than 6', 1: "age less than or equal to 6", -1: 'age unknown'},
562
+ 2: {0: 'age less than 7 or greater than 18', 1: "age between 7 and 18", -1: 'age unknown'},
563
+ 3: {0: 'age less than 19 or greater than 65', 1: "age between 19 and 65", -1: 'age unknown'},
564
+ 4: {0: 'age less than 66', 1: "age greater than or equal to 66", -1: 'age unknown'},
565
+ 5: {0: 'with short sleeve coat', 1: 'with long sleeves', -1: 'coat length unknown'},
566
+ 6: {0: 'with shorts trousers', 1: 'with long trousers'},
567
+ 7: {0: 'without a skirt', 1:'with a skirt'},
568
+ 8: {0: 'without a pure pattern coat', 1: 'with a pure upper-clothes'},
569
+ 9: {0: 'without a stripe pattern coat', 1: 'with a stripe upper-clothes'},
570
+ 10: {0: 'without a design pattern coat', 1: 'with a design upper-clothes'},
571
+ 11: {0: 'without a joint pattern coat', 1: 'with a joint upper-clothes'},
572
+ 12: {0: 'without a lattic pattern coat', 1: 'with a lattic upper-clothes'},
573
+ 13: {0: 'without a black color trousers', 1: 'with black lower-clothes'},
574
+ 14: {0: 'without a white color trousers', 1: 'with white lower-clothes'},
575
+ 15: {0: 'without a gray color trousers', 1: 'with a gray color trousers'},
576
+ 16: {0: 'without a red color trousers', 1: 'with a red color trousers'},
577
+ 17: {0: 'without a yellow color trousers', 1: 'with a yellow color trousers'},
578
+ 18: {0: 'without a blue color trousers', 1: 'with a blue color trousers'},
579
+ 19: {0: 'without a green color trousers', 1: 'with a green color trousers'},
580
+ 20: {0: 'without a purple color trousers', 1: 'with a purple color trousers'},
581
+ 21: {0: 'without a pure pattern trousers', 1: 'with a pure lower-clothes'},
582
+ 22: {0: 'without a stripe pattern trousers', 1: 'with a stripe lower-clothes'},
583
+ 23: {0: 'without a design pattern trousers', 1: 'with a design lower-clothes'},
584
+ 24: {0: 'without a joint pattern trousers', 1: 'with a joint lower-clothes'},
585
+ 25: {0: 'without a lattic pattern trousers', 1: 'with a lattic lower-clothes'},
586
+ 26: {0: 'without a hat', 1: 'with a hat', -1: 'hat unknown'},
587
+ 27: {0: 'without a jacket', 1: 'with a jacket'},
588
+ 28: {0: 'without a sweater', 1: 'with a sweater'},
589
+ 29: {0: 'without a long coat', 1: 'with a long coat'},
590
+ 30: {0: 'without a shirt', 1: 'with a shirt'},
591
+ 31: {0: 'without a dress', 1: 'with a dress'},
592
+ 32: {0: 'without a business suit', 1: 'with a business suit'},
593
+ 33: {0: 'without a black color coat', 1: 'with a black color coat', -1:'unknown coat color'},
594
+ 34: {0: 'without a white color coat', 1: 'with a white color coat', -1:'unknown coat color'},
595
+ 35: {0: 'without a gray color coat', 1: 'with a gray color coat', -1:'unknown coat color'},
596
+ 36: {0: 'without a red color coat', 1: 'with a red color coat', -1:'unknown coat color'},
597
+ 37: {0: 'without a yellow color coat', 1: 'with a yellow color coat', -1:'unknown coat color'},
598
+ 38: {0: 'without a blue color coat', 1: 'with a blue color coat', -1:'unknown coat color'},
599
+ 39: {0: 'without a green color coat', 1: 'with a green color coat', -1:'unknown coat color'},
600
+ 40: {0: 'without a purple color coat', 1: 'with a purple color coat', -1:'unknown coat color'},
601
+ 41: {0: 'with short hair', 1: 'with long hair', -1: 'unknown hair style'},
602
+ 42: {0: 'without leather shoes', 1: 'with leather shoes'},
603
+ 43: {0: 'without boots', 1: 'with boots'},
604
+ 44: {0: 'without walking shoes', 1: 'with walking shoes'},
605
+ 45: {0: 'without sandal', 1: 'with sandal'},
606
+ 46: {0: 'without a bag', 1: 'without a bag', -1: 'unknown bag style'},
607
+ 47: {0: 'without glasses', 1: 'with glasses'},
608
+ 48: {0: 'not stand', 1: 'stand', -1: 'unknown pose'},
609
+ 49: {0: 'not sit', 1: 'sit', -1: 'unknown pose'},
610
+ 50: {0: 'not lie', 1: 'lie', -1: 'unknown pose'},
611
+ 51: {0: 'not stoop', 1: 'stoop', -1: 'unknown pose'}}
612
+
613
+ lup_600_1200w_attr_name = {
614
+ 0: {0: 'male', 1: 'female', -1:'gender unknown'},
615
+ 1: {0: 'age greater than 6', 1: "age less than or equal to 6", -1: 'age unknown'},
616
+ 2: {0: 'age less than 7 or greater than 18', 1: "age between 7 and 18", -1: 'age unknown'},
617
+ 3: {0: 'age less than 19 or greater than 65', 1: "age between 19 and 65", -1: 'age unknown'},
618
+ 4: {0: 'age less than 66', 1: "age greater than or equal to 66", -1: 'age unknown'},
619
+ 5: {0: 'with short sleeve coat', 1: 'with long sleeves', -1: 'coat length unknown'},
620
+ 6: {0: 'with shorts trousers', 1: 'with long trousers'},
621
+ 7: {0: 'without a skirt', 1:'with a skirt'},
622
+ 8: {0: 'without a pure pattern coat', 1: 'with a pure upper-clothes'},
623
+ 9: {0: 'without a stripe pattern coat', 1: 'with a stripe upper-clothes'},
624
+ 10: {0: 'without a design pattern coat', 1: 'with a design upper-clothes'},
625
+ 11: {0: 'without a joint pattern coat', 1: 'with a joint upper-clothes'},
626
+ 12: {0: 'without a lattic pattern coat', 1: 'with a lattic upper-clothes'},
627
+ 13: {0: 'without a black color trousers', 1: 'with black lower-clothes'},
628
+ 14: {0: 'without a white color trousers', 1: 'with white lower-clothes'},
629
+ 15: {0: 'without a gray color trousers', 1: 'with a gray color trousers'},
630
+ 16: {0: 'without a red color trousers', 1: 'with a red color trousers'},
631
+ 17: {0: 'without a yellow color trousers', 1: 'with a yellow color trousers'},
632
+ 18: {0: 'without a blue color trousers', 1: 'with a blue color trousers'},
633
+ 19: {0: 'without a green color trousers', 1: 'with a green color trousers'},
634
+ 20: {0: 'without a purple color trousers', 1: 'with a purple color trousers'},
635
+ 21: {0: 'without a pure pattern trousers', 1: 'with a pure lower-clothes'},
636
+ 22: {0: 'without a stripe pattern trousers', 1: 'with a stripe lower-clothes'},
637
+ 23: {0: 'without a design pattern trousers', 1: 'with a design lower-clothes'},
638
+ 24: {0: 'without a joint pattern trousers', 1: 'with a joint lower-clothes'},
639
+ 25: {0: 'without a lattic pattern trousers', 1: 'with a lattic lower-clothes'},
640
+ 26: {0: 'without a hat', 1: 'with a hat', -1: 'hat unknown'},
641
+ 27: {0: 'without a jacket', 1: 'with a jacket'},
642
+ 28: {0: 'without a sweater', 1: 'with a sweater'},
643
+ 29: {0: 'without a long coat', 1: 'with a long coat'},
644
+ 30: {0: 'without a shirt', 1: 'with a shirt'},
645
+ 31: {0: 'without a dress', 1: 'with a dress'},
646
+ 32: {0: 'without a business suit', 1: 'with a business suit'},
647
+ 33: {0: 'without a black color coat', 1: 'with a black color coat', -1:'unknown coat color'},
648
+ 34: {0: 'without a white color coat', 1: 'with a white color coat', -1:'unknown coat color'},
649
+ 35: {0: 'without a gray color coat', 1: 'with a gray color coat', -1:'unknown coat color'},
650
+ 36: {0: 'without a red color coat', 1: 'with a red color coat', -1:'unknown coat color'},
651
+ 37: {0: 'without a yellow color coat', 1: 'with a yellow color coat', -1:'unknown coat color'},
652
+ 38: {0: 'without a blue color coat', 1: 'with a blue color coat', -1:'unknown coat color'},
653
+ 39: {0: 'without a green color coat', 1: 'with a green color coat', -1:'unknown coat color'},
654
+ 40: {0: 'without a purple color coat', 1: 'with a purple color coat', -1:'unknown coat color'},
655
+ 41: {0: 'with short hair', 1: 'with long hair', -1: 'unknown hair style'},
656
+ 42: {0: 'without leather shoes', 1: 'with leather shoes'},
657
+ 43: {0: 'without boots', 1: 'with boots'},
658
+ 44: {0: 'without walking shoes', 1: 'with walking shoes'},
659
+ 45: {0: 'without sandal', 1: 'with sandal'},
660
+ 46: {0: 'without a bag', 1: 'without a bag', -1: 'unknown bag style'},
661
+ 47: {0: 'without glasses', 1: 'with glasses'},
662
+ 48: {0: 'not stand', 1: 'stand', -1: 'unknown pose'},
663
+ 49: {0: 'not sit', 1: 'sit', -1: 'unknown pose'},
664
+ 50: {0: 'not lie', 1: 'lie', -1: 'unknown pose'},
665
+ 51: {0: 'not stoop', 1: 'stoop', -1: 'unknown pose'}}
core/data/datasets/images/peddet_dataset_v2.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+ from PIL import ImageFile
3
+ ImageFile.LOAD_TRUNCATED_IMAGES=True
4
+
5
+ import os
6
+ import os.path
7
+
8
+ import random
9
+ import torch
10
+ import numpy as np
11
+ import copy
12
+
13
+ import time
14
+ from core.data.transforms.peddet_transforms import PedestrainDetectionAugmentation
15
+
16
+ from core.data.datasets.images.seg_dataset_dev import Instances
17
+ from typing import *
18
+ import torch.distributed as dist
19
+ from PIL import Image
20
+ import json
21
+ from pycocotools.coco import COCO
22
+
23
+ from collections import defaultdict
24
+
25
+ __all__ = ['PedestrainDetectionDataset_v2']
26
+
27
+ class PetrelCOCO(COCO):
28
+ def __init__(self, annotation_file=None, annotation=None):
29
+ """
30
+ Constructor of Microsoft COCO helper class for reading and visualizing annotations.
31
+ :param annotation_file (str): location of annotation file
32
+ :param annotation (?): partially processed annotation file
33
+ :return:
34
+ """
35
+ # load dataset
36
+ self.dataset, self.anns, self.cats, self.imgs = dict(), dict(), dict(), dict()
37
+ self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
38
+ assert annotation_file is None or annotation is None
39
+ if annotation_file is not None:
40
+ print('loading annotations into memory...')
41
+ tic = time.time()
42
+ with open(annotation_file, 'r') as f:
43
+ dataset = json.load(f)
44
+ assert type(dataset) == dict, 'annotation file format {} not supported'.format(type(dataset))
45
+ print('Done (t={:0.2f}s)'.format(time.time() - tic))
46
+ self.dataset = dataset
47
+ self.createIndex()
48
+
49
+ if annotation is not None:
50
+ print('adding annotations into memory...')
51
+ tic = time.time()
52
+ dataset = annotation
53
+ self.dataset = dataset
54
+ self.createIndex()
55
+
56
+ def convert_coco_poly_to_mask(segmentations, height, width):
57
+ masks = []
58
+ for polygons in segmentations:
59
+ rles = coco_mask.frPyObjects(polygons, height, width)
60
+ mask = coco_mask.decode(rles)
61
+ if len(mask.shape) < 3:
62
+ mask = mask[..., None]
63
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
64
+ mask = mask.any(dim=2)
65
+ masks.append(mask)
66
+ if masks:
67
+ masks = torch.stack(masks, dim=0)
68
+ else:
69
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
70
+ return masks
71
+
72
+ class ConvertCocoPolysToMask(object):
73
+ def __init__(self, return_masks=False):
74
+ self.return_masks = return_masks
75
+
76
+ def __call__(self, image, target):
77
+ w, h = image.size
78
+
79
+ image_id = target["image_id"]
80
+ image_id = torch.tensor([image_id])
81
+
82
+ anno = target["annotations"]
83
+
84
+ boxes = [obj["bbox"] for obj in anno]
85
+ # guard against no boxes via resizing
86
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
87
+ boxes[:, 2:] += boxes[:, :2]
88
+ boxes[:, 0::2].clamp_(min=0, max=w)
89
+ boxes[:, 1::2].clamp_(min=0, max=h)
90
+
91
+ classes = [obj["category_id"] for obj in anno]
92
+ classes = torch.tensor(classes, dtype=torch.int64)
93
+
94
+ if self.return_masks:
95
+ segmentations = [obj["segmentation"] for obj in anno]
96
+ masks = convert_coco_poly_to_mask(segmentations, h, w)
97
+
98
+ keypoints = None
99
+ if anno and "keypoints" in anno[0]:
100
+ keypoints = [obj["keypoints"] for obj in anno]
101
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
102
+ num_keypoints = keypoints.shape[0]
103
+ if num_keypoints:
104
+ keypoints = keypoints.view(num_keypoints, -1, 3)
105
+
106
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
107
+
108
+ # for conversion to coco api
109
+ area = torch.tensor([obj["area"] for obj in anno])
110
+ iscrowd = torch.BoolTensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
111
+ iscrowd |= classes != 0
112
+
113
+ target = {}
114
+ target["boxes"] = boxes[keep]
115
+ target["labels"] = classes[keep]
116
+ if self.return_masks:
117
+ target["masks"] = masks[keep]
118
+ target["image_id"] = image_id
119
+ if keypoints is not None:
120
+ target["keypoints"] = keypoints[keep]
121
+
122
+ target["area"] = area[keep]
123
+ target["iscrowd"] = iscrowd[keep]
124
+
125
+ target["orig_size"] = torch.as_tensor([int(h), int(w)])
126
+ target["size"] = torch.as_tensor([int(h), int(w)])
127
+
128
+ return image, target
129
+
130
+
131
+ class CocoDetection(data.Dataset):
132
+ """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
133
+
134
+ Args:
135
+ root (string): Root directory where images are downloaded to.
136
+ annFile (string): Path to json annotation file.
137
+ transform (callable, optional): A function/transform that takes in an PIL image
138
+ and returns a transformed version. E.g, ``transforms.ToTensor``
139
+ target_transform (callable, optional): A function/transform that takes in the
140
+ target and transforms it.
141
+ """
142
+
143
+ def __init__(self, ann, phase, transform=None, target_transform=None):
144
+ self.coco = PetrelCOCO(annotation=ann)
145
+
146
+ self.ids = list(self.coco.imgs.keys())
147
+ assert phase in ['train', 'val']
148
+ self.transform = transform
149
+ self.phase = phase
150
+ self.target_transform = target_transform
151
+
152
+ self.rank = dist.get_rank()
153
+ self.world_size = dist.get_world_size()
154
+
155
+ self.initialized = True
156
+
157
+ def _init_memcached(self):
158
+ if not self.initialized:
159
+ ## only use mc default
160
+ print("==> will load files from local machine")
161
+ server_list_config_file = "/mnt/lustre/share/memcached_client/server_list.conf"
162
+ client_config_file = "/mnt/lustre/share/memcached_client/client.conf"
163
+ self.memcached_mclient = mc.MemcachedClient.GetInstance(server_list_config_file, client_config_file)
164
+ ## mc-support-ceph
165
+ print('mc-support-ceph')
166
+ self.ceph_mclient = s3client
167
+
168
+ self.initialized = True
169
+
170
+ def _read_one(self, index=None):
171
+ """
172
+ Args:
173
+ index (int): Index
174
+
175
+ Returns:
176
+ tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
177
+ """
178
+ if index is None:
179
+ index = np.random.randint(len(self.ids))
180
+
181
+ coco = self.coco
182
+ img_id = self.ids[index]
183
+
184
+ ann_ids = coco.getAnnIds(imgIds=img_id)
185
+ target = copy.deepcopy(coco.loadAnns(ann_ids))
186
+
187
+ for one_target in target:
188
+ if 'segmentation' in one_target: del one_target['segmentation']
189
+ if 'keypoints' in one_target: del one_target['keypoints']
190
+
191
+ path = coco.loadImgs(img_id)[0]['file_name']
192
+ img_root = coco.loadImgs(img_id)[0]['img_root']
193
+ imgname = os.path.splitext(path)[0]
194
+
195
+ if self.phase == 'val':
196
+ if 'CrowdHuman' in img_root:
197
+ path = path.replace('.png', '.jpg')
198
+ ## for code in lab, we use jpg
199
+ if 'CrowdHuman' in img_root:
200
+ path = path.replace('.png', '.jpg')
201
+ filename = os.path.join(img_root, path)
202
+ try:
203
+ img = Image.open(filename).convert('RGB')
204
+ if img is None:
205
+ raise Exception("None Image")
206
+ except:
207
+ outputName = "failed_to_read_in_train.txt"
208
+ with open(outputName,"a") as g:
209
+ g.write("%s\n"%(filename))
210
+ print('Read image[{}] failed ({})'.format(index, filename))
211
+ ## if fail then recursive call _read_one without idx
212
+ return self._read_one()
213
+ else:
214
+ output = dict()
215
+ ##set random_seed with img idx
216
+ random.seed(index+self.rank)
217
+ np.random.seed(index+self.rank)
218
+
219
+ if self.transform is not None:
220
+ img = self.transform(img)
221
+
222
+ if self.target_transform is not None:
223
+ target = self.target_transform(target)
224
+
225
+ return img, target, imgname
226
+
227
+ def __getitem__(self, index):
228
+ """
229
+ Args:
230
+ index (int): Index
231
+
232
+ Returns:
233
+ tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
234
+ """
235
+ self._init_memcached()
236
+ img, target, imgname = self._read_one(index)
237
+
238
+ return img, target, imgname
239
+
240
+ def __len__(self):
241
+ return len(self.ids)
242
+
243
+ def __repr__(self):
244
+ fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
245
+ fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
246
+ fmt_str += ' Root Location: {}\n'.format(self.root)
247
+ tmp = ' Transforms (if any): '
248
+ fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
249
+ tmp = ' Target Transforms (if any): '
250
+ fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
251
+ return fmt_str
252
+
253
+
254
+ def coco_merge(
255
+ img_root_list: List[str], input_list: List[str],
256
+ indent: Optional[int] = None,
257
+ ) -> str:
258
+ """Merge COCO annotation files.
259
+
260
+ Args:
261
+ input_extend: Path to input file to be extended.
262
+ input_add: Path to input file to be added.
263
+ output_file : Path to output file with merged annotations.
264
+ indent: Argument passed to `json.dump`. See https://docs.python.org/3/library/json.html#json.dump.
265
+ """
266
+ data_list = []
267
+
268
+ for input in input_list:
269
+ with open(input, 'r') as f:
270
+ data_extend = json.load(f)
271
+
272
+ data_list.append(data_extend)
273
+
274
+ output= {'categories': data_list[0]['categories']}
275
+
276
+ output["images"], output["annotations"] = [], []
277
+
278
+ for i, (data, img_root) in enumerate(zip(data_list, img_root_list)):
279
+ print(
280
+ "Input {}: {} images, {} annotations".format(
281
+ i + 1, len(data["images"]), len(data["annotations"])
282
+ )
283
+ )
284
+
285
+ cat_id_map = {}
286
+ for new_cat in data["categories"]:
287
+ new_id = None
288
+ for output_cat in output["categories"]:
289
+ if new_cat["name"] == output_cat["name"]:
290
+ new_id = output_cat["id"]
291
+ break
292
+
293
+ if new_id is not None:
294
+ cat_id_map[new_cat["id"]] = new_id
295
+ else:
296
+ new_cat_id = max(c["id"] for c in output["categories"]) + 1
297
+ cat_id_map[new_cat["id"]] = new_cat_id
298
+ new_cat["id"] = new_cat_id
299
+ output["categories"].append(new_cat)
300
+
301
+ img_id_map = {}
302
+ for image in data["images"]:
303
+ n_imgs = len(output["images"])
304
+ img_id_map[image["id"]] = n_imgs
305
+ image["id"] = n_imgs
306
+ image["img_root"] = img_root
307
+
308
+ output["images"].append(image)
309
+
310
+ for annotation in data["annotations"]:
311
+ n_anns = len(output["annotations"])
312
+ annotation["id"] = n_anns
313
+ annotation["image_id"] = img_id_map[annotation["image_id"]]
314
+ annotation["category_id"] = cat_id_map[annotation["category_id"]]
315
+
316
+ output["annotations"].append(annotation)
317
+
318
+ print(
319
+ "Result: {} images, {} annotations".format(
320
+ len(output["images"]), len(output["annotations"])
321
+ )
322
+ )
323
+ return output
324
+
325
+
326
+ class PedestrainDetectionDataset_v2(CocoDetection):
327
+ def __init__(self, ginfo, augmentation, task_spec, train=True, vit=False,
328
+ num_append_fake_boxes=0,
329
+ # append to 900 for a fixed length gt input in the sparse labeling (label) branch
330
+ return_box_xyxy=False,
331
+ append_z=True,
332
+ test_trainset=False,
333
+ **kwargs):
334
+ img_folder = task_spec['img_folder'] if isinstance(task_spec['img_folder'], list) else [task_spec['img_folder']]
335
+ ann_file = task_spec['ann_file'] if isinstance(task_spec['ann_file'], list) else [task_spec['ann_file']]
336
+ self.root = img_folder
337
+
338
+ ann = coco_merge(img_folder, ann_file)
339
+
340
+ return_masks = task_spec['return_masks']
341
+ phase = 'train' if train else 'val'
342
+
343
+ super(PedestrainDetectionDataset_v2, self).__init__(ann=ann, phase=phase)
344
+
345
+ self.return_box_xyxy = return_box_xyxy
346
+ transforms = PedestrainDetectionAugmentation(phase=phase if not test_trainset else 'val', vit=vit, return_box_xyxy=self.return_box_xyxy,
347
+ max_size=augmentation.get('max_size',1333),)
348
+
349
+ name2wh = {}
350
+ for img_id in self.ids:
351
+ img_name = self.coco.loadImgs(img_id)[0]['file_name'].split('.')[0]
352
+ height = self.coco.loadImgs(img_id)[0]['height']
353
+ width = self.coco.loadImgs(img_id)[0]['width']
354
+ name2wh[img_name]={'width':width, 'height': height}
355
+
356
+ self.flag = np.zeros(len(self.ids), dtype=np.uint8)
357
+ for i, img_id in enumerate(self.ids):
358
+ img_info = self.coco.loadImgs(img_id)[0]['file_name'].split('.')[0]
359
+ if name2wh[img_info]['width'] / name2wh[img_info]['height'] > 1:
360
+ self.flag[i] = 1
361
+
362
+ self._transforms = transforms
363
+ self.phase = phase
364
+ self.prepare = ConvertCocoPolysToMask(return_masks)
365
+ self.task_name = ginfo.task_name
366
+
367
+ self.num_append_fake_boxes = num_append_fake_boxes
368
+ self.append_z = append_z
369
+
370
+ def _filter_ignores(self, target):
371
+ target = list(filter(lambda rb: rb['category_id'] > -1, target))
372
+
373
+ return target
374
+
375
+ def _minus_target_label(self, target, value):
376
+
377
+ results = []
378
+ for t in target:
379
+ t['category_id'] -= value
380
+ results.append(t)
381
+ return results
382
+
383
+ def __getitem__(self, idx):
384
+ dataset_dict = {}
385
+ img, target, imgname = super(PedestrainDetectionDataset_v2, self).__getitem__(idx)
386
+ target = self._minus_target_label(target, 1)
387
+ total = len(target)
388
+ image_id = self.ids[idx]
389
+
390
+ target = {'image_id': image_id, 'annotations': target}
391
+ img, target = self.prepare(img, target)
392
+ image_shape = (img.size[-1], img.size[-2]) # h, w
393
+ self._record_image_size(dataset_dict, img)
394
+
395
+ if self._transforms is not None:
396
+ img, target = self._transforms(img, target)
397
+
398
+ if self.num_append_fake_boxes > 0:
399
+ # not take iscrowded boxes into consideration
400
+ len_target = target['labels'].shape[0]
401
+ len_append = self.num_append_fake_boxes - len_target
402
+ target['boxes'] = torch.cat([target['boxes'], torch.zeros([len_append, 4])], dim=0)
403
+ # the appended label is set to 1(background), as ped det only has one class 0 for pedestrian
404
+ append_label = 1
405
+ target['labels'] = torch.cat([target['labels'], torch.ones([len_append]).long()*append_label], dim=0)
406
+ target['iscrowd'] = torch.cat([target['iscrowd'], torch.ones([len_append]).bool()], dim=0)
407
+ target['area'] = torch.cat([target['area'], torch.zeros([len_append])], dim=0)
408
+
409
+ dataset_dict['orig_size'] = target['orig_size']
410
+ dataset_dict['size'] = target['size']
411
+ del target['image_id']
412
+ del target['orig_size']
413
+ del target['size']
414
+
415
+ instances = Instances(image_shape, **target)
416
+
417
+ # sparse_labeling should have a shape of [xyz, T(temperal)=2, V=num_append_fake_boxes, M(num_peopoe)=1]
418
+ # T=2, as we consider x1y1, x2y2 as two points. Info in two points will be integrated in conv to
419
+ # have a token representing a box.
420
+ # import pdb;
421
+ # pdb.set_trace()
422
+ sparse_labeling = target['boxes'].reshape(target['boxes'].shape[0], 2, 2).contiguous()
423
+ if self.append_z:
424
+ append_z = torch.zeros([target['boxes'].shape[0], 2, 1])
425
+ sparse_labeling = torch.cat([sparse_labeling, append_z], dim=2) # num_append_fake_boxes, T, xyz
426
+ sparse_labeling = sparse_labeling.unsqueeze(-1).permute(2, 1, 0, 3).contiguous()
427
+
428
+ dataset_dict['sparse_labeling'] = sparse_labeling
429
+ dataset_dict["image"] = img
430
+ dataset_dict["image_id"] = image_id
431
+ dataset_dict["label"] = -1
432
+ dataset_dict["instances"] = instances
433
+ dataset_dict["filename"] = imgname
434
+
435
+ return dataset_dict
436
+
437
+ @staticmethod
438
+ def _record_image_size(dataset_dict, image):
439
+ """
440
+ Raise an error if the image does not match the size specified in the dict.
441
+ """
442
+ # To ensure bbox always remap to original image size # when in PIL, reversed.
443
+ if "width" not in dataset_dict:
444
+ dataset_dict["width"] = image.size[1]
445
+ if "height" not in dataset_dict:
446
+ dataset_dict["height"] = image.size[0]
447
+
448
+
449
+ class PedestrainDetectionDataset_v2demo(CocoDetection):
450
+ def __init__(self, ginfo, augmentation, task_spec, train=True, vit=False,
451
+ num_append_fake_boxes=0,
452
+ # append to 900 for a fixed length gt input in the sparse labeling (label) branch
453
+ return_box_xyxy=False,
454
+ append_z=True,
455
+ test_trainset=False,
456
+ demo_dir='/mnt/cache/tangshixiang/wyz_proj/demo_video_unihcpv2/folder0',
457
+ **kwargs):
458
+ img_folder = task_spec['img_folder'] if isinstance(task_spec['img_folder'], list) else [task_spec['img_folder']]
459
+ ann_file = task_spec['ann_file'] if isinstance(task_spec['ann_file'], list) else [task_spec['ann_file']]
460
+ self.root = img_folder
461
+
462
+ ann = coco_merge(img_folder, ann_file)
463
+
464
+ return_masks = task_spec['return_masks']
465
+ phase = 'train' if train else 'val'
466
+
467
+ super(PedestrainDetectionDataset_v2demo, self).__init__(ann=ann, phase=phase)
468
+
469
+ self.return_box_xyxy = return_box_xyxy
470
+ transforms = PedestrainDetectionAugmentation(phase=phase if not test_trainset else 'val', vit=vit, return_box_xyxy=self.return_box_xyxy,
471
+ max_size=augmentation.get('max_size',1333),)
472
+
473
+ name2wh = {}
474
+ for img_id in self.ids:
475
+ img_name = self.coco.loadImgs(img_id)[0]['file_name'].split('.')[0]
476
+ height = self.coco.loadImgs(img_id)[0]['height']
477
+ width = self.coco.loadImgs(img_id)[0]['width']
478
+ name2wh[img_name]={'width':width, 'height': height}
479
+
480
+ self.flag = np.zeros(len(self.ids), dtype=np.uint8)
481
+ for i, img_id in enumerate(self.ids):
482
+ img_info = self.coco.loadImgs(img_id)[0]['file_name'].split('.')[0]
483
+ if name2wh[img_info]['width'] / name2wh[img_info]['height'] > 1:
484
+ self.flag[i] = 1
485
+
486
+ self._transforms = transforms
487
+ self.phase = phase
488
+ self.prepare = ConvertCocoPolysToMask(return_masks)
489
+ self.task_name = ginfo.task_name
490
+
491
+ self.num_append_fake_boxes = num_append_fake_boxes
492
+ self.append_z = append_z
493
+ self.demo_dir = demo_dir
494
+ self.listdir = os.listdir(self.demo_dir)
495
+
496
+ def _filter_ignores(self, target):
497
+ target = list(filter(lambda rb: rb['category_id'] > -1, target))
498
+
499
+ return target
500
+
501
+ def _minus_target_label(self, target, value):
502
+
503
+ results = []
504
+ for t in target:
505
+ t['category_id'] -= value
506
+ results.append(t)
507
+ return results
508
+
509
+ def __len__(self):
510
+ return len(os.listdir(self.demo_dir))
511
+
512
+ def __getitem__(self, idx):
513
+ dataset_dict = {}
514
+ img, target, imgname = super(PedestrainDetectionDataset_v2demo, self).__getitem__(0)
515
+ demo_dir = self.demo_dir
516
+ filename = os.path.join(demo_dir, self.listdir[idx])
517
+ img = Image.open(filename).convert('RGB')
518
+ target = self._minus_target_label(target, 1)
519
+ total = len(target)
520
+ image_id = self.ids[0]
521
+
522
+ target = {'image_id': image_id, 'annotations': target}
523
+ img, target = self.prepare(img, target)
524
+ image_shape = (img.size[-1], img.size[-2]) # h, w
525
+ self._record_image_size(dataset_dict, img)
526
+
527
+ if self._transforms is not None:
528
+ img, target = self._transforms(img, target)
529
+
530
+ if self.num_append_fake_boxes > 0:
531
+ # not take iscrowded boxes into consideration
532
+ len_target = target['labels'].shape[0]
533
+ len_append = self.num_append_fake_boxes - len_target
534
+ target['boxes'] = torch.cat([target['boxes'], torch.zeros([len_append, 4])], dim=0)
535
+ # the appended label is set to 1(background), as ped det only has one class 0 for pedestrian
536
+ append_label = 1
537
+ target['labels'] = torch.cat([target['labels'], torch.ones([len_append]).long()*append_label], dim=0)
538
+ target['iscrowd'] = torch.cat([target['iscrowd'], torch.ones([len_append]).bool()], dim=0)
539
+ target['area'] = torch.cat([target['area'], torch.zeros([len_append])], dim=0)
540
+
541
+ dataset_dict['orig_size'] = target['orig_size']
542
+ dataset_dict['size'] = target['size']
543
+ del target['image_id']
544
+ del target['orig_size']
545
+ del target['size']
546
+
547
+ instances = Instances(image_shape, **target)
548
+
549
+ # sparse_labeling should have a shape of [xyz, T(temperal)=2, V=num_append_fake_boxes, M(num_peopoe)=1]
550
+ # T=2, as we consider x1y1, x2y2 as two points. Info in two points will be integrated in conv to
551
+ # have a token representing a box.
552
+ # import pdb;
553
+ # pdb.set_trace()
554
+ sparse_labeling = target['boxes'].reshape(target['boxes'].shape[0], 2, 2).contiguous()
555
+ if self.append_z:
556
+ append_z = torch.zeros([target['boxes'].shape[0], 2, 1])
557
+ sparse_labeling = torch.cat([sparse_labeling, append_z], dim=2) # num_append_fake_boxes, T, xyz
558
+ sparse_labeling = sparse_labeling.unsqueeze(-1).permute(2, 1, 0, 3).contiguous()
559
+
560
+ dataset_dict['sparse_labeling'] = sparse_labeling
561
+ dataset_dict["image"] = img
562
+ dataset_dict["image_id"] = image_id
563
+ dataset_dict["label"] = -1
564
+ dataset_dict["instances"] = instances
565
+ dataset_dict["filename"] = filename
566
+
567
+ return dataset_dict
568
+
569
+ @staticmethod
570
+ def _record_image_size(dataset_dict, image):
571
+ """
572
+ Raise an error if the image does not match the size specified in the dict.
573
+ """
574
+ # To ensure bbox always remap to original image size # when in PIL, reversed.
575
+ if "width" not in dataset_dict:
576
+ dataset_dict["width"] = image.size[1]
577
+ if "height" not in dataset_dict:
578
+ dataset_dict["height"] = image.size[0]
core/data/datasets/images/pos_dataset_dev.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from torch.utils.data import Dataset
4
+ from pathlib import Path
5
+
6
+ from abc import ABCMeta, abstractmethod
7
+ import numpy as np
8
+ from torch.utils.data import Dataset
9
+ import os
10
+ import cv2
11
+ import time
12
+ import random
13
+ import os.path as osp
14
+ import os
15
+ import torch
16
+ import warnings
17
+ from collections import OrderedDict, defaultdict
18
+ from core.data.transforms.pose_transforms import *
19
+ import json #_tricks as json
20
+ import numpy as np
21
+ from xtcocotools.coco import COCO
22
+ from xtcocotools.cocoeval import COCOeval
23
+ import torch.distributed as dist
24
+
25
+
26
+ from core.utils import sync_print
27
+
28
+
29
+ class PetrelCOCO(COCO):
30
+ def __init__(self, annotation_file=None, test_index=None, ann_data=None):
31
+ """
32
+ Constructor of Microsoft COCO helper class for reading and visualizing annotations.
33
+ :param annotation_file (str): location of annotation file
34
+ :param image_folder (str): location to the folder that hosts images.
35
+ :return:
36
+ """
37
+ self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
38
+ self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
39
+ self.anno_file = [annotation_file]
40
+ self.test_index = test_index
41
+ if annotation_file is not None:
42
+ print('loading annotations into memory...')
43
+ tic = time.time()
44
+ # https://github.com/cocodataset/cocoapi/pull/453/
45
+ if ann_data == None:
46
+ with open(annotation_file, 'r') as f:
47
+ dataset = json.load(f)
48
+ else:
49
+ dataset = ann_data
50
+ assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
51
+ print('Done (t={:0.2f}s)'.format(time.time()- tic))
52
+ self.dataset = dataset
53
+ self.createIndex()
54
+ if 'annotations' in self.dataset:
55
+ for i in range(len(self.dataset['annotations'])):
56
+ if self.test_index is not None:
57
+ keypoints = np.array(self.dataset['annotations'][i]['keypoints']).reshape([-1, 3])
58
+ keypoints = keypoints[self.test_index, :]
59
+ self.dataset['annotations'][i]['keypoints'] = keypoints.reshape([-1]).tolist()
60
+ if 'iscrowd' not in self.dataset['annotations'][i]:
61
+ self.dataset['annotations'][i]['iscrowd'] = False
62
+
63
+
64
+ class COCOPosDatasetDev(Dataset):
65
+ """CocoDataset dataset for top-down pose estimation.
66
+
67
+ "Microsoft COCO: Common Objects in Context", ECCV'2014.
68
+ More details can be found in the `paper
69
+ <https://arxiv.org/abs/1405.0312>`__ .
70
+
71
+ The dataset loads raw features and apply specified transforms
72
+ to return a dict containing the image tensors and other information.
73
+
74
+ COCO keypoint indexes::
75
+
76
+ 0: 'nose',
77
+ 1: 'left_eye',
78
+ 2: 'right_eye',
79
+ 3: 'left_ear',
80
+ 4: 'right_ear',
81
+ 5: 'left_shoulder',
82
+ 6: 'right_shoulder',
83
+ 7: 'left_elbow',
84
+ 8: 'right_elbow',
85
+ 9: 'left_wrist',
86
+ 10: 'right_wrist',
87
+ 11: 'left_hip',
88
+ 12: 'right_hip',
89
+ 13: 'left_knee',
90
+ 14: 'right_knee',
91
+ 15: 'left_ankle',
92
+ 16: 'right_ankle'
93
+
94
+ Args:
95
+ ann_file (str): Path to the annotation file.
96
+ img_prefix (str): Path to a directory where images are held.
97
+ Default: None.
98
+ data_cfg (dict): config
99
+ pipeline (list[dict | callable]): A sequence of data transforms.
100
+ dataset_info (DatasetInfo): A class containing all dataset info.
101
+ test_mode (bool): Store True when building test or
102
+ validation dataset. Default: False.
103
+ """
104
+ def __init__(self,
105
+ ginfo,
106
+ ann_file,
107
+ img_prefix,
108
+ data_cfg,
109
+ test_mode=False,
110
+ use_udp=False,
111
+ use_ceph=False,
112
+ data_use_ratio=1,
113
+ **kwargs):
114
+ self.image_info = {}
115
+ self.ann_info = {}
116
+ self.initialized = False
117
+
118
+ self.use_ceph = True
119
+ self.annotations_path = ann_file
120
+ self.img_prefix = img_prefix
121
+ self.test_mode = test_mode
122
+ print('data_cfg0', data_cfg)
123
+ # data_cfg=demjson.decode(data_cfg)
124
+ # print('data_cfg',data_cfg)
125
+ self.ann_info['image_size'] = np.array(data_cfg['image_size'])
126
+ self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
127
+ self.ann_info['num_joints'] = data_cfg['num_joints']
128
+
129
+ self.ann_info['inference_channel'] = data_cfg['inference_channel']
130
+ self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
131
+ self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
132
+
133
+ self.db = []
134
+ self.task_name = ginfo.task_name
135
+
136
+ if test_mode:
137
+ pipeline = [
138
+ LoadImageFromFile(use_ceph=use_ceph),
139
+ TopDownAffine(use_udp=use_udp),
140
+ ToUNTensor(),
141
+ Collect(keys=['image'],
142
+ meta_keys=['image_file', 'center', 'bbox', 'scale', 'rotation', 'bbox_score', 'flip_pairs'])
143
+ ]
144
+ else:
145
+ pipeline = [
146
+ LoadImageFromFile(use_ceph=use_ceph),
147
+ TopDownRandomFlip(flip_prob=0.5),
148
+ TopDownHalfBodyTransform(num_joints_half_body=8,prob_half_body=0.3),
149
+ TopDownGetRandomScaleRotation(rot_factor=40, scale_factor=0.5),
150
+ TopDownAffine(use_udp=use_udp),
151
+ ToUNTensor(),
152
+ TopDownGenerateTarget(sigma=2, encoding='UDP' if use_udp else 'MSRA'),
153
+ Collect(keys=['image', 'label', 'target_weight'],
154
+ meta_keys=['image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale','rotation',
155
+ 'bbox_score', 'flip_pairs'])
156
+ ]
157
+
158
+ self.pipeline = ComposeX(pipeline)
159
+ self.use_gt_bbox = data_cfg['use_gt_bbox']
160
+ self.bbox_file = data_cfg['bbox_file'] if data_cfg['bbox_file'].startswith('/mnt') else (Path(__file__).parent / 'resources' / data_cfg['bbox_file']).resolve()
161
+ self.det_bbox_thr = data_cfg.get('det_bbox_thr', 0.0)
162
+ if 'image_thr' in data_cfg:
163
+ warnings.warn(
164
+ 'image_thr is deprecated, '
165
+ 'please use det_bbox_thr instead', DeprecationWarning)
166
+ self.det_bbox_thr = data_cfg['image_thr']
167
+
168
+
169
+ self.ann_info['flip_pairs'] = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
170
+ [11, 12], [13, 14], [15, 16]]
171
+
172
+ self.ann_info['upper_body_ids'] = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
173
+ self.ann_info['lower_body_ids'] = (11, 12, 13, 14, 15, 16)
174
+
175
+ self.ann_info['use_different_joint_weights'] = False
176
+ self.ann_info['joint_weights'] = np.array(
177
+ [
178
+ 1., 1., 1., 1., 1., 1., 1., 1.2, 1.2, 1.5, 1.5, 1., 1., 1.2,
179
+ 1.2, 1.5, 1.5
180
+ ],
181
+ dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
182
+
183
+ # 'https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/'
184
+ # 'pycocotools/cocoeval.py#L523'
185
+
186
+ self.coco = PetrelCOCO(ann_file)
187
+
188
+ cats = [
189
+ cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
190
+ ]
191
+ self.classes = ['__background__'] + cats
192
+ self.num_classes = len(self.classes)
193
+ self._class_to_ind = dict(zip(self.classes, range(self.num_classes)))
194
+ self._class_to_coco_ind = dict(zip(cats, self.coco.getCatIds()))
195
+ self._coco_ind_to_class_ind = dict(
196
+ (self._class_to_coco_ind[cls], self._class_to_ind[cls])
197
+ for cls in self.classes[1:])
198
+ self.img_ids = self.coco.getImgIds()
199
+ self.num_images = len(self.img_ids)
200
+ self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
201
+ self.dataset_name = 'coco'
202
+
203
+ self.db = self._get_db()
204
+ if data_use_ratio != 1:
205
+ self.db = random.sample(self.db, int(len(self.db) * data_use_ratio))
206
+
207
+ print(f'=> COCOPosDatasetDev num_images: {self.num_images}')
208
+ print(f'=> COCOPosDatasetDev load {len(self.db)} samples')
209
+
210
+ @staticmethod
211
+ def _get_mapping_id_name(imgs):
212
+ """
213
+ Args:
214
+ imgs (dict): dict of image info.
215
+
216
+ Returns:
217
+ tuple: Image name & id mapping dicts.
218
+
219
+ - id2name (dict): Mapping image id to name.
220
+ - name2id (dict): Mapping image name to id.
221
+ """
222
+ id2name = {}
223
+ name2id = {}
224
+ for image_id, image in imgs.items():
225
+ file_name = image['file_name']
226
+ id2name[image_id] = file_name
227
+ name2id[file_name] = image_id
228
+
229
+ return id2name, name2id
230
+
231
+ def _get_db(self):
232
+ """Load dataset."""
233
+ if (not self.test_mode) or self.use_gt_bbox:
234
+ # use ground truth bbox
235
+ gt_db = self._load_coco_keypoint_annotations()
236
+ else:
237
+ # use bbox from detection
238
+ gt_db = self._load_coco_person_detection_results()
239
+ return gt_db
240
+
241
+ def _load_coco_keypoint_annotations(self):
242
+ """Ground truth bbox and keypoints."""
243
+ gt_db = []
244
+ for img_id in self.img_ids:
245
+ gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
246
+ return gt_db
247
+
248
+ def _load_coco_keypoint_annotation_kernel(self, img_id):
249
+ """load annotation from COCOAPI.
250
+
251
+ Note:
252
+ bbox:[x1, y1, w, h]
253
+ Args:
254
+ img_id: coco image id
255
+ Returns:
256
+ dict: db entry
257
+ """
258
+ img_ann = self.coco.loadImgs(img_id)[0]
259
+ width = img_ann['width']
260
+ height = img_ann['height']
261
+ num_joints = self.ann_info['num_joints']
262
+
263
+ ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
264
+ objs = self.coco.loadAnns(ann_ids)
265
+
266
+ # sanitize bboxes
267
+ valid_objs = []
268
+ for obj in objs:
269
+ if 'bbox' not in obj:
270
+ continue
271
+ x, y, w, h = obj['bbox']
272
+ x1 = max(0, x)
273
+ y1 = max(0, y)
274
+ x2 = min(width - 1, x1 + max(0, w - 1))
275
+ y2 = min(height - 1, y1 + max(0, h - 1))
276
+ if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
277
+ obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
278
+ valid_objs.append(obj)
279
+ objs = valid_objs
280
+
281
+ bbox_id = 0
282
+ rec = []
283
+ for obj in objs:
284
+ if 'keypoints' not in obj:
285
+ continue
286
+ if max(obj['keypoints']) == 0:
287
+ continue
288
+ if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
289
+ continue
290
+ joints_3d = np.zeros((num_joints, 3), dtype=np.float32)
291
+ joints_3d_visible = np.zeros((num_joints, 3), dtype=np.float32)
292
+
293
+ keypoints = np.array(obj['keypoints']).reshape(-1, 3)
294
+ joints_3d[:, :2] = keypoints[:, :2]
295
+ joints_3d_visible[:, :2] = np.minimum(1, keypoints[:, 2:3])
296
+
297
+ center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
298
+
299
+ image_file = os.path.join(self.img_prefix, self.id2name[img_id])
300
+ rec.append({
301
+ 'image_file': image_file,
302
+ 'center': center,
303
+ 'scale': scale,
304
+ 'bbox': obj['clean_bbox'][:4],
305
+ 'rotation': 0,
306
+ 'joints_3d': joints_3d,
307
+ 'joints_3d_visible': joints_3d_visible,
308
+ 'dataset': self.dataset_name,
309
+ 'bbox_score': 1,
310
+ 'bbox_id': bbox_id
311
+ })
312
+ bbox_id = bbox_id + 1
313
+
314
+ return rec
315
+
316
+ def _xywh2cs(self, x, y, w, h):
317
+ """This encodes bbox(x,y,w,w) into (center, scale)
318
+
319
+ Args:
320
+ x, y, w, h
321
+
322
+ Returns:
323
+ tuple: A tuple containing center and scale.
324
+
325
+ - center (np.ndarray[float32](2,)): center of the bbox (x, y).
326
+ - scale (np.ndarray[float32](2,)): scale of the bbox w & h.
327
+ """
328
+ aspect_ratio = self.ann_info['image_size'][0] / self.ann_info[
329
+ 'image_size'][1]
330
+ center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
331
+
332
+ if (not self.test_mode) and np.random.rand() < 0.3:
333
+ center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
334
+
335
+ if w > aspect_ratio * h:
336
+ h = w * 1.0 / aspect_ratio
337
+ elif w < aspect_ratio * h:
338
+ w = h * aspect_ratio
339
+
340
+ # pixel std is 200.0
341
+ scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
342
+ # padding to include proper amount of context
343
+ scale = scale * 1.25
344
+
345
+ return center, scale
346
+
347
+ def _load_coco_person_detection_results(self):
348
+ """Load coco person detection results."""
349
+ num_joints = self.ann_info['num_joints']
350
+
351
+ with open(self.bbox_file, 'r') as f:
352
+ all_boxes = json.load(f)
353
+
354
+ if not all_boxes:
355
+ raise ValueError('=> Load %s fail!' % self.bbox_file)
356
+
357
+ print(f'=> Total boxes: {len(all_boxes)}')
358
+
359
+ kpt_db = []
360
+ bbox_id = 0
361
+ for det_res in all_boxes:
362
+ if det_res['category_id'] != 1:
363
+ continue
364
+
365
+ image_file = os.path.join(self.img_prefix,
366
+ self.id2name[det_res['image_id']])
367
+ box = det_res['bbox']
368
+ score = det_res['score']
369
+
370
+ if score < self.det_bbox_thr:
371
+ continue
372
+
373
+ center, scale = self._xywh2cs(*box[:4])
374
+ joints_3d = np.zeros((num_joints, 3), dtype=np.float32)
375
+ joints_3d_visible = np.ones((num_joints, 3), dtype=np.float32)
376
+ kpt_db.append({
377
+ 'image_file': image_file,
378
+ 'center': center,
379
+ 'scale': scale,
380
+ 'rotation': 0,
381
+ 'bbox': box[:4],
382
+ 'bbox_score': score,
383
+ 'dataset': self.dataset_name,
384
+ 'joints_3d': joints_3d,
385
+ 'joints_3d_visible': joints_3d_visible,
386
+ 'bbox_id': bbox_id
387
+ })
388
+ bbox_id = bbox_id + 1
389
+ print(f'=> Total boxes after filter '
390
+ f'low score@{self.det_bbox_thr}: {bbox_id}')
391
+ return kpt_db
392
+
393
+ def __len__(self):
394
+ """Get the size of the dataset."""
395
+ return len(self.db)
396
+
397
+ def __getitem__(self, idx):
398
+ """Get the sample given index."""
399
+ results = copy.deepcopy(self.db[idx])
400
+ results['ann_info'] = self.ann_info
401
+ out = self.pipeline(results)
402
+ C = self.num_classes - 1 # delete the background class
403
+ if 'label' in out:
404
+ out['dense_labeling'] = np.resize(out['label'], (C, self.ann_info['image_size'][0], self.ann_info['image_size'][1]))
405
+ else:
406
+ out['dense_labeling'] = np.zeros((C, self.ann_info['image_size'][0], self.ann_info['image_size'][1]))
407
+ # del out['ann_info']
408
+ return out # dict_keys(['image_file', 'center', 'scale', 'bbox', 'rotation', 'joints_3d', 'joints_3d_visible',
409
+ # 'dataset', 'bbox_score', 'bbox_id', 'ann_info', 'image', 'flipped', 'label',
410
+ # 'target_weight'])
411
+
412
+
413
+ class MPIIPosDatasetDev(Dataset):
414
+ def __init__(self,
415
+ ginfo,
416
+ ann_file,
417
+ img_prefix,
418
+ data_cfg,
419
+ test_mode=False,
420
+ use_udp=False,
421
+ data_use_ratio=1,
422
+ **kwargs):
423
+
424
+ self.image_info = {}
425
+ self.ann_info = {}
426
+
427
+ self.ann_file = ann_file
428
+ self.img_prefix = img_prefix
429
+
430
+ self.test_mode = test_mode
431
+
432
+ self.ann_info['image_size'] = np.array(data_cfg['image_size'])
433
+ self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
434
+ self.ann_info['num_joints'] = data_cfg['num_joints']
435
+
436
+ self.ann_info['inference_channel'] = data_cfg['inference_channel']
437
+ self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
438
+ self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
439
+
440
+ self.ann_info['use_different_joint_weights'] = data_cfg.get(
441
+ 'use_different_joint_weights', False)
442
+
443
+ assert self.ann_info['num_joints'] == 16
444
+ self.ann_info['flip_pairs'] = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]]
445
+ self.ann_info['flip_index'] = [5, 4, 3, 2, 1, 0, 6, 7, 8, 9, 15, 14, 13, 12, 11, 10]
446
+ self.ann_info['upper_body_ids'] = [7, 8, 9, 10, 11, 12, 13, 14, 15]
447
+ self.ann_info['lower_body_ids'] = [0, 1, 2, 3, 4, 5, 6]
448
+ self.ann_info['joint_weights'] = np.array([
449
+ 1.5, 1.2, 1., 1., 1.2, 1.5, 1., 1., 1., 1., 1.5, 1.2, 1., 1., 1.2, 1.5
450
+ ])
451
+ self.ann_info['skeleton'] = [[0, 1], [1, 2], [2, 6], [6, 3], [3, 4], [4, 5], [6, 7],
452
+ [7, 8], [8, 9], [8, 12], [12, 11], [11, 10], [8, 13], [13, 14], [14, 15]]
453
+ self.sigmas = np.array(
454
+ [0.089, 0.083, 0.107, 0.107, 0.083, 0.089, 0.026, 0.026, 0.026, 0.026, 0.062, 0.072, 0.179, 0.179, 0.072,
455
+ 0.062])
456
+ self.dataset_name = 'mpii'
457
+
458
+ self.db = self._get_db()
459
+ if data_use_ratio != 1:
460
+ self.db = random.sample(self.db, int(len(self.db) * data_use_ratio))
461
+
462
+ self.image_set = set(x['image_file'] for x in self.db)
463
+ self.num_images = len(self.image_set)
464
+
465
+ print(f'=> num_images: {self.num_images}')
466
+ print(f'=> load {len(self.db)} samples')
467
+
468
+ if test_mode:
469
+ pipeline = [
470
+ LoadImageFromFile(),
471
+ TopDownAffine(use_udp=use_udp),
472
+ ToUNTensor(),
473
+ Collect(keys=['image'],
474
+ meta_keys=['image_file', 'center', 'scale', 'rotation', 'flip_pairs'])
475
+ ]
476
+ else:
477
+ pipeline = [
478
+ LoadImageFromFile(),
479
+ TopDownRandomFlip(flip_prob=0.5),
480
+ TopDownGetRandomScaleRotation(rot_factor=40, scale_factor=0.5),
481
+ TopDownAffine(use_udp=use_udp),
482
+ ToUNTensor(),
483
+ TopDownGenerateTarget(sigma=2, encoding='UDP' if use_udp else 'MSRA'),
484
+ Collect(keys=['image', 'label', 'target_weight'],
485
+ meta_keys=['image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', 'rotation',
486
+ 'flip_pairs'])
487
+ ]
488
+ self.pipeline = ComposeX(pipeline)
489
+
490
+ self.task_name = ginfo.task_name
491
+
492
+ @staticmethod
493
+ def _get_mapping_id_name(imgs):
494
+ """
495
+ Args:
496
+ imgs (dict): dict of image info.
497
+ Returns:
498
+ tuple: Image name & id mapping dicts.
499
+ - id2name (dict): Mapping image id to name.
500
+ - name2id (dict): Mapping image name to id.
501
+ """
502
+ id2name = {}
503
+ name2id = {}
504
+ for image_id, image in imgs.items():
505
+ file_name = image['file_name']
506
+ id2name[image_id] = file_name
507
+ name2id[file_name] = image_id
508
+
509
+ return id2name, name2id
510
+
511
+ def _xywh2cs(self, x, y, w, h, padding=1.25):
512
+ """This encodes bbox(x,y,w,h) into (center, scale)
513
+ Args:
514
+ x, y, w, h (float): left, top, width and height
515
+ padding (float): bounding box padding factor
516
+ Returns:
517
+ center (np.ndarray[float32](2,)): center of the bbox (x, y).
518
+ scale (np.ndarray[float32](2,)): scale of the bbox w & h.
519
+ """
520
+ aspect_ratio = self.ann_info['image_size'][0] / self.ann_info[
521
+ 'image_size'][1]
522
+ center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
523
+
524
+ if (not self.test_mode) and np.random.rand() < 0.3:
525
+ center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
526
+
527
+ if w > aspect_ratio * h:
528
+ h = w * 1.0 / aspect_ratio
529
+ elif w < aspect_ratio * h:
530
+ w = h * aspect_ratio
531
+
532
+ # pixel std is 200.0
533
+ scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
534
+ # padding to include proper amount of context
535
+ scale = scale * padding
536
+
537
+ return center, scale
538
+
539
+ def _get_normalize_factor(self, gts, *args, **kwargs):
540
+ """Get the normalize factor. generally inter-ocular distance measured
541
+ as the Euclidean distance between the outer corners of the eyes is
542
+ used. This function should be overrode, to measure NME.
543
+ Args:
544
+ gts (np.ndarray[N, K, 2]): Groundtruth keypoint location.
545
+ Returns:
546
+ np.ndarray[N, 2]: normalized factor
547
+ """
548
+ return np.ones([gts.shape[0], 2], dtype=np.float32)
549
+
550
+ def _get_db(self):
551
+ """Load dataset."""
552
+ # create train/val split
553
+ with open(self.ann_file, 'r') as f:
554
+ anno = json.load(f)
555
+
556
+ gt_db = []
557
+ bbox_id = 0
558
+ for a in anno:
559
+ image_name = a['image']
560
+
561
+ center = np.array(a['center'], dtype=np.float32)
562
+ scale = np.array([a['scale'], a['scale']], dtype=np.float32)
563
+
564
+ # Adjust center/scale slightly to avoid cropping limbs
565
+ if center[0] != -1:
566
+ center[1] = center[1] + 15 * scale[1]
567
+ # padding to include proper amount of context
568
+ scale = scale * 1.25
569
+
570
+ # MPII uses matlab format, index is 1-based,
571
+ # we should first convert to 0-based index
572
+ center = center - 1
573
+
574
+ joints_3d = np.zeros((self.ann_info['num_joints'], 3),
575
+ dtype=np.float32)
576
+ joints_3d_visible = np.zeros((self.ann_info['num_joints'], 3),
577
+ dtype=np.float32)
578
+ if not self.test_mode:
579
+ joints = np.array(a['joints'])
580
+ joints_vis = np.array(a['joints_vis'])
581
+ assert len(joints) == self.ann_info['num_joints'], \
582
+ f'joint num diff: {len(joints)}' + \
583
+ f' vs {self.ann_info["num_joints"]}'
584
+
585
+ joints_3d[:, 0:2] = joints[:, 0:2] - 1
586
+ joints_3d_visible[:, :2] = joints_vis[:, None]
587
+ image_file = osp.join(self.img_prefix, image_name)
588
+ gt_db.append({
589
+ 'image_file': image_file,
590
+ 'bbox_id': bbox_id,
591
+ 'center': center,
592
+ 'scale': scale,
593
+ 'rotation': 0,
594
+ 'joints_3d': joints_3d,
595
+ 'joints_3d_visible': joints_3d_visible,
596
+ 'dataset': self.dataset_name,
597
+ 'bbox_score': 1
598
+ })
599
+ bbox_id = bbox_id + 1
600
+ gt_db = sorted(gt_db, key=lambda x: x['bbox_id'])
601
+
602
+ return gt_db
603
+
604
+ @staticmethod
605
+ def _write_keypoint_results(keypoints, res_file):
606
+ """Write results into a json file."""
607
+
608
+ with open(res_file, 'w') as f:
609
+ json.dump(keypoints, f, sort_keys=True, indent=4)
610
+
611
+ def _report_metric(self,
612
+ res_file,
613
+ metrics,
614
+ pck_thr=0.2,
615
+ pckh_thr=0.7,
616
+ auc_nor=30):
617
+ """Keypoint evaluation.
618
+ Args:
619
+ res_file (str): Json file stored prediction results.
620
+ metrics (str | list[str]): Metric to be performed.
621
+ Options: 'PCK', 'PCKh', 'AUC', 'EPE', 'NME'.
622
+ pck_thr (float): PCK threshold, default as 0.2.
623
+ pckh_thr (float): PCKh threshold, default as 0.7.
624
+ auc_nor (float): AUC normalization factor, default as 30 pixel.
625
+ Returns:
626
+ List: Evaluation results for evaluation metric.
627
+ """
628
+ info_str = []
629
+
630
+ with open(res_file, 'r') as fin:
631
+ preds = json.load(fin)
632
+ assert len(preds) == len(self.db)
633
+
634
+ outputs = []
635
+ gts = []
636
+ masks = []
637
+ box_sizes = []
638
+ threshold_bbox = []
639
+ threshold_head_box = []
640
+
641
+ for pred, item in zip(preds, self.db):
642
+ outputs.append(np.array(pred['keypoints'])[:, :-1])
643
+ gts.append(np.array(item['joints_3d'])[:, :-1])
644
+ masks.append((np.array(item['joints_3d_visible'])[:, 0]) > 0)
645
+ if 'PCK' in metrics:
646
+ bbox = np.array(item['bbox'])
647
+ bbox_thr = np.max(bbox[2:])
648
+ threshold_bbox.append(np.array([bbox_thr, bbox_thr]))
649
+ if 'PCKh' in metrics:
650
+ head_box_thr = item['head_size']
651
+ threshold_head_box.append(
652
+ np.array([head_box_thr, head_box_thr]))
653
+ box_sizes.append(item.get('box_size', 1))
654
+
655
+ outputs = np.array(outputs)
656
+ gts = np.array(gts)
657
+ masks = np.array(masks)
658
+ threshold_bbox = np.array(threshold_bbox)
659
+ threshold_head_box = np.array(threshold_head_box)
660
+ box_sizes = np.array(box_sizes).reshape([-1, 1])
661
+
662
+ if 'PCK' in metrics:
663
+ _, pck, _ = keypoint_pck_accuracy(outputs, gts, masks, pck_thr,
664
+ threshold_bbox)
665
+ info_str.append(('PCK', pck))
666
+
667
+ if 'PCKh' in metrics:
668
+ _, pckh, _ = keypoint_pck_accuracy(outputs, gts, masks, pckh_thr,
669
+ threshold_head_box)
670
+ info_str.append(('PCKh', pckh))
671
+
672
+ if 'AUC' in metrics:
673
+ info_str.append(('AUC', keypoint_auc(outputs, gts, masks,
674
+ auc_nor)))
675
+
676
+ if 'EPE' in metrics:
677
+ info_str.append(('EPE', keypoint_epe(outputs, gts, masks)))
678
+
679
+ if 'NME' in metrics:
680
+ normalize_factor = self._get_normalize_factor(
681
+ gts=gts, box_sizes=box_sizes)
682
+ info_str.append(
683
+ ('NME', keypoint_nme(outputs, gts, masks, normalize_factor)))
684
+
685
+ return info_str
686
+
687
+ def __len__(self):
688
+ """Get the size of the dataset."""
689
+ return len(self.db)
690
+
691
+ def __getitem__(self, idx):
692
+ """Get the sample given index."""
693
+ results = copy.deepcopy(self.db[idx])
694
+ results['ann_info'] = self.ann_info
695
+ out = self.pipeline(results)
696
+ C = self.ann_info['num_joints']
697
+ if 'label' in out:
698
+ out['dense_labeling'] = np.resize(out['label'],
699
+ (C, self.ann_info['image_size'][1], self.ann_info['image_size'][0]))
700
+ else:
701
+ out['dense_labeling'] = np.zeros((C, self.ann_info['image_size'][1], self.ann_info['image_size'][0]))
702
+ # import pdb;pdb.set_trace()
703
+ return out
704
+
705
+ def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
706
+ """sort kpts and remove the repeated ones."""
707
+ kpts = sorted(kpts, key=lambda x: x[key])
708
+ num = len(kpts)
709
+ for i in range(num - 1, 0, -1):
710
+ if kpts[i][key] == kpts[i - 1][key]:
711
+ del kpts[i]
712
+
713
+ return kpts
core/data/datasets/images/resources/CHval.odgt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8f9c1b0cb455d6b0fb53b73d1dd92fbb3b3b02bfd897661ceddc246e4991b5e
3
+ size 19994003
core/data/datasets/images/resources/COCO_val2017_detections_AP_H_56_person.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53ba0ad8d0fd461c5a000cd90797fa8c39cd8c38cd125125c0412626ff592d59
3
+ size 16383781
core/data/datasets/images/resources/mpii_gt_val.mat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ab6874f858046c74acdd1b9dacb8746a2ddddc331952487a9774e3ee0c2b075
3
+ size 1257356
core/data/datasets/images/resources/test_caltech_heavy_1xnew.odgt ADDED
The diff for this file is too large to render. See raw diff
 
core/data/datasets/images/seg_data_tools/__init__.py ADDED
File without changes
core/data/datasets/images/seg_data_tools/collate.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.data.dataloader import default_collate
6
+
7
+ from lib.extensions.parallel.data_container import DataContainer
8
+
9
+
10
+ def stack(batch, data_key=None, return_dc=False):
11
+ if isinstance(batch[0][data_key], DataContainer):
12
+ if batch[0][data_key].stack:
13
+ assert isinstance(batch[0][data_key].data, torch.Tensor)
14
+ samples = [sample[data_key].data for sample in batch]
15
+ return default_collate(samples)
16
+
17
+ elif not return_dc:
18
+ return [sample[data_key].data for sample in batch]
19
+
20
+ else:
21
+ return DataContainer([sample[data_key].data for sample in batch])
22
+
23
+ else:
24
+ return default_collate([sample[data_key] for sample in batch])
25
+
26
+
27
+ def collate(batch, trans_dict):
28
+ data_keys = batch[0].keys()
29
+
30
+ target_width, target_height = trans_dict['input_size']
31
+ target_widths, target_heights = [target_width] * len(batch), [target_height] * len(batch)
32
+
33
+
34
+ for i in range(len(batch)):
35
+ target_width, target_height = target_widths[i], target_heights[i]
36
+
37
+ if 'meta' in data_keys:
38
+ batch[i]['meta'].data['input_size'] = [target_width, target_height]
39
+
40
+ channels, height, width = batch[i]['img'].size()
41
+ if height == target_height and width == target_width:
42
+ continue
43
+
44
+ scaled_size = [width, height]
45
+
46
+ if trans_dict['align_method'] in ['only_scale', 'scale_and_pad']:
47
+ w_scale_ratio = target_width / width
48
+ h_scale_ratio = target_height / height
49
+ if trans_dict['align_method'] == 'scale_and_pad':
50
+ w_scale_ratio = min(w_scale_ratio, h_scale_ratio)
51
+ h_scale_ratio = w_scale_ratio
52
+
53
+ scaled_size = (int(round(width * w_scale_ratio)), int(round(height * h_scale_ratio)))
54
+ if 'meta' in data_keys and 'border_size' in batch[i]['meta'].data:
55
+ batch[i]['meta'].data['border_size'] = scaled_size
56
+
57
+ scaled_size_hw = (scaled_size[1], scaled_size[0])
58
+ batch[i]['img'] = DataContainer(F.interpolate(batch[i]['img'].data.unsqueeze(0),
59
+ scaled_size_hw, mode='bilinear', align_corners=True).squeeze(0), stack=True)
60
+ if 'labelmap' in data_keys:
61
+ labelmap = batch[i]['labelmap'].data.unsqueeze(0).unsqueeze(0).float()
62
+ labelmap = F.interpolate(labelmap, scaled_size_hw, mode='nearest').long().squeeze(0).squeeze(0)
63
+ batch[i]['labelmap'] = DataContainer(labelmap, stack=True)
64
+
65
+ if 'maskmap' in data_keys:
66
+ maskmap = batch[i]['maskmap'].data.unsqueeze(0).unsqueeze(0).float()
67
+ maskmap = F.interpolate(maskmap, scaled_size_hw, mode='nearest').long().squeeze(0).squeeze(0)
68
+ batch[i]['maskmap'].data = DataContainer(maskmap, stack=True)
69
+
70
+ pad_width = target_width - scaled_size[0]
71
+ pad_height = target_height - scaled_size[1]
72
+ assert pad_height >= 0 and pad_width >= 0
73
+ if pad_width > 0 or pad_height > 0:
74
+ assert trans_dict['align_method'] in ['only_pad', 'scale_and_pad']
75
+ left_pad = 0
76
+ up_pad = 0
77
+ if 'pad_mode' not in trans_dict or trans_dict['pad_mode'] == 'random':
78
+ left_pad = random.randint(0, pad_width) # pad_left
79
+ up_pad = random.randint(0, pad_height) # pad_up
80
+
81
+ elif trans_dict['pad_mode'] == 'pad_left_up':
82
+ left_pad = pad_width
83
+ up_pad = pad_height
84
+
85
+ elif trans_dict['pad_mode'] == 'pad_right_down':
86
+ left_pad = 0
87
+ up_pad = 0
88
+
89
+ elif trans_dict['pad_mode'] == 'pad_center':
90
+ left_pad = pad_width // 2
91
+ up_pad = pad_height // 2
92
+
93
+ elif trans_dict['pad_mode'] == 'pad_border':
94
+ if random.randint(0, 1) == 0:
95
+ left_pad = pad_width
96
+ up_pad = pad_height
97
+ else:
98
+ left_pad = 0
99
+ up_pad = 0
100
+ else:
101
+ raise ValueError("mode not define")
102
+ exit(1)
103
+
104
+ pad = (left_pad, pad_width-left_pad, up_pad, pad_height-up_pad)
105
+
106
+ batch[i]['img'] = DataContainer(F.pad(batch[i]['img'].data, pad=pad, value=0), stack=batch[i]['img'].stack)
107
+
108
+ if 'labelmap' in data_keys:
109
+ batch[i]['labelmap'] = DataContainer(F.pad(batch[i]['labelmap'].data, pad=pad, value=-1), stack=batch[i]['labelmap'].stack)
110
+
111
+ if 'maskmap' in data_keys:
112
+ batch[i]['maskmap'] = DataContainer(F.pad(batch[i]['maskmap'].data, pad=pad, value=0), stack=batch[i]['maskmap'].stack)
113
+
114
+ if 'distance_map' in data_keys:
115
+ batch[i]['distance_map'] = DataContainer(F.pad(batch[i]['distance_map'].data, pad=pad, value=255), stack=batch[i]['distance_map'].stack)
116
+
117
+ if 'angle_map' in data_keys:
118
+ batch[i]['angle_map'] = DataContainer(F.pad(batch[i]['angle_map'].data, pad=pad, value=0), stack=batch[i]['angle_map'].stack)
119
+
120
+ if 'mask_label_map' in data_keys:
121
+ batch[i]['mask_label_map'] = DataContainer(F.pad(batch[i]['mask_label_map'].data, pad=pad, value=-1), stack=batch[i]['mask_label_map'].stack)
122
+
123
+ if 'direction_label_map' in data_keys:
124
+ batch[i]['direction_label_map'] = DataContainer(F.pad(batch[i]['direction_label_map'].data, pad=pad, value=-1), stack=batch[i]['direction_label_map'].stack)
125
+
126
+ if 'multi_label_direction_map' in data_keys:
127
+ batch[i]['multi_label_direction_map'] = DataContainer(F.pad(batch[i]['multi_label_direction_map'].data, pad=pad, value=-1), stack=batch[i]['multi_label_direction_map'].stack)
128
+
129
+ if 'energy_label_map' in data_keys:
130
+ batch[i]['energy_label_map'] = DataContainer(F.pad(batch[i]['energy_label_map'].data, pad=pad, value=-1), stack=batch[i]['energy_label_map'].stack)
131
+
132
+ if 'offsetmap_h' in data_keys:
133
+ batch[i]['offsetmap_h'] = DataContainer(F.pad(batch[i]['offsetmap_h'].data, pad=pad, value=0), stack=batch[i]['offsetmap_h'].stack)
134
+
135
+ if 'offsetmap_w' in data_keys:
136
+ batch[i]['offsetmap_w'] = DataContainer(F.pad(batch[i]['offsetmap_w'].data, pad=pad, value=0), stack=batch[i]['offsetmap_w'].stack)
137
+
138
+ return dict({key: stack(batch, data_key=key) for key in data_keys})
139
+
140
+
141
+
142
+
143
+
core/data/datasets/images/seg_data_tools/cv2_aug_transforms.py ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import math
3
+ import random
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ class _BaseTransform(object):
9
+
10
+ DATA_ITEMS = (
11
+ 'labelmap', 'maskmap',
12
+ 'distance_map', 'angle_map', 'multi_label_direction_map',
13
+ 'boundary_map', 'offsetmap',
14
+ # 'offsetmap_h', 'offsetmap_w',
15
+ 'region_indexmap'
16
+ )
17
+
18
+ def __call__(self, img, **kwargs):
19
+
20
+ data_dict = collections.defaultdict(lambda: None)
21
+ data_dict.update(kwargs)
22
+
23
+ return img, data_dict
24
+
25
+ def _process(self, img, data_dict, skip_condition, *args, **kwargs):
26
+ assert isinstance(img, np.ndarray), \
27
+ "img should be numpy array, got {}.".format(type(img))
28
+ if not skip_condition:
29
+ img = self._process_img(img, *args, **kwargs)
30
+
31
+ ret_dict = collections.defaultdict(lambda: None)
32
+ for name in self.DATA_ITEMS:
33
+ func_name = '_process_' + name
34
+ x = data_dict[name]
35
+
36
+ assert isinstance(x, np.ndarray) or x is None, \
37
+ "{} should be numpy array or None, got {}.".format(
38
+ name, type(x))
39
+
40
+ if hasattr(self, func_name) and x is not None and not skip_condition:
41
+ ret_dict[name] = getattr(self, func_name)(x, *args, **kwargs)
42
+ else:
43
+ ret_dict[name] = x
44
+
45
+ return img, ret_dict
46
+
47
+
48
+ class Padding(_BaseTransform):
49
+ """ Padding the Image to proper size.
50
+ Args:
51
+ stride: the stride of the network.
52
+ pad_value: the value that pad to the image border.
53
+ img: Image object as input.
54
+ Returns::
55
+ img: Image object.
56
+ """
57
+
58
+ def __init__(self, pad=None, pad_ratio=0.5, mean=(104, 117, 123), allow_outside_center=True):
59
+ self.pad = pad
60
+ self.ratio = pad_ratio
61
+ self.mean = mean
62
+ self.allow_outside_center = allow_outside_center
63
+
64
+ def _pad(self, x, pad_value, height, width, target_size, offset_left, offset_up):
65
+ expand_x = np.zeros((
66
+ max(height, target_size[1]) + abs(offset_up),
67
+ max(width, target_size[0]) + abs(offset_left),
68
+ *x.shape[2:]
69
+ ), dtype=x.dtype)
70
+ expand_x[:, :] = pad_value
71
+ expand_x[
72
+ abs(min(offset_up, 0)):abs(min(offset_up, 0)) + height,
73
+ abs(min(offset_left, 0)):abs(min(offset_left, 0)) + width] = x
74
+ x = expand_x[
75
+ max(offset_up, 0):max(offset_up, 0) + target_size[1],
76
+ max(offset_left, 0):max(offset_left, 0) + target_size[0]
77
+ ]
78
+ return x
79
+
80
+ def _process_img(self, img, *args):
81
+ return self._pad(img, self.mean, *args)
82
+
83
+ def _process_labelmap(self, x, *args):
84
+ return self._pad(x, 255, *args)
85
+
86
+ def _process_region_indexmap(self, x, *args):
87
+ return self._pad(x, 0, *args)
88
+
89
+ def _process_maskmap(self, x, *args):
90
+ return self._pad(x, 1, *args)
91
+
92
+ def _process_distance_map(self, x, *args):
93
+ return self._pad(x, 255, *args)
94
+
95
+ def _process_angle_map(self, x, *args):
96
+ return self._pad(x, 0, *args)
97
+
98
+ def _process_boundary_map(self, x, *args):
99
+ return self._pad(x, 0, *args)
100
+
101
+ def _process_multi_label_direction_map(self, x, *args):
102
+ return self._pad(x, 0, *args)
103
+
104
+ # def _process_offsetmap_h(self, x, *args):
105
+ # return self._pad(x, 0, *args)
106
+
107
+ # def _process_offsetmap_w(self, x, *args):
108
+ # return self._pad(x, 0, *args)
109
+
110
+ def _process_offsetmap(self, x, *args):
111
+ return self._pad(x, 0, *args)
112
+
113
+ def __call__(self, img, **kwargs):
114
+ img, data_dict = super().__call__(img, **kwargs)
115
+
116
+ height, width, channels = img.shape
117
+ left_pad, up_pad, right_pad, down_pad = self.pad
118
+
119
+ target_size = [width + left_pad +
120
+ right_pad, height + up_pad + down_pad]
121
+ offset_left = -left_pad
122
+ offset_up = -up_pad
123
+
124
+ return self._process(
125
+ img, data_dict,
126
+ random.random() > self.ratio,
127
+ height, width, target_size, offset_left, offset_up
128
+ )
129
+
130
+
131
+ class RandomHFlip(_BaseTransform):
132
+ def __init__(self, swap_pair=None, flip_ratio=0.5):
133
+ self.swap_pair = swap_pair
134
+ self.ratio = flip_ratio
135
+
136
+ def _process_img(self, img):
137
+ return cv2.flip(img, 1)
138
+
139
+ def _process_labelmap(self, labelmap):
140
+ labelmap = cv2.flip(labelmap, 1)
141
+ # to handle datasets with left/right annatations
142
+ if self.swap_pair is not None:
143
+ assert isinstance(self.swap_pair, (tuple, list))
144
+ temp = labelmap.copy()
145
+ for pair in self.swap_pair:
146
+ assert isinstance(pair, (tuple, list)) and len(pair) == 2
147
+ labelmap[temp == pair[0]] = pair[1]
148
+ labelmap[temp == pair[1]] = pair[0]
149
+
150
+ return labelmap
151
+
152
+ def _process_region_indexmap(self, labelmap):
153
+ return cv2.flip(labelmap, 1)
154
+
155
+ def _process_maskmap(self, x):
156
+ return cv2.flip(x, 1)
157
+
158
+ def _process_distance_map(self, x):
159
+ return cv2.flip(x, 1)
160
+
161
+ def _process_angle_map(self, angle_map):
162
+ ret_angle_map = angle_map.copy()
163
+ mask = (angle_map > 0) & (angle_map < 180)
164
+ ret_angle_map[mask] = 180 - angle_map[mask]
165
+ mask = (angle_map < 0) & (angle_map > -180)
166
+ ret_angle_map[mask] = - (180 + angle_map[mask])
167
+ ret_angle_map = cv2.flip(ret_angle_map, 1)
168
+ return ret_angle_map
169
+
170
+ def _process_boundary_map(self, x):
171
+ return cv2.flip(x, 1)
172
+
173
+ def _process_multi_label_direction_map(self, multi_label_direction_map):
174
+ perm = [4, 3, 2, 1, 0, 7, 6, 5]
175
+ multi_label_direction_map = cv2.flip(multi_label_direction_map, 1)
176
+ multi_label_direction_map = multi_label_direction_map[..., perm]
177
+ return multi_label_direction_map
178
+
179
+ # def _process_offsetmap_h(self, x):
180
+ # return cv2.flip(x, 1)
181
+
182
+ # def _process_offsetmap_w(self, x):
183
+ # return -cv2.flip(x, 1)
184
+
185
+ def _process_offsetmap_w(self, x):
186
+ x = cv2.flip(x, 1)
187
+ x[..., 1] = -x[..., 1]
188
+ return x
189
+
190
+ def __call__(self, img, **kwargs):
191
+ img, data_dict = super().__call__(img, **kwargs)
192
+
193
+ return self._process(
194
+ img, data_dict,
195
+ random.random() > self.ratio
196
+ )
197
+
198
+
199
+ class RandomSaturation(_BaseTransform):
200
+ def __init__(self, lower=0.5, upper=1.5, saturation_ratio=0.5):
201
+ self.lower = lower
202
+ self.upper = upper
203
+ self.ratio = saturation_ratio
204
+ assert self.upper >= self.lower, "saturation upper must be >= lower."
205
+ assert self.lower >= 0, "saturation lower must be non-negative."
206
+
207
+ def _process_img(self, img):
208
+ img = img.astype(np.float32)
209
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
210
+ img[:, :, 1] *= random.uniform(self.lower, self.upper)
211
+ img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
212
+ img = np.clip(img, 0, 255).astype(np.uint8)
213
+ return img
214
+
215
+ def __call__(self, img, **kwargs):
216
+ img, data_dict = super().__call__(img, **kwargs)
217
+
218
+ return self._process(
219
+ img, data_dict,
220
+ random.random() > self.ratio
221
+ )
222
+
223
+
224
+ class RandomHue(_BaseTransform):
225
+ def __init__(self, delta=18, hue_ratio=0.5):
226
+ assert 0 <= delta <= 360
227
+ self.delta = delta
228
+ self.ratio = hue_ratio
229
+
230
+ def _process_img(self, img):
231
+ img = img.astype(np.float32)
232
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
233
+ img[:, :, 0] += random.uniform(-self.delta, self.delta)
234
+ img[:, :, 0][img[:, :, 0] > 360] -= 360
235
+ img[:, :, 0][img[:, :, 0] < 0] += 360
236
+ img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
237
+ img = np.clip(img, 0, 255).astype(np.uint8)
238
+ return img
239
+
240
+ def __call__(self, img, **kwargs):
241
+ img, data_dict = super().__call__(img, **kwargs)
242
+
243
+ return self._process(
244
+ img, data_dict,
245
+ random.random() > self.ratio
246
+ )
247
+
248
+
249
+ class RandomPerm(_BaseTransform):
250
+ def __init__(self, perm_ratio=0.5):
251
+ self.ratio = perm_ratio
252
+ self.perms = ((0, 1, 2), (0, 2, 1),
253
+ (1, 0, 2), (1, 2, 0),
254
+ (2, 0, 1), (2, 1, 0))
255
+
256
+ def _process_img(self, img):
257
+ swap = self.perms[random.randint(0, len(self.perms) - 1)]
258
+ img = img[:, :, swap].astype(np.uint8)
259
+ return img
260
+
261
+ def __call__(self, img, **kwargs):
262
+ img, data_dict = super().__call__(img, **kwargs)
263
+
264
+ return self._process(
265
+ img, data_dict,
266
+ random.random() > self.ratio
267
+ )
268
+
269
+
270
+ class RandomContrast(_BaseTransform):
271
+ def __init__(self, lower=0.5, upper=1.5, contrast_ratio=0.5):
272
+ self.lower = lower
273
+ self.upper = upper
274
+ self.ratio = contrast_ratio
275
+ assert self.upper >= self.lower, "contrast upper must be >= lower."
276
+ assert self.lower >= 0, "contrast lower must be non-negative."
277
+
278
+ def _process_img(self, img):
279
+ img = img.astype(np.float32)
280
+ img *= random.uniform(self.lower, self.upper)
281
+ img = np.clip(img, 0, 255).astype(np.uint8)
282
+ return img
283
+
284
+ def __call__(self, img, **kwargs):
285
+ img, data_dict = super().__call__(img, **kwargs)
286
+
287
+ return self._process(
288
+ img, data_dict,
289
+ random.random() > self.ratio
290
+ )
291
+
292
+
293
+ class RandomBrightness(_BaseTransform):
294
+ def __init__(self, shift_value=30, brightness_ratio=0.5):
295
+ self.shift_value = shift_value
296
+ self.ratio = brightness_ratio
297
+
298
+ def _process_img(self, img):
299
+ img = img.astype(np.float32)
300
+ shift = random.randint(-self.shift_value, self.shift_value)
301
+ img[:, :, :] += shift
302
+ img = np.around(img)
303
+ img = np.clip(img, 0, 255).astype(np.uint8)
304
+ return img
305
+
306
+ def __call__(self, img, **kwargs):
307
+ img, data_dict = super().__call__(img, **kwargs)
308
+
309
+ return self._process(
310
+ img, data_dict,
311
+ random.random() > self.ratio
312
+ )
313
+
314
+
315
+ class RandomResize(_BaseTransform):
316
+ """Resize the given numpy.ndarray to random size and aspect ratio.
317
+
318
+ Args:
319
+ scale_min: the min scale to resize.
320
+ scale_max: the max scale to resize.
321
+ """
322
+
323
+ def __init__(self, scale_range=(0.75, 1.25), aspect_range=(0.9, 1.1), target_size=None,
324
+ resize_bound=None, method='random', max_side_bound=None, scale_list=None, resize_ratio=0.5):
325
+ self.scale_range = scale_range
326
+ self.aspect_range = aspect_range
327
+ self.resize_bound = resize_bound
328
+ self.max_side_bound = max_side_bound
329
+ self.scale_list = scale_list
330
+ self.method = method
331
+ self.ratio = resize_ratio
332
+
333
+ if target_size is not None:
334
+ if isinstance(target_size, int):
335
+ self.input_size = (target_size, target_size)
336
+ elif isinstance(target_size, (list, tuple)) and len(target_size) == 2:
337
+ self.input_size = target_size
338
+ else:
339
+ raise TypeError(
340
+ 'Got inappropriate size arg: {}'.format(target_size))
341
+ else:
342
+ self.input_size = None
343
+
344
+ def get_scale(self, img_size):
345
+ if self.method == 'random':
346
+ scale_ratio = random.uniform(
347
+ self.scale_range[0], self.scale_range[1])
348
+ return scale_ratio
349
+
350
+ elif self.method == 'bound':
351
+ scale1 = self.resize_bound[0] / min(img_size)
352
+ scale2 = self.resize_bound[1] / max(img_size)
353
+ scale = min(scale1, scale2)
354
+ return scale
355
+
356
+ else:
357
+ raise ValueError("invalid method")
358
+ exit(1)
359
+
360
+ def _process_img(self, img, converted_size, *args):
361
+ return cv2.resize(img, converted_size, interpolation=cv2.INTER_CUBIC).astype(np.uint8)
362
+
363
+ def _process_labelmap(self, x, converted_size, *args):
364
+ return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
365
+
366
+ def _process_region_indexmap(self, x, converted_size, *args):
367
+ return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
368
+
369
+ def _process_maskmap(self, x, converted_size, *args):
370
+ return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
371
+
372
+ def _process_distance_map(self, x, converted_size, *args):
373
+ return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
374
+
375
+ def _process_angle_map(self, x, converted_size, *args):
376
+ return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
377
+
378
+ def _process_boundary_map(self, x, converted_size, *args):
379
+ return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
380
+
381
+ def _process_multi_label_direction_map(self, x, converted_size, *args):
382
+ return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
383
+
384
+ # def _process_offsetmap_h(self, x, converted_size, h_scale_ratio, w_scale_ratio):
385
+ # return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST) * h_scale_ratio
386
+
387
+ # def _process_offsetmap_w(self, x, converted_size, h_scale_ratio, w_scale_ratio):
388
+ # return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST) * w_scale_ratio
389
+
390
+ def _process_offsetmap(self, x, converted_size, h_scale_ratio, w_scale_ratio):
391
+ return cv2.resize(x, converted_size, interpolation=cv2.INTER_NEAREST)
392
+
393
+ def __call__(self, img, **kwargs):
394
+ """
395
+ Args:
396
+ img (Image): Image to be resized.
397
+ maskmap (Image): Mask to be resized.
398
+ kpt (list): keypoints to be resized.
399
+ center: (list): center points to be resized.
400
+
401
+ Returns:
402
+ Image: Randomly resize image.
403
+ Image: Randomly resize maskmap.
404
+ list: Randomly resize keypoints.
405
+ list: Randomly resize center points.
406
+ """
407
+ img, data_dict = super().__call__(img, **kwargs)
408
+
409
+ height, width, _ = img.shape
410
+ if self.scale_list is None:
411
+ scale_ratio = self.get_scale([width, height])
412
+ else:
413
+ scale_ratio = self.scale_list[random.randint(
414
+ 0, len(self.scale_list)-1)]
415
+
416
+ aspect_ratio = random.uniform(*self.aspect_range)
417
+ w_scale_ratio = math.sqrt(aspect_ratio) * scale_ratio
418
+ h_scale_ratio = math.sqrt(1.0 / aspect_ratio) * scale_ratio
419
+ if self.max_side_bound is not None and max(height*h_scale_ratio, width*w_scale_ratio) > self.max_side_bound:
420
+ d_ratio = self.max_side_bound / max(height * h_scale_ratio, width * w_scale_ratio)
421
+ w_scale_ratio *= d_ratio
422
+ h_scale_ratio *= d_ratio
423
+
424
+ converted_size = (int(width * w_scale_ratio),
425
+ int(height * h_scale_ratio))
426
+ return self._process(
427
+ img, data_dict,
428
+ random.random() > self.ratio,
429
+ converted_size, h_scale_ratio, w_scale_ratio
430
+ )
431
+
432
+
433
+ class RandomRotate(_BaseTransform):
434
+ """Rotate the input numpy.ndarray and points to the given degree.
435
+
436
+ Args:
437
+ degree (number): Desired rotate degree.
438
+ """
439
+
440
+ def __init__(self, max_degree, rotate_ratio=0.5, mean=(104, 117, 123)):
441
+ assert isinstance(max_degree, int)
442
+ self.max_degree = max_degree
443
+ self.ratio = rotate_ratio
444
+ self.mean = mean
445
+
446
+ def _warp(self, x, border_value, rotate_mat, new_width, new_height):
447
+ return cv2.warpAffine(x, rotate_mat, (new_width, new_height), borderValue=border_value)
448
+
449
+ def _process_img(self, x, *args):
450
+ return self._warp(x, self.mean, *args).astype(np.uint8)
451
+
452
+ def _process_labelmap(self, x, *args):
453
+ return self._warp(x, (255, 255, 255), *args).astype(np.uint8)
454
+
455
+ def _process_maskmap(self, x, *args):
456
+ return self._warp(x, (1, 1, 1), *args).astype(np.uint8)
457
+
458
+ def __call__(self, img, **kwargs):
459
+ """
460
+ Args:
461
+ img (Image): Image to be rotated.
462
+ maskmap (Image): Mask to be rotated.
463
+ kpt (list): Keypoints to be rotated.
464
+ center (list): Center points to be rotated.
465
+
466
+ Returns:
467
+ Image: Rotated image.
468
+ list: Rotated key points.
469
+ """
470
+ img, data_dict = super().__call__(img, **kwargs)
471
+
472
+ rotate_degree = random.uniform(-self.max_degree, self.max_degree)
473
+ height, width, _ = img.shape
474
+ img_center = (width / 2.0, height / 2.0)
475
+ rotate_mat = cv2.getRotationMatrix2D(img_center, rotate_degree, 1.0)
476
+ cos_val = np.abs(rotate_mat[0, 0])
477
+ sin_val = np.abs(rotate_mat[0, 1])
478
+ new_width = int(height * sin_val + width * cos_val)
479
+ new_height = int(height * cos_val + width * sin_val)
480
+ rotate_mat[0, 2] += (new_width / 2.) - img_center[0]
481
+ rotate_mat[1, 2] += (new_height / 2.) - img_center[1]
482
+
483
+ return self._process(
484
+ img, data_dict,
485
+ random.random() > self.ratio,
486
+ rotate_mat, new_width, new_height
487
+ )
488
+
489
+
490
+ class RandomCrop(_BaseTransform):
491
+ """Crop the given numpy.ndarray and at a random location.
492
+
493
+ Args:
494
+ size (int or tuple): Desired output size of the crop.(w, h)
495
+ """
496
+
497
+ def __init__(self, crop_size, crop_ratio=0.5, method='random', grid=None, allow_outside_center=True):
498
+ self.ratio = crop_ratio
499
+ self.method = method
500
+ self.grid = grid
501
+ self.allow_outside_center = allow_outside_center
502
+
503
+ if isinstance(crop_size, float):
504
+ self.size = (crop_size, crop_size)
505
+ elif isinstance(crop_size, collections.Iterable) and len(crop_size) == 2:
506
+ self.size = crop_size
507
+ else:
508
+ raise TypeError('Got inappropriate size arg: {}'.format(crop_size))
509
+
510
+ def get_lefttop(self, crop_size, img_size):
511
+ if self.method == 'center':
512
+ return [(img_size[0] - crop_size[0]) // 2, (img_size[1] - crop_size[1]) // 2]
513
+
514
+ elif self.method == 'random':
515
+ x = random.randint(0, img_size[0] - crop_size[0])
516
+ y = random.randint(0, img_size[1] - crop_size[1])
517
+ return [x, y]
518
+
519
+ elif self.method == 'grid':
520
+ grid_x = random.randint(0, self.grid[0] - 1)
521
+ grid_y = random.randint(0, self.grid[1] - 1)
522
+ x = grid_x * ((img_size[0] - crop_size[0]) // (self.grid[0] - 1))
523
+ y = grid_y * ((img_size[1] - crop_size[1]) // (self.grid[1] - 1))
524
+ return [x, y]
525
+
526
+ else:
527
+ raise ValueError('Crop method invalid')
528
+ exit(1)
529
+
530
+ def _crop(self, x, offset_up, offset_left, target_size):
531
+ return x[offset_up:offset_up + target_size[1], offset_left:offset_left + target_size[0]]
532
+
533
+ def _process_img(self, img, *args):
534
+ return self._crop(img, *args)
535
+
536
+ def _process_labelmap(self, x, *args):
537
+ return self._crop(x, *args)
538
+
539
+ def _process_region_indexmap(self, x, *args):
540
+ return self._crop(x, *args)
541
+
542
+ def _process_maskmap(self, x, *args):
543
+ return self._crop(x, *args)
544
+
545
+ def _process_distance_map(self, x, *args):
546
+ return self._crop(x, *args)
547
+
548
+ def _process_angle_map(self, x, *args):
549
+ return self._crop(x, *args)
550
+
551
+ def _process_boundary_map(self, x, *args):
552
+ return self._crop(x, *args)
553
+
554
+ def _process_multi_label_direction_map(self, x, *args):
555
+ return self._crop(x, *args)
556
+
557
+ # def _process_offsetmap_h(self, x, *args):
558
+ # return self._crop(x, *args)
559
+
560
+ # def _process_offsetmap_w(self, x, *args):
561
+ # return self._crop(x, *args)
562
+
563
+ def _process_offsetmap(self, x, *args):
564
+ return self._crop(x, *args)
565
+
566
+ def __call__(self, img, **kwargs):
567
+ """
568
+ Args:
569
+ img (Image): Image to be cropped.
570
+ maskmap (Image): Mask to be cropped.
571
+
572
+ Returns:
573
+ Image: Cropped image.
574
+ Image: Cropped maskmap.
575
+ list: Cropped keypoints.
576
+ list: Cropped center points.
577
+ """
578
+ img, data_dict = super().__call__(img, **kwargs)
579
+
580
+ height, width, _ = img.shape
581
+ target_size = [min(self.size[0], width), min(self.size[1], height)]
582
+
583
+ offset_left, offset_up = self.get_lefttop(target_size, [width, height])
584
+ return self._process(
585
+ img, data_dict,
586
+ random.random() > self.ratio,
587
+ offset_up, offset_left, target_size
588
+ )
589
+
590
+
591
+ class Resize(RandomResize):
592
+ """Resize the given numpy.ndarray to random size and aspect ratio.
593
+ Args:
594
+ scale_min: the min scale to resize.
595
+ scale_max: the max scale to resize.
596
+ """
597
+
598
+ def __init__(self, target_size=None, min_side_length=None, max_side_length=None, max_side_bound=None):
599
+ self.target_size = target_size
600
+ self.min_side_length = min_side_length
601
+ self.max_side_length = max_side_length
602
+ self.max_side_bound = max_side_bound
603
+
604
+ def __call__(self, img, **kwargs):
605
+ img, data_dict = super(RandomResize, self).__call__(img, **kwargs)
606
+
607
+ height, width, _ = img.shape
608
+ if self.target_size is not None:
609
+ target_size = self.target_size
610
+ w_scale_ratio = self.target_size[0] / width
611
+ h_scale_ratio = self.target_size[1] / height
612
+
613
+ elif self.min_side_length is not None:
614
+ scale_ratio = self.min_side_length / min(width, height)
615
+ w_scale_ratio, h_scale_ratio = scale_ratio, scale_ratio
616
+ target_size = [int(round(width * w_scale_ratio)),
617
+ int(round(height * h_scale_ratio))]
618
+
619
+ else:
620
+ scale_ratio = self.max_side_length / max(width, height)
621
+ w_scale_ratio, h_scale_ratio = scale_ratio, scale_ratio
622
+ target_size = [int(round(width * w_scale_ratio)),
623
+ int(round(height * h_scale_ratio))]
624
+
625
+ if self.max_side_bound is not None and max(target_size) > self.max_side_bound:
626
+ d_ratio = self.max_side_bound / max(target_size)
627
+ w_scale_ratio = d_ratio * w_scale_ratio
628
+ h_scale_ratio = d_ratio * h_scale_ratio
629
+ target_size = [int(round(width * w_scale_ratio)),
630
+ int(round(height * h_scale_ratio))]
631
+
632
+ target_size = tuple(target_size)
633
+ return self._process(
634
+ img, data_dict,
635
+ False,
636
+ target_size, h_scale_ratio, w_scale_ratio
637
+ )
638
+
639
+
640
+ class CV2AugCompose(object):
641
+ """Composes several transforms together.
642
+
643
+ Args:
644
+ transforms (list of ``Transform`` objects): list of transforms to compose.
645
+
646
+ Example:
647
+ >>> CV2AugCompose([
648
+ >>> RandomCrop(),
649
+ >>> ])
650
+ """
651
+
652
+ def __init__(self, configer, split='train'):
653
+ self.configer = configer
654
+ self.split = split
655
+
656
+ if self.split == 'train':
657
+ shuffle_train_trans = []
658
+ if self.configer.exists('train_trans', 'shuffle_trans_seq'):
659
+ if isinstance(self.configer.get('train_trans', 'shuffle_trans_seq')[0], list):
660
+ train_trans_seq_list = self.configer.get(
661
+ 'train_trans', 'shuffle_trans_seq')
662
+ for train_trans_seq in train_trans_seq_list:
663
+ shuffle_train_trans += train_trans_seq
664
+
665
+ else:
666
+ shuffle_train_trans = self.configer.get(
667
+ 'train_trans', 'shuffle_trans_seq')
668
+ trans_seq = self.configer.get(
669
+ 'train_trans', 'trans_seq') + shuffle_train_trans
670
+ trans_key = 'train_trans'
671
+ else:
672
+ trans_seq = self.configer.get('val_trans', 'trans_seq')
673
+ trans_key = 'val_trans'
674
+
675
+ self.transforms = dict()
676
+ self.trans_config = self.configer.get(trans_key)
677
+ for trans_name in trans_seq:
678
+ specs = TRANSFORM_SPEC[trans_name]
679
+ config = self.configer.get(trans_key, trans_name)
680
+ for spec in specs:
681
+ if 'when' not in spec:
682
+ break
683
+ choose_this = True
684
+ for cond_key, cond_value in spec['when'].items():
685
+ choose_this = choose_this and (
686
+ config[cond_key] == cond_value)
687
+ if choose_this:
688
+ break
689
+ else:
690
+ raise RuntimeError("Not support!")
691
+
692
+ kwargs = {}
693
+ for arg_name, arg_path in spec["args"].items():
694
+ if isinstance(arg_path, str):
695
+ arg_value = config.get(arg_path, None)
696
+ elif isinstance(arg_path, list):
697
+ arg_value = self.configer.get(*arg_path)
698
+ kwargs[arg_name] = arg_value
699
+
700
+ klass = TRANSFORM_MAPPING[trans_name]
701
+ self.transforms[trans_name] = klass(**kwargs)
702
+
703
+ def __call__(self, img, **data_dict):
704
+
705
+ orig_key_list = list(data_dict)
706
+
707
+ if self.configer.get('data', 'input_mode') == 'RGB':
708
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
709
+
710
+ if self.split == 'train':
711
+ shuffle_trans_seq = []
712
+ if self.configer.exists('train_trans', 'shuffle_trans_seq'):
713
+ if isinstance(self.configer.get('train_trans', 'shuffle_trans_seq')[0], list):
714
+ shuffle_trans_seq_list = self.configer.get('train_trans', 'shuffle_trans_seq')
715
+ shuffle_trans_seq = shuffle_trans_seq_list[random.randint(0, len(shuffle_trans_seq_list))]
716
+ else:
717
+ shuffle_trans_seq = self.configer.get('train_trans', 'shuffle_trans_seq')
718
+ random.shuffle(shuffle_trans_seq)
719
+ trans_seq = shuffle_trans_seq + self.configer.get('train_trans', 'trans_seq')
720
+ else:
721
+ trans_seq = self.configer.get('val_trans', 'trans_seq')
722
+
723
+ for trans_key in trans_seq:
724
+ img, data_dict = self.transforms[trans_key](img, **data_dict)
725
+
726
+ if self.configer.get('data', 'input_mode') == 'RGB':
727
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
728
+
729
+ return (img, *[data_dict[key] for key in orig_key_list])
730
+
731
+ def __repr__(self):
732
+ import pprint
733
+ return 'CV2AugCompose({})'.format(pprint.pformat(self.trans_config))
734
+
735
+
736
+ TRANSFORM_MAPPING = {
737
+ "random_saturation": RandomSaturation,
738
+ "random_hue": RandomHue,
739
+ "random_perm": RandomPerm,
740
+ "random_contrast": RandomContrast,
741
+ "padding": Padding,
742
+ "random_brightness": RandomBrightness,
743
+ "random_hflip": RandomHFlip,
744
+ "random_resize": RandomResize,
745
+ "random_crop": RandomCrop,
746
+ "random_rotate": RandomRotate,
747
+ "resize": Resize,
748
+ }
749
+
750
+ TRANSFORM_SPEC = {
751
+ "random_style": [{
752
+ "args": {
753
+ "style_ratio": "ratio"
754
+ }
755
+ }],
756
+ "random_saturation": [{
757
+ "args": {
758
+ "lower": "lower",
759
+ "upper": "upper",
760
+ "saturation_ratio": "ratio"
761
+ }
762
+ }],
763
+ "random_hue": [{
764
+ "args": {
765
+ "delta": "delta",
766
+ "hue_ratio": "ratio"
767
+ }
768
+ }],
769
+ "ramdom_perm": [{
770
+ "args": {
771
+ "perm_ratio": "ratio"
772
+ }
773
+ }],
774
+ "random_contrast": [{
775
+ "args": {
776
+ "lower": "lower",
777
+ "upper": "upper",
778
+ "contrast_ratio": "ratio"
779
+ }
780
+ }],
781
+ "padding": [{
782
+ "args": {
783
+ "pad": "pad",
784
+ "pad_ratio": "ratio",
785
+ "mean": ["normalize", "mean_value"],
786
+ "allow_outside_center": "allow_outside_center"
787
+ }
788
+ }],
789
+ "random_brightness": [{
790
+ "args": {
791
+ "shift_value": "shift_value",
792
+ "brightness_ratio": "ratio"
793
+ }
794
+ }],
795
+ "random_hflip": [{
796
+ "args": {
797
+ "swap_pair": "swap_pair",
798
+ "flip_ratio": "ratio"
799
+ }
800
+ }],
801
+ "random_resize": [
802
+ {
803
+ "args": {
804
+ "method": "method",
805
+ "scale_range": "scale_range",
806
+ "aspect_range": "aspect_range",
807
+ "max_side_bound": "max_side_bound",
808
+ "resize_ratio": "ratio"
809
+ },
810
+ "when": {
811
+ "method": "random"
812
+ }
813
+ },
814
+ {
815
+ "args": {
816
+ "method": "method",
817
+ "scale_range": "scale_range",
818
+ "aspect_range": "aspect_range",
819
+ "target_size": "target_size",
820
+ "resize_ratio": "ratio"
821
+ },
822
+ "when": {
823
+ "method": "focus"
824
+ }
825
+ },
826
+ {
827
+ "args": {
828
+ "method": "method",
829
+ "aspect_range": "aspect_range",
830
+ "resize_bound": "resize_bound",
831
+ "resize_ratio": "ratio"
832
+ },
833
+ "when": {
834
+ "method": "bound"
835
+ }
836
+ },
837
+ ],
838
+ "random_crop": [
839
+ {
840
+ "args": {
841
+ "crop_size": "crop_size",
842
+ "method": "method",
843
+ "crop_ratio": "ratio",
844
+ "allow_outside_center": "allow_outside_center"
845
+ },
846
+ "when": {
847
+ "method": "random"
848
+ }
849
+ },
850
+ {
851
+ "args": {
852
+ "crop_size": "crop_size",
853
+ "method": "method",
854
+ "crop_ratio": "ratio",
855
+ "allow_outside_center": "allow_outside_center"
856
+ },
857
+ "when": {
858
+ "method": "center"
859
+ }
860
+ },
861
+ {
862
+ "args": {
863
+ "crop_size": "crop_size",
864
+ "method": "method",
865
+ "crop_ratio": "ratio",
866
+ "grid": "grid",
867
+ "allow_outside_center": "allow_outside_center"
868
+ },
869
+ "when": {
870
+ "method": "grid"
871
+ }
872
+ },
873
+ ],
874
+ "random_rotate": [{
875
+ "args": {
876
+ "max_degree": "rotate_degree",
877
+ "rotate_ratio": "ratio",
878
+ "mean": ["normalize", "mean_value"]
879
+ }
880
+ }],
881
+ "resize": [{
882
+ "args": {
883
+ "target_size": "target_size",
884
+ "min_side_length": "min_side_length",
885
+ "max_side_bound": "max_side_bound",
886
+ "max_side_length": "max_side_length"
887
+ }
888
+ }],
889
+ }
core/data/datasets/images/seg_data_tools/transforms.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+
5
+
6
+ class Normalize(object):
7
+ """Normalize a ``torch.tensor``
8
+
9
+ Args:
10
+ inputs (torch.tensor): tensor to be normalized.
11
+ mean: (list): the mean of RGB
12
+ std: (list): the std of RGB
13
+
14
+ Returns:
15
+ Tensor: Normalized tensor.
16
+ """
17
+ def __init__(self, div_value, mean, std):
18
+ self.div_value = div_value
19
+ self.mean = mean
20
+ self.std =std
21
+
22
+ def __call__(self, inputs):
23
+ inputs = inputs.div(self.div_value)
24
+ for t, m, s in zip(inputs, self.mean, self.std):
25
+ t.sub_(m).div_(s)
26
+
27
+ return inputs
28
+
29
+
30
+ class DeNormalize(object):
31
+ """DeNormalize a ``torch.tensor``
32
+
33
+ Args:
34
+ inputs (torch.tensor): tensor to be normalized.
35
+ mean: (list): the mean of RGB
36
+ std: (list): the std of RGB
37
+
38
+ Returns:
39
+ Tensor: Normalized tensor.
40
+ """
41
+ def __init__(self, div_value, mean, std):
42
+ self.div_value = div_value
43
+ self.mean = mean
44
+ self.std =std
45
+
46
+ def __call__(self, inputs):
47
+ result = inputs.clone()
48
+ for i in range(result.size(0)):
49
+ result[i, :, :] = result[i, :, :] * self.std[i] + self.mean[i]
50
+
51
+ return result.mul_(self.div_value)
52
+
53
+
54
+ class ToTensor(object):
55
+ """Convert a ``numpy.ndarray or Image`` to tensor.
56
+
57
+ See ``ToTensor`` for more details.
58
+
59
+ Args:
60
+ inputs (numpy.ndarray or Image): Image to be converted to tensor.
61
+
62
+ Returns:
63
+ Tensor: Converted image.
64
+ """
65
+ def __call__(self, inputs):
66
+ if isinstance(inputs, Image.Image):
67
+ channels = len(inputs.mode)
68
+ inputs = np.array(inputs)
69
+ inputs = inputs.reshape(inputs.shape[0], inputs.shape[1], channels)
70
+ inputs = torch.from_numpy(inputs.transpose(2, 0, 1))
71
+ else:
72
+ inputs = torch.from_numpy(inputs.transpose(2, 0, 1))
73
+
74
+ return inputs.float()
75
+
76
+
77
+ class ToLabel(object):
78
+ def __call__(self, inputs):
79
+ return torch.from_numpy(np.array(inputs)).long()
80
+
81
+
82
+ class ReLabel(object):
83
+ """
84
+ 255 indicate the background, relabel 255 to some value.
85
+ """
86
+ def __init__(self, olabel, nlabel):
87
+ self.olabel = olabel
88
+ self.nlabel = nlabel
89
+
90
+ def __call__(self, inputs):
91
+ assert isinstance(inputs, torch.LongTensor), 'tensor needs to be LongTensor'
92
+
93
+ inputs[inputs == self.olabel] = self.nlabel
94
+ return inputs
95
+
96
+
97
+ class Compose(object):
98
+
99
+ def __init__(self, transforms):
100
+ self.transforms = transforms
101
+
102
+ def __call__(self, inputs):
103
+ for t in self.transforms:
104
+ inputs = t(inputs)
105
+
106
+ return inputs
core/data/datasets/images/seg_dataset_dev.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import cv2
4
+ import torch
5
+
6
+ import numpy as np
7
+ import itertools
8
+ from typing import Any, Dict, List, Tuple, Union
9
+ from torch.utils import data
10
+ from torch.nn import functional as F
11
+ from PIL import Image
12
+ from pathlib import Path
13
+
14
+ import core.data.transforms.seg_aug_dev as T
15
+ from core.data.transforms.seg_transforms_dev import AugInput, apply_transform_gens
16
+
17
+
18
+ class Instances:
19
+ """
20
+ This class represents a list of instances in an image.
21
+ It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields".
22
+ All fields must have the same ``__len__`` which is the number of instances.
23
+
24
+ All other (non-field) attributes of this class are considered private:
25
+ they must start with '_' and are not modifiable by a user.
26
+
27
+ Some basic usage:
28
+
29
+ 1. Set/get/check a field:
30
+
31
+ .. code-block:: python
32
+
33
+ instances.gt_boxes = Boxes(...)
34
+ print(instances.pred_masks) # a tensor of shape (N, H, W)
35
+ print('gt_masks' in instances)
36
+
37
+ 2. ``len(instances)`` returns the number of instances
38
+ 3. Indexing: ``instances[indices]`` will apply the indexing on all the fields
39
+ and returns a new :class:`Instances`.
40
+ Typically, ``indices`` is a integer vector of indices,
41
+ or a binary mask of length ``num_instances``
42
+
43
+ .. code-block:: python
44
+
45
+ category_3_detections = instances[instances.pred_classes == 3]
46
+ confident_detections = instances[instances.scores > 0.9]
47
+ """
48
+
49
+ def __init__(self, image_size: Tuple[int, int], **kwargs: Any):
50
+ """
51
+ Args:
52
+ image_size (height, width): the spatial size of the image.
53
+ kwargs: fields to add to this `Instances`.
54
+ """
55
+ self._image_size = image_size
56
+ self._fields: Dict[str, Any] = {}
57
+ for k, v in kwargs.items():
58
+ self.set(k, v)
59
+
60
+ @property
61
+ def image_size(self) -> Tuple[int, int]:
62
+ """
63
+ Returns:
64
+ tuple: height, width
65
+ """
66
+ return self._image_size
67
+
68
+ def __setattr__(self, name: str, val: Any) -> None:
69
+ if name.startswith("_"):
70
+ super().__setattr__(name, val)
71
+ else:
72
+ self.set(name, val)
73
+
74
+ def __getattr__(self, name: str) -> Any:
75
+ if name == "_fields" or name not in self._fields:
76
+ raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
77
+ return self._fields[name]
78
+
79
+ def set(self, name: str, value: Any) -> None:
80
+ """
81
+ Set the field named `name` to `value`.
82
+ The length of `value` must be the number of instances,
83
+ and must agree with other existing fields in this object.
84
+ """
85
+ data_len = len(value)
86
+ if len(self._fields):
87
+ assert (
88
+ len(self) == data_len
89
+ ), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
90
+ self._fields[name] = value
91
+
92
+ def has(self, name: str) -> bool:
93
+ """
94
+ Returns:
95
+ bool: whether the field called `name` exists.
96
+ """
97
+ return name in self._fields
98
+
99
+ def remove(self, name: str) -> None:
100
+ """
101
+ Remove the field called `name`.
102
+ """
103
+ del self._fields[name]
104
+
105
+ def get(self, name: str) -> Any:
106
+ """
107
+ Returns the field called `name`.
108
+ """
109
+ return self._fields[name]
110
+
111
+ def get_fields(self) -> Dict[str, Any]:
112
+ """
113
+ Returns:
114
+ dict: a dict which maps names (str) to data of the fields
115
+
116
+ Modifying the returned dict will modify this instance.
117
+ """
118
+ return self._fields
119
+
120
+ # Tensor-like methods
121
+ def cuda(self, *args: Any, **kwargs: Any) -> "Instances":
122
+ """
123
+ Returns:
124
+ Instances: all fields are called with a `cuda`, if the field has this method.
125
+ """
126
+ ret = Instances(self._image_size)
127
+ for k, v in self._fields.items():
128
+ if hasattr(v, "cuda"):
129
+ v = v.cuda(*args, **kwargs)
130
+ ret.set(k, v)
131
+ return ret
132
+
133
+ def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Instances":
134
+ """
135
+ Args:
136
+ item: an index-like object and will be used to index all the fields.
137
+
138
+ Returns:
139
+ If `item` is a string, return the data in the corresponding field.
140
+ Otherwise, returns an `Instances` where all fields are indexed by `item`.
141
+ """
142
+ if type(item) == int:
143
+ if item >= len(self) or item < -len(self):
144
+ raise IndexError("Instances index out of range!")
145
+ else:
146
+ item = slice(item, None, len(self))
147
+
148
+ ret = Instances(self._image_size)
149
+ for k, v in self._fields.items():
150
+ ret.set(k, v[item])
151
+ return ret
152
+
153
+ def __len__(self) -> int:
154
+ for v in self._fields.values():
155
+ # use __len__ because len() has to be int and is not friendly to tracing
156
+ return v.__len__()
157
+ raise NotImplementedError("Empty Instances does not support __len__!")
158
+
159
+ def __iter__(self):
160
+ raise NotImplementedError("`Instances` object is not iterable!")
161
+
162
+ @staticmethod
163
+ def cat(instance_lists: List["Instances"]) -> "Instances":
164
+ """
165
+ Args:
166
+ instance_lists (list[Instances])
167
+
168
+ Returns:
169
+ Instances
170
+ """
171
+ assert all(isinstance(i, Instances) for i in instance_lists)
172
+ assert len(instance_lists) > 0
173
+ if len(instance_lists) == 1:
174
+ return instance_lists[0]
175
+
176
+ image_size = instance_lists[0].image_size
177
+ if not isinstance(image_size, torch.Tensor): # could be a tensor in tracing
178
+ for i in instance_lists[1:]:
179
+ assert i.image_size == image_size
180
+ ret = Instances(image_size)
181
+ for k in instance_lists[0]._fields.keys():
182
+ values = [i.get(k) for i in instance_lists]
183
+ v0 = values[0]
184
+ if isinstance(v0, torch.Tensor):
185
+ values = torch.cat(values, dim=0)
186
+ elif isinstance(v0, list):
187
+ values = list(itertools.chain(*values))
188
+ elif hasattr(type(v0), "cat"):
189
+ values = type(v0).cat(values)
190
+ else:
191
+ raise ValueError("Unsupported type {} for concatenation".format(type(v0)))
192
+ ret.set(k, values)
193
+ return ret
194
+
195
+ def __str__(self) -> str:
196
+ s = self.__class__.__name__ + "("
197
+ s += "num_instances={}, ".format(len(self))
198
+ s += "image_height={}, ".format(self._image_size[0])
199
+ s += "image_width={}, ".format(self._image_size[1])
200
+ s += "fields=[{}])".format(", ".join((f"{k}: {v}" for k, v in self._fields.items())))
201
+ return s
202
+
203
+ __repr__ = __str__
204
+
205
+
206
+ class BitMasks:
207
+ """
208
+ This class stores the segmentation masks for all objects in one image, in
209
+ the form of bitmaps.
210
+
211
+ Attributes:
212
+ tensor: bool Tensor of N,H,W, representing N instances in the image.
213
+ """
214
+
215
+ def __init__(self, tensor: Union[torch.Tensor, np.ndarray]):
216
+ """
217
+ Args:
218
+ tensor: bool Tensor of N,H,W, representing N instances in the image.
219
+ """
220
+ device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
221
+ tensor = torch.as_tensor(tensor, dtype=torch.bool, device=device)
222
+ assert tensor.dim() == 3, tensor.size()
223
+ self.image_size = tensor.shape[1:]
224
+ self.tensor = tensor
225
+
226
+ def to(self, *args: Any, **kwargs: Any) -> "BitMasks":
227
+ return BitMasks(self.tensor.to(*args, **kwargs))
228
+
229
+ @property
230
+ def device(self) -> torch.device:
231
+ return self.tensor.device
232
+
233
+ def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "BitMasks":
234
+ """
235
+ Returns:
236
+ BitMasks: Create a new :class:`BitMasks` by indexing.
237
+
238
+ The following usage are allowed:
239
+
240
+ 1. `new_masks = masks[3]`: return a `BitMasks` which contains only one mask.
241
+ 2. `new_masks = masks[2:10]`: return a slice of masks.
242
+ 3. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
243
+ with `length = len(masks)`. Nonzero elements in the vector will be selected.
244
+
245
+ Note that the returned object might share storage with this object,
246
+ subject to Pytorch's indexing semantics.
247
+ """
248
+ if isinstance(item, int):
249
+ return BitMasks(self.tensor[item].unsqueeze(0))
250
+ m = self.tensor[item]
251
+ assert m.dim() == 3, "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
252
+ item, m.shape
253
+ )
254
+ return BitMasks(m)
255
+
256
+ def __iter__(self) -> torch.Tensor:
257
+ yield from self.tensor
258
+
259
+ def __repr__(self) -> str:
260
+ s = self.__class__.__name__ + "("
261
+ s += "num_instances={})".format(len(self.tensor))
262
+ return s
263
+
264
+ def __len__(self) -> int:
265
+ return self.tensor.shape[0]
266
+
267
+ def nonempty(self) -> torch.Tensor:
268
+ """
269
+ Find masks that are non-empty.
270
+
271
+ Returns:
272
+ Tensor: a BoolTensor which represents
273
+ whether each mask is empty (False) or non-empty (True).
274
+ """
275
+ return self.tensor.flatten(1).any(dim=1)
276
+
277
+ @staticmethod
278
+ def cat(bitmasks_list: List["BitMasks"]) -> "BitMasks":
279
+ """
280
+ Concatenates a list of BitMasks into a single BitMasks
281
+
282
+ Arguments:
283
+ bitmasks_list (list[BitMasks])
284
+
285
+ Returns:
286
+ BitMasks: the concatenated BitMasks
287
+ """
288
+ assert isinstance(bitmasks_list, (list, tuple))
289
+ assert len(bitmasks_list) > 0
290
+ assert all(isinstance(bitmask, BitMasks) for bitmask in bitmasks_list)
291
+
292
+ cat_bitmasks = type(bitmasks_list[0])(torch.cat([bm.tensor for bm in bitmasks_list], dim=0))
293
+ return cat_bitmasks
core/data/datasets/images/smpl_data_tools/__pycache__/_smpl.cpython-312.pyc ADDED
Binary file (22.5 kB). View file
 
core/data/datasets/images/smpl_data_tools/__pycache__/config_smpl.cpython-312.pyc ADDED
Binary file (2.35 kB). View file
 
core/data/datasets/images/smpl_data_tools/__pycache__/image_ops.cpython-312.pyc ADDED
Binary file (12.1 kB). View file
 
core/data/datasets/images/smpl_data_tools/__pycache__/tsv_file.cpython-312.pyc ADDED
Binary file (11.4 kB). View file
 
core/data/datasets/images/smpl_data_tools/_smpl.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains the definition of the SMPL model
3
+
4
+ It is adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/)
5
+ """
6
+ from __future__ import division
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import numpy as np
11
+ import scipy.sparse
12
+ import pickle
13
+
14
+ from . import config_smpl as cfg
15
+
16
+ def rodrigues(theta):
17
+ """Convert axis-angle representation to rotation matrix.
18
+ Args:
19
+ theta: size = [B, 3]
20
+ Returns:
21
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
22
+ """
23
+ l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
24
+ angle = torch.unsqueeze(l1norm, -1)
25
+ normalized = torch.div(theta, angle)
26
+ angle = angle * 0.5
27
+ v_cos = torch.cos(angle)
28
+ v_sin = torch.sin(angle)
29
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
30
+ return quat2mat(quat)
31
+
32
+ def quat2mat(quat):
33
+ """Convert quaternion coefficients to rotation matrix.
34
+ Args:
35
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
36
+ Returns:
37
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
38
+ """
39
+ norm_quat = quat
40
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
41
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
42
+
43
+ B = quat.size(0)
44
+
45
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
46
+ wx, wy, wz = w*x, w*y, w*z
47
+ xy, xz, yz = x*y, x*z, y*z
48
+
49
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
50
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
51
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
52
+ return rotMat
53
+
54
+ def orthographic_projection(X, camera):
55
+ """Perform orthographic projection of 3D points X using the camera parameters
56
+ Args:
57
+ X: size = [B, N, 3]
58
+ camera: size = [B, 3]
59
+ Returns:
60
+ Projected 2D points -- size = [B, N, 2]
61
+ """
62
+ camera = camera.view(-1, 1, 3)
63
+ X_trans = X[:, :, :2] + camera[:, :, 1:]
64
+ shape = X_trans.shape
65
+ X_2d = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape)
66
+ return X_2d
67
+
68
+ class SMPL(nn.Module):
69
+
70
+ def __init__(self, gender='neutral'):
71
+ super(SMPL, self).__init__()
72
+
73
+ if gender=='m':
74
+ model_file=cfg.SMPL_Male
75
+ elif gender=='f':
76
+ model_file=cfg.SMPL_Female
77
+ else:
78
+ model_file=cfg.SMPL_FILE
79
+
80
+ smpl_model = pickle.load(open(model_file, 'rb'), encoding='latin1')
81
+ J_regressor = smpl_model['J_regressor'].tocoo()
82
+ row = J_regressor.row
83
+ col = J_regressor.col
84
+ data = J_regressor.data
85
+ i = torch.LongTensor([row, col])
86
+ v = torch.FloatTensor(data)
87
+ J_regressor_shape = [24, 6890]
88
+ self.register_buffer('J_regressor', torch.sparse.FloatTensor(i, v, J_regressor_shape).to_dense()) # 24*6890
89
+ self.register_buffer('weights', torch.FloatTensor(smpl_model['weights'])) # 6890 * 24
90
+ self.register_buffer('posedirs', torch.FloatTensor(smpl_model['posedirs'])) # 6890*3*207
91
+ self.register_buffer('v_template', torch.FloatTensor(smpl_model['v_template'])) # 6890*3
92
+ self.register_buffer('shapedirs', torch.FloatTensor(np.array(smpl_model['shapedirs']))) # # 6890*3*10
93
+ self.register_buffer('faces', torch.from_numpy(smpl_model['f'].astype(np.int64))) # 13776 * 3
94
+ self.register_buffer('kintree_table', torch.from_numpy(smpl_model['kintree_table'].astype(np.int64))) # 2*24
95
+ id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])}
96
+ self.register_buffer('parent', torch.LongTensor([id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])]))
97
+
98
+ self.pose_shape = [24, 3]
99
+ self.beta_shape = [10]
100
+ self.translation_shape = [3]
101
+
102
+ self.pose = torch.zeros(self.pose_shape)
103
+ self.beta = torch.zeros(self.beta_shape)
104
+ self.translation = torch.zeros(self.translation_shape)
105
+
106
+ self.verts = None
107
+ self.J = None
108
+ self.R = None
109
+
110
+ J_regressor_extra = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_TRAIN_EXTRA)).float() # 14*6890
111
+ self.register_buffer('J_regressor_extra', J_regressor_extra)
112
+ self.joints_idx = cfg.JOINTS_IDX
113
+
114
+ J_regressor_h36m_correct = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M_correct)).float() # 17*6890
115
+ self.register_buffer('J_regressor_h36m_correct', J_regressor_h36m_correct)
116
+
117
+
118
+ def forward(self, pose, beta):
119
+ device = pose.device
120
+ batch_size = pose.shape[0]
121
+ v_template = self.v_template[None, :]
122
+ shapedirs = self.shapedirs.view(-1,10)[None, :].expand(batch_size, -1, -1)
123
+ beta = beta[:, :, None]
124
+ # print(f'pose device {pose.device} beta device {beta.device} smpl parameter device {shapedirs.device}')
125
+ v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template
126
+ # batched sparse matmul not supported in pytorch
127
+ J = []
128
+ for i in range(batch_size):
129
+ J.append(torch.matmul(self.J_regressor, v_shaped[i]))
130
+ J = torch.stack(J, dim=0)
131
+ # input it rotmat: (bs,24,3,3)
132
+ if pose.ndimension() == 4:
133
+ R = pose
134
+ # input it rotmat: (bs,72)
135
+ elif pose.ndimension() == 2:
136
+ pose_cube = pose.view(-1, 3) # (batch_size * 24, 1, 3)
137
+ R = rodrigues(pose_cube).view(batch_size, 24, 3, 3)
138
+ R = R.view(batch_size, 24, 3, 3)
139
+ I_cube = torch.eye(3)[None, None, :].to(device)
140
+
141
+ lrotmin = (R[:,1:,:] - I_cube).view(batch_size, -1)
142
+ posedirs = self.posedirs.view(-1,207)[None, :].expand(batch_size, -1, -1)
143
+ v_posed = v_shaped + torch.matmul(posedirs, lrotmin[:, :, None]).view(-1, 6890, 3)
144
+ J_ = J.clone()
145
+ J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :]
146
+ G_ = torch.cat([R, J_[:, :, :, None]], dim=-1)
147
+ pad_row = torch.FloatTensor([0,0,0,1]).to(device).view(1,1,1,4).expand(batch_size, 24, -1, -1)
148
+ G_ = torch.cat([G_, pad_row], dim=2)
149
+ G = [G_[:, 0].clone()]
150
+ for i in range(1, 24):
151
+ G.append(torch.matmul(G[self.parent[i-1]], G_[:, i, :, :]))
152
+ G = torch.stack(G, dim=1)
153
+
154
+ rest = torch.cat([J, torch.zeros(batch_size, 24, 1).to(device)], dim=2).view(batch_size, 24, 4, 1)
155
+ zeros = torch.zeros(batch_size, 24, 4, 3).to(device)
156
+ rest = torch.cat([zeros, rest], dim=-1)
157
+ rest = torch.matmul(G, rest)
158
+ G = G - rest
159
+ T = torch.matmul(self.weights, G.permute(1,0,2,3).contiguous().view(24,-1)).view(6890, batch_size, 4, 4).transpose(0,1)
160
+ rest_shape_h = torch.cat([v_posed, torch.ones_like(v_posed)[:, :, [0]]], dim=-1)
161
+ v = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0]
162
+ return v
163
+
164
+ def get_joints(self, vertices):
165
+ """
166
+ This method is used to get the joint locations from the SMPL mesh
167
+ Input:
168
+ vertices: size = (B, 6890, 3)
169
+ Output:
170
+ 3D joints: size = (B, 38, 3)
171
+ """
172
+ joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor])
173
+ joints_extra = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_extra])
174
+ joints = torch.cat((joints, joints_extra), dim=1)
175
+ joints = joints[:, cfg.JOINTS_IDX]
176
+ return joints
177
+
178
+ def get_h36m_joints(self, vertices):
179
+ """
180
+ This method is used to get the joint locations from the SMPL mesh
181
+ Input:
182
+ vertices: size = (B, 6890, 3)
183
+ Output:
184
+ 3D joints: size = (B, 17, 3)
185
+ """
186
+ joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct])
187
+ return joints
188
+
189
+ class SparseMM(torch.autograd.Function):
190
+ """Redefine sparse @ dense matrix multiplication to enable backpropagation.
191
+ The builtin matrix multiplication operation does not support backpropagation in some cases.
192
+ """
193
+ @staticmethod
194
+ def forward(ctx, sparse, dense):
195
+ ctx.req_grad = dense.requires_grad
196
+ ctx.save_for_backward(sparse)
197
+ return torch.matmul(sparse, dense)
198
+
199
+ @staticmethod
200
+ def backward(ctx, grad_output):
201
+ grad_input = None
202
+ sparse, = ctx.saved_tensors
203
+ if ctx.req_grad:
204
+ grad_input = torch.matmul(sparse.t(), grad_output)
205
+ return None, grad_input
206
+
207
+ def spmm(sparse, dense):
208
+ return SparseMM.apply(sparse, dense)
209
+
210
+
211
+ def scipy_to_pytorch(A, U, D):
212
+ """Convert scipy sparse matrices to pytorch sparse matrix."""
213
+ ptU = []
214
+ ptD = []
215
+
216
+ for i in range(len(U)):
217
+ u = scipy.sparse.coo_matrix(U[i])
218
+ i = torch.LongTensor(np.array([u.row, u.col]))
219
+ v = torch.FloatTensor(u.data)
220
+ # return index value and shape instead of a sparse tensor to avoid bug in multi-worker
221
+ ptU.append([i, v, u.shape])
222
+ for i in range(len(D)):
223
+ d = scipy.sparse.coo_matrix(D[i])
224
+ i = torch.LongTensor(np.array([d.row, d.col]))
225
+ v = torch.FloatTensor(d.data)
226
+ # return index value and shape instead of a sparse tensor to avoid bug in multi-worker
227
+ ptD.append([i, v, d.shape])
228
+
229
+ return ptU, ptD
230
+
231
+
232
+ def adjmat_sparse(adjmat, nsize=1):
233
+ """Create row-normalized sparse graph adjacency matrix."""
234
+ adjmat = scipy.sparse.csr_matrix(adjmat)
235
+ if nsize > 1:
236
+ orig_adjmat = adjmat.copy()
237
+ for _ in range(1, nsize):
238
+ adjmat = adjmat * orig_adjmat
239
+ adjmat.data = np.ones_like(adjmat.data)
240
+ for i in range(adjmat.shape[0]):
241
+ adjmat[i,i] = 1
242
+ num_neighbors = np.array(1 / adjmat.sum(axis=-1))
243
+ adjmat = adjmat.multiply(num_neighbors)
244
+ adjmat = scipy.sparse.coo_matrix(adjmat)
245
+ row = adjmat.row
246
+ col = adjmat.col
247
+ data = adjmat.data
248
+ i = torch.LongTensor(np.array([row, col]))
249
+ v = torch.from_numpy(data).float()
250
+ # adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape)
251
+ # return index value and shape instead of a sparse tensor to avoid bug in multi-worker
252
+
253
+
254
+ return [i, v, adjmat.shape]
255
+
256
+ def get_graph_params(filename, nsize=1):
257
+ """Load and process graph adjacency matrix and upsampling/downsampling matrices."""
258
+ data = np.load(filename, encoding='latin1', allow_pickle=True)
259
+ A = data['A']
260
+ U = data['U']
261
+ D = data['D']
262
+ U, D = scipy_to_pytorch(A, U, D)
263
+ A = [adjmat_sparse(a, nsize=nsize) for a in A]
264
+ return A, U, D
265
+
266
+
267
+ class Mesh(object):
268
+ """Mesh object that is used for handling certain graph operations."""
269
+ def __init__(self, filename=cfg.SMPL_sampling_matrix,
270
+ num_downsampling=1, nsize=1, device=torch.device('cuda')):
271
+ super(Mesh, self).__init__()
272
+
273
+ self.device = device
274
+ self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
275
+ # self._A = [a.to(device) for a in self._A]
276
+ # self._U = [u.to(device) for u in self._U]
277
+ # self._D = [d.to(device) for d in self._D]
278
+
279
+ self.num_downsampling = num_downsampling
280
+
281
+ # load template vertices from SMPL and normalize them
282
+ smpl = SMPL()
283
+ ref_vertices = smpl.v_template
284
+ center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None]
285
+ ref_vertices -= center
286
+ ref_vertices /= ref_vertices.abs().max().item()
287
+
288
+ self._ref_vertices = ref_vertices.to(device)
289
+ self.faces = smpl.faces.int().to(device)
290
+ self.sparse = False
291
+
292
+ @property
293
+ def ref_vertices(self):
294
+ """Return the template vertices at the specified subsampling level."""
295
+ _D = [torch.sparse.FloatTensor(item[0], item[1], item[2]).to(self.device) for item in self._D]
296
+ ref_vertices = self._ref_vertices
297
+ for i in range(self.num_downsampling):
298
+ ref_vertices = torch.spmm(_D[i], ref_vertices)
299
+ return ref_vertices
300
+
301
+ def downsample(self, x, n1=0, n2=None):
302
+ _D = [torch.sparse.FloatTensor(item[0], item[1], item[2]).to(self.device) for item in self._D]
303
+ """Downsample mesh."""
304
+ if n2 is None:
305
+ n2 = self.num_downsampling
306
+ if x.ndimension() < 3:
307
+ for i in range(n1, n2):
308
+ x = spmm(_D[i], x)
309
+ elif x.ndimension() == 3:
310
+ out = []
311
+ for i in range(x.shape[0]):
312
+ y = x[i]
313
+ for j in range(n1, n2):
314
+ y = spmm(_D[j], y)
315
+ out.append(y)
316
+ x = torch.stack(out, dim=0)
317
+ return x
318
+
319
+ def upsample(self, x, n1=1, n2=0):
320
+ _U = [torch.sparse.FloatTensor(item[0], item[1], item[2]).to(self.device) for item in self._U]
321
+ """Upsample mesh."""
322
+ if x.ndimension() < 3:
323
+ for i in reversed(range(n2, n1)):
324
+ x = spmm(_U[i], x)
325
+ elif x.ndimension() == 3:
326
+ out = []
327
+ for i in range(x.shape[0]):
328
+ y = x[i]
329
+ for j in reversed(range(n2, n1)):
330
+ y = spmm(_U[j], y)
331
+ out.append(y)
332
+ x = torch.stack(out, dim=0)
333
+ return x
core/data/datasets/images/smpl_data_tools/config_smpl.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains definitions of useful data stuctures and the paths
3
+ for the datasets and data files necessary to run the code.
4
+
5
+ Adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) and Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
6
+
7
+ """
8
+
9
+ from os.path import join,split,abspath
10
+ # from os import getcwd
11
+ import sys
12
+
13
+ dirname, filename = split(abspath(sys.argv[0]))
14
+
15
+ folder_path = join(dirname,'core/data/datasets/images/smpl_data_tools/smpl_modeling/')
16
+ # print("current path {} ".format(folder_path))
17
+ JOINT_REGRESSOR_TRAIN_EXTRA = folder_path + 'data/J_regressor_extra.npy'
18
+ JOINT_REGRESSOR_H36M_correct = folder_path + 'data/J_regressor_h36m_correct.npy'
19
+ SMPL_FILE = folder_path + 'data/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl'
20
+ SMPL_Male = folder_path + 'data/basicModel_m_lbs_10_207_0_v1.0.0.pkl'
21
+ SMPL_Female = folder_path + 'data/basicModel_f_lbs_10_207_0_v1.0.0.pkl'
22
+ SMPL_sampling_matrix = folder_path + 'data/mesh_downsampling.npz'
23
+ MANO_FILE = folder_path + 'data/MANO_RIGHT.pkl'
24
+ MANO_sampling_matrix = folder_path + 'data/mano_downsampling.npz'
25
+
26
+ JOINTS_IDX = [8, 5, 29, 30, 4, 7, 21, 19, 17, 16, 18, 20, 31, 32, 33, 34, 35, 36, 37, 24, 26, 25, 28, 27]
27
+
28
+
29
+ """
30
+ We follow the body joint definition, loss functions, and evaluation metrics from
31
+ open source project GraphCMR (https://github.com/nkolot/GraphCMR/)
32
+
33
+ Each dataset uses different sets of joints.
34
+ We use a superset of 24 joints such that we include all joints from every dataset.
35
+ If a dataset doesn't provide annotations for a specific joint, we simply ignore it.
36
+ The joints used here are:
37
+ """
38
+ J24_NAME = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder',
39
+ 'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear')
40
+ H36M_J17_NAME = ( 'Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head',
41
+ 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist')
42
+ J24_TO_J14 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
43
+ H36M_J17_TO_J14 = [3, 2, 1, 4, 5, 6, 16, 15, 14, 11, 12, 13, 8, 10]
44
+
45
+ """
46
+ We follow the hand joint definition and mesh topology from
47
+ open source project Manopth (https://github.com/hassony2/manopth)
48
+
49
+ The hand joints used here are:
50
+ """
51
+ J_NAME = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1',
52
+ 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
53
+ ROOT_INDEX = 0
core/data/datasets/images/smpl_data_tools/image_ops.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----------------------------------------------------------------------------------------------
2
+ # METRO (https://github.com/microsoft/MeshTransformer)
3
+ # Copyright (c) Microsoft Corporation. All Rights Reserved [see https://github.com/microsoft/MeshTransformer/blob/main/LICENSE for details]
4
+ # Licensed under the MIT license.
5
+ # ----------------------------------------------------------------------------------------------
6
+ """
7
+ Image processing tools
8
+ Modified from open source projects:
9
+ (https://github.com/nkolot/GraphCMR/)
10
+ (https://github.com/open-mmlab/mmdetection)
11
+ """
12
+
13
+ import numpy as np
14
+ import base64
15
+ import cv2
16
+ import torch
17
+ import scipy.misc
18
+
19
+ def img_from_base64(imagestring):
20
+ try:
21
+ jpgbytestring = base64.b64decode(imagestring)
22
+ nparr = np.frombuffer(jpgbytestring, np.uint8)
23
+ r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
24
+ return r
25
+ except ValueError:
26
+ return None
27
+
28
+ def myimrotate(img, angle, center=None, scale=1.0, border_value=0, auto_bound=False):
29
+ if center is not None and auto_bound:
30
+ raise ValueError('`auto_bound` conflicts with `center`')
31
+ try:
32
+ h, w = img.shape[:2]
33
+ except:
34
+ h, w = img.size[:2]
35
+ if center is None:
36
+ center = ((w - 1) * 0.5, (h - 1) * 0.5)
37
+ assert isinstance(center, tuple)
38
+
39
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
40
+ if auto_bound:
41
+ cos = np.abs(matrix[0, 0])
42
+ sin = np.abs(matrix[0, 1])
43
+ new_w = h * sin + w * cos
44
+ new_h = h * cos + w * sin
45
+ matrix[0, 2] += (new_w - w) * 0.5
46
+ matrix[1, 2] += (new_h - h) * 0.5
47
+ w = int(np.round(new_w))
48
+ h = int(np.round(new_h))
49
+ rotated = cv2.warpAffine(img, matrix, (w, h), borderValue=border_value)
50
+ return rotated
51
+
52
+ def myimresize(img, size, return_scale=False, interpolation='bilinear'):
53
+
54
+ try:
55
+ h, w = img.shape[:2]
56
+ except:
57
+ h, w = img.size[:2]
58
+ resized_img = cv2.resize(
59
+ img, (size[0],size[1]), interpolation=cv2.INTER_LINEAR)
60
+ if not return_scale:
61
+ return resized_img
62
+ else:
63
+ w_scale = size[0] / w
64
+ h_scale = size[1] / h
65
+ return resized_img, w_scale, h_scale
66
+
67
+
68
+ def get_transform(center, scale, res, rot=0):
69
+ """Generate transformation matrix."""
70
+ h = 200 * scale
71
+ t = np.zeros((3, 3))
72
+ t[0, 0] = float(res[1]) / h
73
+ t[1, 1] = float(res[0]) / h
74
+ t[0, 2] = res[1] * (-float(center[0]) / h + .5)
75
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
76
+ t[2, 2] = 1
77
+ if not rot == 0:
78
+ rot = -rot # To match direction of rotation from cropping
79
+ rot_mat = np.zeros((3,3))
80
+ rot_rad = rot * np.pi / 180
81
+ sn,cs = np.sin(rot_rad), np.cos(rot_rad)
82
+ rot_mat[0,:2] = [cs, -sn]
83
+ rot_mat[1,:2] = [sn, cs]
84
+ rot_mat[2,2] = 1
85
+ # Need to rotate around center
86
+ t_mat = np.eye(3)
87
+ t_mat[0,2] = -res[1]/2
88
+ t_mat[1,2] = -res[0]/2
89
+ t_inv = t_mat.copy()
90
+ t_inv[:2,2] *= -1
91
+ t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
92
+ return t
93
+
94
+ def transform(pt, center, scale, res, invert=0, rot=0):
95
+ """Transform pixel location to different reference."""
96
+ t = get_transform(center, scale, res, rot=rot)
97
+ if invert:
98
+ # t = np.linalg.inv(t)
99
+ t_torch = torch.from_numpy(t)
100
+ t_torch = torch.inverse(t_torch)
101
+ t = t_torch.numpy()
102
+ new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T
103
+ new_pt = np.dot(t, new_pt)
104
+ return new_pt[:2].astype(int)+1
105
+
106
+ def crop(img, center, scale, res, rot=0):
107
+ """Crop image according to the supplied bounding box."""
108
+ # Upper left point
109
+ ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
110
+ # Bottom right point
111
+ br = np.array(transform([res[0]+1,
112
+ res[1]+1], center, scale, res, invert=1))-1
113
+ # Padding so that when rotated proper amount of context is included
114
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
115
+ if not rot == 0:
116
+ ul -= pad
117
+ br += pad
118
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
119
+ try:
120
+ image_shape = img.shape
121
+ except:
122
+ image_shape = img.size
123
+ if len(image_shape) > 2:
124
+ new_shape += [image_shape[2]]
125
+ new_img = np.zeros(new_shape)
126
+
127
+ # Range to fill new array
128
+ new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
129
+ new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
130
+ # Range to sample from original image
131
+ old_x = max(0, ul[0]), min(len(img[0]), br[0])
132
+ old_y = max(0, ul[1]), min(len(img), br[1])
133
+
134
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
135
+ old_x[0]:old_x[1]]
136
+ if not rot == 0:
137
+ # Remove padding
138
+ # new_img = scipy.misc.imrotate(new_img, rot)
139
+ new_img = myimrotate(new_img, rot)
140
+ new_img = new_img[pad:-pad, pad:-pad]
141
+
142
+ # new_img = scipy.misc.imresize(new_img, res)
143
+ new_img = myimresize(new_img, [res[0], res[1]])
144
+ return new_img
145
+
146
+ def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
147
+ """'Undo' the image cropping/resizing.
148
+ This function is used when evaluating mask/part segmentation.
149
+ """
150
+ try:
151
+ res = img.shape[:2]
152
+ except:
153
+ res = img.size[:2]
154
+ # Upper left point
155
+ ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
156
+ # Bottom right point
157
+ br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1
158
+ # size of cropped image
159
+ crop_shape = [br[1] - ul[1], br[0] - ul[0]]
160
+
161
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
162
+
163
+ try:
164
+ image_shape = img.shape
165
+ except:
166
+ image_shape = img.size
167
+
168
+ if len(image_shape) > 2:
169
+ new_shape += [image_shape[2]]
170
+ new_img = np.zeros(orig_shape, dtype=np.uint8)
171
+ # Range to fill new array
172
+ new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
173
+ new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
174
+ # Range to sample from original image
175
+ old_x = max(0, ul[0]), min(orig_shape[1], br[0])
176
+ old_y = max(0, ul[1]), min(orig_shape[0], br[1])
177
+ # img = scipy.misc.imresize(img, crop_shape, interp='nearest')
178
+ img = myimresize(img, [crop_shape[0],crop_shape[1]])
179
+ new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
180
+ return new_img
181
+
182
+ def rot_aa(aa, rot):
183
+ """Rotate axis angle parameters."""
184
+ # pose parameters
185
+ R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
186
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
187
+ [0, 0, 1]])
188
+ # find the rotation of the body in camera frame
189
+ per_rdg, _ = cv2.Rodrigues(aa)
190
+ # apply the global rotation to the global orientation
191
+ resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
192
+ aa = (resrot.T)[0]
193
+ return aa
194
+
195
+ def flip_img(img):
196
+ """Flip rgb images or masks.
197
+ channels come last, e.g. (256,256,3).
198
+ """
199
+ img = np.fliplr(img)
200
+ return img
201
+
202
+ def flip_kp(kp):
203
+ """Flip keypoints."""
204
+ flipped_parts = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22]
205
+ kp = kp[flipped_parts]
206
+ kp[:,0] = - kp[:,0]
207
+ return kp
208
+
209
+ def flip_pose(pose):
210
+ """Flip pose.
211
+ The flipping is based on SMPL parameters.
212
+ """
213
+ flippedParts = [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13,
214
+ 14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33,
215
+ 34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41,
216
+ 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55,
217
+ 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68]
218
+ pose = pose[flippedParts]
219
+ # we also negate the second and the third dimension of the axis-angle
220
+ pose[1::3] = -pose[1::3]
221
+ pose[2::3] = -pose[2::3]
222
+ return pose
223
+
224
+ def flip_aa(aa):
225
+ """Flip axis-angle representation.
226
+ We negate the second and the third dimension of the axis-angle.
227
+ """
228
+ aa[1] = -aa[1]
229
+ aa[2] = -aa[2]
230
+ return aa
core/data/datasets/images/smpl_data_tools/smpl_modeling/data/J_regressor_extra.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40dfaa71fcc7eed6966a6ed046311b7e8ea0eb9a5172b298e3df6fc4b6ec0eb0
3
+ size 771808