|
import numpy as np |
|
|
|
from ..builder import PIPELINES |
|
|
|
|
|
@PIPELINES.register_module() |
|
class InstaBoost(object): |
|
r"""Data augmentation method in `InstaBoost: Boosting Instance |
|
Segmentation Via Probability Map Guided Copy-Pasting |
|
<https://arxiv.org/abs/1908.07801>`_. |
|
|
|
Refer to https://github.com/GothicAi/Instaboost for implementation details. |
|
""" |
|
|
|
def __init__(self, |
|
action_candidate=('normal', 'horizontal', 'skip'), |
|
action_prob=(1, 0, 0), |
|
scale=(0.8, 1.2), |
|
dx=15, |
|
dy=15, |
|
theta=(-1, 1), |
|
color_prob=0.5, |
|
hflag=False, |
|
aug_ratio=0.5): |
|
try: |
|
import instaboostfast as instaboost |
|
except ImportError: |
|
raise ImportError( |
|
'Please run "pip install instaboostfast" ' |
|
'to install instaboostfast first for instaboost augmentation.') |
|
self.cfg = instaboost.InstaBoostConfig(action_candidate, action_prob, |
|
scale, dx, dy, theta, |
|
color_prob, hflag) |
|
self.aug_ratio = aug_ratio |
|
|
|
def _load_anns(self, results): |
|
labels = results['ann_info']['labels'] |
|
masks = results['ann_info']['masks'] |
|
bboxes = results['ann_info']['bboxes'] |
|
n = len(labels) |
|
|
|
anns = [] |
|
for i in range(n): |
|
label = labels[i] |
|
bbox = bboxes[i] |
|
mask = masks[i] |
|
x1, y1, x2, y2 = bbox |
|
|
|
bbox = [x1, y1, x2 - x1, y2 - y1] |
|
anns.append({ |
|
'category_id': label, |
|
'segmentation': mask, |
|
'bbox': bbox |
|
}) |
|
|
|
return anns |
|
|
|
def _parse_anns(self, results, anns, img): |
|
gt_bboxes = [] |
|
gt_labels = [] |
|
gt_masks_ann = [] |
|
for ann in anns: |
|
x1, y1, w, h = ann['bbox'] |
|
|
|
if w <= 0 or h <= 0: |
|
continue |
|
bbox = [x1, y1, x1 + w, y1 + h] |
|
gt_bboxes.append(bbox) |
|
gt_labels.append(ann['category_id']) |
|
gt_masks_ann.append(ann['segmentation']) |
|
gt_bboxes = np.array(gt_bboxes, dtype=np.float32) |
|
gt_labels = np.array(gt_labels, dtype=np.int64) |
|
results['ann_info']['labels'] = gt_labels |
|
results['ann_info']['bboxes'] = gt_bboxes |
|
results['ann_info']['masks'] = gt_masks_ann |
|
results['img'] = img |
|
return results |
|
|
|
def __call__(self, results): |
|
img = results['img'] |
|
orig_type = img.dtype |
|
anns = self._load_anns(results) |
|
if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]): |
|
try: |
|
import instaboostfast as instaboost |
|
except ImportError: |
|
raise ImportError('Please run "pip install instaboostfast" ' |
|
'to install instaboostfast first.') |
|
anns, img = instaboost.get_new_data( |
|
anns, img.astype(np.uint8), self.cfg, background=None) |
|
|
|
results = self._parse_anns(results, anns, img.astype(orig_type)) |
|
return results |
|
|
|
def __repr__(self): |
|
repr_str = self.__class__.__name__ |
|
repr_str += f'(cfg={self.cfg}, aug_ratio={self.aug_ratio})' |
|
return repr_str |
|
|