Spaces:
Running
Running
import torch | |
from torch.nn import Module | |
class BinLoss(Module): | |
r"""Binary cross-entropy loss for hard and soft attention. | |
Attributes | |
None | |
Methods | |
forward: Computes the binary cross-entropy loss for hard and soft attention. | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward( | |
self, hard_attention: torch.Tensor, soft_attention: torch.Tensor, | |
) -> torch.Tensor: | |
r"""Computes the binary cross-entropy loss for hard and soft attention. | |
Args: | |
hard_attention (torch.Tensor): A binary tensor indicating the hard attention. | |
soft_attention (torch.Tensor): A tensor containing the soft attention probabilities. | |
Returns: | |
torch.Tensor: The binary cross-entropy loss. | |
""" | |
log_sum = torch.log( | |
torch.clamp(soft_attention[hard_attention == 1], min=1e-12), | |
).sum() | |
return -log_sum / hard_attention.sum() | |