Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Union | |
from mmengine.structures import BaseDataElement, InstanceData, PixelData | |
from mmpose.structures import MultilevelPixelData | |
class PoseDataSample(BaseDataElement): | |
"""The base data structure of MMPose that is used as the interface between | |
modules. | |
The attributes of ``PoseDataSample`` includes: | |
- ``gt_instances``(InstanceData): Ground truth of instances with | |
keypoint annotations | |
- ``pred_instances``(InstanceData): Instances with keypoint | |
predictions | |
- ``gt_fields``(PixelData): Ground truth of spatial distribution | |
annotations like keypoint heatmaps and part affine fields (PAF) | |
- ``pred_fields``(PixelData): Predictions of spatial distributions | |
Examples: | |
>>> import torch | |
>>> from mmengine.structures import InstanceData, PixelData | |
>>> from mmpose.structures import PoseDataSample | |
>>> pose_meta = dict(img_shape=(800, 1216), | |
... crop_size=(256, 192), | |
... heatmap_size=(64, 48)) | |
>>> gt_instances = InstanceData() | |
>>> gt_instances.bboxes = torch.rand((1, 4)) | |
>>> gt_instances.keypoints = torch.rand((1, 17, 2)) | |
>>> gt_instances.keypoints_visible = torch.rand((1, 17, 1)) | |
>>> gt_fields = PixelData() | |
>>> gt_fields.heatmaps = torch.rand((17, 64, 48)) | |
>>> data_sample = PoseDataSample(gt_instances=gt_instances, | |
... gt_fields=gt_fields, | |
... metainfo=pose_meta) | |
>>> assert 'img_shape' in data_sample | |
>>> len(data_sample.gt_intances) | |
1 | |
""" | |
def gt_instances(self) -> InstanceData: | |
return self._gt_instances | |
def gt_instances(self, value: InstanceData): | |
self.set_field(value, '_gt_instances', dtype=InstanceData) | |
def gt_instances(self): | |
del self._gt_instances | |
def gt_instance_labels(self) -> InstanceData: | |
return self._gt_instance_labels | |
def gt_instance_labels(self, value: InstanceData): | |
self.set_field(value, '_gt_instance_labels', dtype=InstanceData) | |
def gt_instance_labels(self): | |
del self._gt_instance_labels | |
def pred_instances(self) -> InstanceData: | |
return self._pred_instances | |
def pred_instances(self, value: InstanceData): | |
self.set_field(value, '_pred_instances', dtype=InstanceData) | |
def pred_instances(self): | |
del self._pred_instances | |
def gt_fields(self) -> Union[PixelData, MultilevelPixelData]: | |
return self._gt_fields | |
def gt_fields(self, value: Union[PixelData, MultilevelPixelData]): | |
self.set_field(value, '_gt_fields', dtype=type(value)) | |
def gt_fields(self): | |
del self._gt_fields | |
def pred_fields(self) -> PixelData: | |
return self._pred_heatmaps | |
def pred_fields(self, value: PixelData): | |
self.set_field(value, '_pred_heatmaps', dtype=PixelData) | |
def pred_fields(self): | |
del self._pred_heatmaps | |