taim-gan / src /models /modules /attention.py
Dmmc's picture
three-model version
c8ddb9b
raw
history blame
3.06 kB
"""Attention modules"""
from typing import Any, Optional
import torch
from torch import nn
from src.models.modules.conv_utils import conv1d
class ChannelWiseAttention(nn.Module):
"""ChannelWise attention adapted from ControlGAN"""
def __init__(self, fm_size: int, text_d: int) -> None:
"""
Initialize the Channel-Wise attention module
:param int fm_size:
Height and width of feature map on k-th iteration of forward-pass.
In paper, it's H_k * W_k
:param int text_d: Dimensionality of sentence. From paper, it's D
"""
super().__init__()
# perception layer
self.text_conv = conv1d(text_d, fm_size)
# attention across channel dimension
self.softmax = nn.Softmax(2)
def forward(self, v_k: torch.Tensor, w_text: torch.Tensor) -> Any:
"""
Apply attention to visual features taking into account features of words
:param torch.Tensor v_k: Visual context
:param torch.Tensor w_text: Textual features
:return: Fused hidden visual features and word features
:rtype: Any
"""
w_hat = self.text_conv(w_text)
m_k = v_k @ w_hat
a_k = self.softmax(m_k)
w_hat = torch.transpose(w_hat, 1, 2)
return a_k @ w_hat
class SpatialAttention(nn.Module):
"""Spatial attention module for attending textual context to visual features"""
def __init__(self, d: int, d_hat: int) -> None:
"""
Set up softmax and conv layers
:param int d: Initial embedding size for textual features. D from paper
:param int d_hat: Height of image feature map. D_hat from paper
"""
super().__init__()
self.softmax = nn.Softmax(2)
self.conv = conv1d(d, d_hat)
def forward(
self,
text_context: torch.Tensor,
image: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Any:
"""
Project image features into the latent space
of textual features and apply attention
:param torch.Tensor text_context: D x T tensor of hidden textual features
:param torch.Tensor image: D_hat x N visual features
:param Optional[torch.Tensor] mask:
Boolean tensor for masking the padded words. BxL
:return: Word features attended by visual features
:rtype: Any
"""
# number of features on image feature map H * W
feature_num = image.size(2)
# number of words in caption
len_caption = text_context.size(2)
text_context = self.conv(text_context)
image = torch.transpose(image, 1, 2)
s_i_j = image @ text_context
if mask is not None:
# duplicating mask and aligning dims with s_i_j
mask = mask.repeat(1, feature_num).view(-1, feature_num, len_caption)
s_i_j[mask] = -float("inf")
b_i_j = self.softmax(s_i_j)
c_i_j = b_i_j @ torch.transpose(text_context, 1, 2)
return torch.transpose(c_i_j, 1, 2)