SignLanguage / inferencer_demo.py
ZiyuG's picture
Update inferencer_demo.py
ccaeafc verified
raw
history blame
7.42 kB
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
from typing import Dict
import sys
sys.path.append("./")
sys.path.append("./mmpose")
from mmpose.apis.inferencers import MMPoseInferencer, get_model_aliases
filter_args = dict(bbox_thr=0.3, nms_thr=0.3, pose_based_nms=False)
POSE2D_SPECIFIC_ARGS = dict(
yoloxpose=dict(bbox_thr=0.01, nms_thr=0.65, pose_based_nms=True),
rtmo=dict(bbox_thr=0.1, nms_thr=0.65, pose_based_nms=True),
)
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'inputs',
type=str,
nargs='?',
help='Input image/video path or folder path.')
# init args
parser.add_argument(
'--pose2d',
type=str,
default="wholebody",
help='Pretrained 2D pose estimation algorithm. It\'s the path to the '
'config file or the model name defined in metafile.')
parser.add_argument(
'--pose2d-weights',
type=str,
default=None,
help='Path to the custom checkpoint file of the selected pose model. '
'If it is not specified and "pose2d" is a model name of metafile, '
'the weights will be loaded from metafile.')
parser.add_argument(
'--pose3d',
type=str,
default=None,
help='Pretrained 3D pose estimation algorithm. It\'s the path to the '
'config file or the model name defined in metafile.')
parser.add_argument(
'--pose3d-weights',
type=str,
default=None,
help='Path to the custom checkpoint file of the selected pose model. '
'If it is not specified and "pose3d" is a model name of metafile, '
'the weights will be loaded from metafile.')
parser.add_argument(
'--det-model',
type=str,
default=None,
help='Config path or alias of detection model.')
parser.add_argument(
'--det-weights',
type=str,
default=None,
help='Path to the checkpoints of detection model.')
parser.add_argument(
'--det-cat-ids',
type=int,
nargs='+',
default=0,
help='Category id for detection model.')
parser.add_argument(
'--scope',
type=str,
default='mmpose',
help='Scope where modules are defined.')
parser.add_argument(
'--device',
type=str,
default=None,
help='Device used for inference. '
'If not specified, the available device will be automatically used.')
parser.add_argument(
'--show-progress',
action='store_true',
help='Display the progress bar during inference.')
# The default arguments for prediction filtering differ for top-down
# and bottom-up models. We assign the default arguments according to the
# selected pose2d model
args, _ = parser.parse_known_args()
for model in POSE2D_SPECIFIC_ARGS:
if model in args.pose2d:
filter_args.update(POSE2D_SPECIFIC_ARGS[model])
break
# call args
parser.add_argument(
'--show',
action='store_true',
help='Display the image/video in a popup window.')
parser.add_argument(
'--draw-bbox',
action='store_true',
help='Whether to draw the bounding boxes.')
parser.add_argument(
'--draw-heatmap',
action='store_true',
default=False,
help='Whether to draw the predicted heatmaps.')
parser.add_argument(
'--bbox-thr',
type=float,
default=filter_args['bbox_thr'],
help='Bounding box score threshold')
parser.add_argument(
'--nms-thr',
type=float,
default=filter_args['nms_thr'],
help='IoU threshold for bounding box NMS')
parser.add_argument(
'--pose-based-nms',
type=lambda arg: arg.lower() in ('true', 'yes', 't', 'y', '1'),
default=filter_args['pose_based_nms'],
help='Whether to use pose-based NMS')
parser.add_argument(
'--kpt-thr', type=float, default=0.3, help='Keypoint score threshold')
parser.add_argument(
'--tracking-thr', type=float, default=0.3, help='Tracking threshold')
parser.add_argument(
'--use-oks-tracking',
action='store_true',
help='Whether to use OKS as similarity in tracking')
parser.add_argument(
'--disable-norm-pose-2d',
action='store_true',
help='Whether to scale the bbox (along with the 2D pose) to the '
'average bbox scale of the dataset, and move the bbox (along with the '
'2D pose) to the average bbox center of the dataset. This is useful '
'when bbox is small, especially in multi-person scenarios.')
parser.add_argument(
'--disable-rebase-keypoint',
action='store_true',
default=False,
help='Whether to disable rebasing the predicted 3D pose so its '
'lowest keypoint has a height of 0 (landing on the ground). Rebase '
'is useful for visualization when the model do not predict the '
'global position of the 3D pose.')
parser.add_argument(
'--num-instances',
type=int,
default=1,
help='The number of 3D poses to be visualized in every frame. If '
'less than 0, it will be set to the number of pose results in the '
'first frame.')
parser.add_argument(
'--radius',
type=int,
default=3,
help='Keypoint radius for visualization.')
parser.add_argument(
'--thickness',
type=int,
default=1,
help='Link thickness for visualization.')
parser.add_argument(
'--skeleton-style',
default='mmpose',
type=str,
choices=['mmpose', 'openpose'],
help='Skeleton style selection')
parser.add_argument(
'--black-background',
action='store_true',
help='Plot predictions on a black image')
parser.add_argument(
'--vis-out-dir',
type=str,
default='', #'tmp/nouse/',
help='Directory for saving visualized results.')
parser.add_argument(
'--pred-out-dir',
type=str,
default='tmp/',
help='Directory for saving inference results.')
parser.add_argument(
'--show-alias',
action='store_true',
help='Display all the available model aliases.')
call_args = vars(parser.parse_args())
init_kws = [
'pose2d', 'pose2d_weights', 'scope', 'device', 'det_model',
'det_weights', 'det_cat_ids', 'pose3d', 'pose3d_weights',
'show_progress'
]
init_args = {}
for init_kw in init_kws:
init_args[init_kw] = call_args.pop(init_kw)
display_alias = call_args.pop('show_alias')
return init_args, call_args, display_alias
def display_model_aliases(model_aliases: Dict[str, str]) -> None:
"""Display the available model aliases and their corresponding model
names."""
aliases = list(model_aliases.keys())
max_alias_length = max(map(len, aliases))
print(f'{"ALIAS".ljust(max_alias_length+2)}MODEL_NAME')
for alias in sorted(aliases):
print(f'{alias.ljust(max_alias_length+2)}{model_aliases[alias]}')
def main():
init_args, call_args, display_alias = parse_args()
if display_alias:
model_alises = get_model_aliases(init_args['scope'])
display_model_aliases(model_alises)
else:
inferencer = MMPoseInferencer(**init_args)
for _ in inferencer(**call_args):
pass
if __name__ == '__main__':
main()