"""ACM and its variations""" from typing import Any import torch from torch import nn from .conv_utils import conv2d class ACM(nn.Module): """Affine Combination Module from ManiGAN""" def __init__(self, img_chans: int, text_chans: int, inner_dim: int = 64) -> None: """ Initialize the convolutional layers :param int img_chans: Channels in visual input :param int text_chans: Channels of textual input :param int inner_dim: Hyperparameters for inner dimensionality of features """ super().__init__() self.conv = conv2d(in_channels=img_chans, out_channels=inner_dim) self.weights = conv2d(in_channels=inner_dim, out_channels=text_chans) self.biases = conv2d(in_channels=inner_dim, out_channels=text_chans) def forward(self, text: torch.Tensor, img: torch.Tensor) -> Any: """ Propagate the textual and visual input through the ACM module :param torch.Tensor text: Textual input (can be hidden features) :param torch.Tensor img: Image input :return: Affine combination of text and image :rtype: torch.Tensor """ img_features = self.conv(img) return text * self.weights(img_features) + self.biases(img_features)