|
"""Generator Module""" |
|
|
|
from typing import Any, Optional |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from src.models.modules.acm import ACM |
|
from src.models.modules.attention import ChannelWiseAttention, SpatialAttention |
|
from src.models.modules.cond_augment import CondAugmentation |
|
from src.models.modules.downsample import down_sample |
|
from src.models.modules.residual import ResidualBlock |
|
from src.models.modules.upsample import img_up_block, up_sample |
|
|
|
|
|
class InitStageG(nn.Module): |
|
"""Initial Stage Generator Module""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__( |
|
self, Ng: int, Ng_init: int, conditioning_dim: int, D: int, noise_dim: int |
|
): |
|
""" |
|
:param Ng: Number of channels. |
|
:param Ng_init: Initial value of Ng, this is output channel of first image upsample. |
|
:param conditioning_dim: Dimension of the conditioning space |
|
:param D: Dimension of the text embedding space [D from AttnGAN paper] |
|
:param noise_dim: Dimension of the noise space |
|
""" |
|
super().__init__() |
|
self.gf_dim = Ng |
|
self.gf_init = Ng_init |
|
self.in_dim = noise_dim + conditioning_dim + D |
|
self.text_dim = D |
|
|
|
self.define_module() |
|
|
|
def define_module(self) -> None: |
|
"""Defines FC, Upsample, Residual, ACM, Attention modules""" |
|
nz, ng = self.in_dim, self.gf_dim |
|
self.fully_connect = nn.Sequential( |
|
nn.Linear(nz, ng * 4 * 4 * 2, bias=False), |
|
nn.BatchNorm1d(ng * 4 * 4 * 2), |
|
nn.GLU(dim=1), |
|
) |
|
|
|
self.upsample1 = up_sample(ng, ng // 2) |
|
self.upsample2 = up_sample(ng // 2, ng // 4) |
|
self.upsample3 = up_sample(ng // 4, ng // 8) |
|
self.upsample4 = up_sample( |
|
ng // 8 * 3, ng // 16 |
|
) |
|
|
|
self.residual = self._make_layer(ResidualBlock, ng // 8 * 3) |
|
self.acm_module = ACM(self.gf_init, ng // 8 * 3) |
|
|
|
self.spatial_att = SpatialAttention(self.text_dim, ng // 8) |
|
self.channel_att = ChannelWiseAttention( |
|
32 * 32, self.text_dim |
|
) |
|
|
|
def _make_layer(self, block: Any, channel_num: int) -> nn.Module: |
|
layers = [] |
|
for _ in range(2): |
|
layers.append(block(channel_num)) |
|
return nn.Sequential(*layers) |
|
|
|
def forward( |
|
self, |
|
noise: torch.Tensor, |
|
condition: torch.Tensor, |
|
global_inception: torch.Tensor, |
|
local_upsampled_inception: torch.Tensor, |
|
word_embeddings: torch.Tensor, |
|
mask: Optional[torch.Tensor] = None, |
|
) -> Any: |
|
""" |
|
:param noise: Noise tensor |
|
:param condition: Condition tensor (c^ from stackGAN++ paper) |
|
:param global_inception: Global inception feature |
|
:param local_upsampled_inception: Local inception feature, upsampled to 32 x 32 |
|
:param word_embeddings: Word embeddings [shape: D x L or D x T] |
|
:param mask: Mask for padding tokens |
|
:return: Hidden Image feature map Tensor of 64 x 64 size |
|
""" |
|
noise_concat = torch.cat((noise, condition), 1) |
|
inception_concat = torch.cat((noise_concat, global_inception), 1) |
|
hidden = self.fully_connect(inception_concat) |
|
hidden = hidden.view(-1, self.gf_dim, 4, 4) |
|
hidden = self.upsample1(hidden) |
|
hidden = self.upsample2(hidden) |
|
hidden_32 = self.upsample3(hidden) |
|
hidden_32_view = hidden_32.view( |
|
hidden_32.shape[0], -1, hidden_32.shape[2] * hidden_32.shape[3] |
|
) |
|
|
|
spatial_att_feat = self.spatial_att( |
|
word_embeddings, hidden_32_view, mask |
|
) |
|
channel_att_feat = self.channel_att( |
|
spatial_att_feat, word_embeddings |
|
) |
|
spatial_att_feat = spatial_att_feat.view( |
|
word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3] |
|
) |
|
channel_att_feat = channel_att_feat.view( |
|
word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3] |
|
) |
|
|
|
spatial_concat = torch.cat( |
|
(hidden_32, spatial_att_feat), 1 |
|
) |
|
attn_concat = torch.cat( |
|
(spatial_concat, channel_att_feat), 1 |
|
) |
|
|
|
hidden_32 = self.acm_module(attn_concat, local_upsampled_inception) |
|
hidden_32 = self.residual(hidden_32) |
|
hidden_64 = self.upsample4(hidden_32) |
|
return hidden_64 |
|
|
|
|
|
class NextStageG(nn.Module): |
|
"""Next Stage Generator Module""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, Ng: int, Ng_init: int, D: int, image_size: int): |
|
""" |
|
:param Ng: Number of channels. |
|
:param Ng_init: Initial value of Ng. |
|
:param D: Dimension of the text embedding space [D from AttnGAN paper] |
|
:param image_size: Size of the output image from previous generator stage. |
|
""" |
|
super().__init__() |
|
self.gf_dim = Ng |
|
self.gf_init = Ng_init |
|
self.text_dim = D |
|
self.img_size = image_size |
|
|
|
self.define_module() |
|
|
|
def define_module(self) -> None: |
|
"""Defines FC, Upsample, Residual, ACM, Attention modules""" |
|
ng = self.gf_dim |
|
self.spatial_att = SpatialAttention(self.text_dim, ng) |
|
self.channel_att = ChannelWiseAttention( |
|
self.img_size * self.img_size, self.text_dim |
|
) |
|
|
|
self.residual = self._make_layer(ResidualBlock, ng * 3) |
|
self.upsample = up_sample(ng * 3, ng) |
|
self.acm_module = ACM(self.gf_init, ng * 3) |
|
self.upsample2 = up_sample(ng, ng) |
|
|
|
def _make_layer(self, block: Any, channel_num: int) -> nn.Module: |
|
layers = [] |
|
for _ in range(2): |
|
layers.append(block(channel_num)) |
|
return nn.Sequential(*layers) |
|
|
|
def forward( |
|
self, |
|
hidden_feat: Any, |
|
word_embeddings: torch.Tensor, |
|
vgg64_feat: torch.Tensor, |
|
mask: Optional[torch.Tensor] = None, |
|
) -> Any: |
|
""" |
|
:param hidden_feat: Hidden feature from previous generator stage [i.e. hidden_64] |
|
:param word_embeddings: Word embeddings |
|
:param vgg64_feat: VGG feature map of size 64 x 64 |
|
:param mask: Mask for the padding tokens |
|
:return: Image feature map of size 256 x 256 |
|
""" |
|
hidden_view = hidden_feat.view( |
|
hidden_feat.shape[0], -1, hidden_feat.shape[2] * hidden_feat.shape[3] |
|
) |
|
spatial_att_feat = self.spatial_att( |
|
word_embeddings, hidden_view, mask |
|
) |
|
channel_att_feat = self.channel_att( |
|
spatial_att_feat, word_embeddings |
|
) |
|
spatial_att_feat = spatial_att_feat.view( |
|
word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3] |
|
) |
|
channel_att_feat = channel_att_feat.view( |
|
word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3] |
|
) |
|
|
|
spatial_concat = torch.cat( |
|
(hidden_feat, spatial_att_feat), 1 |
|
) |
|
attn_concat = torch.cat( |
|
(spatial_concat, channel_att_feat), 1 |
|
) |
|
|
|
hidden_64 = self.acm_module(attn_concat, vgg64_feat) |
|
hidden_64 = self.residual(hidden_64) |
|
hidden_128 = self.upsample(hidden_64) |
|
hidden_256 = self.upsample2(hidden_128) |
|
return hidden_256 |
|
|
|
|
|
class GetImageG(nn.Module): |
|
"""Generates the Final Fake Image from the Image Feature Map""" |
|
|
|
def __init__(self, Ng: int): |
|
""" |
|
:param Ng: Number of channels. |
|
""" |
|
super().__init__() |
|
self.img = nn.Sequential( |
|
nn.Conv2d(Ng, 3, kernel_size=3, stride=1, padding=1, bias=False), nn.Tanh() |
|
) |
|
|
|
def forward(self, hidden_feat: torch.Tensor) -> Any: |
|
""" |
|
:param hidden_feat: Image feature map |
|
:return: Final fake image |
|
""" |
|
return self.img(hidden_feat) |
|
|
|
|
|
class Generator(nn.Module): |
|
"""Generator Module""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, Ng: int, D: int, conditioning_dim: int, noise_dim: int): |
|
""" |
|
:param Ng: Number of channels. [Taken from StackGAN++ paper] |
|
:param D: Dimension of the text embedding space |
|
:param conditioning_dim: Dimension of the conditioning space |
|
:param noise_dim: Dimension of the noise space |
|
""" |
|
super().__init__() |
|
self.cond_augment = CondAugmentation(D, conditioning_dim) |
|
self.hidden_net1 = InitStageG(Ng * 16, Ng, conditioning_dim, D, noise_dim) |
|
self.inception_img_upsample = img_up_block( |
|
D, Ng |
|
) |
|
self.hidden_net2 = NextStageG(Ng, Ng, D, 64) |
|
self.generate_img = GetImageG(Ng) |
|
|
|
self.acm_module = ACM(Ng, Ng) |
|
|
|
self.vgg_downsample = down_sample(D // 2, Ng) |
|
self.upsample1 = up_sample(Ng, Ng) |
|
self.upsample2 = up_sample(Ng, Ng) |
|
|
|
def forward( |
|
self, |
|
noise: torch.Tensor, |
|
sentence_embeddings: torch.Tensor, |
|
word_embeddings: torch.Tensor, |
|
global_inception_feat: torch.Tensor, |
|
local_inception_feat: torch.Tensor, |
|
vgg_feat: torch.Tensor, |
|
mask: Optional[torch.Tensor] = None, |
|
) -> Any: |
|
""" |
|
:param noise: Noise vector [shape: (batch, noise_dim)] |
|
:param sentence_embeddings: Sentence embeddings [shape: (batch, D)] |
|
:param word_embeddings: Word embeddings [shape: D x L, where L is length of sentence] |
|
:param global_inception_feat: Global Inception feature map [shape: (batch, D)] |
|
:param local_inception_feat: Local Inception feature map [shape: (batch, D, 17, 17)] |
|
:param vgg_feat: VGG feature map [shape: (batch, D // 2 = 128, 128, 128)] |
|
:param mask: Mask for the padding tokens |
|
:return: Final fake image |
|
""" |
|
c_hat, mu_tensor, logvar = self.cond_augment(sentence_embeddings) |
|
hidden_32 = self.inception_img_upsample(local_inception_feat) |
|
|
|
hidden_64 = self.hidden_net1( |
|
noise, c_hat, global_inception_feat, hidden_32, word_embeddings, mask |
|
) |
|
|
|
vgg_64 = self.vgg_downsample(vgg_feat) |
|
|
|
hidden_256 = self.hidden_net2(hidden_64, word_embeddings, vgg_64, mask) |
|
|
|
vgg_128 = self.upsample1(vgg_64) |
|
vgg_256 = self.upsample2(vgg_128) |
|
|
|
hidden_256 = self.acm_module(hidden_256, vgg_256) |
|
fake_img = self.generate_img(hidden_256) |
|
|
|
return fake_img, mu_tensor, logvar |
|
|