import copy
import logging
import numpy as np
from typing import List, Optional, Union
import torch
from detectron2.config import configurable
from . import detection_utils as utils
from . import transforms as T
This file contains the default mapping that's applied to "dataset dicts".
__all__ = ["DatasetMapper"]
class DatasetMapper:
A callable which takes a dataset dict in Detectron2 Dataset format,
and map it into a format used by the model.
This is the default callable to be used to map your dataset dict into training data.
You may need to follow it to implement your own one for customized logic,
such as a different way to read or transform images.
See :doc:`/tutorials/data_loading` for details.
The callable currently does the following:
1. Read the image from "file_name"
2. Applies cropping/geometric transforms to the image and annotations
3. Prepare data and annotations to Tensor and :class:`Instances`
def __init__(
is_train: bool,
augmentations: List[Union[T.Augmentation, T.Transform]],
image_format: str,
use_instance_mask: bool = False,
use_keypoint: bool = False,
instance_mask_format: str = "polygon",
keypoint_hflip_indices: Optional[np.ndarray] = None,
precomputed_proposal_topk: Optional[int] = None,
recompute_boxes: bool = False,
NOTE: this interface is experimental.
is_train: whether it's used in training or inference
augmentations: a list of augmentations or deterministic transforms to apply
image_format: an image format supported by :func:`detection_utils.read_image`.
use_instance_mask: whether to process instance segmentation annotations, if available
use_keypoint: whether to process keypoint annotations if available
instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
masks into this format.
keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
precomputed_proposal_topk: if given, will load pre-computed
proposals from dataset_dict and keep the top k proposals for each image.
recompute_boxes: whether to overwrite bounding box annotations
by computing tight bounding boxes from instance mask annotations.
if recompute_boxes:
assert use_instance_mask, "recompute_boxes requires instance masks"
self.is_train = is_train
self.augmentations = T.AugmentationList(augmentations)
self.image_format = image_format
self.use_instance_mask = use_instance_mask
self.instance_mask_format = instance_mask_format
self.use_keypoint = use_keypoint
self.keypoint_hflip_indices = keypoint_hflip_indices
self.proposal_topk = precomputed_proposal_topk
self.recompute_boxes = recompute_boxes
logger = logging.getLogger(__name__)
mode = "training" if is_train else "inference"
logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
def from_config(cls, cfg, is_train: bool = True):
augs = utils.build_augmentation(cfg, is_train)
if cfg.INPUT.CROP.ENABLED and is_train:
augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
recompute_boxes = cfg.MODEL.MASK_ON
recompute_boxes = False
ret = {
"is_train": is_train,
"augmentations": augs,
"image_format": cfg.INPUT.FORMAT,
"use_instance_mask": cfg.MODEL.MASK_ON,
"instance_mask_format": cfg.INPUT.MASK_FORMAT,
"use_keypoint": cfg.MODEL.KEYPOINT_ON,
"recompute_boxes": recompute_boxes,
ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
ret["precomputed_proposal_topk"] = (
if is_train
return ret
def _transform_annotations(self, dataset_dict, transforms, image_shape):
for anno in dataset_dict["annotations"]:
if not self.use_instance_mask:
anno.pop("segmentation", None)
if not self.use_keypoint:
anno.pop("keypoints", None)
annos = [
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
instances = utils.annotations_to_instances(
annos, image_shape, mask_format=self.instance_mask_format
if self.recompute_boxes:
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
dataset_dict["instances"] = utils.filter_empty_instances(instances)
def __call__(self, dataset_dict):
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
dict: a format that builtin models in detectron2 accept
dataset_dict = copy.deepcopy(dataset_dict)
image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
utils.check_image_size(dataset_dict, image)
if "sem_seg_file_name" in dataset_dict:
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
sem_seg_gt = None
aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
transforms = self.augmentations(aug_input)
image, sem_seg_gt = aug_input.image, aug_input.sem_seg
image_shape = image.shape[:2]
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
if sem_seg_gt is not None:
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
if self.proposal_topk is not None:
dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
if not self.is_train:
dataset_dict.pop("annotations", None)
dataset_dict.pop("sem_seg_file_name", None)
return dataset_dict
if "annotations" in dataset_dict:
self._transform_annotations(dataset_dict, transforms, image_shape)
return dataset_dict