Spaces:
Running
Running
################################################## PACKAGES ############################################################ | |
################################################# PACKAGES ############################################################# | |
# PyTorch for deep learning operations | |
import torch | |
import torch.nn as nn | |
# PyTorch data loading and utilities | |
import torch.multiprocessing | |
# Additional PyTorch modules and libraries | |
import cv2 # OpenCV for image processing | |
# Transfer Learning model library | |
import timm | |
# Data manipulation and handling | |
import requests | |
# COCO dataset tools | |
from pycocotools.coco import COCO | |
import numpy as np | |
# Hugging Face Transformers library for BERT models | |
from transformers import BertModel, BertTokenizer, DistilBertModel, DistilBertConfig, DistilBertTokenizer | |
import torch.nn.functional as F | |
# Image processing and augmentations | |
import albumentations as A | |
# Visualization and progress tracking | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
# Additional utility for iterating over combinations | |
import itertools | |
from albumentations.pytorch import ToTensorV2 | |
import pandas as pd | |
from configs import CFG | |
from huggingface_hub import PyTorchModelHubMixin | |
################################################### MODELS ############################################################ | |
################################################# MODELS ############################################################## | |
class ProjectionHead(nn.Module): | |
def __init__(self, input_dim, projection_dim=CFG.projection_dim, dropout_rate=CFG.dropout_rate, *args, **kwargs): | |
""" | |
Projection Head module for contrastive learning. | |
:param input_dim: Dimensionality of input features. | |
:param projection_dim: Dimensionality of projected features (default: CFG.projection_dim). | |
:param dropout_rate: Dropout rate (default: CFG.dropout_rate). | |
""" | |
super(ProjectionHead, self).__init__(*args, **kwargs) | |
# Attributes | |
self.input_dim = input_dim | |
self.projection_dim = projection_dim | |
self.dropout_rate = dropout_rate | |
# Layers | |
self.linear_layer1 = nn.Linear(self.input_dim, self.projection_dim) | |
self.gelu = nn.GELU() | |
self.linear_layer2 = nn.Linear(self.projection_dim, self.projection_dim) | |
self.dropout = nn.Dropout(self.dropout_rate) | |
self.normalization_layer = nn.LayerNorm(self.projection_dim) | |
def forward(self, inputs): | |
""" | |
Forward pass of the projection head. | |
:param inputs: Input features. | |
:return: Projected features. | |
""" | |
x = inputs | |
x = self.linear_layer1(x) | |
x = self.gelu(x) | |
x = self.linear_layer2(x) | |
x = self.dropout(x) | |
x = self.normalization_layer(x) | |
return x | |
def __call__(self, inputs): | |
""" | |
Callable method for the projection head. | |
:param inputs: Input features. | |
:return: Projected features. | |
""" | |
return self.forward(inputs) | |
class ImageEncoder(nn.Module): | |
def __init__(self, model_name=CFG.vit_name, projection_dim=CFG.projection_dim, trainable=False, | |
dropout_rate=CFG.dropout_rate, *args, **kwargs): | |
""" | |
Image encoder module using Vision Transformer (ViT) backbone. | |
:param model_name: Name of the Vision Transformer model (default: CFG.vit_name). | |
:param projection_dim: Dimensionality of projected features (default: CFG.projection_dim). | |
:param trainable: Whether to make the backbone trainable (default: False). | |
:param dropout_rate: Dropout rate (default: CFG.dropout_rate). | |
""" | |
super(ImageEncoder, self).__init__(*args, **kwargs) | |
# Attributes | |
self.model_name = model_name | |
self.projection_dim = projection_dim | |
self.trainable = trainable | |
self.dropout_rate = dropout_rate | |
# Models | |
self.pretrained_vit = timm.create_model(self.model_name, pretrained=True, num_classes=0) | |
self.projection_head = ProjectionHead(self.pretrained_vit.embed_dim, self.projection_dim, self.dropout_rate) | |
# Freeze pretrained ViT layers | |
for parameter in self.pretrained_vit.parameters(): | |
parameter.requires_grad = self.trainable | |
def forward(self, images): | |
""" | |
Forward pass of the image encoder. | |
:param images: Input images. | |
:return: Projected features. | |
""" | |
x = images | |
# forward_features: to return sequences (encoder) -> torch.Size([batch_size, 197, 768]) forward_head: to | |
# return flattened sequences (vectors) -> torch.Size([batch_size, 768]) if num_classes=0 (no classification) | |
# in timm.create_model and torch.Size([batch_size, 1000]) otherwise (classification) | |
x = self.pretrained_vit.forward_features(x) | |
# output: torch.Size([batch_size, 197, 256]) | |
x = self.projection_head(x) | |
return x | |
def __call__(self, images): | |
""" | |
Callable method for the image encoder. | |
:param images: Input images. | |
:return: Projected features. | |
""" | |
return self.forward(images) | |
class TextEncoder(nn.Module): | |
def __init__(self, model_name=CFG.bert_name, projection_dim=CFG.projection_dim, | |
trainable=False, dropout_rate=CFG.dropout_rate, *args, **kwargs): | |
""" | |
Text encoder module using BERT backbone. | |
:param model_name: Name of the BERT model (default: CFG.bert_name). | |
:param projection_dim: Dimensionality of projected features (default: CFG.projection_dim). | |
:param trainable: Whether to make the backbone trainable (default: False). | |
:param dropout_rate: Dropout rate (default: CFG.dropout_rate). | |
""" | |
super(TextEncoder, self).__init__(*args, **kwargs) | |
# Attributes | |
self.model_name = model_name | |
self.projection_dim = projection_dim | |
self.dropout_rate = dropout_rate | |
self.trainable = trainable | |
# Models | |
self.pretrained_bert = BertModel.from_pretrained(self.model_name) | |
self.projection_head = ProjectionHead(self.pretrained_bert.config.hidden_size, | |
self.projection_dim, self.dropout_rate) | |
# Freeze BERT | |
for parameter in self.pretrained_bert.parameters(): | |
parameter.requires_grad = self.trainable | |
def forward(self, captions): | |
""" | |
Forward pass of the text encoder. | |
:param captions: Input captions (input_ids, attention_mask). | |
:return: Projected features. | |
""" | |
input_ids, attention_mask = captions | |
# last_hidden_state: torch.Size([batch_size, sequence, 768]) | |
# pooler_output: torch.Size([batch_size, 768]) | |
x = self.pretrained_bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state | |
# output: torch.Size([batch_size, sequence, 256]) | |
x = self.projection_head(x) | |
return x | |
def __call__(self, captions): | |
""" | |
Callable method for the text encoder. | |
:param captions: Input captions (input_ids, attention_mask). | |
:return: Projected features. | |
""" | |
return self.forward(captions) | |
class ModalityTokenEncoder(nn.Module): | |
def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', token_dim=CFG.token_dim, *args, **kwargs): | |
""" | |
Modality token encoder module for encoding modality token information. | |
:param projection_dim: Dimensionality of projected features (default: CFG.projection_dim). | |
:param token_size: Token size. | |
:param device: Device to run the module on (default: 'cpu'). | |
:param token_dim: Dimension of tokens | |
""" | |
super(ModalityTokenEncoder, self).__init__(*args, **kwargs) | |
# Attributes | |
self.projection_dim = projection_dim | |
self.device = device | |
self.token_size = token_size | |
self.token_dim = token_dim | |
# Models | |
text_variance = torch.rand(1) * 0.5 + 0.1 | |
image_variance = torch.rand(1) * 0.5 + 0.1 | |
self.text_token = nn.Parameter(torch.normal(mean=0, std=text_variance.item(), | |
size=(self.token_size, self.token_dim)).to(self.device)) | |
self.image_token = nn.Parameter(torch.normal(mean=0, std=image_variance.item(), | |
size=(self.token_size, self.token_dim)).to(self.device)) | |
# Projection | |
self.token_projection = nn.Sequential( | |
nn.Linear(self.token_dim, 64), | |
nn.ReLU(), | |
nn.Linear(64, 128), | |
nn.ReLU(), | |
nn.Linear(128, self.projection_dim), | |
nn.LayerNorm(self.projection_dim) | |
) | |
def forward(self, modality_type): | |
""" | |
Forward pass of the modality encoder. | |
:param modality_type: Input token indicator. | |
:return: Projected features. | |
""" | |
token = torch.where(torch.tensor(modality_type == "image"), self.image_token, self.text_token) | |
token_features = self.token_projection(token) | |
return token_features | |
def __call__(self, modality_type): | |
""" | |
Callable method for the token encoder. | |
:param modality_type: Input token indicator. | |
:return: Projected features. | |
""" | |
return self.forward(modality_type) | |
class UniversalProjectionOutput: | |
def __init__(self, outputs): | |
""" | |
Wrapper class for projection model outputs. | |
:param outputs: Dictionary containing model outputs. | |
""" | |
self.outputs = outputs | |
def __getattr__(self, name): | |
""" | |
Retrieve attribute from outputs dictionary. | |
:param name: Name of the attribute to retrieve. | |
:return: Value of the attribute. | |
""" | |
if name in self.outputs: | |
return self.outputs[name] | |
else: | |
raise AttributeError(f"'UniversalProjectionOutput' object has no attribute '{name}'") | |
class UniversalProjectionEncoder(nn.Module): | |
def __init__(self, input_dim=CFG.projection_dim, num_head=CFG.num_head, num_layers=CFG.num_layers, *args, **kwargs): | |
""" | |
Initialize Universal Projection module. | |
:param input_dim: Dimensionality of the input embeddings. Defaults to CFG.projection_dim. | |
:param num_head: Number of attention heads. Defaults to CFG.num_head. | |
:param num_layers: Number of transformer layers. Defaults to CFG.num_layers. | |
""" | |
super(UniversalProjectionEncoder, self).__init__(*args, **kwargs) | |
self.input_dim = input_dim | |
self.num_head = num_head | |
self.num_layers = num_layers | |
self.transformer_encoder_block = nn.TransformerEncoderLayer( | |
d_model=self.input_dim, | |
nhead=self.num_head, | |
batch_first=True | |
) | |
self.transformer_encoder = nn.TransformerEncoder( | |
self.transformer_encoder_block, | |
num_layers=self.num_layers | |
) | |
# self.transformer_encoder = TransformerModel(self.input_dim, self.num_head, self.num_layers) | |
# model_name = 'bert-large-uncased' | |
self.layer_normalization = nn.LayerNorm(self.input_dim) | |
# self.transfopip install torch torchvision -Urmer_encoder = BertModel.from_pretrained(model_name) | |
def forward(self, inputs): | |
# x: image or caption embeddings | |
x, tokens = inputs | |
## Universal Projection block | |
tokens = tokens.unsqueeze(0).expand(x.size()[0], -1, -1) | |
# Concatenate tokens with image/caption embeddings | |
# output_tensor = torch.cat((tokens, x), dim=1) | |
output_tensor = x + tokens | |
# Normalization | |
output_norm = self.layer_normalization(output_tensor) | |
# Projection | |
output_encoder = self.transformer_encoder(output_norm) | |
## Residual Connection | |
residual_output = output_encoder + output_tensor | |
# output = output_encoder[:, CFG.token_size:, :] | |
# Residual connection | |
return UniversalProjectionOutput({'last_hidden_state': residual_output, | |
'mean_output': torch.mean(residual_output, dim=1), | |
'pooler_output': residual_output[:, 0, :]}) | |
def __call__(self, inputs): | |
return self.forward(inputs) | |
class OneEncoder(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, image_encoder=ImageEncoder(), text_encoder=TextEncoder(), | |
modality_token_encoder=ModalityTokenEncoder(), | |
universal_projection_encoder=UniversalProjectionEncoder(), device='cpu', | |
tokenizer=BertTokenizer.from_pretrained(CFG.bert_name), | |
image_preprocessor=A.Compose([A.Resize(CFG.image_size, CFG.image_size, always_apply=True), | |
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), | |
always_apply=True), ToTensorV2()]), | |
*args, **kwargs): | |
""" | |
Initialize the model. | |
:param image_encoder: Image encoder module (default: ImageEncoder()). | |
:param text_encoder: Text encoder module (default: TextEncoder()). | |
:param modality_token_encoder: Modality encoder module (default: ModalityEncoder()). | |
:param universal_projection_encoder: Universal projection encoder module (default: UniversalProjection()). | |
:param device: Device to run the model on (default: 'cpu'). | |
:param tokenizer: Tokenizer for text encoding (default: BertTokenizer.from_pretrained(CFG.bert_name)). | |
:param image_preprocessor: Preprocessor for image inputs (default: A.Compose([...])). | |
""" | |
super(OneEncoder, self).__init__(*args, **kwargs) | |
self.device = device | |
self.image_encoder = image_encoder | |
self.text_encoder = text_encoder | |
self.universal_projection_encoder = universal_projection_encoder | |
self.modality_token_encoder = modality_token_encoder | |
self.modality_token_encoder.device = self.device | |
self.tokenizer = tokenizer | |
self.image_preprocessor = image_preprocessor | |
# The learnable temperature parameter τ was initialized to the equivalent of 0.07 from (Wu et al., 2018) | |
# and clipped to prevent scaling the logits by more than 100, which we found necessary | |
# to prevent training instability. | |
self.temperature = nn.Parameter(torch.tensor(0.07).to(self.device)) | |
def load_image(cls, image_path): | |
# Load online image | |
if image_path.startswith("http"): | |
response = requests.get(image_path) | |
# Check if the request was successful | |
if response.status_code == 200: | |
# Convert the image content to a numpy array | |
img_array = np.asarray(bytearray(response.content), dtype=np.uint8) | |
# Decode the image using OpenCV | |
image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | |
# Convert BGR to RGB | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# Load local image | |
else: | |
image = cv2.imread(image_path) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
return image | |
def encode_image(self, image_paths=None, image_tensors=None, outputs="mean"): | |
""" | |
Encode images into feature vectors. | |
:param image_paths: List of image paths. | |
:param image_tensors: Torch tensor (batch, 3, 224, 224). | |
:param outputs type of outputs: mean, pooler, sequence | |
:return: Encoded image features. | |
""" | |
if image_paths is not None: | |
image_processed = [self.image_preprocessor(image=self.load_image(image))["image"] for image in image_paths] | |
image_processed = torch.stack(image_processed).to(self.device) | |
with torch.no_grad(): | |
image_features = self.image_encoder(image_processed.to(self.device)) | |
modality_token_feature = self.modality_token_encoder("image") | |
output_features = self.universal_projection_encoder([image_features, modality_token_feature]) | |
elif image_tensors is not None: | |
with torch.no_grad(): | |
image_features = self.image_encoder(image_tensors.to(self.device)) | |
modality_token_feature = self.modality_token_encoder("image") | |
output_features = self.universal_projection_encoder([image_features, modality_token_feature]) | |
if outputs == "mean": | |
image_features = output_features.mean_output | |
elif outputs == "sequence": | |
image_features = output_features.last_hidden_state | |
else: | |
image_features = output_features.pooler_output | |
return image_features | |
def encode_text(self, texts, max_length=128, outputs="mean"): | |
""" | |
Encode text descriptions into feature vectors. | |
:param texts: List of text descriptions. | |
:param max_length: Maximum length of the text sequences (default: 128). | |
:param outputs type of outputs: mean, sequence, pooler | |
:return: Encoded text features. | |
""" | |
encoded_query = self.tokenizer( | |
texts, padding=True, truncation=True, max_length=max_length | |
) | |
batch = { | |
key: torch.tensor(values).to(self.device) | |
for key, values in encoded_query.items() | |
} | |
with torch.no_grad(): | |
text_features = self.text_encoder([ | |
batch["input_ids"], batch["attention_mask"] | |
]) | |
modality_token_feature = self.modality_token_encoder("text") | |
output_features = self.universal_projection_encoder([text_features, modality_token_feature]) | |
if outputs == "mean": | |
text_features = output_features.mean_output | |
elif outputs == "sequence": | |
text_features = output_features.last_hidden_state | |
else: | |
text_features = output_features.pooler_output | |
return text_features | |
def matching(self, image_paths, texts, normalize=True, top_k=None, strategy="similarity", temperature=0.0): | |
""" | |
Calculate similarities between images and texts. | |
:param image_paths: List of paths to images. | |
:param texts: List of text descriptions. | |
:param normalize: Whether to normalize the features (default: True). | |
:param top_k: Return top K results (default: None). | |
:param strategy: Matching strategy, either 'similarity' or 'softmax' (default: 'similarity'). | |
:param temperature: change real distribution, default = 2.5 | |
:return: If top_k is provided, returns top probabilities and labels, otherwise returns dot similarities. | |
""" | |
image_features = self.encode_image(image_paths=image_paths) | |
text_features = self.encode_text(texts=texts) | |
if normalize: | |
image_features = F.normalize(image_features, p=2, dim=-1) | |
text_features = F.normalize(text_features, p=2, dim=-1) | |
dot_similarities = (image_features @ text_features.T) * torch.exp(torch.tensor(temperature).to(self.device)) | |
if strategy == 'softmax': | |
dot_similarities = (float(len(set(texts))) * dot_similarities).softmax(dim=-1) | |
if top_k is not None: | |
top_probs, top_labels = dot_similarities.cpu().topk(top_k, dim=-1) | |
return top_probs, top_labels | |
else: | |
return dot_similarities, None | |
def image_retrieval(self, query, image_paths, image_embeddings=None, temperature=0.0, n=9, plot=False): | |
""" | |
Perform image retrieval based on a text query. | |
:param query: Text query (string). | |
:param image_paths: List of image paths (optional). | |
:param image_embeddings: Precomputed image embeddings (optional). | |
:param temperature: change real distribution, default = 2.5 | |
:param n: Number of images to retrieve (default: 9). | |
:param plot: Whether to plot the retrieved images (default: False). | |
:return: Tuple containing similarity values and indices of the retrieved images. | |
""" | |
text_embeddings = self.encode_text([query]) | |
if image_embeddings is None: | |
image_embeddings = self.encode_image(image_paths=image_paths) | |
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) | |
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) | |
dot_similarity = (text_embeddings_n @ image_embeddings_n.T) * torch.exp( | |
torch.tensor(temperature).to(self.device)) | |
if n > len(image_paths): | |
n = len(image_paths) | |
values, indices = torch.topk(dot_similarity.cpu().squeeze(0), n) | |
if plot: | |
nrows = int(np.sqrt(n)) | |
ncols = int(np.ceil(n / nrows)) | |
matches = [image_paths[idx] for idx in indices] | |
fig, axes = plt.subplots(nrows, ncols, figsize=(20, 20)) | |
for match, ax in zip(matches, axes.flatten()): | |
image = self.load_image(f"{match}") | |
ax.imshow(image) | |
ax.axis("off") | |
plt.savefig("img.png") | |
#fig.suptitle(query) | |
#plt.show() | |
#return values, indices | |
def text_retrieval(self, query, texts, text_embeddings=None, n=9, plot_image=False, temperature=0.0): | |
""" | |
Perform text retrieval based on an image query. | |
:param query: Image query (path of image). | |
:param texts: List of text samples. | |
:param text_embeddings: Precomputed text embeddings (optional). | |
:param n: Number of texts to retrieve (default: 9). | |
:param plot_image: Plot the query | |
:param temperature: change real distribution, default = 2.5 | |
:return: List of retrieved text samples and its probabilities. | |
""" | |
if text_embeddings is None: | |
text_embeddings = self.encode_text(texts) | |
image_embeddings = self.encode_image([query]) | |
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) | |
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) | |
dot_similarity = (image_embeddings_n @ text_embeddings_n.T) * torch.exp( | |
torch.tensor(temperature).to(self.device)) | |
if n > len(texts): | |
n = len(texts) | |
values, indices = torch.topk(dot_similarity.cpu().squeeze(0), n) | |
matches = [texts[idx] for idx in indices] | |
if plot_image: | |
# Read and plot the image | |
image = self.load_image(query) | |
# Plot the image | |
plt.imshow(image) | |
#plt.title('Random Image') | |
plt.axis('off') | |
plt.savefig("img.png") | |
#plt.show() | |
#return matches, values | |