RegionSpot / regionspot /test_time_augmentation.py
bklg's picture
Upload 114 files
a153c95
from itertools import count
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from detectron2.modeling import GeneralizedRCNNWithTTA, DatasetMapperTTA
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference_single_image
from detectron2.structures import Instances, Boxes
class RegionSpotWithTTA(GeneralizedRCNNWithTTA):
def __init__(self, cfg, model, tta_mapper=None, batch_size=3):
"""
Args:
cfg (CfgNode):
model ( RegionSpot): a RegionSpot to apply TTA on.
tta_mapper (callable): takes a dataset dict and returns a list of
augmented versions of the dataset dict. Defaults to
`DatasetMapperTTA(cfg)`.
batch_size (int): batch the augmented images into this batch size for inference.
"""
# fix the issue: cannot assign module before Module.__init__() call
nn.Module.__init__(self)
if isinstance(model, DistributedDataParallel):
model = model.module
self.cfg = cfg.clone()
self.model = model
if tta_mapper is None:
tta_mapper = DatasetMapperTTA(cfg)
self.tta_mapper = tta_mapper
self.batch_size = batch_size
# cvpods tta.
self.enable_cvpods_tta = cfg.TEST.AUG.CVPODS_TTA
self.enable_scale_filter = cfg.TEST.AUG.SCALE_FILTER
self.scale_ranges = cfg.TEST.AUG.SCALE_RANGES
self.max_detection = cfg.MODEL.RegionSpot.NUM_PROPOSALS
def _batch_inference(self, batched_inputs, detected_instances=None):
"""
Execute inference on a list of inputs,
using batch size = self.batch_size, instead of the length of the list.
"""
if detected_instances is None:
detected_instances = [None] * len(batched_inputs)
factors = 2 if self.tta_mapper.flip else 1
if self.enable_scale_filter:
assert len(batched_inputs) == len(self.scale_ranges) * factors
outputs = []
inputs, instances = [], []
for idx, input, instance in zip(count(), batched_inputs, detected_instances):
inputs.append(input)
instances.append(instance)
if self.enable_cvpods_tta:
output = self.model.forward(inputs, do_postprocess=False)[0]
if self.enable_scale_filter:
pred_boxes = output.get("pred_boxes")
keep = self.filter_boxes(pred_boxes.tensor, *self.scale_ranges[idx // factors])
output = Instances(
image_size=output.image_size,
pred_boxes=Boxes(pred_boxes.tensor[keep]),
pred_classes=output.pred_classes[keep],
scores=output.scores[keep])
outputs.extend([output])
else:
if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1:
outputs.extend(
self.model.forward(
inputs,
do_postprocess=False,
)
)
inputs, instances = [], []
return outputs
@staticmethod
def filter_boxes(boxes, min_scale, max_scale):
"""
boxes: (N, 4) shape
"""
# assert boxes.mode == "xyxy"
w = boxes[:, 2] - boxes[:, 0]
h = boxes[:, 3] - boxes[:, 1]
keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale)
return keep
def _inference_one_image(self, input):
"""
Args:
input (dict): one dataset dict with "image" field being a CHW tensor
Returns:
dict: one output dict
"""
orig_shape = (input["height"], input["width"])
augmented_inputs, tfms = self._get_augmented_inputs(input)
# Detect boxes from all augmented versions
all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms)
# merge all detected boxes to obtain final predictions for boxes
if self.enable_cvpods_tta:
merged_instances = self._merge_detections_cvpods_tta(all_boxes, all_scores, all_classes, orig_shape)
else:
merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape)
return {"instances": merged_instances}
def _merge_detections(self, all_boxes, all_scores, all_classes, shape_hw):
# select from the union of all results
num_boxes = len(all_boxes)
num_classes = self.cfg.MODEL. RegionSpot.NUM_CLASSES
# +1 because fast_rcnn_inference expects background scores as well
all_scores_2d = torch.zeros(num_boxes, num_classes + 1, device=all_boxes.device)
for idx, cls, score in zip(count(), all_classes, all_scores):
all_scores_2d[idx, cls] = score
merged_instances, _ = fast_rcnn_inference_single_image(
all_boxes,
all_scores_2d,
shape_hw,
1e-8,
self.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST,
self.cfg.TEST.DETECTIONS_PER_IMAGE,
)
return merged_instances
def _merge_detections_cvpods_tta(self, all_boxes, all_scores, all_classes, shape_hw):
all_scores = torch.tensor(all_scores).to(all_boxes.device)
all_classes = torch.tensor(all_classes).to(all_boxes.device)
all_boxes, all_scores, all_classes = self.merge_result_from_multi_scales(
all_boxes, all_scores, all_classes,
nms_type="soft_vote", vote_thresh=0.65,
max_detection=self.max_detection
)
all_boxes = Boxes(all_boxes)
all_boxes.clip(shape_hw)
result = Instances(shape_hw)
result.pred_boxes = all_boxes
result.scores = all_scores
result.pred_classes = all_classes.long()
return result
def merge_result_from_multi_scales(
self, boxes, scores, labels, nms_type="soft-vote", vote_thresh=0.65, max_detection=100
):
boxes, scores, labels = self.batched_vote_nms(
boxes, scores, labels, nms_type, vote_thresh
)
number_of_detections = boxes.shape[0]
# Limit to max_per_image detections **over all classes**
if number_of_detections > max_detection > 0:
boxes = boxes[:max_detection]
scores = scores[:max_detection]
labels = labels[:max_detection]
return boxes, scores, labels
def batched_vote_nms(self, boxes, scores, labels, vote_type, vote_thresh=0.65):
# apply per class level nms, add max_coordinates on boxes first, then remove it.
labels = labels.float()
max_coordinates = boxes.max() + 1
offsets = labels.reshape(-1, 1) * max_coordinates
boxes = boxes + offsets
boxes, scores, labels = self.bbox_vote(boxes, scores, labels, vote_thresh, vote_type)
boxes -= labels.reshape(-1, 1) * max_coordinates
return boxes, scores, labels
def bbox_vote(self, boxes, scores, labels, vote_thresh, vote_type="softvote"):
assert boxes.shape[0] == scores.shape[0] == labels.shape[0]
det = torch.cat((boxes, scores.reshape(-1, 1), labels.reshape(-1, 1)), dim=1)
vote_results = torch.zeros(0, 6, device=det.device)
if det.numel() == 0:
return vote_results[:, :4], vote_results[:, 4], vote_results[:, 5]
order = scores.argsort(descending=True)
det = det[order]
while det.shape[0] > 0:
# IOU
area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
xx1 = torch.max(det[0, 0], det[:, 0])
yy1 = torch.max(det[0, 1], det[:, 1])
xx2 = torch.min(det[0, 2], det[:, 2])
yy2 = torch.min(det[0, 3], det[:, 3])
w = torch.clamp(xx2 - xx1, min=0.)
h = torch.clamp(yy2 - yy1, min=0.)
inter = w * h
iou = inter / (area[0] + area[:] - inter)
# get needed merge det and delete these det
merge_index = torch.where(iou >= vote_thresh)[0]
vote_det = det[merge_index, :]
det = det[iou < vote_thresh]
if merge_index.shape[0] <= 1:
vote_results = torch.cat((vote_results, vote_det), dim=0)
else:
if vote_type == "soft_vote":
vote_det_iou = iou[merge_index]
det_accu_sum = self.get_soft_dets_sum(vote_det, vote_det_iou)
elif vote_type == "vote":
det_accu_sum = self.get_dets_sum(vote_det)
vote_results = torch.cat((vote_results, det_accu_sum), dim=0)
order = vote_results[:, 4].argsort(descending=True)
vote_results = vote_results[order, :]
return vote_results[:, :4], vote_results[:, 4], vote_results[:, 5]
@staticmethod
def get_dets_sum(vote_det):
vote_det[:, :4] *= vote_det[:, 4:5].repeat(1, 4)
max_score = vote_det[:, 4].max()
det_accu_sum = torch.zeros((1, 6), device=vote_det.device)
det_accu_sum[:, :4] = torch.sum(vote_det[:, :4], dim=0) / torch.sum(vote_det[:, 4])
det_accu_sum[:, 4] = max_score
det_accu_sum[:, 5] = vote_det[0, 5]
return det_accu_sum
@staticmethod
def get_soft_dets_sum(vote_det, vote_det_iou):
soft_vote_det = vote_det.detach().clone()
soft_vote_det[:, 4] *= (1 - vote_det_iou)
INFERENCE_TH = 0.05
soft_index = torch.where(soft_vote_det[:, 4] >= INFERENCE_TH)[0]
soft_vote_det = soft_vote_det[soft_index, :]
vote_det[:, :4] *= vote_det[:, 4:5].repeat(1, 4)
max_score = vote_det[:, 4].max()
det_accu_sum = torch.zeros((1, 6), device=vote_det.device)
det_accu_sum[:, :4] = torch.sum(vote_det[:, :4], dim=0) / torch.sum(vote_det[:, 4])
det_accu_sum[:, 4] = max_score
det_accu_sum[:, 5] = vote_det[0, 5]
if soft_vote_det.shape[0] > 0:
det_accu_sum = torch.cat((det_accu_sum, soft_vote_det), dim=0)
return det_accu_sum