Spaces:
Running
Running
Load CLIP model from transformers
Browse files- requirements.txt +1 -2
- src/nn.py +35 -179
- src/retrieval.py +2 -2
requirements.txt
CHANGED
@@ -6,5 +6,4 @@ gradio==3.33.1
|
|
6 |
gdown==4.4.0
|
7 |
inflect==6.0.4
|
8 |
nltk==3.8.1
|
9 |
-
|
10 |
-
transformers==4.26.1
|
|
|
6 |
gdown==4.4.0
|
7 |
inflect==6.0.4
|
8 |
nltk==3.8.1
|
9 |
+
transformers==4.29.2
|
|
src/nn.py
CHANGED
@@ -6,10 +6,9 @@ from typing import Optional
|
|
6 |
import faiss
|
7 |
import gdown
|
8 |
import numpy as np
|
9 |
-
import open_clip
|
10 |
import torch
|
11 |
-
from open_clip.transformer import Transformer
|
12 |
from PIL import Image
|
|
|
13 |
|
14 |
from src.retrieval import ArrowMetadataProvider
|
15 |
from src.transforms import TextCompose, default_vocabulary_transforms
|
@@ -28,73 +27,46 @@ class CaSED(torch.nn.Module):
|
|
28 |
Args:
|
29 |
index_name (str): Name of the faiss index to use.
|
30 |
vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary.
|
31 |
-
model_name (str): Name of the CLIP model to use. Defaults to "ViT-L-14".
|
32 |
-
pretrained (str): Pretrained weights to use for the CLIP model. Defaults to "openai".
|
33 |
|
34 |
Extra hparams:
|
35 |
alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5.
|
36 |
artifact_dir (str): Path to the directory where the databases are stored. Defaults to
|
37 |
"artifacts/".
|
38 |
retrieval_num_results (int): Number of results to return. Defaults to 10.
|
39 |
-
vocabulary_prompt (str): Prompt to use for the vocabulary. Defaults to "{}".
|
40 |
-
tau (float): Temperature to use for the classifier. Defaults to 1.0.
|
41 |
"""
|
42 |
|
43 |
def __init__(
|
44 |
self,
|
45 |
index_name: str = "ViT-L-14_CC12M",
|
46 |
vocabulary_transforms: TextCompose = default_vocabulary_transforms(),
|
47 |
-
model_name: str = "ViT-L-14",
|
48 |
-
pretrained: str = "openai",
|
49 |
-
vocabulary_prompt: str = "{}",
|
50 |
**kwargs,
|
51 |
):
|
52 |
super().__init__()
|
53 |
-
self._prev_vocab_words = None
|
54 |
-
self._prev_used_prompts = None
|
55 |
-
self._prev_vocab_words_z = None
|
56 |
-
|
57 |
-
model, _, preprocess = open_clip.create_model_and_transforms(
|
58 |
-
model_name, pretrained=pretrained, device="cpu"
|
59 |
-
)
|
60 |
-
tokenizer = open_clip.get_tokenizer(model_name)
|
61 |
-
self.tokenizer = tokenizer
|
62 |
-
self.preprocess = preprocess
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
kwargs["alpha"] = kwargs.get("alpha", 0.5)
|
65 |
kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/")
|
66 |
kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10)
|
67 |
-
vocabulary_prompt = kwargs.get("vocabulary_prompt", "{}")
|
68 |
-
kwargs["vocabulary_prompts"] = [vocabulary_prompt]
|
69 |
-
kwargs["tau"] = kwargs.get("tau", 1.0)
|
70 |
self.hparams = kwargs
|
71 |
|
72 |
-
language_encoder = LanguageTransformer(
|
73 |
-
model.transformer,
|
74 |
-
model.token_embedding,
|
75 |
-
model.positional_embedding,
|
76 |
-
model.ln_final,
|
77 |
-
model.text_projection,
|
78 |
-
model.attn_mask,
|
79 |
-
)
|
80 |
-
scale = model.logit_scale.exp().item()
|
81 |
-
classifier = NearestNeighboursClassifier(scale=scale, tau=self.hparams["tau"])
|
82 |
-
|
83 |
-
self.index_name = index_name
|
84 |
-
self.vocabulary_transforms = vocabulary_transforms
|
85 |
-
self.vision_encoder = model.visual
|
86 |
-
self.language_encoder = language_encoder
|
87 |
-
self.classifier = classifier
|
88 |
-
|
89 |
# download databases
|
90 |
self.prepare_data()
|
91 |
|
92 |
-
# load faiss indices
|
93 |
indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
|
94 |
indices_fp = indices_list_dir / "indices.json"
|
95 |
self.indices = json.load(open(indices_fp))
|
96 |
-
|
97 |
-
# load faiss indices and metadata providers
|
98 |
self.resources = {}
|
99 |
for name, index_fp in self.indices.items():
|
100 |
text_index_fp = Path(index_fp) / "text.index"
|
@@ -107,7 +79,7 @@ class CaSED(torch.nn.Module):
|
|
107 |
|
108 |
self.resources[name] = {
|
109 |
"device": DEVICE,
|
110 |
-
"model":
|
111 |
"text_index": text_index,
|
112 |
"metadata_provider": metadata_provider,
|
113 |
}
|
@@ -175,156 +147,40 @@ class CaSED(torch.nn.Module):
|
|
175 |
|
176 |
return vocabularies
|
177 |
|
178 |
-
@torch.no_grad()
|
179 |
-
def encode_vocabulary(self, vocabulary: list, use_prompts: bool = False) -> torch.Tensor:
|
180 |
-
"""Encode a vocabulary.
|
181 |
-
|
182 |
-
Args:
|
183 |
-
vocabulary (list): List of words.
|
184 |
-
"""
|
185 |
-
# check if vocabulary has changed
|
186 |
-
if vocabulary == self._prev_vocab_words and use_prompts == self._prev_used_prompts:
|
187 |
-
return self._prev_vocab_words_z
|
188 |
-
|
189 |
-
# tokenize vocabulary
|
190 |
-
classes = [c.replace("_", " ") for c in vocabulary]
|
191 |
-
prompts = self.hparams["vocabulary_prompts"] if use_prompts else ["{}"]
|
192 |
-
texts_views = [[p.format(c) for c in classes] for p in prompts]
|
193 |
-
tokenized_texts_views = [
|
194 |
-
torch.cat([self.tokenizer(prompt) for prompt in class_prompts])
|
195 |
-
for class_prompts in texts_views
|
196 |
-
]
|
197 |
-
tokenized_texts_views = torch.stack(tokenized_texts_views).to(DEVICE)
|
198 |
-
|
199 |
-
# encode vocabulary
|
200 |
-
T, C, _ = tokenized_texts_views.shape
|
201 |
-
texts_z_views = self.language_encoder(tokenized_texts_views.view(T * C, -1))
|
202 |
-
texts_z_views = texts_z_views.view(T, C, -1)
|
203 |
-
texts_z_views = texts_z_views / texts_z_views.norm(dim=-1, keepdim=True)
|
204 |
-
|
205 |
-
# cache vocabulary
|
206 |
-
self._prev_vocab_words = vocabulary
|
207 |
-
self._prev_used_prompts = use_prompts
|
208 |
-
self._prev_vocab_words_z = texts_z_views
|
209 |
-
|
210 |
-
return texts_z_views
|
211 |
-
|
212 |
@torch.no_grad()
|
213 |
def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor():
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
vocabulary = self.query_index(image_z)
|
219 |
|
220 |
# generate a single text embedding from the unfiltered vocabulary
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
|
|
225 |
|
226 |
# filter the vocabulary, embed it, and get its mean embedding
|
227 |
vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
|
228 |
-
|
229 |
-
|
230 |
-
|
|
|
|
|
231 |
|
232 |
# get the image and text predictions
|
233 |
-
|
234 |
-
|
|
|
|
|
235 |
|
236 |
# average the image and text predictions
|
237 |
alpha = alpha or self.hparams["alpha"]
|
238 |
sample_p = alpha * image_p + (1 - alpha) * text_p
|
239 |
|
240 |
# get the scores
|
241 |
-
|
242 |
-
scores = sample_p[0].tolist()
|
243 |
-
|
244 |
-
del image_z, unfiltered_vocabulary_z, text_z, vocabulary_z, mean_vocabulary_z
|
245 |
-
del image_p, text_p, sample_p
|
246 |
|
247 |
return vocabulary, scores
|
248 |
-
|
249 |
-
|
250 |
-
class NearestNeighboursClassifier(torch.nn.Module):
|
251 |
-
"""Nearest neighbours classifier.
|
252 |
-
|
253 |
-
It computes the similarity between the query and the supports using the
|
254 |
-
cosine similarity and then applies a softmax to obtain the logits.
|
255 |
-
|
256 |
-
Args:
|
257 |
-
scale (float): Scale for the logits of the query. Defaults to 1.0.
|
258 |
-
tau (float): Temperature for the softmax. Defaults to 1.0.
|
259 |
-
"""
|
260 |
-
|
261 |
-
def __init__(self, scale: float = 1.0, tau: float = 1.0):
|
262 |
-
super().__init__()
|
263 |
-
self.scale = scale
|
264 |
-
self.tau = tau
|
265 |
-
|
266 |
-
def forward(self, query: torch.Tensor, supports: torch.Tensor):
|
267 |
-
query = query / query.norm(dim=-1, keepdim=True)
|
268 |
-
supports = supports / supports.norm(dim=-1, keepdim=True)
|
269 |
-
|
270 |
-
if supports.dim() == 2:
|
271 |
-
supports = supports.unsqueeze(0)
|
272 |
-
|
273 |
-
Q, _ = query.shape
|
274 |
-
N, C, _ = supports.shape
|
275 |
-
|
276 |
-
supports = supports.mean(dim=0)
|
277 |
-
supports = supports / supports.norm(dim=-1, keepdim=True)
|
278 |
-
similarity = self.scale * query @ supports.T
|
279 |
-
similarity = similarity / self.tau if self.tau != 1.0 else similarity
|
280 |
-
logits = similarity.softmax(dim=-1)
|
281 |
-
|
282 |
-
return logits
|
283 |
-
|
284 |
-
|
285 |
-
class LanguageTransformer(torch.nn.Module):
|
286 |
-
"""Language Transformer for CLIP.
|
287 |
-
|
288 |
-
Args:
|
289 |
-
transformer (Transformer): Transformer model.
|
290 |
-
token_embedding (torch.nn.Embedding): Token embedding.
|
291 |
-
positional_embedding (torch.nn.Parameter): Positional embedding.
|
292 |
-
ln_final (torch.nn.LayerNorm): Layer norm.
|
293 |
-
text_projection (torch.nn.Parameter): Text projection.
|
294 |
-
"""
|
295 |
-
|
296 |
-
def __init__(
|
297 |
-
self,
|
298 |
-
model: Transformer,
|
299 |
-
token_embedding: torch.nn.Embedding,
|
300 |
-
positional_embedding: torch.nn.Parameter,
|
301 |
-
ln_final: torch.nn.LayerNorm,
|
302 |
-
text_projection: torch.nn.Parameter,
|
303 |
-
attn_mask: torch.Tensor,
|
304 |
-
):
|
305 |
-
super().__init__()
|
306 |
-
self.transformer = model
|
307 |
-
self.token_embedding = token_embedding
|
308 |
-
self.positional_embedding = positional_embedding
|
309 |
-
self.ln_final = ln_final
|
310 |
-
self.text_projection = text_projection
|
311 |
-
|
312 |
-
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
313 |
-
|
314 |
-
def forward(self, text: torch.Tensor) -> torch.Tensor:
|
315 |
-
cast_dtype = self.transformer.get_cast_dtype()
|
316 |
-
|
317 |
-
"""Forward pass for the text encoder."""
|
318 |
-
x = self.token_embedding(text).to(cast_dtype)
|
319 |
-
|
320 |
-
x = x + self.positional_embedding.to(cast_dtype)
|
321 |
-
x = x.permute(1, 0, 2)
|
322 |
-
x = self.transformer(x, attn_mask=self.attn_mask)
|
323 |
-
x = x.permute(1, 0, 2)
|
324 |
-
x = self.ln_final(x)
|
325 |
-
|
326 |
-
# x.shape = [batch_size, n_ctx, transformer.width]
|
327 |
-
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
328 |
-
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
329 |
-
|
330 |
-
return x
|
|
|
6 |
import faiss
|
7 |
import gdown
|
8 |
import numpy as np
|
|
|
9 |
import torch
|
|
|
10 |
from PIL import Image
|
11 |
+
from transformers import CLIPModel, CLIPProcessor
|
12 |
|
13 |
from src.retrieval import ArrowMetadataProvider
|
14 |
from src.transforms import TextCompose, default_vocabulary_transforms
|
|
|
27 |
Args:
|
28 |
index_name (str): Name of the faiss index to use.
|
29 |
vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary.
|
|
|
|
|
30 |
|
31 |
Extra hparams:
|
32 |
alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5.
|
33 |
artifact_dir (str): Path to the directory where the databases are stored. Defaults to
|
34 |
"artifacts/".
|
35 |
retrieval_num_results (int): Number of results to return. Defaults to 10.
|
|
|
|
|
36 |
"""
|
37 |
|
38 |
def __init__(
|
39 |
self,
|
40 |
index_name: str = "ViT-L-14_CC12M",
|
41 |
vocabulary_transforms: TextCompose = default_vocabulary_transforms(),
|
|
|
|
|
|
|
42 |
**kwargs,
|
43 |
):
|
44 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
# load CLIP
|
47 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(DEVICE)
|
48 |
+
self.index_name = index_name
|
49 |
+
self.vocabulary_transforms = vocabulary_transforms
|
50 |
+
self.vision_encoder = model.vision_model
|
51 |
+
self.vision_proj = model.visual_projection
|
52 |
+
self.language_encoder = model.text_model
|
53 |
+
self.language_proj = model.text_projection
|
54 |
+
self.logit_scale = model.logit_scale.exp()
|
55 |
+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
56 |
+
|
57 |
+
# set hparams
|
58 |
kwargs["alpha"] = kwargs.get("alpha", 0.5)
|
59 |
kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/")
|
60 |
kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10)
|
|
|
|
|
|
|
61 |
self.hparams = kwargs
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
# download databases
|
64 |
self.prepare_data()
|
65 |
|
66 |
+
# load faiss indices and metadata providers
|
67 |
indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
|
68 |
indices_fp = indices_list_dir / "indices.json"
|
69 |
self.indices = json.load(open(indices_fp))
|
|
|
|
|
70 |
self.resources = {}
|
71 |
for name, index_fp in self.indices.items():
|
72 |
text_index_fp = Path(index_fp) / "text.index"
|
|
|
79 |
|
80 |
self.resources[name] = {
|
81 |
"device": DEVICE,
|
82 |
+
"model": "ViT-L-14",
|
83 |
"text_index": text_index,
|
84 |
"metadata_provider": metadata_provider,
|
85 |
}
|
|
|
147 |
|
148 |
return vocabularies
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
@torch.no_grad()
|
151 |
def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor():
|
152 |
+
# forward the image
|
153 |
+
image = self.processor(images=Image.open(image_fp), return_tensors="pt")
|
154 |
+
image["pixel_values"] = image["pixel_values"].to(DEVICE)
|
155 |
+
image_z = self.vision_proj(self.vision_encoder(**image)[1])
|
|
|
156 |
|
157 |
# generate a single text embedding from the unfiltered vocabulary
|
158 |
+
vocabulary = self.query_index(image_z)
|
159 |
+
text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
|
160 |
+
text["input_ids"] = text["input_ids"][:, :77].to(DEVICE)
|
161 |
+
text["attention_mask"] = text["attention_mask"][:, :77].to(DEVICE)
|
162 |
+
text_z = self.language_encoder(**text)[1]
|
163 |
+
text_z = self.language_proj(text_z)
|
164 |
|
165 |
# filter the vocabulary, embed it, and get its mean embedding
|
166 |
vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
|
167 |
+
text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
|
168 |
+
text = {k: v.to(DEVICE) for k, v in text.items()}
|
169 |
+
vocabulary_z = self.language_encoder(**text)[1]
|
170 |
+
vocabulary_z = self.language_proj(vocabulary_z)
|
171 |
+
vocabulary_z = vocabulary_z / vocabulary_z.norm(dim=-1, keepdim=True)
|
172 |
|
173 |
# get the image and text predictions
|
174 |
+
image_z = image_z / image_z.norm(dim=-1, keepdim=True)
|
175 |
+
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
176 |
+
image_p = (torch.matmul(image_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1)
|
177 |
+
text_p = (torch.matmul(text_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1)
|
178 |
|
179 |
# average the image and text predictions
|
180 |
alpha = alpha or self.hparams["alpha"]
|
181 |
sample_p = alpha * image_p + (1 - alpha) * text_p
|
182 |
|
183 |
# get the scores
|
184 |
+
scores = sample_p[0].cpu().tolist()
|
|
|
|
|
|
|
|
|
185 |
|
186 |
return vocabulary, scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/retrieval.py
CHANGED
@@ -11,8 +11,8 @@ class ArrowMetadataProvider:
|
|
11 |
Code taken from: https://github.dev/rom1504/clip-retrieval
|
12 |
"""
|
13 |
|
14 |
-
def __init__(self, arrow_folder:
|
15 |
-
arrow_files = [str(a) for a in sorted(
|
16 |
self.table = pa.concat_tables(
|
17 |
[
|
18 |
pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
|
|
|
11 |
Code taken from: https://github.dev/rom1504/clip-retrieval
|
12 |
"""
|
13 |
|
14 |
+
def __init__(self, arrow_folder: Path):
|
15 |
+
arrow_files = [str(a) for a in sorted(arrow_folder.glob("**/*")) if a.is_file()]
|
16 |
self.table = pa.concat_tables(
|
17 |
[
|
18 |
pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
|