|
import io
|
|
from typing import List
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
"""
|
|
Some functions in this file are modified from https://github.com/SysCV/sam-hq/blob/main/train/utils/misc.py.
|
|
"""
|
|
|
|
|
|
def point_sample(input, point_coords, **kwargs):
|
|
"""
|
|
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
|
|
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
|
|
[0, 1] x [0, 1] square.
|
|
Args:
|
|
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
|
|
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
|
|
[0, 1] x [0, 1] normalized point coordinates.
|
|
Returns:
|
|
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
|
|
features for points in `point_coords`. The features are obtained via bilinear
|
|
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
|
|
"""
|
|
|
|
add_dim = False
|
|
if point_coords.dim() == 3:
|
|
add_dim = True
|
|
point_coords = point_coords.unsqueeze(2)
|
|
output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
|
|
if add_dim:
|
|
output = output.squeeze(3)
|
|
return output
|
|
|
|
|
|
def cat(tensors: List[torch.Tensor], dim: int = 0):
|
|
"""
|
|
Efficient version of torch.cat that avoids a copy if there is only a single element in a list.
|
|
"""
|
|
|
|
assert isinstance(tensors, (list, tuple))
|
|
if len(tensors) == 1:
|
|
return tensors[0]
|
|
return torch.cat(tensors, dim)
|
|
|
|
|
|
def get_uncertain_point_coords_with_randomness(
|
|
coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio
|
|
):
|
|
"""
|
|
Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
|
|
are calculated for each point using 'uncertainty_func' function that takes point's logit
|
|
prediction as input.
|
|
See PointRend paper for details.
|
|
Args:
|
|
coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
|
|
class-specific or class-agnostic prediction.
|
|
uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
|
|
contains logit predictions for P points and returns their uncertainties as a Tensor of
|
|
shape (N, 1, P).
|
|
num_points (int): The number of points P to sample.
|
|
oversample_ratio (int): Oversampling parameter.
|
|
importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
|
|
Returns:
|
|
point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
|
|
sampled points.
|
|
"""
|
|
|
|
assert oversample_ratio >= 1
|
|
assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
|
|
num_boxes = coarse_logits.shape[0]
|
|
num_sampled = int(num_points * oversample_ratio)
|
|
point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
|
|
point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
|
|
point_uncertainties = uncertainty_func(point_logits)
|
|
num_uncertain_points = int(importance_sample_ratio * num_points)
|
|
num_random_points = num_points - num_uncertain_points
|
|
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
|
shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device)
|
|
idx += shift[:, None]
|
|
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
|
|
if num_random_points > 0:
|
|
point_coords = cat(
|
|
[
|
|
point_coords,
|
|
torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device),
|
|
],
|
|
dim=1,
|
|
)
|
|
return point_coords
|
|
|
|
|
|
def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, mode: str):
|
|
"""
|
|
Compute the DICE loss, similar to generalized IOU for masks
|
|
Args:
|
|
inputs: A float tensor of arbitrary shape.
|
|
The predictions for each example.
|
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
|
classification label for each element in inputs
|
|
(0 for the negative class and 1 for the positive class).
|
|
"""
|
|
inputs = inputs.sigmoid()
|
|
inputs = inputs.flatten(1)
|
|
numerator = 2 * (inputs * targets).sum(-1)
|
|
denominator = inputs.sum(-1) + targets.sum(-1)
|
|
loss = 1 - (numerator + 1) / (denominator + 1)
|
|
if mode == "none":
|
|
return loss
|
|
else:
|
|
return loss.sum() / num_masks
|
|
|
|
|
|
dice_loss_jit = torch.jit.script(dice_loss)
|
|
|
|
|
|
def sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, mode: str):
|
|
"""
|
|
Args:
|
|
inputs: A float tensor of arbitrary shape.
|
|
The predictions for each example.
|
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
|
classification label for each element in inputs
|
|
(0 for the negative class and 1 for the positive class).
|
|
Returns:
|
|
Loss tensor
|
|
"""
|
|
|
|
loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
|
|
|
if mode == "none":
|
|
return loss.mean(1)
|
|
else:
|
|
return loss.mean(1).sum() / num_masks
|
|
|
|
|
|
sigmoid_ce_loss_jit = torch.jit.script(sigmoid_ce_loss)
|
|
|
|
|
|
def calculate_uncertainty(logits):
|
|
"""
|
|
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
|
|
foreground class in `classes`.
|
|
Args:
|
|
logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
|
|
class-agnostic, where R is the total number of predicted masks in all images and C is
|
|
the number of foreground classes. The values are logits.
|
|
Returns:
|
|
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
|
|
the most uncertain locations having the highest uncertainty score.
|
|
"""
|
|
|
|
assert logits.shape[1] == 1
|
|
gt_class_logits = logits.clone()
|
|
return -(torch.abs(gt_class_logits))
|
|
|
|
|
|
def loss_masks(src_masks, target_masks, num_masks, oversample_ratio=3.0, mode="mean"):
|
|
"""
|
|
Compute the losses related to the masks: the focal loss and the dice loss.
|
|
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
|
|
"""
|
|
|
|
with torch.no_grad():
|
|
|
|
point_coords = get_uncertain_point_coords_with_randomness(
|
|
src_masks,
|
|
lambda logits: calculate_uncertainty(logits),
|
|
112 * 112,
|
|
oversample_ratio,
|
|
0.75,
|
|
)
|
|
|
|
point_labels = point_sample(
|
|
target_masks,
|
|
point_coords,
|
|
align_corners=False,
|
|
).squeeze(1)
|
|
|
|
point_logits = point_sample(
|
|
src_masks,
|
|
point_coords,
|
|
align_corners=False,
|
|
).squeeze(1)
|
|
|
|
loss_mask = sigmoid_ce_loss_jit(point_logits, point_labels, num_masks, mode)
|
|
loss_dice = dice_loss_jit(point_logits, point_labels, num_masks, mode)
|
|
|
|
del src_masks
|
|
del target_masks
|
|
return loss_mask, loss_dice
|
|
|
|
|
|
def mask_iou(pred_label, label):
|
|
"""
|
|
calculate mask iou for pred_label and gt_label.
|
|
"""
|
|
|
|
pred_label = (pred_label > 0)[0].int()
|
|
label = (label > 128)[0].int()
|
|
|
|
intersection = ((label * pred_label) > 0).sum()
|
|
union = ((label + pred_label) > 0).sum()
|
|
return intersection / (union + 1e-6)
|
|
|
|
|
|
def compute_iou(preds, target):
|
|
if preds.shape[2] != target.shape[2] or preds.shape[3] != target.shape[3]:
|
|
postprocess_preds = F.interpolate(preds, size=target.size()[2:], mode="bilinear", align_corners=False)
|
|
else:
|
|
postprocess_preds = preds
|
|
iou = 0
|
|
for i in range(0, len(preds)):
|
|
iou = iou + mask_iou(postprocess_preds[i], target[i])
|
|
return iou / len(preds)
|
|
|
|
|
|
def mask_to_boundary(mask, dilation_ratio=0.02):
|
|
"""
|
|
Convert binary mask to boundary mask.
|
|
:param mask (numpy array, uint8): binary mask
|
|
:param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
|
|
:return: boundary mask (numpy array)
|
|
"""
|
|
|
|
h, w = mask.shape
|
|
img_diag = np.sqrt(h**2 + w**2)
|
|
dilation = int(round(dilation_ratio * img_diag))
|
|
if dilation < 1:
|
|
dilation = 1
|
|
|
|
new_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0)
|
|
kernel = np.ones((3, 3), dtype=np.uint8)
|
|
new_mask_erode = cv2.erode(new_mask, kernel, iterations=dilation)
|
|
mask_erode = new_mask_erode[1 : h + 1, 1 : w + 1]
|
|
|
|
return mask - mask_erode
|
|
|
|
|
|
def boundary_iou(gt, dt, dilation_ratio=0.02):
|
|
"""
|
|
Compute boundary iou between two binary masks.
|
|
:param gt (numpy array, uint8): binary mask
|
|
:param dt (numpy array, uint8): binary mask
|
|
:param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
|
|
:return: boundary iou (float)
|
|
"""
|
|
|
|
device = gt.device
|
|
dt = (dt > 0)[0].cpu().byte().numpy()
|
|
gt = (gt > 128)[0].cpu().byte().numpy()
|
|
|
|
gt_boundary = mask_to_boundary(gt, dilation_ratio)
|
|
dt_boundary = mask_to_boundary(dt, dilation_ratio)
|
|
intersection = ((gt_boundary * dt_boundary) > 0).sum()
|
|
union = ((gt_boundary + dt_boundary) > 0).sum()
|
|
boundary_iou = intersection / (union + 1e-6)
|
|
return torch.tensor(boundary_iou).float().to(device)
|
|
|
|
|
|
def compute_boundary_iou(preds, target):
|
|
if preds.shape[2] != target.shape[2] or preds.shape[3] != target.shape[3]:
|
|
postprocess_preds = F.interpolate(preds, size=target.size()[2:], mode="bilinear", align_corners=False)
|
|
else:
|
|
postprocess_preds = preds
|
|
iou = 0
|
|
for i in range(0, len(preds)):
|
|
iou = iou + boundary_iou(target[i], postprocess_preds[i])
|
|
return iou / len(preds)
|
|
|
|
|
|
def masks_sample_points(masks, k=10):
|
|
"""Sample points on mask"""
|
|
|
|
if masks.numel() == 0:
|
|
return torch.zeros((0, 2), device=masks.device)
|
|
|
|
h, w = masks.shape[-2:]
|
|
|
|
y = torch.arange(0, h, dtype=torch.float)
|
|
x = torch.arange(0, w, dtype=torch.float)
|
|
y, x = torch.meshgrid(y, x)
|
|
y = y.to(masks)
|
|
x = x.to(masks)
|
|
|
|
|
|
samples = []
|
|
for b_i in range(len(masks)):
|
|
select_mask = masks[b_i].bool()
|
|
x_idx = torch.masked_select(x, select_mask)
|
|
y_idx = torch.masked_select(y, select_mask)
|
|
|
|
perm = torch.randperm(x_idx.size(0))
|
|
idx = perm[:k]
|
|
samples_x = x_idx[idx]
|
|
samples_y = y_idx[idx]
|
|
samples_xy = torch.cat((samples_x[:, None], samples_y[:, None]), dim=1)
|
|
samples.append(samples_xy)
|
|
|
|
samples = torch.stack(samples)
|
|
|
|
return samples
|
|
|
|
|
|
def mask_iou_batch(pred_label, label):
|
|
"""
|
|
calculate mask iou for pred_label and gt_label.
|
|
"""
|
|
|
|
pred_label = (pred_label > 0).int()
|
|
label = (label > 128).int()
|
|
|
|
intersection = ((label * pred_label) > 0).sum(dim=(-1, -2))
|
|
union = ((label + pred_label) > 0).sum(dim=(-1, -2))
|
|
return intersection / (union + 1e-6)
|
|
|