josedolot commited on
Commit
0dde52f
·
1 Parent(s): 15eca91

Upload utils/smp_metrics.py

Browse files
Files changed (1) hide show
  1. utils/smp_metrics.py +758 -0
utils/smp_metrics.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Various metrics based on Type I and Type II errors.
2
+
3
+ References:
4
+ https://en.wikipedia.org/wiki/Confusion_matrix
5
+
6
+
7
+ Example:
8
+
9
+ .. code-block:: python
10
+
11
+ import segmentation_models_pytorch as smp
12
+
13
+ # lets assume we have multilabel prediction for 3 classes
14
+ output = torch.rand([10, 3, 256, 256])
15
+ target = torch.rand([10, 3, 256, 256]).round().long()
16
+
17
+ # first compute statistics for true positives, false positives, false negative and
18
+ # true negative "pixels"
19
+ tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multilabel', threshold=0.5)
20
+
21
+ # then compute metrics with required reduction (see metric docs)
22
+ iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
23
+ f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
24
+ f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro")
25
+ accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
26
+ recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
27
+
28
+ """
29
+ import torch
30
+ import warnings
31
+ from typing import Optional, List, Tuple, Union
32
+
33
+
34
+ __all__ = [
35
+ "get_stats",
36
+ "fbeta_score",
37
+ "f1_score",
38
+ "iou_score",
39
+ "accuracy",
40
+ "precision",
41
+ "recall",
42
+ "sensitivity",
43
+ "specificity",
44
+ "balanced_accuracy",
45
+ "positive_predictive_value",
46
+ "negative_predictive_value",
47
+ "false_negative_rate",
48
+ "false_positive_rate",
49
+ "false_discovery_rate",
50
+ "false_omission_rate",
51
+ "positive_likelihood_ratio",
52
+ "negative_likelihood_ratio",
53
+ ]
54
+
55
+
56
+ ###################################################################################################
57
+ # Statistics computation (true positives, false positives, false negatives, false positives)
58
+ ###################################################################################################
59
+
60
+
61
+ def get_stats(
62
+ output: Union[torch.LongTensor, torch.FloatTensor],
63
+ target: torch.LongTensor,
64
+ mode: str,
65
+ ignore_index: Optional[int] = None,
66
+ threshold: Optional[Union[float, List[float]]] = None,
67
+ num_classes: Optional[int] = None,
68
+ ) -> Tuple[torch.LongTensor]:
69
+ """Compute true positive, false positive, false negative, true negative 'pixels'
70
+ for each image and each class.
71
+
72
+ Args:
73
+ output (Union[torch.LongTensor, torch.FloatTensor]): Model output with following
74
+ shapes and types depending on the specified ``mode``:
75
+
76
+ 'binary'
77
+ shape (N, 1, ...) and ``torch.LongTensor`` or ``torch.FloatTensor``
78
+
79
+ 'multilabel'
80
+ shape (N, C, ...) and ``torch.LongTensor`` or ``torch.FloatTensor``
81
+
82
+ 'multiclass'
83
+ shape (N, ...) and ``torch.LongTensor``
84
+
85
+ target (torch.LongTensor): Targets with following shapes depending on the specified ``mode``:
86
+
87
+ 'binary'
88
+ shape (N, 1, ...)
89
+
90
+ 'multilabel'
91
+ shape (N, C, ...)
92
+
93
+ 'multiclass'
94
+ shape (N, ...)
95
+
96
+ mode (str): One of ``'binary'`` | ``'multilabel'`` | ``'multiclass'``
97
+ ignore_index (Optional[int]): Label to ignore on for metric computation.
98
+ **Not** supproted for ``'binary'`` and ``'multilabel'`` modes. Defaults to None.
99
+ threshold (Optional[float, List[float]]): Binarization threshold for
100
+ ``output`` in case of ``'binary'`` or ``'multilabel'`` modes. Defaults to None.
101
+ num_classes (Optional[int]): Number of classes, necessary attribute
102
+ only for ``'multiclass'`` mode.
103
+
104
+ Raises:
105
+ ValueError: in case of misconfiguration.
106
+
107
+ Returns:
108
+ Tuple[torch.LongTensor]: true_positive, false_positive, false_negative,
109
+ true_negative tensors (N, C) shape each.
110
+
111
+ """
112
+
113
+ if torch.is_floating_point(target):
114
+ raise ValueError(f"Target should be one of the integer types, got {target.dtype}.")
115
+
116
+ if torch.is_floating_point(output) and threshold is None:
117
+ raise ValueError(
118
+ f"Output should be one of the integer types if ``threshold`` is not None, got {output.dtype}."
119
+ )
120
+
121
+ if torch.is_floating_point(output) and mode == "multiclass":
122
+ raise ValueError(f"For ``multiclass`` mode ``target`` should be one of the integer types, got {output.dtype}.")
123
+
124
+ if mode not in {"binary", "multiclass", "multilabel"}:
125
+ raise ValueError(f"``mode`` should be in ['binary', 'multiclass', 'multilabel'], got mode={mode}.")
126
+
127
+ if mode == "multiclass" and threshold is not None:
128
+ raise ValueError("``threshold`` parameter does not supported for this 'multiclass' mode")
129
+
130
+ if output.shape != target.shape:
131
+ raise ValueError(
132
+ "Dimensions should match, but ``output`` shape is not equal to ``target`` "
133
+ + f"shape, {output.shape} != {target.shape}"
134
+ )
135
+
136
+ if mode != "multiclass" and ignore_index is not None:
137
+ raise ValueError(f"``ignore_index`` parameter is not supproted for '{mode}' mode")
138
+
139
+ if mode == "multiclass" and num_classes is None:
140
+ raise ValueError("``num_classes`` attribute should be not ``None`` for 'multiclass' mode.")
141
+
142
+ if mode == "multiclass":
143
+ if ignore_index is not None:
144
+ ignore = target == ignore_index
145
+ output = torch.where(ignore, -1, output)
146
+ target = torch.where(ignore, -1, target)
147
+ tp, fp, fn, tn = _get_stats_multiclass(output, target, num_classes)
148
+ else:
149
+ if threshold is not None:
150
+ output = torch.where(output >= threshold, 1, 0)
151
+ target = torch.where(target >= threshold, 1, 0)
152
+ tp, fp, fn, tn = _get_stats_multilabel(output, target)
153
+
154
+ return tp, fp, fn, tn
155
+
156
+
157
+ @torch.no_grad()
158
+ def _get_stats_multiclass(
159
+ output: torch.LongTensor,
160
+ target: torch.LongTensor,
161
+ num_classes: int,
162
+ ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]:
163
+
164
+ batch_size, *dims = output.shape
165
+ num_elements = torch.prod(torch.tensor(dims)).long()
166
+
167
+ tp_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
168
+ fp_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
169
+ fn_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
170
+ tn_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
171
+
172
+ for i in range(batch_size):
173
+ target_i = target[i]
174
+ output_i = output[i]
175
+ matched = target_i * (output_i == target_i)
176
+ tp = torch.histc(matched.float(), bins=num_classes, min=0, max=num_classes - 1)
177
+ fp = torch.histc(output_i.float(), bins=num_classes, min=0, max=num_classes - 1) - tp
178
+ fn = torch.histc(target_i.float(), bins=num_classes, min=0, max=num_classes - 1) - tp
179
+ tn = num_elements - tp - fp - fn
180
+ tp_count[i] = tp.long()
181
+ fp_count[i] = fp.long()
182
+ fn_count[i] = fn.long()
183
+ tn_count[i] = tn.long()
184
+
185
+ return tp_count, fp_count, fn_count, tn_count
186
+
187
+
188
+ @torch.no_grad()
189
+ def _get_stats_multilabel(
190
+ output: torch.LongTensor,
191
+ target: torch.LongTensor,
192
+ ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]:
193
+
194
+ batch_size, num_classes, *dims = target.shape
195
+ # print("HERER", batch_size, num_classes, *dims)
196
+ output = output.view(batch_size, num_classes, -1)
197
+ target = target.view(batch_size, num_classes, -1)
198
+
199
+ # print(output.size())
200
+
201
+ tp = (output * target).sum(2)
202
+ fp = output.sum(2) - tp
203
+ fn = target.sum(2) - tp
204
+ tn = torch.prod(torch.tensor(dims)) - (tp + fp + fn)
205
+
206
+ return tp, fp, fn, tn
207
+
208
+
209
+ ###################################################################################################
210
+ # Metrics computation
211
+ ###################################################################################################
212
+
213
+
214
+ def _handle_zero_division(x, zero_division):
215
+ nans = torch.isnan(x)
216
+ if torch.any(nans) and zero_division == "warn":
217
+ warnings.warn("Zero division in metric calculation!")
218
+ value = zero_division if zero_division != "warn" else 0
219
+ value = torch.tensor(value, dtype=x.dtype).to(x.device)
220
+ x = torch.where(nans, value, x)
221
+ return x
222
+
223
+
224
+ def _compute_metric(
225
+ metric_fn,
226
+ tp,
227
+ fp,
228
+ fn,
229
+ tn,
230
+ reduction: Optional[str] = None,
231
+ class_weights: Optional[List[float]] = None,
232
+ zero_division="warn",
233
+ **metric_kwargs,
234
+ ) -> float:
235
+
236
+ if class_weights is None and reduction is not None and "weighted" in reduction:
237
+ raise ValueError(f"Class weights should be provided for `{reduction}` reduction")
238
+
239
+ class_weights = class_weights if class_weights is not None else 1.0
240
+ class_weights = torch.tensor(class_weights).to(tp.device)
241
+ class_weights = class_weights / class_weights.sum()
242
+
243
+ if reduction == "micro":
244
+ tp = tp.sum()
245
+ fp = fp.sum()
246
+ fn = fn.sum()
247
+ tn = tn.sum()
248
+ score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
249
+
250
+ elif reduction == "macro" or reduction == "weighted":
251
+ tp = tp.sum(0)
252
+ fp = fp.sum(0)
253
+ fn = fn.sum(0)
254
+ tn = tn.sum(0)
255
+ score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
256
+ score = _handle_zero_division(score, zero_division)
257
+ score = (score * class_weights).mean()
258
+
259
+ elif reduction == "micro-imagewise":
260
+ tp = tp.sum(1)
261
+ fp = fp.sum(1)
262
+ fn = fn.sum(1)
263
+ tn = tn.sum(1)
264
+ score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
265
+ score = _handle_zero_division(score, zero_division)
266
+ score = score.mean()
267
+
268
+ elif reduction == "macro-imagewise" or reduction == "weighted-imagewise":
269
+ score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
270
+ score = _handle_zero_division(score, zero_division)
271
+ score = (score.mean(0) * class_weights).mean()
272
+
273
+ elif reduction == "none" or reduction is None:
274
+ score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
275
+ score = _handle_zero_division(score, zero_division)
276
+
277
+ else:
278
+ raise ValueError(
279
+ "`reduction` should be in [micro, macro, weighted, micro-imagewise,"
280
+ + "macro-imagesize, weighted-imagewise, none, None]"
281
+ )
282
+
283
+ return score
284
+
285
+
286
+ # Logic for metric computation, all metrics are with the same interface
287
+
288
+
289
+ def _fbeta_score(tp, fp, fn, tn, beta=1):
290
+ beta_tp = (1 + beta ** 2) * tp
291
+ beta_fn = (beta ** 2) * fn
292
+ score = beta_tp / (beta_tp + beta_fn + fp)
293
+ return score
294
+
295
+
296
+ def _iou_score(tp, fp, fn, tn):
297
+ return tp / (tp + fp + fn)
298
+
299
+
300
+ def _accuracy(tp, fp, fn, tn):
301
+ return tp / (tp + fp + fn + tn)
302
+
303
+
304
+ def _sensitivity(tp, fp, fn, tn):
305
+ return tp / (tp + fn)
306
+
307
+
308
+ def _specificity(tp, fp, fn, tn):
309
+ return tn / (tn + fp)
310
+
311
+
312
+ def _balanced_accuracy(tp, fp, fn, tn):
313
+ return (_sensitivity(tp, fp, fn, tn) + _specificity(tp, fp, fn, tn)) / 2
314
+
315
+
316
+ def _positive_predictive_value(tp, fp, fn, tn):
317
+ return tp / (tp + fp)
318
+
319
+
320
+ def _negative_predictive_value(tp, fp, fn, tn):
321
+ return tn / (tn + fn)
322
+
323
+
324
+ def _false_negative_rate(tp, fp, fn, tn):
325
+ return fn / (fn + tp)
326
+
327
+
328
+ def _false_positive_rate(tp, fp, fn, tn):
329
+ return fp / (fp + tn)
330
+
331
+
332
+ def _false_discovery_rate(tp, fp, fn, tn):
333
+ return 1 - _positive_predictive_value(tp, fp, fn, tn)
334
+
335
+
336
+ def _false_omission_rate(tp, fp, fn, tn):
337
+ return 1 - _negative_predictive_value(tp, fp, fn, tn)
338
+
339
+
340
+ def _positive_likelihood_ratio(tp, fp, fn, tn):
341
+ return _sensitivity(tp, fp, fn, tn) / _false_positive_rate(tp, fp, fn, tn)
342
+
343
+
344
+ def _negative_likelihood_ratio(tp, fp, fn, tn):
345
+ return _false_negative_rate(tp, fp, fn, tn) / _specificity(tp, fp, fn, tn)
346
+
347
+
348
+ def fbeta_score(
349
+ tp: torch.LongTensor,
350
+ fp: torch.LongTensor,
351
+ fn: torch.LongTensor,
352
+ tn: torch.LongTensor,
353
+ beta: float = 1.0,
354
+ reduction: Optional[str] = None,
355
+ class_weights: Optional[List[float]] = None,
356
+ zero_division: Union[str, float] = 1.0,
357
+ ) -> torch.Tensor:
358
+ """F beta score"""
359
+ return _compute_metric(
360
+ _fbeta_score,
361
+ tp,
362
+ fp,
363
+ fn,
364
+ tn,
365
+ beta=beta,
366
+ reduction=reduction,
367
+ class_weights=class_weights,
368
+ zero_division=zero_division,
369
+ )
370
+
371
+
372
+ def f1_score(
373
+ tp: torch.LongTensor,
374
+ fp: torch.LongTensor,
375
+ fn: torch.LongTensor,
376
+ tn: torch.LongTensor,
377
+ reduction: Optional[str] = None,
378
+ class_weights: Optional[List[float]] = None,
379
+ zero_division: Union[str, float] = 1.0,
380
+ ) -> torch.Tensor:
381
+ """F1 score"""
382
+ return _compute_metric(
383
+ _fbeta_score,
384
+ tp,
385
+ fp,
386
+ fn,
387
+ tn,
388
+ beta=1.0,
389
+ reduction=reduction,
390
+ class_weights=class_weights,
391
+ zero_division=zero_division,
392
+ )
393
+
394
+
395
+ def iou_score(
396
+ tp: torch.LongTensor,
397
+ fp: torch.LongTensor,
398
+ fn: torch.LongTensor,
399
+ tn: torch.LongTensor,
400
+ reduction: Optional[str] = None,
401
+ class_weights: Optional[List[float]] = None,
402
+ zero_division: Union[str, float] = 1.0,
403
+ ) -> torch.Tensor:
404
+ """IoU score or Jaccard index""" # noqa
405
+ return _compute_metric(
406
+ _iou_score,
407
+ tp,
408
+ fp,
409
+ fn,
410
+ tn,
411
+ reduction=reduction,
412
+ class_weights=class_weights,
413
+ zero_division=zero_division,
414
+ )
415
+
416
+
417
+ def accuracy(
418
+ tp: torch.LongTensor,
419
+ fp: torch.LongTensor,
420
+ fn: torch.LongTensor,
421
+ tn: torch.LongTensor,
422
+ reduction: Optional[str] = None,
423
+ class_weights: Optional[List[float]] = None,
424
+ zero_division: Union[str, float] = 1.0,
425
+ ) -> torch.Tensor:
426
+ """Accuracy"""
427
+ return _compute_metric(
428
+ _accuracy,
429
+ tp,
430
+ fp,
431
+ fn,
432
+ tn,
433
+ reduction=reduction,
434
+ class_weights=class_weights,
435
+ zero_division=zero_division,
436
+ )
437
+
438
+
439
+ def sensitivity(
440
+ tp: torch.LongTensor,
441
+ fp: torch.LongTensor,
442
+ fn: torch.LongTensor,
443
+ tn: torch.LongTensor,
444
+ reduction: Optional[str] = None,
445
+ class_weights: Optional[List[float]] = None,
446
+ zero_division: Union[str, float] = 1.0,
447
+ ) -> torch.Tensor:
448
+ """Sensitivity, recall, hit rate, or true positive rate (TPR)"""
449
+ return _compute_metric(
450
+ _sensitivity,
451
+ tp,
452
+ fp,
453
+ fn,
454
+ tn,
455
+ reduction=reduction,
456
+ class_weights=class_weights,
457
+ zero_division=zero_division,
458
+ )
459
+
460
+
461
+ def specificity(
462
+ tp: torch.LongTensor,
463
+ fp: torch.LongTensor,
464
+ fn: torch.LongTensor,
465
+ tn: torch.LongTensor,
466
+ reduction: Optional[str] = None,
467
+ class_weights: Optional[List[float]] = None,
468
+ zero_division: Union[str, float] = 1.0,
469
+ ) -> torch.Tensor:
470
+ """Specificity, selectivity or true negative rate (TNR)"""
471
+ return _compute_metric(
472
+ _specificity,
473
+ tp,
474
+ fp,
475
+ fn,
476
+ tn,
477
+ reduction=reduction,
478
+ class_weights=class_weights,
479
+ zero_division=zero_division,
480
+ )
481
+
482
+
483
+ def balanced_accuracy(
484
+ tp: torch.LongTensor,
485
+ fp: torch.LongTensor,
486
+ fn: torch.LongTensor,
487
+ tn: torch.LongTensor,
488
+ reduction: Optional[str] = None,
489
+ class_weights: Optional[List[float]] = None,
490
+ zero_division: Union[str, float] = 1.0,
491
+ ) -> torch.Tensor:
492
+ """Balanced accuracy"""
493
+ return _compute_metric(
494
+ _balanced_accuracy,
495
+ tp,
496
+ fp,
497
+ fn,
498
+ tn,
499
+ reduction=reduction,
500
+ class_weights=class_weights,
501
+ zero_division=zero_division,
502
+ )
503
+
504
+
505
+ def positive_predictive_value(
506
+ tp: torch.LongTensor,
507
+ fp: torch.LongTensor,
508
+ fn: torch.LongTensor,
509
+ tn: torch.LongTensor,
510
+ reduction: Optional[str] = None,
511
+ class_weights: Optional[List[float]] = None,
512
+ zero_division: Union[str, float] = 1.0,
513
+ ) -> torch.Tensor:
514
+ """Precision or positive predictive value (PPV)"""
515
+ return _compute_metric(
516
+ _positive_predictive_value,
517
+ tp,
518
+ fp,
519
+ fn,
520
+ tn,
521
+ reduction=reduction,
522
+ class_weights=class_weights,
523
+ zero_division=zero_division,
524
+ )
525
+
526
+
527
+ def negative_predictive_value(
528
+ tp: torch.LongTensor,
529
+ fp: torch.LongTensor,
530
+ fn: torch.LongTensor,
531
+ tn: torch.LongTensor,
532
+ reduction: Optional[str] = None,
533
+ class_weights: Optional[List[float]] = None,
534
+ zero_division: Union[str, float] = 1.0,
535
+ ) -> torch.Tensor:
536
+ """Negative predictive value (NPV)"""
537
+ return _compute_metric(
538
+ _negative_predictive_value,
539
+ tp,
540
+ fp,
541
+ fn,
542
+ tn,
543
+ reduction=reduction,
544
+ class_weights=class_weights,
545
+ zero_division=zero_division,
546
+ )
547
+
548
+
549
+ def false_negative_rate(
550
+ tp: torch.LongTensor,
551
+ fp: torch.LongTensor,
552
+ fn: torch.LongTensor,
553
+ tn: torch.LongTensor,
554
+ reduction: Optional[str] = None,
555
+ class_weights: Optional[List[float]] = None,
556
+ zero_division: Union[str, float] = 1.0,
557
+ ) -> torch.Tensor:
558
+ """Miss rate or false negative rate (FNR)"""
559
+ return _compute_metric(
560
+ _false_negative_rate,
561
+ tp,
562
+ fp,
563
+ fn,
564
+ tn,
565
+ reduction=reduction,
566
+ class_weights=class_weights,
567
+ zero_division=zero_division,
568
+ )
569
+
570
+
571
+ def false_positive_rate(
572
+ tp: torch.LongTensor,
573
+ fp: torch.LongTensor,
574
+ fn: torch.LongTensor,
575
+ tn: torch.LongTensor,
576
+ reduction: Optional[str] = None,
577
+ class_weights: Optional[List[float]] = None,
578
+ zero_division: Union[str, float] = 1.0,
579
+ ) -> torch.Tensor:
580
+ """Fall-out or false positive rate (FPR)"""
581
+ return _compute_metric(
582
+ _false_positive_rate,
583
+ tp,
584
+ fp,
585
+ fn,
586
+ tn,
587
+ reduction=reduction,
588
+ class_weights=class_weights,
589
+ zero_division=zero_division,
590
+ )
591
+
592
+
593
+ def false_discovery_rate(
594
+ tp: torch.LongTensor,
595
+ fp: torch.LongTensor,
596
+ fn: torch.LongTensor,
597
+ tn: torch.LongTensor,
598
+ reduction: Optional[str] = None,
599
+ class_weights: Optional[List[float]] = None,
600
+ zero_division: Union[str, float] = 1.0,
601
+ ) -> torch.Tensor:
602
+ """False discovery rate (FDR)""" # noqa
603
+ return _compute_metric(
604
+ _false_discovery_rate,
605
+ tp,
606
+ fp,
607
+ fn,
608
+ tn,
609
+ reduction=reduction,
610
+ class_weights=class_weights,
611
+ zero_division=zero_division,
612
+ )
613
+
614
+
615
+ def false_omission_rate(
616
+ tp: torch.LongTensor,
617
+ fp: torch.LongTensor,
618
+ fn: torch.LongTensor,
619
+ tn: torch.LongTensor,
620
+ reduction: Optional[str] = None,
621
+ class_weights: Optional[List[float]] = None,
622
+ zero_division: Union[str, float] = 1.0,
623
+ ) -> torch.Tensor:
624
+ """False omission rate (FOR)""" # noqa
625
+ return _compute_metric(
626
+ _false_omission_rate,
627
+ tp,
628
+ fp,
629
+ fn,
630
+ tn,
631
+ reduction=reduction,
632
+ class_weights=class_weights,
633
+ zero_division=zero_division,
634
+ )
635
+
636
+
637
+ def positive_likelihood_ratio(
638
+ tp: torch.LongTensor,
639
+ fp: torch.LongTensor,
640
+ fn: torch.LongTensor,
641
+ tn: torch.LongTensor,
642
+ reduction: Optional[str] = None,
643
+ class_weights: Optional[List[float]] = None,
644
+ zero_division: Union[str, float] = 1.0,
645
+ ) -> torch.Tensor:
646
+ """Positive likelihood ratio (LR+)"""
647
+ return _compute_metric(
648
+ _positive_likelihood_ratio,
649
+ tp,
650
+ fp,
651
+ fn,
652
+ tn,
653
+ reduction=reduction,
654
+ class_weights=class_weights,
655
+ zero_division=zero_division,
656
+ )
657
+
658
+
659
+ def negative_likelihood_ratio(
660
+ tp: torch.LongTensor,
661
+ fp: torch.LongTensor,
662
+ fn: torch.LongTensor,
663
+ tn: torch.LongTensor,
664
+ reduction: Optional[str] = None,
665
+ class_weights: Optional[List[float]] = None,
666
+ zero_division: Union[str, float] = 1.0,
667
+ ) -> torch.Tensor:
668
+ """Negative likelihood ratio (LR-)"""
669
+ return _compute_metric(
670
+ _negative_likelihood_ratio,
671
+ tp,
672
+ fp,
673
+ fn,
674
+ tn,
675
+ reduction=reduction,
676
+ class_weights=class_weights,
677
+ zero_division=zero_division,
678
+ )
679
+
680
+
681
+ _doc = """
682
+
683
+ Args:
684
+ tp (torch.LongTensor): tensor of shape (N, C), true positive cases
685
+ fp (torch.LongTensor): tensor of shape (N, C), false positive cases
686
+ fn (torch.LongTensor): tensor of shape (N, C), false negative cases
687
+ tn (torch.LongTensor): tensor of shape (N, C), true negative cases
688
+ reduction (Optional[str]): Define how to aggregate metric between classes and images:
689
+
690
+ - 'micro'
691
+ Sum true positive, false positive, false negative and true negative pixels over
692
+ all images and all classes and then compute score.
693
+
694
+ - 'macro'
695
+ Sum true positive, false positive, false negative and true negative pixels over
696
+ all images for each label, then compute score for each label separately and average labels scores.
697
+ This does not take label imbalance into account.
698
+
699
+ - 'weighted'
700
+ Sum true positive, false positive, false negative and true negative pixels over
701
+ all images for each label, then compute score for each label separately and average
702
+ weighted labels scores.
703
+
704
+ - 'micro-imagewise'
705
+ Sum true positive, false positive, false negative and true negative pixels for **each image**,
706
+ then compute score for **each image** and average scores over dataset. All images contribute equally
707
+ to final score, however takes into accout class imbalance for each image.
708
+
709
+ - 'macro-imagewise'
710
+ Compute score for each image and for each class on that image separately, then compute average score
711
+ on each image over labels and average image scores over dataset. Does not take into account label
712
+ imbalance on each image.
713
+
714
+ - 'weighted-imagewise'
715
+ Compute score for each image and for each class on that image separately, then compute weighted average
716
+ score on each image over labels and average image scores over dataset.
717
+
718
+ - 'none' or ``None``
719
+ Same as ``'macro-imagewise'``, but without any reduction.
720
+
721
+ For ``'binary'`` case ``'micro' = 'macro' = 'weighted'`` and
722
+ ``'micro-imagewise' = 'macro-imagewise' = 'weighted-imagewise'``.
723
+
724
+ Prefixes ``'micro'``, ``'macro'`` and ``'weighted'`` define how the scores for classes will be aggregated,
725
+ while postfix ``'imagewise'`` defines how scores between the images will be aggregated.
726
+
727
+ class_weights (Optional[List[float]]): list of class weights for metric
728
+ aggregation, in case of `weighted*` reduction is chosen. Defaults to None.
729
+ zero_division (Union[str, float]): Sets the value to return when there is a zero division,
730
+ i.e. when all predictions and labels are negative. If set to “warn”, this acts as 0,
731
+ but warnings are also raised. Defaults to 1.
732
+
733
+ Returns:
734
+ torch.Tensor: if ``'reduction'`` is not ``None`` or ``'none'`` returns scalar metric,
735
+ else returns tensor of shape (N, C)
736
+
737
+ References:
738
+ https://en.wikipedia.org/wiki/Confusion_matrix
739
+ """
740
+
741
+ fbeta_score.__doc__ += _doc
742
+ f1_score.__doc__ += _doc
743
+ iou_score.__doc__ += _doc
744
+ accuracy.__doc__ += _doc
745
+ sensitivity.__doc__ += _doc
746
+ specificity.__doc__ += _doc
747
+ balanced_accuracy.__doc__ += _doc
748
+ positive_predictive_value.__doc__ += _doc
749
+ negative_predictive_value.__doc__ += _doc
750
+ false_negative_rate.__doc__ += _doc
751
+ false_positive_rate.__doc__ += _doc
752
+ false_discovery_rate.__doc__ += _doc
753
+ false_omission_rate.__doc__ += _doc
754
+ positive_likelihood_ratio.__doc__ += _doc
755
+ negative_likelihood_ratio.__doc__ += _doc
756
+
757
+ precision = positive_predictive_value
758
+ recall = sensitivity