broadwell commited on
Commit
a6b26e3
·
verified ·
1 Parent(s): 02ab20e

Add viz/explanation feature for image and text activations

Browse files
CLIP_Explainability/README.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLIP Explainability
2
+
3
+ This repo contains the code for the [CLIP Explainability project](CLIP_Explainability.pdf).
4
+ In this project, we conduct an in-depth study of CLIP’s learned image and text representations using saliency map visualization. We propose a modification to the existing saliency visualization method that improves its performance as shown by our qualitative evaluations. We then use this method to study CLIP’s ability in capturing similarities and dissimilarities between an input image and targets belonging to different domains including image, text, and emotion.
5
+
6
+ ## Setup
7
+
8
+ To install the required libraries run the following command:
9
+
10
+ ```
11
+ pip install -r requirements.txt
12
+
13
+ ```
14
+
15
+ ## Organization
16
+
17
+ [code](code) directory contains
18
+
19
+ - the implementation of saliency visualization methods: for [ViT](code/vit_cam.py) and [ResNet](code/rn_cam.py)-based CLIP
20
+ - [GradCAM](code/pytorch-grad-cam) implementation based on [pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam/tree/e93f41104e20134e5feac2a660b343437f601ad0) slightly modified to adapt to CLIP.
21
+ - A re-implementation of CLIP taken from [Transformer-MM-Explainability](https://github.com/hila-chefer/Transformer-MM-Explainability) repo that keeps tack of attention maps and gradients: [clip_.py](code/clip_.py)
22
+ - [Notebooks](code/notebooks/) for the experiments explained in the report
23
+
24
+
25
+ [Images](Images) contains images used in the experiments.
26
+
27
+ [results](results) contains the results obtained from the experiments. Any result generated by the notebooks will be stored in this directory.
28
+
29
+
30
+ ## Experiments
31
+
32
+
33
+ | Notebook Name | Experiment | Note |
34
+ | ------------- | ------------- | ------------- |
35
+ | [vit_block_vis](code/notebooks/vit_block_vis.ipynb) | Layer-wise Attention Visualization | - |
36
+ | [saliency_method_compare](code/notebooks/saliency_method_compare.ipynb) | ViT Explainability Method Comparison | Qualitative comparison |
37
+ | [affectnet_emotions](code/notebooks/affectnet_emotions.ipynb) | ViT Explainability Method Comparison | Bias comparison; you need to download a sample of the AffectNet dataset [here](https://drive.google.com/drive/u/1/folders/11RusPab71wGw6LTd9pUnY1Gz3JSH-N_N) and place it in [Images](Images). |
38
+ | [pos_neg_vis](code/notebooks/pos_neg_vis.ipynb) | Positive vs Negative Saliency | - |
39
+ | [artemis_emotions](code/notebooks/artemis_emotions.ipynb) | Emotion-Image Similarity | you need to download the pre-processed WikiArt images [here](https://drive.google.com/drive/u/1/folders/11RusPab71wGw6LTd9pUnY1Gz3JSH-N_N) and place it in [Images](Images). Note that this notebook chooses images randomly so the results may not be the same as the ones in the report. |
40
+ | [perword_vis](code/notebooks/perword_vis.ipynb) | Word-Wise Saliency Visualization |
41
+ | [global_vis](code/notebooks/global_vis.ipynb) | - | can be used to visualize saliency maps for ViT and ResNet-based CLIP.|
42
+
43
+
CLIP_Explainability/auxilary.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+ from typing import Tuple, Optional
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch.nn.init import xavier_uniform_
8
+ from torch.nn.init import constant_
9
+ from torch.nn.init import xavier_normal_
10
+ from torch.nn.parameter import Parameter
11
+ from torch.nn import functional as F
12
+
13
+ # We define this function as _pad because it takes an argument
14
+ # named pad, which clobbers the recursive reference to the pad
15
+ # function needed for __torch_function__ support
16
+ pad = F.pad
17
+
18
+
19
+ # This class exists solely for Transformer; it has an annotation stating
20
+ # that bias is never None, which appeases TorchScript
21
+ class _LinearWithBias(torch.nn.Linear):
22
+ bias: Tensor
23
+
24
+ def __init__(self, in_features: int, out_features: int) -> None:
25
+ super().__init__(in_features, out_features, bias=True)
26
+
27
+
28
+ def multi_head_attention_forward(
29
+ query: Tensor,
30
+ key: Tensor,
31
+ value: Tensor,
32
+ embed_dim_to_check: int,
33
+ num_heads: int,
34
+ in_proj_weight: Tensor,
35
+ in_proj_bias: Tensor,
36
+ bias_k: Optional[Tensor],
37
+ bias_v: Optional[Tensor],
38
+ add_zero_attn: bool,
39
+ dropout_p: float,
40
+ out_proj_weight: Tensor,
41
+ out_proj_bias: Tensor,
42
+ training: bool = True,
43
+ key_padding_mask: Optional[Tensor] = None,
44
+ need_weights: bool = True,
45
+ attn_mask: Optional[Tensor] = None,
46
+ use_separate_proj_weight: bool = False,
47
+ q_proj_weight: Optional[Tensor] = None,
48
+ k_proj_weight: Optional[Tensor] = None,
49
+ v_proj_weight: Optional[Tensor] = None,
50
+ static_k: Optional[Tensor] = None,
51
+ static_v: Optional[Tensor] = None,
52
+ attention_probs_forward_hook=None,
53
+ attention_probs_backwards_hook=None,
54
+ ) -> Tuple[Tensor, Optional[Tensor]]:
55
+ if not torch.jit.is_scripting():
56
+ tens_ops = (
57
+ query,
58
+ key,
59
+ value,
60
+ in_proj_weight,
61
+ in_proj_bias,
62
+ bias_k,
63
+ bias_v,
64
+ out_proj_weight,
65
+ out_proj_bias,
66
+ )
67
+ if any([type(t) is not Tensor for t in tens_ops]) and F.has_torch_function(
68
+ tens_ops
69
+ ):
70
+ return F.handle_torch_function(
71
+ multi_head_attention_forward,
72
+ tens_ops,
73
+ query,
74
+ key,
75
+ value,
76
+ embed_dim_to_check,
77
+ num_heads,
78
+ in_proj_weight,
79
+ in_proj_bias,
80
+ bias_k,
81
+ bias_v,
82
+ add_zero_attn,
83
+ dropout_p,
84
+ out_proj_weight,
85
+ out_proj_bias,
86
+ training=training,
87
+ key_padding_mask=key_padding_mask,
88
+ need_weights=need_weights,
89
+ attn_mask=attn_mask,
90
+ use_separate_proj_weight=use_separate_proj_weight,
91
+ q_proj_weight=q_proj_weight,
92
+ k_proj_weight=k_proj_weight,
93
+ v_proj_weight=v_proj_weight,
94
+ static_k=static_k,
95
+ static_v=static_v,
96
+ )
97
+ tgt_len, bsz, embed_dim = query.size()
98
+ assert embed_dim == embed_dim_to_check
99
+ # allow MHA to have different sizes for the feature dimension
100
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
101
+
102
+ head_dim = embed_dim // num_heads
103
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
104
+ scaling = float(head_dim) ** -0.5
105
+
106
+ if not use_separate_proj_weight:
107
+ if torch.equal(query, key) and torch.equal(key, value):
108
+ # self-attention
109
+ q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
110
+
111
+ elif torch.equal(key, value):
112
+ # encoder-decoder attention
113
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
114
+ _b = in_proj_bias
115
+ _start = 0
116
+ _end = embed_dim
117
+ _w = in_proj_weight[_start:_end, :]
118
+ if _b is not None:
119
+ _b = _b[_start:_end]
120
+ q = F.linear(query, _w, _b)
121
+
122
+ if key is None:
123
+ assert value is None
124
+ k = None
125
+ v = None
126
+ else:
127
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
128
+ _b = in_proj_bias
129
+ _start = embed_dim
130
+ _end = None
131
+ _w = in_proj_weight[_start:, :]
132
+ if _b is not None:
133
+ _b = _b[_start:]
134
+ k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
135
+
136
+ else:
137
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
138
+ _b = in_proj_bias
139
+ _start = 0
140
+ _end = embed_dim
141
+ _w = in_proj_weight[_start:_end, :]
142
+ if _b is not None:
143
+ _b = _b[_start:_end]
144
+ q = F.linear(query, _w, _b)
145
+
146
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
147
+ _b = in_proj_bias
148
+ _start = embed_dim
149
+ _end = embed_dim * 2
150
+ _w = in_proj_weight[_start:_end, :]
151
+ if _b is not None:
152
+ _b = _b[_start:_end]
153
+ k = F.linear(key, _w, _b)
154
+
155
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
156
+ _b = in_proj_bias
157
+ _start = embed_dim * 2
158
+ _end = None
159
+ _w = in_proj_weight[_start:, :]
160
+ if _b is not None:
161
+ _b = _b[_start:]
162
+ v = F.linear(value, _w, _b)
163
+ else:
164
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
165
+ len1, len2 = q_proj_weight_non_opt.size()
166
+ assert len1 == embed_dim and len2 == query.size(-1)
167
+
168
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
169
+ len1, len2 = k_proj_weight_non_opt.size()
170
+ assert len1 == embed_dim and len2 == key.size(-1)
171
+
172
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
173
+ len1, len2 = v_proj_weight_non_opt.size()
174
+ assert len1 == embed_dim and len2 == value.size(-1)
175
+
176
+ if in_proj_bias is not None:
177
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
178
+ k = F.linear(
179
+ key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)]
180
+ )
181
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])
182
+ else:
183
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
184
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
185
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
186
+ q = q * scaling
187
+
188
+ if attn_mask is not None:
189
+ assert (
190
+ attn_mask.dtype == torch.float32
191
+ or attn_mask.dtype == torch.float64
192
+ or attn_mask.dtype == torch.float16
193
+ or attn_mask.dtype == torch.uint8
194
+ or attn_mask.dtype == torch.bool
195
+ ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
196
+ attn_mask.dtype
197
+ )
198
+ if attn_mask.dtype == torch.uint8:
199
+ warnings.warn(
200
+ "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
201
+ )
202
+ attn_mask = attn_mask.to(torch.bool)
203
+
204
+ if attn_mask.dim() == 2:
205
+ attn_mask = attn_mask.unsqueeze(0)
206
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
207
+ raise RuntimeError("The size of the 2D attn_mask is not correct.")
208
+ elif attn_mask.dim() == 3:
209
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
210
+ raise RuntimeError("The size of the 3D attn_mask is not correct.")
211
+ else:
212
+ raise RuntimeError(
213
+ "attn_mask's dimension {} is not supported".format(attn_mask.dim())
214
+ )
215
+ # attn_mask's dim is 3 now.
216
+
217
+ # convert ByteTensor key_padding_mask to bool
218
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
219
+ warnings.warn(
220
+ "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
221
+ )
222
+ key_padding_mask = key_padding_mask.to(torch.bool)
223
+
224
+ if bias_k is not None and bias_v is not None:
225
+ if static_k is None and static_v is None:
226
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
227
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
228
+ if attn_mask is not None:
229
+ attn_mask = pad(attn_mask, (0, 1))
230
+ if key_padding_mask is not None:
231
+ key_padding_mask = pad(key_padding_mask, (0, 1))
232
+ else:
233
+ assert static_k is None, "bias cannot be added to static key."
234
+ assert static_v is None, "bias cannot be added to static value."
235
+ else:
236
+ assert bias_k is None
237
+ assert bias_v is None
238
+
239
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
240
+ if k is not None:
241
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
242
+ if v is not None:
243
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
244
+
245
+ if static_k is not None:
246
+ assert static_k.size(0) == bsz * num_heads
247
+ assert static_k.size(2) == head_dim
248
+ k = static_k
249
+
250
+ if static_v is not None:
251
+ assert static_v.size(0) == bsz * num_heads
252
+ assert static_v.size(2) == head_dim
253
+ v = static_v
254
+
255
+ src_len = k.size(1)
256
+
257
+ if key_padding_mask is not None:
258
+ assert key_padding_mask.size(0) == bsz
259
+ assert key_padding_mask.size(1) == src_len
260
+
261
+ if add_zero_attn:
262
+ src_len += 1
263
+ k = torch.cat(
264
+ [
265
+ k,
266
+ torch.zeros(
267
+ (k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device
268
+ ),
269
+ ],
270
+ dim=1,
271
+ )
272
+ v = torch.cat(
273
+ [
274
+ v,
275
+ torch.zeros(
276
+ (v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device
277
+ ),
278
+ ],
279
+ dim=1,
280
+ )
281
+ if attn_mask is not None:
282
+ attn_mask = pad(attn_mask, (0, 1))
283
+ if key_padding_mask is not None:
284
+ key_padding_mask = pad(key_padding_mask, (0, 1))
285
+
286
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
287
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
288
+
289
+ if attn_mask is not None:
290
+ if attn_mask.dtype == torch.bool:
291
+ attn_output_weights.masked_fill_(attn_mask, float("-inf"))
292
+ else:
293
+ attn_output_weights += attn_mask
294
+
295
+ if key_padding_mask is not None:
296
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
297
+ attn_output_weights = attn_output_weights.masked_fill(
298
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
299
+ float("-inf"),
300
+ )
301
+ attn_output_weights = attn_output_weights.view(
302
+ bsz * num_heads, tgt_len, src_len
303
+ )
304
+
305
+ attn_output_weights = F.softmax(attn_output_weights, dim=-1)
306
+ attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
307
+
308
+ # use hooks for the attention weights if necessary
309
+ if (
310
+ attention_probs_forward_hook is not None
311
+ and attention_probs_backwards_hook is not None
312
+ ):
313
+ attention_probs_forward_hook(attn_output_weights)
314
+ attn_output_weights.register_hook(attention_probs_backwards_hook)
315
+
316
+ attn_output = torch.bmm(attn_output_weights, v)
317
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
318
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
319
+ attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
320
+
321
+ if need_weights:
322
+ # average attention weights over heads
323
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
324
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
325
+ else:
326
+ return attn_output, None
327
+
328
+
329
+ class MultiheadAttention(torch.nn.Module):
330
+ r"""Allows the model to jointly attend to information
331
+ from different representation subspaces.
332
+ See reference: Attention Is All You Need
333
+
334
+ .. math::
335
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
336
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
337
+
338
+ Args:
339
+ embed_dim: total dimension of the model.
340
+ num_heads: parallel attention heads.
341
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
342
+ bias: add bias as module parameter. Default: True.
343
+ add_bias_kv: add bias to the key and value sequences at dim=0.
344
+ add_zero_attn: add a new batch of zeros to the key and
345
+ value sequences at dim=1.
346
+ kdim: total number of features in key. Default: None.
347
+ vdim: total number of features in value. Default: None.
348
+
349
+ Note: if kdim and vdim are None, they will be set to embed_dim such that
350
+ query, key, and value have the same number of features.
351
+
352
+ Examples::
353
+
354
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
355
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
356
+ """
357
+
358
+ bias_k: Optional[torch.Tensor]
359
+ bias_v: Optional[torch.Tensor]
360
+
361
+ def __init__(
362
+ self,
363
+ embed_dim,
364
+ num_heads,
365
+ dropout=0.0,
366
+ bias=True,
367
+ add_bias_kv=False,
368
+ add_zero_attn=False,
369
+ kdim=None,
370
+ vdim=None,
371
+ ):
372
+ super(MultiheadAttention, self).__init__()
373
+ self.embed_dim = embed_dim
374
+ self.kdim = kdim if kdim is not None else embed_dim
375
+ self.vdim = vdim if vdim is not None else embed_dim
376
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
377
+
378
+ self.num_heads = num_heads
379
+ self.dropout = dropout
380
+ self.head_dim = embed_dim // num_heads
381
+ assert (
382
+ self.head_dim * num_heads == self.embed_dim
383
+ ), "embed_dim must be divisible by num_heads"
384
+
385
+ if self._qkv_same_embed_dim is False:
386
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
387
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
388
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
389
+ self.register_parameter("in_proj_weight", None)
390
+ else:
391
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
392
+ self.register_parameter("q_proj_weight", None)
393
+ self.register_parameter("k_proj_weight", None)
394
+ self.register_parameter("v_proj_weight", None)
395
+
396
+ if bias:
397
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
398
+ else:
399
+ self.register_parameter("in_proj_bias", None)
400
+ self.out_proj = _LinearWithBias(embed_dim, embed_dim)
401
+
402
+ if add_bias_kv:
403
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
404
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
405
+ else:
406
+ self.bias_k = self.bias_v = None
407
+
408
+ self.add_zero_attn = add_zero_attn
409
+
410
+ self._reset_parameters()
411
+
412
+ def _reset_parameters(self):
413
+ if self._qkv_same_embed_dim:
414
+ xavier_uniform_(self.in_proj_weight)
415
+ else:
416
+ xavier_uniform_(self.q_proj_weight)
417
+ xavier_uniform_(self.k_proj_weight)
418
+ xavier_uniform_(self.v_proj_weight)
419
+
420
+ if self.in_proj_bias is not None:
421
+ constant_(self.in_proj_bias, 0.0)
422
+ constant_(self.out_proj.bias, 0.0)
423
+ if self.bias_k is not None:
424
+ xavier_normal_(self.bias_k)
425
+ if self.bias_v is not None:
426
+ xavier_normal_(self.bias_v)
427
+
428
+ def __setstate__(self, state):
429
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
430
+ if "_qkv_same_embed_dim" not in state:
431
+ state["_qkv_same_embed_dim"] = True
432
+
433
+ super(MultiheadAttention, self).__setstate__(state)
434
+
435
+ def forward(
436
+ self,
437
+ query,
438
+ key,
439
+ value,
440
+ key_padding_mask=None,
441
+ need_weights=True,
442
+ attn_mask=None,
443
+ attention_probs_forward_hook=None,
444
+ attention_probs_backwards_hook=None,
445
+ ):
446
+ r"""
447
+ Args:
448
+ query, key, value: map a query and a set of key-value pairs to an output.
449
+ See "Attention Is All You Need" for more details.
450
+ key_padding_mask: if provided, specified padding elements in the key will
451
+ be ignored by the attention. When given a binary mask and a value is True,
452
+ the corresponding value on the attention layer will be ignored. When given
453
+ a byte mask and a value is non-zero, the corresponding value on the attention
454
+ layer will be ignored
455
+ need_weights: output attn_output_weights.
456
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
457
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
458
+
459
+ Shape:
460
+ - Inputs:
461
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
462
+ the embedding dimension.
463
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
464
+ the embedding dimension.
465
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
466
+ the embedding dimension.
467
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
468
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
469
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
470
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
471
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
472
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
473
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
474
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
475
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
476
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
477
+ is provided, it will be added to the attention weight.
478
+
479
+ - Outputs:
480
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
481
+ E is the embedding dimension.
482
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
483
+ L is the target sequence length, S is the source sequence length.
484
+ """
485
+ if not self._qkv_same_embed_dim:
486
+ return multi_head_attention_forward(
487
+ query,
488
+ key,
489
+ value,
490
+ self.embed_dim,
491
+ self.num_heads,
492
+ self.in_proj_weight,
493
+ self.in_proj_bias,
494
+ self.bias_k,
495
+ self.bias_v,
496
+ self.add_zero_attn,
497
+ self.dropout,
498
+ self.out_proj.weight,
499
+ self.out_proj.bias,
500
+ training=self.training,
501
+ key_padding_mask=key_padding_mask,
502
+ need_weights=need_weights,
503
+ attn_mask=attn_mask,
504
+ use_separate_proj_weight=True,
505
+ q_proj_weight=self.q_proj_weight,
506
+ k_proj_weight=self.k_proj_weight,
507
+ v_proj_weight=self.v_proj_weight,
508
+ attention_probs_forward_hook=attention_probs_forward_hook,
509
+ attention_probs_backwards_hook=attention_probs_backwards_hook,
510
+ )
511
+ else:
512
+ return multi_head_attention_forward(
513
+ query,
514
+ key,
515
+ value,
516
+ self.embed_dim,
517
+ self.num_heads,
518
+ self.in_proj_weight,
519
+ self.in_proj_bias,
520
+ self.bias_k,
521
+ self.bias_v,
522
+ self.add_zero_attn,
523
+ self.dropout,
524
+ self.out_proj.weight,
525
+ self.out_proj.bias,
526
+ training=self.training,
527
+ key_padding_mask=key_padding_mask,
528
+ need_weights=need_weights,
529
+ attn_mask=attn_mask,
530
+ attention_probs_forward_hook=attention_probs_forward_hook,
531
+ attention_probs_backwards_hook=attention_probs_backwards_hook,
532
+ )
CLIP_Explainability/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
CLIP_Explainability/clip_.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ taken from https://github.com/hila-chefer/Transformer-MM-Explainability
3
+ added similarity_score
4
+ """
5
+
6
+ import hashlib
7
+ import os
8
+ import urllib
9
+ import warnings
10
+ from typing import Union, List
11
+ import re
12
+ import html
13
+
14
+ import torch
15
+ from PIL import Image
16
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
17
+ from tqdm import tqdm
18
+ import ftfy
19
+
20
+ from transformers import BatchFeature
21
+
22
+ from .model import build_model
23
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
24
+
25
+ __all__ = ["available_models", "load", "tokenize"]
26
+ _tokenizer = _Tokenizer()
27
+
28
+ _MODELS = {
29
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
30
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
31
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
32
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
33
+ }
34
+
35
+
36
+ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
37
+ os.makedirs(root, exist_ok=True)
38
+ filename = os.path.basename(url)
39
+
40
+ expected_sha256 = url.split("/")[-2]
41
+ download_target = os.path.join(root, filename)
42
+
43
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
44
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
45
+
46
+ if os.path.isfile(download_target):
47
+ if (
48
+ hashlib.sha256(open(download_target, "rb").read()).hexdigest()
49
+ == expected_sha256
50
+ ):
51
+ return download_target
52
+ else:
53
+ warnings.warn(
54
+ f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
55
+ )
56
+
57
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
58
+ with tqdm(
59
+ total=int(source.info().get("Content-Length")),
60
+ ncols=80,
61
+ unit="iB",
62
+ unit_scale=True,
63
+ ) as loop:
64
+ while True:
65
+ buffer = source.read(8192)
66
+ if not buffer:
67
+ break
68
+
69
+ output.write(buffer)
70
+ loop.update(len(buffer))
71
+
72
+ if (
73
+ hashlib.sha256(open(download_target, "rb").read()).hexdigest()
74
+ != expected_sha256
75
+ ):
76
+ raise RuntimeError(
77
+ f"Model has been downloaded but the SHA256 checksum does not not match"
78
+ )
79
+
80
+ return download_target
81
+
82
+
83
+ def _transform(n_px):
84
+ return Compose(
85
+ [
86
+ Resize(n_px, interpolation=Image.BICUBIC),
87
+ CenterCrop(n_px),
88
+ lambda image: image.convert("RGB"),
89
+ ToTensor(),
90
+ Normalize(
91
+ (0.48145466, 0.4578275, 0.40821073),
92
+ (0.26862954, 0.26130258, 0.27577711),
93
+ ),
94
+ ]
95
+ )
96
+
97
+
98
+ def available_models() -> List[str]:
99
+ """Returns the names of available CLIP models"""
100
+ return list(_MODELS.keys())
101
+
102
+
103
+ def load(
104
+ name: str,
105
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
106
+ jit=True,
107
+ ):
108
+ """Load a CLIP model
109
+
110
+ Parameters
111
+ ----------
112
+ name : str
113
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
114
+
115
+ device : Union[str, torch.device]
116
+ The device to put the loaded model
117
+
118
+ jit : bool
119
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
120
+
121
+ Returns
122
+ -------
123
+ model : torch.nn.Module
124
+ The CLIP model
125
+
126
+ preprocess : Callable[[PIL.Image], torch.Tensor]
127
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
128
+ """
129
+ if name in _MODELS:
130
+ model_path = _download(_MODELS[name])
131
+ elif os.path.isfile(name):
132
+ model_path = name
133
+ else:
134
+ raise RuntimeError(
135
+ f"Model {name} not found; available models = {available_models()}"
136
+ )
137
+
138
+ try:
139
+ # loading JIT archive
140
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
141
+ state_dict = None
142
+ except RuntimeError:
143
+ # loading saved state dict
144
+ if jit:
145
+ warnings.warn(
146
+ f"File {model_path} is not a JIT archive. Loading as a state dict instead"
147
+ )
148
+ jit = False
149
+ state_dict = torch.load(model_path, map_location="cpu")
150
+
151
+ if not jit:
152
+ model = build_model(state_dict or model.state_dict()).to(device)
153
+ if str(device) == "cpu":
154
+ model.float()
155
+ return model, _transform(model.visual.input_resolution)
156
+
157
+ # patch the device names
158
+ device_holder = torch.jit.trace(
159
+ lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
160
+ )
161
+ device_node = [
162
+ n
163
+ for n in device_holder.graph.findAllNodes("prim::Constant")
164
+ if "Device" in repr(n)
165
+ ][-1]
166
+
167
+ def patch_device(module):
168
+ graphs = [module.graph] if hasattr(module, "graph") else []
169
+ if hasattr(module, "forward1"):
170
+ graphs.append(module.forward1.graph)
171
+
172
+ for graph in graphs:
173
+ for node in graph.findAllNodes("prim::Constant"):
174
+ if "value" in node.attributeNames() and str(node["value"]).startswith(
175
+ "cuda"
176
+ ):
177
+ node.copyAttributes(device_node)
178
+
179
+ model.apply(patch_device)
180
+ patch_device(model.encode_image)
181
+ patch_device(model.encode_text)
182
+
183
+ # patch dtype to float32 on CPU
184
+ if str(device) == "cpu":
185
+ float_holder = torch.jit.trace(
186
+ lambda: torch.ones([]).float(), example_inputs=[]
187
+ )
188
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
189
+ float_node = float_input.node()
190
+
191
+ def patch_float(module):
192
+ graphs = [module.graph] if hasattr(module, "graph") else []
193
+ if hasattr(module, "forward1"):
194
+ graphs.append(module.forward1.graph)
195
+
196
+ for graph in graphs:
197
+ for node in graph.findAllNodes("aten::to"):
198
+ inputs = list(node.inputs())
199
+ for i in [
200
+ 1,
201
+ 2,
202
+ ]: # dtype can be the second or third argument to aten::to()
203
+ if inputs[i].node()["value"] == 5:
204
+ inputs[i].node().copyAttributes(float_node)
205
+
206
+ model.apply(patch_float)
207
+ patch_float(model.encode_image)
208
+ patch_float(model.encode_text)
209
+
210
+ model.float()
211
+
212
+ return model, _transform(model.input_resolution.item())
213
+
214
+
215
+ def tokenize(
216
+ texts: Union[str, List[str]], context_length: int = 77
217
+ ) -> torch.LongTensor:
218
+ """
219
+ Returns the tokenized representation of given input string(s)
220
+
221
+ Parameters
222
+ ----------
223
+ texts : Union[str, List[str]]
224
+ An input string or a list of input strings to tokenize
225
+
226
+ context_length : int
227
+ The context length to use; all CLIP models use 77 as the context length
228
+
229
+ Returns
230
+ -------
231
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
232
+ """
233
+ if isinstance(texts, str):
234
+ texts = [texts]
235
+
236
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
237
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
238
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
239
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
240
+
241
+ for i, tokens in enumerate(all_tokens):
242
+ if len(tokens) > context_length:
243
+ raise RuntimeError(
244
+ f"Input {texts[i]} is too long for context length {context_length}"
245
+ )
246
+ result[i, : len(tokens)] = torch.tensor(tokens)
247
+
248
+ return result
249
+
250
+
251
+ def basic_clean(text):
252
+ text = ftfy.fix_text(text)
253
+ text = html.unescape(html.unescape(text))
254
+ return text.strip()
255
+
256
+
257
+ def whitespace_clean(text):
258
+ text = re.sub(r"\s+", " ", text)
259
+ text = text.strip()
260
+ return text
261
+
262
+
263
+ def tokenize_ja(
264
+ tokenizer,
265
+ texts: Union[str, List[str]],
266
+ max_seq_len: int = 77,
267
+ ):
268
+ """
269
+ This is a function that have the original clip's code has.
270
+ https://github.com/openai/CLIP/blob/main/clip/clip.py#L195
271
+ """
272
+ if isinstance(texts, str):
273
+ texts = [texts]
274
+ texts = [whitespace_clean(basic_clean(text)) for text in texts]
275
+
276
+ inputs = tokenizer(
277
+ texts,
278
+ max_length=max_seq_len - 1,
279
+ padding="max_length",
280
+ truncation=True,
281
+ add_special_tokens=False,
282
+ )
283
+ # add bos token at first place
284
+ input_ids = [[tokenizer.bos_token_id] + ids for ids in inputs["input_ids"]]
285
+ attention_mask = [[1] + am for am in inputs["attention_mask"]]
286
+ position_ids = [list(range(0, len(input_ids[0])))] * len(texts)
287
+
288
+ return BatchFeature(
289
+ {
290
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
291
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
292
+ "position_ids": torch.tensor(position_ids, dtype=torch.long),
293
+ }
294
+ )
295
+
296
+
297
+ def similarity_score(clip_model, image, target_features):
298
+ image_features = clip_model.encode_image(image)
299
+
300
+ image_features_norm = image_features.norm(dim=-1, keepdim=True)
301
+ image_features_new = image_features / image_features_norm
302
+ target_features_norm = target_features.norm(dim=-1, keepdim=True)
303
+ target_features_new = target_features / target_features_norm
304
+
305
+ return image_features_new[0].dot(target_features_new[0]) * 100
CLIP_Explainability/image_utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+ def show_cam_on_image(img, mask, neg_saliency=False):
5
+
6
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
7
+
8
+ heatmap = np.float32(heatmap) / 255
9
+ cam = heatmap + np.float32(img)
10
+ cam = cam / np.max(cam)
11
+ return cam
12
+
13
+ def show_overlapped_cam(img, neg_mask, pos_mask):
14
+ neg_heatmap = cv2.applyColorMap(np.uint8(255 * neg_mask), cv2.COLORMAP_RAINBOW)
15
+ pos_heatmap = cv2.applyColorMap(np.uint8(255 * pos_mask), cv2.COLORMAP_JET)
16
+ neg_heatmap = np.float32(neg_heatmap) / 255
17
+ pos_heatmap = np.float32(pos_heatmap) / 255
18
+ # try different options: sum, average, ...
19
+ heatmap = neg_heatmap + pos_heatmap
20
+ cam = heatmap + np.float32(img)
21
+ cam = cam / np.max(cam)
22
+ return cam
CLIP_Explainability/model.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ taken from https://github.com/hila-chefer/Transformer-MM-Explainability
3
+ """
4
+
5
+ from collections import OrderedDict
6
+ from typing import Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from .auxilary import *
13
+
14
+ class Bottleneck(nn.Module):
15
+ expansion = 4
16
+
17
+ def __init__(self, inplanes, planes, stride=1):
18
+ super().__init__()
19
+
20
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
21
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
22
+ self.bn1 = nn.BatchNorm2d(planes)
23
+
24
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
25
+ self.bn2 = nn.BatchNorm2d(planes)
26
+
27
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
28
+
29
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
30
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
31
+
32
+ self.relu = nn.ReLU(inplace=True)
33
+ self.downsample = None
34
+ self.stride = stride
35
+
36
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
37
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
38
+ self.downsample = nn.Sequential(OrderedDict([
39
+ ("-1", nn.AvgPool2d(stride)),
40
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
41
+ ("1", nn.BatchNorm2d(planes * self.expansion))
42
+ ]))
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ identity = x
46
+
47
+ out = self.relu(self.bn1(self.conv1(x)))
48
+ out = self.relu(self.bn2(self.conv2(out)))
49
+ out = self.avgpool(out)
50
+ out = self.bn3(self.conv3(out))
51
+
52
+ if self.downsample is not None:
53
+ identity = self.downsample(x)
54
+
55
+ out += identity
56
+ out = self.relu(out)
57
+ return out
58
+
59
+
60
+ class AttentionPool2d(nn.Module):
61
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
62
+ super().__init__()
63
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
64
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
66
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
67
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
68
+ self.num_heads = num_heads
69
+
70
+ def forward(self, x):
71
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
72
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
73
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
74
+ x, _ = multi_head_attention_forward(
75
+ query=x, key=x, value=x,
76
+ embed_dim_to_check=x.shape[-1],
77
+ num_heads=self.num_heads,
78
+ q_proj_weight=self.q_proj.weight,
79
+ k_proj_weight=self.k_proj.weight,
80
+ v_proj_weight=self.v_proj.weight,
81
+ in_proj_weight=None,
82
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
83
+ bias_k=None,
84
+ bias_v=None,
85
+ add_zero_attn=False,
86
+ dropout_p=0,
87
+ out_proj_weight=self.c_proj.weight,
88
+ out_proj_bias=self.c_proj.bias,
89
+ use_separate_proj_weight=True,
90
+ training=self.training,
91
+ need_weights=False
92
+ )
93
+
94
+ return x[0]
95
+
96
+
97
+ class ModifiedResNet(nn.Module):
98
+ """
99
+ A ResNet class that is similar to torchvision's but contains the following changes:
100
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
101
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
102
+ - The final pooling layer is a QKV attention instead of an average pool
103
+ """
104
+
105
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
106
+ super().__init__()
107
+ self.output_dim = output_dim
108
+ self.input_resolution = input_resolution
109
+
110
+ # the 3-layer stem
111
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
112
+ self.bn1 = nn.BatchNorm2d(width // 2)
113
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
114
+ self.bn2 = nn.BatchNorm2d(width // 2)
115
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116
+ self.bn3 = nn.BatchNorm2d(width)
117
+ self.avgpool = nn.AvgPool2d(2)
118
+ self.relu = nn.ReLU(inplace=True)
119
+
120
+ # residual layers
121
+ self._inplanes = width # this is a *mutable* variable used during construction
122
+ self.layer1 = self._make_layer(width, layers[0])
123
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126
+
127
+ embed_dim = width * 32 # the ResNet feature dimension
128
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
129
+
130
+ def _make_layer(self, planes, blocks, stride=1):
131
+ layers = [Bottleneck(self._inplanes, planes, stride)]
132
+
133
+ self._inplanes = planes * Bottleneck.expansion
134
+ for _ in range(1, blocks):
135
+ layers.append(Bottleneck(self._inplanes, planes))
136
+
137
+ return nn.Sequential(*layers)
138
+
139
+ def forward(self, x):
140
+ def stem(x):
141
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
142
+ x = self.relu(bn(conv(x)))
143
+ x = self.avgpool(x)
144
+ return x
145
+
146
+ x = x.type(self.conv1.weight.dtype)
147
+ x = stem(x)
148
+ x = self.layer1(x)
149
+ x = self.layer2(x)
150
+ x = self.layer3(x)
151
+ x = self.layer4(x)
152
+ x = self.attnpool(x)
153
+
154
+ return x
155
+
156
+
157
+ class LayerNorm(nn.LayerNorm):
158
+ """Subclass torch's LayerNorm to handle fp16."""
159
+
160
+ def forward(self, x: torch.Tensor):
161
+ orig_type = x.dtype
162
+ ret = super().forward(x.type(torch.float32))
163
+ return ret.type(orig_type)
164
+
165
+
166
+ class QuickGELU(nn.Module):
167
+ def forward(self, x: torch.Tensor):
168
+ return x * torch.sigmoid(1.702 * x)
169
+
170
+
171
+ class ResidualAttentionBlock(nn.Module):
172
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173
+ super().__init__()
174
+
175
+ self.attn = MultiheadAttention(d_model, n_head)
176
+ self.ln_1 = LayerNorm(d_model)
177
+ self.mlp = nn.Sequential(OrderedDict([
178
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
179
+ ("gelu", QuickGELU()),
180
+ ("c_proj", nn.Linear(d_model * 4, d_model))
181
+ ]))
182
+ self.ln_2 = LayerNorm(d_model)
183
+ self.attn_mask = attn_mask
184
+
185
+ self.attn_probs = None
186
+ self.attn_grad = None
187
+
188
+ def set_attn_probs(self, attn_probs):
189
+ self.attn_probs = attn_probs
190
+
191
+ def set_attn_grad(self, attn_grad):
192
+ self.attn_grad = attn_grad
193
+
194
+ def attention(self, x: torch.Tensor):
195
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
196
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, attention_probs_forward_hook=self.set_attn_probs,
197
+ attention_probs_backwards_hook=self.set_attn_grad)[0]
198
+
199
+ def forward(self, x: torch.Tensor):
200
+ x = x + self.attention(self.ln_1(x))
201
+ x = x + self.mlp(self.ln_2(x))
202
+ return x
203
+
204
+
205
+ class Transformer(nn.Module):
206
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
207
+ super().__init__()
208
+ self.width = width
209
+ self.layers = layers
210
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
211
+
212
+ def forward(self, x: torch.Tensor):
213
+ return self.resblocks(x)
214
+
215
+
216
+ class VisualTransformer(nn.Module):
217
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
218
+ super().__init__()
219
+ self.input_resolution = input_resolution
220
+ self.output_dim = output_dim
221
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
222
+
223
+ scale = width ** -0.5
224
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
225
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
226
+ self.ln_pre = LayerNorm(width)
227
+
228
+ self.transformer = Transformer(width, layers, heads)
229
+
230
+ self.ln_post = LayerNorm(width)
231
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
232
+
233
+ def forward(self, x: torch.Tensor):
234
+ x = self.conv1(x) # shape = [*, width, grid, grid]
235
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
236
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
237
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
238
+ x = x + self.positional_embedding.to(x.dtype)
239
+ x = self.ln_pre(x)
240
+
241
+ x = x.permute(1, 0, 2) # NLD -> LND
242
+ x = self.transformer(x)
243
+ x = x.permute(1, 0, 2) # LND -> NLD
244
+
245
+ x = self.ln_post(x[:, 0, :])
246
+
247
+ if self.proj is not None:
248
+ x = x @ self.proj
249
+
250
+ return x
251
+
252
+
253
+ class CLIP(nn.Module):
254
+ def __init__(self,
255
+ embed_dim: int,
256
+ # vision
257
+ image_resolution: int,
258
+ vision_layers: Union[Tuple[int, int, int, int], int],
259
+ vision_width: int,
260
+ vision_patch_size: int,
261
+ # text
262
+ context_length: int,
263
+ vocab_size: int,
264
+ transformer_width: int,
265
+ transformer_heads: int,
266
+ transformer_layers: int
267
+ ):
268
+ super().__init__()
269
+
270
+ self.context_length = context_length
271
+
272
+ if isinstance(vision_layers, (tuple, list)):
273
+ vision_heads = vision_width * 32 // 64
274
+ self.visual = ModifiedResNet(
275
+ layers=vision_layers,
276
+ output_dim=embed_dim,
277
+ heads=vision_heads,
278
+ input_resolution=image_resolution,
279
+ width=vision_width
280
+ )
281
+ else:
282
+ vision_heads = vision_width // 64
283
+ self.visual = VisualTransformer(
284
+ input_resolution=image_resolution,
285
+ patch_size=vision_patch_size,
286
+ width=vision_width,
287
+ layers=vision_layers,
288
+ heads=vision_heads,
289
+ output_dim=embed_dim
290
+ )
291
+
292
+ self.transformer = Transformer(
293
+ width=transformer_width,
294
+ layers=transformer_layers,
295
+ heads=transformer_heads,
296
+ attn_mask=self.build_attention_mask()
297
+ )
298
+
299
+ self.vocab_size = vocab_size
300
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
301
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
302
+ self.ln_final = LayerNorm(transformer_width)
303
+
304
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
305
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
306
+
307
+ self.initialize_parameters()
308
+
309
+ def initialize_parameters(self):
310
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
311
+ nn.init.normal_(self.positional_embedding, std=0.01)
312
+
313
+ if isinstance(self.visual, ModifiedResNet):
314
+ if self.visual.attnpool is not None:
315
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
316
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
317
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
318
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
319
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
320
+
321
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
322
+ for name, param in resnet_block.named_parameters():
323
+ if name.endswith("bn3.weight"):
324
+ nn.init.zeros_(param)
325
+
326
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
327
+ attn_std = self.transformer.width ** -0.5
328
+ fc_std = (2 * self.transformer.width) ** -0.5
329
+ for block in self.transformer.resblocks:
330
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
331
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
332
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
333
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
334
+
335
+ if self.text_projection is not None:
336
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
337
+
338
+ def build_attention_mask(self):
339
+ # lazily create causal attention mask, with full attention between the vision tokens
340
+ # pytorch uses additive attention mask; fill with -inf
341
+ mask = torch.empty(self.context_length, self.context_length)
342
+ mask.fill_(float("-inf"))
343
+ mask.triu_(1) # zero out the lower diagonal
344
+ return mask
345
+
346
+ @property
347
+ def dtype(self):
348
+ return self.visual.conv1.weight.dtype
349
+
350
+ def encode_image(self, image):
351
+ return self.visual(image.type(self.dtype))
352
+
353
+ def encode_text(self, text):
354
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
355
+
356
+ x = x + self.positional_embedding.type(self.dtype)
357
+ x = x.permute(1, 0, 2) # NLD -> LND
358
+ x = self.transformer(x)
359
+ x = x.permute(1, 0, 2) # LND -> NLD
360
+ x = self.ln_final(x).type(self.dtype)
361
+
362
+ # x.shape = [batch_size, n_ctx, transformer.width]
363
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
364
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
365
+
366
+ return x
367
+
368
+ def forward(self, image, text):
369
+ image_features = self.encode_image(image)
370
+ text_features = self.encode_text(text)
371
+
372
+ # normalized features
373
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
374
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
375
+
376
+ # cosine similarity as logits
377
+ logit_scale = self.logit_scale.exp()
378
+ logits_per_image = logit_scale * image_features @ text_features.t()
379
+ logits_per_text = logit_scale * text_features @ image_features.t()
380
+
381
+ # shape = [global_batch_size, global_batch_size]
382
+ return logits_per_image, logits_per_text
383
+
384
+
385
+ def convert_weights(model: nn.Module):
386
+ """Convert applicable model parameters to fp16"""
387
+
388
+ def _convert_weights_to_fp16(l):
389
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
390
+ l.weight.data = l.weight.data.half()
391
+ if l.bias is not None:
392
+ l.bias.data = l.bias.data.half()
393
+
394
+ if isinstance(l, MultiheadAttention):
395
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
396
+ tensor = getattr(l, attr)
397
+ if tensor is not None:
398
+ tensor.data = tensor.data.half()
399
+
400
+ for name in ["text_projection", "proj"]:
401
+ if hasattr(l, name):
402
+ attr = getattr(l, name)
403
+ if attr is not None:
404
+ attr.data = attr.data.half()
405
+
406
+ model.apply(_convert_weights_to_fp16)
407
+
408
+
409
+ def build_model(state_dict: dict):
410
+ vit = "visual.proj" in state_dict
411
+
412
+ if vit:
413
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
414
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
415
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
416
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
417
+ image_resolution = vision_patch_size * grid_size
418
+ else:
419
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
420
+ vision_layers = tuple(counts)
421
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
422
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
423
+ vision_patch_size = None
424
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
425
+ image_resolution = output_width * 32
426
+
427
+ embed_dim = state_dict["text_projection"].shape[1]
428
+ context_length = state_dict["positional_embedding"].shape[0]
429
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
430
+ transformer_width = state_dict["ln_final.weight"].shape[0]
431
+ transformer_heads = transformer_width // 64
432
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
433
+
434
+ model = CLIP(
435
+ embed_dim,
436
+ image_resolution, vision_layers, vision_width, vision_patch_size,
437
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
438
+ )
439
+
440
+ for key in ["input_resolution", "context_length", "vocab_size"]:
441
+ if key in state_dict:
442
+ del state_dict[key]
443
+
444
+ convert_weights(model)
445
+ model.load_state_dict(state_dict)
446
+ return model.eval()
CLIP_Explainability/simple_tokenizer.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ taken from https://github.com/hila-chefer/Transformer-MM-Explainability
3
+ """
4
+
5
+ import gzip
6
+ import html
7
+ import os
8
+ from functools import lru_cache
9
+
10
+ import ftfy
11
+ import regex as re
12
+
13
+
14
+ @lru_cache()
15
+ def default_bpe():
16
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
17
+
18
+
19
+ @lru_cache()
20
+ def bytes_to_unicode():
21
+ """
22
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
23
+ The reversible bpe codes work on unicode strings.
24
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
25
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
26
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
27
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
28
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
29
+ """
30
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
31
+ cs = bs[:]
32
+ n = 0
33
+ for b in range(2**8):
34
+ if b not in bs:
35
+ bs.append(b)
36
+ cs.append(2**8+n)
37
+ n += 1
38
+ cs = [chr(n) for n in cs]
39
+ return dict(zip(bs, cs))
40
+
41
+
42
+ def get_pairs(word):
43
+ """Return set of symbol pairs in a word.
44
+ Word is represented as tuple of symbols (symbols being variable-length strings).
45
+ """
46
+ pairs = set()
47
+ prev_char = word[0]
48
+ for char in word[1:]:
49
+ pairs.add((prev_char, char))
50
+ prev_char = char
51
+ return pairs
52
+
53
+
54
+ def basic_clean(text):
55
+ text = ftfy.fix_text(text)
56
+ text = html.unescape(html.unescape(text))
57
+ return text.strip()
58
+
59
+
60
+ def whitespace_clean(text):
61
+ text = re.sub(r'\s+', ' ', text)
62
+ text = text.strip()
63
+ return text
64
+
65
+
66
+ class SimpleTokenizer(object):
67
+ def __init__(self, bpe_path: str = default_bpe()):
68
+ self.byte_encoder = bytes_to_unicode()
69
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
70
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
71
+ merges = merges[1:49152-256-2+1]
72
+ merges = [tuple(merge.split()) for merge in merges]
73
+ vocab = list(bytes_to_unicode().values())
74
+ vocab = vocab + [v+'</w>' for v in vocab]
75
+ for merge in merges:
76
+ vocab.append(''.join(merge))
77
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
78
+ self.encoder = dict(zip(vocab, range(len(vocab))))
79
+ self.decoder = {v: k for k, v in self.encoder.items()}
80
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
81
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
82
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
83
+
84
+ def bpe(self, token):
85
+ if token in self.cache:
86
+ return self.cache[token]
87
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
88
+ pairs = get_pairs(word)
89
+
90
+ if not pairs:
91
+ return token+'</w>'
92
+
93
+ while True:
94
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
95
+ if bigram not in self.bpe_ranks:
96
+ break
97
+ first, second = bigram
98
+ new_word = []
99
+ i = 0
100
+ while i < len(word):
101
+ try:
102
+ j = word.index(first, i)
103
+ new_word.extend(word[i:j])
104
+ i = j
105
+ except:
106
+ new_word.extend(word[i:])
107
+ break
108
+
109
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
110
+ new_word.append(first+second)
111
+ i += 2
112
+ else:
113
+ new_word.append(word[i])
114
+ i += 1
115
+ new_word = tuple(new_word)
116
+ word = new_word
117
+ if len(word) == 1:
118
+ break
119
+ else:
120
+ pairs = get_pairs(word)
121
+ word = ' '.join(word)
122
+ self.cache[token] = word
123
+ return word
124
+
125
+ def encode(self, text):
126
+ bpe_tokens = []
127
+ text = whitespace_clean(basic_clean(text)).lower()
128
+ for token in re.findall(self.pat, text):
129
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
130
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
131
+ return bpe_tokens
132
+
133
+ def decode(self, tokens):
134
+ text = ''.join([self.decoder[token] for token in tokens])
135
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
136
+ return text
CLIP_Explainability/vit_cam.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+ import cv2
6
+ import regex as re
7
+
8
+ from .image_utils import show_cam_on_image, show_overlapped_cam
9
+
10
+
11
+ def vit_block_vis(
12
+ image,
13
+ target_features,
14
+ img_encoder,
15
+ block,
16
+ device,
17
+ grad=False,
18
+ neg_saliency=False,
19
+ img_dim=224,
20
+ ):
21
+ img_encoder.eval()
22
+ image_features = img_encoder(image)
23
+
24
+ image_features_norm = image_features.norm(dim=-1, keepdim=True)
25
+ image_features_new = image_features / image_features_norm
26
+ target_features_norm = target_features.norm(dim=-1, keepdim=True)
27
+ target_features_new = target_features / target_features_norm
28
+
29
+ similarity = image_features_new[0].dot(target_features_new[0])
30
+ image = (image - image.min()) / (image.max() - image.min())
31
+
32
+ img_encoder.zero_grad()
33
+ similarity.backward(retain_graph=True)
34
+
35
+ image_attn_blocks = list(
36
+ dict(img_encoder.transformer.resblocks.named_children()).values()
37
+ )
38
+
39
+ if grad:
40
+ cam = image_attn_blocks[block].attn_grad.detach()
41
+ else:
42
+ cam = image_attn_blocks[block].attn_probs.detach()
43
+
44
+ cam = cam.mean(dim=0)
45
+ image_relevance = cam[0, 1:]
46
+
47
+ resize_dim = int(np.sqrt(list(image_relevance.shape)[0]))
48
+
49
+ # image_relevance = image_relevance.reshape(1, 1, 7, 7)
50
+ image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim)
51
+
52
+ image_relevance = torch.nn.functional.interpolate(
53
+ image_relevance, size=img_dim, mode="bilinear"
54
+ )
55
+ image_relevance = image_relevance.reshape(img_dim, img_dim)
56
+ image_relevance = (image_relevance - image_relevance.min()) / (
57
+ image_relevance.max() - image_relevance.min()
58
+ )
59
+
60
+ cam = image_relevance * image
61
+ cam = cam / torch.max(cam)
62
+
63
+ # TODO: maybe we can ignore this...
64
+ ####
65
+ masked_image_features = img_encoder(cam)
66
+ masked_image_features_norm = masked_image_features.norm(dim=-1, keepdim=True)
67
+ masked_image_features_new = masked_image_features / masked_image_features_norm
68
+ new_score = masked_image_features_new[0].dot(target_features_new[0])
69
+ ####
70
+
71
+ cam = cam[0].permute(1, 2, 0).data.cpu().numpy()
72
+ cam = np.float32(cam)
73
+
74
+ plt.imshow(cam)
75
+
76
+ return new_score
77
+
78
+
79
+ def vit_relevance(
80
+ image,
81
+ target_features,
82
+ img_encoder,
83
+ device,
84
+ method="last grad",
85
+ neg_saliency=False,
86
+ img_dim=224,
87
+ ):
88
+ img_encoder.eval()
89
+ image_features = img_encoder(image)
90
+
91
+ image_features_norm = image_features.norm(dim=-1, keepdim=True)
92
+ image_features_new = image_features / image_features_norm
93
+ target_features_norm = target_features.norm(dim=-1, keepdim=True)
94
+ target_features_new = target_features / target_features_norm
95
+ similarity = image_features_new[0].dot(target_features_new[0])
96
+ if neg_saliency:
97
+ objective = 1 - similarity
98
+ else:
99
+ objective = similarity
100
+ img_encoder.zero_grad()
101
+ objective.backward(retain_graph=True)
102
+ image_attn_blocks = list(
103
+ dict(img_encoder.transformer.resblocks.named_children()).values()
104
+ )
105
+ num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
106
+
107
+ last_attn = image_attn_blocks[-1].attn_probs.detach()
108
+ last_attn = last_attn.reshape(-1, last_attn.shape[-1], last_attn.shape[-1])
109
+
110
+ last_grad = image_attn_blocks[-1].attn_grad.detach()
111
+ last_grad = last_grad.reshape(-1, last_grad.shape[-1], last_grad.shape[-1])
112
+
113
+ if method == "gradcam":
114
+ cam = last_grad * last_attn
115
+ cam = cam.clamp(min=0).mean(dim=0)
116
+ image_relevance = cam[0, 1:]
117
+
118
+ else:
119
+ R = torch.eye(
120
+ num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype
121
+ ).to(device)
122
+ for blk in image_attn_blocks:
123
+ cam = blk.attn_probs.detach()
124
+ cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
125
+
126
+ if method == "last grad":
127
+ grad = last_grad
128
+ elif method == "all grads":
129
+ grad = blk.attn_grad.detach()
130
+ else:
131
+ print(
132
+ "The available visualization methods are: 'gradcam', 'last grad', 'all grads'."
133
+ )
134
+ return
135
+
136
+ cam = grad * cam
137
+ cam = cam.clamp(min=0).mean(dim=0)
138
+ R += torch.matmul(cam, R)
139
+
140
+ image_relevance = R[0, 1:]
141
+
142
+ resize_dim = int(np.sqrt(list(image_relevance.shape)[0]))
143
+
144
+ # image_relevance = image_relevance.reshape(1, 1, 7, 7)
145
+ image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim)
146
+
147
+ image_relevance = torch.nn.functional.interpolate(
148
+ image_relevance, size=img_dim, mode="bilinear"
149
+ )
150
+ image_relevance = image_relevance.reshape(img_dim, img_dim).data.cpu().numpy()
151
+ image_relevance = (image_relevance - image_relevance.min()) / (
152
+ image_relevance.max() - image_relevance.min()
153
+ )
154
+ image = image[0].permute(1, 2, 0).data.cpu().numpy()
155
+ image = (image - image.min()) / (image.max() - image.min())
156
+
157
+ return image_relevance, image
158
+
159
+
160
+ def interpret_vit(
161
+ image,
162
+ target_features,
163
+ img_encoder,
164
+ device,
165
+ method="last grad",
166
+ neg_saliency=False,
167
+ img_dim=224,
168
+ ):
169
+ image_relevance, image = vit_relevance(
170
+ image,
171
+ target_features,
172
+ img_encoder,
173
+ device,
174
+ method=method,
175
+ neg_saliency=neg_saliency,
176
+ img_dim=img_dim,
177
+ )
178
+
179
+ vis = show_cam_on_image(image, image_relevance, neg_saliency=neg_saliency)
180
+ vis = np.uint8(255 * vis)
181
+ vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
182
+
183
+ return vis
184
+ # plt.imshow(vis)
185
+
186
+
187
+ def interpret_vit_overlapped(
188
+ image, target_features, img_encoder, device, method="last grad", img_dim=224
189
+ ):
190
+ pos_image_relevance, _ = vit_relevance(
191
+ image,
192
+ target_features,
193
+ img_encoder,
194
+ device,
195
+ method=method,
196
+ neg_saliency=False,
197
+ img_dim=img_dim,
198
+ )
199
+ neg_image_relevance, image = vit_relevance(
200
+ image,
201
+ target_features,
202
+ img_encoder,
203
+ device,
204
+ method=method,
205
+ neg_saliency=True,
206
+ img_dim=img_dim,
207
+ )
208
+
209
+ vis = show_overlapped_cam(image, neg_image_relevance, pos_image_relevance)
210
+ vis = np.uint8(255 * vis)
211
+ vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
212
+
213
+ plt.imshow(vis)
214
+
215
+
216
+ def vit_perword_relevance(
217
+ image,
218
+ text,
219
+ clip_model,
220
+ clip_tokenizer,
221
+ device,
222
+ masked_word="",
223
+ use_last_grad=True,
224
+ data_only=False,
225
+ img_dim=224,
226
+ ):
227
+ clip_model.eval()
228
+
229
+ main_text = clip_tokenizer(text).to(device)
230
+ # remove the word for which you want to visualize the saliency
231
+ masked_text = re.sub(masked_word, "", text)
232
+ masked_text = clip_tokenizer(masked_text).to(device)
233
+
234
+ image_features = clip_model.encode_image(image)
235
+ main_text_features = clip_model.encode_text(main_text)
236
+ masked_text_features = clip_model.encode_text(masked_text)
237
+
238
+ image_features_norm = image_features.norm(dim=-1, keepdim=True)
239
+ image_features_new = image_features / image_features_norm
240
+ main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True)
241
+ main_text_features_new = main_text_features / main_text_features_norm
242
+
243
+ masked_text_features_norm = masked_text_features.norm(dim=-1, keepdim=True)
244
+ masked_text_features_new = masked_text_features / masked_text_features_norm
245
+
246
+ objective = image_features_new[0].dot(
247
+ main_text_features_new[0] - masked_text_features_new[0]
248
+ )
249
+
250
+ clip_model.visual.zero_grad()
251
+ objective.backward(retain_graph=True)
252
+
253
+ image_attn_blocks = list(
254
+ dict(clip_model.visual.transformer.resblocks.named_children()).values()
255
+ )
256
+ num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
257
+
258
+ R = torch.eye(
259
+ num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype
260
+ ).to(device)
261
+
262
+ last_grad = image_attn_blocks[-1].attn_grad.detach()
263
+ last_grad = last_grad.reshape(-1, last_grad.shape[-1], last_grad.shape[-1])
264
+
265
+ for blk in image_attn_blocks:
266
+ cam = blk.attn_probs.detach()
267
+ cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
268
+
269
+ if use_last_grad:
270
+ grad = last_grad
271
+ else:
272
+ grad = blk.attn_grad.detach()
273
+
274
+ cam = grad * cam
275
+ cam = cam.clamp(min=0).mean(dim=0)
276
+ R += torch.matmul(cam, R)
277
+
278
+ image_relevance = R[0, 1:]
279
+
280
+ resize_dim = int(np.sqrt(list(image_relevance.shape)[0]))
281
+
282
+ image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim)
283
+
284
+ image_relevance = torch.nn.functional.interpolate(
285
+ image_relevance, size=img_dim, mode="bilinear"
286
+ )
287
+ image_relevance = image_relevance.reshape(img_dim, img_dim).data.cpu().numpy()
288
+ image_relevance = (image_relevance - image_relevance.min()) / (
289
+ image_relevance.max() - image_relevance.min()
290
+ )
291
+
292
+ if data_only:
293
+ return image_relevance
294
+
295
+ image = image[0].permute(1, 2, 0).data.cpu().numpy()
296
+ image = (image - image.min()) / (image.max() - image.min())
297
+
298
+ return image_relevance, image
299
+
300
+
301
+ def interpret_perword_vit(
302
+ image,
303
+ text,
304
+ clip_model,
305
+ clip_tokenizer,
306
+ device,
307
+ masked_word="",
308
+ use_last_grad=True,
309
+ img_dim=224,
310
+ ):
311
+ image_relevance, image = vit_perword_relevance(
312
+ image,
313
+ text,
314
+ clip_model,
315
+ clip_tokenizer,
316
+ device,
317
+ masked_word,
318
+ use_last_grad,
319
+ img_dim=img_dim,
320
+ )
321
+ vis = show_cam_on_image(image, image_relevance)
322
+ vis = np.uint8(255 * vis)
323
+ vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
324
+
325
+ plt.imshow(vis)
app.py CHANGED
@@ -1,12 +1,26 @@
 
 
1
  from math import ceil
2
 
 
3
  from multilingual_clip import pt_multilingual_clip
4
  import numpy as np
5
  import pandas as pd
 
 
6
  import streamlit as st
7
  import torch
 
8
  from transformers import AutoTokenizer, AutoModel
9
 
 
 
 
 
 
 
 
 
10
 
11
  st.set_page_config(layout="wide")
12
 
@@ -15,16 +29,28 @@ def init():
15
  st.session_state.current_page = 1
16
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
18
 
19
  # Load the open CLIP models
20
  ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
21
- ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
 
 
 
 
22
 
23
  st.session_state.ml_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(
24
  ml_model_name
25
  )
26
  st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name)
27
 
 
 
 
 
 
 
 
28
  st.session_state.ja_model = AutoModel.from_pretrained(
29
  ja_model_name, trust_remote_code=True
30
  ).to(device)
@@ -32,7 +58,12 @@ def init():
32
  ja_model_name, trust_remote_code=True
33
  )
34
 
 
 
35
  st.session_state.search_image_ids = []
 
 
 
36
 
37
  # Load the image IDs
38
  st.session_state.images_info = pd.read_csv("./metadata.csv")
@@ -43,8 +74,10 @@ def init():
43
  )
44
 
45
  # Load the image feature vectors
46
- ml_image_features = np.load("./multilingual_features.npy")
47
- ja_image_features = np.load("./hakuhodo_features.npy")
 
 
48
 
49
  # Convert features to Tensors: Float32 on CPU and Float16 on GPU
50
  if device == "cpu":
@@ -128,16 +161,207 @@ def clip_search(search_query):
128
  st.session_state.image_ids,
129
  )
130
 
131
- result_image_ids = [match[0] for match in matches]
132
- st.session_state.search_image_ids = result_image_ids
133
 
134
 
135
  def string_search():
136
  clip_search(st.session_state.search_field_value)
137
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  st.title("Explore Japanese visual aesthetics with CLIP models")
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  search_row = st.columns([45, 10, 13, 7, 25], vertical_alignment="center")
142
  with search_row[0]:
143
  search_field = st.text_input(
@@ -148,7 +372,9 @@ with search_row[0]:
148
  key="search_field_value",
149
  )
150
  with search_row[1]:
151
- st.button("Search", on_click=string_search, use_container_width=True)
 
 
152
  with search_row[2]:
153
  st.empty()
154
  with search_row[3]:
@@ -163,7 +389,7 @@ with search_row[4]:
163
  label_visibility="collapsed",
164
  )
165
 
166
- canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="center")
167
  with canned_searches[0]:
168
  st.markdown("**Suggested searches:**")
169
  if st.session_state.active_model == "M-CLIP (multiple languages)":
@@ -257,16 +483,27 @@ for image_id in batch:
257
  link_text = st.session_state.images_info.loc[image_id]["permalink"].split("/")[
258
  2
259
  ]
 
 
 
 
260
  st.html(
261
  f"""<div style="display: flex; flex-direction: column; align-items: center">
262
- <img src="{st.session_state.images_info.loc[image_id]['image_url']}" style="max-width: 100%; max-height: 800px" />
263
- <div>{st.session_state.images_info.loc[image_id]['caption']}</div>
264
  </div>"""
265
  )
266
  st.caption(
267
- f"""<div style="display: flex; flex-direction: column; align-items: center; position: relative; top: -20px">
268
  <a href="{st.session_state.images_info.loc[image_id]['permalink']}">{link_text}</a>
269
  <div>""",
270
  unsafe_allow_html=True,
271
  )
 
 
 
 
 
 
 
272
  col = (col + 1) % row_size
 
1
+ from base64 import b64encode
2
+ from io import BytesIO
3
  from math import ceil
4
 
5
+ import matplotlib.pyplot as plt
6
  from multilingual_clip import pt_multilingual_clip
7
  import numpy as np
8
  import pandas as pd
9
+ from PIL import Image
10
+ import requests
11
  import streamlit as st
12
  import torch
13
+ from torchvision.transforms import ToPILImage
14
  from transformers import AutoTokenizer, AutoModel
15
 
16
+ from CLIP_Explainability.clip_ import load, tokenize
17
+ from CLIP_Explainability.vit_cam import (
18
+ interpret_vit,
19
+ vit_perword_relevance,
20
+ ) # , interpret_vit_overlapped
21
+
22
+ MAX_IMG_WIDTH = 450 # For small dialog
23
+ MAX_IMG_HEIGHT = 800
24
 
25
  st.set_page_config(layout="wide")
26
 
 
29
  st.session_state.current_page = 1
30
 
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ st.session_state.device = device
33
 
34
  # Load the open CLIP models
35
  ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
36
+ ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
37
+
38
+ st.session_state.ml_image_model, st.session_state.ml_image_preprocess = load(
39
+ ml_model_path, device=device, jit=False
40
+ )
41
 
42
  st.session_state.ml_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(
43
  ml_model_name
44
  )
45
  st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name)
46
 
47
+ ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
48
+ ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
49
+
50
+ st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load(
51
+ ja_model_path, device=device, jit=False
52
+ )
53
+
54
  st.session_state.ja_model = AutoModel.from_pretrained(
55
  ja_model_name, trust_remote_code=True
56
  ).to(device)
 
58
  ja_model_name, trust_remote_code=True
59
  )
60
 
61
+ st.session_state.active_model = "M-CLIP (multiple languages)"
62
+
63
  st.session_state.search_image_ids = []
64
+ st.session_state.search_image_scores = {}
65
+ st.session_state.activations_image = None
66
+ st.session_state.text_table_df = None
67
 
68
  # Load the image IDs
69
  st.session_state.images_info = pd.read_csv("./metadata.csv")
 
74
  )
75
 
76
  # Load the image feature vectors
77
+ # ml_image_features = np.load("./multilingual_features.npy")
78
+ # ja_image_features = np.load("./hakuhodo_features.npy")
79
+ ml_image_features = np.load("./resized_ml_features.npy")
80
+ ja_image_features = np.load("./resized_ja_features.npy")
81
 
82
  # Convert features to Tensors: Float32 on CPU and Float16 on GPU
83
  if device == "cpu":
 
161
  st.session_state.image_ids,
162
  )
163
 
164
+ st.session_state.search_image_ids = [match[0] for match in matches]
165
+ st.session_state.search_image_scores = {match[0]: match[1] for match in matches}
166
 
167
 
168
  def string_search():
169
  clip_search(st.session_state.search_field_value)
170
 
171
 
172
+ def visualize_gradcam(viz_image_id):
173
+ if not st.session_state.search_field_value:
174
+ return
175
+
176
+ header_cols = st.columns([80, 20], vertical_alignment="bottom")
177
+ with header_cols[0]:
178
+ st.title("Image + query details")
179
+ with header_cols[1]:
180
+ if st.button("Close"):
181
+ st.rerun()
182
+
183
+ st.markdown(
184
+ f"**Query text:** {st.session_state.search_field_value} | **Image relevance:** {round(st.session_state.search_image_scores[viz_image_id], 3)}"
185
+ )
186
+
187
+ # with st.spinner("Calculating..."):
188
+ info_text = st.text("Calculating activation regions...")
189
+
190
+ image_url = st.session_state.images_info.loc[viz_image_id]["image_url"]
191
+ image_response = requests.get(image_url)
192
+ image = Image.open(BytesIO(image_response.content), formats=["JPEG"])
193
+
194
+ img_dim = 224
195
+ if st.session_state.active_model == "M-CLIP (multiple languages)":
196
+ img_dim = 240
197
+
198
+ orig_img_dims = image.size
199
+
200
+ altered_image = image.resize((img_dim, img_dim), Image.LANCZOS)
201
+
202
+ if st.session_state.active_model == "M-CLIP (multiple languages)":
203
+ p_image = (
204
+ st.session_state.ml_image_preprocess(altered_image)
205
+ .unsqueeze(0)
206
+ .to(st.session_state.device)
207
+ )
208
+
209
+ # Sometimes used for token importance viz
210
+ tokenized_text = st.session_state.ml_tokenizer.tokenize(
211
+ st.session_state.search_field_value
212
+ )
213
+ image_model = st.session_state.ml_image_model
214
+ # tokenize = st.session_state.ml_tokenizer.tokenize
215
+
216
+ text_features = st.session_state.ml_model.forward(
217
+ st.session_state.search_field_value, st.session_state.ml_tokenizer
218
+ )
219
+
220
+ vis_t = interpret_vit(
221
+ p_image.type(st.session_state.ml_image_model.dtype),
222
+ text_features,
223
+ st.session_state.ml_image_model.visual,
224
+ st.session_state.device,
225
+ img_dim=img_dim,
226
+ )
227
+
228
+ else:
229
+ p_image = (
230
+ st.session_state.ja_image_preprocess(altered_image)
231
+ .unsqueeze(0)
232
+ .to(st.session_state.device)
233
+ )
234
+
235
+ # Sometimes used for token importance viz
236
+ tokenized_text = st.session_state.ja_tokenizer.tokenize(
237
+ st.session_state.search_field_value
238
+ )
239
+ image_model = st.session_state.ja_image_model
240
+
241
+ t_text = st.session_state.ja_tokenizer(
242
+ st.session_state.search_field_value, return_tensors="pt"
243
+ )
244
+ text_features = st.session_state.ja_model.get_text_features(**t_text)
245
+
246
+ vis_t = interpret_vit(
247
+ p_image.type(st.session_state.ja_image_model.dtype),
248
+ text_features,
249
+ st.session_state.ja_image_model.visual,
250
+ st.session_state.device,
251
+ img_dim=img_dim,
252
+ )
253
+
254
+ transform = ToPILImage()
255
+ vis_img = transform(vis_t)
256
+
257
+ if orig_img_dims[0] > orig_img_dims[1]:
258
+ scale_factor = MAX_IMG_WIDTH / orig_img_dims[0]
259
+ scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)]
260
+ else:
261
+ scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1]
262
+ scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT]
263
+
264
+ st.session_state.activations_image = vis_img.resize(scaled_dims)
265
+
266
+ image_io = BytesIO()
267
+ st.session_state.activations_image.save(image_io, "PNG")
268
+ dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode("ascii")
269
+
270
+ st.html(
271
+ f"""<div style="display: flex; flex-direction: column; align-items: center">
272
+ <img src="{dataurl}" />
273
+ </div>"""
274
+ )
275
+
276
+ info_text.empty()
277
+
278
+ tokenized_text = [tok for tok in tokenized_text if tok != "▁"]
279
+
280
+ if (
281
+ len(tokenized_text) > 1
282
+ and len(tokenized_text) < 15
283
+ and st.button(
284
+ "Calculate text importance (may take some time)",
285
+ )
286
+ ):
287
+ search_tokens = []
288
+ token_scores = []
289
+
290
+ progress_text = f"Processing {len(tokenized_text)} text tokens"
291
+ progress_bar = st.progress(0.0, text=progress_text)
292
+
293
+ for t, tok in enumerate(tokenized_text):
294
+ token = tok.replace("▁", "")
295
+ word_rel = vit_perword_relevance(
296
+ p_image,
297
+ st.session_state.search_field_value,
298
+ image_model,
299
+ tokenize,
300
+ st.session_state.device,
301
+ token,
302
+ data_only=True,
303
+ img_dim=img_dim,
304
+ )
305
+ avg_score = np.mean(word_rel)
306
+ if avg_score == 0 or np.isnan(avg_score):
307
+ continue
308
+ search_tokens.append(token)
309
+ token_scores.append(1 / avg_score)
310
+
311
+ progress_bar.progress(
312
+ (t + 1) / len(tokenized_text),
313
+ text=f"Processing token {t+1} of {len(tokenized_text)} tokens",
314
+ )
315
+ progress_bar.empty()
316
+
317
+ normed_scores = torch.softmax(torch.tensor(token_scores), dim=0)
318
+
319
+ token_scores = [f"{round(score.item() * 100, 3)}%" for score in normed_scores]
320
+ st.session_state.text_table_df = pd.DataFrame(
321
+ {"token": search_tokens, "importance": token_scores}
322
+ )
323
+
324
+ st.markdown("**Importance of each text token to relevance score**")
325
+ st.table(st.session_state.text_table_df)
326
+
327
+
328
+ @st.dialog(" ", width="small")
329
+ def image_modal(vis_image_id):
330
+ visualize_gradcam(vis_image_id)
331
+
332
+
333
  st.title("Explore Japanese visual aesthetics with CLIP models")
334
 
335
+ st.markdown(
336
+ """
337
+ <style>
338
+ [data-testid=stImageCaption] {
339
+ padding: 0 0 0 0;
340
+ }
341
+ [data-testid=stVerticalBlockBorderWrapper] {
342
+ line-height: 1.2;
343
+ }
344
+ [data-testid=stVerticalBlock] {
345
+ gap: .75rem;
346
+ }
347
+ [data-testid=baseButton-secondary] {
348
+ min-height: 1rem;
349
+ padding: 0 0.75rem;
350
+ margin: 0 0 1rem 0;
351
+ }
352
+ div[aria-label="dialog"]>button[aria-label="Close"] {
353
+ display: none;
354
+ }
355
+ [data-testid=stFullScreenFrame] {
356
+ display: flex;
357
+ flex-direction: column;
358
+ align-items: center;
359
+ }
360
+ </style>
361
+ """,
362
+ unsafe_allow_html=True,
363
+ )
364
+
365
  search_row = st.columns([45, 10, 13, 7, 25], vertical_alignment="center")
366
  with search_row[0]:
367
  search_field = st.text_input(
 
372
  key="search_field_value",
373
  )
374
  with search_row[1]:
375
+ st.button(
376
+ "Search", on_click=string_search, use_container_width=True, type="primary"
377
+ )
378
  with search_row[2]:
379
  st.empty()
380
  with search_row[3]:
 
389
  label_visibility="collapsed",
390
  )
391
 
392
+ canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="top")
393
  with canned_searches[0]:
394
  st.markdown("**Suggested searches:**")
395
  if st.session_state.active_model == "M-CLIP (multiple languages)":
 
483
  link_text = st.session_state.images_info.loc[image_id]["permalink"].split("/")[
484
  2
485
  ]
486
+ # st.image(
487
+ # st.session_state.images_info.loc[image_id]["image_url"],
488
+ # caption=st.session_state.images_info.loc[image_id]["caption"],
489
+ # )
490
  st.html(
491
  f"""<div style="display: flex; flex-direction: column; align-items: center">
492
+ <img src="{st.session_state.images_info.loc[image_id]['image_url']}" style="max-width: 100%; max-height: {MAX_IMG_HEIGHT}px" />
493
+ <div>{st.session_state.images_info.loc[image_id]['caption']} <b>[{round(st.session_state.search_image_scores[image_id], 3)}]</b></div>
494
  </div>"""
495
  )
496
  st.caption(
497
+ f"""<div style="display: flex; flex-direction: column; align-items: center; position: relative; top: -12px">
498
  <a href="{st.session_state.images_info.loc[image_id]['permalink']}">{link_text}</a>
499
  <div>""",
500
  unsafe_allow_html=True,
501
  )
502
+ st.button(
503
+ "Explain this",
504
+ on_click=image_modal,
505
+ args=[image_id],
506
+ use_container_width=True,
507
+ key=image_id,
508
+ )
509
  col = (col + 1) % row_size
requirements.txt CHANGED
@@ -1,6 +1,9 @@
1
  multilingual_clip==1.0.10
2
  numpy==1.26
3
  pandas==2.1.2
 
 
4
  sentencepiece==0.2.0
5
  torch==2.4.0
 
6
  transformers==4.35.0
 
1
  multilingual_clip==1.0.10
2
  numpy==1.26
3
  pandas==2.1.2
4
+ pillow==10.1.0
5
+ requests==2.31.0
6
  sentencepiece==0.2.0
7
  torch==2.4.0
8
+ torchvision==0.19.0
9
  transformers==4.35.0
resized_ja_features.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ec1ba33ef7ffe1236ce4adbfae3d785e89ab7ce98cbc1e99ff74c2391a8a657
3
+ size 25903232
resized_ml_features.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b13a2171ead017721de26fe8c250b871ff4917dc573fbbe9da6b24cc348b156
3
+ size 16189568