File size: 13,145 Bytes
d7a991a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import numpy as np

from mmpose.core import OneEuroFilter, oks_iou


def _compute_iou(bboxA, bboxB):
    """Compute the Intersection over Union (IoU) between two boxes .

    Args:
        bboxA (list): The first bbox info (left, top, right, bottom, score).
        bboxB (list): The second bbox info (left, top, right, bottom, score).

    Returns:
        float: The IoU value.
    """

    x1 = max(bboxA[0], bboxB[0])
    y1 = max(bboxA[1], bboxB[1])
    x2 = min(bboxA[2], bboxB[2])
    y2 = min(bboxA[3], bboxB[3])

    inter_area = max(0, x2 - x1) * max(0, y2 - y1)

    bboxA_area = (bboxA[2] - bboxA[0]) * (bboxA[3] - bboxA[1])
    bboxB_area = (bboxB[2] - bboxB[0]) * (bboxB[3] - bboxB[1])
    union_area = float(bboxA_area + bboxB_area - inter_area)
    if union_area == 0:
        union_area = 1e-5
        warnings.warn('union_area=0 is unexpected')

    iou = inter_area / union_area

    return iou


def _track_by_iou(res, results_last, thr):
    """Get track id using IoU tracking greedily.

    Args:
        res (dict): The bbox & pose results of the person instance.
        results_last (list[dict]): The bbox & pose & track_id info of the
            last frame (bbox_result, pose_result, track_id).
        thr (float): The threshold for iou tracking.

    Returns:
        int: The track id for the new person instance.
        list[dict]: The bbox & pose & track_id info of the persons
            that have not been matched on the last frame.
        dict: The matched person instance on the last frame.
    """

    bbox = list(res['bbox'])

    max_iou_score = -1
    max_index = -1
    match_result = {}
    for index, res_last in enumerate(results_last):
        bbox_last = list(res_last['bbox'])

        iou_score = _compute_iou(bbox, bbox_last)
        if iou_score > max_iou_score:
            max_iou_score = iou_score
            max_index = index

    if max_iou_score > thr:
        track_id = results_last[max_index]['track_id']
        match_result = results_last[max_index]
        del results_last[max_index]
    else:
        track_id = -1

    return track_id, results_last, match_result


def _track_by_oks(res, results_last, thr):
    """Get track id using OKS tracking greedily.

    Args:
        res (dict): The pose results of the person instance.
        results_last (list[dict]): The pose & track_id info of the
            last frame (pose_result, track_id).
        thr (float): The threshold for oks tracking.

    Returns:
        int: The track id for the new person instance.
        list[dict]: The pose & track_id info of the persons
            that have not been matched on the last frame.
        dict: The matched person instance on the last frame.
    """
    pose = res['keypoints'].reshape((-1))
    area = res['area']
    max_index = -1
    match_result = {}

    if len(results_last) == 0:
        return -1, results_last, match_result

    pose_last = np.array(
        [res_last['keypoints'].reshape((-1)) for res_last in results_last])
    area_last = np.array([res_last['area'] for res_last in results_last])

    oks_score = oks_iou(pose, pose_last, area, area_last)

    max_index = np.argmax(oks_score)

    if oks_score[max_index] > thr:
        track_id = results_last[max_index]['track_id']
        match_result = results_last[max_index]
        del results_last[max_index]
    else:
        track_id = -1

    return track_id, results_last, match_result


def _get_area(results):
    """Get bbox for each person instance on the current frame.

    Args:
        results (list[dict]): The pose results of the current frame
            (pose_result).
    Returns:
        list[dict]: The bbox & pose info of the current frame
            (bbox_result, pose_result, area).
    """
    for result in results:
        if 'bbox' in result:
            result['area'] = ((result['bbox'][2] - result['bbox'][0]) *
                              (result['bbox'][3] - result['bbox'][1]))
        else:
            xmin = np.min(
                result['keypoints'][:, 0][result['keypoints'][:, 0] > 0],
                initial=1e10)
            xmax = np.max(result['keypoints'][:, 0])
            ymin = np.min(
                result['keypoints'][:, 1][result['keypoints'][:, 1] > 0],
                initial=1e10)
            ymax = np.max(result['keypoints'][:, 1])
            result['area'] = (xmax - xmin) * (ymax - ymin)
            result['bbox'] = np.array([xmin, ymin, xmax, ymax])
    return results


