pg56714's picture
Upload 96 files
9043dc9 verified
import io
from typing import List
import cv2
import numpy as np
import torch
from torch.nn import functional as F
"""
Some functions in this file are modified from https://github.com/SysCV/sam-hq/blob/main/train/utils/misc.py.
"""
def point_sample(input, point_coords, **kwargs):
"""
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
[0, 1] x [0, 1] square.
Args:
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
[0, 1] x [0, 1] normalized point coordinates.
Returns:
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
features for points in `point_coords`. The features are obtained via bilinear
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
"""
add_dim = False
if point_coords.dim() == 3:
add_dim = True
point_coords = point_coords.unsqueeze(2)
output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
if add_dim:
output = output.squeeze(3)
return output
def cat(tensors: List[torch.Tensor], dim: int = 0):
"""
Efficient version of torch.cat that avoids a copy if there is only a single element in a list.
"""
assert isinstance(tensors, (list, tuple))
if len(tensors) == 1:
return tensors[0]
return torch.cat(tensors, dim)
def get_uncertain_point_coords_with_randomness(
coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio
):
"""
Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
are calculated for each point using 'uncertainty_func' function that takes point's logit
prediction as input.
See PointRend paper for details.
Args:
coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
class-specific or class-agnostic prediction.
uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
contains logit predictions for P points and returns their uncertainties as a Tensor of
shape (N, 1, P).
num_points (int): The number of points P to sample.
oversample_ratio (int): Oversampling parameter.
importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
Returns:
point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
sampled points.
"""
assert oversample_ratio >= 1
assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
num_boxes = coarse_logits.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
point_uncertainties = uncertainty_func(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
if num_random_points > 0:
point_coords = cat(
[
point_coords,
torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device),
],
dim=1,
)
return point_coords
def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, mode: str):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
"""
inputs = inputs.sigmoid()
inputs = inputs.flatten(1)
numerator = 2 * (inputs * targets).sum(-1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
if mode == "none":
return loss
else:
return loss.sum() / num_masks
dice_loss_jit = torch.jit.script(dice_loss) # type: torch.jit.ScriptModule
def sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, mode: str):
"""
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
Loss tensor
"""
loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
if mode == "none":
return loss.mean(1)
else:
return loss.mean(1).sum() / num_masks
sigmoid_ce_loss_jit = torch.jit.script(sigmoid_ce_loss) # type: torch.jit.ScriptModule
def calculate_uncertainty(logits):
"""
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
foreground class in `classes`.
Args:
logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
class-agnostic, where R is the total number of predicted masks in all images and C is
the number of foreground classes. The values are logits.
Returns:
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
the most uncertain locations having the highest uncertainty score.
"""
assert logits.shape[1] == 1
gt_class_logits = logits.clone()
return -(torch.abs(gt_class_logits))
def loss_masks(src_masks, target_masks, num_masks, oversample_ratio=3.0, mode="mean"):
"""
Compute the losses related to the masks: the focal loss and the dice loss.
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
"""
with torch.no_grad():
# sample point_coords
point_coords = get_uncertain_point_coords_with_randomness(
src_masks,
lambda logits: calculate_uncertainty(logits),
112 * 112,
oversample_ratio,
0.75,
)
# get gt labels
point_labels = point_sample(
target_masks,
point_coords,
align_corners=False,
).squeeze(1)
point_logits = point_sample(
src_masks,
point_coords,
align_corners=False,
).squeeze(1)
loss_mask = sigmoid_ce_loss_jit(point_logits, point_labels, num_masks, mode)
loss_dice = dice_loss_jit(point_logits, point_labels, num_masks, mode)
del src_masks
del target_masks
return loss_mask, loss_dice
def mask_iou(pred_label, label):
"""
calculate mask iou for pred_label and gt_label.
"""
pred_label = (pred_label > 0)[0].int()
label = (label > 128)[0].int()
intersection = ((label * pred_label) > 0).sum()
union = ((label + pred_label) > 0).sum()
return intersection / (union + 1e-6)
def compute_iou(preds, target):
if preds.shape[2] != target.shape[2] or preds.shape[3] != target.shape[3]:
postprocess_preds = F.interpolate(preds, size=target.size()[2:], mode="bilinear", align_corners=False)
else:
postprocess_preds = preds
iou = 0
for i in range(0, len(preds)):
iou = iou + mask_iou(postprocess_preds[i], target[i])
return iou / len(preds)
def mask_to_boundary(mask, dilation_ratio=0.02):
"""
Convert binary mask to boundary mask.
:param mask (numpy array, uint8): binary mask
:param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
:return: boundary mask (numpy array)
"""
h, w = mask.shape
img_diag = np.sqrt(h**2 + w**2)
dilation = int(round(dilation_ratio * img_diag))
if dilation < 1:
dilation = 1
# Pad image so mask truncated by the image border is also considered as boundary.
new_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0)
kernel = np.ones((3, 3), dtype=np.uint8)
new_mask_erode = cv2.erode(new_mask, kernel, iterations=dilation)
mask_erode = new_mask_erode[1 : h + 1, 1 : w + 1]
# G_d intersects G in the paper.
return mask - mask_erode
def boundary_iou(gt, dt, dilation_ratio=0.02):
"""
Compute boundary iou between two binary masks.
:param gt (numpy array, uint8): binary mask
:param dt (numpy array, uint8): binary mask
:param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
:return: boundary iou (float)
"""
device = gt.device
dt = (dt > 0)[0].cpu().byte().numpy()
gt = (gt > 128)[0].cpu().byte().numpy()
gt_boundary = mask_to_boundary(gt, dilation_ratio)
dt_boundary = mask_to_boundary(dt, dilation_ratio)
intersection = ((gt_boundary * dt_boundary) > 0).sum()
union = ((gt_boundary + dt_boundary) > 0).sum()
boundary_iou = intersection / (union + 1e-6)
return torch.tensor(boundary_iou).float().to(device)
def compute_boundary_iou(preds, target):
if preds.shape[2] != target.shape[2] or preds.shape[3] != target.shape[3]:
postprocess_preds = F.interpolate(preds, size=target.size()[2:], mode="bilinear", align_corners=False)
else:
postprocess_preds = preds
iou = 0
for i in range(0, len(preds)):
iou = iou + boundary_iou(target[i], postprocess_preds[i])
return iou / len(preds)
def masks_sample_points(masks, k=10):
"""Sample points on mask"""
if masks.numel() == 0:
return torch.zeros((0, 2), device=masks.device)
h, w = masks.shape[-2:]
y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x)
y = y.to(masks)
x = x.to(masks)
# k = 10
samples = []
for b_i in range(len(masks)):
select_mask = masks[b_i].bool()
x_idx = torch.masked_select(x, select_mask)
y_idx = torch.masked_select(y, select_mask)
perm = torch.randperm(x_idx.size(0))
idx = perm[:k]
samples_x = x_idx[idx]
samples_y = y_idx[idx]
samples_xy = torch.cat((samples_x[:, None], samples_y[:, None]), dim=1)
samples.append(samples_xy)
samples = torch.stack(samples)
return samples
def mask_iou_batch(pred_label, label):
"""
calculate mask iou for pred_label and gt_label.
"""
pred_label = (pred_label > 0).int()
label = (label > 128).int()
intersection = ((label * pred_label) > 0).sum(dim=(-1, -2))
union = ((label + pred_label) > 0).sum(dim=(-1, -2))
return intersection / (union + 1e-6)