Spaces:
Running
Running
import json | |
import tarfile | |
from pathlib import Path | |
from typing import Optional | |
import faiss | |
import gdown | |
import numpy as np | |
import torch | |
from PIL import Image | |
from transformers import CLIPModel, CLIPProcessor | |
from src.retrieval import ArrowMetadataProvider | |
from src.transforms import TextCompose, default_vocabulary_transforms | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
RETRIEVAL_DATABASES = { | |
"cc12m": "https://drive.google.com/uc?id=1HyM4mnKSxF0sqzAe-KZL8y-cQWRPiuXn&confirm=t", | |
} | |
class CaSED(torch.nn.Module): | |
"""Torch module for Category Search from External Databases (CaSED). | |
Args: | |
index_name (str): Name of the faiss index to use. | |
vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary. | |
Extra hparams: | |
alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5. | |
artifact_dir (str): Path to the directory where the databases are stored. Defaults to | |
"artifacts/". | |
retrieval_num_results (int): Number of results to return. Defaults to 10. | |
""" | |
def __init__( | |
self, | |
index_name: str = "ViT-L-14_CC12M", | |
vocabulary_transforms: TextCompose = default_vocabulary_transforms(), | |
**kwargs, | |
): | |
super().__init__() | |
# load CLIP | |
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(DEVICE) | |
self.index_name = index_name | |
self.vocabulary_transforms = vocabulary_transforms | |
self.vision_encoder = model.vision_model | |
self.vision_proj = model.visual_projection | |
self.language_encoder = model.text_model | |
self.language_proj = model.text_projection | |
self.logit_scale = model.logit_scale.exp() | |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
# set hparams | |
kwargs["alpha"] = kwargs.get("alpha", 0.5) | |
kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/") | |
kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10) | |
self.hparams = kwargs | |
# download databases | |
self.prepare_data() | |
# load faiss indices and metadata providers | |
indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval" | |
indices_fp = indices_list_dir / "indices.json" | |
self.indices = json.load(open(indices_fp)) | |
self.resources = {} | |
for name, index_fp in self.indices.items(): | |
text_index_fp = Path(index_fp) / "text.index" | |
metadata_fp = Path(index_fp) / "metadata/" | |
text_index = faiss.read_index( | |
str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY | |
) | |
metadata_provider = ArrowMetadataProvider(metadata_fp) | |
self.resources[name] = { | |
"device": DEVICE, | |
"model": "ViT-L-14", | |
"text_index": text_index, | |
"metadata_provider": metadata_provider, | |
} | |
def prepare_data(self): | |
"""Download data if needed.""" | |
databases_path = Path(self.hparams["artifact_dir"]) / "models" / "databases" | |
for name, url in RETRIEVAL_DATABASES.items(): | |
database_path = Path(databases_path, name) | |
if database_path.exists(): | |
continue | |
# download data | |
target_path = Path(databases_path, name + ".tar.gz") | |
try: | |
gdown.download(url, str(target_path), quiet=False) | |
tar = tarfile.open(target_path, "r:gz") | |
tar.extractall(target_path.parent) | |
tar.close() | |
target_path.unlink() | |
except FileNotFoundError: | |
print(f"Could not download {url}.") | |
print(f"Please download it manually and place it in {target_path.parent}.") | |
def query_index(self, sample_z: torch.Tensor) -> torch.Tensor: | |
# get the index | |
resources = self.resources[self.index_name] | |
text_index = resources["text_index"] | |
metadata_provider = resources["metadata_provider"] | |
# query the index | |
sample_z = sample_z.squeeze(0) | |
sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True) | |
query_input = sample_z.cpu().detach().numpy().tolist() | |
query = np.expand_dims(np.array(query_input).astype("float32"), 0) | |
distances, idxs, _ = text_index.search_and_reconstruct( | |
query, self.hparams["retrieval_num_results"] | |
) | |
results = idxs[0] | |
nb_results = np.where(results == -1)[0] | |
nb_results = nb_results[0] if len(nb_results) > 0 else len(results) | |
indices = results[:nb_results] | |
distances = distances[0][:nb_results] | |
if len(distances) == 0: | |
return [] | |
# get the metadata | |
results = [] | |
metadata = metadata_provider.get(indices[:20], ["caption"]) | |
for key, (d, i) in enumerate(zip(distances, indices)): | |
output = {} | |
meta = None if key + 1 > len(metadata) else metadata[key] | |
if meta is not None: | |
output.update(meta) | |
output["id"] = i.item() | |
output["similarity"] = d.item() | |
results.append(output) | |
# get the captions only | |
vocabularies = [result["caption"] for result in results] | |
return vocabularies | |
def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor(): | |
# forward the image | |
image = self.processor(images=Image.open(image_fp), return_tensors="pt") | |
image["pixel_values"] = image["pixel_values"].to(DEVICE) | |
image_z = self.vision_proj(self.vision_encoder(**image)[1]) | |
# generate a single text embedding from the unfiltered vocabulary | |
vocabulary = self.query_index(image_z) | |
text = self.processor(text=vocabulary, return_tensors="pt", padding=True) | |
text["input_ids"] = text["input_ids"][:, :77].to(DEVICE) | |
text["attention_mask"] = text["attention_mask"][:, :77].to(DEVICE) | |
text_z = self.language_encoder(**text)[1] | |
text_z = self.language_proj(text_z) | |
# filter the vocabulary, embed it, and get its mean embedding | |
vocabulary = self.vocabulary_transforms(vocabulary) or ["object"] | |
text = self.processor(text=vocabulary, return_tensors="pt", padding=True) | |
text = {k: v.to(DEVICE) for k, v in text.items()} | |
vocabulary_z = self.language_encoder(**text)[1] | |
vocabulary_z = self.language_proj(vocabulary_z) | |
vocabulary_z = vocabulary_z / vocabulary_z.norm(dim=-1, keepdim=True) | |
# get the image and text predictions | |
image_z = image_z / image_z.norm(dim=-1, keepdim=True) | |
text_z = text_z / text_z.norm(dim=-1, keepdim=True) | |
image_p = (torch.matmul(image_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1) | |
text_p = (torch.matmul(text_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1) | |
# average the image and text predictions | |
alpha = alpha or self.hparams["alpha"] | |
sample_p = alpha * image_p + (1 - alpha) * text_p | |
# get the scores | |
scores = sample_p[0].cpu().tolist() | |
return vocabulary, scores | |