taim-gan / src /models /modules /generator.py
Dmmc's picture
three-model version
c8ddb9b
raw
history blame
11.9 kB
"""Generator Module"""
from typing import Any, Optional
import torch
from torch import nn
from src.models.modules.acm import ACM
from src.models.modules.attention import ChannelWiseAttention, SpatialAttention
from src.models.modules.cond_augment import CondAugmentation
from src.models.modules.downsample import down_sample
from src.models.modules.residual import ResidualBlock
from src.models.modules.upsample import img_up_block, up_sample
class InitStageG(nn.Module):
"""Initial Stage Generator Module"""
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments
# pylint: disable=invalid-name
# pylint: disable=too-many-locals
def __init__(
self, Ng: int, Ng_init: int, conditioning_dim: int, D: int, noise_dim: int
):
"""
:param Ng: Number of channels.
:param Ng_init: Initial value of Ng, this is output channel of first image upsample.
:param conditioning_dim: Dimension of the conditioning space
:param D: Dimension of the text embedding space [D from AttnGAN paper]
:param noise_dim: Dimension of the noise space
"""
super().__init__()
self.gf_dim = Ng
self.gf_init = Ng_init
self.in_dim = noise_dim + conditioning_dim + D
self.text_dim = D
self.define_module()
def define_module(self) -> None:
"""Defines FC, Upsample, Residual, ACM, Attention modules"""
nz, ng = self.in_dim, self.gf_dim
self.fully_connect = nn.Sequential(
nn.Linear(nz, ng * 4 * 4 * 2, bias=False),
nn.BatchNorm1d(ng * 4 * 4 * 2),
nn.GLU(dim=1), # we start from 4 x 4 feat_map and return hidden_64.
)
self.upsample1 = up_sample(ng, ng // 2)
self.upsample2 = up_sample(ng // 2, ng // 4)
self.upsample3 = up_sample(ng // 4, ng // 8)
self.upsample4 = up_sample(
ng // 8 * 3, ng // 16
) # multiply channel by 3 because concat spatial and channel att
self.residual = self._make_layer(ResidualBlock, ng // 8 * 3)
self.acm_module = ACM(self.gf_init, ng // 8 * 3)
self.spatial_att = SpatialAttention(self.text_dim, ng // 8)
self.channel_att = ChannelWiseAttention(
32 * 32, self.text_dim
) # 32 x 32 is the feature map size
def _make_layer(self, block: Any, channel_num: int) -> nn.Module:
layers = []
for _ in range(2): # number of residual blocks hardcoded to 2
layers.append(block(channel_num))
return nn.Sequential(*layers)
def forward(
self,
noise: torch.Tensor,
condition: torch.Tensor,
global_inception: torch.Tensor,
local_upsampled_inception: torch.Tensor,
word_embeddings: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Any:
"""
:param noise: Noise tensor
:param condition: Condition tensor (c^ from stackGAN++ paper)
:param global_inception: Global inception feature
:param local_upsampled_inception: Local inception feature, upsampled to 32 x 32
:param word_embeddings: Word embeddings [shape: D x L or D x T]
:param mask: Mask for padding tokens
:return: Hidden Image feature map Tensor of 64 x 64 size
"""
noise_concat = torch.cat((noise, condition), 1)
inception_concat = torch.cat((noise_concat, global_inception), 1)
hidden = self.fully_connect(inception_concat)
hidden = hidden.view(-1, self.gf_dim, 4, 4) # convert to 4x4 image feature map
hidden = self.upsample1(hidden)
hidden = self.upsample2(hidden)
hidden_32 = self.upsample3(hidden) # shape: (batch_size, gf_dim // 8, 32, 32)
hidden_32_view = hidden_32.view(
hidden_32.shape[0], -1, hidden_32.shape[2] * hidden_32.shape[3]
) # this reshaping is done as attention module expects this shape.
spatial_att_feat = self.spatial_att(
word_embeddings, hidden_32_view, mask
) # spatial att shape: (batch, D^, 32 * 32)
channel_att_feat = self.channel_att(
spatial_att_feat, word_embeddings
) # channel att shape: (batch, D^, 32 * 32), or (batch, C, Hk* Wk) from controlGAN paper
spatial_att_feat = spatial_att_feat.view(
word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3]
) # reshape to (batch, D^, 32, 32)
channel_att_feat = channel_att_feat.view(
word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3]
) # reshape to (batch, D^, 32, 32)
spatial_concat = torch.cat(
(hidden_32, spatial_att_feat), 1
) # concat spatial attention feature with hidden_32
attn_concat = torch.cat(
(spatial_concat, channel_att_feat), 1
) # concat channel and spatial attention feature
hidden_32 = self.acm_module(attn_concat, local_upsampled_inception)
hidden_32 = self.residual(hidden_32)
hidden_64 = self.upsample4(hidden_32)
return hidden_64
class NextStageG(nn.Module):
"""Next Stage Generator Module"""
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments
# pylint: disable=invalid-name
# pylint: disable=too-many-locals
def __init__(self, Ng: int, Ng_init: int, D: int, image_size: int):
"""
:param Ng: Number of channels.
:param Ng_init: Initial value of Ng.
:param D: Dimension of the text embedding space [D from AttnGAN paper]
:param image_size: Size of the output image from previous generator stage.
"""
super().__init__()
self.gf_dim = Ng
self.gf_init = Ng_init
self.text_dim = D
self.img_size = image_size
self.define_module()
def define_module(self) -> None:
"""Defines FC, Upsample, Residual, ACM, Attention modules"""
ng = self.gf_dim
self.spatial_att = SpatialAttention(self.text_dim, ng)
self.channel_att = ChannelWiseAttention(
self.img_size * self.img_size, self.text_dim
)
self.residual = self._make_layer(ResidualBlock, ng * 3)
self.upsample = up_sample(ng * 3, ng)
self.acm_module = ACM(self.gf_init, ng * 3)
self.upsample2 = up_sample(ng, ng)
def _make_layer(self, block: Any, channel_num: int) -> nn.Module:
layers = []
for _ in range(2): # no of residual layers hardcoded to 2
layers.append(block(channel_num))
return nn.Sequential(*layers)
def forward(
self,
hidden_feat: Any,
word_embeddings: torch.Tensor,
vgg64_feat: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Any:
"""
:param hidden_feat: Hidden feature from previous generator stage [i.e. hidden_64]
:param word_embeddings: Word embeddings
:param vgg64_feat: VGG feature map of size 64 x 64
:param mask: Mask for the padding tokens
:return: Image feature map of size 256 x 256
"""
hidden_view = hidden_feat.view(
hidden_feat.shape[0], -1, hidden_feat.shape[2] * hidden_feat.shape[3]
) # reshape to pass into attention modules.
spatial_att_feat = self.spatial_att(
word_embeddings, hidden_view, mask
) # spatial att shape: (batch, D^, 64 * 64), or D^ x N
channel_att_feat = self.channel_att(
spatial_att_feat, word_embeddings
) # channel att shape: (batch, D^, 64 * 64), or (batch, C, Hk* Wk) from controlGAN paper
spatial_att_feat = spatial_att_feat.view(
word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3]
) # reshape to (batch, D^, 64, 64)
channel_att_feat = channel_att_feat.view(
word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3]
) # reshape to (batch, D^, 64, 64)
spatial_concat = torch.cat(
(hidden_feat, spatial_att_feat), 1
) # concat spatial attention feature with hidden_64
attn_concat = torch.cat(
(spatial_concat, channel_att_feat), 1
) # concat channel and spatial attention feature
hidden_64 = self.acm_module(attn_concat, vgg64_feat)
hidden_64 = self.residual(hidden_64)
hidden_128 = self.upsample(hidden_64)
hidden_256 = self.upsample2(hidden_128)
return hidden_256
class GetImageG(nn.Module):
"""Generates the Final Fake Image from the Image Feature Map"""
def __init__(self, Ng: int):
"""
:param Ng: Number of channels.
"""
super().__init__()
self.img = nn.Sequential(
nn.Conv2d(Ng, 3, kernel_size=3, stride=1, padding=1, bias=False), nn.Tanh()
)
def forward(self, hidden_feat: torch.Tensor) -> Any:
"""
:param hidden_feat: Image feature map
:return: Final fake image
"""
return self.img(hidden_feat)
class Generator(nn.Module):
"""Generator Module"""
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments
# pylint: disable=invalid-name
# pylint: disable=too-many-locals
def __init__(self, Ng: int, D: int, conditioning_dim: int, noise_dim: int):
"""
:param Ng: Number of channels. [Taken from StackGAN++ paper]
:param D: Dimension of the text embedding space
:param conditioning_dim: Dimension of the conditioning space
:param noise_dim: Dimension of the noise space
"""
super().__init__()
self.cond_augment = CondAugmentation(D, conditioning_dim)
self.hidden_net1 = InitStageG(Ng * 16, Ng, conditioning_dim, D, noise_dim)
self.inception_img_upsample = img_up_block(
D, Ng
) # as channel size returned by inception encoder is D (Default in paper: 256)
self.hidden_net2 = NextStageG(Ng, Ng, D, 64)
self.generate_img = GetImageG(Ng)
self.acm_module = ACM(Ng, Ng)
self.vgg_downsample = down_sample(D // 2, Ng)
self.upsample1 = up_sample(Ng, Ng)
self.upsample2 = up_sample(Ng, Ng)
def forward(
self,
noise: torch.Tensor,
sentence_embeddings: torch.Tensor,
word_embeddings: torch.Tensor,
global_inception_feat: torch.Tensor,
local_inception_feat: torch.Tensor,
vgg_feat: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Any:
"""
:param noise: Noise vector [shape: (batch, noise_dim)]
:param sentence_embeddings: Sentence embeddings [shape: (batch, D)]
:param word_embeddings: Word embeddings [shape: D x L, where L is length of sentence]
:param global_inception_feat: Global Inception feature map [shape: (batch, D)]
:param local_inception_feat: Local Inception feature map [shape: (batch, D, 17, 17)]
:param vgg_feat: VGG feature map [shape: (batch, D // 2 = 128, 128, 128)]
:param mask: Mask for the padding tokens
:return: Final fake image
"""
c_hat, mu_tensor, logvar = self.cond_augment(sentence_embeddings)
hidden_32 = self.inception_img_upsample(local_inception_feat)
hidden_64 = self.hidden_net1(
noise, c_hat, global_inception_feat, hidden_32, word_embeddings, mask
)
vgg_64 = self.vgg_downsample(vgg_feat)
hidden_256 = self.hidden_net2(hidden_64, word_embeddings, vgg_64, mask)
vgg_128 = self.upsample1(vgg_64)
vgg_256 = self.upsample2(vgg_128)
hidden_256 = self.acm_module(hidden_256, vgg_256)
fake_img = self.generate_img(hidden_256)
return fake_img, mu_tensor, logvar