Amirparsa-Sal commited on
Commit
5d1f0ae
·
1 Parent(s): 9a99f40
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +7 -0
  2. AnomalyCLIP_lib/AnomalyCLIP.py +531 -0
  3. AnomalyCLIP_lib/CLIP.py +436 -0
  4. AnomalyCLIP_lib/__init__.py +1 -0
  5. AnomalyCLIP_lib/bpe_simple_vocab_16e6.txt.gz +3 -0
  6. AnomalyCLIP_lib/build_model.py +50 -0
  7. AnomalyCLIP_lib/constants.py +2 -0
  8. AnomalyCLIP_lib/model_load.py +235 -0
  9. AnomalyCLIP_lib/simple_tokenizer.py +132 -0
  10. AnomalyCLIP_lib/transform.py +133 -0
  11. Dockerfile +14 -0
  12. LICENSE +21 -0
  13. README.md +142 -0
  14. checkpoints/9_12_4_multiscale/epoch_1.pth +3 -0
  15. checkpoints/9_12_4_multiscale/epoch_10.pth +3 -0
  16. checkpoints/9_12_4_multiscale/epoch_11.pth +3 -0
  17. checkpoints/9_12_4_multiscale/epoch_12.pth +3 -0
  18. checkpoints/9_12_4_multiscale/epoch_13.pth +3 -0
  19. checkpoints/9_12_4_multiscale/epoch_14.pth +3 -0
  20. checkpoints/9_12_4_multiscale/epoch_15.pth +3 -0
  21. checkpoints/9_12_4_multiscale/epoch_2.pth +3 -0
  22. checkpoints/9_12_4_multiscale/epoch_3.pth +3 -0
  23. checkpoints/9_12_4_multiscale/epoch_4.pth +3 -0
  24. checkpoints/9_12_4_multiscale/epoch_5.pth +3 -0
  25. checkpoints/9_12_4_multiscale/epoch_6.pth +3 -0
  26. checkpoints/9_12_4_multiscale/epoch_7.pth +3 -0
  27. checkpoints/9_12_4_multiscale/epoch_8.pth +3 -0
  28. checkpoints/9_12_4_multiscale/epoch_9.pth +3 -0
  29. checkpoints/9_12_4_multiscale/log.txt +0 -0
  30. checkpoints/9_12_4_multiscale_visa/epoch_1.pth +3 -0
  31. checkpoints/9_12_4_multiscale_visa/epoch_10.pth +3 -0
  32. checkpoints/9_12_4_multiscale_visa/epoch_11.pth +3 -0
  33. checkpoints/9_12_4_multiscale_visa/epoch_12.pth +3 -0
  34. checkpoints/9_12_4_multiscale_visa/epoch_13.pth +3 -0
  35. checkpoints/9_12_4_multiscale_visa/epoch_14.pth +3 -0
  36. checkpoints/9_12_4_multiscale_visa/epoch_15.pth +3 -0
  37. checkpoints/9_12_4_multiscale_visa/epoch_2.pth +3 -0
  38. checkpoints/9_12_4_multiscale_visa/epoch_3.pth +3 -0
  39. checkpoints/9_12_4_multiscale_visa/epoch_4.pth +3 -0
  40. checkpoints/9_12_4_multiscale_visa/epoch_5.pth +3 -0
  41. checkpoints/9_12_4_multiscale_visa/epoch_6.pth +3 -0
  42. checkpoints/9_12_4_multiscale_visa/epoch_7.pth +3 -0
  43. checkpoints/9_12_4_multiscale_visa/epoch_8.pth +3 -0
  44. checkpoints/9_12_4_multiscale_visa/epoch_9.pth +3 -0
  45. dataset.py +50 -0
  46. datasets/rayan_dataset.py +127 -0
  47. docker-compose.yml +21 -0
  48. evaluation/base_eval.py +293 -0
  49. evaluation/class_name_mapping.json +5 -0
  50. evaluation/eval_main.py +78 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ *.pyo
