Spaces:
Runtime error
Runtime error
Upload hybridnets/loss.py
Browse files- hybridnets/loss.py +599 -0
hybridnets/loss.py
ADDED
|
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.nn.modules.loss import _Loss
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from utils.utils import postprocess, display, BBoxTransform, ClipBoxes
|
| 8 |
+
from typing import Optional, List
|
| 9 |
+
from functools import partial
|
| 10 |
+
|
| 11 |
+
BINARY_MODE: str = "binary"
|
| 12 |
+
MULTICLASS_MODE: str = "multiclass"
|
| 13 |
+
MULTILABEL_MODE: str = "multilabel"
|
| 14 |
+
|
| 15 |
+
def calc_iou(a, b):
|
| 16 |
+
# a(anchor) [boxes, (y1, x1, y2, x2)]
|
| 17 |
+
# b(gt, coco-style) [boxes, (x1, y1, x2, y2)]
|
| 18 |
+
|
| 19 |
+
area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
|
| 20 |
+
iw = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 0])
|
| 21 |
+
ih = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 1])
|
| 22 |
+
iw = torch.clamp(iw, min=0)
|
| 23 |
+
ih = torch.clamp(ih, min=0)
|
| 24 |
+
ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih
|
| 25 |
+
ua = torch.clamp(ua, min=1e-8)
|
| 26 |
+
intersection = iw * ih
|
| 27 |
+
IoU = intersection / ua
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
return IoU
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class FocalLoss(nn.Module):
|
| 34 |
+
def __init__(self):
|
| 35 |
+
super(FocalLoss, self).__init__()
|
| 36 |
+
|
| 37 |
+
def forward(self, classifications, regressions, anchors, annotations, **kwargs):
|
| 38 |
+
alpha = 0.25
|
| 39 |
+
gamma = 2.0
|
| 40 |
+
batch_size = classifications.shape[0]
|
| 41 |
+
classification_losses = []
|
| 42 |
+
regression_losses = []
|
| 43 |
+
|
| 44 |
+
anchor = anchors[0, :, :] # assuming all image sizes are the same, which it is
|
| 45 |
+
dtype = anchors.dtype
|
| 46 |
+
|
| 47 |
+
anchor_widths = anchor[:, 3] - anchor[:, 1]
|
| 48 |
+
anchor_heights = anchor[:, 2] - anchor[:, 0]
|
| 49 |
+
anchor_ctr_x = anchor[:, 1] + 0.5 * anchor_widths
|
| 50 |
+
anchor_ctr_y = anchor[:, 0] + 0.5 * anchor_heights
|
| 51 |
+
|
| 52 |
+
for j in range(batch_size):
|
| 53 |
+
|
| 54 |
+
classification = classifications[j, :, :]
|
| 55 |
+
regression = regressions[j, :, :]
|
| 56 |
+
|
| 57 |
+
bbox_annotation = annotations[j]
|
| 58 |
+
bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
|
| 59 |
+
|
| 60 |
+
# print(bbox_annotation)
|
| 61 |
+
|
| 62 |
+
classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
|
| 63 |
+
|
| 64 |
+
if bbox_annotation.shape[0] == 0:
|
| 65 |
+
if torch.cuda.is_available():
|
| 66 |
+
|
| 67 |
+
alpha_factor = torch.ones_like(classification) * alpha
|
| 68 |
+
alpha_factor = alpha_factor.cuda()
|
| 69 |
+
alpha_factor = 1. - alpha_factor
|
| 70 |
+
focal_weight = classification
|
| 71 |
+
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
|
| 72 |
+
|
| 73 |
+
bce = -(torch.log(1.0 - classification))
|
| 74 |
+
|
| 75 |
+
cls_loss = focal_weight * bce
|
| 76 |
+
|
| 77 |
+
regression_losses.append(torch.tensor(0).to(dtype).cuda())
|
| 78 |
+
classification_losses.append(cls_loss.sum())
|
| 79 |
+
else:
|
| 80 |
+
|
| 81 |
+
alpha_factor = torch.ones_like(classification) * alpha
|
| 82 |
+
alpha_factor = 1. - alpha_factor
|
| 83 |
+
focal_weight = classification
|
| 84 |
+
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
|
| 85 |
+
|
| 86 |
+
bce = -(torch.log(1.0 - classification))
|
| 87 |
+
|
| 88 |
+
cls_loss = focal_weight * bce
|
| 89 |
+
|
| 90 |
+
regression_losses.append(torch.tensor(0).to(dtype))
|
| 91 |
+
classification_losses.append(cls_loss.sum())
|
| 92 |
+
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4])
|
| 96 |
+
|
| 97 |
+
IoU_max, IoU_argmax = torch.max(IoU, dim=1)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# compute the loss for classification
|
| 101 |
+
#targets = torch.ones_like(classification) * -1
|
| 102 |
+
targets = torch.zeros_like(classification)
|
| 103 |
+
|
| 104 |
+
if torch.cuda.is_available():
|
| 105 |
+
targets = targets.cuda()
|
| 106 |
+
|
| 107 |
+
assigned_annotations = bbox_annotation[IoU_argmax, :]
|
| 108 |
+
|
| 109 |
+
positive_indices = torch.full_like(IoU_max,False,dtype=torch.bool) #torch.ge(IoU_max, 0.2)
|
| 110 |
+
|
| 111 |
+
tensorA = (assigned_annotations[:, 2] - assigned_annotations[:, 0]) * (assigned_annotations[:, 3] - assigned_annotations[:, 1]) > 10 * 10
|
| 112 |
+
# for idx,iou in enumerate(IoU_max):
|
| 113 |
+
# if tensorA[idx]: # Set iou threshold = 0.5
|
| 114 |
+
# if iou >= 0.5:
|
| 115 |
+
# positive_indices[idx] = True
|
| 116 |
+
# # targets[idx,:] = True
|
| 117 |
+
# # else:
|
| 118 |
+
# # positive_indices[idx] = False
|
| 119 |
+
# else:
|
| 120 |
+
# if iou >= 0.15:
|
| 121 |
+
# positive_indices[idx] = True
|
| 122 |
+
# # else:
|
| 123 |
+
# # positive_indices[idx] = False
|
| 124 |
+
|
| 125 |
+
# # targets[torch.lt(IoU_max, 0.4), :] = 0
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
positive_indices[torch.logical_or(torch.logical_and(tensorA,IoU_max >= 0.5),torch.logical_and(~tensorA,IoU_max >= 0.15))] = True
|
| 129 |
+
|
| 130 |
+
num_positive_anchors = positive_indices.sum()
|
| 131 |
+
|
| 132 |
+
# for box in assigned_annotations[positive_indices, :]:
|
| 133 |
+
# xmin,ymin,xmax,ymax, cls = box
|
| 134 |
+
# print("WIDTH HEIGHT:", (xmax-xmin),"\t", (ymax-ymin))
|
| 135 |
+
# for box in bbox_annotation:
|
| 136 |
+
# xmin,ymin,xmax,ymax, cls = box
|
| 137 |
+
# print("111 WIDTH HEIGHT:", (xmax-xmin),"\t", (ymax-ymin))
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# targets[positive_indices, :] = 0
|
| 141 |
+
targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
|
| 142 |
+
|
| 143 |
+
alpha_factor = torch.ones_like(targets) * alpha
|
| 144 |
+
if torch.cuda.is_available():
|
| 145 |
+
alpha_factor = alpha_factor.cuda()
|
| 146 |
+
|
| 147 |
+
alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
|
| 148 |
+
focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
|
| 149 |
+
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
|
| 150 |
+
|
| 151 |
+
bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))
|
| 152 |
+
|
| 153 |
+
cls_loss = focal_weight * bce
|
| 154 |
+
|
| 155 |
+
zeros = torch.zeros_like(cls_loss)
|
| 156 |
+
if torch.cuda.is_available():
|
| 157 |
+
zeros = zeros.cuda()
|
| 158 |
+
cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros)
|
| 159 |
+
|
| 160 |
+
classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
|
| 161 |
+
|
| 162 |
+
if positive_indices.sum() > 0:
|
| 163 |
+
assigned_annotations = assigned_annotations[positive_indices, :]
|
| 164 |
+
|
| 165 |
+
anchor_widths_pi = anchor_widths[positive_indices]
|
| 166 |
+
anchor_heights_pi = anchor_heights[positive_indices]
|
| 167 |
+
anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
|
| 168 |
+
anchor_ctr_y_pi = anchor_ctr_y[positive_indices]
|
| 169 |
+
|
| 170 |
+
gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0]
|
| 171 |
+
gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
|
| 172 |
+
gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths
|
| 173 |
+
gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights
|
| 174 |
+
|
| 175 |
+
gt_widths = torch.clamp(gt_widths, min=1)
|
| 176 |
+
gt_heights = torch.clamp(gt_heights, min=1)
|
| 177 |
+
|
| 178 |
+
targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
|
| 179 |
+
targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
|
| 180 |
+
targets_dw = torch.log(gt_widths / anchor_widths_pi)
|
| 181 |
+
targets_dh = torch.log(gt_heights / anchor_heights_pi)
|
| 182 |
+
|
| 183 |
+
targets = torch.stack((targets_dy, targets_dx, targets_dh, targets_dw))
|
| 184 |
+
targets = targets.t()
|
| 185 |
+
|
| 186 |
+
regression_diff = torch.abs(targets - regression[positive_indices, :])
|
| 187 |
+
|
| 188 |
+
regression_loss = torch.where(
|
| 189 |
+
torch.le(regression_diff, 1.0 / 9.0),
|
| 190 |
+
0.5 * 9.0 * torch.pow(regression_diff, 2),
|
| 191 |
+
regression_diff - 0.5 / 9.0
|
| 192 |
+
)
|
| 193 |
+
regression_losses.append(regression_loss.mean())
|
| 194 |
+
else:
|
| 195 |
+
if torch.cuda.is_available():
|
| 196 |
+
regression_losses.append(torch.tensor(0).to(dtype).cuda())
|
| 197 |
+
else:
|
| 198 |
+
regression_losses.append(torch.tensor(0).to(dtype))
|
| 199 |
+
|
| 200 |
+
# debug
|
| 201 |
+
imgs = kwargs.get('imgs', None)
|
| 202 |
+
if imgs is not None:
|
| 203 |
+
regressBoxes = BBoxTransform()
|
| 204 |
+
clipBoxes = ClipBoxes()
|
| 205 |
+
obj_list = kwargs.get('obj_list', None)
|
| 206 |
+
out = postprocess(imgs.detach(),
|
| 207 |
+
torch.stack([anchors[0]] * imgs.shape[0], 0).detach(), regressions.detach(), classifications.detach(),
|
| 208 |
+
regressBoxes, clipBoxes,
|
| 209 |
+
0.25, 0.3)
|
| 210 |
+
imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
|
| 211 |
+
imgs = ((imgs * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255).astype(np.uint8)
|
| 212 |
+
imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in imgs]
|
| 213 |
+
display(out, imgs, obj_list, imshow=False, imwrite=True)
|
| 214 |
+
|
| 215 |
+
return torch.stack(classification_losses).mean(dim=0, keepdim=True), \
|
| 216 |
+
torch.stack(regression_losses).mean(dim=0, keepdim=True) * 50 # https://github.com/google/automl/blob/6fdd1de778408625c1faf368a327fe36ecd41bf7/efficientdet/hparams_config.py#L233
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def focal_loss_with_logits(
|
| 220 |
+
output: torch.Tensor,
|
| 221 |
+
target: torch.Tensor,
|
| 222 |
+
gamma: float = 2.0,
|
| 223 |
+
alpha: Optional[float] = 0.25,
|
| 224 |
+
reduction: str = "mean",
|
| 225 |
+
normalized: bool = False,
|
| 226 |
+
reduced_threshold: Optional[float] = None,
|
| 227 |
+
eps: float = 1e-6,
|
| 228 |
+
) -> torch.Tensor:
|
| 229 |
+
"""Compute binary focal loss between target and output logits.
|
| 230 |
+
See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
|
| 231 |
+
Args:
|
| 232 |
+
output: Tensor of arbitrary shape (predictions of the model)
|
| 233 |
+
target: Tensor of the same shape as input
|
| 234 |
+
gamma: Focal loss power factor
|
| 235 |
+
alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range,
|
| 236 |
+
high values will give more weight to positive class.
|
| 237 |
+
reduction (string, optional): Specifies the reduction to apply to the output:
|
| 238 |
+
'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied,
|
| 239 |
+
'mean': the sum of the output will be divided by the number of
|
| 240 |
+
elements in the output, 'sum': the output will be summed. Note: :attr:`size_average`
|
| 241 |
+
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
|
| 242 |
+
specifying either of those two args will override :attr:`reduction`.
|
| 243 |
+
'batchwise_mean' computes mean loss per sample in batch. Default: 'mean'
|
| 244 |
+
normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
|
| 245 |
+
reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347).
|
| 246 |
+
References:
|
| 247 |
+
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py
|
| 248 |
+
"""
|
| 249 |
+
target = target.type(output.type())
|
| 250 |
+
# print(output.size(), target.size())
|
| 251 |
+
|
| 252 |
+
logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none")
|
| 253 |
+
pt = torch.exp(-logpt)
|
| 254 |
+
|
| 255 |
+
# compute the loss
|
| 256 |
+
if reduced_threshold is None:
|
| 257 |
+
focal_term = (1.0 - pt).pow(gamma)
|
| 258 |
+
else:
|
| 259 |
+
focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma)
|
| 260 |
+
focal_term[pt < reduced_threshold] = 1
|
| 261 |
+
|
| 262 |
+
loss = focal_term * logpt
|
| 263 |
+
|
| 264 |
+
if alpha is not None:
|
| 265 |
+
loss *= alpha * target + (1 - alpha) * (1 - target)
|
| 266 |
+
|
| 267 |
+
if normalized:
|
| 268 |
+
norm_factor = focal_term.sum().clamp_min(eps)
|
| 269 |
+
loss /= norm_factor
|
| 270 |
+
|
| 271 |
+
if reduction == "mean":
|
| 272 |
+
loss = loss.mean()
|
| 273 |
+
if reduction == "sum":
|
| 274 |
+
loss = loss.sum()
|
| 275 |
+
if reduction == "batchwise_mean":
|
| 276 |
+
loss = loss.sum(0)
|
| 277 |
+
|
| 278 |
+
return loss
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class FocalLossSeg(_Loss):
|
| 282 |
+
def __init__(
|
| 283 |
+
self,
|
| 284 |
+
mode: str,
|
| 285 |
+
alpha: Optional[float] = None,
|
| 286 |
+
gamma: Optional[float] = 2.0,
|
| 287 |
+
ignore_index: Optional[int] = None,
|
| 288 |
+
reduction: Optional[str] = "mean",
|
| 289 |
+
normalized: bool = False,
|
| 290 |
+
reduced_threshold: Optional[float] = None,
|
| 291 |
+
):
|
| 292 |
+
"""Compute Focal loss
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
mode: Loss mode 'binary', 'multiclass' or 'multilabel'
|
| 296 |
+
alpha: Prior probability of having positive value in target.
|
| 297 |
+
gamma: Power factor for dampening weight (focal strength).
|
| 298 |
+
ignore_index: If not None, targets may contain values to be ignored.
|
| 299 |
+
Target values equal to ignore_index will be ignored from loss computation.
|
| 300 |
+
normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
|
| 301 |
+
reduced_threshold: Switch to reduced focal loss. Note, when using this mode you
|
| 302 |
+
should use `reduction="sum"`.
|
| 303 |
+
|
| 304 |
+
Shape
|
| 305 |
+
- **y_pred** - torch.Tensor of shape (N, C, H, W)
|
| 306 |
+
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
|
| 307 |
+
|
| 308 |
+
Reference
|
| 309 |
+
https://github.com/BloodAxe/pytorch-toolbelt
|
| 310 |
+
|
| 311 |
+
"""
|
| 312 |
+
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
|
| 313 |
+
super().__init__()
|
| 314 |
+
|
| 315 |
+
self.mode = mode
|
| 316 |
+
self.ignore_index = ignore_index
|
| 317 |
+
self.focal_loss_fn = partial(
|
| 318 |
+
focal_loss_with_logits,
|
| 319 |
+
alpha=alpha,
|
| 320 |
+
gamma=gamma,
|
| 321 |
+
reduced_threshold=reduced_threshold,
|
| 322 |
+
reduction=reduction,
|
| 323 |
+
normalized=normalized,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
|
| 327 |
+
|
| 328 |
+
if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
|
| 329 |
+
y_true = y_true.view(-1)
|
| 330 |
+
y_pred = y_pred.view(-1)
|
| 331 |
+
|
| 332 |
+
if self.ignore_index is not None:
|
| 333 |
+
# Filter predictions with ignore label from loss computation
|
| 334 |
+
not_ignored = y_true != self.ignore_index
|
| 335 |
+
y_pred = y_pred[not_ignored]
|
| 336 |
+
y_true = y_true[not_ignored]
|
| 337 |
+
|
| 338 |
+
loss = self.focal_loss_fn(y_pred, y_true)
|
| 339 |
+
|
| 340 |
+
elif self.mode == MULTICLASS_MODE:
|
| 341 |
+
num_classes = y_pred.size(1)
|
| 342 |
+
loss = 0
|
| 343 |
+
|
| 344 |
+
# Filter anchors with -1 label from loss computation
|
| 345 |
+
if self.ignore_index is not None:
|
| 346 |
+
not_ignored = y_true != self.ignore_index
|
| 347 |
+
|
| 348 |
+
for cls in range(num_classes):
|
| 349 |
+
# cls_y_true = (y_true == cls).long()
|
| 350 |
+
|
| 351 |
+
cls_y_true = y_true[:, cls, ...]
|
| 352 |
+
cls_y_pred = y_pred[:, cls, ...]
|
| 353 |
+
|
| 354 |
+
if self.ignore_index is not None:
|
| 355 |
+
cls_y_true = cls_y_true[not_ignored]
|
| 356 |
+
cls_y_pred = cls_y_pred[not_ignored]
|
| 357 |
+
|
| 358 |
+
loss += self.focal_loss_fn(cls_y_pred, cls_y_true)
|
| 359 |
+
|
| 360 |
+
return loss
|
| 361 |
+
|
| 362 |
+
def to_tensor(x, dtype=None) -> torch.Tensor:
|
| 363 |
+
if isinstance(x, torch.Tensor):
|
| 364 |
+
if dtype is not None:
|
| 365 |
+
x = x.type(dtype)
|
| 366 |
+
return x
|
| 367 |
+
if isinstance(x, np.ndarray):
|
| 368 |
+
x = torch.from_numpy(x)
|
| 369 |
+
if dtype is not None:
|
| 370 |
+
x = x.type(dtype)
|
| 371 |
+
return x
|
| 372 |
+
if isinstance(x, (list, tuple)):
|
| 373 |
+
x = np.array(x)
|
| 374 |
+
x = torch.from_numpy(x)
|
| 375 |
+
if dtype is not None:
|
| 376 |
+
x = x.type(dtype)
|
| 377 |
+
return x
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def soft_dice_score(
|
| 381 |
+
output: torch.Tensor,
|
| 382 |
+
target: torch.Tensor,
|
| 383 |
+
smooth: float = 0.0,
|
| 384 |
+
eps: float = 1e-7,
|
| 385 |
+
dims=None,
|
| 386 |
+
) -> torch.Tensor:
|
| 387 |
+
assert output.size() == target.size()
|
| 388 |
+
if dims is not None:
|
| 389 |
+
intersection = torch.sum(output * target, dim=dims)
|
| 390 |
+
cardinality = torch.sum(output + target, dim=dims)
|
| 391 |
+
else:
|
| 392 |
+
intersection = torch.sum(output * target)
|
| 393 |
+
cardinality = torch.sum(output + target)
|
| 394 |
+
dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
|
| 395 |
+
return dice_score
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
class DiceLoss(_Loss):
|
| 399 |
+
def __init__(
|
| 400 |
+
self,
|
| 401 |
+
mode: str,
|
| 402 |
+
classes: Optional[List[int]] = None,
|
| 403 |
+
log_loss: bool = False,
|
| 404 |
+
from_logits: bool = True,
|
| 405 |
+
smooth: float = 0.0,
|
| 406 |
+
ignore_index: Optional[int] = None,
|
| 407 |
+
eps: float = 1e-7,
|
| 408 |
+
):
|
| 409 |
+
"""Dice loss for image segmentation task.
|
| 410 |
+
It supports binary, multiclass and multilabel cases
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
mode: Loss mode 'binary', 'multiclass' or 'multilabel'
|
| 414 |
+
classes: List of classes that contribute in loss computation. By default, all channels are included.
|
| 415 |
+
log_loss: If True, loss computed as `- log(dice_coeff)`, otherwise `1 - dice_coeff`
|
| 416 |
+
from_logits: If True, assumes input is raw logits
|
| 417 |
+
smooth: Smoothness constant for dice coefficient (a)
|
| 418 |
+
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
|
| 419 |
+
eps: A small epsilon for numerical stability to avoid zero division error
|
| 420 |
+
(denominator will be always greater or equal to eps)
|
| 421 |
+
|
| 422 |
+
Shape
|
| 423 |
+
- **y_pred** - torch.Tensor of shape (N, C, H, W)
|
| 424 |
+
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
|
| 425 |
+
|
| 426 |
+
Reference
|
| 427 |
+
https://github.com/BloodAxe/pytorch-toolbelt
|
| 428 |
+
"""
|
| 429 |
+
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
|
| 430 |
+
super(DiceLoss, self).__init__()
|
| 431 |
+
self.mode = mode
|
| 432 |
+
if classes is not None:
|
| 433 |
+
assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary"
|
| 434 |
+
classes = to_tensor(classes, dtype=torch.long)
|
| 435 |
+
|
| 436 |
+
self.classes = classes
|
| 437 |
+
self.from_logits = from_logits
|
| 438 |
+
self.smooth = smooth
|
| 439 |
+
self.eps = eps
|
| 440 |
+
self.log_loss = log_loss
|
| 441 |
+
self.ignore_index = ignore_index
|
| 442 |
+
|
| 443 |
+
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
|
| 444 |
+
|
| 445 |
+
assert y_true.size(0) == y_pred.size(0)
|
| 446 |
+
|
| 447 |
+
if self.from_logits:
|
| 448 |
+
# Apply activations to get [0..1] class probabilities
|
| 449 |
+
# Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on
|
| 450 |
+
# extreme values 0 and 1
|
| 451 |
+
# print(y_pred)
|
| 452 |
+
|
| 453 |
+
if self.mode == MULTICLASS_MODE:
|
| 454 |
+
y_pred = y_pred.log_softmax(dim=1).exp()
|
| 455 |
+
else:
|
| 456 |
+
y_pred = F.logsigmoid(y_pred).exp()
|
| 457 |
+
|
| 458 |
+
# print("AFTER: ", y_pred)
|
| 459 |
+
|
| 460 |
+
bs = y_true.size(0)
|
| 461 |
+
num_classes = y_pred.size(1)
|
| 462 |
+
dims = (0, 2)
|
| 463 |
+
|
| 464 |
+
if self.mode == BINARY_MODE:
|
| 465 |
+
y_true = y_true.view(bs, 1, -1)
|
| 466 |
+
y_pred = y_pred.view(bs, 1, -1)
|
| 467 |
+
|
| 468 |
+
if self.ignore_index is not None:
|
| 469 |
+
mask = y_true != self.ignore_index
|
| 470 |
+
y_pred = y_pred * mask
|
| 471 |
+
y_true = y_true * mask
|
| 472 |
+
|
| 473 |
+
if self.mode == MULTICLASS_MODE:
|
| 474 |
+
|
| 475 |
+
y_true = y_true.view(bs, num_classes, -1)
|
| 476 |
+
y_pred = y_pred.view(bs, num_classes, -1)
|
| 477 |
+
|
| 478 |
+
# print("NUM CLASSES:", num_classes, y_true.size())
|
| 479 |
+
|
| 480 |
+
# if self.ignore_index is not None:
|
| 481 |
+
# mask = y_true != self.ignore_index
|
| 482 |
+
# y_pred = y_pred * mask.unsqueeze(1)
|
| 483 |
+
#
|
| 484 |
+
# y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C
|
| 485 |
+
# y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # H, C, H*W
|
| 486 |
+
# else:
|
| 487 |
+
# y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
|
| 488 |
+
# y_true = y_true.permute(0, 2, 1) # N, C, H*W
|
| 489 |
+
#
|
| 490 |
+
# print("HERE", y_true.size())
|
| 491 |
+
# print(y_pred.size())
|
| 492 |
+
|
| 493 |
+
if self.mode == MULTILABEL_MODE:
|
| 494 |
+
y_true = y_true.view(bs, num_classes, -1)
|
| 495 |
+
y_pred = y_pred.view(bs, num_classes, -1)
|
| 496 |
+
|
| 497 |
+
if self.ignore_index is not None:
|
| 498 |
+
mask = y_true != self.ignore_index
|
| 499 |
+
y_pred = y_pred * mask
|
| 500 |
+
y_true = y_true * mask
|
| 501 |
+
|
| 502 |
+
scores = self.compute_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims)
|
| 503 |
+
|
| 504 |
+
if self.log_loss:
|
| 505 |
+
loss = -torch.log(scores.clamp_min(self.eps))
|
| 506 |
+
else:
|
| 507 |
+
loss = 1.0 - scores
|
| 508 |
+
|
| 509 |
+
# Dice loss is undefined for non-empty classes
|
| 510 |
+
# So we zero contribution of channel that does not have true pixels
|
| 511 |
+
# NOTE: A better workaround would be to use loss term `mean(y_pred)`
|
| 512 |
+
# for this case, however it will be a modified jaccard loss
|
| 513 |
+
|
| 514 |
+
mask = y_true.sum(dims) > 0
|
| 515 |
+
loss *= mask.to(loss.dtype)
|
| 516 |
+
|
| 517 |
+
if self.classes is not None:
|
| 518 |
+
loss = loss[self.classes]
|
| 519 |
+
|
| 520 |
+
return self.aggregate_loss(loss)
|
| 521 |
+
|
| 522 |
+
def aggregate_loss(self, loss):
|
| 523 |
+
return loss.mean()
|
| 524 |
+
|
| 525 |
+
def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor:
|
| 526 |
+
return soft_dice_score(output, target, smooth, eps, dims)
|
| 527 |
+
|
| 528 |
+
def soft_tversky_score(
|
| 529 |
+
output: torch.Tensor,
|
| 530 |
+
target: torch.Tensor,
|
| 531 |
+
alpha: float,
|
| 532 |
+
beta: float,
|
| 533 |
+
smooth: float = 0.0,
|
| 534 |
+
eps: float = 1e-7,
|
| 535 |
+
dims=None,
|
| 536 |
+
) -> torch.Tensor:
|
| 537 |
+
assert output.size() == target.size()
|
| 538 |
+
if dims is not None:
|
| 539 |
+
intersection = torch.sum(output * target, dim=dims) # TP
|
| 540 |
+
fp = torch.sum(output * (1.0 - target), dim=dims)
|
| 541 |
+
fn = torch.sum((1 - output) * target, dim=dims)
|
| 542 |
+
else:
|
| 543 |
+
intersection = torch.sum(output * target) # TP
|
| 544 |
+
fp = torch.sum(output * (1.0 - target))
|
| 545 |
+
fn = torch.sum((1 - output) * target)
|
| 546 |
+
|
| 547 |
+
tversky_score = (intersection + smooth) / (intersection + alpha * fp + beta * fn + smooth).clamp_min(eps)
|
| 548 |
+
|
| 549 |
+
return tversky_score
|
| 550 |
+
|
| 551 |
+
class TverskyLoss(DiceLoss):
|
| 552 |
+
"""Tversky loss for image segmentation task.
|
| 553 |
+
Where TP and FP is weighted by alpha and beta params.
|
| 554 |
+
With alpha == beta == 0.5, this loss becomes equal DiceLoss.
|
| 555 |
+
It supports binary, multiclass and multilabel cases
|
| 556 |
+
|
| 557 |
+
Args:
|
| 558 |
+
mode: Metric mode {'binary', 'multiclass', 'multilabel'}
|
| 559 |
+
classes: Optional list of classes that contribute in loss computation;
|
| 560 |
+
By default, all channels are included.
|
| 561 |
+
log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky``
|
| 562 |
+
from_logits: If True assumes input is raw logits
|
| 563 |
+
smooth:
|
| 564 |
+
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
|
| 565 |
+
eps: Small epsilon for numerical stability
|
| 566 |
+
alpha: Weight constant that penalize model for FPs (False Positives)
|
| 567 |
+
beta: Weight constant that penalize model for FNs (False Positives)
|
| 568 |
+
gamma: Constant that squares the error function. Defaults to ``1.0``
|
| 569 |
+
|
| 570 |
+
Return:
|
| 571 |
+
loss: torch.Tensor
|
| 572 |
+
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
def __init__(
|
| 576 |
+
self,
|
| 577 |
+
mode: str,
|
| 578 |
+
classes: List[int] = None,
|
| 579 |
+
log_loss: bool = False,
|
| 580 |
+
from_logits: bool = True,
|
| 581 |
+
smooth: float = 0.0,
|
| 582 |
+
ignore_index: Optional[int] = None,
|
| 583 |
+
eps: float = 1e-7,
|
| 584 |
+
alpha: float = 0.5,
|
| 585 |
+
beta: float = 0.5,
|
| 586 |
+
gamma: float = 1.0
|
| 587 |
+
):
|
| 588 |
+
|
| 589 |
+
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
|
| 590 |
+
super().__init__(mode, classes, log_loss, from_logits, smooth, ignore_index, eps)
|
| 591 |
+
self.alpha = alpha
|
| 592 |
+
self.beta = beta
|
| 593 |
+
self.gamma = gamma
|
| 594 |
+
|
| 595 |
+
def aggregate_loss(self, loss):
|
| 596 |
+
return loss.mean() ** self.gamma
|
| 597 |
+
|
| 598 |
+
def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor:
|
| 599 |
+
return soft_tversky_score(output, target, self.alpha, self.beta, smooth, eps, dims)
|