taim-gan / src /models /modules /image_encoder.py
Dmmc's picture
three-model version
c8ddb9b
raw
history blame
4.28 kB
"""Image Encoder Module"""
from typing import Any
import torch
from torch import nn
from src.models.modules.conv_utils import conv2d
# build inception v3 image encoder
class InceptionEncoder(nn.Module):
"""Image Encoder Module adapted from AttnGAN"""
def __init__(self, D: int):
"""
:param D: Dimension of the text embedding space [D from AttnGAN paper]
"""
super().__init__()
self.text_emb_dim = D
model = torch.hub.load(
"pytorch/vision:v0.10.0", "inception_v3", pretrained=True
)
for param in model.parameters():
param.requires_grad = False
self.define_module(model)
self.init_trainable_weights()
def define_module(self, model: nn.Module) -> None:
"""
This function defines the modules of the image encoder
:param model: Pretrained Inception V3 model
"""
model.cust_upsample = nn.Upsample(size=(299, 299), mode="bilinear")
model.cust_maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
model.cust_maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
model.cust_avgpool = nn.AvgPool2d(kernel_size=8)
attribute_list = [
"cust_upsample",
"Conv2d_1a_3x3",
"Conv2d_2a_3x3",
"Conv2d_2b_3x3",
"cust_maxpool1",
"Conv2d_3b_1x1",
"Conv2d_4a_3x3",
"cust_maxpool2",
"Mixed_5b",
"Mixed_5c",
"Mixed_5d",
"Mixed_6a",
"Mixed_6b",
"Mixed_6c",
"Mixed_6d",
"Mixed_6e",
]
self.feature_extractor = nn.Sequential(
*[getattr(model, name) for name in attribute_list]
)
attribute_list2 = ["Mixed_7a", "Mixed_7b", "Mixed_7c", "cust_avgpool"]
self.feature_extractor2 = nn.Sequential(
*[getattr(model, name) for name in attribute_list2]
)
self.emb_features = conv2d(
768, self.text_emb_dim, kernel_size=1, stride=1, padding=0
)
self.emb_cnn_code = nn.Linear(2048, self.text_emb_dim)
def init_trainable_weights(self) -> None:
"""
This function initializes the trainable weights of the image encoder
"""
initrange = 0.1
self.emb_features.weight.data.uniform_(-initrange, initrange)
self.emb_cnn_code.weight.data.uniform_(-initrange, initrange)
def forward(self, image_tensor: torch.Tensor) -> Any:
"""
:param image_tensor: Input image
:return: features: local feature matrix (v from attnGAN paper) [shape: (batch, D, 17, 17)]
:return: cnn_code: global image feature (v^ from attnGAN paper) [shape: (batch, D)]
"""
# this is the image size
# x.shape: 10 3 256 256
features = self.feature_extractor(image_tensor)
# 17 x 17 x 768
image_tensor = self.feature_extractor2(features)
image_tensor = image_tensor.view(image_tensor.size(0), -1)
# 2048
# global image features
cnn_code = self.emb_cnn_code(image_tensor)
if features is not None:
features = self.emb_features(features)
# feature.shape: 10 256 17 17
# cnn_code.shape: 10 256
return features, cnn_code
class VGGEncoder(nn.Module):
"""Pre Trained VGG Encoder Module"""
def __init__(self) -> None:
"""
Initialize pre-trained VGG model with frozen parameters
"""
super().__init__()
self.select = "8" ## We want to get the output of the 8th layer in VGG.
self.model = torch.hub.load("pytorch/vision:v0.10.0", "vgg16", pretrained=True)
for param in self.model.parameters():
param.resquires_grad = False
self.vgg_modules = self.model.features._modules
def forward(self, image_tensor: torch.Tensor) -> Any:
"""
:param x: Input image tensor [shape: (batch, 3, 256, 256)]
:return: VGG features [shape: (batch, 128, 128, 128)]
"""
for name, layer in self.vgg_modules.items():
image_tensor = layer(image_tensor)
if name == self.select:
return image_tensor
return None