File size: 6,716 Bytes
96c0ca2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#!/usr/bin/env python
# encoding: utf-8
'''

@license: (C) Copyright 2021, Hey.

@author: Hey

@email: [email protected]

@tel: 137****6540

@datetime: 2023/6/28 10:25

@project: LucaOne

@file: masked_loss.py

@desc: masked loss

'''
import warnings
import torch
import torch.nn as nn


class _MaskedLoss(nn.Module):
    """Base class for masked losses"""

    def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
        super().__init__()
        self.reduction = reduction
        self.ignore_nans = ignore_nans
        self.ignore_value = ignore_value

    def forward(self, pred, target, mask=None):
        """Compute a loss between pred and target for given mask.

        Note that this implementation is faster than loss(pred[mask], target[mask])

        for a given loss, and is nan-proof."""
        '''

        if not (target.size() == pred.size()):

            warnings.warn(

                "Using a target size ({}) that is different to the pred size ({}). "

                "This will likely lead to incorrect results due to broadcasting. "

                "Please ensure they have the same size.".format(

                    target.size(), pred.size()),

                stacklevel=2,

            )

        '''
        if mask is None and self.ignore_value is not None:
            mask = target != self.ignore_value
        elif mask is None:
            mask = torch.ones_like(target, dtype=bool)
        target_proxy = target
        if self.ignore_nans:
            target_proxy = target.clone()
            nans = torch.isnan(target)
            if nans.any():
                with torch.no_grad():
                    mask = mask & ~nans
                    target_proxy[nans] = 0
        # full_loss = self.criterion(pred, target_proxy)
        # print("mask shape")
        # print(mask.shape)
        if self.reduction == 'meanmean' and pred.ndim == 3 and pred.shape[-1] == 1:
            # token-level binary classification
            # pred: n , seq_len, 1 -> n * seq_len
            # target: n, seq_len -> n * seq_len
            full_loss = self.criterion(pred.view(-1), target_proxy.view(-1))
            full_loss = torch.reshape(full_loss, (-1, pred.shape[1]))
            # print("ok1")
        elif self.reduction == 'meanmean' and pred.ndim == 3:
            if target.ndim == 3:
                # token-level regression
                # pred: n , seq_len, label_size -> n * seq_len * label_size
                # target: n, seq_len, label_size -> n * seq_len * label_size
                full_loss = self.criterion(pred.view(-1), target_proxy.view(-1))
                full_loss = torch.reshape(full_loss, (-1, pred.shape[1], pred.shape[-1]))
                # print("ok21")
            else:
                # token-level multi classification
                # pred: n , seq_len, label_size -> n * seq_len, label_size
                # target: n, seq_len -> n * seq_len
                full_loss = self.criterion(pred.view(-1, pred.shape[-1]), target_proxy.view(-1))
                full_loss = torch.reshape(full_loss, (-1, pred.shape[1]))
                # print("ok22")
        elif self.reduction == 'meanmean' and pred.ndim == 2 and target.ndim == 2:
            # seq-level multi label
            # pred: n , label_size -> n * label_size
            # target: n, label_size -> n * label_size
            full_loss = self.criterion(pred.view(-1), target_proxy.view(-1))
            full_loss = torch.reshape(full_loss, (-1, pred.shape[1]))
            # print("ok3")
        elif self.reduction == 'meanmean':
            self.reduction = "mean"
            full_loss = self.criterion(pred, target_proxy)
            # print("ok4")
        else:
            full_loss = self.criterion(pred, target_proxy)
            # print("ok5")

        full_loss[~mask] = 0
        '''

        if not mask.any():

            warnings.warn("Evaluation mask is False everywhere, this might lead to incorrect results.")

            print(full_loss.sum(), mask.to(full_loss.dtype).sum())

        '''
        if self.reduction == 'none':
            return full_loss
        if self.reduction == 'sum':
            return full_loss.sum()
        if self.reduction == 'mean':
            '''

            print("mask:")

            print(mask.to(full_loss.dtype).sum(dim=-1))

            print(mask.to(full_loss.dtype).sum())

            '''
            return full_loss.sum() / (mask.to(full_loss.dtype).sum() + 1e-12)
        if self.reduction == 'meanmean':
            if mask.ndim == 3:
                mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
                '''

                print("mask:")

                print(mask_sum)

                '''
                full_loss = full_loss.sum(dim=-1) / (mask_sum + 1e-12)
                mask_sum = mask_sum.to(torch.bool).sum(dim=-1)
                # print(mask_sum)
                full_loss = full_loss.sum(dim=-1) / (mask_sum + 1e-12)
                mask_sum = mask_sum.to(torch.bool).sum()
                # print(mask_sum)
                loss = full_loss.sum() / (mask_sum + 1e-12)
            else:
                mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
                '''

                print("mask:")

                print(mask_sum)

                print(mask_sum.to(torch.bool).sum())

                '''
                loss = torch.sum(full_loss.sum(dim=-1) / (mask_sum + 1e-12)) / (mask_sum.to(torch.bool).sum() + 1e-12)
            # print(full_loss.sum() / (mask.to(full_loss.dtype).sum() + 1e-12), loss)
            return loss
        if self.reduction in ["summean", "meansum"]:
            if mask.ndim == 3:
                mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
                '''

                print("mask:")

                print(mask_sum)

                '''
                full_loss = full_loss.sum(dim=-1)
                mask_sum = mask_sum.to(torch.bool).sum(dim=-1)
                # print(mask_sum)
                full_loss = full_loss.sum(dim=-1) / (mask_sum + 1e-12)
                mask_sum = mask_sum.to(torch.bool).sum()
                # print(mask_sum)
                loss = full_loss.sum() / (mask_sum + 1e-12)
            else:
                mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
                '''

                print("mask:")

                print(mask_sum)

                print(mask_sum.to(torch.bool).sum())

                '''
                loss = full_loss.sum() / (mask_sum.to(torch.bool).sum() + 1e-12)
            return loss
        return full_loss