File size: 6,716 Bytes
96c0ca2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
#!/usr/bin/env python
# encoding: utf-8
'''
@license: (C) Copyright 2021, Hey.
@author: Hey
@email: [email protected]
@tel: 137****6540
@datetime: 2023/6/28 10:25
@project: LucaOne
@file: masked_loss.py
@desc: masked loss
'''
import warnings
import torch
import torch.nn as nn
class _MaskedLoss(nn.Module):
"""Base class for masked losses"""
def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
super().__init__()
self.reduction = reduction
self.ignore_nans = ignore_nans
self.ignore_value = ignore_value
def forward(self, pred, target, mask=None):
"""Compute a loss between pred and target for given mask.
Note that this implementation is faster than loss(pred[mask], target[mask])
for a given loss, and is nan-proof."""
'''
if not (target.size() == pred.size()):
warnings.warn(
"Using a target size ({}) that is different to the pred size ({}). "
"This will likely lead to incorrect results due to broadcasting. "
"Please ensure they have the same size.".format(
target.size(), pred.size()),
stacklevel=2,
)
'''
if mask is None and self.ignore_value is not None:
mask = target != self.ignore_value
elif mask is None:
mask = torch.ones_like(target, dtype=bool)
target_proxy = target
if self.ignore_nans:
target_proxy = target.clone()
nans = torch.isnan(target)
if nans.any():
with torch.no_grad():
mask = mask & ~nans
target_proxy[nans] = 0
# full_loss = self.criterion(pred, target_proxy)
# print("mask shape")
# print(mask.shape)
if self.reduction == 'meanmean' and pred.ndim == 3 and pred.shape[-1] == 1:
# token-level binary classification
# pred: n , seq_len, 1 -> n * seq_len
# target: n, seq_len -> n * seq_len
full_loss = self.criterion(pred.view(-1), target_proxy.view(-1))
full_loss = torch.reshape(full_loss, (-1, pred.shape[1]))
# print("ok1")
elif self.reduction == 'meanmean' and pred.ndim == 3:
if target.ndim == 3:
# token-level regression
# pred: n , seq_len, label_size -> n * seq_len * label_size
# target: n, seq_len, label_size -> n * seq_len * label_size
full_loss = self.criterion(pred.view(-1), target_proxy.view(-1))
full_loss = torch.reshape(full_loss, (-1, pred.shape[1], pred.shape[-1]))
# print("ok21")
else:
# token-level multi classification
# pred: n , seq_len, label_size -> n * seq_len, label_size
# target: n, seq_len -> n * seq_len
full_loss = self.criterion(pred.view(-1, pred.shape[-1]), target_proxy.view(-1))
full_loss = torch.reshape(full_loss, (-1, pred.shape[1]))
# print("ok22")
elif self.reduction == 'meanmean' and pred.ndim == 2 and target.ndim == 2:
# seq-level multi label
# pred: n , label_size -> n * label_size
# target: n, label_size -> n * label_size
full_loss = self.criterion(pred.view(-1), target_proxy.view(-1))
full_loss = torch.reshape(full_loss, (-1, pred.shape[1]))
# print("ok3")
elif self.reduction == 'meanmean':
self.reduction = "mean"
full_loss = self.criterion(pred, target_proxy)
# print("ok4")
else:
full_loss = self.criterion(pred, target_proxy)
# print("ok5")
full_loss[~mask] = 0
'''
if not mask.any():
warnings.warn("Evaluation mask is False everywhere, this might lead to incorrect results.")
print(full_loss.sum(), mask.to(full_loss.dtype).sum())
'''
if self.reduction == 'none':
return full_loss
if self.reduction == 'sum':
return full_loss.sum()
if self.reduction == 'mean':
'''
print("mask:")
print(mask.to(full_loss.dtype).sum(dim=-1))
print(mask.to(full_loss.dtype).sum())
'''
return full_loss.sum() / (mask.to(full_loss.dtype).sum() + 1e-12)
if self.reduction == 'meanmean':
if mask.ndim == 3:
mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
'''
print("mask:")
print(mask_sum)
'''
full_loss = full_loss.sum(dim=-1) / (mask_sum + 1e-12)
mask_sum = mask_sum.to(torch.bool).sum(dim=-1)
# print(mask_sum)
full_loss = full_loss.sum(dim=-1) / (mask_sum + 1e-12)
mask_sum = mask_sum.to(torch.bool).sum()
# print(mask_sum)
loss = full_loss.sum() / (mask_sum + 1e-12)
else:
mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
'''
print("mask:")
print(mask_sum)
print(mask_sum.to(torch.bool).sum())
'''
loss = torch.sum(full_loss.sum(dim=-1) / (mask_sum + 1e-12)) / (mask_sum.to(torch.bool).sum() + 1e-12)
# print(full_loss.sum() / (mask.to(full_loss.dtype).sum() + 1e-12), loss)
return loss
if self.reduction in ["summean", "meansum"]:
if mask.ndim == 3:
mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
'''
print("mask:")
print(mask_sum)
'''
full_loss = full_loss.sum(dim=-1)
mask_sum = mask_sum.to(torch.bool).sum(dim=-1)
# print(mask_sum)
full_loss = full_loss.sum(dim=-1) / (mask_sum + 1e-12)
mask_sum = mask_sum.to(torch.bool).sum()
# print(mask_sum)
loss = full_loss.sum() / (mask_sum + 1e-12)
else:
mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
'''
print("mask:")
print(mask_sum)
print(mask_sum.to(torch.bool).sum())
'''
loss = full_loss.sum() / (mask_sum.to(torch.bool).sum() + 1e-12)
return loss
return full_loss
|