Spaces:
Build error
Build error
File size: 13,145 Bytes
d7a991a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 |
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numpy as np
from mmpose.core import OneEuroFilter, oks_iou
def _compute_iou(bboxA, bboxB):
"""Compute the Intersection over Union (IoU) between two boxes .
Args:
bboxA (list): The first bbox info (left, top, right, bottom, score).
bboxB (list): The second bbox info (left, top, right, bottom, score).
Returns:
float: The IoU value.
"""
x1 = max(bboxA[0], bboxB[0])
y1 = max(bboxA[1], bboxB[1])
x2 = min(bboxA[2], bboxB[2])
y2 = min(bboxA[3], bboxB[3])
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
bboxA_area = (bboxA[2] - bboxA[0]) * (bboxA[3] - bboxA[1])
bboxB_area = (bboxB[2] - bboxB[0]) * (bboxB[3] - bboxB[1])
union_area = float(bboxA_area + bboxB_area - inter_area)
if union_area == 0:
union_area = 1e-5
warnings.warn('union_area=0 is unexpected')
iou = inter_area / union_area
return iou
def _track_by_iou(res, results_last, thr):
"""Get track id using IoU tracking greedily.
Args:
res (dict): The bbox & pose results of the person instance.
results_last (list[dict]): The bbox & pose & track_id info of the
last frame (bbox_result, pose_result, track_id).
thr (float): The threshold for iou tracking.
Returns:
int: The track id for the new person instance.
list[dict]: The bbox & pose & track_id info of the persons
that have not been matched on the last frame.
dict: The matched person instance on the last frame.
"""
bbox = list(res['bbox'])
max_iou_score = -1
max_index = -1
match_result = {}
for index, res_last in enumerate(results_last):
bbox_last = list(res_last['bbox'])
iou_score = _compute_iou(bbox, bbox_last)
if iou_score > max_iou_score:
max_iou_score = iou_score
max_index = index
if max_iou_score > thr:
track_id = results_last[max_index]['track_id']
match_result = results_last[max_index]
del results_last[max_index]
else:
track_id = -1
return track_id, results_last, match_result
def _track_by_oks(res, results_last, thr):
"""Get track id using OKS tracking greedily.
Args:
res (dict): The pose results of the person instance.
results_last (list[dict]): The pose & track_id info of the
last frame (pose_result, track_id).
thr (float): The threshold for oks tracking.
Returns:
int: The track id for the new person instance.
list[dict]: The pose & track_id info of the persons
that have not been matched on the last frame.
dict: The matched person instance on the last frame.
"""
pose = res['keypoints'].reshape((-1))
area = res['area']
max_index = -1
match_result = {}
if len(results_last) == 0:
return -1, results_last, match_result
pose_last = np.array(
[res_last['keypoints'].reshape((-1)) for res_last in results_last])
area_last = np.array([res_last['area'] for res_last in results_last])
oks_score = oks_iou(pose, pose_last, area, area_last)
max_index = np.argmax(oks_score)
if oks_score[max_index] > thr:
track_id = results_last[max_index]['track_id']
match_result = results_last[max_index]
del results_last[max_index]
else:
track_id = -1
return track_id, results_last, match_result
def _get_area(results):
"""Get bbox for each person instance on the current frame.
Args:
results (list[dict]): The pose results of the current frame
(pose_result).
Returns:
list[dict]: The bbox & pose info of the current frame
(bbox_result, pose_result, area).
"""
for result in results:
if 'bbox' in result:
result['area'] = ((result['bbox'][2] - result['bbox'][0]) *
(result['bbox'][3] - result['bbox'][1]))
else:
xmin = np.min(
result['keypoints'][:, 0][result['keypoints'][:, 0] > 0],
initial=1e10)
xmax = np.max(result['keypoints'][:, 0])
ymin = np.min(
result['keypoints'][:, 1][result['keypoints'][:, 1] > 0],
initial=1e10)
ymax = np.max(result['keypoints'][:, 1])
result['area'] = (xmax - xmin) * (ymax - ymin)
result['bbox'] = np.array([xmin, ymin, xmax, ymax])
return results
def _temporal_refine(result, match_result, fps=None):
"""Refine koypoints using tracked person instance on last frame.
Args:
results (dict): The pose results of the current frame
(pose_result).
match_result (dict): The pose results of the last frame
(match_result)
Returns:
(array): The person keypoints after refine.
"""
if 'one_euro' in match_result:
result['keypoints'][:, :2] = match_result['one_euro'](
result['keypoints'][:, :2])
result['one_euro'] = match_result['one_euro']
else:
result['one_euro'] = OneEuroFilter(result['keypoints'][:, :2], fps=fps)
return result['keypoints']
def get_track_id(results,
results_last,
next_id,
min_keypoints=3,
use_oks=False,
tracking_thr=0.3,
use_one_euro=False,
fps=None):
"""Get track id for each person instance on the current frame.
Args:
results (list[dict]): The bbox & pose results of the current frame
(bbox_result, pose_result).
results_last (list[dict]): The bbox & pose & track_id info of the
last frame (bbox_result, pose_result, track_id).
next_id (int): The track id for the new person instance.
min_keypoints (int): Minimum number of keypoints recognized as person.
default: 3.
use_oks (bool): Flag to using oks tracking. default: False.
tracking_thr (float): The threshold for tracking.
use_one_euro (bool): Option to use one-euro-filter. default: False.
fps (optional): Parameters that d_cutoff
when one-euro-filter is used as a video input
Returns:
tuple:
- results (list[dict]): The bbox & pose & track_id info of the \
current frame (bbox_result, pose_result, track_id).
- next_id (int): The track id for the new person instance.
"""
results = _get_area(results)
if use_oks:
_track = _track_by_oks
else:
_track = _track_by_iou
for result in results:
track_id, results_last, match_result = _track(result, results_last,
tracking_thr)
if track_id == -1:
if np.count_nonzero(result['keypoints'][:, 1]) > min_keypoints:
result['track_id'] = next_id
next_id += 1
else:
# If the number of keypoints detected is small,
# delete that person instance.
result['keypoints'][:, 1] = -10
result['bbox'] *= 0
result['track_id'] = -1
else:
result['track_id'] = track_id
if use_one_euro:
result['keypoints'] = _temporal_refine(
result, match_result, fps=fps)
del match_result
return results, next_id
def vis_pose_tracking_result(model,
img,
result,
radius=4,
thickness=1,
kpt_score_thr=0.3,
dataset='TopDownCocoDataset',
dataset_info=None,
show=False,
out_file=None):
"""Visualize the pose tracking results on the image.
Args:
model (nn.Module): The loaded detector.
img (str | np.ndarray): Image filename or loaded image.
result (list[dict]): The results to draw over `img`
(bbox_result, pose_result).
radius (int): Radius of circles.
thickness (int): Thickness of lines.
kpt_score_thr (float): The threshold to visualize the keypoints.
skeleton (list[tuple]): Default None.
show (bool): Whether to show the image. Default True.
out_file (str|None): The filename of the output visualization image.
"""
if hasattr(model, 'module'):
model = model.module
palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102],
[230, 230, 0], [255, 153, 255], [153, 204, 255],
[255, 102, 255], [255, 51, 255], [102, 178, 255],
[51, 153, 255], [255, 153, 153], [255, 102, 102],
[255, 51, 51], [153, 255, 153], [102, 255, 102],
[51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0],
[255, 255, 255]])
if dataset_info is None and dataset is not None:
warnings.warn(
'dataset is deprecated.'
'Please set `dataset_info` in the config.'
'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
DeprecationWarning)
# TODO: These will be removed in the later versions.
if dataset in ('TopDownCocoDataset', 'BottomUpCocoDataset',
'TopDownOCHumanDataset'):
kpt_num = 17
skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12],
[5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9],
[8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4],
[3, 5], [4, 6]]
elif dataset == 'TopDownCocoWholeBodyDataset':
kpt_num = 133
skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12],
[5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9],
[8, 10], [1, 2], [0, 1], [0, 2],
[1, 3], [2, 4], [3, 5], [4, 6], [15, 17], [15, 18],
[15, 19], [16, 20], [16, 21], [16, 22], [91, 92],
[92, 93], [93, 94], [94, 95], [91, 96], [96, 97],
[97, 98], [98, 99], [91, 100], [100, 101], [101, 102],
[102, 103], [91, 104], [104, 105], [105, 106],
[106, 107], [91, 108], [108, 109], [109, 110],
[110, 111], [112, 113], [113, 114], [114, 115],
[115, 116], [112, 117], [117, 118], [118, 119],
[119, 120], [112, 121], [121, 122], [122, 123],
[123, 124], [112, 125], [125, 126], [126, 127],
[127, 128], [112, 129], [129, 130], [130, 131],
[131, 132]]
radius = 1
elif dataset == 'TopDownAicDataset':
kpt_num = 14
skeleton = [[2, 1], [1, 0], [0, 13], [13, 3], [3, 4], [4, 5],
[8, 7], [7, 6], [6, 9], [9, 10], [10, 11], [12, 13],
[0, 6], [3, 9]]
elif dataset == 'TopDownMpiiDataset':
kpt_num = 16
skeleton = [[0, 1], [1, 2], [2, 6], [6, 3], [3, 4], [4, 5], [6, 7],
[7, 8], [8, 9], [8, 12], [12, 11], [11, 10], [8, 13],
[13, 14], [14, 15]]
elif dataset in ('OneHand10KDataset', 'FreiHandDataset',
'PanopticDataset'):
kpt_num = 21
skeleton = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7],
[7, 8], [0, 9], [9, 10], [10, 11], [11, 12], [0, 13],
[13, 14], [14, 15], [15, 16], [0, 17], [17, 18],
[18, 19], [19, 20]]
elif dataset == 'InterHand2DDataset':
kpt_num = 21
skeleton = [[0, 1], [1, 2], [2, 3], [4, 5], [5, 6], [6, 7], [8, 9],
[9, 10], [10, 11], [12, 13], [13, 14], [14, 15],
[16, 17], [17, 18], [18, 19], [3, 20], [7, 20],
[11, 20], [15, 20], [19, 20]]
else:
raise NotImplementedError()
elif dataset_info is not None:
kpt_num = dataset_info.keypoint_num
skeleton = dataset_info.skeleton
for res in result:
track_id = res['track_id']
bbox_color = palette[track_id % len(palette)]
pose_kpt_color = palette[[track_id % len(palette)] * kpt_num]
pose_link_color = palette[[track_id % len(palette)] * len(skeleton)]
img = model.show_result(
img, [res],
skeleton,
radius=radius,
thickness=thickness,
pose_kpt_color=pose_kpt_color,
pose_link_color=pose_link_color,
bbox_color=tuple(bbox_color.tolist()),
kpt_score_thr=kpt_score_thr,
show=show,
out_file=out_file)
return img
|