josedolot's picture
Upload hybridnets/loss.py
c432040
raw
history blame
23.3 kB
import torch
import torch.nn as nn
import cv2
import numpy as np
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F
from utils.utils import postprocess, display, BBoxTransform, ClipBoxes
from typing import Optional, List
from functools import partial
BINARY_MODE: str = "binary"
MULTICLASS_MODE: str = "multiclass"
MULTILABEL_MODE: str = "multilabel"
def calc_iou(a, b):
# a(anchor) [boxes, (y1, x1, y2, x2)]
# b(gt, coco-style) [boxes, (x1, y1, x2, y2)]
area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
iw = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 0])
ih = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 1])
iw = torch.clamp(iw, min=0)
ih = torch.clamp(ih, min=0)
ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih
ua = torch.clamp(ua, min=1e-8)
intersection = iw * ih
IoU = intersection / ua
return IoU
class FocalLoss(nn.Module):
def __init__(self):
super(FocalLoss, self).__init__()
def forward(self, classifications, regressions, anchors, annotations, **kwargs):
alpha = 0.25
gamma = 2.0
batch_size = classifications.shape[0]
classification_losses = []
regression_losses = []
anchor = anchors[0, :, :] # assuming all image sizes are the same, which it is
dtype = anchors.dtype
anchor_widths = anchor[:, 3] - anchor[:, 1]
anchor_heights = anchor[:, 2] - anchor[:, 0]
anchor_ctr_x = anchor[:, 1] + 0.5 * anchor_widths
anchor_ctr_y = anchor[:, 0] + 0.5 * anchor_heights
for j in range(batch_size):
classification = classifications[j, :, :]
regression = regressions[j, :, :]
bbox_annotation = annotations[j]
bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
# print(bbox_annotation)
classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
if bbox_annotation.shape[0] == 0:
if torch.cuda.is_available():
alpha_factor = torch.ones_like(classification) * alpha
alpha_factor = alpha_factor.cuda()
alpha_factor = 1. - alpha_factor
focal_weight = classification
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
bce = -(torch.log(1.0 - classification))
cls_loss = focal_weight * bce
regression_losses.append(torch.tensor(0).to(dtype).cuda())
classification_losses.append(cls_loss.sum())
else:
alpha_factor = torch.ones_like(classification) * alpha
alpha_factor = 1. - alpha_factor
focal_weight = classification
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
bce = -(torch.log(1.0 - classification))
cls_loss = focal_weight * bce
regression_losses.append(torch.tensor(0).to(dtype))
classification_losses.append(cls_loss.sum())
continue
IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4])
IoU_max, IoU_argmax = torch.max(IoU, dim=1)
# compute the loss for classification
#targets = torch.ones_like(classification) * -1
targets = torch.zeros_like(classification)
if torch.cuda.is_available():
targets = targets.cuda()
assigned_annotations = bbox_annotation[IoU_argmax, :]
positive_indices = torch.full_like(IoU_max,False,dtype=torch.bool) #torch.ge(IoU_max, 0.2)
tensorA = (assigned_annotations[:, 2] - assigned_annotations[:, 0]) * (assigned_annotations[:, 3] - assigned_annotations[:, 1]) > 10 * 10
# for idx,iou in enumerate(IoU_max):
# if tensorA[idx]: # Set iou threshold = 0.5
# if iou >= 0.5:
# positive_indices[idx] = True
# # targets[idx,:] = True
# # else:
# # positive_indices[idx] = False
# else:
# if iou >= 0.15:
# positive_indices[idx] = True
# # else:
# # positive_indices[idx] = False
# # targets[torch.lt(IoU_max, 0.4), :] = 0
positive_indices[torch.logical_or(torch.logical_and(tensorA,IoU_max >= 0.5),torch.logical_and(~tensorA,IoU_max >= 0.15))] = True
num_positive_anchors = positive_indices.sum()
# for box in assigned_annotations[positive_indices, :]:
# xmin,ymin,xmax,ymax, cls = box
# print("WIDTH HEIGHT:", (xmax-xmin),"\t", (ymax-ymin))
# for box in bbox_annotation:
# xmin,ymin,xmax,ymax, cls = box
# print("111 WIDTH HEIGHT:", (xmax-xmin),"\t", (ymax-ymin))
# targets[positive_indices, :] = 0
targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
alpha_factor = torch.ones_like(targets) * alpha
if torch.cuda.is_available():
alpha_factor = alpha_factor.cuda()
alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))
cls_loss = focal_weight * bce
zeros = torch.zeros_like(cls_loss)
if torch.cuda.is_available():
zeros = zeros.cuda()
cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros)
classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
if positive_indices.sum() > 0:
assigned_annotations = assigned_annotations[positive_indices, :]
anchor_widths_pi = anchor_widths[positive_indices]
anchor_heights_pi = anchor_heights[positive_indices]
anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
anchor_ctr_y_pi = anchor_ctr_y[positive_indices]
gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0]
gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths
gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights
gt_widths = torch.clamp(gt_widths, min=1)
gt_heights = torch.clamp(gt_heights, min=1)
targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
targets_dw = torch.log(gt_widths / anchor_widths_pi)
targets_dh = torch.log(gt_heights / anchor_heights_pi)
targets = torch.stack((targets_dy, targets_dx, targets_dh, targets_dw))
targets = targets.t()
regression_diff = torch.abs(targets - regression[positive_indices, :])
regression_loss = torch.where(
torch.le(regression_diff, 1.0 / 9.0),
0.5 * 9.0 * torch.pow(regression_diff, 2),
regression_diff - 0.5 / 9.0
)
regression_losses.append(regression_loss.mean())
else:
if torch.cuda.is_available():
regression_losses.append(torch.tensor(0).to(dtype).cuda())
else:
regression_losses.append(torch.tensor(0).to(dtype))
# debug
imgs = kwargs.get('imgs', None)
if imgs is not None:
regressBoxes = BBoxTransform()
clipBoxes = ClipBoxes()
obj_list = kwargs.get('obj_list', None)
out = postprocess(imgs.detach(),
torch.stack([anchors[0]] * imgs.shape[0], 0).detach(), regressions.detach(), classifications.detach(),
regressBoxes, clipBoxes,
0.25, 0.3)
imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
imgs = ((imgs * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255).astype(np.uint8)
imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in imgs]
display(out, imgs, obj_list, imshow=False, imwrite=True)
return torch.stack(classification_losses).mean(dim=0, keepdim=True), \
torch.stack(regression_losses).mean(dim=0, keepdim=True) * 50 # https://github.com/google/automl/blob/6fdd1de778408625c1faf368a327fe36ecd41bf7/efficientdet/hparams_config.py#L233
def focal_loss_with_logits(
output: torch.Tensor,
target: torch.Tensor,
gamma: float = 2.0,
alpha: Optional[float] = 0.25,
reduction: str = "mean",
normalized: bool = False,
reduced_threshold: Optional[float] = None,
eps: float = 1e-6,
) -> torch.Tensor:
"""Compute binary focal loss between target and output logits.
See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
Args:
output: Tensor of arbitrary shape (predictions of the model)
target: Tensor of the same shape as input
gamma: Focal loss power factor
alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range,
high values will give more weight to positive class.
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied,
'mean': the sum of the output will be divided by the number of
elements in the output, 'sum': the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`.
'batchwise_mean' computes mean loss per sample in batch. Default: 'mean'
normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347).
References:
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py
"""
target = target.type(output.type())
# print(output.size(), target.size())
logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none")
pt = torch.exp(-logpt)
# compute the loss
if reduced_threshold is None:
focal_term = (1.0 - pt).pow(gamma)
else:
focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma)
focal_term[pt < reduced_threshold] = 1
loss = focal_term * logpt
if alpha is not None:
loss *= alpha * target + (1 - alpha) * (1 - target)
if normalized:
norm_factor = focal_term.sum().clamp_min(eps)
loss /= norm_factor
if reduction == "mean":
loss = loss.mean()
if reduction == "sum":
loss = loss.sum()
if reduction == "batchwise_mean":
loss = loss.sum(0)
return loss
class FocalLossSeg(_Loss):
def __init__(
self,
mode: str,
alpha: Optional[float] = None,
gamma: Optional[float] = 2.0,
ignore_index: Optional[int] = None,
reduction: Optional[str] = "mean",
normalized: bool = False,
reduced_threshold: Optional[float] = None,
):
"""Compute Focal loss
Args:
mode: Loss mode 'binary', 'multiclass' or 'multilabel'
alpha: Prior probability of having positive value in target.
gamma: Power factor for dampening weight (focal strength).
ignore_index: If not None, targets may contain values to be ignored.
Target values equal to ignore_index will be ignored from loss computation.
normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
reduced_threshold: Switch to reduced focal loss. Note, when using this mode you
should use `reduction="sum"`.
Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
Reference
https://github.com/BloodAxe/pytorch-toolbelt
"""
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super().__init__()
self.mode = mode
self.ignore_index = ignore_index
self.focal_loss_fn = partial(
focal_loss_with_logits,
alpha=alpha,
gamma=gamma,
reduced_threshold=reduced_threshold,
reduction=reduction,
normalized=normalized,
)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
y_true = y_true.view(-1)
y_pred = y_pred.view(-1)
if self.ignore_index is not None:
# Filter predictions with ignore label from loss computation
not_ignored = y_true != self.ignore_index
y_pred = y_pred[not_ignored]
y_true = y_true[not_ignored]
loss = self.focal_loss_fn(y_pred, y_true)
elif self.mode == MULTICLASS_MODE:
num_classes = y_pred.size(1)
loss = 0
# Filter anchors with -1 label from loss computation
if self.ignore_index is not None:
not_ignored = y_true != self.ignore_index
for cls in range(num_classes):
# cls_y_true = (y_true == cls).long()
cls_y_true = y_true[:, cls, ...]
cls_y_pred = y_pred[:, cls, ...]
if self.ignore_index is not None:
cls_y_true = cls_y_true[not_ignored]
cls_y_pred = cls_y_pred[not_ignored]
loss += self.focal_loss_fn(cls_y_pred, cls_y_true)
return loss
def to_tensor(x, dtype=None) -> torch.Tensor:
if isinstance(x, torch.Tensor):
if dtype is not None:
x = x.type(dtype)
return x
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if dtype is not None:
x = x.type(dtype)
return x
if isinstance(x, (list, tuple)):
x = np.array(x)
x = torch.from_numpy(x)
if dtype is not None:
x = x.type(dtype)
return x
def soft_dice_score(
output: torch.Tensor,
target: torch.Tensor,
smooth: float = 0.0,
eps: float = 1e-7,
dims=None,
) -> torch.Tensor:
assert output.size() == target.size()
if dims is not None:
intersection = torch.sum(output * target, dim=dims)
cardinality = torch.sum(output + target, dim=dims)
else:
intersection = torch.sum(output * target)
cardinality = torch.sum(output + target)
dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
return dice_score
class DiceLoss(_Loss):
def __init__(
self,
mode: str,
classes: Optional[List[int]] = None,
log_loss: bool = False,
from_logits: bool = True,
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
):
"""Dice loss for image segmentation task.
It supports binary, multiclass and multilabel cases
Args:
mode: Loss mode 'binary', 'multiclass' or 'multilabel'
classes: List of classes that contribute in loss computation. By default, all channels are included.
log_loss: If True, loss computed as `- log(dice_coeff)`, otherwise `1 - dice_coeff`
from_logits: If True, assumes input is raw logits
smooth: Smoothness constant for dice coefficient (a)
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
eps: A small epsilon for numerical stability to avoid zero division error
(denominator will be always greater or equal to eps)
Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
Reference
https://github.com/BloodAxe/pytorch-toolbelt
"""
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super(DiceLoss, self).__init__()
self.mode = mode
if classes is not None:
assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary"
classes = to_tensor(classes, dtype=torch.long)
self.classes = classes
self.from_logits = from_logits
self.smooth = smooth
self.eps = eps
self.log_loss = log_loss
self.ignore_index = ignore_index
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
assert y_true.size(0) == y_pred.size(0)
if self.from_logits:
# Apply activations to get [0..1] class probabilities
# Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on
# extreme values 0 and 1
# print(y_pred)
if self.mode == MULTICLASS_MODE:
y_pred = y_pred.log_softmax(dim=1).exp()
else:
y_pred = F.logsigmoid(y_pred).exp()
# print("AFTER: ", y_pred)
bs = y_true.size(0)
num_classes = y_pred.size(1)
dims = (0, 2)
if self.mode == BINARY_MODE:
y_true = y_true.view(bs, 1, -1)
y_pred = y_pred.view(bs, 1, -1)
if self.ignore_index is not None:
mask = y_true != self.ignore_index
y_pred = y_pred * mask
y_true = y_true * mask
if self.mode == MULTICLASS_MODE:
y_true = y_true.view(bs, num_classes, -1)
y_pred = y_pred.view(bs, num_classes, -1)
# print("NUM CLASSES:", num_classes, y_true.size())
# if self.ignore_index is not None:
# mask = y_true != self.ignore_index
# y_pred = y_pred * mask.unsqueeze(1)
#
# y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C
# y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # H, C, H*W
# else:
# y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
# y_true = y_true.permute(0, 2, 1) # N, C, H*W
#
# print("HERE", y_true.size())
# print(y_pred.size())
if self.mode == MULTILABEL_MODE:
y_true = y_true.view(bs, num_classes, -1)
y_pred = y_pred.view(bs, num_classes, -1)
if self.ignore_index is not None:
mask = y_true != self.ignore_index
y_pred = y_pred * mask
y_true = y_true * mask
scores = self.compute_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims)
if self.log_loss:
loss = -torch.log(scores.clamp_min(self.eps))
else:
loss = 1.0 - scores
# Dice loss is undefined for non-empty classes
# So we zero contribution of channel that does not have true pixels
# NOTE: A better workaround would be to use loss term `mean(y_pred)`
# for this case, however it will be a modified jaccard loss
mask = y_true.sum(dims) > 0
loss *= mask.to(loss.dtype)
if self.classes is not None:
loss = loss[self.classes]
return self.aggregate_loss(loss)
def aggregate_loss(self, loss):
return loss.mean()
def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor:
return soft_dice_score(output, target, smooth, eps, dims)
def soft_tversky_score(
output: torch.Tensor,
target: torch.Tensor,
alpha: float,
beta: float,
smooth: float = 0.0,
eps: float = 1e-7,
dims=None,
) -> torch.Tensor:
assert output.size() == target.size()
if dims is not None:
intersection = torch.sum(output * target, dim=dims) # TP
fp = torch.sum(output * (1.0 - target), dim=dims)
fn = torch.sum((1 - output) * target, dim=dims)
else:
intersection = torch.sum(output * target) # TP
fp = torch.sum(output * (1.0 - target))
fn = torch.sum((1 - output) * target)
tversky_score = (intersection + smooth) / (intersection + alpha * fp + beta * fn + smooth).clamp_min(eps)
return tversky_score
class TverskyLoss(DiceLoss):
"""Tversky loss for image segmentation task.
Where TP and FP is weighted by alpha and beta params.
With alpha == beta == 0.5, this loss becomes equal DiceLoss.
It supports binary, multiclass and multilabel cases
Args:
mode: Metric mode {'binary', 'multiclass', 'multilabel'}
classes: Optional list of classes that contribute in loss computation;
By default, all channels are included.
log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky``
from_logits: If True assumes input is raw logits
smooth:
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
eps: Small epsilon for numerical stability
alpha: Weight constant that penalize model for FPs (False Positives)
beta: Weight constant that penalize model for FNs (False Positives)
gamma: Constant that squares the error function. Defaults to ``1.0``
Return:
loss: torch.Tensor
"""
def __init__(
self,
mode: str,
classes: List[int] = None,
log_loss: bool = False,
from_logits: bool = True,
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0
):
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super().__init__(mode, classes, log_loss, from_logits, smooth, ignore_index, eps)
self.alpha = alpha
self.beta = beta
self.gamma = gamma
def aggregate_loss(self, loss):
return loss.mean() ** self.gamma
def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor:
return soft_tversky_score(output, target, self.alpha, self.beta, smooth, eps, dims)