Add viz/explanation feature for image and text activations
Browse files- CLIP_Explainability/README.md +43 -0
- CLIP_Explainability/auxilary.py +532 -0
- CLIP_Explainability/bpe_simple_vocab_16e6.txt.gz +3 -0
- CLIP_Explainability/clip_.py +305 -0
- CLIP_Explainability/image_utils.py +22 -0
- CLIP_Explainability/model.py +446 -0
- CLIP_Explainability/simple_tokenizer.py +136 -0
- CLIP_Explainability/vit_cam.py +325 -0
- app.py +247 -10
- requirements.txt +3 -0
- resized_ja_features.npy +3 -0
- resized_ml_features.npy +3 -0
CLIP_Explainability/README.md
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CLIP Explainability
|
2 |
+
|
3 |
+
This repo contains the code for the [CLIP Explainability project](CLIP_Explainability.pdf).
|
4 |
+
In this project, we conduct an in-depth study of CLIP’s learned image and text representations using saliency map visualization. We propose a modification to the existing saliency visualization method that improves its performance as shown by our qualitative evaluations. We then use this method to study CLIP’s ability in capturing similarities and dissimilarities between an input image and targets belonging to different domains including image, text, and emotion.
|
5 |
+
|
6 |
+
## Setup
|
7 |
+
|
8 |
+
To install the required libraries run the following command:
|
9 |
+
|
10 |
+
```
|
11 |
+
pip install -r requirements.txt
|
12 |
+
|
13 |
+
```
|
14 |
+
|
15 |
+
## Organization
|
16 |
+
|
17 |
+
[code](code) directory contains
|
18 |
+
|
19 |
+
- the implementation of saliency visualization methods: for [ViT](code/vit_cam.py) and [ResNet](code/rn_cam.py)-based CLIP
|
20 |
+
- [GradCAM](code/pytorch-grad-cam) implementation based on [pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam/tree/e93f41104e20134e5feac2a660b343437f601ad0) slightly modified to adapt to CLIP.
|
21 |
+
- A re-implementation of CLIP taken from [Transformer-MM-Explainability](https://github.com/hila-chefer/Transformer-MM-Explainability) repo that keeps tack of attention maps and gradients: [clip_.py](code/clip_.py)
|
22 |
+
- [Notebooks](code/notebooks/) for the experiments explained in the report
|
23 |
+
|
24 |
+
|
25 |
+
[Images](Images) contains images used in the experiments.
|
26 |
+
|
27 |
+
[results](results) contains the results obtained from the experiments. Any result generated by the notebooks will be stored in this directory.
|
28 |
+
|
29 |
+
|
30 |
+
## Experiments
|
31 |
+
|
32 |
+
|
33 |
+
| Notebook Name | Experiment | Note |
|
34 |
+
| ------------- | ------------- | ------------- |
|
35 |
+
| [vit_block_vis](code/notebooks/vit_block_vis.ipynb) | Layer-wise Attention Visualization | - |
|
36 |
+
| [saliency_method_compare](code/notebooks/saliency_method_compare.ipynb) | ViT Explainability Method Comparison | Qualitative comparison |
|
37 |
+
| [affectnet_emotions](code/notebooks/affectnet_emotions.ipynb) | ViT Explainability Method Comparison | Bias comparison; you need to download a sample of the AffectNet dataset [here](https://drive.google.com/drive/u/1/folders/11RusPab71wGw6LTd9pUnY1Gz3JSH-N_N) and place it in [Images](Images). |
|
38 |
+
| [pos_neg_vis](code/notebooks/pos_neg_vis.ipynb) | Positive vs Negative Saliency | - |
|
39 |
+
| [artemis_emotions](code/notebooks/artemis_emotions.ipynb) | Emotion-Image Similarity | you need to download the pre-processed WikiArt images [here](https://drive.google.com/drive/u/1/folders/11RusPab71wGw6LTd9pUnY1Gz3JSH-N_N) and place it in [Images](Images). Note that this notebook chooses images randomly so the results may not be the same as the ones in the report. |
|
40 |
+
| [perword_vis](code/notebooks/perword_vis.ipynb) | Word-Wise Saliency Visualization |
|
41 |
+
| [global_vis](code/notebooks/global_vis.ipynb) | - | can be used to visualize saliency maps for ViT and ResNet-based CLIP.|
|
42 |
+
|
43 |
+
|
CLIP_Explainability/auxilary.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import warnings
|
3 |
+
from typing import Tuple, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import Tensor
|
7 |
+
from torch.nn.init import xavier_uniform_
|
8 |
+
from torch.nn.init import constant_
|
9 |
+
from torch.nn.init import xavier_normal_
|
10 |
+
from torch.nn.parameter import Parameter
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
# We define this function as _pad because it takes an argument
|
14 |
+
# named pad, which clobbers the recursive reference to the pad
|
15 |
+
# function needed for __torch_function__ support
|
16 |
+
pad = F.pad
|
17 |
+
|
18 |
+
|
19 |
+
# This class exists solely for Transformer; it has an annotation stating
|
20 |
+
# that bias is never None, which appeases TorchScript
|
21 |
+
class _LinearWithBias(torch.nn.Linear):
|
22 |
+
bias: Tensor
|
23 |
+
|
24 |
+
def __init__(self, in_features: int, out_features: int) -> None:
|
25 |
+
super().__init__(in_features, out_features, bias=True)
|
26 |
+
|
27 |
+
|
28 |
+
def multi_head_attention_forward(
|
29 |
+
query: Tensor,
|
30 |
+
key: Tensor,
|
31 |
+
value: Tensor,
|
32 |
+
embed_dim_to_check: int,
|
33 |
+
num_heads: int,
|
34 |
+
in_proj_weight: Tensor,
|
35 |
+
in_proj_bias: Tensor,
|
36 |
+
bias_k: Optional[Tensor],
|
37 |
+
bias_v: Optional[Tensor],
|
38 |
+
add_zero_attn: bool,
|
39 |
+
dropout_p: float,
|
40 |
+
out_proj_weight: Tensor,
|
41 |
+
out_proj_bias: Tensor,
|
42 |
+
training: bool = True,
|
43 |
+
key_padding_mask: Optional[Tensor] = None,
|
44 |
+
need_weights: bool = True,
|
45 |
+
attn_mask: Optional[Tensor] = None,
|
46 |
+
use_separate_proj_weight: bool = False,
|
47 |
+
q_proj_weight: Optional[Tensor] = None,
|
48 |
+
k_proj_weight: Optional[Tensor] = None,
|
49 |
+
v_proj_weight: Optional[Tensor] = None,
|
50 |
+
static_k: Optional[Tensor] = None,
|
51 |
+
static_v: Optional[Tensor] = None,
|
52 |
+
attention_probs_forward_hook=None,
|
53 |
+
attention_probs_backwards_hook=None,
|
54 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
55 |
+
if not torch.jit.is_scripting():
|
56 |
+
tens_ops = (
|
57 |
+
query,
|
58 |
+
key,
|
59 |
+
value,
|
60 |
+
in_proj_weight,
|
61 |
+
in_proj_bias,
|
62 |
+
bias_k,
|
63 |
+
bias_v,
|
64 |
+
out_proj_weight,
|
65 |
+
out_proj_bias,
|
66 |
+
)
|
67 |
+
if any([type(t) is not Tensor for t in tens_ops]) and F.has_torch_function(
|
68 |
+
tens_ops
|
69 |
+
):
|
70 |
+
return F.handle_torch_function(
|
71 |
+
multi_head_attention_forward,
|
72 |
+
tens_ops,
|
73 |
+
query,
|
74 |
+
key,
|
75 |
+
value,
|
76 |
+
embed_dim_to_check,
|
77 |
+
num_heads,
|
78 |
+
in_proj_weight,
|
79 |
+
in_proj_bias,
|
80 |
+
bias_k,
|
81 |
+
bias_v,
|
82 |
+
add_zero_attn,
|
83 |
+
dropout_p,
|
84 |
+
out_proj_weight,
|
85 |
+
out_proj_bias,
|
86 |
+
training=training,
|
87 |
+
key_padding_mask=key_padding_mask,
|
88 |
+
need_weights=need_weights,
|
89 |
+
attn_mask=attn_mask,
|
90 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
91 |
+
q_proj_weight=q_proj_weight,
|
92 |
+
k_proj_weight=k_proj_weight,
|
93 |
+
v_proj_weight=v_proj_weight,
|
94 |
+
static_k=static_k,
|
95 |
+
static_v=static_v,
|
96 |
+
)
|
97 |
+
tgt_len, bsz, embed_dim = query.size()
|
98 |
+
assert embed_dim == embed_dim_to_check
|
99 |
+
# allow MHA to have different sizes for the feature dimension
|
100 |
+
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
101 |
+
|
102 |
+
head_dim = embed_dim // num_heads
|
103 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
104 |
+
scaling = float(head_dim) ** -0.5
|
105 |
+
|
106 |
+
if not use_separate_proj_weight:
|
107 |
+
if torch.equal(query, key) and torch.equal(key, value):
|
108 |
+
# self-attention
|
109 |
+
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
110 |
+
|
111 |
+
elif torch.equal(key, value):
|
112 |
+
# encoder-decoder attention
|
113 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
114 |
+
_b = in_proj_bias
|
115 |
+
_start = 0
|
116 |
+
_end = embed_dim
|
117 |
+
_w = in_proj_weight[_start:_end, :]
|
118 |
+
if _b is not None:
|
119 |
+
_b = _b[_start:_end]
|
120 |
+
q = F.linear(query, _w, _b)
|
121 |
+
|
122 |
+
if key is None:
|
123 |
+
assert value is None
|
124 |
+
k = None
|
125 |
+
v = None
|
126 |
+
else:
|
127 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
128 |
+
_b = in_proj_bias
|
129 |
+
_start = embed_dim
|
130 |
+
_end = None
|
131 |
+
_w = in_proj_weight[_start:, :]
|
132 |
+
if _b is not None:
|
133 |
+
_b = _b[_start:]
|
134 |
+
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
|
135 |
+
|
136 |
+
else:
|
137 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
138 |
+
_b = in_proj_bias
|
139 |
+
_start = 0
|
140 |
+
_end = embed_dim
|
141 |
+
_w = in_proj_weight[_start:_end, :]
|
142 |
+
if _b is not None:
|
143 |
+
_b = _b[_start:_end]
|
144 |
+
q = F.linear(query, _w, _b)
|
145 |
+
|
146 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
147 |
+
_b = in_proj_bias
|
148 |
+
_start = embed_dim
|
149 |
+
_end = embed_dim * 2
|
150 |
+
_w = in_proj_weight[_start:_end, :]
|
151 |
+
if _b is not None:
|
152 |
+
_b = _b[_start:_end]
|
153 |
+
k = F.linear(key, _w, _b)
|
154 |
+
|
155 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
156 |
+
_b = in_proj_bias
|
157 |
+
_start = embed_dim * 2
|
158 |
+
_end = None
|
159 |
+
_w = in_proj_weight[_start:, :]
|
160 |
+
if _b is not None:
|
161 |
+
_b = _b[_start:]
|
162 |
+
v = F.linear(value, _w, _b)
|
163 |
+
else:
|
164 |
+
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
165 |
+
len1, len2 = q_proj_weight_non_opt.size()
|
166 |
+
assert len1 == embed_dim and len2 == query.size(-1)
|
167 |
+
|
168 |
+
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
169 |
+
len1, len2 = k_proj_weight_non_opt.size()
|
170 |
+
assert len1 == embed_dim and len2 == key.size(-1)
|
171 |
+
|
172 |
+
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
173 |
+
len1, len2 = v_proj_weight_non_opt.size()
|
174 |
+
assert len1 == embed_dim and len2 == value.size(-1)
|
175 |
+
|
176 |
+
if in_proj_bias is not None:
|
177 |
+
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
178 |
+
k = F.linear(
|
179 |
+
key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)]
|
180 |
+
)
|
181 |
+
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])
|
182 |
+
else:
|
183 |
+
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
|
184 |
+
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
|
185 |
+
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
|
186 |
+
q = q * scaling
|
187 |
+
|
188 |
+
if attn_mask is not None:
|
189 |
+
assert (
|
190 |
+
attn_mask.dtype == torch.float32
|
191 |
+
or attn_mask.dtype == torch.float64
|
192 |
+
or attn_mask.dtype == torch.float16
|
193 |
+
or attn_mask.dtype == torch.uint8
|
194 |
+
or attn_mask.dtype == torch.bool
|
195 |
+
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
|
196 |
+
attn_mask.dtype
|
197 |
+
)
|
198 |
+
if attn_mask.dtype == torch.uint8:
|
199 |
+
warnings.warn(
|
200 |
+
"Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
|
201 |
+
)
|
202 |
+
attn_mask = attn_mask.to(torch.bool)
|
203 |
+
|
204 |
+
if attn_mask.dim() == 2:
|
205 |
+
attn_mask = attn_mask.unsqueeze(0)
|
206 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
207 |
+
raise RuntimeError("The size of the 2D attn_mask is not correct.")
|
208 |
+
elif attn_mask.dim() == 3:
|
209 |
+
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
|
210 |
+
raise RuntimeError("The size of the 3D attn_mask is not correct.")
|
211 |
+
else:
|
212 |
+
raise RuntimeError(
|
213 |
+
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
|
214 |
+
)
|
215 |
+
# attn_mask's dim is 3 now.
|
216 |
+
|
217 |
+
# convert ByteTensor key_padding_mask to bool
|
218 |
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
219 |
+
warnings.warn(
|
220 |
+
"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
|
221 |
+
)
|
222 |
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
223 |
+
|
224 |
+
if bias_k is not None and bias_v is not None:
|
225 |
+
if static_k is None and static_v is None:
|
226 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
227 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
228 |
+
if attn_mask is not None:
|
229 |
+
attn_mask = pad(attn_mask, (0, 1))
|
230 |
+
if key_padding_mask is not None:
|
231 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
232 |
+
else:
|
233 |
+
assert static_k is None, "bias cannot be added to static key."
|
234 |
+
assert static_v is None, "bias cannot be added to static value."
|
235 |
+
else:
|
236 |
+
assert bias_k is None
|
237 |
+
assert bias_v is None
|
238 |
+
|
239 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
240 |
+
if k is not None:
|
241 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
242 |
+
if v is not None:
|
243 |
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
244 |
+
|
245 |
+
if static_k is not None:
|
246 |
+
assert static_k.size(0) == bsz * num_heads
|
247 |
+
assert static_k.size(2) == head_dim
|
248 |
+
k = static_k
|
249 |
+
|
250 |
+
if static_v is not None:
|
251 |
+
assert static_v.size(0) == bsz * num_heads
|
252 |
+
assert static_v.size(2) == head_dim
|
253 |
+
v = static_v
|
254 |
+
|
255 |
+
src_len = k.size(1)
|
256 |
+
|
257 |
+
if key_padding_mask is not None:
|
258 |
+
assert key_padding_mask.size(0) == bsz
|
259 |
+
assert key_padding_mask.size(1) == src_len
|
260 |
+
|
261 |
+
if add_zero_attn:
|
262 |
+
src_len += 1
|
263 |
+
k = torch.cat(
|
264 |
+
[
|
265 |
+
k,
|
266 |
+
torch.zeros(
|
267 |
+
(k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device
|
268 |
+
),
|
269 |
+
],
|
270 |
+
dim=1,
|
271 |
+
)
|
272 |
+
v = torch.cat(
|
273 |
+
[
|
274 |
+
v,
|
275 |
+
torch.zeros(
|
276 |
+
(v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device
|
277 |
+
),
|
278 |
+
],
|
279 |
+
dim=1,
|
280 |
+
)
|
281 |
+
if attn_mask is not None:
|
282 |
+
attn_mask = pad(attn_mask, (0, 1))
|
283 |
+
if key_padding_mask is not None:
|
284 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
285 |
+
|
286 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
287 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
288 |
+
|
289 |
+
if attn_mask is not None:
|
290 |
+
if attn_mask.dtype == torch.bool:
|
291 |
+
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
292 |
+
else:
|
293 |
+
attn_output_weights += attn_mask
|
294 |
+
|
295 |
+
if key_padding_mask is not None:
|
296 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
297 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
298 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
299 |
+
float("-inf"),
|
300 |
+
)
|
301 |
+
attn_output_weights = attn_output_weights.view(
|
302 |
+
bsz * num_heads, tgt_len, src_len
|
303 |
+
)
|
304 |
+
|
305 |
+
attn_output_weights = F.softmax(attn_output_weights, dim=-1)
|
306 |
+
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
|
307 |
+
|
308 |
+
# use hooks for the attention weights if necessary
|
309 |
+
if (
|
310 |
+
attention_probs_forward_hook is not None
|
311 |
+
and attention_probs_backwards_hook is not None
|
312 |
+
):
|
313 |
+
attention_probs_forward_hook(attn_output_weights)
|
314 |
+
attn_output_weights.register_hook(attention_probs_backwards_hook)
|
315 |
+
|
316 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
317 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
318 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
319 |
+
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
|
320 |
+
|
321 |
+
if need_weights:
|
322 |
+
# average attention weights over heads
|
323 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
324 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
325 |
+
else:
|
326 |
+
return attn_output, None
|
327 |
+
|
328 |
+
|
329 |
+
class MultiheadAttention(torch.nn.Module):
|
330 |
+
r"""Allows the model to jointly attend to information
|
331 |
+
from different representation subspaces.
|
332 |
+
See reference: Attention Is All You Need
|
333 |
+
|
334 |
+
.. math::
|
335 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
336 |
+
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
337 |
+
|
338 |
+
Args:
|
339 |
+
embed_dim: total dimension of the model.
|
340 |
+
num_heads: parallel attention heads.
|
341 |
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
342 |
+
bias: add bias as module parameter. Default: True.
|
343 |
+
add_bias_kv: add bias to the key and value sequences at dim=0.
|
344 |
+
add_zero_attn: add a new batch of zeros to the key and
|
345 |
+
value sequences at dim=1.
|
346 |
+
kdim: total number of features in key. Default: None.
|
347 |
+
vdim: total number of features in value. Default: None.
|
348 |
+
|
349 |
+
Note: if kdim and vdim are None, they will be set to embed_dim such that
|
350 |
+
query, key, and value have the same number of features.
|
351 |
+
|
352 |
+
Examples::
|
353 |
+
|
354 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
355 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
356 |
+
"""
|
357 |
+
|
358 |
+
bias_k: Optional[torch.Tensor]
|
359 |
+
bias_v: Optional[torch.Tensor]
|
360 |
+
|
361 |
+
def __init__(
|
362 |
+
self,
|
363 |
+
embed_dim,
|
364 |
+
num_heads,
|
365 |
+
dropout=0.0,
|
366 |
+
bias=True,
|
367 |
+
add_bias_kv=False,
|
368 |
+
add_zero_attn=False,
|
369 |
+
kdim=None,
|
370 |
+
vdim=None,
|
371 |
+
):
|
372 |
+
super(MultiheadAttention, self).__init__()
|
373 |
+
self.embed_dim = embed_dim
|
374 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
375 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
376 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
377 |
+
|
378 |
+
self.num_heads = num_heads
|
379 |
+
self.dropout = dropout
|
380 |
+
self.head_dim = embed_dim // num_heads
|
381 |
+
assert (
|
382 |
+
self.head_dim * num_heads == self.embed_dim
|
383 |
+
), "embed_dim must be divisible by num_heads"
|
384 |
+
|
385 |
+
if self._qkv_same_embed_dim is False:
|
386 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
387 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
388 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
389 |
+
self.register_parameter("in_proj_weight", None)
|
390 |
+
else:
|
391 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
392 |
+
self.register_parameter("q_proj_weight", None)
|
393 |
+
self.register_parameter("k_proj_weight", None)
|
394 |
+
self.register_parameter("v_proj_weight", None)
|
395 |
+
|
396 |
+
if bias:
|
397 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
398 |
+
else:
|
399 |
+
self.register_parameter("in_proj_bias", None)
|
400 |
+
self.out_proj = _LinearWithBias(embed_dim, embed_dim)
|
401 |
+
|
402 |
+
if add_bias_kv:
|
403 |
+
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
404 |
+
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
405 |
+
else:
|
406 |
+
self.bias_k = self.bias_v = None
|
407 |
+
|
408 |
+
self.add_zero_attn = add_zero_attn
|
409 |
+
|
410 |
+
self._reset_parameters()
|
411 |
+
|
412 |
+
def _reset_parameters(self):
|
413 |
+
if self._qkv_same_embed_dim:
|
414 |
+
xavier_uniform_(self.in_proj_weight)
|
415 |
+
else:
|
416 |
+
xavier_uniform_(self.q_proj_weight)
|
417 |
+
xavier_uniform_(self.k_proj_weight)
|
418 |
+
xavier_uniform_(self.v_proj_weight)
|
419 |
+
|
420 |
+
if self.in_proj_bias is not None:
|
421 |
+
constant_(self.in_proj_bias, 0.0)
|
422 |
+
constant_(self.out_proj.bias, 0.0)
|
423 |
+
if self.bias_k is not None:
|
424 |
+
xavier_normal_(self.bias_k)
|
425 |
+
if self.bias_v is not None:
|
426 |
+
xavier_normal_(self.bias_v)
|
427 |
+
|
428 |
+
def __setstate__(self, state):
|
429 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
430 |
+
if "_qkv_same_embed_dim" not in state:
|
431 |
+
state["_qkv_same_embed_dim"] = True
|
432 |
+
|
433 |
+
super(MultiheadAttention, self).__setstate__(state)
|
434 |
+
|
435 |
+
def forward(
|
436 |
+
self,
|
437 |
+
query,
|
438 |
+
key,
|
439 |
+
value,
|
440 |
+
key_padding_mask=None,
|
441 |
+
need_weights=True,
|
442 |
+
attn_mask=None,
|
443 |
+
attention_probs_forward_hook=None,
|
444 |
+
attention_probs_backwards_hook=None,
|
445 |
+
):
|
446 |
+
r"""
|
447 |
+
Args:
|
448 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
449 |
+
See "Attention Is All You Need" for more details.
|
450 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
451 |
+
be ignored by the attention. When given a binary mask and a value is True,
|
452 |
+
the corresponding value on the attention layer will be ignored. When given
|
453 |
+
a byte mask and a value is non-zero, the corresponding value on the attention
|
454 |
+
layer will be ignored
|
455 |
+
need_weights: output attn_output_weights.
|
456 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
457 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
458 |
+
|
459 |
+
Shape:
|
460 |
+
- Inputs:
|
461 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
462 |
+
the embedding dimension.
|
463 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
464 |
+
the embedding dimension.
|
465 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
466 |
+
the embedding dimension.
|
467 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
468 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
469 |
+
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
470 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
471 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
472 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
473 |
+
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
474 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
475 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
476 |
+
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
477 |
+
is provided, it will be added to the attention weight.
|
478 |
+
|
479 |
+
- Outputs:
|
480 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
481 |
+
E is the embedding dimension.
|
482 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
483 |
+
L is the target sequence length, S is the source sequence length.
|
484 |
+
"""
|
485 |
+
if not self._qkv_same_embed_dim:
|
486 |
+
return multi_head_attention_forward(
|
487 |
+
query,
|
488 |
+
key,
|
489 |
+
value,
|
490 |
+
self.embed_dim,
|
491 |
+
self.num_heads,
|
492 |
+
self.in_proj_weight,
|
493 |
+
self.in_proj_bias,
|
494 |
+
self.bias_k,
|
495 |
+
self.bias_v,
|
496 |
+
self.add_zero_attn,
|
497 |
+
self.dropout,
|
498 |
+
self.out_proj.weight,
|
499 |
+
self.out_proj.bias,
|
500 |
+
training=self.training,
|
501 |
+
key_padding_mask=key_padding_mask,
|
502 |
+
need_weights=need_weights,
|
503 |
+
attn_mask=attn_mask,
|
504 |
+
use_separate_proj_weight=True,
|
505 |
+
q_proj_weight=self.q_proj_weight,
|
506 |
+
k_proj_weight=self.k_proj_weight,
|
507 |
+
v_proj_weight=self.v_proj_weight,
|
508 |
+
attention_probs_forward_hook=attention_probs_forward_hook,
|
509 |
+
attention_probs_backwards_hook=attention_probs_backwards_hook,
|
510 |
+
)
|
511 |
+
else:
|
512 |
+
return multi_head_attention_forward(
|
513 |
+
query,
|
514 |
+
key,
|
515 |
+
value,
|
516 |
+
self.embed_dim,
|
517 |
+
self.num_heads,
|
518 |
+
self.in_proj_weight,
|
519 |
+
self.in_proj_bias,
|
520 |
+
self.bias_k,
|
521 |
+
self.bias_v,
|
522 |
+
self.add_zero_attn,
|
523 |
+
self.dropout,
|
524 |
+
self.out_proj.weight,
|
525 |
+
self.out_proj.bias,
|
526 |
+
training=self.training,
|
527 |
+
key_padding_mask=key_padding_mask,
|
528 |
+
need_weights=need_weights,
|
529 |
+
attn_mask=attn_mask,
|
530 |
+
attention_probs_forward_hook=attention_probs_forward_hook,
|
531 |
+
attention_probs_backwards_hook=attention_probs_backwards_hook,
|
532 |
+
)
|
CLIP_Explainability/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
CLIP_Explainability/clip_.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
taken from https://github.com/hila-chefer/Transformer-MM-Explainability
|
3 |
+
added similarity_score
|
4 |
+
"""
|
5 |
+
|
6 |
+
import hashlib
|
7 |
+
import os
|
8 |
+
import urllib
|
9 |
+
import warnings
|
10 |
+
from typing import Union, List
|
11 |
+
import re
|
12 |
+
import html
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from PIL import Image
|
16 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
17 |
+
from tqdm import tqdm
|
18 |
+
import ftfy
|
19 |
+
|
20 |
+
from transformers import BatchFeature
|
21 |
+
|
22 |
+
from .model import build_model
|
23 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
24 |
+
|
25 |
+
__all__ = ["available_models", "load", "tokenize"]
|
26 |
+
_tokenizer = _Tokenizer()
|
27 |
+
|
28 |
+
_MODELS = {
|
29 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
30 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
31 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
32 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
37 |
+
os.makedirs(root, exist_ok=True)
|
38 |
+
filename = os.path.basename(url)
|
39 |
+
|
40 |
+
expected_sha256 = url.split("/")[-2]
|
41 |
+
download_target = os.path.join(root, filename)
|
42 |
+
|
43 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
44 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
45 |
+
|
46 |
+
if os.path.isfile(download_target):
|
47 |
+
if (
|
48 |
+
hashlib.sha256(open(download_target, "rb").read()).hexdigest()
|
49 |
+
== expected_sha256
|
50 |
+
):
|
51 |
+
return download_target
|
52 |
+
else:
|
53 |
+
warnings.warn(
|
54 |
+
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
55 |
+
)
|
56 |
+
|
57 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
58 |
+
with tqdm(
|
59 |
+
total=int(source.info().get("Content-Length")),
|
60 |
+
ncols=80,
|
61 |
+
unit="iB",
|
62 |
+
unit_scale=True,
|
63 |
+
) as loop:
|
64 |
+
while True:
|
65 |
+
buffer = source.read(8192)
|
66 |
+
if not buffer:
|
67 |
+
break
|
68 |
+
|
69 |
+
output.write(buffer)
|
70 |
+
loop.update(len(buffer))
|
71 |
+
|
72 |
+
if (
|
73 |
+
hashlib.sha256(open(download_target, "rb").read()).hexdigest()
|
74 |
+
!= expected_sha256
|
75 |
+
):
|
76 |
+
raise RuntimeError(
|
77 |
+
f"Model has been downloaded but the SHA256 checksum does not not match"
|
78 |
+
)
|
79 |
+
|
80 |
+
return download_target
|
81 |
+
|
82 |
+
|
83 |
+
def _transform(n_px):
|
84 |
+
return Compose(
|
85 |
+
[
|
86 |
+
Resize(n_px, interpolation=Image.BICUBIC),
|
87 |
+
CenterCrop(n_px),
|
88 |
+
lambda image: image.convert("RGB"),
|
89 |
+
ToTensor(),
|
90 |
+
Normalize(
|
91 |
+
(0.48145466, 0.4578275, 0.40821073),
|
92 |
+
(0.26862954, 0.26130258, 0.27577711),
|
93 |
+
),
|
94 |
+
]
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def available_models() -> List[str]:
|
99 |
+
"""Returns the names of available CLIP models"""
|
100 |
+
return list(_MODELS.keys())
|
101 |
+
|
102 |
+
|
103 |
+
def load(
|
104 |
+
name: str,
|
105 |
+
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
|
106 |
+
jit=True,
|
107 |
+
):
|
108 |
+
"""Load a CLIP model
|
109 |
+
|
110 |
+
Parameters
|
111 |
+
----------
|
112 |
+
name : str
|
113 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
114 |
+
|
115 |
+
device : Union[str, torch.device]
|
116 |
+
The device to put the loaded model
|
117 |
+
|
118 |
+
jit : bool
|
119 |
+
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
120 |
+
|
121 |
+
Returns
|
122 |
+
-------
|
123 |
+
model : torch.nn.Module
|
124 |
+
The CLIP model
|
125 |
+
|
126 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
127 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
128 |
+
"""
|
129 |
+
if name in _MODELS:
|
130 |
+
model_path = _download(_MODELS[name])
|
131 |
+
elif os.path.isfile(name):
|
132 |
+
model_path = name
|
133 |
+
else:
|
134 |
+
raise RuntimeError(
|
135 |
+
f"Model {name} not found; available models = {available_models()}"
|
136 |
+
)
|
137 |
+
|
138 |
+
try:
|
139 |
+
# loading JIT archive
|
140 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
141 |
+
state_dict = None
|
142 |
+
except RuntimeError:
|
143 |
+
# loading saved state dict
|
144 |
+
if jit:
|
145 |
+
warnings.warn(
|
146 |
+
f"File {model_path} is not a JIT archive. Loading as a state dict instead"
|
147 |
+
)
|
148 |
+
jit = False
|
149 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
150 |
+
|
151 |
+
if not jit:
|
152 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
153 |
+
if str(device) == "cpu":
|
154 |
+
model.float()
|
155 |
+
return model, _transform(model.visual.input_resolution)
|
156 |
+
|
157 |
+
# patch the device names
|
158 |
+
device_holder = torch.jit.trace(
|
159 |
+
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
|
160 |
+
)
|
161 |
+
device_node = [
|
162 |
+
n
|
163 |
+
for n in device_holder.graph.findAllNodes("prim::Constant")
|
164 |
+
if "Device" in repr(n)
|
165 |
+
][-1]
|
166 |
+
|
167 |
+
def patch_device(module):
|
168 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
169 |
+
if hasattr(module, "forward1"):
|
170 |
+
graphs.append(module.forward1.graph)
|
171 |
+
|
172 |
+
for graph in graphs:
|
173 |
+
for node in graph.findAllNodes("prim::Constant"):
|
174 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith(
|
175 |
+
"cuda"
|
176 |
+
):
|
177 |
+
node.copyAttributes(device_node)
|
178 |
+
|
179 |
+
model.apply(patch_device)
|
180 |
+
patch_device(model.encode_image)
|
181 |
+
patch_device(model.encode_text)
|
182 |
+
|
183 |
+
# patch dtype to float32 on CPU
|
184 |
+
if str(device) == "cpu":
|
185 |
+
float_holder = torch.jit.trace(
|
186 |
+
lambda: torch.ones([]).float(), example_inputs=[]
|
187 |
+
)
|
188 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
189 |
+
float_node = float_input.node()
|
190 |
+
|
191 |
+
def patch_float(module):
|
192 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
193 |
+
if hasattr(module, "forward1"):
|
194 |
+
graphs.append(module.forward1.graph)
|
195 |
+
|
196 |
+
for graph in graphs:
|
197 |
+
for node in graph.findAllNodes("aten::to"):
|
198 |
+
inputs = list(node.inputs())
|
199 |
+
for i in [
|
200 |
+
1,
|
201 |
+
2,
|
202 |
+
]: # dtype can be the second or third argument to aten::to()
|
203 |
+
if inputs[i].node()["value"] == 5:
|
204 |
+
inputs[i].node().copyAttributes(float_node)
|
205 |
+
|
206 |
+
model.apply(patch_float)
|
207 |
+
patch_float(model.encode_image)
|
208 |
+
patch_float(model.encode_text)
|
209 |
+
|
210 |
+
model.float()
|
211 |
+
|
212 |
+
return model, _transform(model.input_resolution.item())
|
213 |
+
|
214 |
+
|
215 |
+
def tokenize(
|
216 |
+
texts: Union[str, List[str]], context_length: int = 77
|
217 |
+
) -> torch.LongTensor:
|
218 |
+
"""
|
219 |
+
Returns the tokenized representation of given input string(s)
|
220 |
+
|
221 |
+
Parameters
|
222 |
+
----------
|
223 |
+
texts : Union[str, List[str]]
|
224 |
+
An input string or a list of input strings to tokenize
|
225 |
+
|
226 |
+
context_length : int
|
227 |
+
The context length to use; all CLIP models use 77 as the context length
|
228 |
+
|
229 |
+
Returns
|
230 |
+
-------
|
231 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
232 |
+
"""
|
233 |
+
if isinstance(texts, str):
|
234 |
+
texts = [texts]
|
235 |
+
|
236 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
237 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
238 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
239 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
240 |
+
|
241 |
+
for i, tokens in enumerate(all_tokens):
|
242 |
+
if len(tokens) > context_length:
|
243 |
+
raise RuntimeError(
|
244 |
+
f"Input {texts[i]} is too long for context length {context_length}"
|
245 |
+
)
|
246 |
+
result[i, : len(tokens)] = torch.tensor(tokens)
|
247 |
+
|
248 |
+
return result
|
249 |
+
|
250 |
+
|
251 |
+
def basic_clean(text):
|
252 |
+
text = ftfy.fix_text(text)
|
253 |
+
text = html.unescape(html.unescape(text))
|
254 |
+
return text.strip()
|
255 |
+
|
256 |
+
|
257 |
+
def whitespace_clean(text):
|
258 |
+
text = re.sub(r"\s+", " ", text)
|
259 |
+
text = text.strip()
|
260 |
+
return text
|
261 |
+
|
262 |
+
|
263 |
+
def tokenize_ja(
|
264 |
+
tokenizer,
|
265 |
+
texts: Union[str, List[str]],
|
266 |
+
max_seq_len: int = 77,
|
267 |
+
):
|
268 |
+
"""
|
269 |
+
This is a function that have the original clip's code has.
|
270 |
+
https://github.com/openai/CLIP/blob/main/clip/clip.py#L195
|
271 |
+
"""
|
272 |
+
if isinstance(texts, str):
|
273 |
+
texts = [texts]
|
274 |
+
texts = [whitespace_clean(basic_clean(text)) for text in texts]
|
275 |
+
|
276 |
+
inputs = tokenizer(
|
277 |
+
texts,
|
278 |
+
max_length=max_seq_len - 1,
|
279 |
+
padding="max_length",
|
280 |
+
truncation=True,
|
281 |
+
add_special_tokens=False,
|
282 |
+
)
|
283 |
+
# add bos token at first place
|
284 |
+
input_ids = [[tokenizer.bos_token_id] + ids for ids in inputs["input_ids"]]
|
285 |
+
attention_mask = [[1] + am for am in inputs["attention_mask"]]
|
286 |
+
position_ids = [list(range(0, len(input_ids[0])))] * len(texts)
|
287 |
+
|
288 |
+
return BatchFeature(
|
289 |
+
{
|
290 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
291 |
+
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
292 |
+
"position_ids": torch.tensor(position_ids, dtype=torch.long),
|
293 |
+
}
|
294 |
+
)
|
295 |
+
|
296 |
+
|
297 |
+
def similarity_score(clip_model, image, target_features):
|
298 |
+
image_features = clip_model.encode_image(image)
|
299 |
+
|
300 |
+
image_features_norm = image_features.norm(dim=-1, keepdim=True)
|
301 |
+
image_features_new = image_features / image_features_norm
|
302 |
+
target_features_norm = target_features.norm(dim=-1, keepdim=True)
|
303 |
+
target_features_new = target_features / target_features_norm
|
304 |
+
|
305 |
+
return image_features_new[0].dot(target_features_new[0]) * 100
|
CLIP_Explainability/image_utils.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
def show_cam_on_image(img, mask, neg_saliency=False):
|
5 |
+
|
6 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
7 |
+
|
8 |
+
heatmap = np.float32(heatmap) / 255
|
9 |
+
cam = heatmap + np.float32(img)
|
10 |
+
cam = cam / np.max(cam)
|
11 |
+
return cam
|
12 |
+
|
13 |
+
def show_overlapped_cam(img, neg_mask, pos_mask):
|
14 |
+
neg_heatmap = cv2.applyColorMap(np.uint8(255 * neg_mask), cv2.COLORMAP_RAINBOW)
|
15 |
+
pos_heatmap = cv2.applyColorMap(np.uint8(255 * pos_mask), cv2.COLORMAP_JET)
|
16 |
+
neg_heatmap = np.float32(neg_heatmap) / 255
|
17 |
+
pos_heatmap = np.float32(pos_heatmap) / 255
|
18 |
+
# try different options: sum, average, ...
|
19 |
+
heatmap = neg_heatmap + pos_heatmap
|
20 |
+
cam = heatmap + np.float32(img)
|
21 |
+
cam = cam / np.max(cam)
|
22 |
+
return cam
|
CLIP_Explainability/model.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
taken from https://github.com/hila-chefer/Transformer-MM-Explainability
|
3 |
+
"""
|
4 |
+
|
5 |
+
from collections import OrderedDict
|
6 |
+
from typing import Tuple, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import nn
|
12 |
+
from .auxilary import *
|
13 |
+
|
14 |
+
class Bottleneck(nn.Module):
|
15 |
+
expansion = 4
|
16 |
+
|
17 |
+
def __init__(self, inplanes, planes, stride=1):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
21 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
22 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
23 |
+
|
24 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
25 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
26 |
+
|
27 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
28 |
+
|
29 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
30 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
31 |
+
|
32 |
+
self.relu = nn.ReLU(inplace=True)
|
33 |
+
self.downsample = None
|
34 |
+
self.stride = stride
|
35 |
+
|
36 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
37 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
38 |
+
self.downsample = nn.Sequential(OrderedDict([
|
39 |
+
("-1", nn.AvgPool2d(stride)),
|
40 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
41 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
42 |
+
]))
|
43 |
+
|
44 |
+
def forward(self, x: torch.Tensor):
|
45 |
+
identity = x
|
46 |
+
|
47 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
48 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
49 |
+
out = self.avgpool(out)
|
50 |
+
out = self.bn3(self.conv3(out))
|
51 |
+
|
52 |
+
if self.downsample is not None:
|
53 |
+
identity = self.downsample(x)
|
54 |
+
|
55 |
+
out += identity
|
56 |
+
out = self.relu(out)
|
57 |
+
return out
|
58 |
+
|
59 |
+
|
60 |
+
class AttentionPool2d(nn.Module):
|
61 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
62 |
+
super().__init__()
|
63 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
64 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
65 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
66 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
67 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
68 |
+
self.num_heads = num_heads
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
72 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
73 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
74 |
+
x, _ = multi_head_attention_forward(
|
75 |
+
query=x, key=x, value=x,
|
76 |
+
embed_dim_to_check=x.shape[-1],
|
77 |
+
num_heads=self.num_heads,
|
78 |
+
q_proj_weight=self.q_proj.weight,
|
79 |
+
k_proj_weight=self.k_proj.weight,
|
80 |
+
v_proj_weight=self.v_proj.weight,
|
81 |
+
in_proj_weight=None,
|
82 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
83 |
+
bias_k=None,
|
84 |
+
bias_v=None,
|
85 |
+
add_zero_attn=False,
|
86 |
+
dropout_p=0,
|
87 |
+
out_proj_weight=self.c_proj.weight,
|
88 |
+
out_proj_bias=self.c_proj.bias,
|
89 |
+
use_separate_proj_weight=True,
|
90 |
+
training=self.training,
|
91 |
+
need_weights=False
|
92 |
+
)
|
93 |
+
|
94 |
+
return x[0]
|
95 |
+
|
96 |
+
|
97 |
+
class ModifiedResNet(nn.Module):
|
98 |
+
"""
|
99 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
100 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
101 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
102 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
106 |
+
super().__init__()
|
107 |
+
self.output_dim = output_dim
|
108 |
+
self.input_resolution = input_resolution
|
109 |
+
|
110 |
+
# the 3-layer stem
|
111 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
112 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
113 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
114 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
115 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
116 |
+
self.bn3 = nn.BatchNorm2d(width)
|
117 |
+
self.avgpool = nn.AvgPool2d(2)
|
118 |
+
self.relu = nn.ReLU(inplace=True)
|
119 |
+
|
120 |
+
# residual layers
|
121 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
122 |
+
self.layer1 = self._make_layer(width, layers[0])
|
123 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
124 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
125 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
126 |
+
|
127 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
128 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
129 |
+
|
130 |
+
def _make_layer(self, planes, blocks, stride=1):
|
131 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
132 |
+
|
133 |
+
self._inplanes = planes * Bottleneck.expansion
|
134 |
+
for _ in range(1, blocks):
|
135 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
136 |
+
|
137 |
+
return nn.Sequential(*layers)
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
def stem(x):
|
141 |
+
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
142 |
+
x = self.relu(bn(conv(x)))
|
143 |
+
x = self.avgpool(x)
|
144 |
+
return x
|
145 |
+
|
146 |
+
x = x.type(self.conv1.weight.dtype)
|
147 |
+
x = stem(x)
|
148 |
+
x = self.layer1(x)
|
149 |
+
x = self.layer2(x)
|
150 |
+
x = self.layer3(x)
|
151 |
+
x = self.layer4(x)
|
152 |
+
x = self.attnpool(x)
|
153 |
+
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class LayerNorm(nn.LayerNorm):
|
158 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
159 |
+
|
160 |
+
def forward(self, x: torch.Tensor):
|
161 |
+
orig_type = x.dtype
|
162 |
+
ret = super().forward(x.type(torch.float32))
|
163 |
+
return ret.type(orig_type)
|
164 |
+
|
165 |
+
|
166 |
+
class QuickGELU(nn.Module):
|
167 |
+
def forward(self, x: torch.Tensor):
|
168 |
+
return x * torch.sigmoid(1.702 * x)
|
169 |
+
|
170 |
+
|
171 |
+
class ResidualAttentionBlock(nn.Module):
|
172 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
173 |
+
super().__init__()
|
174 |
+
|
175 |
+
self.attn = MultiheadAttention(d_model, n_head)
|
176 |
+
self.ln_1 = LayerNorm(d_model)
|
177 |
+
self.mlp = nn.Sequential(OrderedDict([
|
178 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
179 |
+
("gelu", QuickGELU()),
|
180 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
181 |
+
]))
|
182 |
+
self.ln_2 = LayerNorm(d_model)
|
183 |
+
self.attn_mask = attn_mask
|
184 |
+
|
185 |
+
self.attn_probs = None
|
186 |
+
self.attn_grad = None
|
187 |
+
|
188 |
+
def set_attn_probs(self, attn_probs):
|
189 |
+
self.attn_probs = attn_probs
|
190 |
+
|
191 |
+
def set_attn_grad(self, attn_grad):
|
192 |
+
self.attn_grad = attn_grad
|
193 |
+
|
194 |
+
def attention(self, x: torch.Tensor):
|
195 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
196 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, attention_probs_forward_hook=self.set_attn_probs,
|
197 |
+
attention_probs_backwards_hook=self.set_attn_grad)[0]
|
198 |
+
|
199 |
+
def forward(self, x: torch.Tensor):
|
200 |
+
x = x + self.attention(self.ln_1(x))
|
201 |
+
x = x + self.mlp(self.ln_2(x))
|
202 |
+
return x
|
203 |
+
|
204 |
+
|
205 |
+
class Transformer(nn.Module):
|
206 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
207 |
+
super().__init__()
|
208 |
+
self.width = width
|
209 |
+
self.layers = layers
|
210 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
211 |
+
|
212 |
+
def forward(self, x: torch.Tensor):
|
213 |
+
return self.resblocks(x)
|
214 |
+
|
215 |
+
|
216 |
+
class VisualTransformer(nn.Module):
|
217 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
218 |
+
super().__init__()
|
219 |
+
self.input_resolution = input_resolution
|
220 |
+
self.output_dim = output_dim
|
221 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
222 |
+
|
223 |
+
scale = width ** -0.5
|
224 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
225 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
226 |
+
self.ln_pre = LayerNorm(width)
|
227 |
+
|
228 |
+
self.transformer = Transformer(width, layers, heads)
|
229 |
+
|
230 |
+
self.ln_post = LayerNorm(width)
|
231 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
232 |
+
|
233 |
+
def forward(self, x: torch.Tensor):
|
234 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
235 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
236 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
237 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
238 |
+
x = x + self.positional_embedding.to(x.dtype)
|
239 |
+
x = self.ln_pre(x)
|
240 |
+
|
241 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
242 |
+
x = self.transformer(x)
|
243 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
244 |
+
|
245 |
+
x = self.ln_post(x[:, 0, :])
|
246 |
+
|
247 |
+
if self.proj is not None:
|
248 |
+
x = x @ self.proj
|
249 |
+
|
250 |
+
return x
|
251 |
+
|
252 |
+
|
253 |
+
class CLIP(nn.Module):
|
254 |
+
def __init__(self,
|
255 |
+
embed_dim: int,
|
256 |
+
# vision
|
257 |
+
image_resolution: int,
|
258 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
259 |
+
vision_width: int,
|
260 |
+
vision_patch_size: int,
|
261 |
+
# text
|
262 |
+
context_length: int,
|
263 |
+
vocab_size: int,
|
264 |
+
transformer_width: int,
|
265 |
+
transformer_heads: int,
|
266 |
+
transformer_layers: int
|
267 |
+
):
|
268 |
+
super().__init__()
|
269 |
+
|
270 |
+
self.context_length = context_length
|
271 |
+
|
272 |
+
if isinstance(vision_layers, (tuple, list)):
|
273 |
+
vision_heads = vision_width * 32 // 64
|
274 |
+
self.visual = ModifiedResNet(
|
275 |
+
layers=vision_layers,
|
276 |
+
output_dim=embed_dim,
|
277 |
+
heads=vision_heads,
|
278 |
+
input_resolution=image_resolution,
|
279 |
+
width=vision_width
|
280 |
+
)
|
281 |
+
else:
|
282 |
+
vision_heads = vision_width // 64
|
283 |
+
self.visual = VisualTransformer(
|
284 |
+
input_resolution=image_resolution,
|
285 |
+
patch_size=vision_patch_size,
|
286 |
+
width=vision_width,
|
287 |
+
layers=vision_layers,
|
288 |
+
heads=vision_heads,
|
289 |
+
output_dim=embed_dim
|
290 |
+
)
|
291 |
+
|
292 |
+
self.transformer = Transformer(
|
293 |
+
width=transformer_width,
|
294 |
+
layers=transformer_layers,
|
295 |
+
heads=transformer_heads,
|
296 |
+
attn_mask=self.build_attention_mask()
|
297 |
+
)
|
298 |
+
|
299 |
+
self.vocab_size = vocab_size
|
300 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
301 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
302 |
+
self.ln_final = LayerNorm(transformer_width)
|
303 |
+
|
304 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
305 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
306 |
+
|
307 |
+
self.initialize_parameters()
|
308 |
+
|
309 |
+
def initialize_parameters(self):
|
310 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
311 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
312 |
+
|
313 |
+
if isinstance(self.visual, ModifiedResNet):
|
314 |
+
if self.visual.attnpool is not None:
|
315 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
316 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
317 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
318 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
319 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
320 |
+
|
321 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
322 |
+
for name, param in resnet_block.named_parameters():
|
323 |
+
if name.endswith("bn3.weight"):
|
324 |
+
nn.init.zeros_(param)
|
325 |
+
|
326 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
327 |
+
attn_std = self.transformer.width ** -0.5
|
328 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
329 |
+
for block in self.transformer.resblocks:
|
330 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
331 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
332 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
333 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
334 |
+
|
335 |
+
if self.text_projection is not None:
|
336 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
337 |
+
|
338 |
+
def build_attention_mask(self):
|
339 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
340 |
+
# pytorch uses additive attention mask; fill with -inf
|
341 |
+
mask = torch.empty(self.context_length, self.context_length)
|
342 |
+
mask.fill_(float("-inf"))
|
343 |
+
mask.triu_(1) # zero out the lower diagonal
|
344 |
+
return mask
|
345 |
+
|
346 |
+
@property
|
347 |
+
def dtype(self):
|
348 |
+
return self.visual.conv1.weight.dtype
|
349 |
+
|
350 |
+
def encode_image(self, image):
|
351 |
+
return self.visual(image.type(self.dtype))
|
352 |
+
|
353 |
+
def encode_text(self, text):
|
354 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
355 |
+
|
356 |
+
x = x + self.positional_embedding.type(self.dtype)
|
357 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
358 |
+
x = self.transformer(x)
|
359 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
360 |
+
x = self.ln_final(x).type(self.dtype)
|
361 |
+
|
362 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
363 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
364 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
365 |
+
|
366 |
+
return x
|
367 |
+
|
368 |
+
def forward(self, image, text):
|
369 |
+
image_features = self.encode_image(image)
|
370 |
+
text_features = self.encode_text(text)
|
371 |
+
|
372 |
+
# normalized features
|
373 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
374 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
375 |
+
|
376 |
+
# cosine similarity as logits
|
377 |
+
logit_scale = self.logit_scale.exp()
|
378 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
379 |
+
logits_per_text = logit_scale * text_features @ image_features.t()
|
380 |
+
|
381 |
+
# shape = [global_batch_size, global_batch_size]
|
382 |
+
return logits_per_image, logits_per_text
|
383 |
+
|
384 |
+
|
385 |
+
def convert_weights(model: nn.Module):
|
386 |
+
"""Convert applicable model parameters to fp16"""
|
387 |
+
|
388 |
+
def _convert_weights_to_fp16(l):
|
389 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
390 |
+
l.weight.data = l.weight.data.half()
|
391 |
+
if l.bias is not None:
|
392 |
+
l.bias.data = l.bias.data.half()
|
393 |
+
|
394 |
+
if isinstance(l, MultiheadAttention):
|
395 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
396 |
+
tensor = getattr(l, attr)
|
397 |
+
if tensor is not None:
|
398 |
+
tensor.data = tensor.data.half()
|
399 |
+
|
400 |
+
for name in ["text_projection", "proj"]:
|
401 |
+
if hasattr(l, name):
|
402 |
+
attr = getattr(l, name)
|
403 |
+
if attr is not None:
|
404 |
+
attr.data = attr.data.half()
|
405 |
+
|
406 |
+
model.apply(_convert_weights_to_fp16)
|
407 |
+
|
408 |
+
|
409 |
+
def build_model(state_dict: dict):
|
410 |
+
vit = "visual.proj" in state_dict
|
411 |
+
|
412 |
+
if vit:
|
413 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
414 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
415 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
416 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
417 |
+
image_resolution = vision_patch_size * grid_size
|
418 |
+
else:
|
419 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
420 |
+
vision_layers = tuple(counts)
|
421 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
422 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
423 |
+
vision_patch_size = None
|
424 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
425 |
+
image_resolution = output_width * 32
|
426 |
+
|
427 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
428 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
429 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
430 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
431 |
+
transformer_heads = transformer_width // 64
|
432 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
433 |
+
|
434 |
+
model = CLIP(
|
435 |
+
embed_dim,
|
436 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
437 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
438 |
+
)
|
439 |
+
|
440 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
441 |
+
if key in state_dict:
|
442 |
+
del state_dict[key]
|
443 |
+
|
444 |
+
convert_weights(model)
|
445 |
+
model.load_state_dict(state_dict)
|
446 |
+
return model.eval()
|
CLIP_Explainability/simple_tokenizer.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
taken from https://github.com/hila-chefer/Transformer-MM-Explainability
|
3 |
+
"""
|
4 |
+
|
5 |
+
import gzip
|
6 |
+
import html
|
7 |
+
import os
|
8 |
+
from functools import lru_cache
|
9 |
+
|
10 |
+
import ftfy
|
11 |
+
import regex as re
|
12 |
+
|
13 |
+
|
14 |
+
@lru_cache()
|
15 |
+
def default_bpe():
|
16 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
17 |
+
|
18 |
+
|
19 |
+
@lru_cache()
|
20 |
+
def bytes_to_unicode():
|
21 |
+
"""
|
22 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
23 |
+
The reversible bpe codes work on unicode strings.
|
24 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
25 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
26 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
27 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
28 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
29 |
+
"""
|
30 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
31 |
+
cs = bs[:]
|
32 |
+
n = 0
|
33 |
+
for b in range(2**8):
|
34 |
+
if b not in bs:
|
35 |
+
bs.append(b)
|
36 |
+
cs.append(2**8+n)
|
37 |
+
n += 1
|
38 |
+
cs = [chr(n) for n in cs]
|
39 |
+
return dict(zip(bs, cs))
|
40 |
+
|
41 |
+
|
42 |
+
def get_pairs(word):
|
43 |
+
"""Return set of symbol pairs in a word.
|
44 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
45 |
+
"""
|
46 |
+
pairs = set()
|
47 |
+
prev_char = word[0]
|
48 |
+
for char in word[1:]:
|
49 |
+
pairs.add((prev_char, char))
|
50 |
+
prev_char = char
|
51 |
+
return pairs
|
52 |
+
|
53 |
+
|
54 |
+
def basic_clean(text):
|
55 |
+
text = ftfy.fix_text(text)
|
56 |
+
text = html.unescape(html.unescape(text))
|
57 |
+
return text.strip()
|
58 |
+
|
59 |
+
|
60 |
+
def whitespace_clean(text):
|
61 |
+
text = re.sub(r'\s+', ' ', text)
|
62 |
+
text = text.strip()
|
63 |
+
return text
|
64 |
+
|
65 |
+
|
66 |
+
class SimpleTokenizer(object):
|
67 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
68 |
+
self.byte_encoder = bytes_to_unicode()
|
69 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
70 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
71 |
+
merges = merges[1:49152-256-2+1]
|
72 |
+
merges = [tuple(merge.split()) for merge in merges]
|
73 |
+
vocab = list(bytes_to_unicode().values())
|
74 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
75 |
+
for merge in merges:
|
76 |
+
vocab.append(''.join(merge))
|
77 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
78 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
79 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
80 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
81 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
82 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
83 |
+
|
84 |
+
def bpe(self, token):
|
85 |
+
if token in self.cache:
|
86 |
+
return self.cache[token]
|
87 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
88 |
+
pairs = get_pairs(word)
|
89 |
+
|
90 |
+
if not pairs:
|
91 |
+
return token+'</w>'
|
92 |
+
|
93 |
+
while True:
|
94 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
95 |
+
if bigram not in self.bpe_ranks:
|
96 |
+
break
|
97 |
+
first, second = bigram
|
98 |
+
new_word = []
|
99 |
+
i = 0
|
100 |
+
while i < len(word):
|
101 |
+
try:
|
102 |
+
j = word.index(first, i)
|
103 |
+
new_word.extend(word[i:j])
|
104 |
+
i = j
|
105 |
+
except:
|
106 |
+
new_word.extend(word[i:])
|
107 |
+
break
|
108 |
+
|
109 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
110 |
+
new_word.append(first+second)
|
111 |
+
i += 2
|
112 |
+
else:
|
113 |
+
new_word.append(word[i])
|
114 |
+
i += 1
|
115 |
+
new_word = tuple(new_word)
|
116 |
+
word = new_word
|
117 |
+
if len(word) == 1:
|
118 |
+
break
|
119 |
+
else:
|
120 |
+
pairs = get_pairs(word)
|
121 |
+
word = ' '.join(word)
|
122 |
+
self.cache[token] = word
|
123 |
+
return word
|
124 |
+
|
125 |
+
def encode(self, text):
|
126 |
+
bpe_tokens = []
|
127 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
128 |
+
for token in re.findall(self.pat, text):
|
129 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
130 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
131 |
+
return bpe_tokens
|
132 |
+
|
133 |
+
def decode(self, tokens):
|
134 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
135 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
136 |
+
return text
|
CLIP_Explainability/vit_cam.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import cv2
|
6 |
+
import regex as re
|
7 |
+
|
8 |
+
from .image_utils import show_cam_on_image, show_overlapped_cam
|
9 |
+
|
10 |
+
|
11 |
+
def vit_block_vis(
|
12 |
+
image,
|
13 |
+
target_features,
|
14 |
+
img_encoder,
|
15 |
+
block,
|
16 |
+
device,
|
17 |
+
grad=False,
|
18 |
+
neg_saliency=False,
|
19 |
+
img_dim=224,
|
20 |
+
):
|
21 |
+
img_encoder.eval()
|
22 |
+
image_features = img_encoder(image)
|
23 |
+
|
24 |
+
image_features_norm = image_features.norm(dim=-1, keepdim=True)
|
25 |
+
image_features_new = image_features / image_features_norm
|
26 |
+
target_features_norm = target_features.norm(dim=-1, keepdim=True)
|
27 |
+
target_features_new = target_features / target_features_norm
|
28 |
+
|
29 |
+
similarity = image_features_new[0].dot(target_features_new[0])
|
30 |
+
image = (image - image.min()) / (image.max() - image.min())
|
31 |
+
|
32 |
+
img_encoder.zero_grad()
|
33 |
+
similarity.backward(retain_graph=True)
|
34 |
+
|
35 |
+
image_attn_blocks = list(
|
36 |
+
dict(img_encoder.transformer.resblocks.named_children()).values()
|
37 |
+
)
|
38 |
+
|
39 |
+
if grad:
|
40 |
+
cam = image_attn_blocks[block].attn_grad.detach()
|
41 |
+
else:
|
42 |
+
cam = image_attn_blocks[block].attn_probs.detach()
|
43 |
+
|
44 |
+
cam = cam.mean(dim=0)
|
45 |
+
image_relevance = cam[0, 1:]
|
46 |
+
|
47 |
+
resize_dim = int(np.sqrt(list(image_relevance.shape)[0]))
|
48 |
+
|
49 |
+
# image_relevance = image_relevance.reshape(1, 1, 7, 7)
|
50 |
+
image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim)
|
51 |
+
|
52 |
+
image_relevance = torch.nn.functional.interpolate(
|
53 |
+
image_relevance, size=img_dim, mode="bilinear"
|
54 |
+
)
|
55 |
+
image_relevance = image_relevance.reshape(img_dim, img_dim)
|
56 |
+
image_relevance = (image_relevance - image_relevance.min()) / (
|
57 |
+
image_relevance.max() - image_relevance.min()
|
58 |
+
)
|
59 |
+
|
60 |
+
cam = image_relevance * image
|
61 |
+
cam = cam / torch.max(cam)
|
62 |
+
|
63 |
+
# TODO: maybe we can ignore this...
|
64 |
+
####
|
65 |
+
masked_image_features = img_encoder(cam)
|
66 |
+
masked_image_features_norm = masked_image_features.norm(dim=-1, keepdim=True)
|
67 |
+
masked_image_features_new = masked_image_features / masked_image_features_norm
|
68 |
+
new_score = masked_image_features_new[0].dot(target_features_new[0])
|
69 |
+
####
|
70 |
+
|
71 |
+
cam = cam[0].permute(1, 2, 0).data.cpu().numpy()
|
72 |
+
cam = np.float32(cam)
|
73 |
+
|
74 |
+
plt.imshow(cam)
|
75 |
+
|
76 |
+
return new_score
|
77 |
+
|
78 |
+
|
79 |
+
def vit_relevance(
|
80 |
+
image,
|
81 |
+
target_features,
|
82 |
+
img_encoder,
|
83 |
+
device,
|
84 |
+
method="last grad",
|
85 |
+
neg_saliency=False,
|
86 |
+
img_dim=224,
|
87 |
+
):
|
88 |
+
img_encoder.eval()
|
89 |
+
image_features = img_encoder(image)
|
90 |
+
|
91 |
+
image_features_norm = image_features.norm(dim=-1, keepdim=True)
|
92 |
+
image_features_new = image_features / image_features_norm
|
93 |
+
target_features_norm = target_features.norm(dim=-1, keepdim=True)
|
94 |
+
target_features_new = target_features / target_features_norm
|
95 |
+
similarity = image_features_new[0].dot(target_features_new[0])
|
96 |
+
if neg_saliency:
|
97 |
+
objective = 1 - similarity
|
98 |
+
else:
|
99 |
+
objective = similarity
|
100 |
+
img_encoder.zero_grad()
|
101 |
+
objective.backward(retain_graph=True)
|
102 |
+
image_attn_blocks = list(
|
103 |
+
dict(img_encoder.transformer.resblocks.named_children()).values()
|
104 |
+
)
|
105 |
+
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
|
106 |
+
|
107 |
+
last_attn = image_attn_blocks[-1].attn_probs.detach()
|
108 |
+
last_attn = last_attn.reshape(-1, last_attn.shape[-1], last_attn.shape[-1])
|
109 |
+
|
110 |
+
last_grad = image_attn_blocks[-1].attn_grad.detach()
|
111 |
+
last_grad = last_grad.reshape(-1, last_grad.shape[-1], last_grad.shape[-1])
|
112 |
+
|
113 |
+
if method == "gradcam":
|
114 |
+
cam = last_grad * last_attn
|
115 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
116 |
+
image_relevance = cam[0, 1:]
|
117 |
+
|
118 |
+
else:
|
119 |
+
R = torch.eye(
|
120 |
+
num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype
|
121 |
+
).to(device)
|
122 |
+
for blk in image_attn_blocks:
|
123 |
+
cam = blk.attn_probs.detach()
|
124 |
+
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
|
125 |
+
|
126 |
+
if method == "last grad":
|
127 |
+
grad = last_grad
|
128 |
+
elif method == "all grads":
|
129 |
+
grad = blk.attn_grad.detach()
|
130 |
+
else:
|
131 |
+
print(
|
132 |
+
"The available visualization methods are: 'gradcam', 'last grad', 'all grads'."
|
133 |
+
)
|
134 |
+
return
|
135 |
+
|
136 |
+
cam = grad * cam
|
137 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
138 |
+
R += torch.matmul(cam, R)
|
139 |
+
|
140 |
+
image_relevance = R[0, 1:]
|
141 |
+
|
142 |
+
resize_dim = int(np.sqrt(list(image_relevance.shape)[0]))
|
143 |
+
|
144 |
+
# image_relevance = image_relevance.reshape(1, 1, 7, 7)
|
145 |
+
image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim)
|
146 |
+
|
147 |
+
image_relevance = torch.nn.functional.interpolate(
|
148 |
+
image_relevance, size=img_dim, mode="bilinear"
|
149 |
+
)
|
150 |
+
image_relevance = image_relevance.reshape(img_dim, img_dim).data.cpu().numpy()
|
151 |
+
image_relevance = (image_relevance - image_relevance.min()) / (
|
152 |
+
image_relevance.max() - image_relevance.min()
|
153 |
+
)
|
154 |
+
image = image[0].permute(1, 2, 0).data.cpu().numpy()
|
155 |
+
image = (image - image.min()) / (image.max() - image.min())
|
156 |
+
|
157 |
+
return image_relevance, image
|
158 |
+
|
159 |
+
|
160 |
+
def interpret_vit(
|
161 |
+
image,
|
162 |
+
target_features,
|
163 |
+
img_encoder,
|
164 |
+
device,
|
165 |
+
method="last grad",
|
166 |
+
neg_saliency=False,
|
167 |
+
img_dim=224,
|
168 |
+
):
|
169 |
+
image_relevance, image = vit_relevance(
|
170 |
+
image,
|
171 |
+
target_features,
|
172 |
+
img_encoder,
|
173 |
+
device,
|
174 |
+
method=method,
|
175 |
+
neg_saliency=neg_saliency,
|
176 |
+
img_dim=img_dim,
|
177 |
+
)
|
178 |
+
|
179 |
+
vis = show_cam_on_image(image, image_relevance, neg_saliency=neg_saliency)
|
180 |
+
vis = np.uint8(255 * vis)
|
181 |
+
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
182 |
+
|
183 |
+
return vis
|
184 |
+
# plt.imshow(vis)
|
185 |
+
|
186 |
+
|
187 |
+
def interpret_vit_overlapped(
|
188 |
+
image, target_features, img_encoder, device, method="last grad", img_dim=224
|
189 |
+
):
|
190 |
+
pos_image_relevance, _ = vit_relevance(
|
191 |
+
image,
|
192 |
+
target_features,
|
193 |
+
img_encoder,
|
194 |
+
device,
|
195 |
+
method=method,
|
196 |
+
neg_saliency=False,
|
197 |
+
img_dim=img_dim,
|
198 |
+
)
|
199 |
+
neg_image_relevance, image = vit_relevance(
|
200 |
+
image,
|
201 |
+
target_features,
|
202 |
+
img_encoder,
|
203 |
+
device,
|
204 |
+
method=method,
|
205 |
+
neg_saliency=True,
|
206 |
+
img_dim=img_dim,
|
207 |
+
)
|
208 |
+
|
209 |
+
vis = show_overlapped_cam(image, neg_image_relevance, pos_image_relevance)
|
210 |
+
vis = np.uint8(255 * vis)
|
211 |
+
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
212 |
+
|
213 |
+
plt.imshow(vis)
|
214 |
+
|
215 |
+
|
216 |
+
def vit_perword_relevance(
|
217 |
+
image,
|
218 |
+
text,
|
219 |
+
clip_model,
|
220 |
+
clip_tokenizer,
|
221 |
+
device,
|
222 |
+
masked_word="",
|
223 |
+
use_last_grad=True,
|
224 |
+
data_only=False,
|
225 |
+
img_dim=224,
|
226 |
+
):
|
227 |
+
clip_model.eval()
|
228 |
+
|
229 |
+
main_text = clip_tokenizer(text).to(device)
|
230 |
+
# remove the word for which you want to visualize the saliency
|
231 |
+
masked_text = re.sub(masked_word, "", text)
|
232 |
+
masked_text = clip_tokenizer(masked_text).to(device)
|
233 |
+
|
234 |
+
image_features = clip_model.encode_image(image)
|
235 |
+
main_text_features = clip_model.encode_text(main_text)
|
236 |
+
masked_text_features = clip_model.encode_text(masked_text)
|
237 |
+
|
238 |
+
image_features_norm = image_features.norm(dim=-1, keepdim=True)
|
239 |
+
image_features_new = image_features / image_features_norm
|
240 |
+
main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True)
|
241 |
+
main_text_features_new = main_text_features / main_text_features_norm
|
242 |
+
|
243 |
+
masked_text_features_norm = masked_text_features.norm(dim=-1, keepdim=True)
|
244 |
+
masked_text_features_new = masked_text_features / masked_text_features_norm
|
245 |
+
|
246 |
+
objective = image_features_new[0].dot(
|
247 |
+
main_text_features_new[0] - masked_text_features_new[0]
|
248 |
+
)
|
249 |
+
|
250 |
+
clip_model.visual.zero_grad()
|
251 |
+
objective.backward(retain_graph=True)
|
252 |
+
|
253 |
+
image_attn_blocks = list(
|
254 |
+
dict(clip_model.visual.transformer.resblocks.named_children()).values()
|
255 |
+
)
|
256 |
+
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
|
257 |
+
|
258 |
+
R = torch.eye(
|
259 |
+
num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype
|
260 |
+
).to(device)
|
261 |
+
|
262 |
+
last_grad = image_attn_blocks[-1].attn_grad.detach()
|
263 |
+
last_grad = last_grad.reshape(-1, last_grad.shape[-1], last_grad.shape[-1])
|
264 |
+
|
265 |
+
for blk in image_attn_blocks:
|
266 |
+
cam = blk.attn_probs.detach()
|
267 |
+
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
|
268 |
+
|
269 |
+
if use_last_grad:
|
270 |
+
grad = last_grad
|
271 |
+
else:
|
272 |
+
grad = blk.attn_grad.detach()
|
273 |
+
|
274 |
+
cam = grad * cam
|
275 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
276 |
+
R += torch.matmul(cam, R)
|
277 |
+
|
278 |
+
image_relevance = R[0, 1:]
|
279 |
+
|
280 |
+
resize_dim = int(np.sqrt(list(image_relevance.shape)[0]))
|
281 |
+
|
282 |
+
image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim)
|
283 |
+
|
284 |
+
image_relevance = torch.nn.functional.interpolate(
|
285 |
+
image_relevance, size=img_dim, mode="bilinear"
|
286 |
+
)
|
287 |
+
image_relevance = image_relevance.reshape(img_dim, img_dim).data.cpu().numpy()
|
288 |
+
image_relevance = (image_relevance - image_relevance.min()) / (
|
289 |
+
image_relevance.max() - image_relevance.min()
|
290 |
+
)
|
291 |
+
|
292 |
+
if data_only:
|
293 |
+
return image_relevance
|
294 |
+
|
295 |
+
image = image[0].permute(1, 2, 0).data.cpu().numpy()
|
296 |
+
image = (image - image.min()) / (image.max() - image.min())
|
297 |
+
|
298 |
+
return image_relevance, image
|
299 |
+
|
300 |
+
|
301 |
+
def interpret_perword_vit(
|
302 |
+
image,
|
303 |
+
text,
|
304 |
+
clip_model,
|
305 |
+
clip_tokenizer,
|
306 |
+
device,
|
307 |
+
masked_word="",
|
308 |
+
use_last_grad=True,
|
309 |
+
img_dim=224,
|
310 |
+
):
|
311 |
+
image_relevance, image = vit_perword_relevance(
|
312 |
+
image,
|
313 |
+
text,
|
314 |
+
clip_model,
|
315 |
+
clip_tokenizer,
|
316 |
+
device,
|
317 |
+
masked_word,
|
318 |
+
use_last_grad,
|
319 |
+
img_dim=img_dim,
|
320 |
+
)
|
321 |
+
vis = show_cam_on_image(image, image_relevance)
|
322 |
+
vis = np.uint8(255 * vis)
|
323 |
+
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
324 |
+
|
325 |
+
plt.imshow(vis)
|
app.py
CHANGED
@@ -1,12 +1,26 @@
|
|
|
|
|
|
1 |
from math import ceil
|
2 |
|
|
|
3 |
from multilingual_clip import pt_multilingual_clip
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
|
|
|
|
6 |
import streamlit as st
|
7 |
import torch
|
|
|
8 |
from transformers import AutoTokenizer, AutoModel
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
st.set_page_config(layout="wide")
|
12 |
|
@@ -15,16 +29,28 @@ def init():
|
|
15 |
st.session_state.current_page = 1
|
16 |
|
17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
18 |
|
19 |
# Load the open CLIP models
|
20 |
ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
|
23 |
st.session_state.ml_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(
|
24 |
ml_model_name
|
25 |
)
|
26 |
st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name)
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
st.session_state.ja_model = AutoModel.from_pretrained(
|
29 |
ja_model_name, trust_remote_code=True
|
30 |
).to(device)
|
@@ -32,7 +58,12 @@ def init():
|
|
32 |
ja_model_name, trust_remote_code=True
|
33 |
)
|
34 |
|
|
|
|
|
35 |
st.session_state.search_image_ids = []
|
|
|
|
|
|
|
36 |
|
37 |
# Load the image IDs
|
38 |
st.session_state.images_info = pd.read_csv("./metadata.csv")
|
@@ -43,8 +74,10 @@ def init():
|
|
43 |
)
|
44 |
|
45 |
# Load the image feature vectors
|
46 |
-
ml_image_features = np.load("./multilingual_features.npy")
|
47 |
-
ja_image_features = np.load("./hakuhodo_features.npy")
|
|
|
|
|
48 |
|
49 |
# Convert features to Tensors: Float32 on CPU and Float16 on GPU
|
50 |
if device == "cpu":
|
@@ -128,16 +161,207 @@ def clip_search(search_query):
|
|
128 |
st.session_state.image_ids,
|
129 |
)
|
130 |
|
131 |
-
|
132 |
-
st.session_state.
|
133 |
|
134 |
|
135 |
def string_search():
|
136 |
clip_search(st.session_state.search_field_value)
|
137 |
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
st.title("Explore Japanese visual aesthetics with CLIP models")
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
search_row = st.columns([45, 10, 13, 7, 25], vertical_alignment="center")
|
142 |
with search_row[0]:
|
143 |
search_field = st.text_input(
|
@@ -148,7 +372,9 @@ with search_row[0]:
|
|
148 |
key="search_field_value",
|
149 |
)
|
150 |
with search_row[1]:
|
151 |
-
st.button(
|
|
|
|
|
152 |
with search_row[2]:
|
153 |
st.empty()
|
154 |
with search_row[3]:
|
@@ -163,7 +389,7 @@ with search_row[4]:
|
|
163 |
label_visibility="collapsed",
|
164 |
)
|
165 |
|
166 |
-
canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="
|
167 |
with canned_searches[0]:
|
168 |
st.markdown("**Suggested searches:**")
|
169 |
if st.session_state.active_model == "M-CLIP (multiple languages)":
|
@@ -257,16 +483,27 @@ for image_id in batch:
|
|
257 |
link_text = st.session_state.images_info.loc[image_id]["permalink"].split("/")[
|
258 |
2
|
259 |
]
|
|
|
|
|
|
|
|
|
260 |
st.html(
|
261 |
f"""<div style="display: flex; flex-direction: column; align-items: center">
|
262 |
-
<img src="{st.session_state.images_info.loc[image_id]['image_url']}" style="max-width: 100%; max-height:
|
263 |
-
<div>{st.session_state.images_info.loc[image_id]['caption']}</div>
|
264 |
</div>"""
|
265 |
)
|
266 |
st.caption(
|
267 |
-
f"""<div style="display: flex; flex-direction: column; align-items: center; position: relative; top: -
|
268 |
<a href="{st.session_state.images_info.loc[image_id]['permalink']}">{link_text}</a>
|
269 |
<div>""",
|
270 |
unsafe_allow_html=True,
|
271 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
col = (col + 1) % row_size
|
|
|
1 |
+
from base64 import b64encode
|
2 |
+
from io import BytesIO
|
3 |
from math import ceil
|
4 |
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
from multilingual_clip import pt_multilingual_clip
|
7 |
import numpy as np
|
8 |
import pandas as pd
|
9 |
+
from PIL import Image
|
10 |
+
import requests
|
11 |
import streamlit as st
|
12 |
import torch
|
13 |
+
from torchvision.transforms import ToPILImage
|
14 |
from transformers import AutoTokenizer, AutoModel
|
15 |
|
16 |
+
from CLIP_Explainability.clip_ import load, tokenize
|
17 |
+
from CLIP_Explainability.vit_cam import (
|
18 |
+
interpret_vit,
|
19 |
+
vit_perword_relevance,
|
20 |
+
) # , interpret_vit_overlapped
|
21 |
+
|
22 |
+
MAX_IMG_WIDTH = 450 # For small dialog
|
23 |
+
MAX_IMG_HEIGHT = 800
|
24 |
|
25 |
st.set_page_config(layout="wide")
|
26 |
|
|
|
29 |
st.session_state.current_page = 1
|
30 |
|
31 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
32 |
+
st.session_state.device = device
|
33 |
|
34 |
# Load the open CLIP models
|
35 |
ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
|
36 |
+
ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
|
37 |
+
|
38 |
+
st.session_state.ml_image_model, st.session_state.ml_image_preprocess = load(
|
39 |
+
ml_model_path, device=device, jit=False
|
40 |
+
)
|
41 |
|
42 |
st.session_state.ml_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(
|
43 |
ml_model_name
|
44 |
)
|
45 |
st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name)
|
46 |
|
47 |
+
ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
|
48 |
+
ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
|
49 |
+
|
50 |
+
st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load(
|
51 |
+
ja_model_path, device=device, jit=False
|
52 |
+
)
|
53 |
+
|
54 |
st.session_state.ja_model = AutoModel.from_pretrained(
|
55 |
ja_model_name, trust_remote_code=True
|
56 |
).to(device)
|
|
|
58 |
ja_model_name, trust_remote_code=True
|
59 |
)
|
60 |
|
61 |
+
st.session_state.active_model = "M-CLIP (multiple languages)"
|
62 |
+
|
63 |
st.session_state.search_image_ids = []
|
64 |
+
st.session_state.search_image_scores = {}
|
65 |
+
st.session_state.activations_image = None
|
66 |
+
st.session_state.text_table_df = None
|
67 |
|
68 |
# Load the image IDs
|
69 |
st.session_state.images_info = pd.read_csv("./metadata.csv")
|
|
|
74 |
)
|
75 |
|
76 |
# Load the image feature vectors
|
77 |
+
# ml_image_features = np.load("./multilingual_features.npy")
|
78 |
+
# ja_image_features = np.load("./hakuhodo_features.npy")
|
79 |
+
ml_image_features = np.load("./resized_ml_features.npy")
|
80 |
+
ja_image_features = np.load("./resized_ja_features.npy")
|
81 |
|
82 |
# Convert features to Tensors: Float32 on CPU and Float16 on GPU
|
83 |
if device == "cpu":
|
|
|
161 |
st.session_state.image_ids,
|
162 |
)
|
163 |
|
164 |
+
st.session_state.search_image_ids = [match[0] for match in matches]
|
165 |
+
st.session_state.search_image_scores = {match[0]: match[1] for match in matches}
|
166 |
|
167 |
|
168 |
def string_search():
|
169 |
clip_search(st.session_state.search_field_value)
|
170 |
|
171 |
|
172 |
+
def visualize_gradcam(viz_image_id):
|
173 |
+
if not st.session_state.search_field_value:
|
174 |
+
return
|
175 |
+
|
176 |
+
header_cols = st.columns([80, 20], vertical_alignment="bottom")
|
177 |
+
with header_cols[0]:
|
178 |
+
st.title("Image + query details")
|
179 |
+
with header_cols[1]:
|
180 |
+
if st.button("Close"):
|
181 |
+
st.rerun()
|
182 |
+
|
183 |
+
st.markdown(
|
184 |
+
f"**Query text:** {st.session_state.search_field_value} | **Image relevance:** {round(st.session_state.search_image_scores[viz_image_id], 3)}"
|
185 |
+
)
|
186 |
+
|
187 |
+
# with st.spinner("Calculating..."):
|
188 |
+
info_text = st.text("Calculating activation regions...")
|
189 |
+
|
190 |
+
image_url = st.session_state.images_info.loc[viz_image_id]["image_url"]
|
191 |
+
image_response = requests.get(image_url)
|
192 |
+
image = Image.open(BytesIO(image_response.content), formats=["JPEG"])
|
193 |
+
|
194 |
+
img_dim = 224
|
195 |
+
if st.session_state.active_model == "M-CLIP (multiple languages)":
|
196 |
+
img_dim = 240
|
197 |
+
|
198 |
+
orig_img_dims = image.size
|
199 |
+
|
200 |
+
altered_image = image.resize((img_dim, img_dim), Image.LANCZOS)
|
201 |
+
|
202 |
+
if st.session_state.active_model == "M-CLIP (multiple languages)":
|
203 |
+
p_image = (
|
204 |
+
st.session_state.ml_image_preprocess(altered_image)
|
205 |
+
.unsqueeze(0)
|
206 |
+
.to(st.session_state.device)
|
207 |
+
)
|
208 |
+
|
209 |
+
# Sometimes used for token importance viz
|
210 |
+
tokenized_text = st.session_state.ml_tokenizer.tokenize(
|
211 |
+
st.session_state.search_field_value
|
212 |
+
)
|
213 |
+
image_model = st.session_state.ml_image_model
|
214 |
+
# tokenize = st.session_state.ml_tokenizer.tokenize
|
215 |
+
|
216 |
+
text_features = st.session_state.ml_model.forward(
|
217 |
+
st.session_state.search_field_value, st.session_state.ml_tokenizer
|
218 |
+
)
|
219 |
+
|
220 |
+
vis_t = interpret_vit(
|
221 |
+
p_image.type(st.session_state.ml_image_model.dtype),
|
222 |
+
text_features,
|
223 |
+
st.session_state.ml_image_model.visual,
|
224 |
+
st.session_state.device,
|
225 |
+
img_dim=img_dim,
|
226 |
+
)
|
227 |
+
|
228 |
+
else:
|
229 |
+
p_image = (
|
230 |
+
st.session_state.ja_image_preprocess(altered_image)
|
231 |
+
.unsqueeze(0)
|
232 |
+
.to(st.session_state.device)
|
233 |
+
)
|
234 |
+
|
235 |
+
# Sometimes used for token importance viz
|
236 |
+
tokenized_text = st.session_state.ja_tokenizer.tokenize(
|
237 |
+
st.session_state.search_field_value
|
238 |
+
)
|
239 |
+
image_model = st.session_state.ja_image_model
|
240 |
+
|
241 |
+
t_text = st.session_state.ja_tokenizer(
|
242 |
+
st.session_state.search_field_value, return_tensors="pt"
|
243 |
+
)
|
244 |
+
text_features = st.session_state.ja_model.get_text_features(**t_text)
|
245 |
+
|
246 |
+
vis_t = interpret_vit(
|
247 |
+
p_image.type(st.session_state.ja_image_model.dtype),
|
248 |
+
text_features,
|
249 |
+
st.session_state.ja_image_model.visual,
|
250 |
+
st.session_state.device,
|
251 |
+
img_dim=img_dim,
|
252 |
+
)
|
253 |
+
|
254 |
+
transform = ToPILImage()
|
255 |
+
vis_img = transform(vis_t)
|
256 |
+
|
257 |
+
if orig_img_dims[0] > orig_img_dims[1]:
|
258 |
+
scale_factor = MAX_IMG_WIDTH / orig_img_dims[0]
|
259 |
+
scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)]
|
260 |
+
else:
|
261 |
+
scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1]
|
262 |
+
scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT]
|
263 |
+
|
264 |
+
st.session_state.activations_image = vis_img.resize(scaled_dims)
|
265 |
+
|
266 |
+
image_io = BytesIO()
|
267 |
+
st.session_state.activations_image.save(image_io, "PNG")
|
268 |
+
dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode("ascii")
|
269 |
+
|
270 |
+
st.html(
|
271 |
+
f"""<div style="display: flex; flex-direction: column; align-items: center">
|
272 |
+
<img src="{dataurl}" />
|
273 |
+
</div>"""
|
274 |
+
)
|
275 |
+
|
276 |
+
info_text.empty()
|
277 |
+
|
278 |
+
tokenized_text = [tok for tok in tokenized_text if tok != "▁"]
|
279 |
+
|
280 |
+
if (
|
281 |
+
len(tokenized_text) > 1
|
282 |
+
and len(tokenized_text) < 15
|
283 |
+
and st.button(
|
284 |
+
"Calculate text importance (may take some time)",
|
285 |
+
)
|
286 |
+
):
|
287 |
+
search_tokens = []
|
288 |
+
token_scores = []
|
289 |
+
|
290 |
+
progress_text = f"Processing {len(tokenized_text)} text tokens"
|
291 |
+
progress_bar = st.progress(0.0, text=progress_text)
|
292 |
+
|
293 |
+
for t, tok in enumerate(tokenized_text):
|
294 |
+
token = tok.replace("▁", "")
|
295 |
+
word_rel = vit_perword_relevance(
|
296 |
+
p_image,
|
297 |
+
st.session_state.search_field_value,
|
298 |
+
image_model,
|
299 |
+
tokenize,
|
300 |
+
st.session_state.device,
|
301 |
+
token,
|
302 |
+
data_only=True,
|
303 |
+
img_dim=img_dim,
|
304 |
+
)
|
305 |
+
avg_score = np.mean(word_rel)
|
306 |
+
if avg_score == 0 or np.isnan(avg_score):
|
307 |
+
continue
|
308 |
+
search_tokens.append(token)
|
309 |
+
token_scores.append(1 / avg_score)
|
310 |
+
|
311 |
+
progress_bar.progress(
|
312 |
+
(t + 1) / len(tokenized_text),
|
313 |
+
text=f"Processing token {t+1} of {len(tokenized_text)} tokens",
|
314 |
+
)
|
315 |
+
progress_bar.empty()
|
316 |
+
|
317 |
+
normed_scores = torch.softmax(torch.tensor(token_scores), dim=0)
|
318 |
+
|
319 |
+
token_scores = [f"{round(score.item() * 100, 3)}%" for score in normed_scores]
|
320 |
+
st.session_state.text_table_df = pd.DataFrame(
|
321 |
+
{"token": search_tokens, "importance": token_scores}
|
322 |
+
)
|
323 |
+
|
324 |
+
st.markdown("**Importance of each text token to relevance score**")
|
325 |
+
st.table(st.session_state.text_table_df)
|
326 |
+
|
327 |
+
|
328 |
+
@st.dialog(" ", width="small")
|
329 |
+
def image_modal(vis_image_id):
|
330 |
+
visualize_gradcam(vis_image_id)
|
331 |
+
|
332 |
+
|
333 |
st.title("Explore Japanese visual aesthetics with CLIP models")
|
334 |
|
335 |
+
st.markdown(
|
336 |
+
"""
|
337 |
+
<style>
|
338 |
+
[data-testid=stImageCaption] {
|
339 |
+
padding: 0 0 0 0;
|
340 |
+
}
|
341 |
+
[data-testid=stVerticalBlockBorderWrapper] {
|
342 |
+
line-height: 1.2;
|
343 |
+
}
|
344 |
+
[data-testid=stVerticalBlock] {
|
345 |
+
gap: .75rem;
|
346 |
+
}
|
347 |
+
[data-testid=baseButton-secondary] {
|
348 |
+
min-height: 1rem;
|
349 |
+
padding: 0 0.75rem;
|
350 |
+
margin: 0 0 1rem 0;
|
351 |
+
}
|
352 |
+
div[aria-label="dialog"]>button[aria-label="Close"] {
|
353 |
+
display: none;
|
354 |
+
}
|
355 |
+
[data-testid=stFullScreenFrame] {
|
356 |
+
display: flex;
|
357 |
+
flex-direction: column;
|
358 |
+
align-items: center;
|
359 |
+
}
|
360 |
+
</style>
|
361 |
+
""",
|
362 |
+
unsafe_allow_html=True,
|
363 |
+
)
|
364 |
+
|
365 |
search_row = st.columns([45, 10, 13, 7, 25], vertical_alignment="center")
|
366 |
with search_row[0]:
|
367 |
search_field = st.text_input(
|
|
|
372 |
key="search_field_value",
|
373 |
)
|
374 |
with search_row[1]:
|
375 |
+
st.button(
|
376 |
+
"Search", on_click=string_search, use_container_width=True, type="primary"
|
377 |
+
)
|
378 |
with search_row[2]:
|
379 |
st.empty()
|
380 |
with search_row[3]:
|
|
|
389 |
label_visibility="collapsed",
|
390 |
)
|
391 |
|
392 |
+
canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="top")
|
393 |
with canned_searches[0]:
|
394 |
st.markdown("**Suggested searches:**")
|
395 |
if st.session_state.active_model == "M-CLIP (multiple languages)":
|
|
|
483 |
link_text = st.session_state.images_info.loc[image_id]["permalink"].split("/")[
|
484 |
2
|
485 |
]
|
486 |
+
# st.image(
|
487 |
+
# st.session_state.images_info.loc[image_id]["image_url"],
|
488 |
+
# caption=st.session_state.images_info.loc[image_id]["caption"],
|
489 |
+
# )
|
490 |
st.html(
|
491 |
f"""<div style="display: flex; flex-direction: column; align-items: center">
|
492 |
+
<img src="{st.session_state.images_info.loc[image_id]['image_url']}" style="max-width: 100%; max-height: {MAX_IMG_HEIGHT}px" />
|
493 |
+
<div>{st.session_state.images_info.loc[image_id]['caption']} <b>[{round(st.session_state.search_image_scores[image_id], 3)}]</b></div>
|
494 |
</div>"""
|
495 |
)
|
496 |
st.caption(
|
497 |
+
f"""<div style="display: flex; flex-direction: column; align-items: center; position: relative; top: -12px">
|
498 |
<a href="{st.session_state.images_info.loc[image_id]['permalink']}">{link_text}</a>
|
499 |
<div>""",
|
500 |
unsafe_allow_html=True,
|
501 |
)
|
502 |
+
st.button(
|
503 |
+
"Explain this",
|
504 |
+
on_click=image_modal,
|
505 |
+
args=[image_id],
|
506 |
+
use_container_width=True,
|
507 |
+
key=image_id,
|
508 |
+
)
|
509 |
col = (col + 1) % row_size
|
requirements.txt
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
multilingual_clip==1.0.10
|
2 |
numpy==1.26
|
3 |
pandas==2.1.2
|
|
|
|
|
4 |
sentencepiece==0.2.0
|
5 |
torch==2.4.0
|
|
|
6 |
transformers==4.35.0
|
|
|
1 |
multilingual_clip==1.0.10
|
2 |
numpy==1.26
|
3 |
pandas==2.1.2
|
4 |
+
pillow==10.1.0
|
5 |
+
requests==2.31.0
|
6 |
sentencepiece==0.2.0
|
7 |
torch==2.4.0
|
8 |
+
torchvision==0.19.0
|
9 |
transformers==4.35.0
|
resized_ja_features.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5ec1ba33ef7ffe1236ce4adbfae3d785e89ab7ce98cbc1e99ff74c2391a8a657
|
3 |
+
size 25903232
|
resized_ml_features.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b13a2171ead017721de26fe8c250b871ff4917dc573fbbe9da6b24cc348b156
|
3 |
+
size 16189568
|