File size: 446 Bytes
cb9e677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from typing import Optional

import torch
from torch.nn import functional as F


def compute_loss_with_mask(
    logits: torch.Tensor, target: torch.Tensor, target_mask: Optional[torch.Tensor]
):
    if target_mask is None:
        return F.cross_entropy(logits, target, reduction="mean")

    mb_loss = F.cross_entropy(logits, target, reduction="none")
    mb_loss = torch.sum(mb_loss * target_mask) / torch.sum(target_mask)

    return mb_loss