diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d2ad9ef9c925b734c417b8036bd6c8c2017f8975 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/* +.idea/* diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..b5aaa38c2925865aacefc06a6085ac0b1e93d04b --- /dev/null +++ b/app.py @@ -0,0 +1,117 @@ +import numpy as np # this should come first to mitigate mlk-service bug +from src.models.utils import get_image_arr, load_model +from src.data import TAIMGANTokenizer +from torchvision import transforms +from src.config import config_dict +from pathlib import Path +from enum import IntEnum, auto +from PIL import Image +import gradio as gr +import torch +from src.models.modules import ( + VGGEncoder, + InceptionEncoder, + TextEncoder, + Generator +) + +########## +# PARAMS # +########## + +IMG_CHANS = 3 # RGB channels for image +IMG_HW = 256 # height and width of images +HIDDEN_DIM = 128 # hidden dimensions of lstm cell in one direction +C = 2 * HIDDEN_DIM # length of embeddings + +Ng = config_dict["Ng"] +cond_dim = config_dict["condition_dim"] +z_dim = config_dict["noise_dim"] + + +############### +# LOAD MODELS # +############### + +models = { + "COCO": { + "dir": "weights/coco" + }, + "Bird": { + "dir": "weights/bird" + }, + "UTKFace": { + "dir": "weights/utkface" + } +} + +for model_name in models: + # create tokenizer + models[model_name]["tokenizer"] = TAIMGANTokenizer(captions_path=f"{models[model_name]['dir']}/captions.pickle") + vocab_size = len(models[model_name]["tokenizer"].word_to_ix) + # instantiate models + models[model_name]["generator"] = Generator(Ng=Ng, D=C, conditioning_dim=cond_dim, noise_dim=z_dim).eval() + models[model_name]["lstm"] = TextEncoder(vocab_size=vocab_size, emb_dim=C, hidden_dim=HIDDEN_DIM).eval() + models[model_name]["vgg"] = VGGEncoder().eval() + models[model_name]["inception"] = InceptionEncoder(D=C).eval() + # load models + load_model( + generator=models[model_name]["generator"], + discriminator=None, + image_encoder=models[model_name]["inception"], + text_encoder=models[model_name]["lstm"], + output_dir=Path(models[model_name]["dir"]), + device=torch.device("cpu") + ) + + +def change_image_with_text(image: Image, text: str, model_name: str) -> Image: + """ + Create an image modified by text from the original image + and save it with _modified postfix + + :param gr.Image image: Path to the image + :param str text: Desired caption + """ + global models + tokenizer = models[model_name]["tokenizer"] + G = models[model_name]["generator"] + lstm = models[model_name]["lstm"] + inception = models[model_name]["inception"] + vgg = models[model_name]["vgg"] + # generate some noise + noise = torch.rand(z_dim).unsqueeze(0) + # transform input text and get masks with embeddings + tokens = torch.tensor(tokenizer.encode(text)).unsqueeze(0) + mask = (tokens == tokenizer.pad_token_id) + word_embs, sent_embs = lstm(tokens) + # open the image and transform it to the tensor + image = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((IMG_HW, IMG_HW)), + transforms.Normalize( + mean=(0.5, 0.5, 0.5), + std=(0.5, 0.5, 0.5) + ) + ])(image).unsqueeze(0) + # obtain visual features of the image + vgg_features = vgg(image) + local_features, global_features = inception(image) + # generate new image from the old one + fake_image, _, _ = G(noise, sent_embs, word_embs, global_features, + local_features, vgg_features, mask) + # denormalize the image + fake_image = Image.fromarray(get_image_arr(fake_image)[0]) + # return image in gradio format + return fake_image + + +########## +# GRADIO # +########## +demo = gr.Interface( + fn=change_image_with_text, + inputs=[gr.Image(type="pil"), "text", gr.inputs.Dropdown(list(models.keys()))], + outputs=gr.Image(type="pil") +) +demo.launch(debug=True) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..21680fc103ba0ce14713d919dfb2cf318866673f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +Pillow +torch +torchvision +torchaudio +nltk \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff3ca376bd0589f94cef21d77f7c5727ec24c383 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,2 @@ +"""Config file for the project.""" +from .config import config_dict, update_config diff --git a/src/__pycache__/__init__.cpython-39.pyc b/src/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ad50e1ee18e96375aba110703e8292c2f24bb01 Binary files /dev/null and b/src/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/__pycache__/config.cpython-39.pyc b/src/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..801f56b3008934b1b58f9548a802b8fabc16f522 Binary files /dev/null and b/src/__pycache__/config.cpython-39.pyc differ diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000000000000000000000000000000000000..3e62f69b18f24ea4126d62d5a593ccd49f62af91 --- /dev/null +++ b/src/config.py @@ -0,0 +1,47 @@ +"""Configurations for the project.""" +from pathlib import Path +from typing import Any, Dict + +import torch + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +repo_path = Path(__file__).parent.parent.absolute() +output_path = repo_path / "models" + +config_dict = { + "Ng": 32, + "D": 256, + "condition_dim": 100, + "noise_dim": 100, + "lr_config": { + "disc_lr": 2e-4, + "gen_lr": 2e-4, + "img_encoder_lr": 3e-3, + "text_encoder_lr": 3e-3, + }, + "batch_size": 64, + "device": device, + "epochs": 200, + "output_dir": output_path, + "snapshot": 5, + "const_dict": { + "smooth_val_gen": 0.999, + "lambda1": 1, + "lambda2": 1, + "lambda3": 1, + "lambda4": 1, + "gamma1": 4, + "gamma2": 5, + "gamma3": 10, + }, +} + + +def update_config(cfg_dict: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + """ + Function to update the configuration dictionary. + """ + for key, value in kwargs.items(): + cfg_dict[key] = value + return cfg_dict diff --git a/src/data/.gitkeep b/src/data/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa29aef109004daac476c8f18628dd3525599103 --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1,5 @@ +"""Dataset and custom collate function to load""" + +from .collate import custom_collate +from .datasets import TextImageDataset +from .tokenizer import TAIMGANTokenizer diff --git a/src/data/__pycache__/__init__.cpython-39.pyc b/src/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d42ddc2e4f19915e18ae2014b180881b12189a24 Binary files /dev/null and b/src/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/data/__pycache__/collate.cpython-39.pyc b/src/data/__pycache__/collate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d008b98d2a8ae19e3ea5797b637a166fc92234fc Binary files /dev/null and b/src/data/__pycache__/collate.cpython-39.pyc differ diff --git a/src/data/__pycache__/datasets.cpython-39.pyc b/src/data/__pycache__/datasets.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea1825be5ed30c71947209b67808ddc4f7e2e86b Binary files /dev/null and b/src/data/__pycache__/datasets.cpython-39.pyc differ diff --git a/src/data/__pycache__/tokenizer.cpython-39.pyc b/src/data/__pycache__/tokenizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0f085a785a595eac65f8b0f06a75a0eeb792138 Binary files /dev/null and b/src/data/__pycache__/tokenizer.cpython-39.pyc differ diff --git a/src/data/collate.py b/src/data/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..220060f52bc6f915875a78b3c973ae288435968e --- /dev/null +++ b/src/data/collate.py @@ -0,0 +1,43 @@ +"""Custom collate function for the data loader.""" + +from typing import Any, List + +import torch +from torch.nn.utils.rnn import pad_sequence + + +def custom_collate(batch: List[Any], device: Any) -> Any: + """ + Custom collate function to be used in the data loader. + :param batch: list, with length equal to number of batches. + :return: processed batch of data [add padding to text, stack tensors in batch] + """ + img, correct_capt, curr_class, word_labels = zip(*batch) + batched_img = torch.stack(img, dim=0).to( + device + ) # shape: (batch_size, 3, height, width) + correct_capt_len = torch.tensor( + [len(capt) for capt in correct_capt], dtype=torch.int64 + ).unsqueeze( + 1 + ) # shape: (batch_size, 1) + batched_correct_capt = pad_sequence( + correct_capt, batch_first=True, padding_value=0 + ).to( + device + ) # shape: (batch_size, max_seq_len) + batched_curr_class = torch.stack(curr_class, dim=0).to( + device + ) # shape: (batch_size, 1) + batched_word_labels = pad_sequence( + word_labels, batch_first=True, padding_value=0 + ).to( + device + ) # shape: (batch_size, max_seq_len) + return ( + batched_img, + batched_correct_capt, + correct_capt_len, + batched_curr_class, + batched_word_labels, + ) diff --git a/src/data/datasets.py b/src/data/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..e3d0879cae8681beb289c3a6912fb9101b949002 --- /dev/null +++ b/src/data/datasets.py @@ -0,0 +1,387 @@ +"""Pytorch Dataset classes for the datasets used in the project.""" + +import os +import pickle +from collections import defaultdict +from typing import Any + +import nltk +import numpy as np +import pandas as pd +import torch +import torchvision.transforms.functional as F +from nltk.tokenize import RegexpTokenizer +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class TextImageDataset(Dataset): # type: ignore + """Custom PyTorch Dataset class to load Image and Text data.""" + + # pylint: disable=too-many-instance-attributes + # pylint: disable=too-many-locals + # pylint: disable=too-many-function-args + + def __init__( + self, data_path: str, split: str, num_captions: int, transform: Any = None + ): + """ + :param data_path: Path to the data directory. [i.e. can be './birds/', or './coco/] + :param split: 'train' or 'test' split + :param num_captions: number of captions present per image. + [For birds, this is 10, for coco, this is 5] + :param transform: PyTorch transform to apply to the images. + """ + self.transform = transform + self.bound_box_map = None + self.file_names = self.load_filenames(data_path, split) + self.data_path = data_path + self.num_captions_per_image = num_captions + ( + self.captions, + self.ix_to_word, + self.word_to_ix, + self.vocab_len, + ) = self.get_capt_and_vocab(data_path, split) + self.normalize = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + self.class_ids = self.get_class_id(data_path, split, len(self.file_names)) + if self.data_path.endswith("birds/"): + self.bound_box_map = self.get_bound_box(data_path) + + elif self.data_path.endswith("coco/"): + pass + + else: + raise ValueError( + "Invalid data path. Please ensure the data [CUB/COCO] is stored in correct folders." + ) + + def __len__(self) -> int: + """Return the length of the dataset.""" + return len(self.file_names) + + def __getitem__(self, idx: int) -> Any: + """ + Return the item at index idx. + :param idx: index of the item to return + :return img_tensor: image tensor + :return correct_caption: correct caption for the image [list of word indices] + :return curr_class_id: class id of the image + :return word_labels: POS_tagged word labels [1 for noun and adjective, 0 else] + + """ + file_name = self.file_names[idx] + curr_class_id = self.class_ids[idx] + + if self.bound_box_map is not None: + bbox = self.bound_box_map[file_name] + images_dir = os.path.join(self.data_path, "CUB_200_2011/images") + else: + bbox = None + images_dir = os.path.join(self.data_path, "images") + + img_path = os.path.join(images_dir, file_name + ".jpg") + img_tensor = self.get_image(img_path, bbox, self.transform) + + rand_sent_idx = np.random.randint(0, self.num_captions_per_image) + rand_sent_idx = idx * self.num_captions_per_image + rand_sent_idx + + correct_caption = torch.tensor(self.captions[rand_sent_idx], dtype=torch.int64) + num_words = len(correct_caption) + + capt_token_list = [] + for i in range(num_words): + capt_token_list.append(self.ix_to_word[correct_caption[i].item()]) + + pos_tag_list = nltk.tag.pos_tag(capt_token_list) + word_labels = [] + + for pos_tag in pos_tag_list: + if ( + "NN" in pos_tag[1] or "JJ" in pos_tag[1] + ): # check for Nouns and Adjective only + word_labels.append(1) + else: + word_labels.append(0) + + word_labels = torch.tensor(word_labels).float() # type: ignore + + curr_class_id = torch.tensor(curr_class_id, dtype=torch.int64).unsqueeze(0) + + return ( + img_tensor, + correct_caption, + curr_class_id, + word_labels, + ) + + def get_capt_and_vocab(self, data_dir: str, split: str) -> Any: + """ + Helper function to get the captions, vocab dict for each image. + :param data_dir: path to the data directory [i.e. './birds/' or './coco/'] + :param split: 'train' or 'test' split + :return captions: list of all captions for each image + :return ix_to_word: dictionary mapping index to word + :return word_to_ix: dictionary mapping word to index + :return num_words: number of unique words in the vocabulary + """ + captions_ckpt_path = os.path.join(data_dir, "stubs/captions.pickle") + if os.path.exists( + captions_ckpt_path + ): # check if previously processed captions exist + with open(captions_ckpt_path, "rb") as ckpt_file: + captions = pickle.load(ckpt_file) + train_captions, test_captions = captions[0], captions[1] + ix_to_word, word_to_ix = captions[2], captions[3] + num_words = len(ix_to_word) + del captions + if split == "train": + return train_captions, ix_to_word, word_to_ix, num_words + return test_captions, ix_to_word, word_to_ix, num_words + + else: # if not, process the captions and save them + train_files = self.load_filenames(data_dir, "train") + test_files = self.load_filenames(data_dir, "test") + + train_captions_tokenized = self.get_tokenized_captions( + data_dir, train_files + ) + test_captions_tokenized = self.get_tokenized_captions( + data_dir, test_files + ) # we need both train and test captions to build the vocab + + ( + train_captions, + test_captions, + ix_to_word, + word_to_ix, + num_words, + ) = self.build_vocab( # type: ignore + train_captions_tokenized, test_captions_tokenized, split + ) + vocab_list = [train_captions, test_captions, ix_to_word, word_to_ix] + with open(captions_ckpt_path, "wb") as ckpt_file: + pickle.dump(vocab_list, ckpt_file) + + if split == "train": + return train_captions, ix_to_word, word_to_ix, num_words + if split == "test": + return test_captions, ix_to_word, word_to_ix, num_words + raise ValueError("Invalid split. Please use 'train' or 'test'") + + def build_vocab( + self, tokenized_captions_train: list, tokenized_captions_test: list # type: ignore + ) -> Any: + """ + Helper function which builds the vocab dicts. + :param tokenized_captions_train: list containing all the + train tokenized captions in the dataset. This is list of lists. + :param tokenized_captions_test: list containing all the + test tokenized captions in the dataset. This is list of lists. + :return train_captions_int: list of all captions in training, + where each word is replaced by its index in the vocab + :return test_captions_int: list of all captions in test, + where each word is replaced by its index in the vocab + :return ix_to_word: dictionary mapping index to word + :return word_to_ix: dictionary mapping word to index + :return num_words: number of unique words in the vocabulary + """ + vocab = defaultdict(int) # type: ignore + total_captions = tokenized_captions_train + tokenized_captions_test + for caption in total_captions: + for word in caption: + vocab[word] += 1 + + # sort vocab dict by frequency in descending order + vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True) # type: ignore + + ix_to_word = {} + word_to_ix = {} + ix_to_word[0] = "" + word_to_ix[""] = 0 + + word_idx = 1 + for word, _ in vocab: + word_to_ix[word] = word_idx + ix_to_word[word_idx] = word + word_idx += 1 + + train_captions_int = [] # we want to convert words to indices in vocab. + for caption in tokenized_captions_train: + curr_caption_int = [] + for word in caption: + curr_caption_int.append(word_to_ix[word]) + + train_captions_int.append(curr_caption_int) + + test_captions_int = [] + for caption in tokenized_captions_test: + curr_caption_int = [] + for word in caption: + curr_caption_int.append(word_to_ix[word]) + + test_captions_int.append(curr_caption_int) + + return ( + train_captions_int, + test_captions_int, + ix_to_word, + word_to_ix, + len(ix_to_word), + ) + + def get_tokenized_captions(self, data_dir: str, filenames: list) -> Any: # type: ignore + """ + Helper function to tokenize and return captions for each image in filenames. + :param data_dir: path to the data directory [i.e. './birds/' or './coco/'] + :param filenames: list of all filenames corresponding to the split + :return tokenized_captions: list of all tokenized captions for all files in filenames. + [this returns a list, where each element is again a list of tokens/words] + """ + + all_captions = [] + for filename in filenames: + caption_path = os.path.join(data_dir, "text", filename + ".txt") + with open(caption_path, "r", encoding="utf8") as txt_file: + captions = txt_file.readlines() + count = 0 + for caption in captions: + if len(caption) == 0: + continue + + caption = caption.replace("\ufffd\ufffd", " ") + tokenizer = RegexpTokenizer(r"\w+") + tokens = tokenizer.tokenize( + caption.lower() + ) # splits current caption/line to list of words/tokens + if len(tokens) == 0: + continue + + tokens = [ + t.encode("ascii", "ignore").decode("ascii") for t in tokens + ] + tokens = [t for t in tokens if len(t) > 0] + + all_captions.append(tokens) + count += 1 + if count == self.num_captions_per_image: + break + if count < self.num_captions_per_image: + raise ValueError( + f"Number of captions for {filename} is only {count},\ + which is less than {self.num_captions_per_image}." + ) + + return all_captions + + def get_image(self, img_path: str, bbox: list, transform: Any) -> Any: # type: ignore + """ + Helper function to load and transform an image. + :param img_path: path to the image + :param bbox: bounding box coordinates [x, y, width, height] + :param transform: PyTorch transform to apply to the image + :return img_tensor: transformed image tensor + """ + img = Image.open(img_path).convert("RGB") + width, height = img.size + + if bbox is not None: + r_val = int(np.maximum(bbox[2], bbox[3]) * 0.75) + + center_x = int((2 * bbox[0] + bbox[2]) / 2) + center_y = int((2 * bbox[1] + bbox[3]) / 2) + y1_coord = np.maximum(0, center_y - r_val) + y2_coord = np.minimum(height, center_y + r_val) + x1_coord = np.maximum(0, center_x - r_val) + x2_coord = np.minimum(width, center_x + r_val) + + img = img.crop( + [x1_coord, y1_coord, x2_coord, y2_coord] + ) # This preprocessing steps seems to follow from + # Stackgan: Text to photo-realistic image synthesis + + if transform is not None: + img_tensor = transform(img) # this scales to 304x304, i.e. 256 x (76/64). + x_val = np.random.randint(0, 48) # 304 - 256 = 48 + y_val = np.random.randint(0, 48) + flip = np.random.rand() > 0.5 + + # crop + img_tensor = img_tensor.crop( + [x_val, y_val, x_val + 256, y_val + 256] + ) # this crops to 256x256 + if flip: + img_tensor = F.hflip(img_tensor) + + img_tensor = self.normalize(img_tensor) + + return img_tensor + + def load_filenames(self, data_dir: str, split: str) -> Any: + """ + Helper function to get list of all image filenames. + :param data_dir: path to the data directory [i.e. './birds/' or './coco/'] + :param split: 'train' or 'test' split + :return filenames: list of all image filenames + """ + filepath = f"{data_dir}{split}/filenames.pickle" + if os.path.isfile(filepath): + with open(filepath, "rb") as pick_file: + filenames = pickle.load(pick_file) + else: + raise ValueError( + "Invalid split. Please use 'train' or 'test',\ + or make sure the filenames.pickle file exists." + ) + return filenames + + def get_class_id(self, data_dir: str, split: str, total_elems: int) -> Any: + """ + Helper function to get list of all image class ids. + :param data_dir: path to the data directory [i.e. './birds/' or './coco/'] + :param split: 'train' or 'test' split + :param total_elems: total number of elements in the dataset + :return class_ids: list of all image class ids + """ + filepath = f"{data_dir}{split}/class_info.pickle" + if os.path.isfile(filepath): + with open(filepath, "rb") as class_file: + class_ids = pickle.load(class_file, encoding="latin1") + else: + class_ids = np.arange(total_elems) + return class_ids + + def get_bound_box(self, data_path: str) -> Any: + """ + Helper function to get the bounding box for birds dataset. + :param data_path: path to birds data directory [i.e. './data/birds/'] + :return imageToBox: dictionary mapping image name to bounding box coordinates + """ + bbox_path = os.path.join(data_path, "CUB_200_2011/bounding_boxes.txt") + df_bounding_boxes = pd.read_csv( + bbox_path, delim_whitespace=True, header=None + ).astype(int) + + filepath = os.path.join(data_path, "CUB_200_2011/images.txt") + df_filenames = pd.read_csv(filepath, delim_whitespace=True, header=None) + filenames = df_filenames[ + 1 + ].tolist() # df_filenames[0] just contains the index or ID. + + img_to_box = { # type: ignore + img_file[:-4]: [] for img_file in filenames + } # remove the .jpg extension from the names + num_imgs = len(filenames) + + for i in range(0, num_imgs): + bbox = df_bounding_boxes.iloc[i][1:].tolist() + key = filenames[i][:-4] + img_to_box[key] = bbox + + return img_to_box diff --git a/src/data/stubs/bird.jpg b/src/data/stubs/bird.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5bd625eb1162cc34e286f77c8367819d5fb67c62 Binary files /dev/null and b/src/data/stubs/bird.jpg differ diff --git a/src/data/stubs/pigeon.jpg b/src/data/stubs/pigeon.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fb6e1274a47a5a684f6169723ba2f4e8bf3a3be0 Binary files /dev/null and b/src/data/stubs/pigeon.jpg differ diff --git a/src/data/stubs/rohit.jpeg b/src/data/stubs/rohit.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..6d745af0d263a8621ed844ac1ecccd302325ea61 Binary files /dev/null and b/src/data/stubs/rohit.jpeg differ diff --git a/src/data/tokenizer.py b/src/data/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8ef6a40e9a92d4a1851cc3c3277d5f566ec9cef4 --- /dev/null +++ b/src/data/tokenizer.py @@ -0,0 +1,23 @@ +import pickle +import re +from typing import List + + +class TAIMGANTokenizer: + def __init__(self, captions_path): + with open(captions_path, "rb") as ckpt_file: + captions = pickle.load(ckpt_file) + self.ix_to_word = captions[2] + self.word_to_ix = captions[3] + self.token_regex = r'\w+' + self.pad_token_id = self.word_to_ix[""] + self.pad_repr = "[PAD]" + + def encode(self, text: str) -> List[int]: + return [self.word_to_ix.get(word, self.pad_token_id) + for word in re.findall(self.token_regex, text.lower())] + + def decode(self, tokens: List[int]) -> str: + return ' '.join([self.ix_to_word[token] + if token != self.pad_token_id else self.pad_repr + for token in tokens]) diff --git a/src/features/.gitkeep b/src/features/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/features/__init__.py b/src/features/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/features/build_features.py b/src/features/build_features.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/.gitkeep b/src/models/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b058f6d3433c5d50041f5df3e69a163cdb8062cd --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,4 @@ +"""Helper functions for training loop.""" +from .losses import discriminator_loss, generator_loss, kl_loss +from .train_model import train +from .utils import copy_gen_params, define_optimizers, load_params, prepare_labels diff --git a/src/models/__pycache__/__init__.cpython-39.pyc b/src/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..765c9f4022695c3e61e68848af1bebdcbb33b433 Binary files /dev/null and b/src/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/models/__pycache__/losses.cpython-39.pyc b/src/models/__pycache__/losses.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbd7c0a99734dc6a9fbbe2a62b5e69966e9890f5 Binary files /dev/null and b/src/models/__pycache__/losses.cpython-39.pyc differ diff --git a/src/models/__pycache__/train_model.cpython-39.pyc b/src/models/__pycache__/train_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29ac7284721a002519ed26ffff6855acb3faf083 Binary files /dev/null and b/src/models/__pycache__/train_model.cpython-39.pyc differ diff --git a/src/models/__pycache__/utils.cpython-39.pyc b/src/models/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22bdacc388721249e3d9b121b10db868671b3bc2 Binary files /dev/null and b/src/models/__pycache__/utils.cpython-39.pyc differ diff --git a/src/models/losses.py b/src/models/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..6579694ae4e96921a49a53689efad6fc9b6f2bd4 --- /dev/null +++ b/src/models/losses.py @@ -0,0 +1,344 @@ +"""Module containing the loss functions for the GANs.""" +from typing import Any, Dict + +import torch +from torch import nn + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-locals + + +def generator_loss( + logits: Dict[str, Dict[str, torch.Tensor]], + local_fake_incept_feat: torch.Tensor, + global_fake_incept_feat: torch.Tensor, + real_labels: torch.Tensor, + words_emb: torch.Tensor, + sent_emb: torch.Tensor, + match_labels: torch.Tensor, + cap_lens: torch.Tensor, + class_ids: torch.Tensor, + real_vgg_feat: torch.Tensor, + fake_vgg_feat: torch.Tensor, + const_dict: Dict[str, float], +) -> Any: + """Calculate the loss for the generator. + + Args: + logits: Dictionary with fake/real and word-level/uncond/cond logits + + local_fake_incept_feat: The local inception features for the fake images. + + global_fake_incept_feat: The global inception features for the fake images. + + real_labels: Label for "real" image as predicted by discriminator, + this is a tensor of ones. [shape: (batch_size, 1)]. + + word_labels: POS tagged word labels for the captions. [shape: (batch_size, L)] + + words_emb: The embeddings for all the words in the captions. + shape: (batch_size, embedding_size, max_caption_length) + + sent_emb: The embeddings for the sentences. + shape: (batch_size, embedding_size) + + match_labels: Tensor of shape: (batch_size, 1). + This is of the form torch.tensor([0, 1, 2, ..., batch-1]) + + cap_lens: The length of the 'actual' captions in the batch [without padding] + shape: (batch_size, 1) + + class_ids: The class ids for the instance. shape: (batch_size, 1) + + real_vgg_feat: The vgg features for the real images. shape: (batch_size, 128, 128, 128) + fake_vgg_feat: The vgg features for the fake images. shape: (batch_size, 128, 128, 128) + + const_dict: The dictionary containing the constants. + """ + lambda1 = const_dict["lambda1"] + total_error_g = 0.0 + + cond_logits = logits["fake"]["cond"] + cond_err_g = nn.BCEWithLogitsLoss()(cond_logits, real_labels) + + uncond_logits = logits["fake"]["uncond"] + uncond_err_g = nn.BCEWithLogitsLoss()(uncond_logits, real_labels) + + # add up the conditional and unconditional losses + loss_g = cond_err_g + uncond_err_g + total_error_g += loss_g + + # DAMSM Loss from attnGAN. + loss_damsm = damsm_loss( + local_fake_incept_feat, + global_fake_incept_feat, + words_emb, + sent_emb, + match_labels, + cap_lens, + class_ids, + const_dict, + ) + + total_error_g += loss_damsm + + loss_per = 0.5 * nn.MSELoss()(real_vgg_feat, fake_vgg_feat) # perceptual loss + + total_error_g += lambda1 * loss_per + + return total_error_g + + +def damsm_loss( + local_incept_feat: torch.Tensor, + global_incept_feat: torch.Tensor, + words_emb: torch.Tensor, + sent_emb: torch.Tensor, + match_labels: torch.Tensor, + cap_lens: torch.Tensor, + class_ids: torch.Tensor, + const_dict: Dict[str, float], +) -> Any: + """Calculate the DAMSM loss from the attnGAN paper. + + Args: + local_incept_feat: The local inception features. [shape: (batch, D, 17, 17)] + + global_incept_feat: The global inception features. [shape: (batch, D)] + + words_emb: The embeddings for all the words in the captions. + + shape: (batch, D, max_caption_length) + + sent_emb: The embeddings for the sentences. shape: (batch_size, D) + + match_labels: Tensor of shape: (batch_size, 1). + This is of the form torch.tensor([0, 1, 2, ..., batch-1]) + + cap_lens: The length of the 'actual' captions in the batch [without padding] + shape: (batch_size, 1) + + class_ids: The class ids for the instance. shape: (batch, 1) + + const_dict: The dictionary containing the constants. + """ + batch_size = match_labels.size(0) + # Mask mis-match samples, that come from the same class as the real sample + masks = [] + + match_scores = [] + gamma1 = const_dict["gamma1"] + gamma2 = const_dict["gamma2"] + gamma3 = const_dict["gamma3"] + lambda3 = const_dict["lambda3"] + + for i in range(batch_size): + mask = (class_ids == class_ids[i]).int() + # This ensures that "correct class" index is not included in the mask. + mask[i] = 0 + masks.append(mask.reshape(1, -1)) # shape: (1, batch) + + numb_words = int(cap_lens[i]) + # shape: (1, D, L), this picks the caption at ith batch index. + query_words = words_emb[i, :, :numb_words].unsqueeze(0) + # shape: (batch, D, L), this expands the same caption for all batch indices. + query_words = query_words.repeat(batch_size, 1, 1) + + c_i = compute_region_context_vector( + local_incept_feat, query_words, gamma1 + ) # Taken from attnGAN paper. shape: (batch, D, L) + + query_words = query_words.transpose(1, 2) # shape: (batch, L, D) + c_i = c_i.transpose(1, 2) # shape: (batch, L, D) + query_words = query_words.reshape( + batch_size * numb_words, -1 + ) # shape: (batch * L, D) + c_i = c_i.reshape(batch_size * numb_words, -1) # shape: (batch * L, D) + + r_i = compute_relevance( + c_i, query_words + ) # cosine similarity, or R(c_i, e_i) from attnGAN paper. shape: (batch * L, 1) + r_i = r_i.view(batch_size, numb_words) # shape: (batch, L) + r_i = torch.exp(r_i * gamma2) # shape: (batch, L) + r_i = r_i.sum(dim=1, keepdim=True) # shape: (batch, 1) + r_i = torch.log( + r_i + ) # This is image-text matching score b/w whole image and caption, shape: (batch, 1) + match_scores.append(r_i) + + masks = torch.cat(masks, dim=0).bool() # type: ignore + match_scores = torch.cat(match_scores, dim=1) # type: ignore + + # This corresponds to P(D|Q) from attnGAN. + match_scores = gamma3 * match_scores # type: ignore + match_scores.data.masked_fill_( # type: ignore + masks, -float("inf") + ) # mask out the scores for mis-matched samples + + match_scores_t = match_scores.transpose( # type: ignore + 0, 1 + ) # This corresponds to P(Q|D) from attnGAN. + + # This corresponds to L1_w from attnGAN. + l1_w = nn.CrossEntropyLoss()(match_scores, match_labels) + # This corresponds to L2_w from attnGAN. + l2_w = nn.CrossEntropyLoss()(match_scores_t, match_labels) + + incept_feat_norm = torch.linalg.norm(global_incept_feat, dim=1) + sent_emb_norm = torch.linalg.norm(sent_emb, dim=1) + + # shape: (batch, batch) + global_match_score = global_incept_feat @ (sent_emb.T) + + global_match_score = ( + global_match_score / torch.outer(incept_feat_norm, sent_emb_norm) + ).clamp(min=1e-8) + global_match_score = gamma3 * global_match_score + + # mask out the scores for mis-matched samples + global_match_score.data.masked_fill_(masks, -float("inf")) # type: ignore + + global_match_t = global_match_score.T # shape: (batch, batch) + + # This corresponds to L1_s from attnGAN. + l1_s = nn.CrossEntropyLoss()(global_match_score, match_labels) + # This corresponds to L2_s from attnGAN. + l2_s = nn.CrossEntropyLoss()(global_match_t, match_labels) + + loss_damsm = lambda3 * (l1_w + l2_w + l1_s + l2_s) + + return loss_damsm + + +def compute_relevance(c_i: torch.Tensor, query_words: torch.Tensor) -> Any: + """Computes the cosine similarity between the region context vector and the query words. + + Args: + c_i: The region context vector. shape: (batch * L, D) + query_words: The query words. shape: (batch * L, D) + """ + prod = c_i * query_words # shape: (batch * L, D) + numr = torch.sum(prod, dim=1) # shape: (batch * L, 1) + norm_c = torch.linalg.norm(c_i, ord=2, dim=1) + norm_q = torch.linalg.norm(query_words, ord=2, dim=1) + denr = norm_c * norm_q + r_i = (numr / denr).clamp(min=1e-8).squeeze() # shape: (batch * L, 1) + return r_i + + +def compute_region_context_vector( + local_incept_feat: torch.Tensor, query_words: torch.Tensor, gamma1: float +) -> Any: + """Compute the region context vector (c_i) from attnGAN paper. + + Args: + local_incept_feat: The local inception features. [shape: (batch, D, 17, 17)] + query_words: The embeddings for all the words in the captions. shape: (batch, D, L) + gamma1: The gamma1 value from attnGAN paper. + """ + batch, L = query_words.size(0), query_words.size(2) # pylint: disable=invalid-name + + feat_height, feat_width = local_incept_feat.size(2), local_incept_feat.size(3) + N = feat_height * feat_width # pylint: disable=invalid-name + + # Reshape the local inception features to (batch, D, N) + local_incept_feat = local_incept_feat.view(batch, -1, N) + # shape: (batch, N, D) + incept_feat_t = local_incept_feat.transpose(1, 2) + + sim_matrix = incept_feat_t @ query_words # shape: (batch, N, L) + sim_matrix = sim_matrix.view(batch * N, L) # shape: (batch * N, L) + + sim_matrix = nn.Softmax(dim=1)(sim_matrix) # shape: (batch * N, L) + sim_matrix = sim_matrix.view(batch, N, L) # shape: (batch, N, L) + + sim_matrix = torch.transpose(sim_matrix, 1, 2) # shape: (batch, L, N) + sim_matrix = sim_matrix.reshape(batch * L, N) # shape: (batch * L, N) + + alpha_j = gamma1 * sim_matrix # shape: (batch * L, N) + alpha_j = nn.Softmax(dim=1)(alpha_j) # shape: (batch * L, N) + alpha_j = alpha_j.view(batch, L, N) # shape: (batch, L, N) + alpha_j_t = torch.transpose(alpha_j, 1, 2) # shape: (batch, N, L) + + c_i = ( + local_incept_feat @ alpha_j_t + ) # shape: (batch, D, L) [summing over N dimension in paper, so we multiply like this] + return c_i + + +def discriminator_loss( + logits: Dict[str, Dict[str, torch.Tensor]], + labels: Dict[str, Dict[str, torch.Tensor]], +) -> Any: + """ + Calculate discriminator objective + + :param dict[str, dict[str, torch.Tensor]] logits: + Dictionary with fake/real and word-level/uncond/cond logits + + Example: + + logits = { + "fake": { + "word_level": torch.Tensor (BxL) + "uncond": torch.Tensor (Bx1) + "cond": torch.Tensor (Bx1) + }, + "real": { + "word_level": torch.Tensor (BxL) + "uncond": torch.Tensor (Bx1) + "cond": torch.Tensor (Bx1) + }, + } + :param dict[str, dict[str, torch.Tensor]] labels: + Dictionary with fake/real and word-level/image labels + + Example: + + labels = { + "fake": { + "word_level": torch.Tensor (BxL) + "image": torch.Tensor (Bx1) + }, + "real": { + "word_level": torch.Tensor (BxL) + "image": torch.Tensor (Bx1) + }, + } + :param float lambda_4: Hyperparameter for word loss in paper + :return: Discriminator objective loss + :rtype: Any + """ + # define main loss functions for logit losses + tot_loss = 0.0 + bce_logits = nn.BCEWithLogitsLoss() + bce = nn.BCELoss() + # calculate word-level loss + word_loss = bce(logits["real"]["word_level"], labels["real"]["word_level"]) + # calculate unconditional adversarial loss + uncond_loss = bce_logits(logits["real"]["uncond"], labels["real"]["image"]) + + # calculate conditional adversarial loss + cond_loss = bce_logits(logits["real"]["cond"], labels["real"]["image"]) + + tot_loss = (uncond_loss + cond_loss) / 2.0 + + fake_uncond_loss = bce_logits(logits["fake"]["uncond"], labels["fake"]["image"]) + fake_cond_loss = bce_logits(logits["fake"]["cond"], labels["fake"]["image"]) + + tot_loss += (fake_uncond_loss + fake_cond_loss) / 3.0 + tot_loss += word_loss + + return tot_loss + + +def kl_loss(mu_tensor: torch.Tensor, logvar: torch.Tensor) -> Any: + """ + Calculate KL loss + + :param torch.Tensor mu_tensor: Mean of latent distribution + :param torch.Tensor logvar: Log variance of latent distribution + :return: KL loss [-0.5 * (1 + log(sigma) - mu^2 - sigma^2)] + :rtype: Any + """ + return torch.mean(-0.5 * (1 + 0.5 * logvar - mu_tensor.pow(2) - torch.exp(logvar))) diff --git a/src/models/modules/__init__.py b/src/models/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7953928c945ae3d3c6669a732cc72d0a301ebd --- /dev/null +++ b/src/models/modules/__init__.py @@ -0,0 +1,12 @@ +"""All the modules used in creation of Generator and Discriminator""" +from .acm import ACM +from .attention import ChannelWiseAttention, SpatialAttention +from .cond_augment import CondAugmentation +from .conv_utils import calc_out_conv, conv1d, conv2d +from .discriminator import Discriminator, WordLevelLogits +from .downsample import down_sample +from .generator import Generator +from .image_encoder import InceptionEncoder, VGGEncoder +from .residual import ResidualBlock +from .text_encoder import TextEncoder +from .upsample import img_up_block, up_sample diff --git a/src/models/modules/__pycache__/__init__.cpython-39.pyc b/src/models/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4234fd3c86a244e413a270db8c479634d0ffa4e4 Binary files /dev/null and b/src/models/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/models/modules/__pycache__/acm.cpython-39.pyc b/src/models/modules/__pycache__/acm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cda2e7ff57131dd36503fbfff9a59707e740d337 Binary files /dev/null and b/src/models/modules/__pycache__/acm.cpython-39.pyc differ diff --git a/src/models/modules/__pycache__/attention.cpython-39.pyc b/src/models/modules/__pycache__/attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a90fda3050c1c1a73932c8f63ff04f38bb61f9b Binary files /dev/null and b/src/models/modules/__pycache__/attention.cpython-39.pyc differ diff --git a/src/models/modules/__pycache__/cond_augment.cpython-39.pyc b/src/models/modules/__pycache__/cond_augment.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddf12a697c9b50020587a6e8f54f0539fd9aa79d Binary files /dev/null and b/src/models/modules/__pycache__/cond_augment.cpython-39.pyc differ diff --git a/src/models/modules/__pycache__/conv_utils.cpython-39.pyc b/src/models/modules/__pycache__/conv_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f004e91e1bc04f9f0fdb14ddb9446170c702c153 Binary files /dev/null and b/src/models/modules/__pycache__/conv_utils.cpython-39.pyc differ diff --git a/src/models/modules/__pycache__/discriminator.cpython-39.pyc b/src/models/modules/__pycache__/discriminator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80f7ff7aca9b677e30e9673647dc8bf61d753615 Binary files /dev/null and b/src/models/modules/__pycache__/discriminator.cpython-39.pyc differ diff --git a/src/models/modules/__pycache__/downsample.cpython-39.pyc b/src/models/modules/__pycache__/downsample.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9772d800ad50236cc4e6c825ceae21134ded2c7b Binary files /dev/null and b/src/models/modules/__pycache__/downsample.cpython-39.pyc differ diff --git a/src/models/modules/__pycache__/generator.cpython-39.pyc b/src/models/modules/__pycache__/generator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89717d465680a3167d205c465b0966ebe67959bf Binary files /dev/null and b/src/models/modules/__pycache__/generator.cpython-39.pyc differ diff --git a/src/models/modules/__pycache__/image_encoder.cpython-39.pyc b/src/models/modules/__pycache__/image_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d37dbf7aca8d5154007d28534d7011e661e829a Binary files /dev/null and b/src/models/modules/__pycache__/image_encoder.cpython-39.pyc differ diff --git a/src/models/modules/__pycache__/residual.cpython-39.pyc b/src/models/modules/__pycache__/residual.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e493073358b90d0a0f28799bb65483a608315721 Binary files /dev/null and b/src/models/modules/__pycache__/residual.cpython-39.pyc differ diff --git a/src/models/modules/__pycache__/text_encoder.cpython-39.pyc b/src/models/modules/__pycache__/text_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..696620dea90574c3efad1af7e36358a8b608605d Binary files /dev/null and b/src/models/modules/__pycache__/text_encoder.cpython-39.pyc differ diff --git a/src/models/modules/__pycache__/upsample.cpython-39.pyc b/src/models/modules/__pycache__/upsample.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07f209aa1b0e1075e1e4d6b51c923a1fa956d95e Binary files /dev/null and b/src/models/modules/__pycache__/upsample.cpython-39.pyc differ diff --git a/src/models/modules/acm.py b/src/models/modules/acm.py new file mode 100644 index 0000000000000000000000000000000000000000..9e33110f3cd01ca180cd23f88998e1db4e5b74be --- /dev/null +++ b/src/models/modules/acm.py @@ -0,0 +1,37 @@ +"""ACM and its variations""" + +from typing import Any + +import torch +from torch import nn + +from .conv_utils import conv2d + + +class ACM(nn.Module): + """Affine Combination Module from ManiGAN""" + + def __init__(self, img_chans: int, text_chans: int, inner_dim: int = 64) -> None: + """ + Initialize the convolutional layers + + :param int img_chans: Channels in visual input + :param int text_chans: Channels of textual input + :param int inner_dim: Hyperparameters for inner dimensionality of features + """ + super().__init__() + self.conv = conv2d(in_channels=img_chans, out_channels=inner_dim) + self.weights = conv2d(in_channels=inner_dim, out_channels=text_chans) + self.biases = conv2d(in_channels=inner_dim, out_channels=text_chans) + + def forward(self, text: torch.Tensor, img: torch.Tensor) -> Any: + """ + Propagate the textual and visual input through the ACM module + + :param torch.Tensor text: Textual input (can be hidden features) + :param torch.Tensor img: Image input + :return: Affine combination of text and image + :rtype: torch.Tensor + """ + img_features = self.conv(img) + return text * self.weights(img_features) + self.biases(img_features) diff --git a/src/models/modules/attention.py b/src/models/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f5e990d7967397602b2099544e3e1e63631025 --- /dev/null +++ b/src/models/modules/attention.py @@ -0,0 +1,88 @@ +"""Attention modules""" +from typing import Any, Optional + +import torch +from torch import nn + +from src.models.modules.conv_utils import conv1d + + +class ChannelWiseAttention(nn.Module): + """ChannelWise attention adapted from ControlGAN""" + + def __init__(self, fm_size: int, text_d: int) -> None: + """ + Initialize the Channel-Wise attention module + + :param int fm_size: + Height and width of feature map on k-th iteration of forward-pass. + In paper, it's H_k * W_k + :param int text_d: Dimensionality of sentence. From paper, it's D + """ + super().__init__() + # perception layer + self.text_conv = conv1d(text_d, fm_size) + # attention across channel dimension + self.softmax = nn.Softmax(2) + + def forward(self, v_k: torch.Tensor, w_text: torch.Tensor) -> Any: + """ + Apply attention to visual features taking into account features of words + + :param torch.Tensor v_k: Visual context + :param torch.Tensor w_text: Textual features + :return: Fused hidden visual features and word features + :rtype: Any + """ + w_hat = self.text_conv(w_text) + m_k = v_k @ w_hat + a_k = self.softmax(m_k) + w_hat = torch.transpose(w_hat, 1, 2) + return a_k @ w_hat + + +class SpatialAttention(nn.Module): + """Spatial attention module for attending textual context to visual features""" + + def __init__(self, d: int, d_hat: int) -> None: + """ + Set up softmax and conv layers + + :param int d: Initial embedding size for textual features. D from paper + :param int d_hat: Height of image feature map. D_hat from paper + """ + super().__init__() + self.softmax = nn.Softmax(2) + self.conv = conv1d(d, d_hat) + + def forward( + self, + text_context: torch.Tensor, + image: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Any: + """ + Project image features into the latent space + of textual features and apply attention + + :param torch.Tensor text_context: D x T tensor of hidden textual features + :param torch.Tensor image: D_hat x N visual features + :param Optional[torch.Tensor] mask: + Boolean tensor for masking the padded words. BxL + :return: Word features attended by visual features + :rtype: Any + """ + # number of features on image feature map H * W + feature_num = image.size(2) + # number of words in caption + len_caption = text_context.size(2) + text_context = self.conv(text_context) + image = torch.transpose(image, 1, 2) + s_i_j = image @ text_context + if mask is not None: + # duplicating mask and aligning dims with s_i_j + mask = mask.repeat(1, feature_num).view(-1, feature_num, len_caption) + s_i_j[mask] = -float("inf") + b_i_j = self.softmax(s_i_j) + c_i_j = b_i_j @ torch.transpose(text_context, 1, 2) + return torch.transpose(c_i_j, 1, 2) diff --git a/src/models/modules/cond_augment.py b/src/models/modules/cond_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..4bab9d86afda570670760d2f1b8bc2ba96085251 --- /dev/null +++ b/src/models/modules/cond_augment.py @@ -0,0 +1,57 @@ +"""Conditioning Augmentation Module""" + +from typing import Any + +import torch +from torch import nn + + +class CondAugmentation(nn.Module): + """Conditioning Augmentation Module""" + + def __init__(self, D: int, conditioning_dim: int): + """ + :param D: Dimension of the text embedding space [D from AttnGAN paper] + :param conditioning_dim: Dimension of the conditioning space + """ + super().__init__() + self.cond_dim = conditioning_dim + self.cond_augment = nn.Linear(D, conditioning_dim * 4, bias=True) + self.glu = nn.GLU(dim=1) + + def encode(self, text_embedding: torch.Tensor) -> Any: + """ + This function encodes the text embedding into the conditioning space + :param text_embedding: Text embedding + :return: Conditioning embedding + """ + x_tensor = self.glu(self.cond_augment(text_embedding)) + mu_tensor = x_tensor[:, : self.cond_dim] + logvar = x_tensor[:, self.cond_dim :] + return mu_tensor, logvar + + def sample(self, mu_tensor: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + """ + This function samples from the Gaussian distribution + :param mu: Mean of the Gaussian distribution + :param logvar: Log variance of the Gaussian distribution + :return: Sample from the Gaussian distribution + """ + std = torch.exp(0.5 * logvar) + eps = torch.randn_like( + std + ) # check if this should add requires_grad = True to this tensor? + return mu_tensor + eps * std + + def forward(self, text_embedding: torch.Tensor) -> Any: + """ + This function encodes the text embedding into the conditioning space, + and samples from the Gaussian distribution. + :param text_embedding: Text embedding + :return c_hat: Conditioning embedding (C^ from StackGAN++ paper) + :return mu: Mean of the Gaussian distribution + :return logvar: Log variance of the Gaussian distribution + """ + mu_tensor, logvar = self.encode(text_embedding) + c_hat = self.sample(mu_tensor, logvar) + return c_hat, mu_tensor, logvar diff --git a/src/models/modules/conv_utils.py b/src/models/modules/conv_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..788b219034597ba94e2a4f51dcf806d36cf7b4e1 --- /dev/null +++ b/src/models/modules/conv_utils.py @@ -0,0 +1,78 @@ +"""Frequently used convolution modules""" + +from torch import nn + +from typing import Tuple + + +def conv2d( + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, +) -> nn.Conv2d: + """ + Template convolution which is typically used throughout the project + + :param int in_channels: Number of input channels + :param int out_channels: Number of output channels + :param int kernel_size: Size of sliding kernel + :param int stride: How many steps kernel does when sliding + :param int padding: How many dimensions to pad + :return: Convolution layer with parameters + :rtype: nn.Conv2d + """ + return nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + +def conv1d( + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, +) -> nn.Conv1d: + """ + Template 1d convolution which is typically used throughout the project + + :param int in_channels: Number of input channels + :param int out_channels: Number of output channels + :param int kernel_size: Size of sliding kernel + :param int stride: How many steps kernel does when sliding + :param int padding: How many dimensions to pad + :return: Convolution layer with parameters + :rtype: nn.Conv2d + """ + return nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + +def calc_out_conv( + h_in: int, w_in: int, kernel_size: int = 3, stride: int = 1, padding: int = 0 +) -> Tuple[int, int]: + """ + Calculate the dimensionalities of images propagated through conv layers + + :param h_in: Height of the image + :param w_in: Width of the image + :param kernel_size: Size of sliding kernel + :param stride: How many steps kernel does when sliding + :param padding: How many dimensions to pad + :return: Height and width of image through convolution + :rtype: tuple[int, int] + """ + h_out = int((h_in + 2 * padding - kernel_size) / stride + 1) + w_out = int((w_in + 2 * padding - kernel_size) / stride + 1) + return h_out, w_out diff --git a/src/models/modules/discriminator.py b/src/models/modules/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..ebe184675eba85ff4e0a2144a5a333af8a5250db --- /dev/null +++ b/src/models/modules/discriminator.py @@ -0,0 +1,144 @@ +"""Discriminator providing word-level feedback""" +from typing import Any + +import torch +from torch import nn + +from src.models.modules.conv_utils import conv1d, conv2d +from src.models.modules.image_encoder import InceptionEncoder + + +class WordLevelLogits(nn.Module): + """API for converting regional feature maps into logits for multi-class classification""" + + def __init__(self) -> None: + """ + Instantiate the module with softmax on channel dimension + """ + super().__init__() + self.softmax = nn.Softmax(dim=1) + # layer for flattening the feature maps + self.flat = nn.Flatten(start_dim=2) + # change dism of of textual embs to correlate with chans of inception + self.chan_reduction = conv1d(256, 128) + + def forward( + self, visual_features: torch.Tensor, word_embs: torch.Tensor, mask: torch.Tensor + ) -> Any: + """ + Fuse two types of features together to get output for feeding into the classification loss + :param torch.Tensor visual_features: + Feature maps of an image after being processed by Inception encoder. Bx128x17x17 + :param torch.Tensor word_embs: + Word-level embeddings from the text encoder Bx256xL + :return: Logits for each word in the picture. BxL + :rtype: Any + """ + # make textual and visual features have the same amount of channels + word_embs = self.chan_reduction(word_embs) + # flattening the feature maps + visual_features = self.flat(visual_features) + word_embs = torch.transpose(word_embs, 1, 2) + word_region_correlations = word_embs @ visual_features + # normalize across L dimension + m_norm_l = nn.functional.normalize(word_region_correlations, dim=1) + # normalize across H*W dimension + m_norm_hw = nn.functional.normalize(m_norm_l, dim=2) + m_norm_hw = torch.transpose(m_norm_hw, 1, 2) + weighted_img_feats = visual_features @ m_norm_hw + weighted_img_feats = torch.sum(weighted_img_feats, dim=1) + weighted_img_feats[mask] = -float("inf") + deltas = self.softmax(weighted_img_feats) + return deltas + + +class UnconditionalLogits(nn.Module): + """Head for retrieving logits from an image""" + + def __init__(self) -> None: + """Initialize modules that reduce the features down to a set of logits""" + super().__init__() + self.conv = nn.Conv2d(128, 1, kernel_size=17) + # flattening BxLx1x1 into Bx1 + self.flat = nn.Flatten() + + def forward(self, visual_features: torch.Tensor) -> Any: + """ + Compute logits for unconditioned adversarial loss + + :param visual_features: Local features from Inception network. Bx128x17x17 + :return: Logits for unconditioned adversarial loss. Bx1 + :rtype: Any + """ + # reduce channels and feature maps for visual features + visual_features = self.conv(visual_features) + # flatten Bx1x1x1 into Bx1 + logits = self.flat(visual_features) + return logits + + +class ConditionalLogits(nn.Module): + """Logits extractor for conditioned adversarial loss""" + + def __init__(self) -> None: + super().__init__() + # layer for forming the feature maps out of textual info + self.text_to_fm = conv1d(256, 17 * 17) + # fitting the size of text channels to the size of visual channels + self.chan_aligner = conv2d(1, 128) + # for reduced textual + visual features down to 1x1 feature map + self.joint_conv = nn.Conv2d(2 * 128, 1, kernel_size=17) + # converting Bx1x1x1 into Bx1 + self.flat = nn.Flatten() + + def forward(self, visual_features: torch.Tensor, sent_embs: torch.Tensor) -> Any: + """ + Compute logits for conditional adversarial loss + + :param torch.Tensor visual_features: Features from Inception encoder. Bx128x17x17 + :param torch.Tensor sent_embs: Sentence embeddings from text encoder. Bx256 + :return: Logits for conditional adversarial loss. BxL + :rtype: Any + """ + # make text and visual features have the same sizes of feature maps + # Bx256 -> Bx256x1 -> Bx289x1 + sent_embs = sent_embs.view(-1, 256, 1) + sent_embs = self.text_to_fm(sent_embs) + # transform textual info into shape of visual feature maps + # Bx289x1 -> Bx1x17x17 + sent_embs = sent_embs.view(-1, 1, 17, 17) + # propagate text embs through 1d conv to + # align dims with visual feature maps + sent_embs = self.chan_aligner(sent_embs) + # unite textual and visual features across the dim of channels + cross_features = torch.cat((visual_features, sent_embs), dim=1) + # reduce dims down to length of caption and form raw logits + cross_features = self.joint_conv(cross_features) + # form logits from Bx1x1x1 into Bx1 + logits = self.flat(cross_features) + return logits + + +class Discriminator(nn.Module): + """Simple CNN-based discriminator""" + + def __init__(self) -> None: + """Use a pretrained InceptionNet to extract features""" + super().__init__() + self.encoder = InceptionEncoder(D=128) + # define different logit extractors for different losses + self.logits_word_level = WordLevelLogits() + self.logits_uncond = UnconditionalLogits() + self.logits_cond = ConditionalLogits() + + def forward(self, images: torch.Tensor) -> Any: + """ + Retrieves image features encoded by the image encoder + + :param torch.Tensor images: Images to be analyzed. Bx3x256x256 + :return: image features encoded by image encoder. Bx128x17x17 + """ + # only taking the local features from inception + # Bx3x256x256 -> Bx128x17x17 + img_features, _ = self.encoder(images) + return img_features diff --git a/src/models/modules/downsample.py b/src/models/modules/downsample.py new file mode 100644 index 0000000000000000000000000000000000000000..b34b1be63bbbb695be63c7be744bc7ca74e2b04e --- /dev/null +++ b/src/models/modules/downsample.py @@ -0,0 +1,14 @@ +"""downsample module.""" + +from torch import nn + + +def down_sample(in_planes: int, out_planes: int) -> nn.Module: + """UpSample module.""" + return nn.Sequential( + nn.Conv2d( + in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False + ), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.2, inplace=True), + ) diff --git a/src/models/modules/generator.py b/src/models/modules/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b05d73a7583345b7f67bd6baecc68483105d30 --- /dev/null +++ b/src/models/modules/generator.py @@ -0,0 +1,300 @@ +"""Generator Module""" + +from typing import Any, Optional + +import torch +from torch import nn + +from src.models.modules.acm import ACM +from src.models.modules.attention import ChannelWiseAttention, SpatialAttention +from src.models.modules.cond_augment import CondAugmentation +from src.models.modules.downsample import down_sample +from src.models.modules.residual import ResidualBlock +from src.models.modules.upsample import img_up_block, up_sample + + +class InitStageG(nn.Module): + """Initial Stage Generator Module""" + + # pylint: disable=too-many-instance-attributes + # pylint: disable=too-many-arguments + # pylint: disable=invalid-name + # pylint: disable=too-many-locals + + def __init__( + self, Ng: int, Ng_init: int, conditioning_dim: int, D: int, noise_dim: int + ): + """ + :param Ng: Number of channels. + :param Ng_init: Initial value of Ng, this is output channel of first image upsample. + :param conditioning_dim: Dimension of the conditioning space + :param D: Dimension of the text embedding space [D from AttnGAN paper] + :param noise_dim: Dimension of the noise space + """ + super().__init__() + self.gf_dim = Ng + self.gf_init = Ng_init + self.in_dim = noise_dim + conditioning_dim + D + self.text_dim = D + + self.define_module() + + def define_module(self) -> None: + """Defines FC, Upsample, Residual, ACM, Attention modules""" + nz, ng = self.in_dim, self.gf_dim + self.fully_connect = nn.Sequential( + nn.Linear(nz, ng * 4 * 4 * 2, bias=False), + nn.BatchNorm1d(ng * 4 * 4 * 2), + nn.GLU(dim=1), # we start from 4 x 4 feat_map and return hidden_64. + ) + + self.upsample1 = up_sample(ng, ng // 2) + self.upsample2 = up_sample(ng // 2, ng // 4) + self.upsample3 = up_sample(ng // 4, ng // 8) + self.upsample4 = up_sample( + ng // 8 * 3, ng // 16 + ) # multiply channel by 3 because concat spatial and channel att + + self.residual = self._make_layer(ResidualBlock, ng // 8 * 3) + self.acm_module = ACM(self.gf_init, ng // 8 * 3) + + self.spatial_att = SpatialAttention(self.text_dim, ng // 8) + self.channel_att = ChannelWiseAttention( + 32 * 32, self.text_dim + ) # 32 x 32 is the feature map size + + def _make_layer(self, block: Any, channel_num: int) -> nn.Module: + layers = [] + for _ in range(2): # number of residual blocks hardcoded to 2 + layers.append(block(channel_num)) + return nn.Sequential(*layers) + + def forward( + self, + noise: torch.Tensor, + condition: torch.Tensor, + global_inception: torch.Tensor, + local_upsampled_inception: torch.Tensor, + word_embeddings: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Any: + """ + :param noise: Noise tensor + :param condition: Condition tensor (c^ from stackGAN++ paper) + :param global_inception: Global inception feature + :param local_upsampled_inception: Local inception feature, upsampled to 32 x 32 + :param word_embeddings: Word embeddings [shape: D x L or D x T] + :param mask: Mask for padding tokens + :return: Hidden Image feature map Tensor of 64 x 64 size + """ + noise_concat = torch.cat((noise, condition), 1) + inception_concat = torch.cat((noise_concat, global_inception), 1) + hidden = self.fully_connect(inception_concat) + hidden = hidden.view(-1, self.gf_dim, 4, 4) # convert to 4x4 image feature map + hidden = self.upsample1(hidden) + hidden = self.upsample2(hidden) + hidden_32 = self.upsample3(hidden) # shape: (batch_size, gf_dim // 8, 32, 32) + hidden_32_view = hidden_32.view( + hidden_32.shape[0], -1, hidden_32.shape[2] * hidden_32.shape[3] + ) # this reshaping is done as attention module expects this shape. + + spatial_att_feat = self.spatial_att( + word_embeddings, hidden_32_view, mask + ) # spatial att shape: (batch, D^, 32 * 32) + channel_att_feat = self.channel_att( + spatial_att_feat, word_embeddings + ) # channel att shape: (batch, D^, 32 * 32), or (batch, C, Hk* Wk) from controlGAN paper + spatial_att_feat = spatial_att_feat.view( + word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3] + ) # reshape to (batch, D^, 32, 32) + channel_att_feat = channel_att_feat.view( + word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3] + ) # reshape to (batch, D^, 32, 32) + + spatial_concat = torch.cat( + (hidden_32, spatial_att_feat), 1 + ) # concat spatial attention feature with hidden_32 + attn_concat = torch.cat( + (spatial_concat, channel_att_feat), 1 + ) # concat channel and spatial attention feature + + hidden_32 = self.acm_module(attn_concat, local_upsampled_inception) + hidden_32 = self.residual(hidden_32) + hidden_64 = self.upsample4(hidden_32) + return hidden_64 + + +class NextStageG(nn.Module): + """Next Stage Generator Module""" + + # pylint: disable=too-many-instance-attributes + # pylint: disable=too-many-arguments + # pylint: disable=invalid-name + # pylint: disable=too-many-locals + + def __init__(self, Ng: int, Ng_init: int, D: int, image_size: int): + """ + :param Ng: Number of channels. + :param Ng_init: Initial value of Ng. + :param D: Dimension of the text embedding space [D from AttnGAN paper] + :param image_size: Size of the output image from previous generator stage. + """ + super().__init__() + self.gf_dim = Ng + self.gf_init = Ng_init + self.text_dim = D + self.img_size = image_size + + self.define_module() + + def define_module(self) -> None: + """Defines FC, Upsample, Residual, ACM, Attention modules""" + ng = self.gf_dim + self.spatial_att = SpatialAttention(self.text_dim, ng) + self.channel_att = ChannelWiseAttention( + self.img_size * self.img_size, self.text_dim + ) + + self.residual = self._make_layer(ResidualBlock, ng * 3) + self.upsample = up_sample(ng * 3, ng) + self.acm_module = ACM(self.gf_init, ng * 3) + self.upsample2 = up_sample(ng, ng) + + def _make_layer(self, block: Any, channel_num: int) -> nn.Module: + layers = [] + for _ in range(2): # no of residual layers hardcoded to 2 + layers.append(block(channel_num)) + return nn.Sequential(*layers) + + def forward( + self, + hidden_feat: Any, + word_embeddings: torch.Tensor, + vgg64_feat: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Any: + """ + :param hidden_feat: Hidden feature from previous generator stage [i.e. hidden_64] + :param word_embeddings: Word embeddings + :param vgg64_feat: VGG feature map of size 64 x 64 + :param mask: Mask for the padding tokens + :return: Image feature map of size 256 x 256 + """ + hidden_view = hidden_feat.view( + hidden_feat.shape[0], -1, hidden_feat.shape[2] * hidden_feat.shape[3] + ) # reshape to pass into attention modules. + spatial_att_feat = self.spatial_att( + word_embeddings, hidden_view, mask + ) # spatial att shape: (batch, D^, 64 * 64), or D^ x N + channel_att_feat = self.channel_att( + spatial_att_feat, word_embeddings + ) # channel att shape: (batch, D^, 64 * 64), or (batch, C, Hk* Wk) from controlGAN paper + spatial_att_feat = spatial_att_feat.view( + word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3] + ) # reshape to (batch, D^, 64, 64) + channel_att_feat = channel_att_feat.view( + word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3] + ) # reshape to (batch, D^, 64, 64) + + spatial_concat = torch.cat( + (hidden_feat, spatial_att_feat), 1 + ) # concat spatial attention feature with hidden_64 + attn_concat = torch.cat( + (spatial_concat, channel_att_feat), 1 + ) # concat channel and spatial attention feature + + hidden_64 = self.acm_module(attn_concat, vgg64_feat) + hidden_64 = self.residual(hidden_64) + hidden_128 = self.upsample(hidden_64) + hidden_256 = self.upsample2(hidden_128) + return hidden_256 + + +class GetImageG(nn.Module): + """Generates the Final Fake Image from the Image Feature Map""" + + def __init__(self, Ng: int): + """ + :param Ng: Number of channels. + """ + super().__init__() + self.img = nn.Sequential( + nn.Conv2d(Ng, 3, kernel_size=3, stride=1, padding=1, bias=False), nn.Tanh() + ) + + def forward(self, hidden_feat: torch.Tensor) -> Any: + """ + :param hidden_feat: Image feature map + :return: Final fake image + """ + return self.img(hidden_feat) + + +class Generator(nn.Module): + """Generator Module""" + + # pylint: disable=too-many-instance-attributes + # pylint: disable=too-many-arguments + # pylint: disable=invalid-name + # pylint: disable=too-many-locals + + def __init__(self, Ng: int, D: int, conditioning_dim: int, noise_dim: int): + """ + :param Ng: Number of channels. [Taken from StackGAN++ paper] + :param D: Dimension of the text embedding space + :param conditioning_dim: Dimension of the conditioning space + :param noise_dim: Dimension of the noise space + """ + super().__init__() + self.cond_augment = CondAugmentation(D, conditioning_dim) + self.hidden_net1 = InitStageG(Ng * 16, Ng, conditioning_dim, D, noise_dim) + self.inception_img_upsample = img_up_block( + D, Ng + ) # as channel size returned by inception encoder is D (Default in paper: 256) + self.hidden_net2 = NextStageG(Ng, Ng, D, 64) + self.generate_img = GetImageG(Ng) + + self.acm_module = ACM(Ng, Ng) + + self.vgg_downsample = down_sample(D // 2, Ng) + self.upsample1 = up_sample(Ng, Ng) + self.upsample2 = up_sample(Ng, Ng) + + def forward( + self, + noise: torch.Tensor, + sentence_embeddings: torch.Tensor, + word_embeddings: torch.Tensor, + global_inception_feat: torch.Tensor, + local_inception_feat: torch.Tensor, + vgg_feat: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Any: + """ + :param noise: Noise vector [shape: (batch, noise_dim)] + :param sentence_embeddings: Sentence embeddings [shape: (batch, D)] + :param word_embeddings: Word embeddings [shape: D x L, where L is length of sentence] + :param global_inception_feat: Global Inception feature map [shape: (batch, D)] + :param local_inception_feat: Local Inception feature map [shape: (batch, D, 17, 17)] + :param vgg_feat: VGG feature map [shape: (batch, D // 2 = 128, 128, 128)] + :param mask: Mask for the padding tokens + :return: Final fake image + """ + c_hat, mu_tensor, logvar = self.cond_augment(sentence_embeddings) + hidden_32 = self.inception_img_upsample(local_inception_feat) + + hidden_64 = self.hidden_net1( + noise, c_hat, global_inception_feat, hidden_32, word_embeddings, mask + ) + + vgg_64 = self.vgg_downsample(vgg_feat) + + hidden_256 = self.hidden_net2(hidden_64, word_embeddings, vgg_64, mask) + + vgg_128 = self.upsample1(vgg_64) + vgg_256 = self.upsample2(vgg_128) + + hidden_256 = self.acm_module(hidden_256, vgg_256) + fake_img = self.generate_img(hidden_256) + + return fake_img, mu_tensor, logvar diff --git a/src/models/modules/image_encoder.py b/src/models/modules/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7b315f428adec23c832f80efb6b838a4f7808bb3 --- /dev/null +++ b/src/models/modules/image_encoder.py @@ -0,0 +1,138 @@ +"""Image Encoder Module""" +from typing import Any + +import torch +from torch import nn + +from src.models.modules.conv_utils import conv2d + +# build inception v3 image encoder + + +class InceptionEncoder(nn.Module): + """Image Encoder Module adapted from AttnGAN""" + + def __init__(self, D: int): + """ + :param D: Dimension of the text embedding space [D from AttnGAN paper] + """ + super().__init__() + + self.text_emb_dim = D + + model = torch.hub.load( + "pytorch/vision:v0.10.0", "inception_v3", pretrained=True + ) + for param in model.parameters(): + param.requires_grad = False + + self.define_module(model) + self.init_trainable_weights() + + def define_module(self, model: nn.Module) -> None: + """ + This function defines the modules of the image encoder + :param model: Pretrained Inception V3 model + """ + model.cust_upsample = nn.Upsample(size=(299, 299), mode="bilinear") + model.cust_maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) + model.cust_maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) + model.cust_avgpool = nn.AvgPool2d(kernel_size=8) + + attribute_list = [ + "cust_upsample", + "Conv2d_1a_3x3", + "Conv2d_2a_3x3", + "Conv2d_2b_3x3", + "cust_maxpool1", + "Conv2d_3b_1x1", + "Conv2d_4a_3x3", + "cust_maxpool2", + "Mixed_5b", + "Mixed_5c", + "Mixed_5d", + "Mixed_6a", + "Mixed_6b", + "Mixed_6c", + "Mixed_6d", + "Mixed_6e", + ] + + self.feature_extractor = nn.Sequential( + *[getattr(model, name) for name in attribute_list] + ) + + attribute_list2 = ["Mixed_7a", "Mixed_7b", "Mixed_7c", "cust_avgpool"] + + self.feature_extractor2 = nn.Sequential( + *[getattr(model, name) for name in attribute_list2] + ) + + self.emb_features = conv2d( + 768, self.text_emb_dim, kernel_size=1, stride=1, padding=0 + ) + self.emb_cnn_code = nn.Linear(2048, self.text_emb_dim) + + def init_trainable_weights(self) -> None: + """ + This function initializes the trainable weights of the image encoder + """ + initrange = 0.1 + self.emb_features.weight.data.uniform_(-initrange, initrange) + self.emb_cnn_code.weight.data.uniform_(-initrange, initrange) + + def forward(self, image_tensor: torch.Tensor) -> Any: + """ + :param image_tensor: Input image + :return: features: local feature matrix (v from attnGAN paper) [shape: (batch, D, 17, 17)] + :return: cnn_code: global image feature (v^ from attnGAN paper) [shape: (batch, D)] + """ + # this is the image size + # x.shape: 10 3 256 256 + + features = self.feature_extractor(image_tensor) + # 17 x 17 x 768 + + image_tensor = self.feature_extractor2(features) + + image_tensor = image_tensor.view(image_tensor.size(0), -1) + # 2048 + + # global image features + cnn_code = self.emb_cnn_code(image_tensor) + + if features is not None: + features = self.emb_features(features) + + # feature.shape: 10 256 17 17 + # cnn_code.shape: 10 256 + return features, cnn_code + + +class VGGEncoder(nn.Module): + """Pre Trained VGG Encoder Module""" + + def __init__(self) -> None: + """ + Initialize pre-trained VGG model with frozen parameters + """ + super().__init__() + self.select = "8" ## We want to get the output of the 8th layer in VGG. + + self.model = torch.hub.load("pytorch/vision:v0.10.0", "vgg16", pretrained=True) + + for param in self.model.parameters(): + param.resquires_grad = False + + self.vgg_modules = self.model.features._modules + + def forward(self, image_tensor: torch.Tensor) -> Any: + """ + :param x: Input image tensor [shape: (batch, 3, 256, 256)] + :return: VGG features [shape: (batch, 128, 128, 128)] + """ + for name, layer in self.vgg_modules.items(): + image_tensor = layer(image_tensor) + if name == self.select: + return image_tensor + return None diff --git a/src/models/modules/residual.py b/src/models/modules/residual.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa9340e07c26f981b7f71376a544054875680b3 --- /dev/null +++ b/src/models/modules/residual.py @@ -0,0 +1,42 @@ +"""Residual Block Adopted from ManiGAN""" + +from typing import Any + +import torch +from torch import nn + + +class ResidualBlock(nn.Module): + """Residual Block""" + + def __init__(self, channel_num: int) -> None: + """ + :param channel_num: Number of channels in the input + """ + super().__init__() + self.block = nn.Sequential( + nn.Conv2d( + channel_num, + channel_num * 2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + nn.InstanceNorm2d(channel_num * 2), + nn.GLU(dim=1), + nn.Conv2d( + channel_num, channel_num, kernel_size=3, stride=1, padding=1, bias=False + ), + nn.InstanceNorm2d(channel_num), + ) + + def forward(self, input_tensor: torch.Tensor) -> Any: + """ + :param input_tensor: Input tensor + :return: Output tensor + """ + residual = input_tensor + out = self.block(input_tensor) + out += residual + return out diff --git a/src/models/modules/text_encoder.py b/src/models/modules/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f9445847ff866c88e34ccdd03c639d89a296af74 --- /dev/null +++ b/src/models/modules/text_encoder.py @@ -0,0 +1,39 @@ +"""LSTM-based textual encoder for tokenized input""" + +from typing import Any + +import torch +from torch import nn + + +class TextEncoder(nn.Module): + """Simple text encoder based on RNN""" + + def __init__(self, vocab_size: int, emb_dim: int, hidden_dim: int) -> None: + """ + Initialize embeddings lookup for tokens and main LSTM + + :param vocab_size: + Size of created vocabulary for textual input. L from paper + :param emb_dim: Length of embeddings for each word. + :param hidden_dim: + Length of hidden state of a LSTM cell. 2 x hidden_dim = C (from LWGAN paper) + """ + super().__init__() + self.embs = nn.Embedding(vocab_size, emb_dim) + self.lstm = nn.LSTM(emb_dim, hidden_dim, bidirectional=True, batch_first=True) + + def forward(self, tokens: torch.Tensor) -> Any: + """ + Propagate the text token input through the LSTM and return + two types of embeddings: word-level and sentence-level + + :param torch.Tensor tokens: Input text tokens from vocab + :return: Word-level embeddings (BxCxL) and sentence-level embeddings (BxC) + :rtype: Any + """ + embs = self.embs(tokens) + output, (hidden_states, _) = self.lstm(embs) + word_embs = torch.transpose(output, 1, 2) + sent_embs = torch.cat((hidden_states[-1, :, :], hidden_states[0, :, :]), dim=1) + return word_embs, sent_embs diff --git a/src/models/modules/upsample.py b/src/models/modules/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..6f38b7418c031b513b9396d98e8c1fb023efdd6e --- /dev/null +++ b/src/models/modules/upsample.py @@ -0,0 +1,30 @@ +"""UpSample module.""" + +from torch import nn + + +def up_sample(in_planes: int, out_planes: int) -> nn.Module: + """UpSample module.""" + return nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d( + in_planes, out_planes * 2, kernel_size=3, stride=1, padding=1, bias=False + ), + nn.InstanceNorm2d(out_planes * 2), + nn.GLU(dim=1), + ) + + +def img_up_block(in_planes: int, out_planes: int) -> nn.Module: + """ + Image upsample block. + Mainly used to conver the 17 x 17 local feature map from Inception to 32 x 32 size. + """ + return nn.Sequential( + nn.Upsample(scale_factor=1.9, mode="nearest"), + nn.Conv2d( + in_planes, out_planes * 2, kernel_size=3, stride=1, padding=1, bias=False + ), + nn.InstanceNorm2d(out_planes * 2), + nn.GLU(dim=1), + ) diff --git a/src/models/predict_model.py b/src/models/predict_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/train_model.py b/src/models/train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e27960160f4fa4cbc29f7b25558ccfd7860ad278 --- /dev/null +++ b/src/models/train_model.py @@ -0,0 +1,237 @@ +"""Module to train the GAN model""" + +from typing import Any, Dict + +import torch + +from src.models.losses import discriminator_loss, generator_loss, kl_loss +from src.models.modules.discriminator import Discriminator +from src.models.modules.generator import Generator +from src.models.modules.image_encoder import InceptionEncoder, VGGEncoder +from src.models.modules.text_encoder import TextEncoder +from src.models.utils import ( + define_optimizers, + load_model, + prepare_labels, + save_image_and_caption, + save_model, + save_plot, +) + +# pylint: disable=too-many-locals +# pylint: disable=too-many-statements + + +def train(data_loader: Any, config_dict: Dict[str, Any]) -> None: + """ + Function to train the GAN model + :param data_loader: Data loader for the dataset + :param vocab_len: Length of the vocabulary + :param config_dict: Dictionary containing the configuration parameters + """ + ( + Ng, # pylint: disable=invalid-name + D, # pylint: disable=invalid-name + condition_dim, + noise_dim, + lr_config, + batch_size, + device, + epochs, + vocab_len, + ix2word, + output_dir, + snapshot, + const_dict, + ) = ( + config_dict["Ng"], + config_dict["D"], + config_dict["condition_dim"], + config_dict["noise_dim"], + config_dict["lr_config"], + config_dict["batch_size"], + config_dict["device"], + config_dict["epochs"], + config_dict["vocab_len"], + config_dict["ix2word"], + config_dict["output_dir"], + config_dict["snapshot"], + config_dict["const_dict"], + ) + + generator = Generator(Ng, D, condition_dim, noise_dim).to(device) + discriminator = Discriminator().to(device) + text_encoder = TextEncoder(vocab_len, D, D // 2).to(device) + image_encoder = InceptionEncoder(D).to(device) + vgg_encoder = VGGEncoder().to(device) + gen_loss = [] + disc_loss = [] + + load_model(generator, discriminator, image_encoder, text_encoder, output_dir) + + ( + optimizer_g, + optimizer_d, + optimizer_text_encoder, + opt_image_encoder, + ) = define_optimizers( + generator, discriminator, image_encoder, text_encoder, lr_config + ) + + for epoch in range(1, epochs + 1): + for batch_idx, ( + images, + correct_capt, + correct_capt_len, + curr_class, + word_labels, + ) in enumerate(data_loader): + + labels_real, labels_fake, labels_match, fake_word_labels = prepare_labels( + batch_size, word_labels.size(1), device + ) + + optimizer_d.zero_grad() + optimizer_text_encoder.zero_grad() + + noise = torch.randn(batch_size, noise_dim).to(device) + word_emb, sent_emb = text_encoder(correct_capt) + + local_incept_feat, global_incept_feat = image_encoder(images) + + vgg_feat = vgg_encoder(images) + mask = correct_capt == 0 + + # Generate Fake Images + fake_imgs, mu_tensor, logvar = generator( + noise, + sent_emb, + word_emb, + global_incept_feat, + local_incept_feat, + vgg_feat, + mask, + ) + + # Generate Logits for discriminator update + real_discri_feat = discriminator(images) + fake_discri_feat = discriminator(fake_imgs.detach()) + + logits_discri = { + "fake": { + "uncond": discriminator.logits_uncond(fake_discri_feat), + "cond": discriminator.logits_cond(fake_discri_feat, sent_emb), + }, + "real": { + "word_level": discriminator.logits_word_level( + real_discri_feat, word_emb, mask + ), + "uncond": discriminator.logits_uncond(real_discri_feat), + "cond": discriminator.logits_cond(real_discri_feat, sent_emb), + }, + } + + labels_discri = { + "fake": {"word_level": fake_word_labels, "image": labels_fake}, + "real": {"word_level": word_labels, "image": labels_real}, + } + + # Update Discriminator + + loss_discri = discriminator_loss(logits_discri, labels_discri) + + loss_discri.backward(retain_graph=True) + optimizer_d.step() + optimizer_text_encoder.step() + + disc_loss.append(loss_discri.item()) + + optimizer_g.zero_grad() + opt_image_encoder.zero_grad() + + word_emb, sent_emb = text_encoder(correct_capt) + + fake_imgs, mu_tensor, logvar = generator( + noise, + sent_emb, + word_emb, + global_incept_feat, + local_incept_feat, + vgg_feat, + mask, + ) + + local_fake_incept_feat, global_fake_incept_feat = image_encoder(fake_imgs) + + vgg_feat_fake = vgg_encoder(fake_imgs) + + fake_feat_d = discriminator(fake_imgs) + + logits_gen = { + "fake": { + "uncond": discriminator.logits_uncond(fake_feat_d), + "cond": discriminator.logits_cond(fake_feat_d, sent_emb), + } + } + + # Update Generator + loss_gen = generator_loss( + logits_gen, + local_fake_incept_feat, + global_fake_incept_feat, + labels_real, + word_emb, + sent_emb, + labels_match, + correct_capt_len, + curr_class, + vgg_feat, + vgg_feat_fake, + const_dict, + ) + + loss_kl = kl_loss(mu_tensor, logvar) + + loss_gen += loss_kl + + loss_gen.backward() + optimizer_g.step() + opt_image_encoder.step() + gen_loss.append(loss_gen.item()) + + if (batch_idx + 1) % 20 == 0: + print( + f"Epoch [{epoch}/{epochs}], Batch [{batch_idx + 1}/{len(data_loader)}],\ + Loss D: {loss_discri.item():.4f}, Loss G: {loss_gen.item():.4f}" + ) + + if (batch_idx + 1) % 50 == 0: + with torch.no_grad(): + fake_imgs_act, _, _ = generator( + noise, + sent_emb, + word_emb, + global_incept_feat, + local_incept_feat, + vgg_feat, + mask, + ) + save_image_and_caption( + fake_imgs_act, + images, + correct_capt, + ix2word, + batch_idx, + epoch, + output_dir, + ) + save_plot(gen_loss, disc_loss, epoch, batch_idx, output_dir) + + if epoch % snapshot == 0 and epoch != 0: + save_model( + generator, discriminator, image_encoder, text_encoder, epoch, output_dir + ) + + save_model( + generator, discriminator, image_encoder, text_encoder, epochs, output_dir + ) diff --git a/src/models/utils.py b/src/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..41c58aa74447fca5269d8112a1bdf3dfe875becb --- /dev/null +++ b/src/models/utils.py @@ -0,0 +1,276 @@ +"""Helper functions for models.""" + +import pathlib +import pickle +from copy import deepcopy +from pathlib import Path +from typing import Any, List, Dict + +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image +from torch import optim + +from src.models.modules.discriminator import Discriminator +from src.models.modules.generator import Generator +from src.models.modules.image_encoder import InceptionEncoder +from src.models.modules.text_encoder import TextEncoder + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-locals + + +def copy_gen_params(generator: Generator) -> Any: + """ + Function to copy the parameters of the generator + """ + params = deepcopy(list(p.data for p in generator.parameters())) + return params + + +def define_optimizers( + generator: Generator, + discriminator: Discriminator, + image_encoder: InceptionEncoder, + text_encoder: TextEncoder, + lr_config: Dict[str, float], +) -> Any: + """ + Function to define the optimizers for the generator and discriminator + :param generator: Generator model + :param image_encoder: Image encoder model + :param text_encoder: Text encoder model + :param discriminator: Discriminator model + :param lr_config: Dictionary containing the learning rates for the optimizers + + """ + img_encoder_lr = lr_config["img_encoder_lr"] + text_encoder_lr = lr_config["text_encoder_lr"] + gen_lr = lr_config["gen_lr"] + disc_lr = lr_config["disc_lr"] + + optimizer_g = optim.Adam( + [{"params": generator.parameters()}], + lr=gen_lr, + betas=(0.5, 0.999), + ) + optimizer_d = optim.Adam( + [{"params": discriminator.parameters()}], + lr=disc_lr, + betas=(0.5, 0.999), + ) + optimizer_text_encoder = optim.Adam(text_encoder.parameters(), lr=text_encoder_lr) + optimizer_image_encoder = optim.Adam(image_encoder.parameters(), lr=img_encoder_lr) + + return optimizer_g, optimizer_d, optimizer_text_encoder, optimizer_image_encoder + + +def prepare_labels(batch_size: int, max_seq_len: int, device: torch.device) -> Any: + """ + Function to prepare the labels for the discriminator and generator. + """ + real_labels = torch.FloatTensor(batch_size, 1).fill_(1).to(device) + fake_labels = torch.FloatTensor(batch_size, 1).fill_(0).to(device) + match_labels = torch.LongTensor(range(batch_size)).to(device) + fake_word_labels = torch.FloatTensor(batch_size, max_seq_len).fill_(0).to(device) + + return real_labels, fake_labels, match_labels, fake_word_labels + + +def load_params(generator: Generator, new_params: Any) -> Any: + """ + Function to load new parameters to the generator + """ + for param, new_p in zip(generator.parameters(), new_params): + param.data.copy_(new_p) + + +def get_image_arr(image_tensor: torch.Tensor) -> Any: + """ + Function to convert a tensor to an image array. + :param image_tensor: Tensor containing the image (shape: (batch_size, channels, height, width)) + """ + + image = image_tensor.cpu().detach().numpy() + image = (image + 1) * (255 / 2.0) + image = np.transpose(image, (0, 2, 3, 1)) # (B,C,H,W) -> (B,H,W,C) + image = image.astype(np.uint8) + return image # (B,H,W,C) + + +def get_captions(captions: torch.Tensor, ix2word: Dict[int, str]) -> Any: + """ + Function to convert a tensor to a list of captions. + :param captions: Tensor containing the captions (shape: (batch_size, max_seq_len)) + :param ix2word: Dictionary mapping indices to words + """ + captions = captions.cpu().detach().numpy() + captions = [[ix2word[ix] for ix in cap if ix != 0] for cap in captions] # type: ignore + return captions + + +def save_model( + generator: Generator, + discriminator: Discriminator, + image_encoder: InceptionEncoder, + text_encoder: TextEncoder, + epoch: int, + output_dir: pathlib.PosixPath, +) -> None: + """ + Function to save the model. + :param generator: Generator model + :param discriminator: Discriminator model + :param image_encoder: Image encoder model + :param text_encoder: Text encoder model + :param params: Parameters of the generator + :param epoch: Epoch number + :param output_dir: Output directory + """ + output_path = output_dir / "weights/" + Path(output_path / "generator").mkdir(parents=True, exist_ok=True) + torch.save( + generator.state_dict(), output_path / f"generator/generator_epoch_{epoch}.pth" + ) + Path(output_path / "discriminator").mkdir(parents=True, exist_ok=True) + torch.save( + discriminator.state_dict(), + output_path / f"discriminator/discriminator_epoch_{epoch}.pth", + ) + Path(output_path / "image_encoder").mkdir(parents=True, exist_ok=True) + torch.save( + image_encoder.state_dict(), + output_path / f"image_encoder/image_encoder_epoch_{epoch}.pth", + ) + Path(output_path / "text_encoder").mkdir(parents=True, exist_ok=True) + torch.save( + text_encoder.state_dict(), + output_path / f"text_encoder/text_encoder_epoch_{epoch}.pth", + ) + print(f"Model saved at epoch {epoch}.") + + +def save_image_and_caption( + fake_img_tensor: torch.Tensor, + img_tensor: torch.Tensor, + captions: torch.Tensor, + ix2word: Dict[int, str], + batch_idx: int, + epoch: int, + output_dir: pathlib.PosixPath, +) -> None: + """ + Function to save an image and its corresponding caption. + :param fake_img_tensor: Tensor containing the generated image + (shape: (batch_size, channels, height, width)) + + :param img_tensor: Tensor containing the image + (shape: (batch_size, channels, height, width)) + + :param captions: Tensor containing the captions + (shape: (batch_size, max_seq_len)) + + :param ix2word: Dictionary mapping indices to words + :param batch_idx: Batch index + :param epoch: Epoch number + :param output_dir: Output directory + """ + output_path = output_dir + output_path_text = output_dir + capt_list = get_captions(captions, ix2word) + img_arr = get_image_arr(img_tensor) + fake_img_arr = get_image_arr(fake_img_tensor) + for i in range(img_arr.shape[0]): + img = Image.fromarray(img_arr[i]) + fake_img = Image.fromarray(fake_img_arr[i]) + + fake_img_path = ( + output_path / f"generated/{epoch}_epochs/{batch_idx}_batch/{i+1}.png" + ) + img_path = output_path / f"real/{epoch}_epochs/{batch_idx}_batch/{i+1}.png" + text_path = ( + output_path_text / f"text/{epoch}_epochs/{batch_idx}_batch/captions.txt" + ) + + Path(fake_img_path).parent.mkdir(parents=True, exist_ok=True) + Path(img_path).parent.mkdir(parents=True, exist_ok=True) + Path(text_path).parent.mkdir(parents=True, exist_ok=True) + + fake_img.save(fake_img_path) + img.save(img_path) + + with open(text_path, "a", encoding="utf-8") as txt_file: + text_str = str(i + 1) + ": " + " ".join(capt_list[i]) + txt_file.write(text_str) + txt_file.write("\n") + + +def save_plot( + gen_loss: List[float], + disc_loss: List[float], + epoch: int, + batch_idx: int, + output_dir: pathlib.PosixPath, +) -> None: + """ + Function to save the plot of the loss. + :param gen_loss: List of generator losses + :param disc_loss: List of discriminator losses + :param epoch: Epoch number + :param batch_idx: Batch index + :param output_dir: Output directory + """ + pickle_path = output_dir / "losses/" + output_path = output_dir / "plots" / f"{epoch}_epochs/{batch_idx}_batch/" + Path(output_path).mkdir(parents=True, exist_ok=True) + Path(pickle_path).mkdir(parents=True, exist_ok=True) + + with open(pickle_path / "gen_loss.pkl", "wb") as pickl_file: + pickle.dump(gen_loss, pickl_file) + + with open(pickle_path / "disc_loss.pkl", "wb") as pickl_file: + pickle.dump(disc_loss, pickl_file) + + plt.style.use("fivethirtyeight") + plt.figure(figsize=(24, 12)) + plt.plot(gen_loss, label="Generator Loss") + plt.plot(disc_loss, label="Discriminator Loss") + plt.xlabel("No of Iterations") + plt.ylabel("Loss") + plt.legend() + plt.savefig(output_path / "loss.png", bbox_inches="tight") + plt.clf() + plt.close() + + +def load_model( + generator: Generator, + discriminator: Discriminator, + image_encoder: InceptionEncoder, + text_encoder: TextEncoder, + output_dir: pathlib.Path, + device: torch.device +) -> None: + """ + Function to load the model. + :param generator: Generator model + :param discriminator: Discriminator model + :param image_encoder: Image encoder model + :param text_encoder: Text encoder model + :param output_dir: Output directory + :param device: device to map the location of weights + """ + if (output_dir / "generator.pth").exists(): + generator.load_state_dict(torch.load(output_dir / "generator.pth", map_location=device)) + print("Generator loaded.") + if (output_dir / "discriminator.pth").exists(): + discriminator.load_state_dict(torch.load(output_dir / "discriminator.pth", map_location=device)) + print("Discriminator loaded.") + if (output_dir / "image_encoder.pth").exists(): + image_encoder.load_state_dict(torch.load(output_dir / "image_encoder.pth", map_location=device)) + print("Image Encoder loaded.") + + if (output_dir / "text_encoder.pth").exists(): + text_encoder.load_state_dict(torch.load(output_dir / "text_encoder.pth", map_location=device)) + print("Text Encoder loaded.") diff --git a/src/test_project/__init__.py b/src/test_project/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c98d66d286846041b80647bebc9c7c82b5bc7bcf --- /dev/null +++ b/src/test_project/__init__.py @@ -0,0 +1,2 @@ +"""test project imports""" +from src.test_project.example import Foo diff --git a/src/test_project/__pycache__/__init__.cpython-39.pyc b/src/test_project/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44d4828871b9c0d5924da0fc43bd7ca3c658cb3c Binary files /dev/null and b/src/test_project/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/test_project/__pycache__/example.cpython-39.pyc b/src/test_project/__pycache__/example.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3c1ab1a99eb7c965b3ca6b09ca27bcc8678df3d Binary files /dev/null and b/src/test_project/__pycache__/example.cpython-39.pyc differ diff --git a/src/test_project/example.py b/src/test_project/example.py new file mode 100644 index 0000000000000000000000000000000000000000..073fbb6ff1f6d54a671927d7e61d93f6e0ba7417 --- /dev/null +++ b/src/test_project/example.py @@ -0,0 +1,18 @@ +"""doing some stuff here""" + + +class Foo: + """sample text""" + + def __init__(self, first_var: int, second_var: int) -> None: + """init the bar""" + self.first = first_var + self.second = second_var + + def get_bar(self) -> int: + """return bar""" + return self.first + + def get_foo(self) -> int: + """return bar""" + return self.second diff --git a/src/visualization/.gitkeep b/src/visualization/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/visualization/__init__.py b/src/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/visualization/visualize.py b/src/visualization/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/weights/bird/captions.pickle b/weights/bird/captions.pickle new file mode 100644 index 0000000000000000000000000000000000000000..c1e2c803d643e6461437924e58c17ed674399fa2 --- /dev/null +++ b/weights/bird/captions.pickle @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62945b1308a6b25d37bc16383e4403813dd7f6b51053bc2dd788fc71e92b52ec +size 4823605 diff --git a/weights/bird/generator.pth b/weights/bird/generator.pth new file mode 100644 index 0000000000000000000000000000000000000000..8b889e8e45d9816f01fd92d012543d2b81eed6fb --- /dev/null +++ b/weights/bird/generator.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5200ebbde8e7e73ecf4186f6101d0c9264e0e22358526120f3f534c9db8e538 +size 61722743 diff --git a/weights/bird/image_encoder.pth b/weights/bird/image_encoder.pth new file mode 100644 index 0000000000000000000000000000000000000000..6a839ab344b698fec62d5171f323138c1deebd7b --- /dev/null +++ b/weights/bird/image_encoder.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:707ee0e20699afc4d36bfab2ce397cf5e21850ed1b2ebf814a02fa1a24a36c86 +size 90371601 diff --git a/weights/bird/text_encoder.pth b/weights/bird/text_encoder.pth new file mode 100644 index 0000000000000000000000000000000000000000..c4b637169dc1495933da94551e74f747783e2801 --- /dev/null +++ b/weights/bird/text_encoder.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f716f37e3be768d53acdbd2bf8bd8783e83402fa84e3e33b84fb6b6decf2e81 +size 7163559 diff --git a/weights/coco/captions.pickle b/weights/coco/captions.pickle new file mode 100644 index 0000000000000000000000000000000000000000..622e32768bea1e1e886599555f2b9c4dce8bcad5 --- /dev/null +++ b/weights/coco/captions.pickle @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f32aa550147910b7e43c72b60864d894a3a379cbde370560ae49c79f148fbf7 +size 25016959 diff --git a/weights/coco/generator.pth b/weights/coco/generator.pth new file mode 100644 index 0000000000000000000000000000000000000000..e3dc6164eff9a265995bc8e902f3b4bff1c984a0 --- /dev/null +++ b/weights/coco/generator.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a86d0572f145b9c31d4e277c095f6a4b367a6080e37b56f1ecf560d99f7b4db +size 61722743 diff --git a/weights/coco/image_encoder.pth b/weights/coco/image_encoder.pth new file mode 100644 index 0000000000000000000000000000000000000000..ef13068e1c73aca23b730cddf533f31b2452c4a4 --- /dev/null +++ b/weights/coco/image_encoder.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c015411362fe7d55bd601e0f821b1976f495de7108eaeeb58f7713ef51c53dd +size 90371601 diff --git a/weights/coco/text_encoder.pth b/weights/coco/text_encoder.pth new file mode 100644 index 0000000000000000000000000000000000000000..6348d9bcb439fc1a27fc55f8de18200df9e12074 --- /dev/null +++ b/weights/coco/text_encoder.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6846ad7fdc59a855b636396f2bec544a245b19826325959413751122bde562e8 +size 29534887 diff --git a/weights/utkface/captions.pickle b/weights/utkface/captions.pickle new file mode 100644 index 0000000000000000000000000000000000000000..942dbf7e7de7481bab1e2cca2833fd0ccaa21cf9 --- /dev/null +++ b/weights/utkface/captions.pickle @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69a017da095502591cedeb9077bff65733d6e3874f91b440627f204e22422cab +size 276803 diff --git a/weights/utkface/generator.pth b/weights/utkface/generator.pth new file mode 100644 index 0000000000000000000000000000000000000000..726da77cfef56e14484897e3face79c0d52f364b --- /dev/null +++ b/weights/utkface/generator.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab4e8440174051fbfd0b0875c2d9752a37439e62ce1206592aacd136445ee377 +size 61722743 diff --git a/weights/utkface/image_encoder.pth b/weights/utkface/image_encoder.pth new file mode 100644 index 0000000000000000000000000000000000000000..a759a0c74d320bd04ffbc93ca881ce6669d7c6b6 --- /dev/null +++ b/weights/utkface/image_encoder.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bcbad5041863652d0e4141baa478133ff014682a63ae05a2c74ad08bcd9ee54 +size 90371601 diff --git a/weights/utkface/text_encoder.pth b/weights/utkface/text_encoder.pth new file mode 100644 index 0000000000000000000000000000000000000000..e8dcb9e0b91fa286f1986de97d69a962b33987b6 --- /dev/null +++ b/weights/utkface/text_encoder.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9debe6414c6b17e0e36bbe5125ad4a5cee5e7b086ccb82e05b0df8ec56a2762e +size 2066087