Spaces:
Running
on
A10G
Running
on
A10G
# Copyright (c) OpenMMLab. All rights reserved. | |
from argparse import ArgumentParser | |
from typing import Dict | |
import sys | |
sys.path.append("./") | |
sys.path.append("./mmpose") | |
from mmpose.apis.inferencers import MMPoseInferencer, get_model_aliases | |
filter_args = dict(bbox_thr=0.3, nms_thr=0.3, pose_based_nms=False) | |
POSE2D_SPECIFIC_ARGS = dict( | |
yoloxpose=dict(bbox_thr=0.01, nms_thr=0.65, pose_based_nms=True), | |
rtmo=dict(bbox_thr=0.1, nms_thr=0.65, pose_based_nms=True), | |
) | |
def parse_args(): | |
parser = ArgumentParser() | |
parser.add_argument( | |
'inputs', | |
type=str, | |
nargs='?', | |
help='Input image/video path or folder path.') | |
# init args | |
parser.add_argument( | |
'--pose2d', | |
type=str, | |
default="wholebody", | |
help='Pretrained 2D pose estimation algorithm. It\'s the path to the ' | |
'config file or the model name defined in metafile.') | |
parser.add_argument( | |
'--pose2d-weights', | |
type=str, | |
default=None, | |
help='Path to the custom checkpoint file of the selected pose model. ' | |
'If it is not specified and "pose2d" is a model name of metafile, ' | |
'the weights will be loaded from metafile.') | |
parser.add_argument( | |
'--pose3d', | |
type=str, | |
default=None, | |
help='Pretrained 3D pose estimation algorithm. It\'s the path to the ' | |
'config file or the model name defined in metafile.') | |
parser.add_argument( | |
'--pose3d-weights', | |
type=str, | |
default=None, | |
help='Path to the custom checkpoint file of the selected pose model. ' | |
'If it is not specified and "pose3d" is a model name of metafile, ' | |
'the weights will be loaded from metafile.') | |
parser.add_argument( | |
'--det-model', | |
type=str, | |
default=None, | |
help='Config path or alias of detection model.') | |
parser.add_argument( | |
'--det-weights', | |
type=str, | |
default=None, | |
help='Path to the checkpoints of detection model.') | |
parser.add_argument( | |
'--det-cat-ids', | |
type=int, | |
nargs='+', | |
default=0, | |
help='Category id for detection model.') | |
parser.add_argument( | |
'--scope', | |
type=str, | |
default='mmpose', | |
help='Scope where modules are defined.') | |
parser.add_argument( | |
'--device', | |
type=str, | |
default=None, | |
help='Device used for inference. ' | |
'If not specified, the available device will be automatically used.') | |
parser.add_argument( | |
'--show-progress', | |
action='store_true', | |
help='Display the progress bar during inference.') | |
# The default arguments for prediction filtering differ for top-down | |
# and bottom-up models. We assign the default arguments according to the | |
# selected pose2d model | |
args, _ = parser.parse_known_args() | |
for model in POSE2D_SPECIFIC_ARGS: | |
if model in args.pose2d: | |
filter_args.update(POSE2D_SPECIFIC_ARGS[model]) | |
break | |
# call args | |
parser.add_argument( | |
'--show', | |
action='store_true', | |
help='Display the image/video in a popup window.') | |
parser.add_argument( | |
'--draw-bbox', | |
action='store_true', | |
help='Whether to draw the bounding boxes.') | |
parser.add_argument( | |
'--draw-heatmap', | |
action='store_true', | |
default=False, | |
help='Whether to draw the predicted heatmaps.') | |
parser.add_argument( | |
'--bbox-thr', | |
type=float, | |
default=filter_args['bbox_thr'], | |
help='Bounding box score threshold') | |
parser.add_argument( | |
'--nms-thr', | |
type=float, | |
default=filter_args['nms_thr'], | |
help='IoU threshold for bounding box NMS') | |
parser.add_argument( | |
'--pose-based-nms', | |
type=lambda arg: arg.lower() in ('true', 'yes', 't', 'y', '1'), | |
default=filter_args['pose_based_nms'], | |
help='Whether to use pose-based NMS') | |
parser.add_argument( | |
'--kpt-thr', type=float, default=0.3, help='Keypoint score threshold') | |
parser.add_argument( | |
'--tracking-thr', type=float, default=0.3, help='Tracking threshold') | |
parser.add_argument( | |
'--use-oks-tracking', | |
action='store_true', | |
help='Whether to use OKS as similarity in tracking') | |
parser.add_argument( | |
'--disable-norm-pose-2d', | |
action='store_true', | |
help='Whether to scale the bbox (along with the 2D pose) to the ' | |
'average bbox scale of the dataset, and move the bbox (along with the ' | |
'2D pose) to the average bbox center of the dataset. This is useful ' | |
'when bbox is small, especially in multi-person scenarios.') | |
parser.add_argument( | |
'--disable-rebase-keypoint', | |
action='store_true', | |
default=False, | |
help='Whether to disable rebasing the predicted 3D pose so its ' | |
'lowest keypoint has a height of 0 (landing on the ground). Rebase ' | |
'is useful for visualization when the model do not predict the ' | |
'global position of the 3D pose.') | |
parser.add_argument( | |
'--num-instances', | |
type=int, | |
default=1, | |
help='The number of 3D poses to be visualized in every frame. If ' | |
'less than 0, it will be set to the number of pose results in the ' | |
'first frame.') | |
parser.add_argument( | |
'--radius', | |
type=int, | |
default=3, | |
help='Keypoint radius for visualization.') | |
parser.add_argument( | |
'--thickness', | |
type=int, | |
default=1, | |
help='Link thickness for visualization.') | |
parser.add_argument( | |
'--skeleton-style', | |
default='mmpose', | |
type=str, | |
choices=['mmpose', 'openpose'], | |
help='Skeleton style selection') | |
parser.add_argument( | |
'--black-background', | |
action='store_true', | |
help='Plot predictions on a black image') | |
parser.add_argument( | |
'--vis-out-dir', | |
type=str, | |
default='', #'tmp/nouse/', | |
help='Directory for saving visualized results.') | |
parser.add_argument( | |
'--pred-out-dir', | |
type=str, | |
default='tmp/', | |
help='Directory for saving inference results.') | |
parser.add_argument( | |
'--show-alias', | |
action='store_true', | |
help='Display all the available model aliases.') | |
call_args = vars(parser.parse_args()) | |
init_kws = [ | |
'pose2d', 'pose2d_weights', 'scope', 'device', 'det_model', | |
'det_weights', 'det_cat_ids', 'pose3d', 'pose3d_weights', | |
'show_progress' | |
] | |
init_args = {} | |
for init_kw in init_kws: | |
init_args[init_kw] = call_args.pop(init_kw) | |
display_alias = call_args.pop('show_alias') | |
return init_args, call_args, display_alias | |
def display_model_aliases(model_aliases: Dict[str, str]) -> None: | |
"""Display the available model aliases and their corresponding model | |
names.""" | |
aliases = list(model_aliases.keys()) | |
max_alias_length = max(map(len, aliases)) | |
print(f'{"ALIAS".ljust(max_alias_length+2)}MODEL_NAME') | |
for alias in sorted(aliases): | |
print(f'{alias.ljust(max_alias_length+2)}{model_aliases[alias]}') | |
def main(): | |
init_args, call_args, display_alias = parse_args() | |
if display_alias: | |
model_alises = get_model_aliases(init_args['scope']) | |
display_model_aliases(model_alises) | |
else: | |
inferencer = MMPoseInferencer(**init_args) | |
for _ in inferencer(**call_args): | |
pass | |
if __name__ == '__main__': | |
main() |