3
+ __pycache__/
4
+ *.tar.gz
5
+ *.tar.xz
6
+ ZSAD-dataset
7
+ data/
AnomalyCLIP_lib/AnomalyCLIP.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+
8
+
9
+ class Bottleneck(nn.Module):
10
+ expansion = 4
11
+
12
+ def __init__(self, inplanes, planes, stride=1):
13
+ super().__init__()
14
+
15
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
16
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
17
+ self.bn1 = nn.BatchNorm2d(planes)
18
+ self.relu1 = nn.ReLU(inplace=True)
19
+
20
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+ self.relu2 = nn.ReLU(inplace=True)
23
+
24
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
25
+
26
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
27
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
28
+ self.relu3 = nn.ReLU(inplace=True)
29
+
30
+ self.downsample = None
31
+ self.stride = stride
32
+
33
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
34
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
35
+ self.downsample = nn.Sequential(OrderedDict([
36
+ ("-1", nn.AvgPool2d(stride)),
37
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
38
+ ("1", nn.BatchNorm2d(planes * self.expansion))
39
+ ]))
40
+
41
+ def forward(self, x: torch.Tensor):
42
+ identity = x
43
+
44
+ out = self.relu1(self.bn1(self.conv1(x)))
45
+ out = self.relu2(self.bn2(self.conv2(out)))
46
+ out = self.avgpool(out)
47
+ out = self.bn3(self.conv3(out))
48
+
49
+ if self.downsample is not None:
50
+ identity = self.downsample(x)
51
+
52
+ out += identity
53
+ out = self.relu3(out)
54
+ return out
55
+
56
+
57
+ # implement attention module for v-v self-attention
58
+ class Attention(nn.Module):
59
+ def __init__(self, out_dim, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., settings=''):
60
+ super().__init__()
61
+ self.num_heads = num_heads
62
+ head_dim = dim // num_heads
63
+ self.scale = qk_scale or head_dim ** -0.5
64
+
65
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
66
+ self.attn_drop = nn.Dropout(attn_drop)
67
+ self.proj = nn.Linear(out_dim, dim)
68
+ self.proj_drop = nn.Dropout(proj_drop)
69
+ self.settings = settings
70
+
71
+ def forward(self, x):
72
+ B, N, C = x.shape
73
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
74
+ q, k, v = qkv[0], qkv[1], qkv[2]
75
+
76
+ # original self-attention for the original path
77
+ attn_ori = (q @ k.transpose(-2, -1)) * self.scale
78
+ attn_ori = attn_ori.softmax(dim=-1)
79
+ attn_ori = self.attn_drop(attn_ori)
80
+
81
+ # replace k & q by v
82
+ k = v
83
+ q = k
84
+
85
+ # self-attention, higher temperate for resnets performs better
86
+ attn = (q @ k.transpose(-2, -1)) * self.scale
87
+ attn = (attn).softmax(dim=-1)
88
+ attn = self.attn_drop(attn)
89
+
90
+ x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C)
91
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
92
+ x = self.proj_drop(self.proj(x))
93
+ x_ori = self.proj_drop(self.proj(x_ori))
94
+ return [x, x_ori]
95
+
96
+
97
+
98
+ class LayerNorm(nn.LayerNorm):
99
+ """Subclass torch's LayerNorm to handle fp16."""
100
+
101
+ def forward(self, x: torch.Tensor):
102
+ orig_type = x.dtype
103
+ ret = super().forward(x.type(torch.float32))
104
+ return ret.type(orig_type)
105
+
106
+
107
+ class QuickGELU(nn.Module):
108
+ def forward(self, x: torch.Tensor):
109
+ return x * torch.sigmoid(1.702 * x)
110
+
111
+
112
+ class ResidualAttentionBlock(nn.Module):
113
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, design_details = None):
114
+ super().__init__()
115
+
116
+ self.attn = nn.MultiheadAttention(d_model, n_head)
117
+ self.ln_1 = LayerNorm(d_model)
118
+ self.mlp = nn.Sequential(OrderedDict([
119
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
120
+ ("gelu", QuickGELU()),
121
+ ("c_proj", nn.Linear(d_model * 4, d_model))
122
+ ]))
123
+ self.ln_2 = LayerNorm(d_model)
124
+ self.attn_mask = attn_mask
125
+
126
+ def attention(self, x: torch.Tensor):
127
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
128
+ if isinstance(self.attn, Attention):
129
+ x = x.transpose(0, 1)
130
+ x, x_ori = self.attn(x)
131
+ return [x.transpose(0, 1), x_ori.transpose(0, 1)]
132
+ else:
133
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
134
+
135
+ def forward(self, x, whole = False, ffn = False):
136
+ # print("xxxxx",x.shape)
137
+ # dual paths for blocks deeper than "d"
138
+
139
+ if isinstance(self.attn, Attention):
140
+ if isinstance(x, list):
141
+ if not ffn:
142
+ x, x_ori = x
143
+ x_res = self.attention(self.ln_1(x_ori))
144
+ x_res, x_ori_res = x_res
145
+ x_ori += x_ori_res
146
+ x_ori = x_ori + self.mlp(self.ln_2(x_ori))
147
+ x += x_res # skip ffn for the new path
148
+ # print('hellloooo')
149
+ return [x, x_ori]
150
+ else:
151
+ x, x_ori_1 = x
152
+ x_res = self.attention(self.ln_1(x_ori_1))
153
+ x_res, x_ori_res = x_res
154
+ x_ori = x_ori_1 + x_ori_res
155
+ x_ori = x_ori + self.mlp(self.ln_2(x_ori))
156
+ x += x_res # skip ffn for the new path
157
+ x = x_res + x_ori_1
158
+ x = x + self.mlp(self.ln_2(x))
159
+ return [x, x_ori]
160
+ # start of dual path
161
+ else:
162
+ x_res = self.attention(self.ln_1(x))
163
+ if isinstance(x_res, list):
164
+ x_res, x_ori_res = x_res
165
+ x_ori = x + x_ori_res
166
+ x_ori = x_ori + self.mlp(self.ln_2(x_ori))
167
+ x += x_res
168
+ return [x, x_ori]
169
+
170
+ # singl path before "d"
171
+ else:
172
+ x = x + self.attention(self.ln_1(x))
173
+ x = x + self.mlp(self.ln_2(x))
174
+ return x
175
+
176
+ class ResidualAttentionBlock_learnable_token(nn.Module):
177
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, design_details=None,
178
+ text_layer=False, i = 0):
179
+ super().__init__()
180
+
181
+ self.attn = nn.MultiheadAttention(d_model, n_head)
182
+ self.ln_1 = LayerNorm(d_model)
183
+ self.mlp = nn.Sequential(OrderedDict([
184
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
185
+ ("gelu", QuickGELU()),
186
+ ("c_proj", nn.Linear(d_model * 4, d_model))
187
+ ]))
188
+ self.ln_2 = LayerNorm(d_model)
189
+ self.attn_mask = attn_mask
190
+
191
+ self.i = i
192
+ self.compound_prompt_nctx = design_details['learnabel_text_embedding_length']
193
+ self.text_layer = text_layer
194
+ if i == 0:
195
+ self.first_layer = True
196
+ else:
197
+ self.first_layer = False
198
+
199
+ def attention(self, x: torch.Tensor):
200
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
201
+ if isinstance(self.attn, Attention):
202
+ x = x.transpose(0, 1)
203
+ x, x_ori = self.attn(x)
204
+ return [x.transpose(0, 1), x_ori.transpose(0, 1)]
205
+ else:
206
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
207
+
208
+ def forward(self, inputs):
209
+
210
+ # dual paths for blocks deeper than "d"
211
+ if isinstance(self.attn, Attention):
212
+ x = inputs[0]
213
+ if isinstance(x, list):
214
+ x, x_ori = x
215
+ x_res = self.attention(self.ln_1(x_ori))
216
+ x_res, x_ori_res = x_res
217
+ x_ori += x_ori_res
218
+ x_ori = x_ori + self.mlp(self.ln_2(x_ori))
219
+ x += x_res # skip ffn for the new path
220
+ return [x, x_ori]
221
+
222
+ # start of dual path
223
+ else:
224
+ x_res = self.attention(self.ln_1(x))
225
+ if isinstance(x_res, list):
226
+ x_res, x_ori_res = x_res
227
+ x_ori = x + x_ori_res
228
+ x_ori = x_ori + self.mlp(self.ln_2(x_ori))
229
+ x += x_res
230
+ return [x, x_ori]
231
+
232
+ # singl path before "d"
233
+ else:
234
+ x = inputs[0]
235
+ compound_prompts_deeper = inputs[1]
236
+ counter = inputs[2]
237
+ if not self.first_layer:
238
+ # First check if the ith layer needs compound prompts or not
239
+ if not (counter > len(compound_prompts_deeper) - 1):
240
+ # Appending the learnable tokens in different way
241
+ # x -> [77, NCLS, DIM]
242
+ # First remove the learnable tokens from previous layer
243
+ prefix = x[:1, :, :]
244
+ suffix = x[1 + self.compound_prompt_nctx:, :, :]
245
+ textual_context = compound_prompts_deeper[counter]
246
+ textual_context = textual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
247
+ # Add the learnable tokens of this layer with the input, replaced by previous
248
+ # layer learnable tokens
249
+ x = torch.cat([prefix, textual_context, suffix], dim=0)
250
+ # Once done, update the counter, so that the next time, it does not use same learnable tokens
251
+ counter += 1
252
+ x = x + self.attention(self.ln_1(x))
253
+ x = x + self.mlp(self.ln_2(x))
254
+ return [x, compound_prompts_deeper, counter]
255
+
256
+
257
+ class Transformer(nn.Module):
258
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False, design_details = None ,text_layer = False):
259
+ super().__init__()
260
+ self.width = width
261
+ self.layers = layers
262
+ self.text_layer = text_layer
263
+ self.design_deatails = design_details
264
+ print("text_layer", self.text_layer)
265
+ if self.text_layer and (design_details is not None):
266
+ self.resblocks = nn.ModuleList([ResidualAttentionBlock_learnable_token(width, heads, attn_mask, design_details, text_layer, i=i) for i in range(layers)])
267
+ else:
268
+ self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, attn_mask,) for i in range(layers)])
269
+
270
+ def ori_CLIP_with_patch_forward(self, x, out_layers):
271
+ idx = 0
272
+ out_tokens = []
273
+ for r in self.resblocks:
274
+ idx += 1
275
+ x = r(x)
276
+ if idx in out_layers:
277
+ if isinstance(x, list):
278
+ out_tokens.append(x[1])
279
+ else:
280
+ out_tokens.append(x)
281
+
282
+ return [x, x], out_tokens
283
+
284
+ def AnomalyCLIP_forward(self, x, out_layers, ffn):
285
+ idx = 0
286
+ out_tokens = []
287
+ for r in self.resblocks:
288
+ idx += 1
289
+ x = r(x, ffn = ffn)
290
+ # print("out_layers", out_layers, idx)
291
+ if idx in out_layers:
292
+ if isinstance(x, list):
293
+ out_tokens.append(x[0])
294
+ else:
295
+ out_tokens.append(x)
296
+ return x, out_tokens
297
+
298
+ def forward(self, x: torch.Tensor, out_layers = [6, 12, 18, 24], DPAM_layer = None, ffn = False):
299
+ # visual encoder forward
300
+ if not self.text_layer:
301
+ out_tokens = []
302
+
303
+ if DPAM_layer is None:
304
+ [x, x], out_tokens = self.ori_CLIP_with_patch_forward(x, out_layers)
305
+ return [x, x], out_tokens
306
+ else:
307
+ x, out_tokens = self.AnomalyCLIP_forward(x, out_layers, ffn)
308
+ return x, out_tokens
309
+ # text encoder forward
310
+ # ori text embedding
311
+ elif self.design_deatails is None:
312
+ for idx, r in enumerate(self.resblocks):
313
+ x = r(x)
314
+ return x
315
+ # insert learnable text embedding
316
+ elif self.design_deatails is not None:
317
+ for idx, r in enumerate(self.resblocks):
318
+ x = r(x)
319
+ return x[0]
320
+ def get_cast_dtype(self) -> torch.dtype:
321
+ return self.resblocks[0].mlp.c_fc.weight.dtype
322
+
323
+ class VisionTransformer(nn.Module):
324
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
325
+ super().__init__()
326
+ self.input_resolution = input_resolution
327
+ self.output_dim = output_dim
328
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
329
+
330
+ scale = width ** -0.5
331
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
332
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
333
+ self.ln_pre = LayerNorm(width)
334
+
335
+ self.transformer = Transformer(width, layers, heads, need_weights=True)
336
+ self.attn = None
337
+ self.embed_dim = width
338
+ self.num_heads = heads
339
+
340
+ self.ln_post = LayerNorm(width)
341
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
342
+
343
+
344
+ @torch.no_grad()
345
+ def DAPM_replace(self, DPAM_layer):
346
+ if DPAM_layer is not None:
347
+ for i in range(1, DPAM_layer):
348
+ self.attn = Attention(self.embed_dim, self.embed_dim, self.num_heads, True)
349
+ self.attn.qkv.weight.data = self.transformer.resblocks[-i].attn.in_proj_weight.clone()
350
+ self.attn.qkv.bias.data = self.transformer.resblocks[-i].attn.in_proj_bias.clone()
351
+ self.attn.proj.weight.data = self.transformer.resblocks[-i].attn.out_proj.weight.clone()
352
+ self.attn.proj.bias.data = self.transformer.resblocks[-i].attn.out_proj.bias.clone()
353
+ self.transformer.resblocks[-i].attn = self.attn
354
+
355
+ @torch.no_grad()
356
+ def forward(self, x: torch.Tensor, features_list, ori_patch = False, proj_use = True, DPAM_layer = None, ffn = False):
357
+
358
+ x = self.conv1(x) # shape = [*, width, grid, grid]
359
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
360
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
361
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
362
+ side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
363
+ new_side = int((x.shape[1] - 1) ** 0.5)
364
+
365
+ # update the position embedding during inference for varied input size
366
+ if side != new_side:
367
+ new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
368
+ new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
369
+ new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
370
+ self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
371
+
372
+ pos = self.positional_embedding.to(x.dtype)
373
+ x = x + pos
374
+ x = self.ln_pre(x)
375
+
376
+ x = x.permute(1, 0, 2) # NLD -> LND
377
+ [x, x_ori], patch_tokens = self.transformer(x, features_list, DPAM_layer = DPAM_layer, ffn = ffn)
378
+
379
+
380
+ if True:
381
+ patch_token_list = []
382
+ for patch_token in patch_tokens:
383
+ patch_token = self.ln_post(patch_token.permute(1, 0, 2)) @ self.proj # LND -> NLD
384
+ patch_token_list.append(patch_token)
385
+ patch_tokens = patch_token_list
386
+
387
+ return x_ori[0, :, :] @ self.proj, patch_tokens
388
+
389
+
390
+ return x
391
+
392
+
393
+ from thop import profile
394
+ class AnomalyCLIP(nn.Module):
395
+ def __init__(self,
396
+ embed_dim: int,
397
+ # vision
398
+ image_resolution: int,
399
+ vision_layers: Union[Tuple[int, int, int, int], int],
400
+ vision_width: int,
401
+ vision_patch_size: int,
402
+ # text
403
+ context_length: int,
404
+ vocab_size: int,
405
+ transformer_width: int,
406
+ transformer_heads: int,
407
+ transformer_layers: int,
408
+ design_details = None
409
+ ):
410
+ super().__init__()
411
+
412
+ self.context_length = context_length
413
+
414
+ if isinstance(vision_layers, (tuple, list)):
415
+ vision_heads = vision_width * 32 // 64
416
+ self.visual = ModifiedResNet(
417
+ layers=vision_layers,
418
+ output_dim=embed_dim,
419
+ heads=vision_heads,
420
+ input_resolution=image_resolution,
421
+ width=vision_width
422
+ )
423
+ else:
424
+ vision_heads = vision_width // 64
425
+ self.visual = VisionTransformer(
426
+ input_resolution=image_resolution,
427
+ patch_size=vision_patch_size,
428
+ width=vision_width,
429
+ layers=vision_layers,
430
+ heads=vision_heads,
431
+ output_dim=embed_dim
432
+ )
433
+
434
+ self.transformer = Transformer(
435
+ width=transformer_width,
436
+ layers=transformer_layers,
437
+ heads=transformer_heads,
438
+ attn_mask=self.build_attention_mask(), text_layer=True, design_details=design_details
439
+ )
440
+
441
+ self.vocab_size = vocab_size
442
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
443
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
444
+ self.ln_final = LayerNorm(transformer_width)
445
+
446
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
447
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
448
+
449
+ self.initialize_parameters()
450
+
451
+ def initialize_parameters(self):
452
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
453
+ nn.init.normal_(self.positional_embedding, std=0.01)
454
+
455
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
456
+ attn_std = self.transformer.width ** -0.5
457
+ fc_std = (2 * self.transformer.width) ** -0.5
458
+ for block in self.transformer.resblocks:
459
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
460
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
461
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
462
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
463
+
464
+ if self.text_projection is not None:
465
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
466
+ def build_attention_mask(self):
467
+ # lazily create causal attention mask, with full attention between the vision tokens
468
+ # pytorch uses additive attention mask; fill with -inf
469
+ mask = torch.empty(self.context_length, self.context_length)
470
+ mask.fill_(float("-inf"))
471
+ mask.triu_(1) # zero out the lower diagonal
472
+ return mask
473
+
474
+ @property
475
+ def dtype(self):
476
+ return self.visual.conv1.weight.dtype
477
+
478
+ def encode_image(self, image, feature_list = [], ori_patch = False, proj_use = True, DPAM_layer = None, ffn = False):
479
+ return self.visual(image.type(self.dtype), feature_list, ori_patch = ori_patch, proj_use = proj_use, DPAM_layer = DPAM_layer, ffn = ffn)
480
+
481
+
482
+ def encode_text(self, text):
483
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
484
+
485
+ x = x + self.positional_embedding.type(self.dtype)
486
+ x = x.permute(1, 0, 2) # NLD -> LND
487
+ x = self.transformer(x)
488
+ x = x.permute(1, 0, 2) # LND -> NLD
489
+ x = self.ln_final(x).type(self.dtype)
490
+
491
+ # x.shape = [batch_size, n_ctx, transformer.width]
492
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
493
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
494
+
495
+ return x
496
+
497
+ def encode_text_learn(self, prompts, tokenized_prompts, deep_compound_prompts_text = None, normalize: bool = False):
498
+ cast_dtype = self.transformer.get_cast_dtype()
499
+
500
+ # x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
501
+
502
+ # x = x + self.positional_embedding.to(cast_dtype)
503
+
504
+ x = prompts + self.positional_embedding.to(cast_dtype)
505
+ x = x.permute(1, 0, 2) # NLD -> LND
506
+ # print("test", x.shape, len(deep_compound_prompts_text))
507
+ if deep_compound_prompts_text is None:
508
+ x = self.transformer(x)
509
+ else:
510
+ x = self.transformer([x, deep_compound_prompts_text, 0])
511
+ x = x.permute(1, 0, 2) # LND -> NLD
512
+ x = self.ln_final(x).type(self.dtype) # [batch_size, n_ctx, transformer.width]
513
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
514
+ x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
515
+ return x
516
+
517
+ def forward(self, image, text):
518
+ image_features = self.encode_image(image)
519
+ text_features = self.encode_text(text)
520
+
521
+ # normalized features
522
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
523
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
524
+
525
+ # cosine similarity as logits
526
+ logit_scale = self.logit_scale.exp()
527
+ logits_per_image = logit_scale * image_features @ text_features.t()
528
+ logits_per_text = logits_per_image.t()
529
+
530
+ # shape = [global_batch_size, global_batch_size]
531
+ return logits_per_image, logits_per_text
AnomalyCLIP_lib/CLIP.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.relu2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.relu3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.relu1(self.bn1(self.conv1(x)))
46
+ out = self.relu2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.relu3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+
72
+ side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
73
+ new_side = int((x.shape[0] - 1) ** 0.5)
74
+
75
+ # update the position embedding during inference for varied input size
76
+ if side != new_side:
77
+ new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
78
+ new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
79
+ new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
80
+ self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
81
+
82
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
83
+ x, _ = F.multi_head_attention_forward(
84
+ query=x, key=x, value=x,
85
+ embed_dim_to_check=x.shape[-1],
86
+ num_heads=self.num_heads,
87
+ q_proj_weight=self.q_proj.weight,
88
+ k_proj_weight=self.k_proj.weight,
89
+ v_proj_weight=self.v_proj.weight,
90
+ in_proj_weight=None,
91
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
92
+ bias_k=None,
93
+ bias_v=None,
94
+ add_zero_attn=False,
95
+ dropout_p=0,
96
+ out_proj_weight=self.c_proj.weight,
97
+ out_proj_bias=self.c_proj.bias,
98
+ use_separate_proj_weight=True,
99
+ training=self.training,
100
+ need_weights=False
101
+ )
102
+
103
+ #return x[0]
104
+ return x.transpose(0, 1) # return both cls token and image tokens, B,N,C
105
+
106
+
107
+ class ModifiedResNet(nn.Module):
108
+ """
109
+ A ResNet class that is similar to torchvision's but contains the following changes:
110
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
111
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
112
+ - The final pooling layer is a QKV attention instead of an average pool
113
+ """
114
+
115
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
116
+ super().__init__()
117
+ self.output_dim = output_dim
118
+ self.input_resolution = input_resolution
119
+
120
+ # the 3-layer stem
121
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
122
+ self.bn1 = nn.BatchNorm2d(width // 2)
123
+ self.relu1 = nn.ReLU(inplace=True)
124
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
125
+ self.bn2 = nn.BatchNorm2d(width // 2)
126
+ self.relu2 = nn.ReLU(inplace=True)
127
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
128
+ self.bn3 = nn.BatchNorm2d(width)
129
+ self.relu3 = nn.ReLU(inplace=True)
130
+ self.avgpool = nn.AvgPool2d(2)
131
+
132
+ # residual layers
133
+ self._inplanes = width # this is a *mutable* variable used during construction
134
+ self.layer1 = self._make_layer(width, layers[0])
135
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
136
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
137
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
138
+
139
+ embed_dim = width * 32 # the ResNet feature dimension
140
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
141
+
142
+ def _make_layer(self, planes, blocks, stride=1):
143
+ layers = [Bottleneck(self._inplanes, planes, stride)]
144
+
145
+ self._inplanes = planes * Bottleneck.expansion
146
+ for _ in range(1, blocks):
147
+ layers.append(Bottleneck(self._inplanes, planes))
148
+
149
+ return nn.Sequential(*layers)
150
+
151
+ def forward(self, x):
152
+ def stem(x):
153
+ x = self.relu1(self.bn1(self.conv1(x)))
154
+ x = self.relu2(self.bn2(self.conv2(x)))
155
+ x = self.relu3(self.bn3(self.conv3(x)))
156
+ x = self.avgpool(x)
157
+ return x
158
+
159
+ x = x.type(self.conv1.weight.dtype)
160
+ x = stem(x)
161
+ x = self.layer1(x)
162
+ x = self.layer2(x)
163
+ x = self.layer3(x)
164
+ x = self.layer4(x)
165
+ x = self.attnpool(x)
166
+
167
+ return x
168
+
169
+
170
+ class LayerNorm(nn.LayerNorm):
171
+ """Subclass torch's LayerNorm to handle fp16."""
172
+
173
+ def forward(self, x: torch.Tensor):
174
+ orig_type = x.dtype
175
+ ret = super().forward(x.type(torch.float32))
176
+ return ret.type(orig_type)
177
+
178
+
179
+ class QuickGELU(nn.Module):
180
+ def forward(self, x: torch.Tensor):
181
+ return x * torch.sigmoid(1.702 * x)
182
+
183
+
184
+ class ResidualAttentionBlock(nn.Module):
185
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
186
+ super().__init__()
187
+
188
+ self.attn = nn.MultiheadAttention(d_model, n_head)
189
+ self.ln_1 = LayerNorm(d_model)
190
+ self.mlp = nn.Sequential(OrderedDict([
191
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
192
+ ("gelu", QuickGELU()),
193
+ ("c_proj", nn.Linear(d_model * 4, d_model))
194
+ ]))
195
+ self.ln_2 = LayerNorm(d_model)
196
+ self.attn_mask = attn_mask
197
+ self.need_weights = need_weights
198
+
199
+ def attention(self, x: torch.Tensor):
200
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
201
+ if self.need_weights == False:
202
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
203
+ else:
204
+ return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)
205
+
206
+ def forward(self, x: torch.Tensor):
207
+ if self.need_weights == False:
208
+ x = x + self.attention(self.ln_1(x))
209
+ x = x + self.mlp(self.ln_2(x))
210
+ return x
211
+ else:
212
+ y, attn = self.attention(self.ln_1(x))
213
+ x = x + y
214
+ x = x + self.mlp(self.ln_2(x))
215
+ return x
216
+
217
+
218
+ class Transformer(nn.Module):
219
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
220
+ super().__init__()
221
+ self.width = width
222
+ self.layers = layers
223
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, need_weights if i == layers - 1 else False) for i in range(layers)])
224
+
225
+ def forward(self, x: torch.Tensor):
226
+ return self.resblocks(x)
227
+
228
+ def get_cast_dtype(self) -> torch.dtype:
229
+ return self.resblocks[0].mlp.c_fc.weight.dtype
230
+
231
+
232
+
233
+ class VisionTransformer(nn.Module):
234
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
235
+ super().__init__()
236
+ self.input_resolution = input_resolution
237
+ self.output_dim = output_dim
238
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
239
+
240
+ scale = width ** -0.5
241
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
242
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
243
+ self.ln_pre = LayerNorm(width)
244
+
245
+ self.transformer = Transformer(width, layers, heads, need_weights=True)
246
+
247
+ self.ln_post = LayerNorm(width)
248
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
249
+
250
+ def forward(self, x: torch.Tensor):
251
+ x = self.conv1(x) # shape = [*, width, grid, grid]
252
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
253
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
254
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
255
+
256
+ #####################################################################################
257
+ side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
258
+ new_side = int((x.shape[1] - 1) ** 0.5)
259
+
260
+ # update the position embedding during inference for varied input size
261
+ if side != new_side:
262
+ new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
263
+ new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
264
+ new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
265
+ self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
266
+ #####################################################################################
267
+
268
+
269
+ x = x + self.positional_embedding.to(x.dtype)
270
+ x = self.ln_pre(x)
271
+
272
+ x = x.permute(1, 0, 2) # NLD -> LND
273
+ x = self.transformer(x)
274
+ x = x.permute(1, 0, 2) # LND -> NLD
275
+
276
+ #x = self.ln_post(x[:, 0, :])
277
+ x = self.ln_post(x) # return both cls token and image tokens
278
+
279
+ if self.proj is not None:
280
+ x = x @ self.proj
281
+
282
+ return x
283
+
284
+
285
+ class CLIP(nn.Module):
286
+ def __init__(self,
287
+ embed_dim: int,
288
+ # vision
289
+ image_resolution: int,
290
+ vision_layers: Union[Tuple[int, int, int, int], int],
291
+ vision_width: int,
292
+ vision_patch_size: int,
293
+ # text
294
+ context_length: int,
295
+ vocab_size: int,
296
+ transformer_width: int,
297
+ transformer_heads: int,
298
+ transformer_layers: int
299
+ ):
300
+ super().__init__()
301
+
302
+ self.context_length = context_length
303
+
304
+ if isinstance(vision_layers, (tuple, list)):
305
+ vision_heads = vision_width * 32 // 64
306
+ self.visual = ModifiedResNet(
307
+ layers=vision_layers,
308
+ output_dim=embed_dim,
309
+ heads=vision_heads,
310
+ input_resolution=image_resolution,
311
+ width=vision_width
312
+ )
313
+ else:
314
+ vision_heads = vision_width // 64
315
+ self.visual = VisionTransformer(
316
+ input_resolution=image_resolution,
317
+ patch_size=vision_patch_size,
318
+ width=vision_width,
319
+ layers=vision_layers,
320
+ heads=vision_heads,
321
+ output_dim=embed_dim
322
+ )
323
+
324
+ self.transformer = Transformer(
325
+ width=transformer_width,
326
+ layers=transformer_layers,
327
+ heads=transformer_heads,
328
+ attn_mask=self.build_attention_mask()
329
+ )
330
+
331
+ self.vocab_size = vocab_size
332
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
333
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
334
+ self.ln_final = LayerNorm(transformer_width)
335
+
336
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
337
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
338
+
339
+ self.initialize_parameters()
340
+
341
+ def initialize_parameters(self):
342
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
343
+ nn.init.normal_(self.positional_embedding, std=0.01)
344
+
345
+ if isinstance(self.visual, ModifiedResNet):
346
+ if self.visual.attnpool is not None:
347
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
348
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
349
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
350
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
351
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
352
+
353
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
354
+ for name, param in resnet_block.named_parameters():
355
+ if name.endswith("bn3.weight"):
356
+ nn.init.zeros_(param)
357
+
358
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
359
+ attn_std = self.transformer.width ** -0.5
360
+ fc_std = (2 * self.transformer.width) ** -0.5
361
+ for block in self.transformer.resblocks:
362
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
363
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
364
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
365
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
366
+
367
+ if self.text_projection is not None:
368
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
369
+
370
+ def build_attention_mask(self):
371
+ # lazily create causal attention mask, with full attention between the vision tokens
372
+ # pytorch uses additive attention mask; fill with -inf
373
+ mask = torch.empty(self.context_length, self.context_length)
374
+ mask.fill_(float("-inf"))
375
+ mask.triu_(1) # zero out the lower diagonal
376
+ return mask
377
+
378
+ @property
379
+ def dtype(self):
380
+ return self.visual.conv1.weight.dtype
381
+
382
+ def encode_image(self, image):
383
+ return self.visual(image.type(self.dtype))
384
+
385
+ def encode_text(self, text):
386
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
387
+
388
+ x = x + self.positional_embedding.type(self.dtype)
389
+ x = x.permute(1, 0, 2) # NLD -> LND
390
+ x = self.transformer(x)
391
+ x = x.permute(1, 0, 2) # LND -> NLD
392
+ x = self.ln_final(x).type(self.dtype)
393
+
394
+ # x.shape = [batch_size, n_ctx, transformer.width]
395
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
396
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
397
+
398
+ return x
399
+
400
+ def encode_text_learn(self, prompts, tokenized_prompts, deep_compound_prompts_text = None, normalize: bool = False):
401
+ cast_dtype = self.transformer.get_cast_dtype()
402
+
403
+ # x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
404
+
405
+ # x = x + self.positional_embedding.to(cast_dtype)
406
+
407
+ x = prompts + self.positional_embedding.to(cast_dtype)
408
+ x = x.permute(1, 0, 2) # NLD -> LND
409
+ # print("test", x.shape, len(deep_compound_prompts_text))
410
+ if deep_compound_prompts_text is None:
411
+ x = self.transformer(x)
412
+ else:
413
+ x = self.transformer([x, deep_compound_prompts_text, 0])
414
+ x = x.permute(1, 0, 2) # LND -> NLD
415
+ x = self.ln_final(x).type(self.dtype) # [batch_size, n_ctx, transformer.width]
416
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
417
+ x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
418
+ return x
419
+
420
+
421
+
422
+ def forward(self, image, text):
423
+ image_features = self.encode_image(image)
424
+ text_features = self.encode_text(text)
425
+
426
+ # normalized features
427
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
428
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
429
+
430
+ # cosine similarity as logits
431
+ logit_scale = self.logit_scale.exp()
432
+ logits_per_image = logit_scale * image_features @ text_features.t()
433
+ logits_per_text = logits_per_image.t()
434
+
435
+ # shape = [global_batch_size, global_batch_size]
436
+ return logits_per_image, logits_per_text
AnomalyCLIP_lib/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_load import *
AnomalyCLIP_lib/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
AnomalyCLIP_lib/build_model.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from .CLIP import CLIP
3
+ from .AnomalyCLIP import AnomalyCLIP
4
+
5
+ def build_model(name: str, state_dict: dict, design_details = None):
6
+ vit = "visual.proj" in state_dict
7
+
8
+ if vit:
9
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
10
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
11
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
12
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
13
+ image_resolution = vision_patch_size * grid_size
14
+ else:
15
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
16
+ vision_layers = tuple(counts)
17
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
18
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
19
+ vision_patch_size = None
20
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
21
+ image_resolution = output_width * 32
22
+
23
+ embed_dim = state_dict["text_projection"].shape[1]
24
+ context_length = state_dict["positional_embedding"].shape[0]
25
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
26
+ transformer_width = state_dict["ln_final.weight"].shape[0]
27
+ transformer_heads = transformer_width // 64
28
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
29
+ # print('name', name)
30
+ # if 'CS-' in name:
31
+ if design_details is not None:
32
+ model = AnomalyCLIP(
33
+ embed_dim,
34
+ image_resolution, vision_layers, vision_width, vision_patch_size,
35
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, design_details = design_details
36
+ )
37
+ else:
38
+ model = CLIP(
39
+ embed_dim,
40
+ image_resolution, vision_layers, vision_width, vision_patch_size,
41
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
42
+ )
43
+
44
+ for key in ["input_resolution", "context_length", "vocab_size"]:
45
+ if key in state_dict:
46
+ del state_dict[key]
47
+
48
+ #convert_weights(model)
49
+ model.load_state_dict(state_dict)
50
+ return model.eval()
AnomalyCLIP_lib/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
AnomalyCLIP_lib/model_load.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Union, List
6
+ from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+
14
+ from .build_model import build_model
15
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
16
+ from torchvision.transforms import InterpolationMode
17
+
18
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
19
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
20
+
21
+
22
+ __all__ = ["available_models", "load",
23
+ "get_similarity_map", "compute_similarity"]
24
+ _tokenizer = _Tokenizer()
25
+
26
+ _MODELS = {
27
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
28
+ }
29
+
30
+
31
+ def _download(
32
+ url: str,
33
+ cache_dir: Union[str, None] = None,
34
+ ):
35
+
36
+ if not cache_dir:
37
+ # cache_dir = os.path.expanduser("~/.cache/clip")
38
+ cache_dir = os.path.expanduser("/remote-home/iot_zhouqihang/root/.cache/clip")
39
+ os.makedirs(cache_dir, exist_ok=True)
40
+ filename = os.path.basename(url)
41
+
42
+ if 'openaipublic' in url:
43
+ expected_sha256 = url.split("/")[-2]
44
+ elif 'mlfoundations' in url:
45
+ expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
46
+ else:
47
+ expected_sha256 = ''
48
+
49
+ download_target = os.path.join(cache_dir, filename)
50
+
51
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
52
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
53
+
54
+ if os.path.isfile(download_target):
55
+ if expected_sha256:
56
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
57
+ return download_target
58
+ else:
59
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
60
+ else:
61
+ return download_target
62
+
63
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
64
+ with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
65
+ while True:
66
+ buffer = source.read(8192)
67
+ if not buffer:
68
+ break
69
+
70
+ output.write(buffer)
71
+ loop.update(len(buffer))
72
+
73
+ if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
74
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
75
+
76
+ return download_target
77
+
78
+
79
+ def _convert_image_to_rgb(image):
80
+ return image.convert("RGB")
81
+
82
+
83
+ def _transform(n_px):
84
+ return Compose([
85
+ Resize((n_px, n_px), interpolation=InterpolationMode.BICUBIC),
86
+ #CenterCrop(n_px), # rm center crop to explain whole image
87
+ _convert_image_to_rgb,
88
+ ToTensor(),
89
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
90
+ ])
91
+
92
+
93
+ def available_models() -> List[str]:
94
+ """Returns the names of available CLIP models"""
95
+ return list(_MODELS.keys())
96
+
97
+
98
+ def load_state_dict(checkpoint_path: str, map_location='cpu'):
99
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
100
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
101
+ state_dict = checkpoint['state_dict']
102
+ else:
103
+ state_dict = checkpoint
104
+ if next(iter(state_dict.items()))[0].startswith('module'):
105
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
106
+ return state_dict
107
+
108
+ def load_checkpoint(model, checkpoint_path, strict=True):
109
+ state_dict = load_state_dict(checkpoint_path)
110
+ # detect old format and make compatible with new format
111
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
112
+ state_dict = convert_to_custom_text_state_dict(state_dict)
113
+ resize_pos_embed(state_dict, model)
114
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
115
+ return incompatible_keys
116
+
117
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", design_details = None, jit: bool = False, download_root: str = None):
118
+ """Load a CLIP model
119
+
120
+ Parameters
121
+ ----------
122
+ name : str
123
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
124
+
125
+ device : Union[str, torch.device]
126
+ The device to put the loaded model
127
+
128
+ jit : bool
129
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
130
+
131
+ download_root: str
132
+ path to download the model files; by default, it uses "~/.cache/clip"
133
+
134
+ Returns
135
+ -------
136
+ model : torch.nn.Module
137
+ The CLIP model
138
+
139
+ preprocess : Callable[[PIL.Image], torch.Tensor]
140
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
141
+ """
142
+ print("name", name)
143
+ if name in _MODELS:
144
+ # model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
145
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("/remote-home/iot_zhouqihang/root/.cache/clip"))
146
+ elif os.path.isfile(name):
147
+ model_path = name
148
+ else:
149
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
150
+
151
+ with open(model_path, 'rb') as opened_file:
152
+ try:
153
+ # loading JIT archive
154
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
155
+ state_dict = None
156
+ except RuntimeError:
157
+ # loading saved state dict
158
+ if jit:
159
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
160
+ jit = False
161
+ state_dict = torch.load(opened_file, map_location="cpu")
162
+
163
+ if not jit:
164
+ model = build_model(name, state_dict or model.state_dict(), design_details).to(device)
165
+ if str(device) == "cpu":
166
+ model.float()
167
+ return model, _transform(model.visual.input_resolution)
168
+
169
+ # patch the device names
170
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
171
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
172
+
173
+ def patch_device(module):
174
+ try:
175
+ graphs = [module.graph] if hasattr(module, "graph") else []
176
+ except RuntimeError:
177
+ graphs = []
178
+
179
+ if hasattr(module, "forward1"):
180
+ graphs.append(module.forward1.graph)
181
+
182
+ for graph in graphs:
183
+ for node in graph.findAllNodes("prim::Constant"):
184
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
185
+ node.copyAttributes(device_node)
186
+
187
+ model.apply(patch_device)
188
+ patch_device(model.encode_image)
189
+ patch_device(model.encode_text)
190
+
191
+ # patch dtype to float32 on CPU
192
+ if str(device) == "cpu":
193
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
194
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
195
+ float_node = float_input.node()
196
+
197
+ def patch_float(module):
198
+ try:
199
+ graphs = [module.graph] if hasattr(module, "graph") else []
200
+ except RuntimeError:
201
+ graphs = []
202
+
203
+ if hasattr(module, "forward1"):
204
+ graphs.append(module.forward1.graph)
205
+
206
+ for graph in graphs:
207
+ for node in graph.findAllNodes("aten::to"):
208
+ inputs = list(node.inputs())
209
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
210
+ if inputs[i].node()["value"] == 5:
211
+ inputs[i].node().copyAttributes(float_node)
212
+
213
+ model.apply(patch_float)
214
+ patch_float(model.encode_image)
215
+ patch_float(model.encode_text)
216
+
217
+ model.float()
218
+
219
+ return model, _transform(model.input_resolution.item())
220
+
221
+
222
+ def get_similarity_map(sm, shape):
223
+ side = int(sm.shape[1] ** 0.5)
224
+ sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2)
225
+ sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear')
226
+ sm = sm.permute(0, 2, 3, 1)
227
+ return sm
228
+
229
+
230
+ def compute_similarity(image_features, text_features, t=2):
231
+ prob_1 = image_features[:, :1, :] @ text_features.t()
232
+ b, n_t, n_i, c = image_features.shape[0], text_features.shape[0], image_features.shape[1], image_features.shape[2]
233
+ feats = image_features.reshape(b, n_i, 1, c) * text_features.reshape(1, 1, n_t, c)
234
+ similarity = feats.sum(-1)
235
+ return (similarity/0.07).softmax(-1), prob_1
AnomalyCLIP_lib/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
AnomalyCLIP_lib/transform.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from dataclasses import dataclass, asdict
3
+ from typing import Any, Dict, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.transforms.functional as F
8
+
9
+ from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
10
+ CenterCrop
11
+
12
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13
+
14
+
15
+ @dataclass
16
+ class AugmentationCfg:
17
+ scale: Tuple[float, float] = (0.9, 1.0)
18
+ ratio: Optional[Tuple[float, float]] = None
19
+ color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
20
+ interpolation: Optional[str] = None
21
+ re_prob: Optional[float] = None
22
+ re_count: Optional[int] = None
23
+ use_timm: bool = False
24
+
25
+
26
+ class ResizeMaxSize(nn.Module):
27
+
28
+ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
29
+ super().__init__()
30
+ if not isinstance(max_size, int):
31
+ raise TypeError(f"Size should be int. Got {type(max_size)}")
32
+ self.max_size = max_size
33
+ self.interpolation = interpolation
34
+ self.fn = min if fn == 'min' else min
35
+ self.fill = fill
36
+
37
+ def forward(self, img):
38
+ if isinstance(img, torch.Tensor):
39
+ height, width = img.shape[:2]
40
+ else:
41
+ width, height = img.size
42
+ scale = self.max_size / float(max(height, width))
43
+ if scale != 1.0:
44
+ new_size = tuple(round(dim * scale) for dim in (height, width))
45
+ img = F.resize(img, new_size, self.interpolation)
46
+ pad_h = self.max_size - new_size[0]
47
+ pad_w = self.max_size - new_size[1]
48
+ img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
49
+ return img
50
+
51
+
52
+ def _convert_to_rgb(image):
53
+ return image.convert('RGB')
54
+
55
+
56
+ def image_transform(
57
+ image_size: int,
58
+ is_train: bool,
59
+ mean: Optional[Tuple[float, ...]] = None,
60
+ std: Optional[Tuple[float, ...]] = None,
61
+ resize_longest_max: bool = False,
62
+ fill_color: int = 0,
63
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
64
+ ):
65
+ mean = mean or OPENAI_DATASET_MEAN
66
+ if not isinstance(mean, (list, tuple)):
67
+ mean = (mean,) * 3
68
+
69
+ std = std or OPENAI_DATASET_STD
70
+ if not isinstance(std, (list, tuple)):
71
+ std = (std,) * 3
72
+
73
+ if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
74
+ # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
75
+ image_size = image_size[0]
76
+
77
+ if isinstance(aug_cfg, dict):
78
+ aug_cfg = AugmentationCfg(**aug_cfg)
79
+ else:
80
+ aug_cfg = aug_cfg or AugmentationCfg()
81
+ normalize = Normalize(mean=mean, std=std)
82
+ if is_train:
83
+ aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
84
+ use_timm = aug_cfg_dict.pop('use_timm', False)
85
+ if use_timm:
86
+ from timm.data import create_transform # timm can still be optional
87
+ if isinstance(image_size, (tuple, list)):
88
+ assert len(image_size) >= 2
89
+ input_size = (3,) + image_size[-2:]
90
+ else:
91
+ input_size = (3, image_size, image_size)
92
+ # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
93
+ aug_cfg_dict.setdefault('interpolation', 'random')
94
+ aug_cfg_dict.setdefault('color_jitter', None) # disable by default
95
+ train_transform = create_transform(
96
+ input_size=input_size,
97
+ is_training=True,
98
+ hflip=0.,
99
+ mean=mean,
100
+ std=std,
101
+ re_mode='pixel',
102
+ **aug_cfg_dict,
103
+ )
104
+ else:
105
+ train_transform = Compose([
106
+ RandomResizedCrop(
107
+ image_size,
108
+ scale=aug_cfg_dict.pop('scale'),
109
+ interpolation=InterpolationMode.BICUBIC,
110
+ ),
111
+ _convert_to_rgb,
112
+ ToTensor(),
113
+ normalize,
114
+ ])
115
+ if aug_cfg_dict:
116
+ warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
117
+ return train_transform
118
+ else:
119
+ if resize_longest_max:
120
+ transforms = [
121
+ ResizeMaxSize(image_size, fill=fill_color)
122
+ ]
123
+ else:
124
+ transforms = [
125
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
126
+ CenterCrop(image_size),
127
+ ]
128
+ transforms.extend([
129
+ _convert_to_rgb,
130
+ ToTensor(),
131
+ normalize,
132
+ ])
133
+ return Compose(transforms)
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # A sample Dockerfile to help you replicate our test environment
3
+ # -----------------------------------------------------------------------------
4
+
5
+ FROM pytorch/pytorch:2.4.1-cuda12.4-cudnn9-runtime
6
+ WORKDIR /app
7
+ COPY . .
8
+
9
+ # Install your python and apt requirements
10
+ RUN pip install -r requirements.txt
11
+ RUN apt-get update && apt-get install $(cat apt_requirements.txt) -y
12
+ RUN chmod +x run.sh
13
+
14
+ CMD ["python3", "runner.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Qihang Zhou
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AnomalyCLIP (Train once and test other)
2
+ > [**ICLR 24**] [**AnomalyCLIP: Object-agnostic Prompt Learning for Zero-shot Anomaly Detection**](https://arxiv.org/pdf/2310.18961.pdf)
3
+ >
4
+ > by [Qihang Zhou*](), [Guansong Pang*](https://www.guansongpang.com/), [Yu Tian](https://yutianyt.com/), [Shibo He](https://scholar.google.com/citations?hl=zh-CN&user=5GOcb4gAAAAJ&view_op=list_works&sortby=pubdate), [Jiming Chen](https://scholar.google.com/citations?user=zK9tvo8AAAAJ&hl=zh-CN).
5
+
6
+
7
+ ## Updates
8
+
9
+ - **03.19.2024**: Code has been released !!!
10
+ - **08.08.2024**: Update the code for testing one image.
11
+
12
+ ## Introduction
13
+ Zero-shot anomaly detection (ZSAD) requires detection models trained using auxiliary data to detect anomalies without any training sample in a target dataset. It is a crucial task when training data is not accessible due to various concerns, e.g., data privacy, yet it is challenging since the models need to generalize to anomalies across different domains where the appearance of foreground objects, abnormal regions, and background features, such as defects/tumors on different products/organs, can vary significantly. Recently large pre-trained vision-language models (VLMs), such as CLIP,
14
+ have demonstrated strong zero-shot recognition ability in various vision tasks, including anomaly detection. However, their ZSAD performance is weak since the VLMs focus more on modeling the class semantics of the foreground objects rather than the abnormality/normality in the images.
15
+ In this paper we introduce a novel approach, namely AnomalyCLIP, to adapt CLIP for accurate ZSAD across different domains. The key insight of AnomalyCLIP is to learn object-agnostic text prompts that capture generic normality and abnormality in an image regardless of its foreground objects. This allows our model to focus on the abnormal image regions rather than the object semantics, enabling generalized normality and abnormality recognition on diverse types of objects. Large-scale experiments on 17 real-world anomaly detection datasets show that AnomalyCLIP achieves superior zero-shot performance of detecting and segmenting anomalies in datasets of highly diverse class semantics from various defect inspection and medical imaging domains. All experiments are conducted in PyTorch-2.0.0 with a single NVIDIA RTX 3090 24GB.
16
+
17
+ ## Overview of AnomalyCLIP
18
+ ![overview](https://github.com/zqhang/AnomalyCLIP/assets/19222962/4ec3e5fc-9570-41f7-8067-6e7a515841be)
19
+
20
+
21
+ ## Analysis of different text prompt templates
22
+ ![analysis](./assets/analysis.png)
23
+
24
+
25
+ ## How to Run
26
+ ### Prepare your dataset
27
+ Download the dataset below:
28
+
29
+ * Industrial Domain:
30
+ [MVTec](https://www.mvtec.com/company/research/datasets/mvtec-ad), [VisA](https://github.com/amazon-science/spot-diff), [MPDD](https://github.com/stepanje/MPDD), [BTAD](http://avires.dimi.uniud.it/papers/btad/btad.zip), [SDD](https://www.vicos.si/resources/kolektorsdd/), [DAGM](https://www.kaggle.com/datasets/mhskjelvareid/dagm-2007-competition-dataset-optical-inspection), [DTD-Synthetic](https://drive.google.com/drive/folders/10OyPzvI3H6llCZBxKxFlKWt1Pw1tkMK1)
31
+
32
+ * Medical Domain:
33
+ [HeadCT](https://www.kaggle.com/datasets/felipekitamura/head-ct-hemorrhage), [BrainMRI](https://www.kaggle.com/datasets/navoneel/brain-mri-images-for-brain-tumor-detection), [Br35H](https://www.kaggle.com/datasets/ahmedhamada0/brain-tumor-detection), [COVID-19](https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database), [ISIC](https://isic-challenge-data.s3.amazonaws.com/2016/ISBI2016_ISIC_Part1_Test_Data.zip), [CVC-ColonDB](https://figshare.com/articles/figure/Polyp_DataSet_zip/21221579), [CVC-ClinicDB](https://figshare.com/articles/figure/Polyp_DataSet_zip/21221579), [Kvasir](https://figshare.com/articles/figure/Polyp_DataSet_zip/21221579), [Endo](https://drive.google.com/file/d/1LNpLkv5ZlEUzr_RPN5rdOHaqk0SkZa3m/view), [TN3K](https://github.com/haifangong/TRFE-Net-for-thyroid-nodule-segmentation?tab=readme-ov-file).
34
+
35
+ * Google Drive link (frequently requested dataset): [SDD](https://drive.google.com/drive/folders/1oqaxUZYi44jlLT4WtT6D5T6onPTNZXsu?usp=drive_link), [Br35H](https://drive.google.com/file/d/1l9XODMBm4X23K70LtpxAxgoaBbNzr4Nc/view?usp=drive_link), [COVID-19](https://drive.google.com/file/d/1ECwI8DJmhEtcVHatxCAdFqnSmXs35WFL/view?usp=drive_link)
36
+ ### Generate the dataset JSON
37
+ Take MVTec AD for example (With multiple anomaly categories)
38
+
39
+ Structure of MVTec Folder:
40
+ ```
41
+ mvtec/
42
+
43
+ ├── meta.json
44
+
45
+ ├── bottle/
46
+ │ ├── ground_truth/
47
+ │ │ ├── broken_large/
48
+ │ │ │ └── 000_mask.png
49
+ | | | └── ...
50
+ │ │ └── ...
51
+ │ └── test/
52
+ │ ├── broken_large/
53
+ │ │ └── 000.png
54
+ | | └── ...
55
+ │ └── ...
56
+
57
+ └── ...
58
+ ```
59
+
60
+ ```bash
61
+ cd generate_dataset_json
62
+ python mvtec.py
63
+ ```
64
+
65
+ Take SDD for example (With single anomaly category)
66
+
67
+ Structure of SDD Folder:
68
+ ```
69
+ SDD/
70
+
71
+ ├── electrical_commutators/
72
+ │ └── test/
73
+ │ ├─��� defect/
74
+ │ │ └── kos01_Part5_0.png
75
+ | | └── ...
76
+ │ └── good/
77
+ │ └── kos01_Part0_0.png
78
+ │ └── ...
79
+
80
+ └── meta.json
81
+ ```
82
+
83
+ ```bash
84
+ cd generate_dataset_json
85
+ python SDD.py
86
+ ```
87
+ Select the corresponding script and run it (we provide all scripts for datasets that AnomalyCLIP reported). The generated JSON stores all the information that AnomalyCLIP needs.
88
+
89
+ ### Custom dataset (optional)
90
+ 1. Create a new JSON script in fold [generate_dataset_json](https://github.com/zqhang/AnomalyCLIP/tree/main/generate_dataset_json) according to the fold structure of your own datasets.
91
+ 2. Add the related info of your dataset (i.e., dataset name and class names) in script [dataset\.py](https://github.com/zqhang/AnomalyCLIP/blob/main/dataset.py)
92
+
93
+ ### Run AnomalyCLIP
94
+ * Quick start (use the pre-trained weights)
95
+ ```bash
96
+ bash test.sh
97
+ ```
98
+
99
+ * Train your own weights
100
+ ```bash
101
+ bash train.sh
102
+ ```
103
+
104
+
105
+ ## Main results (We test all datasets by training once on MVTec AD. For MVTec AD, AnomalyCLIP is trained on VisA.)
106
+
107
+ ### Industrial dataset
108
+ ![industrial](./assets/Industrial.png)
109
+
110
+
111
+ ### Medical dataset
112
+ ![medical](./assets/medical.png)
113
+
114
+
115
+ ## Visualization
116
+
117
+ ![hazelnut](./assets/hazelnut.png)
118
+
119
+ ![capusle](./assets/capusle.png)
120
+
121
+ ![skin](./assets/skin.png)
122
+
123
+ ![brain](./assets/brain.png)
124
+
125
+
126
+ ## We provide the reproduction of WinCLIP [here](https://github.com/zqhang/WinCLIP-pytorch)
127
+
128
+
129
+ * We thank for the code repository: [open_clip](https://github.com/mlfoundations/open_clip), [DualCoOp](https://github.com/sunxm2357/DualCoOp), [CLIP_Surgery](https://github.com/xmed-lab/CLIP_Surgery), and [VAND](https://github.com/ByChelsea/VAND-APRIL-GAN/tree/master).
130
+
131
+ ## BibTex Citation
132
+
133
+ If you find this paper and repository useful, please cite our paper.
134
+
135
+ ```
136
+ @inproceedings{zhou2023anomalyclip,
137
+ title={AnomalyCLIP: Object-agnostic Prompt Learning for Zero-shot Anomaly Detection},
138
+ author={Zhou, Qihang and Pang, Guansong and Tian, Yu and He, Shibo and Chen, Jiming},
139
+ booktitle={The Twelfth International Conference on Learning Representations},
140
+ year={2023}
141
+ }
142
+ ```
checkpoints/9_12_4_multiscale/epoch_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a89d1ffe49d86995e936c8e91515efa878d4e1777c73888622091e89a8df9e5b
3
+ size 22631493
checkpoints/9_12_4_multiscale/epoch_10.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7205c05df3319984b349686cbfd8cc01d3ac241a82f33943e9217cbb85604b0b
3
+ size 22631975
checkpoints/9_12_4_multiscale/epoch_11.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40017b0588b3e41aea4cf3902b388bbee494201b4406583f0a9c96f90818a986
3
+ size 22631975
checkpoints/9_12_4_multiscale/epoch_12.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef4bdfad5689797d48296eeceb57343aabba5ae5a2c7e57d4b9e225d2d254252
3
+ size 22631975
checkpoints/9_12_4_multiscale/epoch_13.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4381596b44bbaa33e7b04b4a19a46582980f1ee8742414d71147c8be95ef90d7
3
+ size 22631975
checkpoints/9_12_4_multiscale/epoch_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd2a3865c4cf1363b80f301da7dc181a54787e3c218cc1f3464650a5f749cb26
3
+ size 22631975
checkpoints/9_12_4_multiscale/epoch_15.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94ce202da3e6486a864b904fdfed5057de75846c5834e446fd1d2fe7f97acb44
3
+ size 22631975
checkpoints/9_12_4_multiscale/epoch_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6bfcd2ed1725b3d58dd06d5d38f7ef6d3b9c49d817bb4714a16f3153c3d7450
3
+ size 22631493
checkpoints/9_12_4_multiscale/epoch_3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5af4c383158732845ac2ef195e5036e8528f187ed80173c8d993830a0abed64c
3
+ size 22631493
checkpoints/9_12_4_multiscale/epoch_4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ab9a9909711c89cac5f02f0c46c7baac82b09bfaca59a83271a50b195cad89f
3
+ size 22631493
checkpoints/9_12_4_multiscale/epoch_5.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:317837a0ef5b46d2476c234d3fa77e8cfab7bbfa85711f5fe7eb7f50ea7151a0
3
+ size 22631493
checkpoints/9_12_4_multiscale/epoch_6.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04379155c0df8d4e1194335427091e626df512a9747e47c1bbb7ee3a55708164
3
+ size 22631493
checkpoints/9_12_4_multiscale/epoch_7.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41c5a77a355c27266d6a9c7b6da4b3ee2c193596873d889822e68a797a2688b2
3
+ size 22631493
checkpoints/9_12_4_multiscale/epoch_8.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c92bfa088eccb2efb71b27c9703c0f21158903581efd7292f42938ad96940c82
3
+ size 22631493
checkpoints/9_12_4_multiscale/epoch_9.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43f0eca2d506b88370a06c94a6cd557360c7bcb179a4f3f24981230349a9581a
3
+ size 22631493
checkpoints/9_12_4_multiscale/log.txt ADDED
File without changes
checkpoints/9_12_4_multiscale_visa/epoch_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de5df7fc2ec18acb5709e65b1889d586974d365c39d1aa4df728336633e4ee70
3
+ size 22631493
checkpoints/9_12_4_multiscale_visa/epoch_10.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:397255934bd313beeab2b610fa901f113e12342974687147cad78f502e5ae7e5
3
+ size 22631975
checkpoints/9_12_4_multiscale_visa/epoch_11.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:843fb9df1c46da89f6976a42d10d5fe34675ad48eccb365e3f43785f925c2ae9
3
+ size 22631975
checkpoints/9_12_4_multiscale_visa/epoch_12.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17f69ad9ae4bcc5823fdd9ad56b51ec57cc641270280a1776c1014ea1969f282
3
+ size 22631975
checkpoints/9_12_4_multiscale_visa/epoch_13.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bf5fd9c269e3f68e81134f4361c3239ba14d5f2cd4e3564f93f5b59f616cd19
3
+ size 22631975
checkpoints/9_12_4_multiscale_visa/epoch_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:969dbaaa1a986f17d79dfb81d2ce90443d0e9dd9f19db7fd9a9190f97cc8e3d4
3
+ size 22631975
checkpoints/9_12_4_multiscale_visa/epoch_15.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:415c5dcb52668b8c33fb9c1a351c686d632b919df5b384d63fa9ce7a2338ced4
3
+ size 22631975
checkpoints/9_12_4_multiscale_visa/epoch_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c98c722977ac0fc42c1067a8038656c10466728f6e9d448aad9e3f6b3d5368b6
3
+ size 22631493
checkpoints/9_12_4_multiscale_visa/epoch_3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3e7a65d6b9ff057b5fa53bfc59bfa57a25619b5a5d9cd40ed37579e312ab4aa
3
+ size 22631493
checkpoints/9_12_4_multiscale_visa/epoch_4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f56b0ed7bd9da05f77780a3c4318e038c258b99a02ad1455652cad146b3dded5
3
+ size 22631493
checkpoints/9_12_4_multiscale_visa/epoch_5.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2c44c082a19abde2993e80044466c1e45a620cc24aad39e85bd65ed60d3572d
3
+ size 22631493
checkpoints/9_12_4_multiscale_visa/epoch_6.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:402d63bca2150631fb09d8d1c7529712a4ee8eea29bd7746412eae99b4ec6dc5
3
+ size 22631493
checkpoints/9_12_4_multiscale_visa/epoch_7.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:081526236212ebc011ec53babaf8f0da7e25fbe92300aa7cc68eb41ca29b054f
3
+ size 22631493
checkpoints/9_12_4_multiscale_visa/epoch_8.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f2587be72657ab30fc26bc5957e130ba7359ff53c32beb7984be517a818427c
3
+ size 22631493
checkpoints/9_12_4_multiscale_visa/epoch_9.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4850f209b34912c33718b86c13d2a01c340907d182236a8ef8903f35c80daec0
3
+ size 22631493
dataset.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+ import json
3
+ import random
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch
7
+ import os
8
+
9
+ class Dataset(data.Dataset):
10
+ def __init__(self, root, transform, target_transform, dataset_name, mode='test'):
11
+ self.root = root
12
+ self.transform = transform
13
+ self.target_transform = target_transform
14
+ self.data_all = []
15
+ meta_info = json.load(open(f'{self.root}/meta.json', 'r'))
16
+ name = self.root.split('/')[-1]
17
+ meta_info = meta_info[mode]
18
+
19
+ self.cls_names = list(meta_info.keys())
20
+ for cls_name in self.cls_names:
21
+ self.data_all.extend(meta_info[cls_name])
22
+ self.length = len(self.data_all)
23
+
24
+ self.obj_list = [folder for folder in os.listdir(root) if os.path.isdir(os.path.join(root, folder)) and not folder.startswith('.')]
25
+ self.class_name_map_class_id = {o: i for i, o in enumerate(self.obj_list)}
26
+
27
+ def __len__(self):
28
+ return self.length
29
+
30
+ def __getitem__(self, index):
31
+ data = self.data_all[index]
32
+ img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
33
+ data['specie_name'], data['anomaly']
34
+ img = Image.open(os.path.join(self.root, img_path))
35
+ if anomaly == 0:
36
+ img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
37
+ else:
38
+ if os.path.isdir(os.path.join(self.root, mask_path)):
39
+ # just for classification not report error
40
+ img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
41
+ else:
42
+ img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
43
+ img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
44
+ # transforms
45
+ img = self.transform(img) if self.transform is not None else img
46
+ img_mask = self.target_transform(
47
+ img_mask) if self.target_transform is not None and img_mask is not None else img_mask
48
+ img_mask = [] if img_mask is None else img_mask
49
+ return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
50
+ 'img_path': os.path.join(self.root, img_path), "cls_id": self.class_name_map_class_id[cls_name]}
datasets/rayan_dataset.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # Do Not Alter This File!
3
+ # -----------------------------------------------------------------------------
4
+ # The following code is part of the logic used for loading and evaluating your
5
+ # output scores. Please DO NOT modify this section, as upon your submission,
6
+ # the whole evaluation logic will be overwritten by the original code.
7
+ # -----------------------------------------------------------------------------
8
+ # If you'd like to make modifications, you can create a completely new Dataset
9
+ # class or a child class that inherits from this one and use that with your
10
+ # data loader.
11
+ # -----------------------------------------------------------------------------
12
+
13
+ import os
14
+ from enum import Enum
15
+
16
+ import PIL
17
+ import torch
18
+ from torchvision import transforms
19
+
20
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
21
+ IMAGENET_STD = [0.229, 0.224, 0.225]
22
+
23
+
24
+ class DatasetSplit(Enum):
25
+ TRAIN = "train"
26
+ VAL = "val"
27
+ TEST = "test"
28
+
29
+
30
+ class RayanDataset(torch.utils.data.Dataset):
31
+ def __init__(
32
+ self,
33
+ source,
34
+ classname,
35
+ input_size=518,
36
+ output_size=224,
37
+ split=DatasetSplit.TEST,
38
+ external_transform=None,
39
+ **kwargs,
40
+ ):
41
+ super().__init__()
42
+ self.source = source
43
+ self.split = split
44
+ self.classnames_to_use = [classname]
45
+ self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()
46
+
47
+ if external_transform is None:
48
+ self.transform_img = [
49
+ transforms.Resize((input_size, input_size)),
50
+ transforms.CenterCrop(input_size),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
53
+ ]
54
+ self.transform_img = transforms.Compose(self.transform_img)
55
+ else:
56
+ self.transform_img = external_transform
57
+
58
+ # Output size of the mask has to be of shape: 1×224×224
59
+ self.transform_mask = [
60
+ transforms.Resize((output_size, output_size)),
61
+ transforms.CenterCrop(output_size),
62
+ transforms.ToTensor(),
63
+ ]
64
+ self.transform_mask = transforms.Compose(self.transform_mask)
65
+ self.output_shape = (1, output_size, output_size)
66
+
67
+ def __getitem__(self, idx):
68
+ classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]
69
+ image = PIL.Image.open(image_path).convert("RGB")
70
+ image = self.transform_img(image)
71
+
72
+ if self.split == DatasetSplit.TEST and mask_path is not None:
73
+ mask = PIL.Image.open(mask_path).convert("L")
74
+ mask = self.transform_mask(mask) > 0
75
+ else:
76
+ mask = torch.zeros([*self.output_shape])
77
+
78
+ return {
79
+ "image": image,
80
+ "mask": mask,
81
+ "is_anomaly": int(anomaly != "good"),
82
+ "image_path": image_path,
83
+ }
84
+
85
+ def __len__(self):
86
+ return len(self.data_to_iterate)
87
+
88
+ def get_image_data(self):
89
+ imgpaths_per_class = {}
90
+ maskpaths_per_class = {}
91
+
92
+ for classname in self.classnames_to_use:
93
+ classpath = os.path.join(self.source, classname, self.split.value)
94
+ maskpath = os.path.join(self.source, classname, "ground_truth")
95
+ anomaly_types = os.listdir(classpath)
96
+
97
+ imgpaths_per_class[classname] = {}
98
+ maskpaths_per_class[classname] = {}
99
+
100
+ for anomaly in anomaly_types:
101
+ anomaly_path = os.path.join(classpath, anomaly)
102
+ anomaly_files = sorted(os.listdir(anomaly_path))
103
+ imgpaths_per_class[classname][anomaly] = [
104
+ os.path.join(anomaly_path, x) for x in anomaly_files
105
+ ]
106
+
107
+ if self.split == DatasetSplit.TEST and anomaly != "good":
108
+ anomaly_mask_path = os.path.join(maskpath, anomaly)
109
+ anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
110
+ maskpaths_per_class[classname][anomaly] = [
111
+ os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files
112
+ ]
113
+ else:
114
+ maskpaths_per_class[classname]["good"] = None
115
+
116
+ data_to_iterate = []
117
+ for classname in sorted(imgpaths_per_class.keys()):
118
+ for anomaly in sorted(imgpaths_per_class[classname].keys()):
119
+ for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
120
+ data_tuple = [classname, anomaly, image_path]
121
+ if self.split == DatasetSplit.TEST and anomaly != "good":
122
+ data_tuple.append(maskpaths_per_class[classname][anomaly][i])
123
+ else:
124
+ data_tuple.append(None)
125
+ data_to_iterate.append(data_tuple)
126
+
127
+ return imgpaths_per_class, data_to_iterate
docker-compose.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # A sample Docker Compose file to help you replicate our test environment
3
+ # -----------------------------------------------------------------------------
4
+
5
+ services:
6
+ zsad-service:
7
+ image: zsad-image:1
8
+ build:
9
+ context: .
10
+ container_name: zsad-container
11
+ volumes:
12
+ - ./shared_folder:/app/output
13
+ deploy:
14
+ resources:
15
+ reservations:
16
+ devices:
17
+ - driver: nvidia
18
+ count: all
19
+ capabilities: [gpu]
20
+
21
+ command: [ "python3", "runner.py" ]
evaluation/base_eval.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # Do Not Alter This File!
3
+ # -----------------------------------------------------------------------------
4
+ # The following code is part of the logic used for loading and evaluating your
5
+ # output scores. Please DO NOT modify this section, as upon your submission,
6
+ # the whole evaluation logic will be overwritten by the original code.
7
+ # -----------------------------------------------------------------------------
8
+
9
+ import warnings
10
+ import os
11
+ from pathlib import Path
12
+ import csv
13
+ import json
14
+ import torch
15
+
16
+ import datasets.rayan_dataset as rayan_dataset
17
+ from evaluation.utils.metrics import compute_metrics
18
+
19
+ warnings.filterwarnings("ignore")
20
+
21
+
22
+ class BaseEval:
23
+ def __init__(self, cfg):
24
+ self.cfg = cfg
25
+ self.device = torch.device(
26
+ "cuda:{}".format(cfg["device"]) if torch.cuda.is_available() else "cpu"
27
+ )
28
+
29
+ self.path = cfg["datasets"]["data_path"]
30
+ self.dataset = cfg["datasets"]["dataset_name"]
31
+ self.save_csv = cfg["testing"]["save_csv"]
32
+ self.save_json = cfg["testing"]["save_json"]
33
+ self.categories = cfg["datasets"]["class_name"]
34
+ if isinstance(self.categories, str):
35
+ if self.categories.lower() == "all":
36
+ if self.dataset == "rayan_dataset":
37
+ self.categories = self.get_available_class_names(self.path)
38
+ else:
39
+ self.categories = [self.categories]
40
+ self.output_dir = cfg["testing"]["output_dir"]
41
+ os.makedirs(self.output_dir, exist_ok=True)
42
+ self.scores_dir = cfg["testing"]["output_scores_dir"]
43
+ self.class_name_mapping_dir = cfg["testing"]["class_name_mapping_dir"]
44
+
45
+ self.leaderboard_metric_weights = {
46
+ "image_auroc": 1.2,
47
+ "image_ap": 1.1,
48
+ "image_f1": 1.1,
49
+ "pixel_auroc": 1.0,
50
+ "pixel_aupro": 1.4,
51
+ "pixel_ap": 1.3,
52
+ "pixel_f1": 1.3,
53
+ }
54
+
55
+ def get_available_class_names(self, root_data_path):
56
+ all_items = os.listdir(root_data_path)
57
+ folder_names = [
58
+ item
59
+ for item in all_items
60
+ if os.path.isdir(os.path.join(root_data_path, item))
61
+ ]
62
+
63
+ return folder_names
64
+
65
+ def load_datasets(self, category):
66
+ dataset_classes = {
67
+ "rayan_dataset": rayan_dataset.RayanDataset,
68
+ }
69
+
70
+ dataset_splits = {
71
+ "rayan_dataset": rayan_dataset.DatasetSplit.TEST,
72
+ }
73
+
74
+ test_dataset = dataset_classes[self.dataset](
75
+ source=self.path,
76
+ split=dataset_splits[self.dataset],
77
+ classname=category,
78
+ )
79
+ return test_dataset
80
+
81
+ def get_category_metrics(self, category):
82
+ print(f"Loading scores of '{category}'")
83
+ gt_sp, pr_sp, gt_px, pr_px, _ = self.load_category_scores(category)
84
+
85
+ print(f"Computing metrics for '{category}'")
86
+ image_metric, pixel_metric = compute_metrics(gt_sp, pr_sp, gt_px, pr_px)
87
+
88
+ return image_metric, pixel_metric
89
+
90
+ def load_category_scores(self, category):
91
+ raise NotImplementedError()
92
+
93
+ def get_scores_path_for_image(self, image_path):
94
+ """example image_path: './data/photovoltaic_module/test/good/037.png'"""
95
+ path = Path(image_path)
96
+
97
+ category, split, anomaly_type = path.parts[-4:-1]
98
+ image_name = path.stem
99
+
100
+ return os.path.join(
101
+ self.scores_dir, category, split, anomaly_type, f"{image_name}_scores.json"
102
+ )
103
+
104
+ def calc_leaderboard_score(self, **metrics):
105
+ weighted_sum = 0
106
+ total_weight = 0
107
+ for key, weight in self.leaderboard_metric_weights.items():
108
+ metric = metrics.get(key)
109
+ weighted_sum += metric * weight
110
+ total_weight += weight
111
+
112
+ if total_weight == 0:
113
+ return 0
114
+
115
+ return weighted_sum / total_weight
116
+
117
+ def main(self):
118
+ image_auroc_list = []
119
+ image_f1_list = []
120
+ image_ap_list = []
121
+ pixel_auroc_list = []
122
+ pixel_f1_list = []
123
+ pixel_ap_list = []
124
+ pixel_aupro_list = []
125
+ leaderboard_score_list = []
126
+ for category in self.categories:
127
+ image_metric, pixel_metric = self.get_category_metrics(
128
+ category=category,
129
+ )
130
+ image_auroc, image_f1, image_ap = image_metric
131
+ pixel_auroc, pixel_f1, pixel_ap, pixel_aupro = pixel_metric
132
+ leaderboard_score = self.calc_leaderboard_score(
133
+ image_auroc=image_auroc,
134
+ image_f1=image_f1,
135
+ image_ap=image_ap,
136
+ pixel_auroc=pixel_auroc,
137
+ pixel_aupro=pixel_aupro,
138
+ pixel_f1=pixel_f1,
139
+ pixel_ap=pixel_ap,
140
+ )
141
+
142
+ image_auroc_list.append(image_auroc)
143
+ image_f1_list.append(image_f1)
144
+ image_ap_list.append(image_ap)
145
+ pixel_auroc_list.append(pixel_auroc)
146
+ pixel_f1_list.append(pixel_f1)
147
+ pixel_ap_list.append(pixel_ap)
148
+ pixel_aupro_list.append(pixel_aupro)
149
+ leaderboard_score_list.append(leaderboard_score)
150
+
151
+ print(category)
152
+ print(
153
+ "[image level] auroc:{}, f1:{}, ap:{}".format(
154
+ image_auroc * 100,
155
+ image_f1 * 100,
156
+ image_ap * 100,
157
+ )
158
+ )
159
+ print(
160
+ "[pixel level] auroc:{}, f1:{}, ap:{}, aupro:{}".format(
161
+ pixel_auroc * 100,
162
+ pixel_f1 * 100,
163
+ pixel_ap * 100,
164
+ pixel_aupro * 100,
165
+ )
166
+ )
167
+ print(
168
+ "leaderboard score:{}".format(
169
+ leaderboard_score * 100,
170
+ )
171
+ )
172
+
173
+ image_auroc_mean = sum(image_auroc_list) / len(image_auroc_list)
174
+ image_f1_mean = sum(image_f1_list) / len(image_f1_list)
175
+ image_ap_mean = sum(image_ap_list) / len(image_ap_list)
176
+ pixel_auroc_mean = sum(pixel_auroc_list) / len(pixel_auroc_list)
177
+ pixel_f1_mean = sum(pixel_f1_list) / len(pixel_f1_list)
178
+ pixel_ap_mean = sum(pixel_ap_list) / len(pixel_ap_list)
179
+ pixel_aupro_mean = sum(pixel_aupro_list) / len(pixel_aupro_list)
180
+ leaderboard_score_mean = sum(leaderboard_score_list) / len(
181
+ leaderboard_score_list
182
+ )
183
+
184
+ print("mean")
185
+ print(
186
+ "[image level] auroc:{}, f1:{}, ap:{}".format(
187
+ image_auroc_mean * 100, image_f1_mean * 100, image_ap_mean * 100
188
+ )
189
+ )
190
+ print(
191
+ "[pixel level] auroc:{}, f1:{}, ap:{}, aupro:{}".format(
192
+ pixel_auroc_mean * 100,
193
+ pixel_f1_mean * 100,
194
+ pixel_ap_mean * 100,
195
+ pixel_aupro_mean * 100,
196
+ )
197
+ )
198
+ print(
199
+ "leaderboard score:{}".format(
200
+ leaderboard_score_mean * 100,
201
+ )
202
+ )
203
+
204
+ # Save the final results as a csv file
205
+ if self.save_csv:
206
+ with open(self.class_name_mapping_dir, "r") as f:
207
+ class_name_mapping_dict = json.load(f)
208
+ csv_data = [
209
+ [
210
+ "Category",
211
+ "pixel_auroc",
212
+ "pixel_f1",
213
+ "pixel_ap",
214
+ "pixel_aupro",
215
+ "image_auroc",
216
+ "image_f1",
217
+ "image_ap",
218
+ "leaderboard_score",
219
+ ]
220
+ ]
221
+ for i, category in enumerate(self.categories):
222
+ csv_data.append(
223
+ [
224
+ class_name_mapping_dict[category],
225
+ pixel_auroc_list[i] * 100,
226
+ pixel_f1_list[i] * 100,
227
+ pixel_ap_list[i] * 100,
228
+ pixel_aupro_list[i] * 100,
229
+ image_auroc_list[i] * 100,
230
+ image_f1_list[i] * 100,
231
+ image_ap_list[i] * 100,
232
+ leaderboard_score_list[i] * 100,
233
+ ]
234
+ )
235
+ csv_data.append(
236
+ [
237
+ "mean",
238
+ pixel_auroc_mean * 100,
239
+ pixel_f1_mean * 100,
240
+ pixel_ap_mean * 100,
241
+ pixel_aupro_mean * 100,
242
+ image_auroc_mean * 100,
243
+ image_f1_mean * 100,
244
+ image_ap_mean * 100,
245
+ leaderboard_score_mean * 100,
246
+ ]
247
+ )
248
+
249
+ csv_file_path = os.path.join(self.output_dir, "results.csv")
250
+ with open(csv_file_path, mode="w", newline="") as file:
251
+ writer = csv.writer(file)
252
+ writer.writerows(csv_data)
253
+
254
+ # Save the final results as a json file
255
+ if self.save_json:
256
+ json_data = []
257
+ with open(self.class_name_mapping_dir, "r") as f:
258
+ class_name_mapping_dict = json.load(f)
259
+ for i, category in enumerate(self.categories):
260
+ json_data.append(
261
+ {
262
+ "Category": class_name_mapping_dict[category],
263
+ "pixel_auroc": pixel_auroc_list[i] * 100,
264
+ "pixel_f1": pixel_f1_list[i] * 100,
265
+ "pixel_ap": pixel_ap_list[i] * 100,
266
+ "pixel_aupro": pixel_aupro_list[i] * 100,
267
+ "image_auroc": image_auroc_list[i] * 100,
268
+ "image_f1": image_f1_list[i] * 100,
269
+ "image_ap": image_ap_list[i] * 100,
270
+ "leaderboard_score": leaderboard_score_list[i] * 100,
271
+ }
272
+ )
273
+ json_data.append(
274
+ {
275
+ "Category": "mean",
276
+ "pixel_auroc": pixel_auroc_mean * 100,
277
+ "pixel_f1": pixel_f1_mean * 100,
278
+ "pixel_ap": pixel_ap_mean * 100,
279
+ "pixel_aupro": pixel_aupro_mean * 100,
280
+ "image_auroc": image_auroc_mean * 100,
281
+ "image_f1": image_f1_mean * 100,
282
+ "image_ap": image_ap_mean * 100,
283
+ "leaderboard_score": leaderboard_score_mean * 100,
284
+ }
285
+ )
286
+
287
+ json_file_path = os.path.join(self.output_dir, "results.json")
288
+ with open(json_file_path, mode="w") as file:
289
+ final_json = {
290
+ "result": leaderboard_score_mean * 100,
291
+ "metadata": json_data,
292
+ }
293
+ json.dump(final_json, file, indent=4)
evaluation/class_name_mapping.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "pill": "industrial_01",
3
+ "photovoltaic_module": "industrial_02",
4
+ "capsules": "industrial_03"
5
+ }
evaluation/eval_main.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # Do Not Alter This File!
3
+ # -----------------------------------------------------------------------------
4
+ # The following code is part of the logic used for loading and evaluating your
5
+ # output scores. Please DO NOT modify this section, as upon your submission,
6
+ # the whole evaluation logic will be overwritten by the original code.
7
+ # -----------------------------------------------------------------------------
8
+
9
+ import warnings
10
+ import argparse
11
+ import os
12
+ import sys
13
+
14
+ sys.path.append(os.getcwd())
15
+ from evaluation.json_score import JsonScoreEvaluator
16
+
17
+ warnings.filterwarnings("ignore")
18
+
19
+
20
+ def get_args():
21
+ parser = argparse.ArgumentParser(description="Rayan ZSAD Evaluation Code")
22
+ parser.add_argument("--data_path", type=str, default=None, help="dataset path")
23
+ parser.add_argument("--dataset_name", type=str, default=None, help="dataset name")
24
+ parser.add_argument("--class_name", type=str, default=None, help="category")
25
+ parser.add_argument("--device", type=int, default=None, help="gpu id")
26
+ parser.add_argument(
27
+ "--output_dir", type=str, default=None, help="save results path"
28
+ )
29
+ parser.add_argument(
30
+ "--output_scores_dir", type=str, default=None, help="save scores path"
31
+ )
32
+ parser.add_argument("--save_csv", type=str, default=None, help="save csv")
33
+ parser.add_argument("--save_json", type=str, default=None, help="save json")
34
+
35
+ parser.add_argument(
36
+ "--class_name_mapping_dir",
37
+ type=str,
38
+ default=None,
39
+ help="mapping from actual class names to class numbers",
40
+ )
41
+ args = parser.parse_args()
42
+ return args
43
+
44
+
45
+ def load_args(cfg, args):
46
+ cfg["datasets"]["data_path"] = args.data_path
47
+ assert os.path.exists(
48
+ cfg["datasets"]["data_path"]
49
+ ), f"The dataset path {cfg['datasets']['data_path']} does not exist."
50
+ cfg["datasets"]["dataset_name"] = args.dataset_name
51
+ cfg["datasets"]["class_name"] = args.class_name
52
+ cfg["device"] = args.device
53
+ if isinstance(cfg["device"], int):
54
+ cfg["device"] = str(cfg["device"])
55
+ cfg["testing"]["output_dir"] = args.output_dir
56
+ cfg["testing"]["output_scores_dir"] = args.output_scores_dir
57
+ os.makedirs(cfg["testing"]["output_scores_dir"], exist_ok=True)
58
+
59
+ cfg["testing"]["class_name_mapping_dir"] = args.class_name_mapping_dir
60
+ if args.save_csv.lower() == "true":
61
+ cfg["testing"]["save_csv"] = True
62
+ else:
63
+ cfg["testing"]["save_csv"] = False
64
+
65
+ if args.save_json.lower() == "true":
66
+ cfg["testing"]["save_json"] = True
67
+ else:
68
+ cfg["testing"]["save_json"] = False
69
+
70
+ return cfg
71
+
72
+
73
+ if __name__ == "__main__":
74
+ args = get_args()
75
+ cfg = load_args(cfg={"datasets": {}, "testing": {}, "models": {}}, args=args)
76
+ print(cfg)
77
+ model = JsonScoreEvaluator(cfg=cfg)
78
+ model.main()