def _temporal_refine(result, match_result, fps=None):
    """Refine koypoints using tracked person instance on last frame.

    Args:
        results (dict): The pose results of the current frame
                (pose_result).
        match_result (dict): The pose results of the last frame
                (match_result)
    Returns:
        (array): The person keypoints after refine.
    """
    if 'one_euro' in match_result:
        result['keypoints'][:, :2] = match_result['one_euro'](
            result['keypoints'][:, :2])
        result['one_euro'] = match_result['one_euro']
    else:
        result['one_euro'] = OneEuroFilter(result['keypoints'][:, :2], fps=fps)
    return result['keypoints']


def get_track_id(results,
                 results_last,
                 next_id,
                 min_keypoints=3,
                 use_oks=False,
                 tracking_thr=0.3,
                 use_one_euro=False,
                 fps=None):
    """Get track id for each person instance on the current frame.

    Args:
        results (list[dict]): The bbox & pose results of the current frame
            (bbox_result, pose_result).
        results_last (list[dict]): The bbox & pose & track_id info of the
            last frame (bbox_result, pose_result, track_id).
        next_id (int): The track id for the new person instance.
        min_keypoints (int): Minimum number of keypoints recognized as person.
            default: 3.
        use_oks (bool): Flag to using oks tracking. default: False.
        tracking_thr (float): The threshold for tracking.
        use_one_euro (bool): Option to use one-euro-filter. default: False.
        fps (optional): Parameters that d_cutoff
            when one-euro-filter is used as a video input

    Returns:
        tuple:
        - results (list[dict]): The bbox & pose & track_id info of the \
            current frame (bbox_result, pose_result, track_id).
        - next_id (int): The track id for the new person instance.
    """
    results = _get_area(results)

    if use_oks:
        _track = _track_by_oks
    else:
        _track = _track_by_iou

    for result in results:
        track_id, results_last, match_result = _track(result, results_last,
                                                      tracking_thr)
        if track_id == -1:
            if np.count_nonzero(result['keypoints'][:, 1]) > min_keypoints:
                result['track_id'] = next_id
                next_id += 1
            else:
                # If the number of keypoints detected is small,
                # delete that person instance.
                result['keypoints'][:, 1] = -10
                result['bbox'] *= 0
                result['track_id'] = -1
        else:
            result['track_id'] = track_id
        if use_one_euro:
            result['keypoints'] = _temporal_refine(
                result, match_result, fps=fps)
        del match_result

    return results, next_id


