File size: 3,064 Bytes
c8ddb9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Attention modules"""
from typing import Any, Optional

import torch
from torch import nn

from src.models.modules.conv_utils import conv1d


class ChannelWiseAttention(nn.Module):
    """ChannelWise attention adapted from ControlGAN"""

    def __init__(self, fm_size: int, text_d: int) -> None:
        """
        Initialize the Channel-Wise attention module

        :param int fm_size:
            Height and width of feature map on k-th iteration of forward-pass.
            In paper, it's H_k * W_k
        :param int text_d: Dimensionality of sentence. From paper, it's D
        """
        super().__init__()
        # perception layer
        self.text_conv = conv1d(text_d, fm_size)
        # attention across channel dimension
        self.softmax = nn.Softmax(2)

    def forward(self, v_k: torch.Tensor, w_text: torch.Tensor) -> Any:
        """
        Apply attention to visual features taking into account features of words

        :param torch.Tensor v_k: Visual context
        :param torch.Tensor w_text: Textual features
        :return: Fused hidden visual features and word features
        :rtype: Any
        """
        w_hat = self.text_conv(w_text)
        m_k = v_k @ w_hat
        a_k = self.softmax(m_k)
        w_hat = torch.transpose(w_hat, 1, 2)
        return a_k @ w_hat


class SpatialAttention(nn.Module):
    """Spatial attention module for attending textual context to visual features"""

    def __init__(self, d: int, d_hat: int) -> None:
        """
        Set up softmax and conv layers

        :param int d: Initial embedding size for textual features. D from paper
        :param int d_hat: Height of image feature map. D_hat from paper
        """
        super().__init__()
        self.softmax = nn.Softmax(2)
        self.conv = conv1d(d, d_hat)

    def forward(
        self,
        text_context: torch.Tensor,
        image: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> Any:
        """
        Project image features into the latent space
        of textual features and apply attention

        :param torch.Tensor text_context: D x T tensor of hidden textual features
        :param torch.Tensor image: D_hat x N visual features
        :param Optional[torch.Tensor] mask:
            Boolean tensor for masking the padded words. BxL
        :return: Word features attended by visual features
        :rtype: Any
        """
        # number of features on image feature map H * W
        feature_num = image.size(2)
        # number of words in caption
        len_caption = text_context.size(2)
        text_context = self.conv(text_context)
        image = torch.transpose(image, 1, 2)
        s_i_j = image @ text_context
        if mask is not None:
            # duplicating mask and aligning dims with s_i_j
            mask = mask.repeat(1, feature_num).view(-1, feature_num, len_caption)
            s_i_j[mask] = -float("inf")
        b_i_j = self.softmax(s_i_j)
        c_i_j = b_i_j @ torch.transpose(text_context, 1, 2)
        return torch.transpose(c_i_j, 1, 2)