File size: 7,006 Bytes
fc67275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion


def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
    """
    Function to compute the cross entropy loss. The default value of
    ignore_index is the same as the default value for F.cross_entropy in
    pytorch.
    """
    assert logits.size(0) == targets.size(
        -1
    ), "Logits and Targets tensor shapes don't match up"

    loss = F.nll_loss(
        F.log_softmax(logits, -1, dtype=torch.float32),
        targets,
        reduction="sum",
        ignore_index=ignore_index,
    )
    return loss


@register_criterion("legacy_masked_lm_loss")
class LegacyMaskedLmLoss(FairseqCriterion):
    """
    Implementation for the loss used in masked language model (MLM) training.
    This optionally also computes the next sentence prediction (NSP) loss and
    adds it to the overall loss based on the specified args. There are three
    cases to consider:
        1) Generic MLM training without NSP loss. In this case sentence_targets
           and sentence_logits are both None.
        2) BERT training without NSP loss. In this case sentence_targets is
           not None but sentence_logits is None and we should not be computing
           a sentence level loss.
        3) BERT training with NSP loss. In this case both sentence_targets and
           sentence_logits are not None and we should be computing a sentence
           level loss. The weight of the sentence level loss is specified as
           an argument.
    """

    def __init__(self, task, masked_lm_only, nsp_loss_weight):
        super().__init__(task)
        self.masked_lm_only = masked_lm_only
        self.nsp_loss_weight = nsp_loss_weight

    @staticmethod
    def add_args(parser):
        """Args for MaskedLM Loss"""
        # Default for masked_lm_only is False so as to not break BERT training
        parser.add_argument(
            "--masked-lm-only",
            default=False,
            action="store_true",
            help="compute MLM loss only",
        )
        parser.add_argument(
            "--nsp-loss-weight",
            default=1.0,
            type=float,
            help="weight for next sentence prediction" " loss (default 1)",
        )

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.
        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        lm_logits, output_metadata = model(**sample["net_input"])

        # reshape lm_logits from (N,T,C) to (N*T,C)
        lm_logits = lm_logits.view(-1, lm_logits.size(-1))
        lm_targets = sample["lm_target"].view(-1)
        lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets, self.padding_idx)

        # compute the number of tokens for which loss is computed. This is used
        # to normalize the loss
        ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
        loss = lm_loss / ntokens
        nsentences = sample["nsentences"]
        # nsentences = 0

        # Compute sentence loss if masked_lm_only is False
        sentence_loss = None
        if not self.masked_lm_only:
            sentence_logits = output_metadata["sentence_logits"]
            sentence_targets = sample["sentence_target"].view(-1)
            # This needs to be recomputed due to some differences between
            # TokenBlock and BlockPair dataset. This can be resolved with a
            # refactor of BERTModel which we will do in the future.
            # TODO: Remove this after refactor of BERTModel
            nsentences = sentence_targets.size(0)

            # Check for logits being none which can happen when remove_heads
            # is set to true in the BERT model. Ideally we should set
            # masked_lm_only to true in this case, but that requires some
            # refactor in the BERT model.
            if sentence_logits is not None:
                sentence_loss = compute_cross_entropy_loss(
                    sentence_logits, sentence_targets
                )

                loss += self.nsp_loss_weight * (sentence_loss / nsentences)

        # NOTE: as we are summing up per token mlm loss and per sentence nsp loss
        # we don't need to use sample_size as denominator for the gradient
        # here sample_size is just used for logging
        sample_size = 1
        logging_output = {
            "loss": utils.item(loss.data) if reduce else loss.data,
            "lm_loss": utils.item(lm_loss.data) if reduce else lm_loss.data,
            # sentence loss is not always computed
            "sentence_loss": (
                (utils.item(sentence_loss.data) if reduce else sentence_loss.data)
                if sentence_loss is not None
                else 0.0
            ),
            "ntokens": ntokens,
            "nsentences": nsentences,
            "sample_size": sample_size,
        }
        return loss, sample_size, logging_output

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        lm_loss_sum = sum(log.get("lm_loss", 0) for log in logging_outputs)
        sentence_loss_sum = sum(log.get("sentence_loss", 0) for log in logging_outputs)
        ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
        nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        agg_loss = sum(log.get("loss", 0) for log in logging_outputs)

        metrics.log_scalar(
            "loss",
            agg_loss / sample_size / math.log(2) if sample_size > 0 else 0.0,
            sample_size,
            round=3,
        )
        metrics.log_scalar(
            "lm_loss",
            lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
            ntokens,
            round=3,
        )
        metrics.log_scalar(
            "sentence_loss",
            sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0.0,
            nsentences,
            round=3,
        )
        metrics.log_scalar(
            "nll_loss",
            lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
            ntokens,
            round=3,
        )

    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return True