OneEncoder-retriever / text_image.py
bilalfaye's picture
Update text_image.py
3022d51 verified
################################################## PACKAGES ############################################################
################################################# PACKAGES #############################################################
# PyTorch for deep learning operations
import torch
import torch.nn as nn
# PyTorch data loading and utilities
import torch.multiprocessing
# Additional PyTorch modules and libraries
import cv2 # OpenCV for image processing
# Transfer Learning model library
import timm
# Data manipulation and handling
import requests
# COCO dataset tools
from pycocotools.coco import COCO
import numpy as np
# Hugging Face Transformers library for BERT models
from transformers import BertModel, BertTokenizer, DistilBertModel, DistilBertConfig, DistilBertTokenizer
import torch.nn.functional as F
# Image processing and augmentations
import albumentations as A
# Visualization and progress tracking
from tqdm import tqdm
import matplotlib.pyplot as plt
# Additional utility for iterating over combinations
import itertools
from albumentations.pytorch import ToTensorV2
import pandas as pd
from configs import CFG
from huggingface_hub import PyTorchModelHubMixin
################################################### MODELS ############################################################
################################################# MODELS ##############################################################
class ProjectionHead(nn.Module):
def __init__(self, input_dim, projection_dim=CFG.projection_dim, dropout_rate=CFG.dropout_rate, *args, **kwargs):
"""
Projection Head module for contrastive learning.
:param input_dim: Dimensionality of input features.
:param projection_dim: Dimensionality of projected features (default: CFG.projection_dim).
:param dropout_rate: Dropout rate (default: CFG.dropout_rate).
"""
super(ProjectionHead, self).__init__(*args, **kwargs)
# Attributes
self.input_dim = input_dim
self.projection_dim = projection_dim
self.dropout_rate = dropout_rate
# Layers
self.linear_layer1 = nn.Linear(self.input_dim, self.projection_dim)
self.gelu = nn.GELU()
self.linear_layer2 = nn.Linear(self.projection_dim, self.projection_dim)
self.dropout = nn.Dropout(self.dropout_rate)
self.normalization_layer = nn.LayerNorm(self.projection_dim)
def forward(self, inputs):
"""
Forward pass of the projection head.
:param inputs: Input features.
:return: Projected features.
"""
x = inputs
x = self.linear_layer1(x)
x = self.gelu(x)
x = self.linear_layer2(x)
x = self.dropout(x)
x = self.normalization_layer(x)
return x
def __call__(self, inputs):
"""
Callable method for the projection head.
:param inputs: Input features.
:return: Projected features.
"""
return self.forward(inputs)
class ImageEncoder(nn.Module):
def __init__(self, model_name=CFG.vit_name, projection_dim=CFG.projection_dim, trainable=False,
dropout_rate=CFG.dropout_rate, *args, **kwargs):
"""
Image encoder module using Vision Transformer (ViT) backbone.
:param model_name: Name of the Vision Transformer model (default: CFG.vit_name).
:param projection_dim: Dimensionality of projected features (default: CFG.projection_dim).
:param trainable: Whether to make the backbone trainable (default: False).
:param dropout_rate: Dropout rate (default: CFG.dropout_rate).
"""
super(ImageEncoder, self).__init__(*args, **kwargs)
# Attributes
self.model_name = model_name
self.projection_dim = projection_dim
self.trainable = trainable
self.dropout_rate = dropout_rate
# Models
self.pretrained_vit = timm.create_model(self.model_name, pretrained=True, num_classes=0)
self.projection_head = ProjectionHead(self.pretrained_vit.embed_dim, self.projection_dim, self.dropout_rate)
# Freeze pretrained ViT layers
for parameter in self.pretrained_vit.parameters():
parameter.requires_grad = self.trainable
def forward(self, images):
"""
Forward pass of the image encoder.
:param images: Input images.
:return: Projected features.
"""
x = images
# forward_features: to return sequences (encoder) -> torch.Size([batch_size, 197, 768]) forward_head: to
# return flattened sequences (vectors) -> torch.Size([batch_size, 768]) if num_classes=0 (no classification)
# in timm.create_model and torch.Size([batch_size, 1000]) otherwise (classification)
x = self.pretrained_vit.forward_features(x)
# output: torch.Size([batch_size, 197, 256])
x = self.projection_head(x)
return x
def __call__(self, images):
"""
Callable method for the image encoder.
:param images: Input images.
:return: Projected features.
"""
return self.forward(images)
class TextEncoder(nn.Module):
def __init__(self, model_name=CFG.bert_name, projection_dim=CFG.projection_dim,
trainable=False, dropout_rate=CFG.dropout_rate, *args, **kwargs):
"""
Text encoder module using BERT backbone.
:param model_name: Name of the BERT model (default: CFG.bert_name).
:param projection_dim: Dimensionality of projected features (default: CFG.projection_dim).
:param trainable: Whether to make the backbone trainable (default: False).
:param dropout_rate: Dropout rate (default: CFG.dropout_rate).
"""
super(TextEncoder, self).__init__(*args, **kwargs)
# Attributes
self.model_name = model_name
self.projection_dim = projection_dim
self.dropout_rate = dropout_rate
self.trainable = trainable
# Models
self.pretrained_bert = BertModel.from_pretrained(self.model_name)
self.projection_head = ProjectionHead(self.pretrained_bert.config.hidden_size,
self.projection_dim, self.dropout_rate)
# Freeze BERT
for parameter in self.pretrained_bert.parameters():
parameter.requires_grad = self.trainable
def forward(self, captions):
"""
Forward pass of the text encoder.
:param captions: Input captions (input_ids, attention_mask).
:return: Projected features.
"""
input_ids, attention_mask = captions
# last_hidden_state: torch.Size([batch_size, sequence, 768])
# pooler_output: torch.Size([batch_size, 768])
x = self.pretrained_bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
# output: torch.Size([batch_size, sequence, 256])
x = self.projection_head(x)
return x
def __call__(self, captions):
"""
Callable method for the text encoder.
:param captions: Input captions (input_ids, attention_mask).
:return: Projected features.
"""
return self.forward(captions)
class ModalityTokenEncoder(nn.Module):
def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', token_dim=CFG.token_dim, *args, **kwargs):
"""
Modality token encoder module for encoding modality token information.
:param projection_dim: Dimensionality of projected features (default: CFG.projection_dim).
:param token_size: Token size.
:param device: Device to run the module on (default: 'cpu').
:param token_dim: Dimension of tokens
"""
super(ModalityTokenEncoder, self).__init__(*args, **kwargs)
# Attributes
self.projection_dim = projection_dim
self.device = device
self.token_size = token_size
self.token_dim = token_dim
# Models
text_variance = torch.rand(1) * 0.5 + 0.1
image_variance = torch.rand(1) * 0.5 + 0.1
self.text_token = nn.Parameter(torch.normal(mean=0, std=text_variance.item(),
size=(self.token_size, self.token_dim)).to(self.device))
self.image_token = nn.Parameter(torch.normal(mean=0, std=image_variance.item(),
size=(self.token_size, self.token_dim)).to(self.device))
# Projection
self.token_projection = nn.Sequential(
nn.Linear(self.token_dim, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, self.projection_dim),
nn.LayerNorm(self.projection_dim)
)
def forward(self, modality_type):
"""
Forward pass of the modality encoder.
:param modality_type: Input token indicator.
:return: Projected features.
"""
token = torch.where(torch.tensor(modality_type == "image"), self.image_token, self.text_token)
token_features = self.token_projection(token)
return token_features
def __call__(self, modality_type):
"""
Callable method for the token encoder.
:param modality_type: Input token indicator.
:return: Projected features.
"""
return self.forward(modality_type)
class UniversalProjectionOutput:
def __init__(self, outputs):
"""
Wrapper class for projection model outputs.
:param outputs: Dictionary containing model outputs.
"""
self.outputs = outputs
def __getattr__(self, name):
"""
Retrieve attribute from outputs dictionary.
:param name: Name of the attribute to retrieve.
:return: Value of the attribute.
"""
if name in self.outputs:
return self.outputs[name]
else:
raise AttributeError(f"'UniversalProjectionOutput' object has no attribute '{name}'")
class UniversalProjectionEncoder(nn.Module):
def __init__(self, input_dim=CFG.projection_dim, num_head=CFG.num_head, num_layers=CFG.num_layers, *args, **kwargs):
"""
Initialize Universal Projection module.
:param input_dim: Dimensionality of the input embeddings. Defaults to CFG.projection_dim.
:param num_head: Number of attention heads. Defaults to CFG.num_head.
:param num_layers: Number of transformer layers. Defaults to CFG.num_layers.
"""
super(UniversalProjectionEncoder, self).__init__(*args, **kwargs)
self.input_dim = input_dim
self.num_head = num_head
self.num_layers = num_layers
self.transformer_encoder_block = nn.TransformerEncoderLayer(
d_model=self.input_dim,
nhead=self.num_head,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
self.transformer_encoder_block,
num_layers=self.num_layers
)
# self.transformer_encoder = TransformerModel(self.input_dim, self.num_head, self.num_layers)
# model_name = 'bert-large-uncased'
self.layer_normalization = nn.LayerNorm(self.input_dim)
# self.transfopip install torch torchvision -Urmer_encoder = BertModel.from_pretrained(model_name)
def forward(self, inputs):
# x: image or caption embeddings
x, tokens = inputs
## Universal Projection block
tokens = tokens.unsqueeze(0).expand(x.size()[0], -1, -1)
# Concatenate tokens with image/caption embeddings
# output_tensor = torch.cat((tokens, x), dim=1)
output_tensor = x + tokens
# Normalization
output_norm = self.layer_normalization(output_tensor)
# Projection
output_encoder = self.transformer_encoder(output_norm)
## Residual Connection
residual_output = output_encoder + output_tensor
# output = output_encoder[:, CFG.token_size:, :]
# Residual connection
return UniversalProjectionOutput({'last_hidden_state': residual_output,
'mean_output': torch.mean(residual_output, dim=1),
'pooler_output': residual_output[:, 0, :]})
def __call__(self, inputs):
return self.forward(inputs)
class OneEncoder(nn.Module, PyTorchModelHubMixin):
def __init__(self, image_encoder=ImageEncoder(), text_encoder=TextEncoder(),
modality_token_encoder=ModalityTokenEncoder(),
universal_projection_encoder=UniversalProjectionEncoder(), device='cpu',
tokenizer=BertTokenizer.from_pretrained(CFG.bert_name),
image_preprocessor=A.Compose([A.Resize(CFG.image_size, CFG.image_size, always_apply=True),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),
always_apply=True), ToTensorV2()]),
*args, **kwargs):
"""
Initialize the model.
:param image_encoder: Image encoder module (default: ImageEncoder()).
:param text_encoder: Text encoder module (default: TextEncoder()).
:param modality_token_encoder: Modality encoder module (default: ModalityEncoder()).
:param universal_projection_encoder: Universal projection encoder module (default: UniversalProjection()).
:param device: Device to run the model on (default: 'cpu').
:param tokenizer: Tokenizer for text encoding (default: BertTokenizer.from_pretrained(CFG.bert_name)).
:param image_preprocessor: Preprocessor for image inputs (default: A.Compose([...])).
"""
super(OneEncoder, self).__init__(*args, **kwargs)
self.device = device
self.image_encoder = image_encoder
self.text_encoder = text_encoder
self.universal_projection_encoder = universal_projection_encoder
self.modality_token_encoder = modality_token_encoder
self.modality_token_encoder.device = self.device
self.tokenizer = tokenizer
self.image_preprocessor = image_preprocessor
# The learnable temperature parameter τ was initialized to the equivalent of 0.07 from (Wu et al., 2018)
# and clipped to prevent scaling the logits by more than 100, which we found necessary
# to prevent training instability.
self.temperature = nn.Parameter(torch.tensor(0.07).to(self.device))
@classmethod
def load_image(cls, image_path):
# Load online image
if image_path.startswith("http"):
response = requests.get(image_path)
# Check if the request was successful
if response.status_code == 200:
# Convert the image content to a numpy array
img_array = np.asarray(bytearray(response.content), dtype=np.uint8)
# Decode the image using OpenCV
image = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
# Convert BGR to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Load local image
else:
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def encode_image(self, image_paths=None, image_tensors=None, outputs="mean"):
"""
Encode images into feature vectors.
:param image_paths: List of image paths.
:param image_tensors: Torch tensor (batch, 3, 224, 224).
:param outputs type of outputs: mean, pooler, sequence
:return: Encoded image features.
"""
if image_paths is not None:
image_processed = [self.image_preprocessor(image=self.load_image(image))["image"] for image in image_paths]
image_processed = torch.stack(image_processed).to(self.device)
with torch.no_grad():
image_features = self.image_encoder(image_processed.to(self.device))
modality_token_feature = self.modality_token_encoder("image")
output_features = self.universal_projection_encoder([image_features, modality_token_feature])
elif image_tensors is not None:
with torch.no_grad():
image_features = self.image_encoder(image_tensors.to(self.device))
modality_token_feature = self.modality_token_encoder("image")
output_features = self.universal_projection_encoder([image_features, modality_token_feature])
if outputs == "mean":
image_features = output_features.mean_output
elif outputs == "sequence":
image_features = output_features.last_hidden_state
else:
image_features = output_features.pooler_output
return image_features
def encode_text(self, texts, max_length=128, outputs="mean"):
"""
Encode text descriptions into feature vectors.
:param texts: List of text descriptions.
:param max_length: Maximum length of the text sequences (default: 128).
:param outputs type of outputs: mean, sequence, pooler
:return: Encoded text features.
"""
encoded_query = self.tokenizer(
texts, padding=True, truncation=True, max_length=max_length
)
batch = {
key: torch.tensor(values).to(self.device)
for key, values in encoded_query.items()
}
with torch.no_grad():
text_features = self.text_encoder([
batch["input_ids"], batch["attention_mask"]
])
modality_token_feature = self.modality_token_encoder("text")
output_features = self.universal_projection_encoder([text_features, modality_token_feature])
if outputs == "mean":
text_features = output_features.mean_output
elif outputs == "sequence":
text_features = output_features.last_hidden_state
else:
text_features = output_features.pooler_output
return text_features
def matching(self, image_paths, texts, normalize=True, top_k=None, strategy="similarity", temperature=0.0):
"""
Calculate similarities between images and texts.
:param image_paths: List of paths to images.
:param texts: List of text descriptions.
:param normalize: Whether to normalize the features (default: True).
:param top_k: Return top K results (default: None).
:param strategy: Matching strategy, either 'similarity' or 'softmax' (default: 'similarity').
:param temperature: change real distribution, default = 2.5
:return: If top_k is provided, returns top probabilities and labels, otherwise returns dot similarities.
"""
image_features = self.encode_image(image_paths=image_paths)
text_features = self.encode_text(texts=texts)
if normalize:
image_features = F.normalize(image_features, p=2, dim=-1)
text_features = F.normalize(text_features, p=2, dim=-1)
dot_similarities = (image_features @ text_features.T) * torch.exp(torch.tensor(temperature).to(self.device))
if strategy == 'softmax':
dot_similarities = (float(len(set(texts))) * dot_similarities).softmax(dim=-1)
if top_k is not None:
top_probs, top_labels = dot_similarities.cpu().topk(top_k, dim=-1)
return top_probs, top_labels
else:
return dot_similarities, None
def image_retrieval(self, query, image_paths, image_embeddings=None, temperature=0.0, n=9, plot=False):
"""
Perform image retrieval based on a text query.
:param query: Text query (string).
:param image_paths: List of image paths (optional).
:param image_embeddings: Precomputed image embeddings (optional).
:param temperature: change real distribution, default = 2.5
:param n: Number of images to retrieve (default: 9).
:param plot: Whether to plot the retrieved images (default: False).
:return: Tuple containing similarity values and indices of the retrieved images.
"""
text_embeddings = self.encode_text([query])
if image_embeddings is None:
image_embeddings = self.encode_image(image_paths=image_paths)
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
dot_similarity = (text_embeddings_n @ image_embeddings_n.T) * torch.exp(
torch.tensor(temperature).to(self.device))
if n > len(image_paths):
n = len(image_paths)
values, indices = torch.topk(dot_similarity.cpu().squeeze(0), n)
if plot:
nrows = int(np.sqrt(n))
ncols = int(np.ceil(n / nrows))
matches = [image_paths[idx] for idx in indices]
fig, axes = plt.subplots(nrows, ncols, figsize=(20, 20))
for match, ax in zip(matches, axes.flatten()):
image = self.load_image(f"{match}")
ax.imshow(image)
ax.axis("off")
plt.savefig("img.png")
#fig.suptitle(query)
#plt.show()
#return values, indices
def text_retrieval(self, query, texts, text_embeddings=None, n=9, plot_image=False, temperature=0.0):
"""
Perform text retrieval based on an image query.
:param query: Image query (path of image).
:param texts: List of text samples.
:param text_embeddings: Precomputed text embeddings (optional).
:param n: Number of texts to retrieve (default: 9).
:param plot_image: Plot the query
:param temperature: change real distribution, default = 2.5
:return: List of retrieved text samples and its probabilities.
"""
if text_embeddings is None:
text_embeddings = self.encode_text(texts)
image_embeddings = self.encode_image([query])
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
dot_similarity = (image_embeddings_n @ text_embeddings_n.T) * torch.exp(
torch.tensor(temperature).to(self.device))
if n > len(texts):
n = len(texts)
values, indices = torch.topk(dot_similarity.cpu().squeeze(0), n)
matches = [texts[idx] for idx in indices]
if plot_image:
# Read and plot the image
image = self.load_image(query)
# Plot the image
plt.imshow(image)
#plt.title('Random Image')
plt.axis('off')
plt.savefig("img.png")
#plt.show()
#return matches, values