josedolot commited on
Commit
c432040
·
1 Parent(s): 309a856

Upload hybridnets/loss.py

Browse files
Files changed (1) hide show
  1. hybridnets/loss.py +599 -0
hybridnets/loss.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import cv2
4
+ import numpy as np
5
+ from torch.nn.modules.loss import _Loss
6
+ import torch.nn.functional as F
7
+ from utils.utils import postprocess, display, BBoxTransform, ClipBoxes
8
+ from typing import Optional, List
9
+ from functools import partial
10
+
11
+ BINARY_MODE: str = "binary"
12
+ MULTICLASS_MODE: str = "multiclass"
13
+ MULTILABEL_MODE: str = "multilabel"
14
+
15
+ def calc_iou(a, b):
16
+ # a(anchor) [boxes, (y1, x1, y2, x2)]
17
+ # b(gt, coco-style) [boxes, (x1, y1, x2, y2)]
18
+
19
+ area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
20
+ iw = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 0])
21
+ ih = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 1])
22
+ iw = torch.clamp(iw, min=0)
23
+ ih = torch.clamp(ih, min=0)
24
+ ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih
25
+ ua = torch.clamp(ua, min=1e-8)
26
+ intersection = iw * ih
27
+ IoU = intersection / ua
28
+
29
+
30
+ return IoU
31
+
32
+
33
+ class FocalLoss(nn.Module):
34
+ def __init__(self):
35
+ super(FocalLoss, self).__init__()
36
+
37
+ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
38
+ alpha = 0.25
39
+ gamma = 2.0
40
+ batch_size = classifications.shape[0]
41
+ classification_losses = []
42
+ regression_losses = []
43
+
44
+ anchor = anchors[0, :, :] # assuming all image sizes are the same, which it is
45
+ dtype = anchors.dtype
46
+
47
+ anchor_widths = anchor[:, 3] - anchor[:, 1]
48
+ anchor_heights = anchor[:, 2] - anchor[:, 0]
49
+ anchor_ctr_x = anchor[:, 1] + 0.5 * anchor_widths
50
+ anchor_ctr_y = anchor[:, 0] + 0.5 * anchor_heights
51
+
52
+ for j in range(batch_size):
53
+
54
+ classification = classifications[j, :, :]
55
+ regression = regressions[j, :, :]
56
+
57
+ bbox_annotation = annotations[j]
58
+ bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
59
+
60
+ # print(bbox_annotation)
61
+
62
+ classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
63
+
64
+ if bbox_annotation.shape[0] == 0:
65
+ if torch.cuda.is_available():
66
+
67
+ alpha_factor = torch.ones_like(classification) * alpha
68
+ alpha_factor = alpha_factor.cuda()
69
+ alpha_factor = 1. - alpha_factor
70
+ focal_weight = classification
71
+ focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
72
+
73
+ bce = -(torch.log(1.0 - classification))
74
+
75
+ cls_loss = focal_weight * bce
76
+
77
+ regression_losses.append(torch.tensor(0).to(dtype).cuda())
78
+ classification_losses.append(cls_loss.sum())
79
+ else:
80
+
81
+ alpha_factor = torch.ones_like(classification) * alpha
82
+ alpha_factor = 1. - alpha_factor
83
+ focal_weight = classification
84
+ focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
85
+
86
+ bce = -(torch.log(1.0 - classification))
87
+
88
+ cls_loss = focal_weight * bce
89
+
90
+ regression_losses.append(torch.tensor(0).to(dtype))
91
+ classification_losses.append(cls_loss.sum())
92
+
93
+ continue
94
+
95
+ IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4])
96
+
97
+ IoU_max, IoU_argmax = torch.max(IoU, dim=1)
98
+
99
+
100
+ # compute the loss for classification
101
+ #targets = torch.ones_like(classification) * -1
102
+ targets = torch.zeros_like(classification)
103
+
104
+ if torch.cuda.is_available():
105
+ targets = targets.cuda()
106
+
107
+ assigned_annotations = bbox_annotation[IoU_argmax, :]
108
+
109
+ positive_indices = torch.full_like(IoU_max,False,dtype=torch.bool) #torch.ge(IoU_max, 0.2)
110
+
111
+ tensorA = (assigned_annotations[:, 2] - assigned_annotations[:, 0]) * (assigned_annotations[:, 3] - assigned_annotations[:, 1]) > 10 * 10
112
+ # for idx,iou in enumerate(IoU_max):
113
+ # if tensorA[idx]: # Set iou threshold = 0.5
114
+ # if iou >= 0.5:
115
+ # positive_indices[idx] = True
116
+ # # targets[idx,:] = True
117
+ # # else:
118
+ # # positive_indices[idx] = False
119
+ # else:
120
+ # if iou >= 0.15:
121
+ # positive_indices[idx] = True
122
+ # # else:
123
+ # # positive_indices[idx] = False
124
+
125
+ # # targets[torch.lt(IoU_max, 0.4), :] = 0
126
+
127
+
128
+ positive_indices[torch.logical_or(torch.logical_and(tensorA,IoU_max >= 0.5),torch.logical_and(~tensorA,IoU_max >= 0.15))] = True
129
+
130
+ num_positive_anchors = positive_indices.sum()
131
+
132
+ # for box in assigned_annotations[positive_indices, :]:
133
+ # xmin,ymin,xmax,ymax, cls = box
134
+ # print("WIDTH HEIGHT:", (xmax-xmin),"\t", (ymax-ymin))
135
+ # for box in bbox_annotation:
136
+ # xmin,ymin,xmax,ymax, cls = box
137
+ # print("111 WIDTH HEIGHT:", (xmax-xmin),"\t", (ymax-ymin))
138
+
139
+
140
+ # targets[positive_indices, :] = 0
141
+ targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
142
+
143
+ alpha_factor = torch.ones_like(targets) * alpha
144
+ if torch.cuda.is_available():
145
+ alpha_factor = alpha_factor.cuda()
146
+
147
+ alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
148
+ focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
149
+ focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
150
+
151
+ bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))
152
+
153
+ cls_loss = focal_weight * bce
154
+
155
+ zeros = torch.zeros_like(cls_loss)
156
+ if torch.cuda.is_available():
157
+ zeros = zeros.cuda()
158
+ cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros)
159
+
160
+ classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
161
+
162
+ if positive_indices.sum() > 0:
163
+ assigned_annotations = assigned_annotations[positive_indices, :]
164
+
165
+ anchor_widths_pi = anchor_widths[positive_indices]
166
+ anchor_heights_pi = anchor_heights[positive_indices]
167
+ anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
168
+ anchor_ctr_y_pi = anchor_ctr_y[positive_indices]
169
+
170
+ gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0]
171
+ gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
172
+ gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths
173
+ gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights
174
+
175
+ gt_widths = torch.clamp(gt_widths, min=1)
176
+ gt_heights = torch.clamp(gt_heights, min=1)
177
+
178
+ targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
179
+ targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
180
+ targets_dw = torch.log(gt_widths / anchor_widths_pi)
181
+ targets_dh = torch.log(gt_heights / anchor_heights_pi)
182
+
183
+ targets = torch.stack((targets_dy, targets_dx, targets_dh, targets_dw))
184
+ targets = targets.t()
185
+
186
+ regression_diff = torch.abs(targets - regression[positive_indices, :])
187
+
188
+ regression_loss = torch.where(
189
+ torch.le(regression_diff, 1.0 / 9.0),
190
+ 0.5 * 9.0 * torch.pow(regression_diff, 2),
191
+ regression_diff - 0.5 / 9.0
192
+ )
193
+ regression_losses.append(regression_loss.mean())
194
+ else:
195
+ if torch.cuda.is_available():
196
+ regression_losses.append(torch.tensor(0).to(dtype).cuda())
197
+ else:
198
+ regression_losses.append(torch.tensor(0).to(dtype))
199
+
200
+ # debug
201
+ imgs = kwargs.get('imgs', None)
202
+ if imgs is not None:
203
+ regressBoxes = BBoxTransform()
204
+ clipBoxes = ClipBoxes()
205
+ obj_list = kwargs.get('obj_list', None)
206
+ out = postprocess(imgs.detach(),
207
+ torch.stack([anchors[0]] * imgs.shape[0], 0).detach(), regressions.detach(), classifications.detach(),
208
+ regressBoxes, clipBoxes,
209
+ 0.25, 0.3)
210
+ imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
211
+ imgs = ((imgs * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255).astype(np.uint8)
212
+ imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in imgs]
213
+ display(out, imgs, obj_list, imshow=False, imwrite=True)
214
+
215
+ return torch.stack(classification_losses).mean(dim=0, keepdim=True), \
216
+ torch.stack(regression_losses).mean(dim=0, keepdim=True) * 50 # https://github.com/google/automl/blob/6fdd1de778408625c1faf368a327fe36ecd41bf7/efficientdet/hparams_config.py#L233
217
+
218
+
219
+ def focal_loss_with_logits(
220
+ output: torch.Tensor,
221
+ target: torch.Tensor,
222
+ gamma: float = 2.0,
223
+ alpha: Optional[float] = 0.25,
224
+ reduction: str = "mean",
225
+ normalized: bool = False,
226
+ reduced_threshold: Optional[float] = None,
227
+ eps: float = 1e-6,
228
+ ) -> torch.Tensor:
229
+ """Compute binary focal loss between target and output logits.
230
+ See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
231
+ Args:
232
+ output: Tensor of arbitrary shape (predictions of the model)
233
+ target: Tensor of the same shape as input
234
+ gamma: Focal loss power factor
235
+ alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range,
236
+ high values will give more weight to positive class.
237
+ reduction (string, optional): Specifies the reduction to apply to the output:
238
+ 'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied,
239
+ 'mean': the sum of the output will be divided by the number of
240
+ elements in the output, 'sum': the output will be summed. Note: :attr:`size_average`
241
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
242
+ specifying either of those two args will override :attr:`reduction`.
243
+ 'batchwise_mean' computes mean loss per sample in batch. Default: 'mean'
244
+ normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
245
+ reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347).
246
+ References:
247
+ https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py
248
+ """
249
+ target = target.type(output.type())
250
+ # print(output.size(), target.size())
251
+
252
+ logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none")
253
+ pt = torch.exp(-logpt)
254
+
255
+ # compute the loss
256
+ if reduced_threshold is None:
257
+ focal_term = (1.0 - pt).pow(gamma)
258
+ else:
259
+ focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma)
260
+ focal_term[pt < reduced_threshold] = 1
261
+
262
+ loss = focal_term * logpt
263
+
264
+ if alpha is not None:
265
+ loss *= alpha * target + (1 - alpha) * (1 - target)
266
+
267
+ if normalized:
268
+ norm_factor = focal_term.sum().clamp_min(eps)
269
+ loss /= norm_factor
270
+
271
+ if reduction == "mean":
272
+ loss = loss.mean()
273
+ if reduction == "sum":
274
+ loss = loss.sum()
275
+ if reduction == "batchwise_mean":
276
+ loss = loss.sum(0)
277
+
278
+ return loss
279
+
280
+
281
+ class FocalLossSeg(_Loss):
282
+ def __init__(
283
+ self,
284
+ mode: str,
285
+ alpha: Optional[float] = None,
286
+ gamma: Optional[float] = 2.0,
287
+ ignore_index: Optional[int] = None,
288
+ reduction: Optional[str] = "mean",
289
+ normalized: bool = False,
290
+ reduced_threshold: Optional[float] = None,
291
+ ):
292
+ """Compute Focal loss
293
+
294
+ Args:
295
+ mode: Loss mode 'binary', 'multiclass' or 'multilabel'
296
+ alpha: Prior probability of having positive value in target.
297
+ gamma: Power factor for dampening weight (focal strength).
298
+ ignore_index: If not None, targets may contain values to be ignored.
299
+ Target values equal to ignore_index will be ignored from loss computation.
300
+ normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
301
+ reduced_threshold: Switch to reduced focal loss. Note, when using this mode you
302
+ should use `reduction="sum"`.
303
+
304
+ Shape
305
+ - **y_pred** - torch.Tensor of shape (N, C, H, W)
306
+ - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
307
+
308
+ Reference
309
+ https://github.com/BloodAxe/pytorch-toolbelt
310
+
311
+ """
312
+ assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
313
+ super().__init__()
314
+
315
+ self.mode = mode
316
+ self.ignore_index = ignore_index
317
+ self.focal_loss_fn = partial(
318
+ focal_loss_with_logits,
319
+ alpha=alpha,
320
+ gamma=gamma,
321
+ reduced_threshold=reduced_threshold,
322
+ reduction=reduction,
323
+ normalized=normalized,
324
+ )
325
+
326
+ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
327
+
328
+ if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
329
+ y_true = y_true.view(-1)
330
+ y_pred = y_pred.view(-1)
331
+
332
+ if self.ignore_index is not None:
333
+ # Filter predictions with ignore label from loss computation
334
+ not_ignored = y_true != self.ignore_index
335
+ y_pred = y_pred[not_ignored]
336
+ y_true = y_true[not_ignored]
337
+
338
+ loss = self.focal_loss_fn(y_pred, y_true)
339
+
340
+ elif self.mode == MULTICLASS_MODE:
341
+ num_classes = y_pred.size(1)
342
+ loss = 0
343
+
344
+ # Filter anchors with -1 label from loss computation
345
+ if self.ignore_index is not None:
346
+ not_ignored = y_true != self.ignore_index
347
+
348
+ for cls in range(num_classes):
349
+ # cls_y_true = (y_true == cls).long()
350
+
351
+ cls_y_true = y_true[:, cls, ...]
352
+ cls_y_pred = y_pred[:, cls, ...]
353
+
354
+ if self.ignore_index is not None:
355
+ cls_y_true = cls_y_true[not_ignored]
356
+ cls_y_pred = cls_y_pred[not_ignored]
357
+
358
+ loss += self.focal_loss_fn(cls_y_pred, cls_y_true)
359
+
360
+ return loss
361
+
362
+ def to_tensor(x, dtype=None) -> torch.Tensor:
363
+ if isinstance(x, torch.Tensor):
364
+ if dtype is not None:
365
+ x = x.type(dtype)
366
+ return x
367
+ if isinstance(x, np.ndarray):
368
+ x = torch.from_numpy(x)
369
+ if dtype is not None:
370
+ x = x.type(dtype)
371
+ return x
372
+ if isinstance(x, (list, tuple)):
373
+ x = np.array(x)
374
+ x = torch.from_numpy(x)
375
+ if dtype is not None:
376
+ x = x.type(dtype)
377
+ return x
378
+
379
+
380
+ def soft_dice_score(
381
+ output: torch.Tensor,
382
+ target: torch.Tensor,
383
+ smooth: float = 0.0,
384
+ eps: float = 1e-7,
385
+ dims=None,
386
+ ) -> torch.Tensor:
387
+ assert output.size() == target.size()
388
+ if dims is not None:
389
+ intersection = torch.sum(output * target, dim=dims)
390
+ cardinality = torch.sum(output + target, dim=dims)
391
+ else:
392
+ intersection = torch.sum(output * target)
393
+ cardinality = torch.sum(output + target)
394
+ dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
395
+ return dice_score
396
+
397
+
398
+ class DiceLoss(_Loss):
399
+ def __init__(
400
+ self,
401
+ mode: str,
402
+ classes: Optional[List[int]] = None,
403
+ log_loss: bool = False,
404
+ from_logits: bool = True,
405
+ smooth: float = 0.0,
406
+ ignore_index: Optional[int] = None,
407
+ eps: float = 1e-7,
408
+ ):
409
+ """Dice loss for image segmentation task.
410
+ It supports binary, multiclass and multilabel cases
411
+
412
+ Args:
413
+ mode: Loss mode 'binary', 'multiclass' or 'multilabel'
414
+ classes: List of classes that contribute in loss computation. By default, all channels are included.
415
+ log_loss: If True, loss computed as `- log(dice_coeff)`, otherwise `1 - dice_coeff`
416
+ from_logits: If True, assumes input is raw logits
417
+ smooth: Smoothness constant for dice coefficient (a)
418
+ ignore_index: Label that indicates ignored pixels (does not contribute to loss)
419
+ eps: A small epsilon for numerical stability to avoid zero division error
420
+ (denominator will be always greater or equal to eps)
421
+
422
+ Shape
423
+ - **y_pred** - torch.Tensor of shape (N, C, H, W)
424
+ - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
425
+
426
+ Reference
427
+ https://github.com/BloodAxe/pytorch-toolbelt
428
+ """
429
+ assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
430
+ super(DiceLoss, self).__init__()
431
+ self.mode = mode
432
+ if classes is not None:
433
+ assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary"
434
+ classes = to_tensor(classes, dtype=torch.long)
435
+
436
+ self.classes = classes
437
+ self.from_logits = from_logits
438
+ self.smooth = smooth
439
+ self.eps = eps
440
+ self.log_loss = log_loss
441
+ self.ignore_index = ignore_index
442
+
443
+ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
444
+
445
+ assert y_true.size(0) == y_pred.size(0)
446
+
447
+ if self.from_logits:
448
+ # Apply activations to get [0..1] class probabilities
449
+ # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on
450
+ # extreme values 0 and 1
451
+ # print(y_pred)
452
+
453
+ if self.mode == MULTICLASS_MODE:
454
+ y_pred = y_pred.log_softmax(dim=1).exp()
455
+ else:
456
+ y_pred = F.logsigmoid(y_pred).exp()
457
+
458
+ # print("AFTER: ", y_pred)
459
+
460
+ bs = y_true.size(0)
461
+ num_classes = y_pred.size(1)
462
+ dims = (0, 2)
463
+
464
+ if self.mode == BINARY_MODE:
465
+ y_true = y_true.view(bs, 1, -1)
466
+ y_pred = y_pred.view(bs, 1, -1)
467
+
468
+ if self.ignore_index is not None:
469
+ mask = y_true != self.ignore_index
470
+ y_pred = y_pred * mask
471
+ y_true = y_true * mask
472
+
473
+ if self.mode == MULTICLASS_MODE:
474
+
475
+ y_true = y_true.view(bs, num_classes, -1)
476
+ y_pred = y_pred.view(bs, num_classes, -1)
477
+
478
+ # print("NUM CLASSES:", num_classes, y_true.size())
479
+
480
+ # if self.ignore_index is not None:
481
+ # mask = y_true != self.ignore_index
482
+ # y_pred = y_pred * mask.unsqueeze(1)
483
+ #
484
+ # y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C
485
+ # y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # H, C, H*W
486
+ # else:
487
+ # y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
488
+ # y_true = y_true.permute(0, 2, 1) # N, C, H*W
489
+ #
490
+ # print("HERE", y_true.size())
491
+ # print(y_pred.size())
492
+
493
+ if self.mode == MULTILABEL_MODE:
494
+ y_true = y_true.view(bs, num_classes, -1)
495
+ y_pred = y_pred.view(bs, num_classes, -1)
496
+
497
+ if self.ignore_index is not None:
498
+ mask = y_true != self.ignore_index
499
+ y_pred = y_pred * mask
500
+ y_true = y_true * mask
501
+
502
+ scores = self.compute_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims)
503
+
504
+ if self.log_loss:
505
+ loss = -torch.log(scores.clamp_min(self.eps))
506
+ else:
507
+ loss = 1.0 - scores
508
+
509
+ # Dice loss is undefined for non-empty classes
510
+ # So we zero contribution of channel that does not have true pixels
511
+ # NOTE: A better workaround would be to use loss term `mean(y_pred)`
512
+ # for this case, however it will be a modified jaccard loss
513
+
514
+ mask = y_true.sum(dims) > 0
515
+ loss *= mask.to(loss.dtype)
516
+
517
+ if self.classes is not None:
518
+ loss = loss[self.classes]
519
+
520
+ return self.aggregate_loss(loss)
521
+
522
+ def aggregate_loss(self, loss):
523
+ return loss.mean()
524
+
525
+ def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor:
526
+ return soft_dice_score(output, target, smooth, eps, dims)
527
+
528
+ def soft_tversky_score(
529
+ output: torch.Tensor,
530
+ target: torch.Tensor,
531
+ alpha: float,
532
+ beta: float,
533
+ smooth: float = 0.0,
534
+ eps: float = 1e-7,
535
+ dims=None,
536
+ ) -> torch.Tensor:
537
+ assert output.size() == target.size()
538
+ if dims is not None:
539
+ intersection = torch.sum(output * target, dim=dims) # TP
540
+ fp = torch.sum(output * (1.0 - target), dim=dims)
541
+ fn = torch.sum((1 - output) * target, dim=dims)
542
+ else:
543
+ intersection = torch.sum(output * target) # TP
544
+ fp = torch.sum(output * (1.0 - target))
545
+ fn = torch.sum((1 - output) * target)
546
+
547
+ tversky_score = (intersection + smooth) / (intersection + alpha * fp + beta * fn + smooth).clamp_min(eps)
548
+
549
+ return tversky_score
550
+
551
+ class TverskyLoss(DiceLoss):
552
+ """Tversky loss for image segmentation task.
553
+ Where TP and FP is weighted by alpha and beta params.
554
+ With alpha == beta == 0.5, this loss becomes equal DiceLoss.
555
+ It supports binary, multiclass and multilabel cases
556
+
557
+ Args:
558
+ mode: Metric mode {'binary', 'multiclass', 'multilabel'}
559
+ classes: Optional list of classes that contribute in loss computation;
560
+ By default, all channels are included.
561
+ log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky``
562
+ from_logits: If True assumes input is raw logits
563
+ smooth:
564
+ ignore_index: Label that indicates ignored pixels (does not contribute to loss)
565
+ eps: Small epsilon for numerical stability
566
+ alpha: Weight constant that penalize model for FPs (False Positives)
567
+ beta: Weight constant that penalize model for FNs (False Positives)
568
+ gamma: Constant that squares the error function. Defaults to ``1.0``
569
+
570
+ Return:
571
+ loss: torch.Tensor
572
+
573
+ """
574
+
575
+ def __init__(
576
+ self,
577
+ mode: str,
578
+ classes: List[int] = None,
579
+ log_loss: bool = False,
580
+ from_logits: bool = True,
581
+ smooth: float = 0.0,
582
+ ignore_index: Optional[int] = None,
583
+ eps: float = 1e-7,
584
+ alpha: float = 0.5,
585
+ beta: float = 0.5,
586
+ gamma: float = 1.0
587
+ ):
588
+
589
+ assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
590
+ super().__init__(mode, classes, log_loss, from_logits, smooth, ignore_index, eps)
591
+ self.alpha = alpha
592
+ self.beta = beta
593
+ self.gamma = gamma
594
+
595
+ def aggregate_loss(self, loss):
596
+ return loss.mean() ** self.gamma
597
+
598
+ def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor:
599
+ return soft_tversky_score(output, target, self.alpha, self.beta, smooth, eps, dims)