Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
import numpy as np | |
from mmpose.core import OneEuroFilter, oks_iou | |
def _compute_iou(bboxA, bboxB): | |
"""Compute the Intersection over Union (IoU) between two boxes . | |
Args: | |
bboxA (list): The first bbox info (left, top, right, bottom, score). | |
bboxB (list): The second bbox info (left, top, right, bottom, score). | |
Returns: | |
float: The IoU value. | |
""" | |
x1 = max(bboxA[0], bboxB[0]) | |
y1 = max(bboxA[1], bboxB[1]) | |
x2 = min(bboxA[2], bboxB[2]) | |
y2 = min(bboxA[3], bboxB[3]) | |
inter_area = max(0, x2 - x1) * max(0, y2 - y1) | |
bboxA_area = (bboxA[2] - bboxA[0]) * (bboxA[3] - bboxA[1]) | |
bboxB_area = (bboxB[2] - bboxB[0]) * (bboxB[3] - bboxB[1]) | |
union_area = float(bboxA_area + bboxB_area - inter_area) | |
if union_area == 0: | |
union_area = 1e-5 | |
warnings.warn('union_area=0 is unexpected') | |
iou = inter_area / union_area | |
return iou | |
def _track_by_iou(res, results_last, thr): | |
"""Get track id using IoU tracking greedily. | |
Args: | |
res (dict): The bbox & pose results of the person instance. | |
results_last (list[dict]): The bbox & pose & track_id info of the | |
last frame (bbox_result, pose_result, track_id). | |
thr (float): The threshold for iou tracking. | |
Returns: | |
int: The track id for the new person instance. | |
list[dict]: The bbox & pose & track_id info of the persons | |
that have not been matched on the last frame. | |
dict: The matched person instance on the last frame. | |
""" | |
bbox = list(res['bbox']) | |
max_iou_score = -1 | |
max_index = -1 | |
match_result = {} | |
for index, res_last in enumerate(results_last): | |
bbox_last = list(res_last['bbox']) | |
iou_score = _compute_iou(bbox, bbox_last) | |
if iou_score > max_iou_score: | |
max_iou_score = iou_score | |
max_index = index | |
if max_iou_score > thr: | |
track_id = results_last[max_index]['track_id'] | |
match_result = results_last[max_index] | |
del results_last[max_index] | |
else: | |
track_id = -1 | |
return track_id, results_last, match_result | |
def _track_by_oks(res, results_last, thr): | |
"""Get track id using OKS tracking greedily. | |
Args: | |
res (dict): The pose results of the person instance. | |
results_last (list[dict]): The pose & track_id info of the | |
last frame (pose_result, track_id). | |
thr (float): The threshold for oks tracking. | |
Returns: | |
int: The track id for the new person instance. | |
list[dict]: The pose & track_id info of the persons | |
that have not been matched on the last frame. | |
dict: The matched person instance on the last frame. | |
""" | |
pose = res['keypoints'].reshape((-1)) | |
area = res['area'] | |
max_index = -1 | |
match_result = {} | |
if len(results_last) == 0: | |
return -1, results_last, match_result | |
pose_last = np.array( | |
[res_last['keypoints'].reshape((-1)) for res_last in results_last]) | |
area_last = np.array([res_last['area'] for res_last in results_last]) | |
oks_score = oks_iou(pose, pose_last, area, area_last) | |
max_index = np.argmax(oks_score) | |
if oks_score[max_index] > thr: | |
track_id = results_last[max_index]['track_id'] | |
match_result = results_last[max_index] | |
del results_last[max_index] | |
else: | |
track_id = -1 | |
return track_id, results_last, match_result | |
def _get_area(results): | |
"""Get bbox for each person instance on the current frame. | |
Args: | |
results (list[dict]): The pose results of the current frame | |
(pose_result). | |
Returns: | |
list[dict]: The bbox & pose info of the current frame | |
(bbox_result, pose_result, area). | |
""" | |
for result in results: | |
if 'bbox' in result: | |
result['area'] = ((result['bbox'][2] - result['bbox'][0]) * | |
(result['bbox'][3] - result['bbox'][1])) | |
else: | |
xmin = np.min( | |
result['keypoints'][:, 0][result['keypoints'][:, 0] > 0], | |
initial=1e10) | |
xmax = np.max(result['keypoints'][:, 0]) | |
ymin = np.min( | |
result['keypoints'][:, 1][result['keypoints'][:, 1] > 0], | |
initial=1e10) | |
ymax = np.max(result['keypoints'][:, 1]) | |
result['area'] = (xmax - xmin) * (ymax - ymin) | |
result['bbox'] = np.array([xmin, ymin, xmax, ymax]) | |
return results | |
def _temporal_refine(result, match_result, fps=None): | |
"""Refine koypoints using tracked person instance on last frame. | |
Args: | |
results (dict): The pose results of the current frame | |
(pose_result). | |
match_result (dict): The pose results of the last frame | |
(match_result) | |
Returns: | |
(array): The person keypoints after refine. | |
""" | |
if 'one_euro' in match_result: | |
result['keypoints'][:, :2] = match_result['one_euro']( | |
result['keypoints'][:, :2]) | |
result['one_euro'] = match_result['one_euro'] | |
else: | |
result['one_euro'] = OneEuroFilter(result['keypoints'][:, :2], fps=fps) | |
return result['keypoints'] | |
def get_track_id(results, | |
results_last, | |
next_id, | |
min_keypoints=3, | |
use_oks=False, | |
tracking_thr=0.3, | |
use_one_euro=False, | |
fps=None): | |
"""Get track id for each person instance on the current frame. | |
Args: | |
results (list[dict]): The bbox & pose results of the current frame | |
(bbox_result, pose_result). | |
results_last (list[dict]): The bbox & pose & track_id info of the | |
last frame (bbox_result, pose_result, track_id). | |
next_id (int): The track id for the new person instance. | |
min_keypoints (int): Minimum number of keypoints recognized as person. | |
default: 3. | |
use_oks (bool): Flag to using oks tracking. default: False. | |
tracking_thr (float): The threshold for tracking. | |
use_one_euro (bool): Option to use one-euro-filter. default: False. | |
fps (optional): Parameters that d_cutoff | |
when one-euro-filter is used as a video input | |
Returns: | |
tuple: | |
- results (list[dict]): The bbox & pose & track_id info of the \ | |
current frame (bbox_result, pose_result, track_id). | |
- next_id (int): The track id for the new person instance. | |
""" | |
results = _get_area(results) | |
if use_oks: | |
_track = _track_by_oks | |
else: | |
_track = _track_by_iou | |
for result in results: | |
track_id, results_last, match_result = _track(result, results_last, | |
tracking_thr) | |
if track_id == -1: | |
if np.count_nonzero(result['keypoints'][:, 1]) > min_keypoints: | |
result['track_id'] = next_id | |
next_id += 1 | |
else: | |
# If the number of keypoints detected is small, | |
# delete that person instance. | |
result['keypoints'][:, 1] = -10 | |
result['bbox'] *= 0 | |
result['track_id'] = -1 | |
else: | |
result['track_id'] = track_id | |
if use_one_euro: | |
result['keypoints'] = _temporal_refine( | |
result, match_result, fps=fps) | |
del match_result | |
return results, next_id | |
def vis_pose_tracking_result(model, | |
img, | |
result, | |
radius=4, | |
thickness=1, | |
kpt_score_thr=0.3, | |
dataset='TopDownCocoDataset', | |
dataset_info=None, | |
show=False, | |
out_file=None): | |
"""Visualize the pose tracking results on the image. | |
Args: | |
model (nn.Module): The loaded detector. | |
img (str | np.ndarray): Image filename or loaded image. | |
result (list[dict]): The results to draw over `img` | |
(bbox_result, pose_result). | |
radius (int): Radius of circles. | |
thickness (int): Thickness of lines. | |
kpt_score_thr (float): The threshold to visualize the keypoints. | |
skeleton (list[tuple]): Default None. | |
show (bool): Whether to show the image. Default True. | |
out_file (str|None): The filename of the output visualization image. | |
""" | |
if hasattr(model, 'module'): | |
model = model.module | |
palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], | |
[230, 230, 0], [255, 153, 255], [153, 204, 255], | |
[255, 102, 255], [255, 51, 255], [102, 178, 255], | |
[51, 153, 255], [255, 153, 153], [255, 102, 102], | |
[255, 51, 51], [153, 255, 153], [102, 255, 102], | |
[51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], | |
[255, 255, 255]]) | |
if dataset_info is None and dataset is not None: | |
warnings.warn( | |
'dataset is deprecated.' | |
'Please set `dataset_info` in the config.' | |
'Check https://github.com/open-mmlab/mmpose/pull/663 for details.', | |
DeprecationWarning) | |
# TODO: These will be removed in the later versions. | |
if dataset in ('TopDownCocoDataset', 'BottomUpCocoDataset', | |
'TopDownOCHumanDataset'): | |
kpt_num = 17 | |
skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], | |
[5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], | |
[8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], | |
[3, 5], [4, 6]] | |
elif dataset == 'TopDownCocoWholeBodyDataset': | |
kpt_num = 133 | |
skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], | |
[5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], | |
[8, 10], [1, 2], [0, 1], [0, 2], | |
[1, 3], [2, 4], [3, 5], [4, 6], [15, 17], [15, 18], | |
[15, 19], [16, 20], [16, 21], [16, 22], [91, 92], | |
[92, 93], [93, 94], [94, 95], [91, 96], [96, 97], | |
[97, 98], [98, 99], [91, 100], [100, 101], [101, 102], | |
[102, 103], [91, 104], [104, 105], [105, 106], | |
[106, 107], [91, 108], [108, 109], [109, 110], | |
[110, 111], [112, 113], [113, 114], [114, 115], | |
[115, 116], [112, 117], [117, 118], [118, 119], | |
[119, 120], [112, 121], [121, 122], [122, 123], | |
[123, 124], [112, 125], [125, 126], [126, 127], | |
[127, 128], [112, 129], [129, 130], [130, 131], | |
[131, 132]] | |
radius = 1 | |
elif dataset == 'TopDownAicDataset': | |
kpt_num = 14 | |
skeleton = [[2, 1], [1, 0], [0, 13], [13, 3], [3, 4], [4, 5], | |
[8, 7], [7, 6], [6, 9], [9, 10], [10, 11], [12, 13], | |
[0, 6], [3, 9]] | |
elif dataset == 'TopDownMpiiDataset': | |
kpt_num = 16 | |
skeleton = [[0, 1], [1, 2], [2, 6], [6, 3], [3, 4], [4, 5], [6, 7], | |
[7, 8], [8, 9], [8, 12], [12, 11], [11, 10], [8, 13], | |
[13, 14], [14, 15]] | |
elif dataset in ('OneHand10KDataset', 'FreiHandDataset', | |
'PanopticDataset'): | |
kpt_num = 21 | |
skeleton = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], | |
[7, 8], [0, 9], [9, 10], [10, 11], [11, 12], [0, 13], | |
[13, 14], [14, 15], [15, 16], [0, 17], [17, 18], | |
[18, 19], [19, 20]] | |
elif dataset == 'InterHand2DDataset': | |
kpt_num = 21 | |
skeleton = [[0, 1], [1, 2], [2, 3], [4, 5], [5, 6], [6, 7], [8, 9], | |
[9, 10], [10, 11], [12, 13], [13, 14], [14, 15], | |
[16, 17], [17, 18], [18, 19], [3, 20], [7, 20], | |
[11, 20], [15, 20], [19, 20]] | |
else: | |
raise NotImplementedError() | |
elif dataset_info is not None: | |
kpt_num = dataset_info.keypoint_num | |
skeleton = dataset_info.skeleton | |
for res in result: | |
track_id = res['track_id'] | |
bbox_color = palette[track_id % len(palette)] | |
pose_kpt_color = palette[[track_id % len(palette)] * kpt_num] | |
pose_link_color = palette[[track_id % len(palette)] * len(skeleton)] | |
img = model.show_result( | |
img, [res], | |
skeleton, | |
radius=radius, | |
thickness=thickness, | |
pose_kpt_color=pose_kpt_color, | |
pose_link_color=pose_link_color, | |
bbox_color=tuple(bbox_color.tolist()), | |
kpt_score_thr=kpt_score_thr, | |
show=show, | |
out_file=out_file) | |
return img | |