def vis_pose_tracking_result(model,
                             img,
                             result,
                             radius=4,
                             thickness=1,
                             kpt_score_thr=0.3,
                             dataset='TopDownCocoDataset',
                             dataset_info=None,
                             show=False,
                             out_file=None):
    """Visualize the pose tracking results on the image.

    Args:
        model (nn.Module): The loaded detector.
        img (str | np.ndarray): Image filename or loaded image.
        result (list[dict]): The results to draw over `img`
            (bbox_result, pose_result).
        radius (int): Radius of circles.
        thickness (int): Thickness of lines.
        kpt_score_thr (float): The threshold to visualize the keypoints.
        skeleton (list[tuple]): Default None.
        show (bool):  Whether to show the image. Default True.
        out_file (str|None): The filename of the output visualization image.
    """
    if hasattr(model, 'module'):
        model = model.module

    palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102],
                        [230, 230, 0], [255, 153, 255], [153, 204, 255],
                        [255, 102, 255], [255, 51, 255], [102, 178, 255],
                        [51, 153, 255], [255, 153, 153], [255, 102, 102],
                        [255, 51, 51], [153, 255, 153], [102, 255, 102],
                        [51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0],
                        [255, 255, 255]])

    if dataset_info is None and dataset is not None:
        warnings.warn(
            'dataset is deprecated.'
            'Please set `dataset_info` in the config.'
            'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
            DeprecationWarning)
        # TODO: These will be removed in the later versions.
        if dataset in ('TopDownCocoDataset', 'BottomUpCocoDataset',
                       'TopDownOCHumanDataset'):
            kpt_num = 17
            skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12],
                        [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9],
                        [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4],
                        [3, 5], [4, 6]]

        elif dataset == 'TopDownCocoWholeBodyDataset':
            kpt_num = 133
            skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12],
                        [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9],
                        [8, 10], [1, 2], [0, 1], [0, 2],
                        [1, 3], [2, 4], [3, 5], [4, 6], [15, 17], [15, 18],
                        [15, 19], [16, 20], [16, 21], [16, 22], [91, 92],
                        [92, 93], [93, 94], [94, 95], [91, 96], [96, 97],
                        [97, 98], [98, 99], [91, 100], [100, 101], [101, 102],
                        [102, 103], [91, 104], [104, 105], [105, 106],
                        [106, 107], [91, 108], [108, 109], [109, 110],
                        [110, 111], [112, 113], [113, 114], [114, 115],
                        [115, 116], [112, 117], [117, 118], [118, 119],
                        [119, 120], [112, 121], [121, 122], [122, 123],
                        [123, 124], [112, 125], [125, 126], [126, 127],
                        [127, 128], [112, 129], [129, 130], [130, 131],
                        [131, 132]]
            radius = 1

        elif dataset == 'TopDownAicDataset':
            kpt_num = 14
            skeleton = [[2, 1], [1, 0], [0, 13], [13, 3], [3, 4], [4, 5],
                        [8, 7], [7, 6], [6, 9], [9, 10], [10, 11], [12, 13],
                        [0, 6], [3, 9]]

        elif dataset == 'TopDownMpiiDataset':
            kpt_num = 16
            skeleton = [[0, 1], [1, 2], [2, 6], [6, 3], [3, 4], [4, 5], [6, 7],
                        [7, 8], [8, 9], [8, 12], [12, 11], [11, 10], [8, 13],
                        [13, 14], [14, 15]]

        elif dataset in ('OneHand10KDataset', 'FreiHandDataset',
                         'PanopticDataset'):
            kpt_num = 21
            skeleton = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7],
                        [7, 8], [0, 9], [9, 10], [10, 11], [11, 12], [0, 13],
                        [13, 14], [14, 15], [15, 16], [0, 17], [17, 18],
                        [18, 19], [19, 20]]

        elif dataset == 'InterHand2DDataset':
            kpt_num = 21
            skeleton = [[0, 1], [1, 2], [2, 3], [4, 5], [5, 6], [6, 7], [8, 9],
                        [9, 10], [10, 11], [12, 13], [13, 14], [14, 15],
                        [16, 17], [17, 18], [18, 19], [3, 20], [7, 20],
                        [11, 20], [15, 20], [19, 20]]

        else:
            raise NotImplementedError()

    elif dataset_info is not None:
        kpt_num = dataset_info.keypoint_num
        skeleton = dataset_info.skeleton

    for res in result:
        track_id = res['track_id']
        bbox_color = palette[track_id % len(palette)]
        pose_kpt_color = palette[[track_id % len(palette)] * kpt_num]
        pose_link_color = palette[[track_id % len(palette)] * len(skeleton)]
        img = model.show_result(
            img, [res],
            skeleton,
            radius=radius,
            thickness=thickness,
            pose_kpt_color=pose_kpt_color,
            pose_link_color=pose_link_color,
            bbox_color=tuple(bbox_color.tolist()),
            kpt_score_thr=kpt_score_thr,
            show=show,
            out_file=out_file)

    return img