File size: 11,423 Bytes
80914e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
# Ultralytics YOLO ๐, AGPL-3.0 license
from pathlib import Path
import numpy as np
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
from ultralytics.utils.plotting import output_to_target, plot_images
class PoseValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'pose'
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
LOGGER.warning("WARNING โ ๏ธ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
def preprocess(self, batch):
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
batch = super().preprocess(batch)
batch['keypoints'] = batch['keypoints'].to(self.device).float()
return batch
def get_desc(self):
"""Returns description of evaluation metrics in string format."""
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P',
'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds):
"""Apply non-maximum suppression and return detections with high confidence scores."""
return ops.non_max_suppression(preds,
def init_metrics(self, model):
"""Initiate pose estimation metrics for YOLO model."""
self.kpt_shape =['kpt_shape']
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0]
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
def update_metrics(self, preds, batch):
for si, pred in enumerate(preds):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
bbox = batch['bboxes'][idx]
kpts = batch['keypoints'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
nk = kpts.shape[1] # number of keypoints
shape = batch['ori_shape'][si]
correct_kpts = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
if npr == 0:
if nl:
self.stats.append((correct_bboxes, correct_kpts, *torch.zeros(
(2, 0), device=self.device), cls.squeeze(-1)))
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
ratio_pad=batch['ratio_pad'][si]) # native-space pred
pred_kpts = predn[:, 6:].view(npr, nk, -1)
ops.scale_coords(batch['img'][si].shape[1:], pred_kpts, shape, ratio_pad=batch['ratio_pad'][si])
# Evaluate
if nl:
height, width = batch['img'].shape[2:]
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
(width, height, width, height), device=self.device) # target boxes
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
ratio_pad=batch['ratio_pad'][si]) # native-space labels
tkpts = kpts.clone()
tkpts[..., 0] *= width
tkpts[..., 1] *= height
tkpts = ops.scale_coords(batch['img'][si].shape[1:], tkpts, shape, ratio_pad=batch['ratio_pad'][si])
labelsn =, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn[:, :6], labelsn)
correct_kpts = self._process_batch(predn[:, :6], labelsn, pred_kpts, tkpts)
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
# Append correct_masks, correct_boxes, pconf, pcls, tcls
self.stats.append((correct_bboxes, correct_kpts, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
# Save
if self.args.save_json:
self.pred_to_json(predn, batch['im_file'][si])
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def _process_batch(self, detections, labels, pred_kpts=None, gt_kpts=None):
Return correct prediction matrix
detections (array[N, 6]), x1, y1, x2, y2, conf, class
labels (array[M, 5]), class, x1, y1, x2, y2
pred_kpts (array[N, 51]), 51 = 17 * 3
gt_kpts (array[N, 51])
correct (array[N, 10]), for 10 IoU levels
if pred_kpts is not None and gt_kpts is not None:
# `0.53` is from
area = ops.xyxy2xywh(labels[:, 1:])[:, 2:].prod(1) * 0.53
iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
else: # boxes
iou = box_iou(labels[:, 1:], detections[:, :4])
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
correct_class = labels[:, 0:1] == detections[:, 5]
for i in range(len(self.iouv)):
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
if x[0].shape[0]:
matches =, 1), iou[x[0], x[1]][:, None]),
1).cpu().numpy() # [label, detect, iou]
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
def plot_val_samples(self, batch, ni):
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
def plot_predictions(self, batch, preds, ni):
"""Plots predictions for YOLO model."""
pred_kpts =[p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
*output_to_target(preds, max_det=self.args.max_det),
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
on_plot=self.on_plot) # pred
def pred_to_json(self, predn, filename):
"""Converts YOLO predictions to COCO JSON format."""
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for p, b in zip(predn.tolist(), box.tolist()):
'image_id': image_id,
'category_id': self.class_map[int(p[5])],
'bbox': [round(x, 3) for x in b],
'keypoints': p[6:],
'score': round(p[4], 5)})
def eval_json(self, stats):
"""Evaluates object detection model using COCO JSON format."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json =['path'] / 'annotations/person_keypoints_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: #
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f'{x} file not found'
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'keypoints')]):
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
idx = i * 4 + 2
stats[self.metrics.keys[idx + 1]], stats[
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats
def val(cfg=DEFAULT_CFG, use_python=False):
"""Performs validation on YOLO model using given data."""
model = cfg.model or ''
data = or 'coco8-pose.yaml'
args = dict(model=model, data=data)
if use_python:
from ultralytics import YOLO
validator = PoseValidator(args=args)
if __name__ == '__main__':