File size: 7,329 Bytes
a3ee979
 
 
 
 
 
 
 
 
 
fcf6714
a3ee979
07a2d78
a3ee979
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcf6714
 
 
 
 
 
 
 
 
 
 
 
a3ee979
 
 
 
 
 
 
 
fcf6714
a3ee979
 
07a2d78
a3ee979
 
 
 
 
 
 
 
 
 
 
 
fcf6714
a3ee979
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07a2d78
a3ee979
 
 
 
 
 
 
 
 
 
 
fcf6714
 
 
 
a3ee979
 
fcf6714
 
 
 
 
 
a3ee979
 
 
fcf6714
 
 
 
 
a3ee979
 
fcf6714
 
 
 
a3ee979
 
 
 
 
 
fcf6714
a3ee979
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import json
import tarfile
from pathlib import Path
from typing import Optional

import faiss
import gdown
import numpy as np
import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor

from src.retrieval import ArrowMetadataProvider
from src.transforms import TextCompose, default_vocabulary_transforms

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


RETRIEVAL_DATABASES = {
    "cc12m": "https://drive.google.com/uc?id=1HyM4mnKSxF0sqzAe-KZL8y-cQWRPiuXn&confirm=t",
}


class CaSED(torch.nn.Module):
    """Torch module for Category Search from External Databases (CaSED).

    Args:
        index_name (str): Name of the faiss index to use.
        vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary.

    Extra hparams:
        alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5.
        artifact_dir (str): Path to the directory where the databases are stored. Defaults to
            "artifacts/".
        retrieval_num_results (int): Number of results to return. Defaults to 10.
    """

    def __init__(
        self,
        index_name: str = "ViT-L-14_CC12M",
        vocabulary_transforms: TextCompose = default_vocabulary_transforms(),
        **kwargs,
    ):
        super().__init__()

        # load CLIP
        model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(DEVICE)
        self.index_name = index_name
        self.vocabulary_transforms = vocabulary_transforms
        self.vision_encoder = model.vision_model
        self.vision_proj = model.visual_projection
        self.language_encoder = model.text_model
        self.language_proj = model.text_projection
        self.logit_scale = model.logit_scale.exp()
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

        # set hparams
        kwargs["alpha"] = kwargs.get("alpha", 0.5)
        kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/")
        kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10)
        self.hparams = kwargs

        # download databases
        self.prepare_data()

        # load faiss indices and metadata providers
        indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
        indices_fp = indices_list_dir / "indices.json"
        self.indices = json.load(open(indices_fp))
        self.resources = {}
        for name, index_fp in self.indices.items():
            text_index_fp = Path(index_fp) / "text.index"
            metadata_fp = Path(index_fp) / "metadata/"

            text_index = faiss.read_index(
                str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
            )
            metadata_provider = ArrowMetadataProvider(metadata_fp)

            self.resources[name] = {
                "device": DEVICE,
                "model": "ViT-L-14",
                "text_index": text_index,
                "metadata_provider": metadata_provider,
            }

    def prepare_data(self):
        """Download data if needed."""
        databases_path = Path(self.hparams["artifact_dir"]) / "models" / "databases"

        for name, url in RETRIEVAL_DATABASES.items():
            database_path = Path(databases_path, name)
            if database_path.exists():
                continue

            # download data
            target_path = Path(databases_path, name + ".tar.gz")
            try:
                gdown.download(url, str(target_path), quiet=False)
                tar = tarfile.open(target_path, "r:gz")
                tar.extractall(target_path.parent)
                tar.close()
                target_path.unlink()
            except FileNotFoundError:
                print(f"Could not download {url}.")
                print(f"Please download it manually and place it in {target_path.parent}.")

    @torch.no_grad()
    def query_index(self, sample_z: torch.Tensor) -> torch.Tensor:
        # get the index
        resources = self.resources[self.index_name]
        text_index = resources["text_index"]
        metadata_provider = resources["metadata_provider"]

        # query the index
        sample_z = sample_z.squeeze(0)
        sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True)
        query_input = sample_z.cpu().detach().numpy().tolist()
        query = np.expand_dims(np.array(query_input).astype("float32"), 0)

        distances, idxs, _ = text_index.search_and_reconstruct(
            query, self.hparams["retrieval_num_results"]
        )
        results = idxs[0]
        nb_results = np.where(results == -1)[0]
        nb_results = nb_results[0] if len(nb_results) > 0 else len(results)
        indices = results[:nb_results]
        distances = distances[0][:nb_results]

        if len(distances) == 0:
            return []

        # get the metadata
        results = []
        metadata = metadata_provider.get(indices[:20], ["caption"])
        for key, (d, i) in enumerate(zip(distances, indices)):
            output = {}
            meta = None if key + 1 > len(metadata) else metadata[key]
            if meta is not None:
                output.update(meta)
            output["id"] = i.item()
            output["similarity"] = d.item()
            results.append(output)

        # get the captions only
        vocabularies = [result["caption"] for result in results]

        return vocabularies

    @torch.no_grad()
    def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor():
        # forward the image
        image = self.processor(images=Image.open(image_fp), return_tensors="pt")
        image["pixel_values"] = image["pixel_values"].to(DEVICE)
        image_z = self.vision_proj(self.vision_encoder(**image)[1])

        # generate a single text embedding from the unfiltered vocabulary
        vocabulary = self.query_index(image_z)
        text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
        text["input_ids"] = text["input_ids"][:, :77].to(DEVICE)
        text["attention_mask"] = text["attention_mask"][:, :77].to(DEVICE)
        text_z = self.language_encoder(**text)[1]
        text_z = self.language_proj(text_z)

        # filter the vocabulary, embed it, and get its mean embedding
        vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
        text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
        text = {k: v.to(DEVICE) for k, v in text.items()}
        vocabulary_z = self.language_encoder(**text)[1]
        vocabulary_z = self.language_proj(vocabulary_z)
        vocabulary_z = vocabulary_z / vocabulary_z.norm(dim=-1, keepdim=True)

        # get the image and text predictions
        image_z = image_z / image_z.norm(dim=-1, keepdim=True)
        text_z = text_z / text_z.norm(dim=-1, keepdim=True)
        image_p = (torch.matmul(image_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1)
        text_p = (torch.matmul(text_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1)

        # average the image and text predictions
        alpha = alpha or self.hparams["alpha"]
        sample_p = alpha * image_p + (1 - alpha) * text_p

        # get the scores
        scores = sample_p[0].cpu().tolist()

        return vocabulary, scores