Qihang Yu
Add kMaX-DeepLab
a06fad0
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py
# Reference: https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py
# Modified by Qihang Yu
import torch
import torch.nn.functional as F
from torch import nn
_SOFTMAX_MASKING_CONSTANT = -99999.0
# https://www.tensorflow.org/api_docs/python/tf/math/divide_no_nan
def divide_no_nan(x: torch.Tensor, y: torch.Tensor):
return torch.nan_to_num(x / y, nan=0.0, posinf=0.0, neginf=0.0)
# https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L393
def focal_cross_entropy_loss(
pred: torch.Tensor,
gt: torch.Tensor,
weight: torch.Tensor, # This is for PQ-loss weighting
focal_loss_alpha: float = 0.75,
focal_loss_gamma: float = 0.0,
background_channel_index: int = -1):
"""
pred: B x N x C
gt: B x N
weight: B x N
"""
pred = pred.transpose(1, 2) # B x C x N
gt = F.one_hot(gt, num_classes=pred.shape[1]).transpose(1, 2).to(pred) # B x C x N
loss = F.cross_entropy(pred, gt, reduction="none") # B x N
if focal_loss_gamma == 0.0:
focal_loss = loss
else:
pred = F.softmax(pred, dim=1) # B x C x N
pt = (pred * gt).sum(1) # B x N
focal_loss = torch.pow(1.0 - pt, focal_loss_gamma) * loss # B x N
if focal_loss_alpha >= 0:
alpha_weights = (
focal_loss_alpha * (1.0 - gt[:, background_channel_index])
+ (1 - focal_loss_alpha) * gt[:, background_channel_index]) # B x N
focal_loss = alpha_weights * focal_loss # B x N
focal_loss = focal_loss * weight # B x N
focal_loss = focal_loss.flatten(1)
num_non_zero = (focal_loss != 0.0).to(focal_loss).sum(-1) # B
num_non_zero = torch.clamp(num_non_zero, min=1.0)
loss_sum_per_sample = focal_loss.sum(-1) # B
return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1
# https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L50
def _gumbel_topk_sample(logits: torch.Tensor, k: int):
"""Samples k points from the softmax distribution with Gumbel-Top-k trick."""
# Note that torch.rand is [0, 1), we need to make it (0, 1) to ensure the log is valid.
gumbel_noise = torch.rand(size=logits.shape, dtype=logits.dtype, device=logits.device)
gumbel_noise = -torch.log(-torch.log(gumbel_noise))
_, indices = torch.topk(logits + gumbel_noise, k)
return indices
# https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L576
def pixelwise_insdis_loss(
pixel_feature: torch.Tensor,
gt_mask: torch.Tensor,
sample_temperature: float,
sample_k: int,
instance_discrimination_temperature: float,
pixel_gt_void_mask: torch.Tensor,
inverse_gt_mask_area: torch.Tensor
):
# pixel_feature: B x C x H x W
# gt_mask: B x N x H x W
pixel_feature = pixel_feature.flatten(2) # B x C x HW
gt_mask = gt_mask.flatten(2) # B x N x HW
pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW
inverse_gt_mask_area = inverse_gt_mask_area.flatten(1) # B x HW
sample_logits = torch.log(inverse_gt_mask_area) * sample_temperature # B x HW
# sample_logits.masked_fill_(pixel_gt_void_mask, float('-inf'))
sample_logits += pixel_gt_void_mask.to(sample_logits) * _SOFTMAX_MASKING_CONSTANT
sample_indices = _gumbel_topk_sample(sample_logits, sample_k) # B x K
# Sample ground truth one-hot encodings and compute gt_similarity.
pixel_gt_sampled_feature = torch.gather(gt_mask, dim=2, index=sample_indices.unsqueeze(1).repeat(1, gt_mask.shape[1], 1)) # B x N x K
sampled_gt_similarity = torch.einsum('bnk,bnj->bkj', pixel_gt_sampled_feature, pixel_gt_sampled_feature) # B x K x K
# Normalize the ground truth similarity into a distribution (sum to 1).
pixel_normalizing_constant = sampled_gt_similarity.sum(dim=1, keepdim=True) # B x 1 x K
sampled_gt_similarity /= torch.clamp(pixel_normalizing_constant, min=1.0) # B x K x K
# Sample predicted features and compute pred_similarity.
pixel_pred_sampled_feature = torch.gather(pixel_feature, dim=2, index=sample_indices.unsqueeze(1).repeat(1, pixel_feature.shape[1], 1)) # B x C x K
sampled_pred_similarity = torch.einsum('bck,bcj->bkj', pixel_pred_sampled_feature, pixel_pred_sampled_feature) # B x K x K
sampled_pred_similarity /= instance_discrimination_temperature # B x K x K
loss = F.cross_entropy(sampled_pred_similarity, sampled_gt_similarity, reduction="none") # B x K
num_non_zero = (loss != 0.0).to(loss).sum(-1) # B
num_non_zero = torch.clamp(num_non_zero, min=1.0)
loss_sum_per_sample = loss.sum(-1) # B
return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1
def aux_semantic_loss(
pred_semantic_logits: torch.Tensor,
ground_truth_semantic: torch.Tensor,
sample_temperature: float,
sample_k: int,
pixel_gt_void_mask: torch.Tensor,
inverse_gt_mask_area: torch.Tensor,
num_classes: int):
pred_semantic_logits = pred_semantic_logits.flatten(2) # B x C x HW
ground_truth_semantic = ground_truth_semantic.flatten(1) # B x HW
pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW
inverse_gt_mask_area = inverse_gt_mask_area.flatten(1) # B x HW
sample_logits = torch.log(inverse_gt_mask_area) * sample_temperature # B x HW
sample_logits += pixel_gt_void_mask.to(sample_logits) * _SOFTMAX_MASKING_CONSTANT
sample_indices = _gumbel_topk_sample(sample_logits, sample_k) # B x K
sampled_ground_truth_semantic = torch.gather(ground_truth_semantic, dim=1, index=sample_indices) # B x K
sampled_pred_semantic_logits = torch.gather(pred_semantic_logits, dim=2, index=sample_indices.unsqueeze(1).repeat(1, pred_semantic_logits.shape[1], 1)) # B x C x K
# ignore the class index num_classes.
keep_mask = (sampled_ground_truth_semantic != num_classes) # B x K
loss = F.cross_entropy(sampled_pred_semantic_logits, sampled_ground_truth_semantic, ignore_index=num_classes, reduction='none') # B x K
loss = loss * keep_mask.to(loss)
num_non_zero = (loss != 0.0).to(loss).sum(-1) # B
num_non_zero = torch.clamp(num_non_zero, min=1.0)
loss_sum_per_sample = loss.sum(-1) # B
return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1
# https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/base_loss.py#L56
# https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L510
def dice_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
pixel_gt_void_mask: torch.Tensor,
matched_cls_prob: torch.Tensor
):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
"""
inputs = inputs.softmax(1) # B N HW
# https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L111
inputs = inputs.masked_fill(pixel_gt_void_mask.unsqueeze(1), 0) # remove void pixels.
smooth = 1.0
intersection = 2 * (inputs * targets).sum(-1) + smooth # B x N
denominator = inputs.sum(-1) + targets.sum(-1) + smooth # B x N
loss = 1.0 - divide_no_nan(intersection, denominator)
loss *= matched_cls_prob
# Note: kMaX-DeepLab sum over num_masks and avg over batches. But here batch and num_mask are one
# https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/base_loss.py#L559
# https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/max_deeplab_loss.py#L402
# As the existing of modifer, it equals to multiplier by 0.75
return (loss.sum(1) * 0.75/128).mean() # sum over masks and mean over batches.
def softmax_ce_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
pixel_gt_void_mask: torch.Tensor,
):
"""
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
Loss tensor
"""
loss = F.cross_entropy(inputs, targets, reduction="none") # B x HW
loss = loss.masked_fill(pixel_gt_void_mask, 0) # remove void pixels.
num_non_zero = (loss != 0.0).to(loss).sum(-1) # B
num_non_zero = torch.clamp(num_non_zero, min=1.0)
loss_sum_per_sample = loss.sum(-1) # B
return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1
class SetCriterion(nn.Module):
"""This class computes the loss for DETR.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, share_final_matching,
pixel_insdis_temperature=1.5, pixel_insdis_sample_k=4096,
aux_semantic_temperature=2.0, aux_semantic_sample_k=4096):
"""Create the criterion.
Parameters:
num_classes: number of object categories, omitting the special no-object category
matcher: module able to compute a matching between targets and proposals
eos_coef: relative classification weight applied to the no-object category
losses: list of all the losses to be applied. See get_loss for list of available losses.
"""
super().__init__()
self.num_classes = num_classes
self.matcher = matcher
self.weight_dict = weight_dict
self.eos_coef = eos_coef
self.losses = losses
self.share_final_matching = share_final_matching
self.pixel_insdis_temperature = pixel_insdis_temperature
self.pixel_insdis_sample_k = pixel_insdis_sample_k
self.aux_semantic_temperature = aux_semantic_temperature
self.aux_semantic_sample_k = aux_semantic_sample_k
def loss_labels(self, outputs, targets):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
assert "pred_logits" in outputs
src_logits = outputs["pred_logits"] # B x N x C
target_classes = targets["labels"] # B x N
pq_loss_class_weight = targets["pq_loss_class_weight"]
losses = {"loss_ce": focal_cross_entropy_loss(src_logits, target_classes, pq_loss_class_weight)}
return losses
def loss_masks(self, outputs, targets):
"""Compute the losses related to the masks: the focal loss and the dice loss.
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
"""
src_masks = outputs["pred_masks"] # B x N x H x W
target_masks = targets["masks"]
pq_loss_mask_weight = targets["pq_loss_mask_weight"]
pixel_gt_void_mask = targets["pixel_gt_void_mask"]
src_masks = src_masks.flatten(2) # B x N x HW
target_masks = target_masks.flatten(2) # B x N x HW
pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW
losses = {
"loss_mask": softmax_ce_loss(src_masks, target_masks, pixel_gt_void_mask),
"loss_dice": dice_loss(src_masks, target_masks, pixel_gt_void_mask, pq_loss_mask_weight),
}
return losses
def loss_pixels(self, outputs, targets):
pixel_feature = outputs["pixel_feature"]
target_masks = targets["masks"]
pixel_gt_void_mask = targets["pixel_gt_void_mask"]
inverse_gt_mask_area = targets["inverse_gt_mask_area"]
losses = {"loss_pixel_insdis": pixelwise_insdis_loss(
pixel_feature=pixel_feature,
gt_mask=target_masks,
sample_temperature=self.pixel_insdis_temperature,
sample_k=self.pixel_insdis_sample_k,
instance_discrimination_temperature=0.3,
pixel_gt_void_mask=pixel_gt_void_mask,
inverse_gt_mask_area=inverse_gt_mask_area
)}
del target_masks
return losses
def loss_semantic(self, outputs, targets):
pred_semantic_logits = outputs["aux_semantic_pred"]
ground_truth_semantic = targets["ground_truth_semantic"]
pixel_gt_void_mask = targets["pixel_gt_void_mask"].flatten(1)
inverse_gt_mask_area = targets["inverse_gt_mask_area"].flatten(1)
losses = {"loss_aux_semantic": aux_semantic_loss(
pred_semantic_logits=pred_semantic_logits,
ground_truth_semantic=ground_truth_semantic,
sample_temperature=self.aux_semantic_temperature,
sample_k=self.aux_semantic_sample_k,
pixel_gt_void_mask=pixel_gt_void_mask,
inverse_gt_mask_area=inverse_gt_mask_area,
num_classes=self.num_classes
)}
return losses
@torch.no_grad()
def _get_src_permutation_idx(self, indices):
# permute predictions following indices
# torch.full_like gives a tensor full of i in shape of src.shape
# at each iter, i is the index, src is the src ind in shape of (N)
# so batch_idx is concat of (0,0,...), (1,1,...), with shape (N0+N1+N2+...+Nb)
# so if we flatten gt/pred across bathces, this gives the batch_id of each sample
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
# src_idx is src_ind concated to shape (N0+N1+N2+...+Nb)
# it is a flattened concat of mask_id at each batch
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
def get_loss(self, loss, outputs, targets):
loss_map = {
'labels': self.loss_labels,
'masks': self.loss_masks,
'pixels': self.loss_pixels,
'aux_semantic': self.loss_semantic,
}
assert loss in loss_map, f"do you really want to compute {loss} loss?"
return loss_map[loss](outputs, targets)
@torch.no_grad()
def process_gt(self, outputs, targets, indices, matched_dice, matched_cls_prob, process_semantic=False):
# Permute&Pad Pred&GT for loss compuation.
# By controling process_gt, we can share the matching results for all preds.
src_idx = self._get_src_permutation_idx(indices)
src_masks = outputs["pred_masks"].detach() # B x N x H x W
# Pad and permute the target_mask to B x N x H x W
target_masks = torch.zeros_like(src_masks)
target_masks_o = torch.cat([t["masks"][J] for t, (_, J) in zip(targets, indices)]).to(target_masks)
target_masks[src_idx] = target_masks_o
# Pad and permute the matched_cls_prob to B x N
matched_cls_prob_o = torch.cat([cls_prob for cls_prob in matched_cls_prob])
matched_cls_prob_o = torch.clamp(matched_cls_prob_o, min=self.eos_coef)
# https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L1034
# no penalty for unmatched masks.
matched_cls_prob = torch.full(
src_masks.shape[:2], 0, dtype=src_masks.dtype, device=src_masks.device
) # B x N
matched_cls_prob[src_idx] = matched_cls_prob_o.to(matched_cls_prob)
# pixel_gt_void_mask is used to indicate those pixels without labels.
pixel_gt_void_mask = (target_masks.sum(1) < 1) # B x H x W
# inverse_gt_mask_area is used to sample pixels.
mask_gt_area = target_masks.sum(2).sum(2) # B x N
pixel_gt_area = torch.einsum('bnhw,bn->bhw', target_masks, mask_gt_area) # B x H x W
inverse_gt_mask_area = (pixel_gt_area.shape[1] * pixel_gt_area.shape[2]) / torch.clamp(pixel_gt_area, min=1.0) # B x H x W
src_logits = outputs["pred_logits"] # B x N x C
# Pad and permute the target_classes to B x N
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
# This serves as a padding.
target_classes = torch.full(
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
)
# We put real GT to those corresponds to src_idx, and put void into other places.
target_classes[src_idx] = target_classes_o
src_masks_prob = src_masks.softmax(1)
void_mask = pixel_gt_void_mask.to(src_masks_prob) # B x H x W
# compute iou instead of dice for void overlapping.
def computer_iou_score(x, y):
# x : B x N x H x W
# y : B x H x W
x = x.flatten(2) # B x N x L
y = y.flatten(1) # B x L
intersection = torch.einsum('bnl,bl->bn', x, y) # B x N
denominator = x.sum(-1) # B x N
return intersection / (denominator + 1e-5) # B x N
# Pad and permute the matched_dice to B x N
matched_dice_o = torch.cat([dice for dice in matched_dice])
matched_dice = computer_iou_score(src_masks_prob, void_mask) # unmatched masks use their dice with void
matched_dice[src_idx] = matched_dice_o.to(matched_dice)
matched_dice = torch.clamp(matched_dice, min=self.eos_coef)
processed_gt = {"masks": target_masks, "labels": target_classes,
"pq_loss_mask_weight": matched_cls_prob,
"pq_loss_class_weight": matched_dice,
"pixel_gt_void_mask": pixel_gt_void_mask,
"inverse_gt_mask_area": inverse_gt_mask_area,}
if process_semantic:
# To obtain semantic gt
ground_truth_semantic = [t["semantic_masks"] for t in targets]
ground_truth_semantic = torch.stack(ground_truth_semantic, dim=0) # B x H x W
# self.num_classes is set to ignore label
ground_truth_semantic[ground_truth_semantic==-1] = self.num_classes
processed_gt.update({"ground_truth_semantic": ground_truth_semantic})
return processed_gt
def forward(self, outputs, targets):
"""This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
indices, matched_dice, matched_cls_prob = self.matcher(outputs_without_aux, targets)
# Pad GT to the same number of prediction.
processed_targets = self.process_gt(outputs, targets, indices, matched_dice, matched_cls_prob, process_semantic=True)
# Compute all the requested losses
losses = {}
for loss in self.losses:
losses.update(self.get_loss(loss, outputs, processed_targets))
if "aux_outputs" in outputs:
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
# We share matching results across predictions.
if not self.share_final_matching:
indices, matched_dice, matched_cls_prob = self.matcher(aux_outputs, targets)
if not self.share_final_matching:
processed_targets = self.process_gt(aux_outputs, targets, indices, matched_dice, matched_cls_prob)
for loss in self.losses:
if loss in ['aux_semantic']:
# Only for final output.
continue
l_dict = self.get_loss(loss, aux_outputs, processed_targets)
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
losses.update(l_dict)
return losses
def __repr__(self):
head = "Criterion " + self.__class__.__name__
body = [
"matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
"losses: {}".format(self.losses),
"weight_dict: {}".format(self.weight_dict),
"num_classes: {}".format(self.num_classes),
"eos_coef: {}".format(self.eos_coef),
]
_repr_indent = 4
lines = [head] + [" " * _repr_indent + line for line in body]
return "\n".join(lines)