File size: 3,315 Bytes
5d756f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
from PIL import Image
import numpy as np
import multiprocessing
import io
from tops import logger
from torch.utils.data._utils.collate import default_collate

try:
    import pyspng

    PYSPNG_IMPORTED = True
except ImportError:
    PYSPNG_IMPORTED = False
    print("Could not load pyspng. Defaulting to pillow image backend.")
    from PIL import Image


def get_fdf_keypoints():
    return get_coco_keypoints()[:7]


def get_fdf_flipmap():
    keypoints = get_fdf_keypoints()
    keypoint_flip_map = {
        "left_eye": "right_eye",
        "left_ear": "right_ear",
        "left_shoulder": "right_shoulder",
    }
    for key, value in list(keypoint_flip_map.items()):
        keypoint_flip_map[value] = key
    keypoint_flip_map["nose"] = "nose"
    keypoint_flip_map_idx = []
    for source in keypoints:
        keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
    return keypoint_flip_map_idx


def get_coco_keypoints():
    return [
        "nose",
        "left_eye",
        "right_eye",  # 2
        "left_ear",
        "right_ear",  # 4
        "left_shoulder",
        "right_shoulder",  # 6
        "left_elbow",
        "right_elbow",  # 8
        "left_wrist",
        "right_wrist",  # 10
        "left_hip",
        "right_hip",  # 12
        "left_knee",
        "right_knee",  # 14
        "left_ankle",
        "right_ankle",  # 16
    ]


def get_coco_flipmap():
    keypoints = get_coco_keypoints()
    keypoint_flip_map = {
        "left_eye": "right_eye",
        "left_ear": "right_ear",
        "left_shoulder": "right_shoulder",
        "left_elbow": "right_elbow",
        "left_wrist": "right_wrist",
        "left_hip": "right_hip",
        "left_knee": "right_knee",
        "left_ankle": "right_ankle",
    }
    for key, value in list(keypoint_flip_map.items()):
        keypoint_flip_map[value] = key
    keypoint_flip_map["nose"] = "nose"
    keypoint_flip_map_idx = []
    for source in keypoints:
        keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
    return keypoint_flip_map_idx


def mask_decoder(x):
    mask = torch.from_numpy(np.array(Image.open(io.BytesIO(x)))).squeeze()[None]
    mask = mask > 0  # This fixes bug causing  maskf.loat().max() == 255.
    return mask


def png_decoder(x):
    if PYSPNG_IMPORTED:
        return torch.from_numpy(np.rollaxis(pyspng.load(x), 2))
    with Image.open(io.BytesIO(x)) as im:
        im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
    return im


def jpg_decoder(x):
    with Image.open(io.BytesIO(x)) as im:
        im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
    return im


def get_num_workers(num_workers: int):
    n_cpus = multiprocessing.cpu_count()
    if num_workers > n_cpus:
        logger.warn(f"Setting the number of workers to match cpu count: {n_cpus}")
        return n_cpus
    return num_workers


def collate_fn(batch):
    elem = batch[0]
    ignore_keys = set(["embed_map", "vertx2cat"])
    batch_ = {
        key: default_collate([d[key] for d in batch])
        for key in elem
        if key not in ignore_keys
    }
    if "embed_map" in elem:
        batch_["embed_map"] = elem["embed_map"]
    if "vertx2cat" in elem:
        batch_["vertx2cat"] = elem["vertx2cat"]
    return batch_