File size: 4,282 Bytes
c8ddb9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
"""Image Encoder Module"""
from typing import Any
import torch
from torch import nn
from src.models.modules.conv_utils import conv2d
# build inception v3 image encoder
class InceptionEncoder(nn.Module):
"""Image Encoder Module adapted from AttnGAN"""
def __init__(self, D: int):
"""
:param D: Dimension of the text embedding space [D from AttnGAN paper]
"""
super().__init__()
self.text_emb_dim = D
model = torch.hub.load(
"pytorch/vision:v0.10.0", "inception_v3", pretrained=True
)
for param in model.parameters():
param.requires_grad = False
self.define_module(model)
self.init_trainable_weights()
def define_module(self, model: nn.Module) -> None:
"""
This function defines the modules of the image encoder
:param model: Pretrained Inception V3 model
"""
model.cust_upsample = nn.Upsample(size=(299, 299), mode="bilinear")
model.cust_maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
model.cust_maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
model.cust_avgpool = nn.AvgPool2d(kernel_size=8)
attribute_list = [
"cust_upsample",
"Conv2d_1a_3x3",
"Conv2d_2a_3x3",
"Conv2d_2b_3x3",
"cust_maxpool1",
"Conv2d_3b_1x1",
"Conv2d_4a_3x3",
"cust_maxpool2",
"Mixed_5b",
"Mixed_5c",
"Mixed_5d",
"Mixed_6a",
"Mixed_6b",
"Mixed_6c",
"Mixed_6d",
"Mixed_6e",
]
self.feature_extractor = nn.Sequential(
*[getattr(model, name) for name in attribute_list]
)
attribute_list2 = ["Mixed_7a", "Mixed_7b", "Mixed_7c", "cust_avgpool"]
self.feature_extractor2 = nn.Sequential(
*[getattr(model, name) for name in attribute_list2]
)
self.emb_features = conv2d(
768, self.text_emb_dim, kernel_size=1, stride=1, padding=0
)
self.emb_cnn_code = nn.Linear(2048, self.text_emb_dim)
def init_trainable_weights(self) -> None:
"""
This function initializes the trainable weights of the image encoder
"""
initrange = 0.1
self.emb_features.weight.data.uniform_(-initrange, initrange)
self.emb_cnn_code.weight.data.uniform_(-initrange, initrange)
def forward(self, image_tensor: torch.Tensor) -> Any:
"""
:param image_tensor: Input image
:return: features: local feature matrix (v from attnGAN paper) [shape: (batch, D, 17, 17)]
:return: cnn_code: global image feature (v^ from attnGAN paper) [shape: (batch, D)]
"""
# this is the image size
# x.shape: 10 3 256 256
features = self.feature_extractor(image_tensor)
# 17 x 17 x 768
image_tensor = self.feature_extractor2(features)
image_tensor = image_tensor.view(image_tensor.size(0), -1)
# 2048
# global image features
cnn_code = self.emb_cnn_code(image_tensor)
if features is not None:
features = self.emb_features(features)
# feature.shape: 10 256 17 17
# cnn_code.shape: 10 256
return features, cnn_code
class VGGEncoder(nn.Module):
"""Pre Trained VGG Encoder Module"""
def __init__(self) -> None:
"""
Initialize pre-trained VGG model with frozen parameters
"""
super().__init__()
self.select = "8" ## We want to get the output of the 8th layer in VGG.
self.model = torch.hub.load("pytorch/vision:v0.10.0", "vgg16", pretrained=True)
for param in self.model.parameters():
param.resquires_grad = False
self.vgg_modules = self.model.features._modules
def forward(self, image_tensor: torch.Tensor) -> Any:
"""
:param x: Input image tensor [shape: (batch, 3, 256, 256)]
:return: VGG features [shape: (batch, 128, 128, 128)]
"""
for name, layer in self.vgg_modules.items():
image_tensor = layer(image_tensor)
if name == self.select:
return image_tensor
return None
|