from typing import Optional import torch from torch.nn import functional as F def compute_loss_with_mask( logits: torch.Tensor, target: torch.Tensor, target_mask: Optional[torch.Tensor] ): if target_mask is None: return F.cross_entropy(logits, target, reduction="mean") mb_loss = F.cross_entropy(logits, target, reduction="none") mb_loss = torch.sum(mb_loss * target_mask) / torch.sum(target_mask) return mb_loss