File size: 967 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torch.nn import Module


class BinLoss(Module):
    r"""Binary cross-entropy loss for hard and soft attention.

    Attributes
        None

    Methods
        forward: Computes the binary cross-entropy loss for hard and soft attention.

    """

    def __init__(self):
        super().__init__()

    def forward(
        self, hard_attention: torch.Tensor, soft_attention: torch.Tensor,
    ) -> torch.Tensor:
        r"""Computes the binary cross-entropy loss for hard and soft attention.

        Args:
            hard_attention (torch.Tensor): A binary tensor indicating the hard attention.
            soft_attention (torch.Tensor): A tensor containing the soft attention probabilities.

        Returns:
            torch.Tensor: The binary cross-entropy loss.

        """
        log_sum = torch.log(
            torch.clamp(soft_attention[hard_attention == 1], min=1e-12),
        ).sum()
        return -log_sum / hard_attention.sum()