|
import enum |
|
from copy import deepcopy |
|
|
|
import numpy as np |
|
from skimage import img_as_ubyte |
|
from skimage.transform import rescale, resize |
|
try: |
|
from detectron2 import model_zoo |
|
from detectron2.config import get_cfg |
|
from detectron2.engine import DefaultPredictor |
|
DETECTRON_INSTALLED = True |
|
except: |
|
print("Detectron v2 is not installed") |
|
DETECTRON_INSTALLED = False |
|
|
|
from .countless.countless2d import zero_corrected_countless |
|
|
|
|
|
class ObjectMask(): |
|
def __init__(self, mask): |
|
self.height, self.width = mask.shape |
|
(self.up, self.down), (self.left, self.right) = self._get_limits(mask) |
|
self.mask = mask[self.up:self.down, self.left:self.right].copy() |
|
|
|
@staticmethod |
|
def _get_limits(mask): |
|
def indicator_limits(indicator): |
|
lower = indicator.argmax() |
|
upper = len(indicator) - indicator[::-1].argmax() |
|
return lower, upper |
|
|
|
vertical_indicator = mask.any(axis=1) |
|
vertical_limits = indicator_limits(vertical_indicator) |
|
|
|
horizontal_indicator = mask.any(axis=0) |
|
horizontal_limits = indicator_limits(horizontal_indicator) |
|
|
|
return vertical_limits, horizontal_limits |
|
|
|
def _clean(self): |
|
self.up, self.down, self.left, self.right = 0, 0, 0, 0 |
|
self.mask = np.empty((0, 0)) |
|
|
|
def horizontal_flip(self, inplace=False): |
|
if not inplace: |
|
flipped = deepcopy(self) |
|
return flipped.horizontal_flip(inplace=True) |
|
|
|
self.mask = self.mask[:, ::-1] |
|
return self |
|
|
|
def vertical_flip(self, inplace=False): |
|
if not inplace: |
|
flipped = deepcopy(self) |
|
return flipped.vertical_flip(inplace=True) |
|
|
|
self.mask = self.mask[::-1, :] |
|
return self |
|
|
|
def image_center(self): |
|
y_center = self.up + (self.down - self.up) / 2 |
|
x_center = self.left + (self.right - self.left) / 2 |
|
return y_center, x_center |
|
|
|
def rescale(self, scaling_factor, inplace=False): |
|
if not inplace: |
|
scaled = deepcopy(self) |
|
return scaled.rescale(scaling_factor, inplace=True) |
|
|
|
scaled_mask = rescale(self.mask.astype(float), scaling_factor, order=0) > 0.5 |
|
(up, down), (left, right) = self._get_limits(scaled_mask) |
|
self.mask = scaled_mask[up:down, left:right] |
|
|
|
y_center, x_center = self.image_center() |
|
mask_height, mask_width = self.mask.shape |
|
self.up = int(round(y_center - mask_height / 2)) |
|
self.down = self.up + mask_height |
|
self.left = int(round(x_center - mask_width / 2)) |
|
self.right = self.left + mask_width |
|
return self |
|
|
|
def crop_to_canvas(self, vertical=True, horizontal=True, inplace=False): |
|
if not inplace: |
|
cropped = deepcopy(self) |
|
cropped.crop_to_canvas(vertical=vertical, horizontal=horizontal, inplace=True) |
|
return cropped |
|
|
|
if vertical: |
|
if self.up >= self.height or self.down <= 0: |
|
self._clean() |
|
else: |
|
cut_up, cut_down = max(-self.up, 0), max(self.down - self.height, 0) |
|
if cut_up != 0: |
|
self.mask = self.mask[cut_up:] |
|
self.up = 0 |
|
if cut_down != 0: |
|
self.mask = self.mask[:-cut_down] |
|
self.down = self.height |
|
|
|
if horizontal: |
|
if self.left >= self.width or self.right <= 0: |
|
self._clean() |
|
else: |
|
cut_left, cut_right = max(-self.left, 0), max(self.right - self.width, 0) |
|
if cut_left != 0: |
|
self.mask = self.mask[:, cut_left:] |
|
self.left = 0 |
|
if cut_right != 0: |
|
self.mask = self.mask[:, :-cut_right] |
|
self.right = self.width |
|
|
|
return self |
|
|
|
def restore_full_mask(self, allow_crop=False): |
|
cropped = self.crop_to_canvas(inplace=allow_crop) |
|
mask = np.zeros((cropped.height, cropped.width), dtype=bool) |
|
mask[cropped.up:cropped.down, cropped.left:cropped.right] = cropped.mask |
|
return mask |
|
|
|
def shift(self, vertical=0, horizontal=0, inplace=False): |
|
if not inplace: |
|
shifted = deepcopy(self) |
|
return shifted.shift(vertical=vertical, horizontal=horizontal, inplace=True) |
|
|
|
self.up += vertical |
|
self.down += vertical |
|
self.left += horizontal |
|
self.right += horizontal |
|
return self |
|
|
|
def area(self): |
|
return self.mask.sum() |
|
|
|
|
|
class RigidnessMode(enum.Enum): |
|
soft = 0 |
|
rigid = 1 |
|
|
|
|
|
class SegmentationMask: |
|
def __init__(self, confidence_threshold=0.5, rigidness_mode=RigidnessMode.rigid, |
|
max_object_area=0.3, min_mask_area=0.02, downsample_levels=6, num_variants_per_mask=4, |
|
max_mask_intersection=0.5, max_foreground_coverage=0.5, max_foreground_intersection=0.5, |
|
max_hidden_area=0.2, max_scale_change=0.25, horizontal_flip=True, |
|
max_vertical_shift=0.1, position_shuffle=True): |
|
""" |
|
:param confidence_threshold: float; threshold for confidence of the panoptic segmentator to allow for |
|
the instance. |
|
:param rigidness_mode: RigidnessMode object |
|
when soft, checks intersection only with the object from which the mask_object was produced |
|
when rigid, checks intersection with any foreground class object |
|
:param max_object_area: float; allowed upper bound for to be considered as mask_object. |
|
:param min_mask_area: float; lower bound for mask to be considered valid |
|
:param downsample_levels: int; defines width of the resized segmentation to obtain shifted masks; |
|
:param num_variants_per_mask: int; maximal number of the masks for the same object; |
|
:param max_mask_intersection: float; maximum allowed area fraction of intersection for 2 masks |
|
produced by horizontal shift of the same mask_object; higher value -> more diversity |
|
:param max_foreground_coverage: float; maximum allowed area fraction of intersection for foreground object to be |
|
covered by mask; lower value -> less the objects are covered |
|
:param max_foreground_intersection: float; maximum allowed area of intersection for the mask with foreground |
|
object; lower value -> mask is more on the background than on the objects |
|
:param max_hidden_area: upper bound on part of the object hidden by shifting object outside the screen area; |
|
:param max_scale_change: allowed scale change for the mask_object; |
|
:param horizontal_flip: if horizontal flips are allowed; |
|
:param max_vertical_shift: amount of vertical movement allowed; |
|
:param position_shuffle: shuffle |
|
""" |
|
|
|
assert DETECTRON_INSTALLED, 'Cannot use SegmentationMask without detectron2' |
|
self.cfg = get_cfg() |
|
self.cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")) |
|
self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml") |
|
self.cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = confidence_threshold |
|
self.predictor = DefaultPredictor(self.cfg) |
|
|
|
self.rigidness_mode = RigidnessMode(rigidness_mode) |
|
self.max_object_area = max_object_area |
|
self.min_mask_area = min_mask_area |
|
self.downsample_levels = downsample_levels |
|
self.num_variants_per_mask = num_variants_per_mask |
|
self.max_mask_intersection = max_mask_intersection |
|
self.max_foreground_coverage = max_foreground_coverage |
|
self.max_foreground_intersection = max_foreground_intersection |
|
self.max_hidden_area = max_hidden_area |
|
self.position_shuffle = position_shuffle |
|
|
|
self.max_scale_change = max_scale_change |
|
self.horizontal_flip = horizontal_flip |
|
self.max_vertical_shift = max_vertical_shift |
|
|
|
def get_segmentation(self, img): |
|
im = img_as_ubyte(img) |
|
panoptic_seg, segment_info = self.predictor(im)["panoptic_seg"] |
|
return panoptic_seg, segment_info |
|
|
|
@staticmethod |
|
def _is_power_of_two(n): |
|
return (n != 0) and (n & (n-1) == 0) |
|
|
|
def identify_candidates(self, panoptic_seg, segments_info): |
|
potential_mask_ids = [] |
|
for segment in segments_info: |
|
if not segment["isthing"]: |
|
continue |
|
mask = (panoptic_seg == segment["id"]).int().detach().cpu().numpy() |
|
area = mask.sum().item() / np.prod(panoptic_seg.shape) |
|
if area >= self.max_object_area: |
|
continue |
|
potential_mask_ids.append(segment["id"]) |
|
return potential_mask_ids |
|
|
|
def downsample_mask(self, mask): |
|
height, width = mask.shape |
|
if not (self._is_power_of_two(height) and self._is_power_of_two(width)): |
|
raise ValueError("Image sides are not power of 2.") |
|
|
|
num_iterations = width.bit_length() - 1 - self.downsample_levels |
|
if num_iterations < 0: |
|
raise ValueError(f"Width is lower than 2^{self.downsample_levels}.") |
|
|
|
if height.bit_length() - 1 < num_iterations: |
|
raise ValueError("Height is too low to perform downsampling") |
|
|
|
downsampled = mask |
|
for _ in range(num_iterations): |
|
downsampled = zero_corrected_countless(downsampled) |
|
|
|
return downsampled |
|
|
|
def _augmentation_params(self): |
|
scaling_factor = np.random.uniform(1 - self.max_scale_change, 1 + self.max_scale_change) |
|
if self.horizontal_flip: |
|
horizontal_flip = bool(np.random.choice(2)) |
|
else: |
|
horizontal_flip = False |
|
vertical_shift = np.random.uniform(-self.max_vertical_shift, self.max_vertical_shift) |
|
|
|
return { |
|
"scaling_factor": scaling_factor, |
|
"horizontal_flip": horizontal_flip, |
|
"vertical_shift": vertical_shift |
|
} |
|
|
|
def _get_intersection(self, mask_array, mask_object): |
|
intersection = mask_array[ |
|
mask_object.up:mask_object.down, mask_object.left:mask_object.right |
|
] & mask_object.mask |
|
return intersection |
|
|
|
def _check_masks_intersection(self, aug_mask, total_mask_area, prev_masks): |
|
for existing_mask in prev_masks: |
|
intersection_area = self._get_intersection(existing_mask, aug_mask).sum() |
|
intersection_existing = intersection_area / existing_mask.sum() |
|
intersection_current = 1 - (aug_mask.area() - intersection_area) / total_mask_area |
|
if (intersection_existing > self.max_mask_intersection) or \ |
|
(intersection_current > self.max_mask_intersection): |
|
return False |
|
return True |
|
|
|
def _check_foreground_intersection(self, aug_mask, foreground): |
|
for existing_mask in foreground: |
|
intersection_area = self._get_intersection(existing_mask, aug_mask).sum() |
|
intersection_existing = intersection_area / existing_mask.sum() |
|
if intersection_existing > self.max_foreground_coverage: |
|
return False |
|
intersection_mask = intersection_area / aug_mask.area() |
|
if intersection_mask > self.max_foreground_intersection: |
|
return False |
|
return True |
|
|
|
def _move_mask(self, mask, foreground): |
|
|
|
orig_mask = ObjectMask(mask) |
|
|
|
chosen_masks = [] |
|
chosen_parameters = [] |
|
|
|
scaling_factor_lower_bound = 0. |
|
|
|
for var_idx in range(self.num_variants_per_mask): |
|
|
|
augmentation_params = self._augmentation_params() |
|
augmentation_params["scaling_factor"] = min([ |
|
augmentation_params["scaling_factor"], |
|
2 * min(orig_mask.up, orig_mask.height - orig_mask.down) / orig_mask.height + 1., |
|
2 * min(orig_mask.left, orig_mask.width - orig_mask.right) / orig_mask.width + 1. |
|
]) |
|
augmentation_params["scaling_factor"] = max([ |
|
augmentation_params["scaling_factor"], scaling_factor_lower_bound |
|
]) |
|
|
|
aug_mask = deepcopy(orig_mask) |
|
aug_mask.rescale(augmentation_params["scaling_factor"], inplace=True) |
|
if augmentation_params["horizontal_flip"]: |
|
aug_mask.horizontal_flip(inplace=True) |
|
total_aug_area = aug_mask.area() |
|
if total_aug_area == 0: |
|
scaling_factor_lower_bound = 1. |
|
continue |
|
|
|
|
|
vertical_area = aug_mask.mask.sum(axis=1) / total_aug_area |
|
|
|
max_hidden_up = np.searchsorted(vertical_area.cumsum(), self.max_hidden_area) |
|
max_hidden_down = np.searchsorted(vertical_area[::-1].cumsum(), self.max_hidden_area) |
|
|
|
augmentation_params["vertical_shift"] = np.clip( |
|
augmentation_params["vertical_shift"], |
|
-(aug_mask.up + max_hidden_up) / aug_mask.height, |
|
(aug_mask.height - aug_mask.down + max_hidden_down) / aug_mask.height |
|
) |
|
|
|
vertical_shift = int(round(aug_mask.height * augmentation_params["vertical_shift"])) |
|
aug_mask.shift(vertical=vertical_shift, inplace=True) |
|
aug_mask.crop_to_canvas(vertical=True, horizontal=False, inplace=True) |
|
|
|
|
|
max_hidden_area = self.max_hidden_area - (1 - aug_mask.area() / total_aug_area) |
|
horizontal_area = aug_mask.mask.sum(axis=0) / total_aug_area |
|
max_hidden_left = np.searchsorted(horizontal_area.cumsum(), max_hidden_area) |
|
max_hidden_right = np.searchsorted(horizontal_area[::-1].cumsum(), max_hidden_area) |
|
allowed_shifts = np.arange(-max_hidden_left, aug_mask.width - |
|
(aug_mask.right - aug_mask.left) + max_hidden_right + 1) |
|
allowed_shifts = - (aug_mask.left - allowed_shifts) |
|
|
|
if self.position_shuffle: |
|
np.random.shuffle(allowed_shifts) |
|
|
|
mask_is_found = False |
|
for horizontal_shift in allowed_shifts: |
|
aug_mask_left = deepcopy(aug_mask) |
|
aug_mask_left.shift(horizontal=horizontal_shift, inplace=True) |
|
aug_mask_left.crop_to_canvas(inplace=True) |
|
|
|
prev_masks = [mask] + chosen_masks |
|
is_mask_suitable = self._check_masks_intersection(aug_mask_left, total_aug_area, prev_masks) & \ |
|
self._check_foreground_intersection(aug_mask_left, foreground) |
|
if is_mask_suitable: |
|
aug_draw = aug_mask_left.restore_full_mask() |
|
chosen_masks.append(aug_draw) |
|
augmentation_params["horizontal_shift"] = horizontal_shift / aug_mask_left.width |
|
chosen_parameters.append(augmentation_params) |
|
mask_is_found = True |
|
break |
|
|
|
if not mask_is_found: |
|
break |
|
|
|
return chosen_parameters |
|
|
|
def _prepare_mask(self, mask): |
|
height, width = mask.shape |
|
target_width = width if self._is_power_of_two(width) else (1 << width.bit_length()) |
|
target_height = height if self._is_power_of_two(height) else (1 << height.bit_length()) |
|
|
|
return resize(mask.astype('float32'), (target_height, target_width), order=0, mode='edge').round().astype('int32') |
|
|
|
def get_masks(self, im, return_panoptic=False): |
|
panoptic_seg, segments_info = self.get_segmentation(im) |
|
potential_mask_ids = self.identify_candidates(panoptic_seg, segments_info) |
|
|
|
panoptic_seg_scaled = self._prepare_mask(panoptic_seg.detach().cpu().numpy()) |
|
downsampled = self.downsample_mask(panoptic_seg_scaled) |
|
scene_objects = [] |
|
for segment in segments_info: |
|
if not segment["isthing"]: |
|
continue |
|
mask = downsampled == segment["id"] |
|
if not np.any(mask): |
|
continue |
|
scene_objects.append(mask) |
|
|
|
mask_set = [] |
|
for mask_id in potential_mask_ids: |
|
mask = downsampled == mask_id |
|
if not np.any(mask): |
|
continue |
|
|
|
if self.rigidness_mode is RigidnessMode.soft: |
|
foreground = [mask] |
|
elif self.rigidness_mode is RigidnessMode.rigid: |
|
foreground = scene_objects |
|
else: |
|
raise ValueError(f'Unexpected rigidness_mode: {rigidness_mode}') |
|
|
|
masks_params = self._move_mask(mask, foreground) |
|
|
|
full_mask = ObjectMask((panoptic_seg == mask_id).detach().cpu().numpy()) |
|
|
|
for params in masks_params: |
|
aug_mask = deepcopy(full_mask) |
|
aug_mask.rescale(params["scaling_factor"], inplace=True) |
|
if params["horizontal_flip"]: |
|
aug_mask.horizontal_flip(inplace=True) |
|
|
|
vertical_shift = int(round(aug_mask.height * params["vertical_shift"])) |
|
horizontal_shift = int(round(aug_mask.width * params["horizontal_shift"])) |
|
aug_mask.shift(vertical=vertical_shift, horizontal=horizontal_shift, inplace=True) |
|
aug_mask = aug_mask.restore_full_mask().astype('uint8') |
|
if aug_mask.mean() <= self.min_mask_area: |
|
continue |
|
mask_set.append(aug_mask) |
|
|
|
if return_panoptic: |
|
return mask_set, panoptic_seg.detach().cpu().numpy() |
|
else: |
|
return mask_set |
|
|
|
|
|
def propose_random_square_crop(mask, min_overlap=0.5): |
|
height, width = mask.shape |
|
mask_ys, mask_xs = np.where(mask > 0.5) |
|
|
|
if height < width: |
|
crop_size = height |
|
obj_left, obj_right = mask_xs.min(), mask_xs.max() |
|
obj_width = obj_right - obj_left |
|
left_border = max(0, min(width - crop_size - 1, obj_left + obj_width * min_overlap - crop_size)) |
|
right_border = max(left_border + 1, min(width - crop_size, obj_left + obj_width * min_overlap)) |
|
start_x = np.random.randint(left_border, right_border) |
|
return start_x, 0, start_x + crop_size, height |
|
else: |
|
crop_size = width |
|
obj_top, obj_bottom = mask_ys.min(), mask_ys.max() |
|
obj_height = obj_bottom - obj_top |
|
top_border = max(0, min(height - crop_size - 1, obj_top + obj_height * min_overlap - crop_size)) |
|
bottom_border = max(top_border + 1, min(height - crop_size, obj_top + obj_height * min_overlap)) |
|
start_y = np.random.randint(top_border, bottom_border) |
|
return 0, start_y, width, start_y + crop_size |
|
|