# Copyright (c) OpenMMLab. All rights reserved. from argparse import ArgumentParser from typing import Dict import sys from mmpose.apis.inferencers import MMPoseInferencer, get_model_aliases filter_args = dict(bbox_thr=0.3, nms_thr=0.3, pose_based_nms=False) POSE2D_SPECIFIC_ARGS = dict( yoloxpose=dict(bbox_thr=0.01, nms_thr=0.65, pose_based_nms=True), rtmo=dict(bbox_thr=0.1, nms_thr=0.65, pose_based_nms=True), ) def parse_args(): parser = ArgumentParser() parser.add_argument( 'inputs', type=str, nargs='?', help='Input image/video path or folder path.') # init args parser.add_argument( '--pose2d', type=str, default="wholebody", help='Pretrained 2D pose estimation algorithm. It\'s the path to the ' 'config file or the model name defined in metafile.') parser.add_argument( '--pose2d-weights', type=str, default=None, help='Path to the custom checkpoint file of the selected pose model. ' 'If it is not specified and "pose2d" is a model name of metafile, ' 'the weights will be loaded from metafile.') parser.add_argument( '--pose3d', type=str, default=None, help='Pretrained 3D pose estimation algorithm. It\'s the path to the ' 'config file or the model name defined in metafile.') parser.add_argument( '--pose3d-weights', type=str, default=None, help='Path to the custom checkpoint file of the selected pose model. ' 'If it is not specified and "pose3d" is a model name of metafile, ' 'the weights will be loaded from metafile.') parser.add_argument( '--det-model', type=str, default=None, help='Config path or alias of detection model.') parser.add_argument( '--det-weights', type=str, default=None, help='Path to the checkpoints of detection model.') parser.add_argument( '--det-cat-ids', type=int, nargs='+', default=0, help='Category id for detection model.') parser.add_argument( '--scope', type=str, default='mmpose', help='Scope where modules are defined.') parser.add_argument( '--device', type=str, default=None, help='Device used for inference. ' 'If not specified, the available device will be automatically used.') parser.add_argument( '--show-progress', action='store_true', help='Display the progress bar during inference.') # The default arguments for prediction filtering differ for top-down # and bottom-up models. We assign the default arguments according to the # selected pose2d model args, _ = parser.parse_known_args() for model in POSE2D_SPECIFIC_ARGS: if model in args.pose2d: filter_args.update(POSE2D_SPECIFIC_ARGS[model]) break # call args parser.add_argument( '--show', action='store_true', help='Display the image/video in a popup window.') parser.add_argument( '--draw-bbox', action='store_true', help='Whether to draw the bounding boxes.') parser.add_argument( '--draw-heatmap', action='store_true', default=False, help='Whether to draw the predicted heatmaps.') parser.add_argument( '--bbox-thr', type=float, default=filter_args['bbox_thr'], help='Bounding box score threshold') parser.add_argument( '--nms-thr', type=float, default=filter_args['nms_thr'], help='IoU threshold for bounding box NMS') parser.add_argument( '--pose-based-nms', type=lambda arg: arg.lower() in ('true', 'yes', 't', 'y', '1'), default=filter_args['pose_based_nms'], help='Whether to use pose-based NMS') parser.add_argument( '--kpt-thr', type=float, default=0.3, help='Keypoint score threshold') parser.add_argument( '--tracking-thr', type=float, default=0.3, help='Tracking threshold') parser.add_argument( '--use-oks-tracking', action='store_true', help='Whether to use OKS as similarity in tracking') parser.add_argument( '--disable-norm-pose-2d', action='store_true', help='Whether to scale the bbox (along with the 2D pose) to the ' 'average bbox scale of the dataset, and move the bbox (along with the ' '2D pose) to the average bbox center of the dataset. This is useful ' 'when bbox is small, especially in multi-person scenarios.') parser.add_argument( '--disable-rebase-keypoint', action='store_true', default=False, help='Whether to disable rebasing the predicted 3D pose so its ' 'lowest keypoint has a height of 0 (landing on the ground). Rebase ' 'is useful for visualization when the model do not predict the ' 'global position of the 3D pose.') parser.add_argument( '--num-instances', type=int, default=1, help='The number of 3D poses to be visualized in every frame. If ' 'less than 0, it will be set to the number of pose results in the ' 'first frame.') parser.add_argument( '--radius', type=int, default=3, help='Keypoint radius for visualization.') parser.add_argument( '--thickness', type=int, default=1, help='Link thickness for visualization.') parser.add_argument( '--skeleton-style', default='mmpose', type=str, choices=['mmpose', 'openpose'], help='Skeleton style selection') parser.add_argument( '--black-background', action='store_true', help='Plot predictions on a black image') parser.add_argument( '--vis-out-dir', type=str, default='', #'tmp/nouse/', help='Directory for saving visualized results.') parser.add_argument( '--pred-out-dir', type=str, default='tmp/', help='Directory for saving inference results.') parser.add_argument( '--show-alias', action='store_true', help='Display all the available model aliases.') call_args = vars(parser.parse_args()) init_kws = [ 'pose2d', 'pose2d_weights', 'scope', 'device', 'det_model', 'det_weights', 'det_cat_ids', 'pose3d', 'pose3d_weights', 'show_progress' ] init_args = {} for init_kw in init_kws: init_args[init_kw] = call_args.pop(init_kw) display_alias = call_args.pop('show_alias') return init_args, call_args, display_alias def display_model_aliases(model_aliases: Dict[str, str]) -> None: """Display the available model aliases and their corresponding model names.""" aliases = list(model_aliases.keys()) max_alias_length = max(map(len, aliases)) print(f'{"ALIAS".ljust(max_alias_length+2)}MODEL_NAME') for alias in sorted(aliases): print(f'{alias.ljust(max_alias_length+2)}{model_aliases[alias]}') def main(): init_args, call_args, display_alias = parse_args() if display_alias: model_alises = get_model_aliases(init_args['scope']) display_model_aliases(model_alises) else: inferencer = MMPoseInferencer(**init_args) for _ in inferencer(**call_args): pass if __name__ == '__main__': main()