Spaces:
Running
Running
File size: 967 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
from torch.nn import Module
class BinLoss(Module):
r"""Binary cross-entropy loss for hard and soft attention.
Attributes
None
Methods
forward: Computes the binary cross-entropy loss for hard and soft attention.
"""
def __init__(self):
super().__init__()
def forward(
self, hard_attention: torch.Tensor, soft_attention: torch.Tensor,
) -> torch.Tensor:
r"""Computes the binary cross-entropy loss for hard and soft attention.
Args:
hard_attention (torch.Tensor): A binary tensor indicating the hard attention.
soft_attention (torch.Tensor): A tensor containing the soft attention probabilities.
Returns:
torch.Tensor: The binary cross-entropy loss.
"""
log_sum = torch.log(
torch.clamp(soft_attention[hard_attention == 1], min=1e-12),
).sum()
return -log_sum / hard_attention.sum()
|