three-model version
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +2 -0
- app.py +117 -0
- requirements.txt +5 -0
- src/__init__.py +2 -0
- src/__pycache__/__init__.cpython-39.pyc +0 -0
- src/__pycache__/config.cpython-39.pyc +0 -0
- src/config.py +47 -0
- src/data/.gitkeep +0 -0
- src/data/__init__.py +5 -0
- src/data/__pycache__/__init__.cpython-39.pyc +0 -0
- src/data/__pycache__/collate.cpython-39.pyc +0 -0
- src/data/__pycache__/datasets.cpython-39.pyc +0 -0
- src/data/__pycache__/tokenizer.cpython-39.pyc +0 -0
- src/data/collate.py +43 -0
- src/data/datasets.py +387 -0
- src/data/stubs/bird.jpg +0 -0
- src/data/stubs/pigeon.jpg +0 -0
- src/data/stubs/rohit.jpeg +0 -0
- src/data/tokenizer.py +23 -0
- src/features/.gitkeep +0 -0
- src/features/__init__.py +0 -0
- src/features/build_features.py +0 -0
- src/models/.gitkeep +0 -0
- src/models/__init__.py +4 -0
- src/models/__pycache__/__init__.cpython-39.pyc +0 -0
- src/models/__pycache__/losses.cpython-39.pyc +0 -0
- src/models/__pycache__/train_model.cpython-39.pyc +0 -0
- src/models/__pycache__/utils.cpython-39.pyc +0 -0
- src/models/losses.py +344 -0
- src/models/modules/__init__.py +12 -0
- src/models/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/acm.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/attention.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/cond_augment.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/conv_utils.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/discriminator.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/downsample.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/generator.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/image_encoder.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/residual.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/text_encoder.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/upsample.cpython-39.pyc +0 -0
- src/models/modules/acm.py +37 -0
- src/models/modules/attention.py +88 -0
- src/models/modules/cond_augment.py +57 -0
- src/models/modules/conv_utils.py +78 -0
- src/models/modules/discriminator.py +144 -0
- src/models/modules/downsample.py +14 -0
- src/models/modules/generator.py +300 -0
- src/models/modules/image_encoder.py +138 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__/*
|
2 |
+
.idea/*
|
app.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np # this should come first to mitigate mlk-service bug
|
2 |
+
from src.models.utils import get_image_arr, load_model
|
3 |
+
from src.data import TAIMGANTokenizer
|
4 |
+
from torchvision import transforms
|
5 |
+
from src.config import config_dict
|
6 |
+
from pathlib import Path
|
7 |
+
from enum import IntEnum, auto
|
8 |
+
from PIL import Image
|
9 |
+
import gradio as gr
|
10 |
+
import torch
|
11 |
+
from src.models.modules import (
|
12 |
+
VGGEncoder,
|
13 |
+
InceptionEncoder,
|
14 |
+
TextEncoder,
|
15 |
+
Generator
|
16 |
+
)
|
17 |
+
|
18 |
+
##########
|
19 |
+
# PARAMS #
|
20 |
+
##########
|
21 |
+
|
22 |
+
IMG_CHANS = 3 # RGB channels for image
|
23 |
+
IMG_HW = 256 # height and width of images
|
24 |
+
HIDDEN_DIM = 128 # hidden dimensions of lstm cell in one direction
|
25 |
+
C = 2 * HIDDEN_DIM # length of embeddings
|
26 |
+
|
27 |
+
Ng = config_dict["Ng"]
|
28 |
+
cond_dim = config_dict["condition_dim"]
|
29 |
+
z_dim = config_dict["noise_dim"]
|
30 |
+
|
31 |
+
|
32 |
+
###############
|
33 |
+
# LOAD MODELS #
|
34 |
+
###############
|
35 |
+
|
36 |
+
models = {
|
37 |
+
"COCO": {
|
38 |
+
"dir": "weights/coco"
|
39 |
+
},
|
40 |
+
"Bird": {
|
41 |
+
"dir": "weights/bird"
|
42 |
+
},
|
43 |
+
"UTKFace": {
|
44 |
+
"dir": "weights/utkface"
|
45 |
+
}
|
46 |
+
}
|
47 |
+
|
48 |
+
for model_name in models:
|
49 |
+
# create tokenizer
|
50 |
+
models[model_name]["tokenizer"] = TAIMGANTokenizer(captions_path=f"{models[model_name]['dir']}/captions.pickle")
|
51 |
+
vocab_size = len(models[model_name]["tokenizer"].word_to_ix)
|
52 |
+
# instantiate models
|
53 |
+
models[model_name]["generator"] = Generator(Ng=Ng, D=C, conditioning_dim=cond_dim, noise_dim=z_dim).eval()
|
54 |
+
models[model_name]["lstm"] = TextEncoder(vocab_size=vocab_size, emb_dim=C, hidden_dim=HIDDEN_DIM).eval()
|
55 |
+
models[model_name]["vgg"] = VGGEncoder().eval()
|
56 |
+
models[model_name]["inception"] = InceptionEncoder(D=C).eval()
|
57 |
+
# load models
|
58 |
+
load_model(
|
59 |
+
generator=models[model_name]["generator"],
|
60 |
+
discriminator=None,
|
61 |
+
image_encoder=models[model_name]["inception"],
|
62 |
+
text_encoder=models[model_name]["lstm"],
|
63 |
+
output_dir=Path(models[model_name]["dir"]),
|
64 |
+
device=torch.device("cpu")
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
def change_image_with_text(image: Image, text: str, model_name: str) -> Image:
|
69 |
+
"""
|
70 |
+
Create an image modified by text from the original image
|
71 |
+
and save it with _modified postfix
|
72 |
+
|
73 |
+
:param gr.Image image: Path to the image
|
74 |
+
:param str text: Desired caption
|
75 |
+
"""
|
76 |
+
global models
|
77 |
+
tokenizer = models[model_name]["tokenizer"]
|
78 |
+
G = models[model_name]["generator"]
|
79 |
+
lstm = models[model_name]["lstm"]
|
80 |
+
inception = models[model_name]["inception"]
|
81 |
+
vgg = models[model_name]["vgg"]
|
82 |
+
# generate some noise
|
83 |
+
noise = torch.rand(z_dim).unsqueeze(0)
|
84 |
+
# transform input text and get masks with embeddings
|
85 |
+
tokens = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
|
86 |
+
mask = (tokens == tokenizer.pad_token_id)
|
87 |
+
word_embs, sent_embs = lstm(tokens)
|
88 |
+
# open the image and transform it to the tensor
|
89 |
+
image = transforms.Compose([
|
90 |
+
transforms.ToTensor(),
|
91 |
+
transforms.Resize((IMG_HW, IMG_HW)),
|
92 |
+
transforms.Normalize(
|
93 |
+
mean=(0.5, 0.5, 0.5),
|
94 |
+
std=(0.5, 0.5, 0.5)
|
95 |
+
)
|
96 |
+
])(image).unsqueeze(0)
|
97 |
+
# obtain visual features of the image
|
98 |
+
vgg_features = vgg(image)
|
99 |
+
local_features, global_features = inception(image)
|
100 |
+
# generate new image from the old one
|
101 |
+
fake_image, _, _ = G(noise, sent_embs, word_embs, global_features,
|
102 |
+
local_features, vgg_features, mask)
|
103 |
+
# denormalize the image
|
104 |
+
fake_image = Image.fromarray(get_image_arr(fake_image)[0])
|
105 |
+
# return image in gradio format
|
106 |
+
return fake_image
|
107 |
+
|
108 |
+
|
109 |
+
##########
|
110 |
+
# GRADIO #
|
111 |
+
##########
|
112 |
+
demo = gr.Interface(
|
113 |
+
fn=change_image_with_text,
|
114 |
+
inputs=[gr.Image(type="pil"), "text", gr.inputs.Dropdown(list(models.keys()))],
|
115 |
+
outputs=gr.Image(type="pil")
|
116 |
+
)
|
117 |
+
demo.launch(debug=True)
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Pillow
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
torchaudio
|
5 |
+
nltk
|
src/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
"""Config file for the project."""
|
2 |
+
from .config import config_dict, update_config
|
src/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (260 Bytes). View file
|
|
src/__pycache__/config.cpython-39.pyc
ADDED
Binary file (1.17 kB). View file
|
|
src/config.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Configurations for the project."""
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Any, Dict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
8 |
+
|
9 |
+
repo_path = Path(__file__).parent.parent.absolute()
|
10 |
+
output_path = repo_path / "models"
|
11 |
+
|
12 |
+
config_dict = {
|
13 |
+
"Ng": 32,
|
14 |
+
"D": 256,
|
15 |
+
"condition_dim": 100,
|
16 |
+
"noise_dim": 100,
|
17 |
+
"lr_config": {
|
18 |
+
"disc_lr": 2e-4,
|
19 |
+
"gen_lr": 2e-4,
|
20 |
+
"img_encoder_lr": 3e-3,
|
21 |
+
"text_encoder_lr": 3e-3,
|
22 |
+
},
|
23 |
+
"batch_size": 64,
|
24 |
+
"device": device,
|
25 |
+
"epochs": 200,
|
26 |
+
"output_dir": output_path,
|
27 |
+
"snapshot": 5,
|
28 |
+
"const_dict": {
|
29 |
+
"smooth_val_gen": 0.999,
|
30 |
+
"lambda1": 1,
|
31 |
+
"lambda2": 1,
|
32 |
+
"lambda3": 1,
|
33 |
+
"lambda4": 1,
|
34 |
+
"gamma1": 4,
|
35 |
+
"gamma2": 5,
|
36 |
+
"gamma3": 10,
|
37 |
+
},
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
def update_config(cfg_dict: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
|
42 |
+
"""
|
43 |
+
Function to update the configuration dictionary.
|
44 |
+
"""
|
45 |
+
for key, value in kwargs.items():
|
46 |
+
cfg_dict[key] = value
|
47 |
+
return cfg_dict
|
src/data/.gitkeep
ADDED
File without changes
|
src/data/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Dataset and custom collate function to load"""
|
2 |
+
|
3 |
+
from .collate import custom_collate
|
4 |
+
from .datasets import TextImageDataset
|
5 |
+
from .tokenizer import TAIMGANTokenizer
|
src/data/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (372 Bytes). View file
|
|
src/data/__pycache__/collate.cpython-39.pyc
ADDED
Binary file (1.3 kB). View file
|
|
src/data/__pycache__/datasets.cpython-39.pyc
ADDED
Binary file (11.8 kB). View file
|
|
src/data/__pycache__/tokenizer.cpython-39.pyc
ADDED
Binary file (1.55 kB). View file
|
|
src/data/collate.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Custom collate function for the data loader."""
|
2 |
+
|
3 |
+
from typing import Any, List
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.nn.utils.rnn import pad_sequence
|
7 |
+
|
8 |
+
|
9 |
+
def custom_collate(batch: List[Any], device: Any) -> Any:
|
10 |
+
"""
|
11 |
+
Custom collate function to be used in the data loader.
|
12 |
+
:param batch: list, with length equal to number of batches.
|
13 |
+
:return: processed batch of data [add padding to text, stack tensors in batch]
|
14 |
+
"""
|
15 |
+
img, correct_capt, curr_class, word_labels = zip(*batch)
|
16 |
+
batched_img = torch.stack(img, dim=0).to(
|
17 |
+
device
|
18 |
+
) # shape: (batch_size, 3, height, width)
|
19 |
+
correct_capt_len = torch.tensor(
|
20 |
+
[len(capt) for capt in correct_capt], dtype=torch.int64
|
21 |
+
).unsqueeze(
|
22 |
+
1
|
23 |
+
) # shape: (batch_size, 1)
|
24 |
+
batched_correct_capt = pad_sequence(
|
25 |
+
correct_capt, batch_first=True, padding_value=0
|
26 |
+
).to(
|
27 |
+
device
|
28 |
+
) # shape: (batch_size, max_seq_len)
|
29 |
+
batched_curr_class = torch.stack(curr_class, dim=0).to(
|
30 |
+
device
|
31 |
+
) # shape: (batch_size, 1)
|
32 |
+
batched_word_labels = pad_sequence(
|
33 |
+
word_labels, batch_first=True, padding_value=0
|
34 |
+
).to(
|
35 |
+
device
|
36 |
+
) # shape: (batch_size, max_seq_len)
|
37 |
+
return (
|
38 |
+
batched_img,
|
39 |
+
batched_correct_capt,
|
40 |
+
correct_capt_len,
|
41 |
+
batched_curr_class,
|
42 |
+
batched_word_labels,
|
43 |
+
)
|
src/data/datasets.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Pytorch Dataset classes for the datasets used in the project."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
from collections import defaultdict
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import nltk
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import torch
|
12 |
+
import torchvision.transforms.functional as F
|
13 |
+
from nltk.tokenize import RegexpTokenizer
|
14 |
+
from PIL import Image
|
15 |
+
from torch.utils.data import Dataset
|
16 |
+
from torchvision import transforms
|
17 |
+
|
18 |
+
|
19 |
+
class TextImageDataset(Dataset): # type: ignore
|
20 |
+
"""Custom PyTorch Dataset class to load Image and Text data."""
|
21 |
+
|
22 |
+
# pylint: disable=too-many-instance-attributes
|
23 |
+
# pylint: disable=too-many-locals
|
24 |
+
# pylint: disable=too-many-function-args
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self, data_path: str, split: str, num_captions: int, transform: Any = None
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
:param data_path: Path to the data directory. [i.e. can be './birds/', or './coco/]
|
31 |
+
:param split: 'train' or 'test' split
|
32 |
+
:param num_captions: number of captions present per image.
|
33 |
+
[For birds, this is 10, for coco, this is 5]
|
34 |
+
:param transform: PyTorch transform to apply to the images.
|
35 |
+
"""
|
36 |
+
self.transform = transform
|
37 |
+
self.bound_box_map = None
|
38 |
+
self.file_names = self.load_filenames(data_path, split)
|
39 |
+
self.data_path = data_path
|
40 |
+
self.num_captions_per_image = num_captions
|
41 |
+
(
|
42 |
+
self.captions,
|
43 |
+
self.ix_to_word,
|
44 |
+
self.word_to_ix,
|
45 |
+
self.vocab_len,
|
46 |
+
) = self.get_capt_and_vocab(data_path, split)
|
47 |
+
self.normalize = transforms.Compose(
|
48 |
+
[
|
49 |
+
transforms.ToTensor(),
|
50 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
51 |
+
]
|
52 |
+
)
|
53 |
+
self.class_ids = self.get_class_id(data_path, split, len(self.file_names))
|
54 |
+
if self.data_path.endswith("birds/"):
|
55 |
+
self.bound_box_map = self.get_bound_box(data_path)
|
56 |
+
|
57 |
+
elif self.data_path.endswith("coco/"):
|
58 |
+
pass
|
59 |
+
|
60 |
+
else:
|
61 |
+
raise ValueError(
|
62 |
+
"Invalid data path. Please ensure the data [CUB/COCO] is stored in correct folders."
|
63 |
+
)
|
64 |
+
|
65 |
+
def __len__(self) -> int:
|
66 |
+
"""Return the length of the dataset."""
|
67 |
+
return len(self.file_names)
|
68 |
+
|
69 |
+
def __getitem__(self, idx: int) -> Any:
|
70 |
+
"""
|
71 |
+
Return the item at index idx.
|
72 |
+
:param idx: index of the item to return
|
73 |
+
:return img_tensor: image tensor
|
74 |
+
:return correct_caption: correct caption for the image [list of word indices]
|
75 |
+
:return curr_class_id: class id of the image
|
76 |
+
:return word_labels: POS_tagged word labels [1 for noun and adjective, 0 else]
|
77 |
+
|
78 |
+
"""
|
79 |
+
file_name = self.file_names[idx]
|
80 |
+
curr_class_id = self.class_ids[idx]
|
81 |
+
|
82 |
+
if self.bound_box_map is not None:
|
83 |
+
bbox = self.bound_box_map[file_name]
|
84 |
+
images_dir = os.path.join(self.data_path, "CUB_200_2011/images")
|
85 |
+
else:
|
86 |
+
bbox = None
|
87 |
+
images_dir = os.path.join(self.data_path, "images")
|
88 |
+
|
89 |
+
img_path = os.path.join(images_dir, file_name + ".jpg")
|
90 |
+
img_tensor = self.get_image(img_path, bbox, self.transform)
|
91 |
+
|
92 |
+
rand_sent_idx = np.random.randint(0, self.num_captions_per_image)
|
93 |
+
rand_sent_idx = idx * self.num_captions_per_image + rand_sent_idx
|
94 |
+
|
95 |
+
correct_caption = torch.tensor(self.captions[rand_sent_idx], dtype=torch.int64)
|
96 |
+
num_words = len(correct_caption)
|
97 |
+
|
98 |
+
capt_token_list = []
|
99 |
+
for i in range(num_words):
|
100 |
+
capt_token_list.append(self.ix_to_word[correct_caption[i].item()])
|
101 |
+
|
102 |
+
pos_tag_list = nltk.tag.pos_tag(capt_token_list)
|
103 |
+
word_labels = []
|
104 |
+
|
105 |
+
for pos_tag in pos_tag_list:
|
106 |
+
if (
|
107 |
+
"NN" in pos_tag[1] or "JJ" in pos_tag[1]
|
108 |
+
): # check for Nouns and Adjective only
|
109 |
+
word_labels.append(1)
|
110 |
+
else:
|
111 |
+
word_labels.append(0)
|
112 |
+
|
113 |
+
word_labels = torch.tensor(word_labels).float() # type: ignore
|
114 |
+
|
115 |
+
curr_class_id = torch.tensor(curr_class_id, dtype=torch.int64).unsqueeze(0)
|
116 |
+
|
117 |
+
return (
|
118 |
+
img_tensor,
|
119 |
+
correct_caption,
|
120 |
+
curr_class_id,
|
121 |
+
word_labels,
|
122 |
+
)
|
123 |
+
|
124 |
+
def get_capt_and_vocab(self, data_dir: str, split: str) -> Any:
|
125 |
+
"""
|
126 |
+
Helper function to get the captions, vocab dict for each image.
|
127 |
+
:param data_dir: path to the data directory [i.e. './birds/' or './coco/']
|
128 |
+
:param split: 'train' or 'test' split
|
129 |
+
:return captions: list of all captions for each image
|
130 |
+
:return ix_to_word: dictionary mapping index to word
|
131 |
+
:return word_to_ix: dictionary mapping word to index
|
132 |
+
:return num_words: number of unique words in the vocabulary
|
133 |
+
"""
|
134 |
+
captions_ckpt_path = os.path.join(data_dir, "stubs/captions.pickle")
|
135 |
+
if os.path.exists(
|
136 |
+
captions_ckpt_path
|
137 |
+
): # check if previously processed captions exist
|
138 |
+
with open(captions_ckpt_path, "rb") as ckpt_file:
|
139 |
+
captions = pickle.load(ckpt_file)
|
140 |
+
train_captions, test_captions = captions[0], captions[1]
|
141 |
+
ix_to_word, word_to_ix = captions[2], captions[3]
|
142 |
+
num_words = len(ix_to_word)
|
143 |
+
del captions
|
144 |
+
if split == "train":
|
145 |
+
return train_captions, ix_to_word, word_to_ix, num_words
|
146 |
+
return test_captions, ix_to_word, word_to_ix, num_words
|
147 |
+
|
148 |
+
else: # if not, process the captions and save them
|
149 |
+
train_files = self.load_filenames(data_dir, "train")
|
150 |
+
test_files = self.load_filenames(data_dir, "test")
|
151 |
+
|
152 |
+
train_captions_tokenized = self.get_tokenized_captions(
|
153 |
+
data_dir, train_files
|
154 |
+
)
|
155 |
+
test_captions_tokenized = self.get_tokenized_captions(
|
156 |
+
data_dir, test_files
|
157 |
+
) # we need both train and test captions to build the vocab
|
158 |
+
|
159 |
+
(
|
160 |
+
train_captions,
|
161 |
+
test_captions,
|
162 |
+
ix_to_word,
|
163 |
+
word_to_ix,
|
164 |
+
num_words,
|
165 |
+
) = self.build_vocab( # type: ignore
|
166 |
+
train_captions_tokenized, test_captions_tokenized, split
|
167 |
+
)
|
168 |
+
vocab_list = [train_captions, test_captions, ix_to_word, word_to_ix]
|
169 |
+
with open(captions_ckpt_path, "wb") as ckpt_file:
|
170 |
+
pickle.dump(vocab_list, ckpt_file)
|
171 |
+
|
172 |
+
if split == "train":
|
173 |
+
return train_captions, ix_to_word, word_to_ix, num_words
|
174 |
+
if split == "test":
|
175 |
+
return test_captions, ix_to_word, word_to_ix, num_words
|
176 |
+
raise ValueError("Invalid split. Please use 'train' or 'test'")
|
177 |
+
|
178 |
+
def build_vocab(
|
179 |
+
self, tokenized_captions_train: list, tokenized_captions_test: list # type: ignore
|
180 |
+
) -> Any:
|
181 |
+
"""
|
182 |
+
Helper function which builds the vocab dicts.
|
183 |
+
:param tokenized_captions_train: list containing all the
|
184 |
+
train tokenized captions in the dataset. This is list of lists.
|
185 |
+
:param tokenized_captions_test: list containing all the
|
186 |
+
test tokenized captions in the dataset. This is list of lists.
|
187 |
+
:return train_captions_int: list of all captions in training,
|
188 |
+
where each word is replaced by its index in the vocab
|
189 |
+
:return test_captions_int: list of all captions in test,
|
190 |
+
where each word is replaced by its index in the vocab
|
191 |
+
:return ix_to_word: dictionary mapping index to word
|
192 |
+
:return word_to_ix: dictionary mapping word to index
|
193 |
+
:return num_words: number of unique words in the vocabulary
|
194 |
+
"""
|
195 |
+
vocab = defaultdict(int) # type: ignore
|
196 |
+
total_captions = tokenized_captions_train + tokenized_captions_test
|
197 |
+
for caption in total_captions:
|
198 |
+
for word in caption:
|
199 |
+
vocab[word] += 1
|
200 |
+
|
201 |
+
# sort vocab dict by frequency in descending order
|
202 |
+
vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True) # type: ignore
|
203 |
+
|
204 |
+
ix_to_word = {}
|
205 |
+
word_to_ix = {}
|
206 |
+
ix_to_word[0] = "<end>"
|
207 |
+
word_to_ix["<end>"] = 0
|
208 |
+
|
209 |
+
word_idx = 1
|
210 |
+
for word, _ in vocab:
|
211 |
+
word_to_ix[word] = word_idx
|
212 |
+
ix_to_word[word_idx] = word
|
213 |
+
word_idx += 1
|
214 |
+
|
215 |
+
train_captions_int = [] # we want to convert words to indices in vocab.
|
216 |
+
for caption in tokenized_captions_train:
|
217 |
+
curr_caption_int = []
|
218 |
+
for word in caption:
|
219 |
+
curr_caption_int.append(word_to_ix[word])
|
220 |
+
|
221 |
+
train_captions_int.append(curr_caption_int)
|
222 |
+
|
223 |
+
test_captions_int = []
|
224 |
+
for caption in tokenized_captions_test:
|
225 |
+
curr_caption_int = []
|
226 |
+
for word in caption:
|
227 |
+
curr_caption_int.append(word_to_ix[word])
|
228 |
+
|
229 |
+
test_captions_int.append(curr_caption_int)
|
230 |
+
|
231 |
+
return (
|
232 |
+
train_captions_int,
|
233 |
+
test_captions_int,
|
234 |
+
ix_to_word,
|
235 |
+
word_to_ix,
|
236 |
+
len(ix_to_word),
|
237 |
+
)
|
238 |
+
|
239 |
+
def get_tokenized_captions(self, data_dir: str, filenames: list) -> Any: # type: ignore
|
240 |
+
"""
|
241 |
+
Helper function to tokenize and return captions for each image in filenames.
|
242 |
+
:param data_dir: path to the data directory [i.e. './birds/' or './coco/']
|
243 |
+
:param filenames: list of all filenames corresponding to the split
|
244 |
+
:return tokenized_captions: list of all tokenized captions for all files in filenames.
|
245 |
+
[this returns a list, where each element is again a list of tokens/words]
|
246 |
+
"""
|
247 |
+
|
248 |
+
all_captions = []
|
249 |
+
for filename in filenames:
|
250 |
+
caption_path = os.path.join(data_dir, "text", filename + ".txt")
|
251 |
+
with open(caption_path, "r", encoding="utf8") as txt_file:
|
252 |
+
captions = txt_file.readlines()
|
253 |
+
count = 0
|
254 |
+
for caption in captions:
|
255 |
+
if len(caption) == 0:
|
256 |
+
continue
|
257 |
+
|
258 |
+
caption = caption.replace("\ufffd\ufffd", " ")
|
259 |
+
tokenizer = RegexpTokenizer(r"\w+")
|
260 |
+
tokens = tokenizer.tokenize(
|
261 |
+
caption.lower()
|
262 |
+
) # splits current caption/line to list of words/tokens
|
263 |
+
if len(tokens) == 0:
|
264 |
+
continue
|
265 |
+
|
266 |
+
tokens = [
|
267 |
+
t.encode("ascii", "ignore").decode("ascii") for t in tokens
|
268 |
+
]
|
269 |
+
tokens = [t for t in tokens if len(t) > 0]
|
270 |
+
|
271 |
+
all_captions.append(tokens)
|
272 |
+
count += 1
|
273 |
+
if count == self.num_captions_per_image:
|
274 |
+
break
|
275 |
+
if count < self.num_captions_per_image:
|
276 |
+
raise ValueError(
|
277 |
+
f"Number of captions for {filename} is only {count},\
|
278 |
+
which is less than {self.num_captions_per_image}."
|
279 |
+
)
|
280 |
+
|
281 |
+
return all_captions
|
282 |
+
|
283 |
+
def get_image(self, img_path: str, bbox: list, transform: Any) -> Any: # type: ignore
|
284 |
+
"""
|
285 |
+
Helper function to load and transform an image.
|
286 |
+
:param img_path: path to the image
|
287 |
+
:param bbox: bounding box coordinates [x, y, width, height]
|
288 |
+
:param transform: PyTorch transform to apply to the image
|
289 |
+
:return img_tensor: transformed image tensor
|
290 |
+
"""
|
291 |
+
img = Image.open(img_path).convert("RGB")
|
292 |
+
width, height = img.size
|
293 |
+
|
294 |
+
if bbox is not None:
|
295 |
+
r_val = int(np.maximum(bbox[2], bbox[3]) * 0.75)
|
296 |
+
|
297 |
+
center_x = int((2 * bbox[0] + bbox[2]) / 2)
|
298 |
+
center_y = int((2 * bbox[1] + bbox[3]) / 2)
|
299 |
+
y1_coord = np.maximum(0, center_y - r_val)
|
300 |
+
y2_coord = np.minimum(height, center_y + r_val)
|
301 |
+
x1_coord = np.maximum(0, center_x - r_val)
|
302 |
+
x2_coord = np.minimum(width, center_x + r_val)
|
303 |
+
|
304 |
+
img = img.crop(
|
305 |
+
[x1_coord, y1_coord, x2_coord, y2_coord]
|
306 |
+
) # This preprocessing steps seems to follow from
|
307 |
+
# Stackgan: Text to photo-realistic image synthesis
|
308 |
+
|
309 |
+
if transform is not None:
|
310 |
+
img_tensor = transform(img) # this scales to 304x304, i.e. 256 x (76/64).
|
311 |
+
x_val = np.random.randint(0, 48) # 304 - 256 = 48
|
312 |
+
y_val = np.random.randint(0, 48)
|
313 |
+
flip = np.random.rand() > 0.5
|
314 |
+
|
315 |
+
# crop
|
316 |
+
img_tensor = img_tensor.crop(
|
317 |
+
[x_val, y_val, x_val + 256, y_val + 256]
|
318 |
+
) # this crops to 256x256
|
319 |
+
if flip:
|
320 |
+
img_tensor = F.hflip(img_tensor)
|
321 |
+
|
322 |
+
img_tensor = self.normalize(img_tensor)
|
323 |
+
|
324 |
+
return img_tensor
|
325 |
+
|
326 |
+
def load_filenames(self, data_dir: str, split: str) -> Any:
|
327 |
+
"""
|
328 |
+
Helper function to get list of all image filenames.
|
329 |
+
:param data_dir: path to the data directory [i.e. './birds/' or './coco/']
|
330 |
+
:param split: 'train' or 'test' split
|
331 |
+
:return filenames: list of all image filenames
|
332 |
+
"""
|
333 |
+
filepath = f"{data_dir}{split}/filenames.pickle"
|
334 |
+
if os.path.isfile(filepath):
|
335 |
+
with open(filepath, "rb") as pick_file:
|
336 |
+
filenames = pickle.load(pick_file)
|
337 |
+
else:
|
338 |
+
raise ValueError(
|
339 |
+
"Invalid split. Please use 'train' or 'test',\
|
340 |
+
or make sure the filenames.pickle file exists."
|
341 |
+
)
|
342 |
+
return filenames
|
343 |
+
|
344 |
+
def get_class_id(self, data_dir: str, split: str, total_elems: int) -> Any:
|
345 |
+
"""
|
346 |
+
Helper function to get list of all image class ids.
|
347 |
+
:param data_dir: path to the data directory [i.e. './birds/' or './coco/']
|
348 |
+
:param split: 'train' or 'test' split
|
349 |
+
:param total_elems: total number of elements in the dataset
|
350 |
+
:return class_ids: list of all image class ids
|
351 |
+
"""
|
352 |
+
filepath = f"{data_dir}{split}/class_info.pickle"
|
353 |
+
if os.path.isfile(filepath):
|
354 |
+
with open(filepath, "rb") as class_file:
|
355 |
+
class_ids = pickle.load(class_file, encoding="latin1")
|
356 |
+
else:
|
357 |
+
class_ids = np.arange(total_elems)
|
358 |
+
return class_ids
|
359 |
+
|
360 |
+
def get_bound_box(self, data_path: str) -> Any:
|
361 |
+
"""
|
362 |
+
Helper function to get the bounding box for birds dataset.
|
363 |
+
:param data_path: path to birds data directory [i.e. './data/birds/']
|
364 |
+
:return imageToBox: dictionary mapping image name to bounding box coordinates
|
365 |
+
"""
|
366 |
+
bbox_path = os.path.join(data_path, "CUB_200_2011/bounding_boxes.txt")
|
367 |
+
df_bounding_boxes = pd.read_csv(
|
368 |
+
bbox_path, delim_whitespace=True, header=None
|
369 |
+
).astype(int)
|
370 |
+
|
371 |
+
filepath = os.path.join(data_path, "CUB_200_2011/images.txt")
|
372 |
+
df_filenames = pd.read_csv(filepath, delim_whitespace=True, header=None)
|
373 |
+
filenames = df_filenames[
|
374 |
+
1
|
375 |
+
].tolist() # df_filenames[0] just contains the index or ID.
|
376 |
+
|
377 |
+
img_to_box = { # type: ignore
|
378 |
+
img_file[:-4]: [] for img_file in filenames
|
379 |
+
} # remove the .jpg extension from the names
|
380 |
+
num_imgs = len(filenames)
|
381 |
+
|
382 |
+
for i in range(0, num_imgs):
|
383 |
+
bbox = df_bounding_boxes.iloc[i][1:].tolist()
|
384 |
+
key = filenames[i][:-4]
|
385 |
+
img_to_box[key] = bbox
|
386 |
+
|
387 |
+
return img_to_box
|
src/data/stubs/bird.jpg
ADDED
src/data/stubs/pigeon.jpg
ADDED
src/data/stubs/rohit.jpeg
ADDED
src/data/tokenizer.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import re
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
|
6 |
+
class TAIMGANTokenizer:
|
7 |
+
def __init__(self, captions_path):
|
8 |
+
with open(captions_path, "rb") as ckpt_file:
|
9 |
+
captions = pickle.load(ckpt_file)
|
10 |
+
self.ix_to_word = captions[2]
|
11 |
+
self.word_to_ix = captions[3]
|
12 |
+
self.token_regex = r'\w+'
|
13 |
+
self.pad_token_id = self.word_to_ix["<end>"]
|
14 |
+
self.pad_repr = "[PAD]"
|
15 |
+
|
16 |
+
def encode(self, text: str) -> List[int]:
|
17 |
+
return [self.word_to_ix.get(word, self.pad_token_id)
|
18 |
+
for word in re.findall(self.token_regex, text.lower())]
|
19 |
+
|
20 |
+
def decode(self, tokens: List[int]) -> str:
|
21 |
+
return ' '.join([self.ix_to_word[token]
|
22 |
+
if token != self.pad_token_id else self.pad_repr
|
23 |
+
for token in tokens])
|
src/features/.gitkeep
ADDED
File without changes
|
src/features/__init__.py
ADDED
File without changes
|
src/features/build_features.py
ADDED
File without changes
|
src/models/.gitkeep
ADDED
File without changes
|
src/models/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Helper functions for training loop."""
|
2 |
+
from .losses import discriminator_loss, generator_loss, kl_loss
|
3 |
+
from .train_model import train
|
4 |
+
from .utils import copy_gen_params, define_optimizers, load_params, prepare_labels
|
src/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (461 Bytes). View file
|
|
src/models/__pycache__/losses.cpython-39.pyc
ADDED
Binary file (8.36 kB). View file
|
|
src/models/__pycache__/train_model.cpython-39.pyc
ADDED
Binary file (3.82 kB). View file
|
|
src/models/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (8.76 kB). View file
|
|
src/models/losses.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module containing the loss functions for the GANs."""
|
2 |
+
from typing import Any, Dict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
# pylint: disable=too-many-arguments
|
8 |
+
# pylint: disable=too-many-locals
|
9 |
+
|
10 |
+
|
11 |
+
def generator_loss(
|
12 |
+
logits: Dict[str, Dict[str, torch.Tensor]],
|
13 |
+
local_fake_incept_feat: torch.Tensor,
|
14 |
+
global_fake_incept_feat: torch.Tensor,
|
15 |
+
real_labels: torch.Tensor,
|
16 |
+
words_emb: torch.Tensor,
|
17 |
+
sent_emb: torch.Tensor,
|
18 |
+
match_labels: torch.Tensor,
|
19 |
+
cap_lens: torch.Tensor,
|
20 |
+
class_ids: torch.Tensor,
|
21 |
+
real_vgg_feat: torch.Tensor,
|
22 |
+
fake_vgg_feat: torch.Tensor,
|
23 |
+
const_dict: Dict[str, float],
|
24 |
+
) -> Any:
|
25 |
+
"""Calculate the loss for the generator.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
logits: Dictionary with fake/real and word-level/uncond/cond logits
|
29 |
+
|
30 |
+
local_fake_incept_feat: The local inception features for the fake images.
|
31 |
+
|
32 |
+
global_fake_incept_feat: The global inception features for the fake images.
|
33 |
+
|
34 |
+
real_labels: Label for "real" image as predicted by discriminator,
|
35 |
+
this is a tensor of ones. [shape: (batch_size, 1)].
|
36 |
+
|
37 |
+
word_labels: POS tagged word labels for the captions. [shape: (batch_size, L)]
|
38 |
+
|
39 |
+
words_emb: The embeddings for all the words in the captions.
|
40 |
+
shape: (batch_size, embedding_size, max_caption_length)
|
41 |
+
|
42 |
+
sent_emb: The embeddings for the sentences.
|
43 |
+
shape: (batch_size, embedding_size)
|
44 |
+
|
45 |
+
match_labels: Tensor of shape: (batch_size, 1).
|
46 |
+
This is of the form torch.tensor([0, 1, 2, ..., batch-1])
|
47 |
+
|
48 |
+
cap_lens: The length of the 'actual' captions in the batch [without padding]
|
49 |
+
shape: (batch_size, 1)
|
50 |
+
|
51 |
+
class_ids: The class ids for the instance. shape: (batch_size, 1)
|
52 |
+
|
53 |
+
real_vgg_feat: The vgg features for the real images. shape: (batch_size, 128, 128, 128)
|
54 |
+
fake_vgg_feat: The vgg features for the fake images. shape: (batch_size, 128, 128, 128)
|
55 |
+
|
56 |
+
const_dict: The dictionary containing the constants.
|
57 |
+
"""
|
58 |
+
lambda1 = const_dict["lambda1"]
|
59 |
+
total_error_g = 0.0
|
60 |
+
|
61 |
+
cond_logits = logits["fake"]["cond"]
|
62 |
+
cond_err_g = nn.BCEWithLogitsLoss()(cond_logits, real_labels)
|
63 |
+
|
64 |
+
uncond_logits = logits["fake"]["uncond"]
|
65 |
+
uncond_err_g = nn.BCEWithLogitsLoss()(uncond_logits, real_labels)
|
66 |
+
|
67 |
+
# add up the conditional and unconditional losses
|
68 |
+
loss_g = cond_err_g + uncond_err_g
|
69 |
+
total_error_g += loss_g
|
70 |
+
|
71 |
+
# DAMSM Loss from attnGAN.
|
72 |
+
loss_damsm = damsm_loss(
|
73 |
+
local_fake_incept_feat,
|
74 |
+
global_fake_incept_feat,
|
75 |
+
words_emb,
|
76 |
+
sent_emb,
|
77 |
+
match_labels,
|
78 |
+
cap_lens,
|
79 |
+
class_ids,
|
80 |
+
const_dict,
|
81 |
+
)
|
82 |
+
|
83 |
+
total_error_g += loss_damsm
|
84 |
+
|
85 |
+
loss_per = 0.5 * nn.MSELoss()(real_vgg_feat, fake_vgg_feat) # perceptual loss
|
86 |
+
|
87 |
+
total_error_g += lambda1 * loss_per
|
88 |
+
|
89 |
+
return total_error_g
|
90 |
+
|
91 |
+
|
92 |
+
def damsm_loss(
|
93 |
+
local_incept_feat: torch.Tensor,
|
94 |
+
global_incept_feat: torch.Tensor,
|
95 |
+
words_emb: torch.Tensor,
|
96 |
+
sent_emb: torch.Tensor,
|
97 |
+
match_labels: torch.Tensor,
|
98 |
+
cap_lens: torch.Tensor,
|
99 |
+
class_ids: torch.Tensor,
|
100 |
+
const_dict: Dict[str, float],
|
101 |
+
) -> Any:
|
102 |
+
"""Calculate the DAMSM loss from the attnGAN paper.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
local_incept_feat: The local inception features. [shape: (batch, D, 17, 17)]
|
106 |
+
|
107 |
+
global_incept_feat: The global inception features. [shape: (batch, D)]
|
108 |
+
|
109 |
+
words_emb: The embeddings for all the words in the captions.
|
110 |
+
|
111 |
+
shape: (batch, D, max_caption_length)
|
112 |
+
|
113 |
+
sent_emb: The embeddings for the sentences. shape: (batch_size, D)
|
114 |
+
|
115 |
+
match_labels: Tensor of shape: (batch_size, 1).
|
116 |
+
This is of the form torch.tensor([0, 1, 2, ..., batch-1])
|
117 |
+
|
118 |
+
cap_lens: The length of the 'actual' captions in the batch [without padding]
|
119 |
+
shape: (batch_size, 1)
|
120 |
+
|
121 |
+
class_ids: The class ids for the instance. shape: (batch, 1)
|
122 |
+
|
123 |
+
const_dict: The dictionary containing the constants.
|
124 |
+
"""
|
125 |
+
batch_size = match_labels.size(0)
|
126 |
+
# Mask mis-match samples, that come from the same class as the real sample
|
127 |
+
masks = []
|
128 |
+
|
129 |
+
match_scores = []
|
130 |
+
gamma1 = const_dict["gamma1"]
|
131 |
+
gamma2 = const_dict["gamma2"]
|
132 |
+
gamma3 = const_dict["gamma3"]
|
133 |
+
lambda3 = const_dict["lambda3"]
|
134 |
+
|
135 |
+
for i in range(batch_size):
|
136 |
+
mask = (class_ids == class_ids[i]).int()
|
137 |
+
# This ensures that "correct class" index is not included in the mask.
|
138 |
+
mask[i] = 0
|
139 |
+
masks.append(mask.reshape(1, -1)) # shape: (1, batch)
|
140 |
+
|
141 |
+
numb_words = int(cap_lens[i])
|
142 |
+
# shape: (1, D, L), this picks the caption at ith batch index.
|
143 |
+
query_words = words_emb[i, :, :numb_words].unsqueeze(0)
|
144 |
+
# shape: (batch, D, L), this expands the same caption for all batch indices.
|
145 |
+
query_words = query_words.repeat(batch_size, 1, 1)
|
146 |
+
|
147 |
+
c_i = compute_region_context_vector(
|
148 |
+
local_incept_feat, query_words, gamma1
|
149 |
+
) # Taken from attnGAN paper. shape: (batch, D, L)
|
150 |
+
|
151 |
+
query_words = query_words.transpose(1, 2) # shape: (batch, L, D)
|
152 |
+
c_i = c_i.transpose(1, 2) # shape: (batch, L, D)
|
153 |
+
query_words = query_words.reshape(
|
154 |
+
batch_size * numb_words, -1
|
155 |
+
) # shape: (batch * L, D)
|
156 |
+
c_i = c_i.reshape(batch_size * numb_words, -1) # shape: (batch * L, D)
|
157 |
+
|
158 |
+
r_i = compute_relevance(
|
159 |
+
c_i, query_words
|
160 |
+
) # cosine similarity, or R(c_i, e_i) from attnGAN paper. shape: (batch * L, 1)
|
161 |
+
r_i = r_i.view(batch_size, numb_words) # shape: (batch, L)
|
162 |
+
r_i = torch.exp(r_i * gamma2) # shape: (batch, L)
|
163 |
+
r_i = r_i.sum(dim=1, keepdim=True) # shape: (batch, 1)
|
164 |
+
r_i = torch.log(
|
165 |
+
r_i
|
166 |
+
) # This is image-text matching score b/w whole image and caption, shape: (batch, 1)
|
167 |
+
match_scores.append(r_i)
|
168 |
+
|
169 |
+
masks = torch.cat(masks, dim=0).bool() # type: ignore
|
170 |
+
match_scores = torch.cat(match_scores, dim=1) # type: ignore
|
171 |
+
|
172 |
+
# This corresponds to P(D|Q) from attnGAN.
|
173 |
+
match_scores = gamma3 * match_scores # type: ignore
|
174 |
+
match_scores.data.masked_fill_( # type: ignore
|
175 |
+
masks, -float("inf")
|
176 |
+
) # mask out the scores for mis-matched samples
|
177 |
+
|
178 |
+
match_scores_t = match_scores.transpose( # type: ignore
|
179 |
+
0, 1
|
180 |
+
) # This corresponds to P(Q|D) from attnGAN.
|
181 |
+
|
182 |
+
# This corresponds to L1_w from attnGAN.
|
183 |
+
l1_w = nn.CrossEntropyLoss()(match_scores, match_labels)
|
184 |
+
# This corresponds to L2_w from attnGAN.
|
185 |
+
l2_w = nn.CrossEntropyLoss()(match_scores_t, match_labels)
|
186 |
+
|
187 |
+
incept_feat_norm = torch.linalg.norm(global_incept_feat, dim=1)
|
188 |
+
sent_emb_norm = torch.linalg.norm(sent_emb, dim=1)
|
189 |
+
|
190 |
+
# shape: (batch, batch)
|
191 |
+
global_match_score = global_incept_feat @ (sent_emb.T)
|
192 |
+
|
193 |
+
global_match_score = (
|
194 |
+
global_match_score / torch.outer(incept_feat_norm, sent_emb_norm)
|
195 |
+
).clamp(min=1e-8)
|
196 |
+
global_match_score = gamma3 * global_match_score
|
197 |
+
|
198 |
+
# mask out the scores for mis-matched samples
|
199 |
+
global_match_score.data.masked_fill_(masks, -float("inf")) # type: ignore
|
200 |
+
|
201 |
+
global_match_t = global_match_score.T # shape: (batch, batch)
|
202 |
+
|
203 |
+
# This corresponds to L1_s from attnGAN.
|
204 |
+
l1_s = nn.CrossEntropyLoss()(global_match_score, match_labels)
|
205 |
+
# This corresponds to L2_s from attnGAN.
|
206 |
+
l2_s = nn.CrossEntropyLoss()(global_match_t, match_labels)
|
207 |
+
|
208 |
+
loss_damsm = lambda3 * (l1_w + l2_w + l1_s + l2_s)
|
209 |
+
|
210 |
+
return loss_damsm
|
211 |
+
|
212 |
+
|
213 |
+
def compute_relevance(c_i: torch.Tensor, query_words: torch.Tensor) -> Any:
|
214 |
+
"""Computes the cosine similarity between the region context vector and the query words.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
c_i: The region context vector. shape: (batch * L, D)
|
218 |
+
query_words: The query words. shape: (batch * L, D)
|
219 |
+
"""
|
220 |
+
prod = c_i * query_words # shape: (batch * L, D)
|
221 |
+
numr = torch.sum(prod, dim=1) # shape: (batch * L, 1)
|
222 |
+
norm_c = torch.linalg.norm(c_i, ord=2, dim=1)
|
223 |
+
norm_q = torch.linalg.norm(query_words, ord=2, dim=1)
|
224 |
+
denr = norm_c * norm_q
|
225 |
+
r_i = (numr / denr).clamp(min=1e-8).squeeze() # shape: (batch * L, 1)
|
226 |
+
return r_i
|
227 |
+
|
228 |
+
|
229 |
+
def compute_region_context_vector(
|
230 |
+
local_incept_feat: torch.Tensor, query_words: torch.Tensor, gamma1: float
|
231 |
+
) -> Any:
|
232 |
+
"""Compute the region context vector (c_i) from attnGAN paper.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
local_incept_feat: The local inception features. [shape: (batch, D, 17, 17)]
|
236 |
+
query_words: The embeddings for all the words in the captions. shape: (batch, D, L)
|
237 |
+
gamma1: The gamma1 value from attnGAN paper.
|
238 |
+
"""
|
239 |
+
batch, L = query_words.size(0), query_words.size(2) # pylint: disable=invalid-name
|
240 |
+
|
241 |
+
feat_height, feat_width = local_incept_feat.size(2), local_incept_feat.size(3)
|
242 |
+
N = feat_height * feat_width # pylint: disable=invalid-name
|
243 |
+
|
244 |
+
# Reshape the local inception features to (batch, D, N)
|
245 |
+
local_incept_feat = local_incept_feat.view(batch, -1, N)
|
246 |
+
# shape: (batch, N, D)
|
247 |
+
incept_feat_t = local_incept_feat.transpose(1, 2)
|
248 |
+
|
249 |
+
sim_matrix = incept_feat_t @ query_words # shape: (batch, N, L)
|
250 |
+
sim_matrix = sim_matrix.view(batch * N, L) # shape: (batch * N, L)
|
251 |
+
|
252 |
+
sim_matrix = nn.Softmax(dim=1)(sim_matrix) # shape: (batch * N, L)
|
253 |
+
sim_matrix = sim_matrix.view(batch, N, L) # shape: (batch, N, L)
|
254 |
+
|
255 |
+
sim_matrix = torch.transpose(sim_matrix, 1, 2) # shape: (batch, L, N)
|
256 |
+
sim_matrix = sim_matrix.reshape(batch * L, N) # shape: (batch * L, N)
|
257 |
+
|
258 |
+
alpha_j = gamma1 * sim_matrix # shape: (batch * L, N)
|
259 |
+
alpha_j = nn.Softmax(dim=1)(alpha_j) # shape: (batch * L, N)
|
260 |
+
alpha_j = alpha_j.view(batch, L, N) # shape: (batch, L, N)
|
261 |
+
alpha_j_t = torch.transpose(alpha_j, 1, 2) # shape: (batch, N, L)
|
262 |
+
|
263 |
+
c_i = (
|
264 |
+
local_incept_feat @ alpha_j_t
|
265 |
+
) # shape: (batch, D, L) [summing over N dimension in paper, so we multiply like this]
|
266 |
+
return c_i
|
267 |
+
|
268 |
+
|
269 |
+
def discriminator_loss(
|
270 |
+
logits: Dict[str, Dict[str, torch.Tensor]],
|
271 |
+
labels: Dict[str, Dict[str, torch.Tensor]],
|
272 |
+
) -> Any:
|
273 |
+
"""
|
274 |
+
Calculate discriminator objective
|
275 |
+
|
276 |
+
:param dict[str, dict[str, torch.Tensor]] logits:
|
277 |
+
Dictionary with fake/real and word-level/uncond/cond logits
|
278 |
+
|
279 |
+
Example:
|
280 |
+
|
281 |
+
logits = {
|
282 |
+
"fake": {
|
283 |
+
"word_level": torch.Tensor (BxL)
|
284 |
+
"uncond": torch.Tensor (Bx1)
|
285 |
+
"cond": torch.Tensor (Bx1)
|
286 |
+
},
|
287 |
+
"real": {
|
288 |
+
"word_level": torch.Tensor (BxL)
|
289 |
+
"uncond": torch.Tensor (Bx1)
|
290 |
+
"cond": torch.Tensor (Bx1)
|
291 |
+
},
|
292 |
+
}
|
293 |
+
:param dict[str, dict[str, torch.Tensor]] labels:
|
294 |
+
Dictionary with fake/real and word-level/image labels
|
295 |
+
|
296 |
+
Example:
|
297 |
+
|
298 |
+
labels = {
|
299 |
+
"fake": {
|
300 |
+
"word_level": torch.Tensor (BxL)
|
301 |
+
"image": torch.Tensor (Bx1)
|
302 |
+
},
|
303 |
+
"real": {
|
304 |
+
"word_level": torch.Tensor (BxL)
|
305 |
+
"image": torch.Tensor (Bx1)
|
306 |
+
},
|
307 |
+
}
|
308 |
+
:param float lambda_4: Hyperparameter for word loss in paper
|
309 |
+
:return: Discriminator objective loss
|
310 |
+
:rtype: Any
|
311 |
+
"""
|
312 |
+
# define main loss functions for logit losses
|
313 |
+
tot_loss = 0.0
|
314 |
+
bce_logits = nn.BCEWithLogitsLoss()
|
315 |
+
bce = nn.BCELoss()
|
316 |
+
# calculate word-level loss
|
317 |
+
word_loss = bce(logits["real"]["word_level"], labels["real"]["word_level"])
|
318 |
+
# calculate unconditional adversarial loss
|
319 |
+
uncond_loss = bce_logits(logits["real"]["uncond"], labels["real"]["image"])
|
320 |
+
|
321 |
+
# calculate conditional adversarial loss
|
322 |
+
cond_loss = bce_logits(logits["real"]["cond"], labels["real"]["image"])
|
323 |
+
|
324 |
+
tot_loss = (uncond_loss + cond_loss) / 2.0
|
325 |
+
|
326 |
+
fake_uncond_loss = bce_logits(logits["fake"]["uncond"], labels["fake"]["image"])
|
327 |
+
fake_cond_loss = bce_logits(logits["fake"]["cond"], labels["fake"]["image"])
|
328 |
+
|
329 |
+
tot_loss += (fake_uncond_loss + fake_cond_loss) / 3.0
|
330 |
+
tot_loss += word_loss
|
331 |
+
|
332 |
+
return tot_loss
|
333 |
+
|
334 |
+
|
335 |
+
def kl_loss(mu_tensor: torch.Tensor, logvar: torch.Tensor) -> Any:
|
336 |
+
"""
|
337 |
+
Calculate KL loss
|
338 |
+
|
339 |
+
:param torch.Tensor mu_tensor: Mean of latent distribution
|
340 |
+
:param torch.Tensor logvar: Log variance of latent distribution
|
341 |
+
:return: KL loss [-0.5 * (1 + log(sigma) - mu^2 - sigma^2)]
|
342 |
+
:rtype: Any
|
343 |
+
"""
|
344 |
+
return torch.mean(-0.5 * (1 + 0.5 * logvar - mu_tensor.pow(2) - torch.exp(logvar)))
|
src/models/modules/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""All the modules used in creation of Generator and Discriminator"""
|
2 |
+
from .acm import ACM
|
3 |
+
from .attention import ChannelWiseAttention, SpatialAttention
|
4 |
+
from .cond_augment import CondAugmentation
|
5 |
+
from .conv_utils import calc_out_conv, conv1d, conv2d
|
6 |
+
from .discriminator import Discriminator, WordLevelLogits
|
7 |
+
from .downsample import down_sample
|
8 |
+
from .generator import Generator
|
9 |
+
from .image_encoder import InceptionEncoder, VGGEncoder
|
10 |
+
from .residual import ResidualBlock
|
11 |
+
from .text_encoder import TextEncoder
|
12 |
+
from .upsample import img_up_block, up_sample
|
src/models/modules/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (891 Bytes). View file
|
|
src/models/modules/__pycache__/acm.cpython-39.pyc
ADDED
Binary file (1.66 kB). View file
|
|
src/models/modules/__pycache__/attention.cpython-39.pyc
ADDED
Binary file (3.38 kB). View file
|
|
src/models/modules/__pycache__/cond_augment.cpython-39.pyc
ADDED
Binary file (2.52 kB). View file
|
|
src/models/modules/__pycache__/conv_utils.cpython-39.pyc
ADDED
Binary file (2.37 kB). View file
|
|
src/models/modules/__pycache__/discriminator.cpython-39.pyc
ADDED
Binary file (5.1 kB). View file
|
|
src/models/modules/__pycache__/downsample.cpython-39.pyc
ADDED
Binary file (598 Bytes). View file
|
|
src/models/modules/__pycache__/generator.cpython-39.pyc
ADDED
Binary file (9.03 kB). View file
|
|
src/models/modules/__pycache__/image_encoder.cpython-39.pyc
ADDED
Binary file (4.27 kB). View file
|
|
src/models/modules/__pycache__/residual.cpython-39.pyc
ADDED
Binary file (1.31 kB). View file
|
|
src/models/modules/__pycache__/text_encoder.cpython-39.pyc
ADDED
Binary file (1.92 kB). View file
|
|
src/models/modules/__pycache__/upsample.cpython-39.pyc
ADDED
Binary file (983 Bytes). View file
|
|
src/models/modules/acm.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""ACM and its variations"""
|
2 |
+
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from .conv_utils import conv2d
|
9 |
+
|
10 |
+
|
11 |
+
class ACM(nn.Module):
|
12 |
+
"""Affine Combination Module from ManiGAN"""
|
13 |
+
|
14 |
+
def __init__(self, img_chans: int, text_chans: int, inner_dim: int = 64) -> None:
|
15 |
+
"""
|
16 |
+
Initialize the convolutional layers
|
17 |
+
|
18 |
+
:param int img_chans: Channels in visual input
|
19 |
+
:param int text_chans: Channels of textual input
|
20 |
+
:param int inner_dim: Hyperparameters for inner dimensionality of features
|
21 |
+
"""
|
22 |
+
super().__init__()
|
23 |
+
self.conv = conv2d(in_channels=img_chans, out_channels=inner_dim)
|
24 |
+
self.weights = conv2d(in_channels=inner_dim, out_channels=text_chans)
|
25 |
+
self.biases = conv2d(in_channels=inner_dim, out_channels=text_chans)
|
26 |
+
|
27 |
+
def forward(self, text: torch.Tensor, img: torch.Tensor) -> Any:
|
28 |
+
"""
|
29 |
+
Propagate the textual and visual input through the ACM module
|
30 |
+
|
31 |
+
:param torch.Tensor text: Textual input (can be hidden features)
|
32 |
+
:param torch.Tensor img: Image input
|
33 |
+
:return: Affine combination of text and image
|
34 |
+
:rtype: torch.Tensor
|
35 |
+
"""
|
36 |
+
img_features = self.conv(img)
|
37 |
+
return text * self.weights(img_features) + self.biases(img_features)
|
src/models/modules/attention.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Attention modules"""
|
2 |
+
from typing import Any, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from src.models.modules.conv_utils import conv1d
|
8 |
+
|
9 |
+
|
10 |
+
class ChannelWiseAttention(nn.Module):
|
11 |
+
"""ChannelWise attention adapted from ControlGAN"""
|
12 |
+
|
13 |
+
def __init__(self, fm_size: int, text_d: int) -> None:
|
14 |
+
"""
|
15 |
+
Initialize the Channel-Wise attention module
|
16 |
+
|
17 |
+
:param int fm_size:
|
18 |
+
Height and width of feature map on k-th iteration of forward-pass.
|
19 |
+
In paper, it's H_k * W_k
|
20 |
+
:param int text_d: Dimensionality of sentence. From paper, it's D
|
21 |
+
"""
|
22 |
+
super().__init__()
|
23 |
+
# perception layer
|
24 |
+
self.text_conv = conv1d(text_d, fm_size)
|
25 |
+
# attention across channel dimension
|
26 |
+
self.softmax = nn.Softmax(2)
|
27 |
+
|
28 |
+
def forward(self, v_k: torch.Tensor, w_text: torch.Tensor) -> Any:
|
29 |
+
"""
|
30 |
+
Apply attention to visual features taking into account features of words
|
31 |
+
|
32 |
+
:param torch.Tensor v_k: Visual context
|
33 |
+
:param torch.Tensor w_text: Textual features
|
34 |
+
:return: Fused hidden visual features and word features
|
35 |
+
:rtype: Any
|
36 |
+
"""
|
37 |
+
w_hat = self.text_conv(w_text)
|
38 |
+
m_k = v_k @ w_hat
|
39 |
+
a_k = self.softmax(m_k)
|
40 |
+
w_hat = torch.transpose(w_hat, 1, 2)
|
41 |
+
return a_k @ w_hat
|
42 |
+
|
43 |
+
|
44 |
+
class SpatialAttention(nn.Module):
|
45 |
+
"""Spatial attention module for attending textual context to visual features"""
|
46 |
+
|
47 |
+
def __init__(self, d: int, d_hat: int) -> None:
|
48 |
+
"""
|
49 |
+
Set up softmax and conv layers
|
50 |
+
|
51 |
+
:param int d: Initial embedding size for textual features. D from paper
|
52 |
+
:param int d_hat: Height of image feature map. D_hat from paper
|
53 |
+
"""
|
54 |
+
super().__init__()
|
55 |
+
self.softmax = nn.Softmax(2)
|
56 |
+
self.conv = conv1d(d, d_hat)
|
57 |
+
|
58 |
+
def forward(
|
59 |
+
self,
|
60 |
+
text_context: torch.Tensor,
|
61 |
+
image: torch.Tensor,
|
62 |
+
mask: Optional[torch.Tensor] = None,
|
63 |
+
) -> Any:
|
64 |
+
"""
|
65 |
+
Project image features into the latent space
|
66 |
+
of textual features and apply attention
|
67 |
+
|
68 |
+
:param torch.Tensor text_context: D x T tensor of hidden textual features
|
69 |
+
:param torch.Tensor image: D_hat x N visual features
|
70 |
+
:param Optional[torch.Tensor] mask:
|
71 |
+
Boolean tensor for masking the padded words. BxL
|
72 |
+
:return: Word features attended by visual features
|
73 |
+
:rtype: Any
|
74 |
+
"""
|
75 |
+
# number of features on image feature map H * W
|
76 |
+
feature_num = image.size(2)
|
77 |
+
# number of words in caption
|
78 |
+
len_caption = text_context.size(2)
|
79 |
+
text_context = self.conv(text_context)
|
80 |
+
image = torch.transpose(image, 1, 2)
|
81 |
+
s_i_j = image @ text_context
|
82 |
+
if mask is not None:
|
83 |
+
# duplicating mask and aligning dims with s_i_j
|
84 |
+
mask = mask.repeat(1, feature_num).view(-1, feature_num, len_caption)
|
85 |
+
s_i_j[mask] = -float("inf")
|
86 |
+
b_i_j = self.softmax(s_i_j)
|
87 |
+
c_i_j = b_i_j @ torch.transpose(text_context, 1, 2)
|
88 |
+
return torch.transpose(c_i_j, 1, 2)
|
src/models/modules/cond_augment.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Conditioning Augmentation Module"""
|
2 |
+
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
|
9 |
+
class CondAugmentation(nn.Module):
|
10 |
+
"""Conditioning Augmentation Module"""
|
11 |
+
|
12 |
+
def __init__(self, D: int, conditioning_dim: int):
|
13 |
+
"""
|
14 |
+
:param D: Dimension of the text embedding space [D from AttnGAN paper]
|
15 |
+
:param conditioning_dim: Dimension of the conditioning space
|
16 |
+
"""
|
17 |
+
super().__init__()
|
18 |
+
self.cond_dim = conditioning_dim
|
19 |
+
self.cond_augment = nn.Linear(D, conditioning_dim * 4, bias=True)
|
20 |
+
self.glu = nn.GLU(dim=1)
|
21 |
+
|
22 |
+
def encode(self, text_embedding: torch.Tensor) -> Any:
|
23 |
+
"""
|
24 |
+
This function encodes the text embedding into the conditioning space
|
25 |
+
:param text_embedding: Text embedding
|
26 |
+
:return: Conditioning embedding
|
27 |
+
"""
|
28 |
+
x_tensor = self.glu(self.cond_augment(text_embedding))
|
29 |
+
mu_tensor = x_tensor[:, : self.cond_dim]
|
30 |
+
logvar = x_tensor[:, self.cond_dim :]
|
31 |
+
return mu_tensor, logvar
|
32 |
+
|
33 |
+
def sample(self, mu_tensor: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
|
34 |
+
"""
|
35 |
+
This function samples from the Gaussian distribution
|
36 |
+
:param mu: Mean of the Gaussian distribution
|
37 |
+
:param logvar: Log variance of the Gaussian distribution
|
38 |
+
:return: Sample from the Gaussian distribution
|
39 |
+
"""
|
40 |
+
std = torch.exp(0.5 * logvar)
|
41 |
+
eps = torch.randn_like(
|
42 |
+
std
|
43 |
+
) # check if this should add requires_grad = True to this tensor?
|
44 |
+
return mu_tensor + eps * std
|
45 |
+
|
46 |
+
def forward(self, text_embedding: torch.Tensor) -> Any:
|
47 |
+
"""
|
48 |
+
This function encodes the text embedding into the conditioning space,
|
49 |
+
and samples from the Gaussian distribution.
|
50 |
+
:param text_embedding: Text embedding
|
51 |
+
:return c_hat: Conditioning embedding (C^ from StackGAN++ paper)
|
52 |
+
:return mu: Mean of the Gaussian distribution
|
53 |
+
:return logvar: Log variance of the Gaussian distribution
|
54 |
+
"""
|
55 |
+
mu_tensor, logvar = self.encode(text_embedding)
|
56 |
+
c_hat = self.sample(mu_tensor, logvar)
|
57 |
+
return c_hat, mu_tensor, logvar
|
src/models/modules/conv_utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Frequently used convolution modules"""
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
|
8 |
+
def conv2d(
|
9 |
+
in_channels: int,
|
10 |
+
out_channels: int,
|
11 |
+
kernel_size: int = 3,
|
12 |
+
stride: int = 1,
|
13 |
+
padding: int = 1,
|
14 |
+
) -> nn.Conv2d:
|
15 |
+
"""
|
16 |
+
Template convolution which is typically used throughout the project
|
17 |
+
|
18 |
+
:param int in_channels: Number of input channels
|
19 |
+
:param int out_channels: Number of output channels
|
20 |
+
:param int kernel_size: Size of sliding kernel
|
21 |
+
:param int stride: How many steps kernel does when sliding
|
22 |
+
:param int padding: How many dimensions to pad
|
23 |
+
:return: Convolution layer with parameters
|
24 |
+
:rtype: nn.Conv2d
|
25 |
+
"""
|
26 |
+
return nn.Conv2d(
|
27 |
+
in_channels=in_channels,
|
28 |
+
out_channels=out_channels,
|
29 |
+
kernel_size=kernel_size,
|
30 |
+
stride=stride,
|
31 |
+
padding=padding,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def conv1d(
|
36 |
+
in_channels: int,
|
37 |
+
out_channels: int,
|
38 |
+
kernel_size: int = 1,
|
39 |
+
stride: int = 1,
|
40 |
+
padding: int = 0,
|
41 |
+
) -> nn.Conv1d:
|
42 |
+
"""
|
43 |
+
Template 1d convolution which is typically used throughout the project
|
44 |
+
|
45 |
+
:param int in_channels: Number of input channels
|
46 |
+
:param int out_channels: Number of output channels
|
47 |
+
:param int kernel_size: Size of sliding kernel
|
48 |
+
:param int stride: How many steps kernel does when sliding
|
49 |
+
:param int padding: How many dimensions to pad
|
50 |
+
:return: Convolution layer with parameters
|
51 |
+
:rtype: nn.Conv2d
|
52 |
+
"""
|
53 |
+
return nn.Conv1d(
|
54 |
+
in_channels=in_channels,
|
55 |
+
out_channels=out_channels,
|
56 |
+
kernel_size=kernel_size,
|
57 |
+
stride=stride,
|
58 |
+
padding=padding,
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def calc_out_conv(
|
63 |
+
h_in: int, w_in: int, kernel_size: int = 3, stride: int = 1, padding: int = 0
|
64 |
+
) -> Tuple[int, int]:
|
65 |
+
"""
|
66 |
+
Calculate the dimensionalities of images propagated through conv layers
|
67 |
+
|
68 |
+
:param h_in: Height of the image
|
69 |
+
:param w_in: Width of the image
|
70 |
+
:param kernel_size: Size of sliding kernel
|
71 |
+
:param stride: How many steps kernel does when sliding
|
72 |
+
:param padding: How many dimensions to pad
|
73 |
+
:return: Height and width of image through convolution
|
74 |
+
:rtype: tuple[int, int]
|
75 |
+
"""
|
76 |
+
h_out = int((h_in + 2 * padding - kernel_size) / stride + 1)
|
77 |
+
w_out = int((w_in + 2 * padding - kernel_size) / stride + 1)
|
78 |
+
return h_out, w_out
|
src/models/modules/discriminator.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Discriminator providing word-level feedback"""
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from src.models.modules.conv_utils import conv1d, conv2d
|
8 |
+
from src.models.modules.image_encoder import InceptionEncoder
|
9 |
+
|
10 |
+
|
11 |
+
class WordLevelLogits(nn.Module):
|
12 |
+
"""API for converting regional feature maps into logits for multi-class classification"""
|
13 |
+
|
14 |
+
def __init__(self) -> None:
|
15 |
+
"""
|
16 |
+
Instantiate the module with softmax on channel dimension
|
17 |
+
"""
|
18 |
+
super().__init__()
|
19 |
+
self.softmax = nn.Softmax(dim=1)
|
20 |
+
# layer for flattening the feature maps
|
21 |
+
self.flat = nn.Flatten(start_dim=2)
|
22 |
+
# change dism of of textual embs to correlate with chans of inception
|
23 |
+
self.chan_reduction = conv1d(256, 128)
|
24 |
+
|
25 |
+
def forward(
|
26 |
+
self, visual_features: torch.Tensor, word_embs: torch.Tensor, mask: torch.Tensor
|
27 |
+
) -> Any:
|
28 |
+
"""
|
29 |
+
Fuse two types of features together to get output for feeding into the classification loss
|
30 |
+
:param torch.Tensor visual_features:
|
31 |
+
Feature maps of an image after being processed by Inception encoder. Bx128x17x17
|
32 |
+
:param torch.Tensor word_embs:
|
33 |
+
Word-level embeddings from the text encoder Bx256xL
|
34 |
+
:return: Logits for each word in the picture. BxL
|
35 |
+
:rtype: Any
|
36 |
+
"""
|
37 |
+
# make textual and visual features have the same amount of channels
|
38 |
+
word_embs = self.chan_reduction(word_embs)
|
39 |
+
# flattening the feature maps
|
40 |
+
visual_features = self.flat(visual_features)
|
41 |
+
word_embs = torch.transpose(word_embs, 1, 2)
|
42 |
+
word_region_correlations = word_embs @ visual_features
|
43 |
+
# normalize across L dimension
|
44 |
+
m_norm_l = nn.functional.normalize(word_region_correlations, dim=1)
|
45 |
+
# normalize across H*W dimension
|
46 |
+
m_norm_hw = nn.functional.normalize(m_norm_l, dim=2)
|
47 |
+
m_norm_hw = torch.transpose(m_norm_hw, 1, 2)
|
48 |
+
weighted_img_feats = visual_features @ m_norm_hw
|
49 |
+
weighted_img_feats = torch.sum(weighted_img_feats, dim=1)
|
50 |
+
weighted_img_feats[mask] = -float("inf")
|
51 |
+
deltas = self.softmax(weighted_img_feats)
|
52 |
+
return deltas
|
53 |
+
|
54 |
+
|
55 |
+
class UnconditionalLogits(nn.Module):
|
56 |
+
"""Head for retrieving logits from an image"""
|
57 |
+
|
58 |
+
def __init__(self) -> None:
|
59 |
+
"""Initialize modules that reduce the features down to a set of logits"""
|
60 |
+
super().__init__()
|
61 |
+
self.conv = nn.Conv2d(128, 1, kernel_size=17)
|
62 |
+
# flattening BxLx1x1 into Bx1
|
63 |
+
self.flat = nn.Flatten()
|
64 |
+
|
65 |
+
def forward(self, visual_features: torch.Tensor) -> Any:
|
66 |
+
"""
|
67 |
+
Compute logits for unconditioned adversarial loss
|
68 |
+
|
69 |
+
:param visual_features: Local features from Inception network. Bx128x17x17
|
70 |
+
:return: Logits for unconditioned adversarial loss. Bx1
|
71 |
+
:rtype: Any
|
72 |
+
"""
|
73 |
+
# reduce channels and feature maps for visual features
|
74 |
+
visual_features = self.conv(visual_features)
|
75 |
+
# flatten Bx1x1x1 into Bx1
|
76 |
+
logits = self.flat(visual_features)
|
77 |
+
return logits
|
78 |
+
|
79 |
+
|
80 |
+
class ConditionalLogits(nn.Module):
|
81 |
+
"""Logits extractor for conditioned adversarial loss"""
|
82 |
+
|
83 |
+
def __init__(self) -> None:
|
84 |
+
super().__init__()
|
85 |
+
# layer for forming the feature maps out of textual info
|
86 |
+
self.text_to_fm = conv1d(256, 17 * 17)
|
87 |
+
# fitting the size of text channels to the size of visual channels
|
88 |
+
self.chan_aligner = conv2d(1, 128)
|
89 |
+
# for reduced textual + visual features down to 1x1 feature map
|
90 |
+
self.joint_conv = nn.Conv2d(2 * 128, 1, kernel_size=17)
|
91 |
+
# converting Bx1x1x1 into Bx1
|
92 |
+
self.flat = nn.Flatten()
|
93 |
+
|
94 |
+
def forward(self, visual_features: torch.Tensor, sent_embs: torch.Tensor) -> Any:
|
95 |
+
"""
|
96 |
+
Compute logits for conditional adversarial loss
|
97 |
+
|
98 |
+
:param torch.Tensor visual_features: Features from Inception encoder. Bx128x17x17
|
99 |
+
:param torch.Tensor sent_embs: Sentence embeddings from text encoder. Bx256
|
100 |
+
:return: Logits for conditional adversarial loss. BxL
|
101 |
+
:rtype: Any
|
102 |
+
"""
|
103 |
+
# make text and visual features have the same sizes of feature maps
|
104 |
+
# Bx256 -> Bx256x1 -> Bx289x1
|
105 |
+
sent_embs = sent_embs.view(-1, 256, 1)
|
106 |
+
sent_embs = self.text_to_fm(sent_embs)
|
107 |
+
# transform textual info into shape of visual feature maps
|
108 |
+
# Bx289x1 -> Bx1x17x17
|
109 |
+
sent_embs = sent_embs.view(-1, 1, 17, 17)
|
110 |
+
# propagate text embs through 1d conv to
|
111 |
+
# align dims with visual feature maps
|
112 |
+
sent_embs = self.chan_aligner(sent_embs)
|
113 |
+
# unite textual and visual features across the dim of channels
|
114 |
+
cross_features = torch.cat((visual_features, sent_embs), dim=1)
|
115 |
+
# reduce dims down to length of caption and form raw logits
|
116 |
+
cross_features = self.joint_conv(cross_features)
|
117 |
+
# form logits from Bx1x1x1 into Bx1
|
118 |
+
logits = self.flat(cross_features)
|
119 |
+
return logits
|
120 |
+
|
121 |
+
|
122 |
+
class Discriminator(nn.Module):
|
123 |
+
"""Simple CNN-based discriminator"""
|
124 |
+
|
125 |
+
def __init__(self) -> None:
|
126 |
+
"""Use a pretrained InceptionNet to extract features"""
|
127 |
+
super().__init__()
|
128 |
+
self.encoder = InceptionEncoder(D=128)
|
129 |
+
# define different logit extractors for different losses
|
130 |
+
self.logits_word_level = WordLevelLogits()
|
131 |
+
self.logits_uncond = UnconditionalLogits()
|
132 |
+
self.logits_cond = ConditionalLogits()
|
133 |
+
|
134 |
+
def forward(self, images: torch.Tensor) -> Any:
|
135 |
+
"""
|
136 |
+
Retrieves image features encoded by the image encoder
|
137 |
+
|
138 |
+
:param torch.Tensor images: Images to be analyzed. Bx3x256x256
|
139 |
+
:return: image features encoded by image encoder. Bx128x17x17
|
140 |
+
"""
|
141 |
+
# only taking the local features from inception
|
142 |
+
# Bx3x256x256 -> Bx128x17x17
|
143 |
+
img_features, _ = self.encoder(images)
|
144 |
+
return img_features
|
src/models/modules/downsample.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""downsample module."""
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
def down_sample(in_planes: int, out_planes: int) -> nn.Module:
|
7 |
+
"""UpSample module."""
|
8 |
+
return nn.Sequential(
|
9 |
+
nn.Conv2d(
|
10 |
+
in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False
|
11 |
+
),
|
12 |
+
nn.BatchNorm2d(out_planes),
|
13 |
+
nn.LeakyReLU(0.2, inplace=True),
|
14 |
+
)
|
src/models/modules/generator.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Generator Module"""
|
2 |
+
|
3 |
+
from typing import Any, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from src.models.modules.acm import ACM
|
9 |
+
from src.models.modules.attention import ChannelWiseAttention, SpatialAttention
|
10 |
+
from src.models.modules.cond_augment import CondAugmentation
|
11 |
+
from src.models.modules.downsample import down_sample
|
12 |
+
from src.models.modules.residual import ResidualBlock
|
13 |
+
from src.models.modules.upsample import img_up_block, up_sample
|
14 |
+
|
15 |
+
|
16 |
+
class InitStageG(nn.Module):
|
17 |
+
"""Initial Stage Generator Module"""
|
18 |
+
|
19 |
+
# pylint: disable=too-many-instance-attributes
|
20 |
+
# pylint: disable=too-many-arguments
|
21 |
+
# pylint: disable=invalid-name
|
22 |
+
# pylint: disable=too-many-locals
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self, Ng: int, Ng_init: int, conditioning_dim: int, D: int, noise_dim: int
|
26 |
+
):
|
27 |
+
"""
|
28 |
+
:param Ng: Number of channels.
|
29 |
+
:param Ng_init: Initial value of Ng, this is output channel of first image upsample.
|
30 |
+
:param conditioning_dim: Dimension of the conditioning space
|
31 |
+
:param D: Dimension of the text embedding space [D from AttnGAN paper]
|
32 |
+
:param noise_dim: Dimension of the noise space
|
33 |
+
"""
|
34 |
+
super().__init__()
|
35 |
+
self.gf_dim = Ng
|
36 |
+
self.gf_init = Ng_init
|
37 |
+
self.in_dim = noise_dim + conditioning_dim + D
|
38 |
+
self.text_dim = D
|
39 |
+
|
40 |
+
self.define_module()
|
41 |
+
|
42 |
+
def define_module(self) -> None:
|
43 |
+
"""Defines FC, Upsample, Residual, ACM, Attention modules"""
|
44 |
+
nz, ng = self.in_dim, self.gf_dim
|
45 |
+
self.fully_connect = nn.Sequential(
|
46 |
+
nn.Linear(nz, ng * 4 * 4 * 2, bias=False),
|
47 |
+
nn.BatchNorm1d(ng * 4 * 4 * 2),
|
48 |
+
nn.GLU(dim=1), # we start from 4 x 4 feat_map and return hidden_64.
|
49 |
+
)
|
50 |
+
|
51 |
+
self.upsample1 = up_sample(ng, ng // 2)
|
52 |
+
self.upsample2 = up_sample(ng // 2, ng // 4)
|
53 |
+
self.upsample3 = up_sample(ng // 4, ng // 8)
|
54 |
+
self.upsample4 = up_sample(
|
55 |
+
ng // 8 * 3, ng // 16
|
56 |
+
) # multiply channel by 3 because concat spatial and channel att
|
57 |
+
|
58 |
+
self.residual = self._make_layer(ResidualBlock, ng // 8 * 3)
|
59 |
+
self.acm_module = ACM(self.gf_init, ng // 8 * 3)
|
60 |
+
|
61 |
+
self.spatial_att = SpatialAttention(self.text_dim, ng // 8)
|
62 |
+
self.channel_att = ChannelWiseAttention(
|
63 |
+
32 * 32, self.text_dim
|
64 |
+
) # 32 x 32 is the feature map size
|
65 |
+
|
66 |
+
def _make_layer(self, block: Any, channel_num: int) -> nn.Module:
|
67 |
+
layers = []
|
68 |
+
for _ in range(2): # number of residual blocks hardcoded to 2
|
69 |
+
layers.append(block(channel_num))
|
70 |
+
return nn.Sequential(*layers)
|
71 |
+
|
72 |
+
def forward(
|
73 |
+
self,
|
74 |
+
noise: torch.Tensor,
|
75 |
+
condition: torch.Tensor,
|
76 |
+
global_inception: torch.Tensor,
|
77 |
+
local_upsampled_inception: torch.Tensor,
|
78 |
+
word_embeddings: torch.Tensor,
|
79 |
+
mask: Optional[torch.Tensor] = None,
|
80 |
+
) -> Any:
|
81 |
+
"""
|
82 |
+
:param noise: Noise tensor
|
83 |
+
:param condition: Condition tensor (c^ from stackGAN++ paper)
|
84 |
+
:param global_inception: Global inception feature
|
85 |
+
:param local_upsampled_inception: Local inception feature, upsampled to 32 x 32
|
86 |
+
:param word_embeddings: Word embeddings [shape: D x L or D x T]
|
87 |
+
:param mask: Mask for padding tokens
|
88 |
+
:return: Hidden Image feature map Tensor of 64 x 64 size
|
89 |
+
"""
|
90 |
+
noise_concat = torch.cat((noise, condition), 1)
|
91 |
+
inception_concat = torch.cat((noise_concat, global_inception), 1)
|
92 |
+
hidden = self.fully_connect(inception_concat)
|
93 |
+
hidden = hidden.view(-1, self.gf_dim, 4, 4) # convert to 4x4 image feature map
|
94 |
+
hidden = self.upsample1(hidden)
|
95 |
+
hidden = self.upsample2(hidden)
|
96 |
+
hidden_32 = self.upsample3(hidden) # shape: (batch_size, gf_dim // 8, 32, 32)
|
97 |
+
hidden_32_view = hidden_32.view(
|
98 |
+
hidden_32.shape[0], -1, hidden_32.shape[2] * hidden_32.shape[3]
|
99 |
+
) # this reshaping is done as attention module expects this shape.
|
100 |
+
|
101 |
+
spatial_att_feat = self.spatial_att(
|
102 |
+
word_embeddings, hidden_32_view, mask
|
103 |
+
) # spatial att shape: (batch, D^, 32 * 32)
|
104 |
+
channel_att_feat = self.channel_att(
|
105 |
+
spatial_att_feat, word_embeddings
|
106 |
+
) # channel att shape: (batch, D^, 32 * 32), or (batch, C, Hk* Wk) from controlGAN paper
|
107 |
+
spatial_att_feat = spatial_att_feat.view(
|
108 |
+
word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3]
|
109 |
+
) # reshape to (batch, D^, 32, 32)
|
110 |
+
channel_att_feat = channel_att_feat.view(
|
111 |
+
word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3]
|
112 |
+
) # reshape to (batch, D^, 32, 32)
|
113 |
+
|
114 |
+
spatial_concat = torch.cat(
|
115 |
+
(hidden_32, spatial_att_feat), 1
|
116 |
+
) # concat spatial attention feature with hidden_32
|
117 |
+
attn_concat = torch.cat(
|
118 |
+
(spatial_concat, channel_att_feat), 1
|
119 |
+
) # concat channel and spatial attention feature
|
120 |
+
|
121 |
+
hidden_32 = self.acm_module(attn_concat, local_upsampled_inception)
|
122 |
+
hidden_32 = self.residual(hidden_32)
|
123 |
+
hidden_64 = self.upsample4(hidden_32)
|
124 |
+
return hidden_64
|
125 |
+
|
126 |
+
|
127 |
+
class NextStageG(nn.Module):
|
128 |
+
"""Next Stage Generator Module"""
|
129 |
+
|
130 |
+
# pylint: disable=too-many-instance-attributes
|
131 |
+
# pylint: disable=too-many-arguments
|
132 |
+
# pylint: disable=invalid-name
|
133 |
+
# pylint: disable=too-many-locals
|
134 |
+
|
135 |
+
def __init__(self, Ng: int, Ng_init: int, D: int, image_size: int):
|
136 |
+
"""
|
137 |
+
:param Ng: Number of channels.
|
138 |
+
:param Ng_init: Initial value of Ng.
|
139 |
+
:param D: Dimension of the text embedding space [D from AttnGAN paper]
|
140 |
+
:param image_size: Size of the output image from previous generator stage.
|
141 |
+
"""
|
142 |
+
super().__init__()
|
143 |
+
self.gf_dim = Ng
|
144 |
+
self.gf_init = Ng_init
|
145 |
+
self.text_dim = D
|
146 |
+
self.img_size = image_size
|
147 |
+
|
148 |
+
self.define_module()
|
149 |
+
|
150 |
+
def define_module(self) -> None:
|
151 |
+
"""Defines FC, Upsample, Residual, ACM, Attention modules"""
|
152 |
+
ng = self.gf_dim
|
153 |
+
self.spatial_att = SpatialAttention(self.text_dim, ng)
|
154 |
+
self.channel_att = ChannelWiseAttention(
|
155 |
+
self.img_size * self.img_size, self.text_dim
|
156 |
+
)
|
157 |
+
|
158 |
+
self.residual = self._make_layer(ResidualBlock, ng * 3)
|
159 |
+
self.upsample = up_sample(ng * 3, ng)
|
160 |
+
self.acm_module = ACM(self.gf_init, ng * 3)
|
161 |
+
self.upsample2 = up_sample(ng, ng)
|
162 |
+
|
163 |
+
def _make_layer(self, block: Any, channel_num: int) -> nn.Module:
|
164 |
+
layers = []
|
165 |
+
for _ in range(2): # no of residual layers hardcoded to 2
|
166 |
+
layers.append(block(channel_num))
|
167 |
+
return nn.Sequential(*layers)
|
168 |
+
|
169 |
+
def forward(
|
170 |
+
self,
|
171 |
+
hidden_feat: Any,
|
172 |
+
word_embeddings: torch.Tensor,
|
173 |
+
vgg64_feat: torch.Tensor,
|
174 |
+
mask: Optional[torch.Tensor] = None,
|
175 |
+
) -> Any:
|
176 |
+
"""
|
177 |
+
:param hidden_feat: Hidden feature from previous generator stage [i.e. hidden_64]
|
178 |
+
:param word_embeddings: Word embeddings
|
179 |
+
:param vgg64_feat: VGG feature map of size 64 x 64
|
180 |
+
:param mask: Mask for the padding tokens
|
181 |
+
:return: Image feature map of size 256 x 256
|
182 |
+
"""
|
183 |
+
hidden_view = hidden_feat.view(
|
184 |
+
hidden_feat.shape[0], -1, hidden_feat.shape[2] * hidden_feat.shape[3]
|
185 |
+
) # reshape to pass into attention modules.
|
186 |
+
spatial_att_feat = self.spatial_att(
|
187 |
+
word_embeddings, hidden_view, mask
|
188 |
+
) # spatial att shape: (batch, D^, 64 * 64), or D^ x N
|
189 |
+
channel_att_feat = self.channel_att(
|
190 |
+
spatial_att_feat, word_embeddings
|
191 |
+
) # channel att shape: (batch, D^, 64 * 64), or (batch, C, Hk* Wk) from controlGAN paper
|
192 |
+
spatial_att_feat = spatial_att_feat.view(
|
193 |
+
word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3]
|
194 |
+
) # reshape to (batch, D^, 64, 64)
|
195 |
+
channel_att_feat = channel_att_feat.view(
|
196 |
+
word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3]
|
197 |
+
) # reshape to (batch, D^, 64, 64)
|
198 |
+
|
199 |
+
spatial_concat = torch.cat(
|
200 |
+
(hidden_feat, spatial_att_feat), 1
|
201 |
+
) # concat spatial attention feature with hidden_64
|
202 |
+
attn_concat = torch.cat(
|
203 |
+
(spatial_concat, channel_att_feat), 1
|
204 |
+
) # concat channel and spatial attention feature
|
205 |
+
|
206 |
+
hidden_64 = self.acm_module(attn_concat, vgg64_feat)
|
207 |
+
hidden_64 = self.residual(hidden_64)
|
208 |
+
hidden_128 = self.upsample(hidden_64)
|
209 |
+
hidden_256 = self.upsample2(hidden_128)
|
210 |
+
return hidden_256
|
211 |
+
|
212 |
+
|
213 |
+
class GetImageG(nn.Module):
|
214 |
+
"""Generates the Final Fake Image from the Image Feature Map"""
|
215 |
+
|
216 |
+
def __init__(self, Ng: int):
|
217 |
+
"""
|
218 |
+
:param Ng: Number of channels.
|
219 |
+
"""
|
220 |
+
super().__init__()
|
221 |
+
self.img = nn.Sequential(
|
222 |
+
nn.Conv2d(Ng, 3, kernel_size=3, stride=1, padding=1, bias=False), nn.Tanh()
|
223 |
+
)
|
224 |
+
|
225 |
+
def forward(self, hidden_feat: torch.Tensor) -> Any:
|
226 |
+
"""
|
227 |
+
:param hidden_feat: Image feature map
|
228 |
+
:return: Final fake image
|
229 |
+
"""
|
230 |
+
return self.img(hidden_feat)
|
231 |
+
|
232 |
+
|
233 |
+
class Generator(nn.Module):
|
234 |
+
"""Generator Module"""
|
235 |
+
|
236 |
+
# pylint: disable=too-many-instance-attributes
|
237 |
+
# pylint: disable=too-many-arguments
|
238 |
+
# pylint: disable=invalid-name
|
239 |
+
# pylint: disable=too-many-locals
|
240 |
+
|
241 |
+
def __init__(self, Ng: int, D: int, conditioning_dim: int, noise_dim: int):
|
242 |
+
"""
|
243 |
+
:param Ng: Number of channels. [Taken from StackGAN++ paper]
|
244 |
+
:param D: Dimension of the text embedding space
|
245 |
+
:param conditioning_dim: Dimension of the conditioning space
|
246 |
+
:param noise_dim: Dimension of the noise space
|
247 |
+
"""
|
248 |
+
super().__init__()
|
249 |
+
self.cond_augment = CondAugmentation(D, conditioning_dim)
|
250 |
+
self.hidden_net1 = InitStageG(Ng * 16, Ng, conditioning_dim, D, noise_dim)
|
251 |
+
self.inception_img_upsample = img_up_block(
|
252 |
+
D, Ng
|
253 |
+
) # as channel size returned by inception encoder is D (Default in paper: 256)
|
254 |
+
self.hidden_net2 = NextStageG(Ng, Ng, D, 64)
|
255 |
+
self.generate_img = GetImageG(Ng)
|
256 |
+
|
257 |
+
self.acm_module = ACM(Ng, Ng)
|
258 |
+
|
259 |
+
self.vgg_downsample = down_sample(D // 2, Ng)
|
260 |
+
self.upsample1 = up_sample(Ng, Ng)
|
261 |
+
self.upsample2 = up_sample(Ng, Ng)
|
262 |
+
|
263 |
+
def forward(
|
264 |
+
self,
|
265 |
+
noise: torch.Tensor,
|
266 |
+
sentence_embeddings: torch.Tensor,
|
267 |
+
word_embeddings: torch.Tensor,
|
268 |
+
global_inception_feat: torch.Tensor,
|
269 |
+
local_inception_feat: torch.Tensor,
|
270 |
+
vgg_feat: torch.Tensor,
|
271 |
+
mask: Optional[torch.Tensor] = None,
|
272 |
+
) -> Any:
|
273 |
+
"""
|
274 |
+
:param noise: Noise vector [shape: (batch, noise_dim)]
|
275 |
+
:param sentence_embeddings: Sentence embeddings [shape: (batch, D)]
|
276 |
+
:param word_embeddings: Word embeddings [shape: D x L, where L is length of sentence]
|
277 |
+
:param global_inception_feat: Global Inception feature map [shape: (batch, D)]
|
278 |
+
:param local_inception_feat: Local Inception feature map [shape: (batch, D, 17, 17)]
|
279 |
+
:param vgg_feat: VGG feature map [shape: (batch, D // 2 = 128, 128, 128)]
|
280 |
+
:param mask: Mask for the padding tokens
|
281 |
+
:return: Final fake image
|
282 |
+
"""
|
283 |
+
c_hat, mu_tensor, logvar = self.cond_augment(sentence_embeddings)
|
284 |
+
hidden_32 = self.inception_img_upsample(local_inception_feat)
|
285 |
+
|
286 |
+
hidden_64 = self.hidden_net1(
|
287 |
+
noise, c_hat, global_inception_feat, hidden_32, word_embeddings, mask
|
288 |
+
)
|
289 |
+
|
290 |
+
vgg_64 = self.vgg_downsample(vgg_feat)
|
291 |
+
|
292 |
+
hidden_256 = self.hidden_net2(hidden_64, word_embeddings, vgg_64, mask)
|
293 |
+
|
294 |
+
vgg_128 = self.upsample1(vgg_64)
|
295 |
+
vgg_256 = self.upsample2(vgg_128)
|
296 |
+
|
297 |
+
hidden_256 = self.acm_module(hidden_256, vgg_256)
|
298 |
+
fake_img = self.generate_img(hidden_256)
|
299 |
+
|
300 |
+
return fake_img, mu_tensor, logvar
|
src/models/modules/image_encoder.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Image Encoder Module"""
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from src.models.modules.conv_utils import conv2d
|
8 |
+
|
9 |
+
# build inception v3 image encoder
|
10 |
+
|
11 |
+
|
12 |
+
class InceptionEncoder(nn.Module):
|
13 |
+
"""Image Encoder Module adapted from AttnGAN"""
|
14 |
+
|
15 |
+
def __init__(self, D: int):
|
16 |
+
"""
|
17 |
+
:param D: Dimension of the text embedding space [D from AttnGAN paper]
|
18 |
+
"""
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
self.text_emb_dim = D
|
22 |
+
|
23 |
+
model = torch.hub.load(
|
24 |
+
"pytorch/vision:v0.10.0", "inception_v3", pretrained=True
|
25 |
+
)
|
26 |
+
for param in model.parameters():
|
27 |
+
param.requires_grad = False
|
28 |
+
|
29 |
+
self.define_module(model)
|
30 |
+
self.init_trainable_weights()
|
31 |
+
|
32 |
+
def define_module(self, model: nn.Module) -> None:
|
33 |
+
"""
|
34 |
+
This function defines the modules of the image encoder
|
35 |
+
:param model: Pretrained Inception V3 model
|
36 |
+
"""
|
37 |
+
model.cust_upsample = nn.Upsample(size=(299, 299), mode="bilinear")
|
38 |
+
model.cust_maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
|
39 |
+
model.cust_maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
|
40 |
+
model.cust_avgpool = nn.AvgPool2d(kernel_size=8)
|
41 |
+
|
42 |
+
attribute_list = [
|
43 |
+
"cust_upsample",
|
44 |
+
"Conv2d_1a_3x3",
|
45 |
+
"Conv2d_2a_3x3",
|
46 |
+
"Conv2d_2b_3x3",
|
47 |
+
"cust_maxpool1",
|
48 |
+
"Conv2d_3b_1x1",
|
49 |
+
"Conv2d_4a_3x3",
|
50 |
+
"cust_maxpool2",
|
51 |
+
"Mixed_5b",
|
52 |
+
"Mixed_5c",
|
53 |
+
"Mixed_5d",
|
54 |
+
"Mixed_6a",
|
55 |
+
"Mixed_6b",
|
56 |
+
"Mixed_6c",
|
57 |
+
"Mixed_6d",
|
58 |
+
"Mixed_6e",
|
59 |
+
]
|
60 |
+
|
61 |
+
self.feature_extractor = nn.Sequential(
|
62 |
+
*[getattr(model, name) for name in attribute_list]
|
63 |
+
)
|
64 |
+
|
65 |
+
attribute_list2 = ["Mixed_7a", "Mixed_7b", "Mixed_7c", "cust_avgpool"]
|
66 |
+
|
67 |
+
self.feature_extractor2 = nn.Sequential(
|
68 |
+
*[getattr(model, name) for name in attribute_list2]
|
69 |
+
)
|
70 |
+
|
71 |
+
self.emb_features = conv2d(
|
72 |
+
768, self.text_emb_dim, kernel_size=1, stride=1, padding=0
|
73 |
+
)
|
74 |
+
self.emb_cnn_code = nn.Linear(2048, self.text_emb_dim)
|
75 |
+
|
76 |
+
def init_trainable_weights(self) -> None:
|
77 |
+
"""
|
78 |
+
This function initializes the trainable weights of the image encoder
|
79 |
+
"""
|
80 |
+
initrange = 0.1
|
81 |
+
self.emb_features.weight.data.uniform_(-initrange, initrange)
|
82 |
+
self.emb_cnn_code.weight.data.uniform_(-initrange, initrange)
|
83 |
+
|
84 |
+
def forward(self, image_tensor: torch.Tensor) -> Any:
|
85 |
+
"""
|
86 |
+
:param image_tensor: Input image
|
87 |
+
:return: features: local feature matrix (v from attnGAN paper) [shape: (batch, D, 17, 17)]
|
88 |
+
:return: cnn_code: global image feature (v^ from attnGAN paper) [shape: (batch, D)]
|
89 |
+
"""
|
90 |
+
# this is the image size
|
91 |
+
# x.shape: 10 3 256 256
|
92 |
+
|
93 |
+
features = self.feature_extractor(image_tensor)
|
94 |
+
# 17 x 17 x 768
|
95 |
+
|
96 |
+
image_tensor = self.feature_extractor2(features)
|
97 |
+
|
98 |
+
image_tensor = image_tensor.view(image_tensor.size(0), -1)
|
99 |
+
# 2048
|
100 |
+
|
101 |
+
# global image features
|
102 |
+
cnn_code = self.emb_cnn_code(image_tensor)
|
103 |
+
|
104 |
+
if features is not None:
|
105 |
+
features = self.emb_features(features)
|
106 |
+
|
107 |
+
# feature.shape: 10 256 17 17
|
108 |
+
# cnn_code.shape: 10 256
|
109 |
+
return features, cnn_code
|
110 |
+
|
111 |
+
|
112 |
+
class VGGEncoder(nn.Module):
|
113 |
+
"""Pre Trained VGG Encoder Module"""
|
114 |
+
|
115 |
+
def __init__(self) -> None:
|
116 |
+
"""
|
117 |
+
Initialize pre-trained VGG model with frozen parameters
|
118 |
+
"""
|
119 |
+
super().__init__()
|
120 |
+
self.select = "8" ## We want to get the output of the 8th layer in VGG.
|
121 |
+
|
122 |
+
self.model = torch.hub.load("pytorch/vision:v0.10.0", "vgg16", pretrained=True)
|
123 |
+
|
124 |
+
for param in self.model.parameters():
|
125 |
+
param.resquires_grad = False
|
126 |
+
|
127 |
+
self.vgg_modules = self.model.features._modules
|
128 |
+
|
129 |
+
def forward(self, image_tensor: torch.Tensor) -> Any:
|
130 |
+
"""
|
131 |
+
:param x: Input image tensor [shape: (batch, 3, 256, 256)]
|
132 |
+
:return: VGG features [shape: (batch, 128, 128, 128)]
|
133 |
+
"""
|
134 |
+
for name, layer in self.vgg_modules.items():
|
135 |
+
image_tensor = layer(image_tensor)
|
136 |
+
if name == self.select:
|
137 |
+
return image_tensor
|
138 |
+
return None
|