|
|
|
import numpy as np
|
|
from typing import Tuple
|
|
import torch
|
|
from PIL import Image
|
|
from torch.nn import functional as F
|
|
|
|
__all__ = ["paste_masks_in_image"]
|
|
|
|
|
|
BYTES_PER_FLOAT = 4
|
|
|
|
|
|
GPU_MEM_LIMIT = 1024**3
|
|
|
|
|
|
def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True):
|
|
"""
|
|
Args:
|
|
masks: N, 1, H, W
|
|
boxes: N, 4
|
|
img_h, img_w (int):
|
|
skip_empty (bool): only paste masks within the region that
|
|
tightly bound all boxes, and returns the results this region only.
|
|
An important optimization for CPU.
|
|
|
|
Returns:
|
|
if skip_empty == False, a mask of shape (N, img_h, img_w)
|
|
if skip_empty == True, a mask of shape (N, h', w'), and the slice
|
|
object for the corresponding region.
|
|
"""
|
|
|
|
|
|
|
|
|
|
device = masks.device
|
|
|
|
if skip_empty and not torch.jit.is_scripting():
|
|
x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(
|
|
dtype=torch.int32
|
|
)
|
|
x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32)
|
|
y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32)
|
|
else:
|
|
x0_int, y0_int = 0, 0
|
|
x1_int, y1_int = img_w, img_h
|
|
x0, y0, x1, y1 = torch.split(boxes, 1, dim=1)
|
|
|
|
N = masks.shape[0]
|
|
|
|
img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
|
|
img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
|
|
img_y = (img_y - y0) / (y1 - y0) * 2 - 1
|
|
img_x = (img_x - x0) / (x1 - x0) * 2 - 1
|
|
|
|
|
|
gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
|
|
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
|
|
grid = torch.stack([gx, gy], dim=3)
|
|
|
|
if not torch.jit.is_scripting():
|
|
if not masks.dtype.is_floating_point:
|
|
masks = masks.float()
|
|
img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)
|
|
|
|
if skip_empty and not torch.jit.is_scripting():
|
|
return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
|
|
else:
|
|
return img_masks[:, 0], ()
|
|
|
|
|
|
|
|
@torch.jit.script_if_tracing
|
|
def paste_masks_in_image(
|
|
masks: torch.Tensor, boxes: torch.Tensor, image_shape: Tuple[int, int], threshold: float = 0.5
|
|
):
|
|
"""
|
|
Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image.
|
|
The location, height, and width for pasting each mask is determined by their
|
|
corresponding bounding boxes in boxes.
|
|
|
|
Note:
|
|
This is a complicated but more accurate implementation. In actual deployment, it is
|
|
often enough to use a faster but less accurate implementation.
|
|
See :func:`paste_mask_in_image_old` in this file for an alternative implementation.
|
|
|
|
Args:
|
|
masks (tensor): Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of
|
|
detected object instances in the image and Hmask, Wmask are the mask width and mask
|
|
height of the predicted mask (e.g., Hmask = Wmask = 28). Values are in [0, 1].
|
|
boxes (Boxes or Tensor): A Boxes of length Bimg or Tensor of shape (Bimg, 4).
|
|
boxes[i] and masks[i] correspond to the same object instance.
|
|
image_shape (tuple): height, width
|
|
threshold (float): A threshold in [0, 1] for converting the (soft) masks to
|
|
binary masks.
|
|
|
|
Returns:
|
|
img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the
|
|
number of detected object instances and Himage, Wimage are the image width
|
|
and height. img_masks[i] is a binary mask for object instance i.
|
|
"""
|
|
|
|
assert masks.shape[-1] == masks.shape[-2], "Only square mask predictions are supported"
|
|
N = len(masks)
|
|
if N == 0:
|
|
return masks.new_empty((0,) + image_shape, dtype=torch.uint8)
|
|
if not isinstance(boxes, torch.Tensor):
|
|
boxes = boxes.tensor
|
|
device = boxes.device
|
|
assert len(boxes) == N, boxes.shape
|
|
|
|
img_h, img_w = image_shape
|
|
|
|
|
|
|
|
if device.type == "cpu" or torch.jit.is_scripting():
|
|
|
|
|
|
num_chunks = N
|
|
else:
|
|
|
|
|
|
num_chunks = int(np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
|
|
assert (
|
|
num_chunks <= N
|
|
), "Default GPU_MEM_LIMIT in mask_ops.py is too small; try increasing it"
|
|
chunks = torch.chunk(torch.arange(N, device=device), num_chunks)
|
|
|
|
img_masks = torch.zeros(
|
|
N, img_h, img_w, device=device, dtype=torch.bool if threshold >= 0 else torch.uint8
|
|
)
|
|
for inds in chunks:
|
|
masks_chunk, spatial_inds = _do_paste_mask(
|
|
masks[inds, None, :, :], boxes[inds], img_h, img_w, skip_empty=device.type == "cpu"
|
|
)
|
|
|
|
if threshold >= 0:
|
|
masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
|
|
else:
|
|
|
|
masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)
|
|
|
|
if torch.jit.is_scripting():
|
|
img_masks[inds] = masks_chunk
|
|
else:
|
|
img_masks[(inds,) + spatial_inds] = masks_chunk
|
|
return img_masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def paste_mask_in_image_old(mask, box, img_h, img_w, threshold):
|
|
"""
|
|
Paste a single mask in an image.
|
|
This is a per-box implementation of :func:`paste_masks_in_image`.
|
|
This function has larger quantization error due to incorrect pixel
|
|
modeling and is not used any more.
|
|
|
|
Args:
|
|
mask (Tensor): A tensor of shape (Hmask, Wmask) storing the mask of a single
|
|
object instance. Values are in [0, 1].
|
|
box (Tensor): A tensor of shape (4, ) storing the x0, y0, x1, y1 box corners
|
|
of the object instance.
|
|
img_h, img_w (int): Image height and width.
|
|
threshold (float): Mask binarization threshold in [0, 1].
|
|
|
|
Returns:
|
|
im_mask (Tensor):
|
|
The resized and binarized object mask pasted into the original
|
|
image plane (a tensor of shape (img_h, img_w)).
|
|
"""
|
|
|
|
|
|
|
|
box = box.to(dtype=torch.int32)
|
|
|
|
|
|
|
|
samples_w = box[2] - box[0] + 1
|
|
samples_h = box[3] - box[1] + 1
|
|
|
|
|
|
mask = Image.fromarray(mask.cpu().numpy())
|
|
mask = mask.resize((samples_w, samples_h), resample=Image.BILINEAR)
|
|
mask = np.array(mask, copy=False)
|
|
|
|
if threshold >= 0:
|
|
mask = np.array(mask > threshold, dtype=np.uint8)
|
|
mask = torch.from_numpy(mask)
|
|
else:
|
|
|
|
|
|
mask = torch.from_numpy(mask * 255).to(torch.uint8)
|
|
|
|
im_mask = torch.zeros((img_h, img_w), dtype=torch.uint8)
|
|
x_0 = max(box[0], 0)
|
|
x_1 = min(box[2] + 1, img_w)
|
|
y_0 = max(box[1], 0)
|
|
y_1 = min(box[3] + 1, img_h)
|
|
|
|
im_mask[y_0:y_1, x_0:x_1] = mask[
|
|
(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])
|
|
]
|
|
return im_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pad_masks(masks, padding):
|
|
"""
|
|
Args:
|
|
masks (tensor): A tensor of shape (B, M, M) representing B masks.
|
|
padding (int): Number of cells to pad on all sides.
|
|
|
|
Returns:
|
|
The padded masks and the scale factor of the padding size / original size.
|
|
"""
|
|
B = masks.shape[0]
|
|
M = masks.shape[-1]
|
|
pad2 = 2 * padding
|
|
scale = float(M + pad2) / M
|
|
padded_masks = masks.new_zeros((B, M + pad2, M + pad2))
|
|
padded_masks[:, padding:-padding, padding:-padding] = masks
|
|
return padded_masks, scale
|
|
|
|
|
|
def scale_boxes(boxes, scale):
|
|
"""
|
|
Args:
|
|
boxes (tensor): A tensor of shape (B, 4) representing B boxes with 4
|
|
coords representing the corners x0, y0, x1, y1,
|
|
scale (float): The box scaling factor.
|
|
|
|
Returns:
|
|
Scaled boxes.
|
|
"""
|
|
w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
|
|
h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
|
|
x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
|
|
y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
|
|
|
|
w_half *= scale
|
|
h_half *= scale
|
|
|
|
scaled_boxes = torch.zeros_like(boxes)
|
|
scaled_boxes[:, 0] = x_c - w_half
|
|
scaled_boxes[:, 2] = x_c + w_half
|
|
scaled_boxes[:, 1] = y_c - h_half
|
|
scaled_boxes[:, 3] = y_c + h_half
|
|
return scaled_boxes
|
|
|
|
|
|
@torch.jit.script_if_tracing
|
|
def _paste_masks_tensor_shape(
|
|
masks: torch.Tensor,
|
|
boxes: torch.Tensor,
|
|
image_shape: Tuple[torch.Tensor, torch.Tensor],
|
|
threshold: float = 0.5,
|
|
):
|
|
"""
|
|
A wrapper of paste_masks_in_image where image_shape is Tensor.
|
|
During tracing, shapes might be tensors instead of ints. The Tensor->int
|
|
conversion should be scripted rather than traced.
|
|
"""
|
|
return paste_masks_in_image(masks, boxes, (int(image_shape[0]), int(image_shape[1])), threshold)
|
|
|