File size: 8,241 Bytes
3bbb319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
# Copyright (c) OpenMMLab. All rights reserved.
import os
import warnings
from argparse import ArgumentParser
import cv2
import mmcv
from mmpose.apis import (collect_multi_frames, get_track_id,
inference_top_down_pose_model, init_pose_model,
process_mmdet_results, vis_pose_tracking_result)
from mmpose.core import Smoother
from mmpose.datasets import DatasetInfo
try:
from mmdet.apis import inference_detector, init_detector
has_mmdet = True
except (ImportError, ModuleNotFoundError):
has_mmdet = False
def main():
"""Visualize the demo images.
Using mmdet to detect the human.
"""
parser = ArgumentParser()
parser.add_argument('det_config', help='Config file for detection')
parser.add_argument('det_checkpoint', help='Checkpoint file for detection')
parser.add_argument('pose_config', help='Config file for pose')
parser.add_argument('pose_checkpoint', help='Checkpoint file for pose')
parser.add_argument('--video-path', type=str, help='Video path')
parser.add_argument(
'--show',
action='store_true',
default=False,
help='whether to show visualizations.')
parser.add_argument(
'--out-video-root',
default='',
help='Root of the output video file. '
'Default not saving the visualization video.')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--det-cat-id',
type=int,
default=1,
help='Category id for bounding box detection model')
parser.add_argument(
'--bbox-thr',
type=float,
default=0.3,
help='Bounding box score threshold')
parser.add_argument(
'--kpt-thr', type=float, default=0.3, help='Keypoint score threshold')
parser.add_argument(
'--use-oks-tracking', action='store_true', help='Using OKS tracking')
parser.add_argument(
'--tracking-thr', type=float, default=0.3, help='Tracking threshold')
parser.add_argument(
'--euro',
action='store_true',
help='(Deprecated, please use --smooth and --smooth-filter-cfg) '
'Using One_Euro_Filter for smoothing.')
parser.add_argument(
'--smooth',
action='store_true',
help='Apply a temporal filter to smooth the pose estimation results. '
'See also --smooth-filter-cfg.')
parser.add_argument(
'--smooth-filter-cfg',
type=str,
default='configs/_base_/filters/one_euro.py',
help='Config file of the filter to smooth the pose estimation '
'results. See also --smooth.')
parser.add_argument(
'--radius',
type=int,
default=4,
help='Keypoint radius for visualization')
parser.add_argument(
'--thickness',
type=int,
default=1,
help='Link thickness for visualization')
parser.add_argument(
'--use-multi-frames',
action='store_true',
default=False,
help='whether to use multi frames for inference in the pose'
'estimation stage. Default: False.')
parser.add_argument(
'--online',
action='store_true',
default=False,
help='inference mode. If set to True, can not use future frame'
'information when using multi frames for inference in the pose'
'estimation stage. Default: False.')
assert has_mmdet, 'Please install mmdet to run the demo.'
args = parser.parse_args()
assert args.show or (args.out_video_root != '')
assert args.det_config is not None
assert args.det_checkpoint is not None
print('Initializing model...')
det_model = init_detector(
args.det_config, args.det_checkpoint, device=args.device.lower())
# build the pose model from a config file and a checkpoint file
pose_model = init_pose_model(
args.pose_config, args.pose_checkpoint, device=args.device.lower())
dataset = pose_model.cfg.data['test']['type']
dataset_info = pose_model.cfg.data['test'].get('dataset_info', None)
if dataset_info is None:
warnings.warn(
'Please set `dataset_info` in the config.'
'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
DeprecationWarning)
else:
dataset_info = DatasetInfo(dataset_info)
# read video
video = mmcv.VideoReader(args.video_path)
assert video.opened, f'Faild to load video file {args.video_path}'
if args.out_video_root == '':
save_out_video = False
else:
os.makedirs(args.out_video_root, exist_ok=True)
save_out_video = True
if save_out_video:
fps = video.fps
size = (video.width, video.height)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
videoWriter = cv2.VideoWriter(
os.path.join(args.out_video_root,
f'vis_{os.path.basename(args.video_path)}'), fourcc,
fps, size)
# frame index offsets for inference, used in multi-frame inference setting
if args.use_multi_frames:
assert 'frame_indices_test' in pose_model.cfg.data.test.data_cfg
indices = pose_model.cfg.data.test.data_cfg['frame_indices_test']
# build pose smoother for temporal refinement
if args.euro:
warnings.warn(
'Argument --euro will be deprecated in the future. '
'Please use --smooth to enable temporal smoothing, and '
'--smooth-filter-cfg to set the filter config.',
DeprecationWarning)
smoother = Smoother(
filter_cfg='configs/_base_/filters/one_euro.py', keypoint_dim=2)
elif args.smooth:
smoother = Smoother(filter_cfg=args.smooth_filter_cfg, keypoint_dim=2)
else:
smoother = None
# whether to return heatmap, optional
return_heatmap = False
# return the output of some desired layers,
# e.g. use ('backbone', ) to return backbone feature
output_layer_names = None
next_id = 0
pose_results = []
print('Running inference...')
for frame_id, cur_frame in enumerate(mmcv.track_iter_progress(video)):
pose_results_last = pose_results
# get the detection results of current frame
# the resulting box is (x1, y1, x2, y2)
mmdet_results = inference_detector(det_model, cur_frame)
# keep the person class bounding boxes.
person_results = process_mmdet_results(mmdet_results, args.det_cat_id)
if args.use_multi_frames:
frames = collect_multi_frames(video, frame_id, indices,
args.online)
# test a single image, with a list of bboxes.
pose_results, _ = inference_top_down_pose_model(
pose_model,
frames if args.use_multi_frames else cur_frame,
person_results,
bbox_thr=args.bbox_thr,
format='xyxy',
dataset=dataset,
dataset_info=dataset_info,
return_heatmap=return_heatmap,
outputs=output_layer_names)
# get track id for each person instance
pose_results, next_id = get_track_id(
pose_results,
pose_results_last,
next_id,
use_oks=args.use_oks_tracking,
tracking_thr=args.tracking_thr)
# post-process the pose results with smoother
if smoother:
pose_results = smoother.smooth(pose_results)
# show the results
vis_frame = vis_pose_tracking_result(
pose_model,
cur_frame,
pose_results,
radius=args.radius,
thickness=args.thickness,
dataset=dataset,
dataset_info=dataset_info,
kpt_score_thr=args.kpt_thr,
show=False)
if args.show:
cv2.imshow('Frame', vis_frame)
if save_out_video:
videoWriter.write(vis_frame)
if args.show and cv2.waitKey(1) & 0xFF == ord('q'):
break
if save_out_video:
videoWriter.release()
if args.show:
cv2.destroyAllWindows()
if __name__ == '__main__':
main()
|