|
|
|
import copy
|
|
import numpy as np
|
|
from contextlib import contextmanager
|
|
from itertools import count
|
|
from typing import List
|
|
import torch
|
|
from fvcore.transforms import HFlipTransform, NoOpTransform
|
|
from torch import nn
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
|
|
from detectron2.config import configurable
|
|
from detectron2.data.detection_utils import read_image
|
|
from detectron2.data.transforms import (
|
|
RandomFlip,
|
|
ResizeShortestEdge,
|
|
ResizeTransform,
|
|
apply_augmentations,
|
|
)
|
|
from detectron2.structures import Boxes, Instances
|
|
|
|
from .meta_arch import GeneralizedRCNN
|
|
from .postprocessing import detector_postprocess
|
|
from .roi_heads.fast_rcnn import fast_rcnn_inference_single_image
|
|
|
|
__all__ = ["DatasetMapperTTA", "GeneralizedRCNNWithTTA"]
|
|
|
|
|
|
class DatasetMapperTTA:
|
|
"""
|
|
Implement test-time augmentation for detection data.
|
|
It is a callable which takes a dataset dict from a detection dataset,
|
|
and returns a list of dataset dicts where the images
|
|
are augmented from the input image by the transformations defined in the config.
|
|
This is used for test-time augmentation.
|
|
"""
|
|
|
|
@configurable
|
|
def __init__(self, min_sizes: List[int], max_size: int, flip: bool):
|
|
"""
|
|
Args:
|
|
min_sizes: list of short-edge size to resize the image to
|
|
max_size: maximum height or width of resized images
|
|
flip: whether to apply flipping augmentation
|
|
"""
|
|
self.min_sizes = min_sizes
|
|
self.max_size = max_size
|
|
self.flip = flip
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg):
|
|
return {
|
|
"min_sizes": cfg.TEST.AUG.MIN_SIZES,
|
|
"max_size": cfg.TEST.AUG.MAX_SIZE,
|
|
"flip": cfg.TEST.AUG.FLIP,
|
|
}
|
|
|
|
def __call__(self, dataset_dict):
|
|
"""
|
|
Args:
|
|
dict: a dict in standard model input format. See tutorials for details.
|
|
|
|
Returns:
|
|
list[dict]:
|
|
a list of dicts, which contain augmented version of the input image.
|
|
The total number of dicts is ``len(min_sizes) * (2 if flip else 1)``.
|
|
Each dict has field "transforms" which is a TransformList,
|
|
containing the transforms that are used to generate this image.
|
|
"""
|
|
numpy_image = dataset_dict["image"].permute(1, 2, 0).numpy()
|
|
shape = numpy_image.shape
|
|
orig_shape = (dataset_dict["height"], dataset_dict["width"])
|
|
if shape[:2] != orig_shape:
|
|
|
|
pre_tfm = ResizeTransform(orig_shape[0], orig_shape[1], shape[0], shape[1])
|
|
else:
|
|
pre_tfm = NoOpTransform()
|
|
|
|
|
|
aug_candidates = []
|
|
for min_size in self.min_sizes:
|
|
resize = ResizeShortestEdge(min_size, self.max_size)
|
|
aug_candidates.append([resize])
|
|
if self.flip:
|
|
flip = RandomFlip(prob=1.0)
|
|
aug_candidates.append([resize, flip])
|
|
|
|
|
|
ret = []
|
|
for aug in aug_candidates:
|
|
new_image, tfms = apply_augmentations(aug, np.copy(numpy_image))
|
|
torch_image = torch.from_numpy(np.ascontiguousarray(new_image.transpose(2, 0, 1)))
|
|
|
|
dic = copy.deepcopy(dataset_dict)
|
|
dic["transforms"] = pre_tfm + tfms
|
|
dic["image"] = torch_image
|
|
ret.append(dic)
|
|
return ret
|
|
|
|
|
|
class GeneralizedRCNNWithTTA(nn.Module):
|
|
"""
|
|
A GeneralizedRCNN with test-time augmentation enabled.
|
|
Its :meth:`__call__` method has the same interface as :meth:`GeneralizedRCNN.forward`.
|
|
"""
|
|
|
|
def __init__(self, cfg, model, tta_mapper=None, batch_size=3):
|
|
"""
|
|
Args:
|
|
cfg (CfgNode):
|
|
model (GeneralizedRCNN): a GeneralizedRCNN to apply TTA on.
|
|
tta_mapper (callable): takes a dataset dict and returns a list of
|
|
augmented versions of the dataset dict. Defaults to
|
|
`DatasetMapperTTA(cfg)`.
|
|
batch_size (int): batch the augmented images into this batch size for inference.
|
|
"""
|
|
super().__init__()
|
|
if isinstance(model, DistributedDataParallel):
|
|
model = model.module
|
|
assert isinstance(
|
|
model, GeneralizedRCNN
|
|
), "TTA is only supported on GeneralizedRCNN. Got a model of type {}".format(type(model))
|
|
self.cfg = cfg.clone()
|
|
assert not self.cfg.MODEL.KEYPOINT_ON, "TTA for keypoint is not supported yet"
|
|
assert (
|
|
not self.cfg.MODEL.LOAD_PROPOSALS
|
|
), "TTA for pre-computed proposals is not supported yet"
|
|
|
|
self.model = model
|
|
|
|
if tta_mapper is None:
|
|
tta_mapper = DatasetMapperTTA(cfg)
|
|
self.tta_mapper = tta_mapper
|
|
self.batch_size = batch_size
|
|
|
|
@contextmanager
|
|
def _turn_off_roi_heads(self, attrs):
|
|
"""
|
|
Open a context where some heads in `model.roi_heads` are temporarily turned off.
|
|
Args:
|
|
attr (list[str]): the attribute in `model.roi_heads` which can be used
|
|
to turn off a specific head, e.g., "mask_on", "keypoint_on".
|
|
"""
|
|
roi_heads = self.model.roi_heads
|
|
old = {}
|
|
for attr in attrs:
|
|
try:
|
|
old[attr] = getattr(roi_heads, attr)
|
|
except AttributeError:
|
|
|
|
pass
|
|
|
|
if len(old.keys()) == 0:
|
|
yield
|
|
else:
|
|
for attr in old.keys():
|
|
setattr(roi_heads, attr, False)
|
|
yield
|
|
for attr in old.keys():
|
|
setattr(roi_heads, attr, old[attr])
|
|
|
|
def _batch_inference(self, batched_inputs, detected_instances=None):
|
|
"""
|
|
Execute inference on a list of inputs,
|
|
using batch size = self.batch_size, instead of the length of the list.
|
|
|
|
Inputs & outputs have the same format as :meth:`GeneralizedRCNN.inference`
|
|
"""
|
|
if detected_instances is None:
|
|
detected_instances = [None] * len(batched_inputs)
|
|
|
|
outputs = []
|
|
inputs, instances = [], []
|
|
for idx, input, instance in zip(count(), batched_inputs, detected_instances):
|
|
inputs.append(input)
|
|
instances.append(instance)
|
|
if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1:
|
|
outputs.extend(
|
|
self.model.inference(
|
|
inputs,
|
|
instances if instances[0] is not None else None,
|
|
do_postprocess=False,
|
|
)
|
|
)
|
|
inputs, instances = [], []
|
|
return outputs
|
|
|
|
def __call__(self, batched_inputs):
|
|
"""
|
|
Same input/output format as :meth:`GeneralizedRCNN.forward`
|
|
"""
|
|
|
|
def _maybe_read_image(dataset_dict):
|
|
ret = copy.copy(dataset_dict)
|
|
if "image" not in ret:
|
|
image = read_image(ret.pop("file_name"), self.model.input_format)
|
|
image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
|
ret["image"] = image
|
|
if "height" not in ret and "width" not in ret:
|
|
ret["height"] = image.shape[1]
|
|
ret["width"] = image.shape[2]
|
|
return ret
|
|
|
|
return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs]
|
|
|
|
def _inference_one_image(self, input):
|
|
"""
|
|
Args:
|
|
input (dict): one dataset dict with "image" field being a CHW tensor
|
|
|
|
Returns:
|
|
dict: one output dict
|
|
"""
|
|
orig_shape = (input["height"], input["width"])
|
|
augmented_inputs, tfms = self._get_augmented_inputs(input)
|
|
|
|
with self._turn_off_roi_heads(["mask_on", "keypoint_on"]):
|
|
|
|
all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms)
|
|
|
|
merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape)
|
|
|
|
if self.cfg.MODEL.MASK_ON:
|
|
|
|
augmented_instances = self._rescale_detected_boxes(
|
|
augmented_inputs, merged_instances, tfms
|
|
)
|
|
|
|
outputs = self._batch_inference(augmented_inputs, augmented_instances)
|
|
|
|
del augmented_inputs, augmented_instances
|
|
|
|
merged_instances.pred_masks = self._reduce_pred_masks(outputs, tfms)
|
|
merged_instances = detector_postprocess(merged_instances, *orig_shape)
|
|
return {"instances": merged_instances}
|
|
else:
|
|
return {"instances": merged_instances}
|
|
|
|
def _get_augmented_inputs(self, input):
|
|
augmented_inputs = self.tta_mapper(input)
|
|
tfms = [x.pop("transforms") for x in augmented_inputs]
|
|
return augmented_inputs, tfms
|
|
|
|
def _get_augmented_boxes(self, augmented_inputs, tfms):
|
|
|
|
outputs = self._batch_inference(augmented_inputs)
|
|
|
|
all_boxes = []
|
|
all_scores = []
|
|
all_classes = []
|
|
for output, tfm in zip(outputs, tfms):
|
|
|
|
pred_boxes = output.pred_boxes.tensor
|
|
original_pred_boxes = tfm.inverse().apply_box(pred_boxes.cpu().numpy())
|
|
all_boxes.append(torch.from_numpy(original_pred_boxes).to(pred_boxes.device))
|
|
|
|
all_scores.extend(output.scores)
|
|
all_classes.extend(output.pred_classes)
|
|
all_boxes = torch.cat(all_boxes, dim=0)
|
|
return all_boxes, all_scores, all_classes
|
|
|
|
def _merge_detections(self, all_boxes, all_scores, all_classes, shape_hw):
|
|
|
|
num_boxes = len(all_boxes)
|
|
num_classes = self.cfg.MODEL.ROI_HEADS.NUM_CLASSES
|
|
|
|
all_scores_2d = torch.zeros(num_boxes, num_classes + 1, device=all_boxes.device)
|
|
for idx, cls, score in zip(count(), all_classes, all_scores):
|
|
all_scores_2d[idx, cls] = score
|
|
|
|
merged_instances, _ = fast_rcnn_inference_single_image(
|
|
all_boxes,
|
|
all_scores_2d,
|
|
shape_hw,
|
|
1e-8,
|
|
self.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST,
|
|
self.cfg.TEST.DETECTIONS_PER_IMAGE,
|
|
)
|
|
|
|
return merged_instances
|
|
|
|
def _rescale_detected_boxes(self, augmented_inputs, merged_instances, tfms):
|
|
augmented_instances = []
|
|
for input, tfm in zip(augmented_inputs, tfms):
|
|
|
|
pred_boxes = merged_instances.pred_boxes.tensor.cpu().numpy()
|
|
pred_boxes = torch.from_numpy(tfm.apply_box(pred_boxes))
|
|
|
|
aug_instances = Instances(
|
|
image_size=input["image"].shape[1:3],
|
|
pred_boxes=Boxes(pred_boxes),
|
|
pred_classes=merged_instances.pred_classes,
|
|
scores=merged_instances.scores,
|
|
)
|
|
augmented_instances.append(aug_instances)
|
|
return augmented_instances
|
|
|
|
def _reduce_pred_masks(self, outputs, tfms):
|
|
|
|
|
|
|
|
for output, tfm in zip(outputs, tfms):
|
|
if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
|
|
output.pred_masks = output.pred_masks.flip(dims=[3])
|
|
all_pred_masks = torch.stack([o.pred_masks for o in outputs], dim=0)
|
|
avg_pred_masks = torch.mean(all_pred_masks, dim=0)
|
|
return avg_pred_masks
|
|
|