Amirparsa-Sal
commited on
Commit
·
5d1f0ae
1
Parent(s):
9a99f40
Add codes
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +7 -0
- AnomalyCLIP_lib/AnomalyCLIP.py +531 -0
- AnomalyCLIP_lib/CLIP.py +436 -0
- AnomalyCLIP_lib/__init__.py +1 -0
- AnomalyCLIP_lib/bpe_simple_vocab_16e6.txt.gz +3 -0
- AnomalyCLIP_lib/build_model.py +50 -0
- AnomalyCLIP_lib/constants.py +2 -0
- AnomalyCLIP_lib/model_load.py +235 -0
- AnomalyCLIP_lib/simple_tokenizer.py +132 -0
- AnomalyCLIP_lib/transform.py +133 -0
- Dockerfile +14 -0
- LICENSE +21 -0
- README.md +142 -0
- checkpoints/9_12_4_multiscale/epoch_1.pth +3 -0
- checkpoints/9_12_4_multiscale/epoch_10.pth +3 -0
- checkpoints/9_12_4_multiscale/epoch_11.pth +3 -0
- checkpoints/9_12_4_multiscale/epoch_12.pth +3 -0
- checkpoints/9_12_4_multiscale/epoch_13.pth +3 -0
- checkpoints/9_12_4_multiscale/epoch_14.pth +3 -0
- checkpoints/9_12_4_multiscale/epoch_15.pth +3 -0
- checkpoints/9_12_4_multiscale/epoch_2.pth +3 -0
- checkpoints/9_12_4_multiscale/epoch_3.pth +3 -0
- checkpoints/9_12_4_multiscale/epoch_4.pth +3 -0
- checkpoints/9_12_4_multiscale/epoch_5.pth +3 -0
- checkpoints/9_12_4_multiscale/epoch_6.pth +3 -0
- checkpoints/9_12_4_multiscale/epoch_7.pth +3 -0
- checkpoints/9_12_4_multiscale/epoch_8.pth +3 -0
- checkpoints/9_12_4_multiscale/epoch_9.pth +3 -0
- checkpoints/9_12_4_multiscale/log.txt +0 -0
- checkpoints/9_12_4_multiscale_visa/epoch_1.pth +3 -0
- checkpoints/9_12_4_multiscale_visa/epoch_10.pth +3 -0
- checkpoints/9_12_4_multiscale_visa/epoch_11.pth +3 -0
- checkpoints/9_12_4_multiscale_visa/epoch_12.pth +3 -0
- checkpoints/9_12_4_multiscale_visa/epoch_13.pth +3 -0
- checkpoints/9_12_4_multiscale_visa/epoch_14.pth +3 -0
- checkpoints/9_12_4_multiscale_visa/epoch_15.pth +3 -0
- checkpoints/9_12_4_multiscale_visa/epoch_2.pth +3 -0
- checkpoints/9_12_4_multiscale_visa/epoch_3.pth +3 -0
- checkpoints/9_12_4_multiscale_visa/epoch_4.pth +3 -0
- checkpoints/9_12_4_multiscale_visa/epoch_5.pth +3 -0
- checkpoints/9_12_4_multiscale_visa/epoch_6.pth +3 -0
- checkpoints/9_12_4_multiscale_visa/epoch_7.pth +3 -0
- checkpoints/9_12_4_multiscale_visa/epoch_8.pth +3 -0
- checkpoints/9_12_4_multiscale_visa/epoch_9.pth +3 -0
- dataset.py +50 -0
- datasets/rayan_dataset.py +127 -0
- docker-compose.yml +21 -0
- evaluation/base_eval.py +293 -0
- evaluation/class_name_mapping.json +5 -0
- evaluation/eval_main.py +78 -0
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
*.pyo
|
3 |
+
__pycache__/
|
4 |
+
*.tar.gz
|
5 |
+
*.tar.xz
|
6 |
+
ZSAD-dataset
|
7 |
+
data/
|
AnomalyCLIP_lib/AnomalyCLIP.py
ADDED
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
|
9 |
+
class Bottleneck(nn.Module):
|
10 |
+
expansion = 4
|
11 |
+
|
12 |
+
def __init__(self, inplanes, planes, stride=1):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
16 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
17 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
18 |
+
self.relu1 = nn.ReLU(inplace=True)
|
19 |
+
|
20 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
21 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
22 |
+
self.relu2 = nn.ReLU(inplace=True)
|
23 |
+
|
24 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
25 |
+
|
26 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
27 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
28 |
+
self.relu3 = nn.ReLU(inplace=True)
|
29 |
+
|
30 |
+
self.downsample = None
|
31 |
+
self.stride = stride
|
32 |
+
|
33 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
34 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
35 |
+
self.downsample = nn.Sequential(OrderedDict([
|
36 |
+
("-1", nn.AvgPool2d(stride)),
|
37 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
38 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
39 |
+
]))
|
40 |
+
|
41 |
+
def forward(self, x: torch.Tensor):
|
42 |
+
identity = x
|
43 |
+
|
44 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
45 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
46 |
+
out = self.avgpool(out)
|
47 |
+
out = self.bn3(self.conv3(out))
|
48 |
+
|
49 |
+
if self.downsample is not None:
|
50 |
+
identity = self.downsample(x)
|
51 |
+
|
52 |
+
out += identity
|
53 |
+
out = self.relu3(out)
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
# implement attention module for v-v self-attention
|
58 |
+
class Attention(nn.Module):
|
59 |
+
def __init__(self, out_dim, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., settings=''):
|
60 |
+
super().__init__()
|
61 |
+
self.num_heads = num_heads
|
62 |
+
head_dim = dim // num_heads
|
63 |
+
self.scale = qk_scale or head_dim ** -0.5
|
64 |
+
|
65 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
66 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
67 |
+
self.proj = nn.Linear(out_dim, dim)
|
68 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
69 |
+
self.settings = settings
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
B, N, C = x.shape
|
73 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
74 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
75 |
+
|
76 |
+
# original self-attention for the original path
|
77 |
+
attn_ori = (q @ k.transpose(-2, -1)) * self.scale
|
78 |
+
attn_ori = attn_ori.softmax(dim=-1)
|
79 |
+
attn_ori = self.attn_drop(attn_ori)
|
80 |
+
|
81 |
+
# replace k & q by v
|
82 |
+
k = v
|
83 |
+
q = k
|
84 |
+
|
85 |
+
# self-attention, higher temperate for resnets performs better
|
86 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
87 |
+
attn = (attn).softmax(dim=-1)
|
88 |
+
attn = self.attn_drop(attn)
|
89 |
+
|
90 |
+
x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C)
|
91 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
92 |
+
x = self.proj_drop(self.proj(x))
|
93 |
+
x_ori = self.proj_drop(self.proj(x_ori))
|
94 |
+
return [x, x_ori]
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
class LayerNorm(nn.LayerNorm):
|
99 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
100 |
+
|
101 |
+
def forward(self, x: torch.Tensor):
|
102 |
+
orig_type = x.dtype
|
103 |
+
ret = super().forward(x.type(torch.float32))
|
104 |
+
return ret.type(orig_type)
|
105 |
+
|
106 |
+
|
107 |
+
class QuickGELU(nn.Module):
|
108 |
+
def forward(self, x: torch.Tensor):
|
109 |
+
return x * torch.sigmoid(1.702 * x)
|
110 |
+
|
111 |
+
|
112 |
+
class ResidualAttentionBlock(nn.Module):
|
113 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, design_details = None):
|
114 |
+
super().__init__()
|
115 |
+
|
116 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
117 |
+
self.ln_1 = LayerNorm(d_model)
|
118 |
+
self.mlp = nn.Sequential(OrderedDict([
|
119 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
120 |
+
("gelu", QuickGELU()),
|
121 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
122 |
+
]))
|
123 |
+
self.ln_2 = LayerNorm(d_model)
|
124 |
+
self.attn_mask = attn_mask
|
125 |
+
|
126 |
+
def attention(self, x: torch.Tensor):
|
127 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
128 |
+
if isinstance(self.attn, Attention):
|
129 |
+
x = x.transpose(0, 1)
|
130 |
+
x, x_ori = self.attn(x)
|
131 |
+
return [x.transpose(0, 1), x_ori.transpose(0, 1)]
|
132 |
+
else:
|
133 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
134 |
+
|
135 |
+
def forward(self, x, whole = False, ffn = False):
|
136 |
+
# print("xxxxx",x.shape)
|
137 |
+
# dual paths for blocks deeper than "d"
|
138 |
+
|
139 |
+
if isinstance(self.attn, Attention):
|
140 |
+
if isinstance(x, list):
|
141 |
+
if not ffn:
|
142 |
+
x, x_ori = x
|
143 |
+
x_res = self.attention(self.ln_1(x_ori))
|
144 |
+
x_res, x_ori_res = x_res
|
145 |
+
x_ori += x_ori_res
|
146 |
+
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
|
147 |
+
x += x_res # skip ffn for the new path
|
148 |
+
# print('hellloooo')
|
149 |
+
return [x, x_ori]
|
150 |
+
else:
|
151 |
+
x, x_ori_1 = x
|
152 |
+
x_res = self.attention(self.ln_1(x_ori_1))
|
153 |
+
x_res, x_ori_res = x_res
|
154 |
+
x_ori = x_ori_1 + x_ori_res
|
155 |
+
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
|
156 |
+
x += x_res # skip ffn for the new path
|
157 |
+
x = x_res + x_ori_1
|
158 |
+
x = x + self.mlp(self.ln_2(x))
|
159 |
+
return [x, x_ori]
|
160 |
+
# start of dual path
|
161 |
+
else:
|
162 |
+
x_res = self.attention(self.ln_1(x))
|
163 |
+
if isinstance(x_res, list):
|
164 |
+
x_res, x_ori_res = x_res
|
165 |
+
x_ori = x + x_ori_res
|
166 |
+
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
|
167 |
+
x += x_res
|
168 |
+
return [x, x_ori]
|
169 |
+
|
170 |
+
# singl path before "d"
|
171 |
+
else:
|
172 |
+
x = x + self.attention(self.ln_1(x))
|
173 |
+
x = x + self.mlp(self.ln_2(x))
|
174 |
+
return x
|
175 |
+
|
176 |
+
class ResidualAttentionBlock_learnable_token(nn.Module):
|
177 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, design_details=None,
|
178 |
+
text_layer=False, i = 0):
|
179 |
+
super().__init__()
|
180 |
+
|
181 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
182 |
+
self.ln_1 = LayerNorm(d_model)
|
183 |
+
self.mlp = nn.Sequential(OrderedDict([
|
184 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
185 |
+
("gelu", QuickGELU()),
|
186 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
187 |
+
]))
|
188 |
+
self.ln_2 = LayerNorm(d_model)
|
189 |
+
self.attn_mask = attn_mask
|
190 |
+
|
191 |
+
self.i = i
|
192 |
+
self.compound_prompt_nctx = design_details['learnabel_text_embedding_length']
|
193 |
+
self.text_layer = text_layer
|
194 |
+
if i == 0:
|
195 |
+
self.first_layer = True
|
196 |
+
else:
|
197 |
+
self.first_layer = False
|
198 |
+
|
199 |
+
def attention(self, x: torch.Tensor):
|
200 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
201 |
+
if isinstance(self.attn, Attention):
|
202 |
+
x = x.transpose(0, 1)
|
203 |
+
x, x_ori = self.attn(x)
|
204 |
+
return [x.transpose(0, 1), x_ori.transpose(0, 1)]
|
205 |
+
else:
|
206 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
207 |
+
|
208 |
+
def forward(self, inputs):
|
209 |
+
|
210 |
+
# dual paths for blocks deeper than "d"
|
211 |
+
if isinstance(self.attn, Attention):
|
212 |
+
x = inputs[0]
|
213 |
+
if isinstance(x, list):
|
214 |
+
x, x_ori = x
|
215 |
+
x_res = self.attention(self.ln_1(x_ori))
|
216 |
+
x_res, x_ori_res = x_res
|
217 |
+
x_ori += x_ori_res
|
218 |
+
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
|
219 |
+
x += x_res # skip ffn for the new path
|
220 |
+
return [x, x_ori]
|
221 |
+
|
222 |
+
# start of dual path
|
223 |
+
else:
|
224 |
+
x_res = self.attention(self.ln_1(x))
|
225 |
+
if isinstance(x_res, list):
|
226 |
+
x_res, x_ori_res = x_res
|
227 |
+
x_ori = x + x_ori_res
|
228 |
+
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
|
229 |
+
x += x_res
|
230 |
+
return [x, x_ori]
|
231 |
+
|
232 |
+
# singl path before "d"
|
233 |
+
else:
|
234 |
+
x = inputs[0]
|
235 |
+
compound_prompts_deeper = inputs[1]
|
236 |
+
counter = inputs[2]
|
237 |
+
if not self.first_layer:
|
238 |
+
# First check if the ith layer needs compound prompts or not
|
239 |
+
if not (counter > len(compound_prompts_deeper) - 1):
|
240 |
+
# Appending the learnable tokens in different way
|
241 |
+
# x -> [77, NCLS, DIM]
|
242 |
+
# First remove the learnable tokens from previous layer
|
243 |
+
prefix = x[:1, :, :]
|
244 |
+
suffix = x[1 + self.compound_prompt_nctx:, :, :]
|
245 |
+
textual_context = compound_prompts_deeper[counter]
|
246 |
+
textual_context = textual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
|
247 |
+
# Add the learnable tokens of this layer with the input, replaced by previous
|
248 |
+
# layer learnable tokens
|
249 |
+
x = torch.cat([prefix, textual_context, suffix], dim=0)
|
250 |
+
# Once done, update the counter, so that the next time, it does not use same learnable tokens
|
251 |
+
counter += 1
|
252 |
+
x = x + self.attention(self.ln_1(x))
|
253 |
+
x = x + self.mlp(self.ln_2(x))
|
254 |
+
return [x, compound_prompts_deeper, counter]
|
255 |
+
|
256 |
+
|
257 |
+
class Transformer(nn.Module):
|
258 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False, design_details = None ,text_layer = False):
|
259 |
+
super().__init__()
|
260 |
+
self.width = width
|
261 |
+
self.layers = layers
|
262 |
+
self.text_layer = text_layer
|
263 |
+
self.design_deatails = design_details
|
264 |
+
print("text_layer", self.text_layer)
|
265 |
+
if self.text_layer and (design_details is not None):
|
266 |
+
self.resblocks = nn.ModuleList([ResidualAttentionBlock_learnable_token(width, heads, attn_mask, design_details, text_layer, i=i) for i in range(layers)])
|
267 |
+
else:
|
268 |
+
self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, attn_mask,) for i in range(layers)])
|
269 |
+
|
270 |
+
def ori_CLIP_with_patch_forward(self, x, out_layers):
|
271 |
+
idx = 0
|
272 |
+
out_tokens = []
|
273 |
+
for r in self.resblocks:
|
274 |
+
idx += 1
|
275 |
+
x = r(x)
|
276 |
+
if idx in out_layers:
|
277 |
+
if isinstance(x, list):
|
278 |
+
out_tokens.append(x[1])
|
279 |
+
else:
|
280 |
+
out_tokens.append(x)
|
281 |
+
|
282 |
+
return [x, x], out_tokens
|
283 |
+
|
284 |
+
def AnomalyCLIP_forward(self, x, out_layers, ffn):
|
285 |
+
idx = 0
|
286 |
+
out_tokens = []
|
287 |
+
for r in self.resblocks:
|
288 |
+
idx += 1
|
289 |
+
x = r(x, ffn = ffn)
|
290 |
+
# print("out_layers", out_layers, idx)
|
291 |
+
if idx in out_layers:
|
292 |
+
if isinstance(x, list):
|
293 |
+
out_tokens.append(x[0])
|
294 |
+
else:
|
295 |
+
out_tokens.append(x)
|
296 |
+
return x, out_tokens
|
297 |
+
|
298 |
+
def forward(self, x: torch.Tensor, out_layers = [6, 12, 18, 24], DPAM_layer = None, ffn = False):
|
299 |
+
# visual encoder forward
|
300 |
+
if not self.text_layer:
|
301 |
+
out_tokens = []
|
302 |
+
|
303 |
+
if DPAM_layer is None:
|
304 |
+
[x, x], out_tokens = self.ori_CLIP_with_patch_forward(x, out_layers)
|
305 |
+
return [x, x], out_tokens
|
306 |
+
else:
|
307 |
+
x, out_tokens = self.AnomalyCLIP_forward(x, out_layers, ffn)
|
308 |
+
return x, out_tokens
|
309 |
+
# text encoder forward
|
310 |
+
# ori text embedding
|
311 |
+
elif self.design_deatails is None:
|
312 |
+
for idx, r in enumerate(self.resblocks):
|
313 |
+
x = r(x)
|
314 |
+
return x
|
315 |
+
# insert learnable text embedding
|
316 |
+
elif self.design_deatails is not None:
|
317 |
+
for idx, r in enumerate(self.resblocks):
|
318 |
+
x = r(x)
|
319 |
+
return x[0]
|
320 |
+
def get_cast_dtype(self) -> torch.dtype:
|
321 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
322 |
+
|
323 |
+
class VisionTransformer(nn.Module):
|
324 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
325 |
+
super().__init__()
|
326 |
+
self.input_resolution = input_resolution
|
327 |
+
self.output_dim = output_dim
|
328 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
329 |
+
|
330 |
+
scale = width ** -0.5
|
331 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
332 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
333 |
+
self.ln_pre = LayerNorm(width)
|
334 |
+
|
335 |
+
self.transformer = Transformer(width, layers, heads, need_weights=True)
|
336 |
+
self.attn = None
|
337 |
+
self.embed_dim = width
|
338 |
+
self.num_heads = heads
|
339 |
+
|
340 |
+
self.ln_post = LayerNorm(width)
|
341 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
342 |
+
|
343 |
+
|
344 |
+
@torch.no_grad()
|
345 |
+
def DAPM_replace(self, DPAM_layer):
|
346 |
+
if DPAM_layer is not None:
|
347 |
+
for i in range(1, DPAM_layer):
|
348 |
+
self.attn = Attention(self.embed_dim, self.embed_dim, self.num_heads, True)
|
349 |
+
self.attn.qkv.weight.data = self.transformer.resblocks[-i].attn.in_proj_weight.clone()
|
350 |
+
self.attn.qkv.bias.data = self.transformer.resblocks[-i].attn.in_proj_bias.clone()
|
351 |
+
self.attn.proj.weight.data = self.transformer.resblocks[-i].attn.out_proj.weight.clone()
|
352 |
+
self.attn.proj.bias.data = self.transformer.resblocks[-i].attn.out_proj.bias.clone()
|
353 |
+
self.transformer.resblocks[-i].attn = self.attn
|
354 |
+
|
355 |
+
@torch.no_grad()
|
356 |
+
def forward(self, x: torch.Tensor, features_list, ori_patch = False, proj_use = True, DPAM_layer = None, ffn = False):
|
357 |
+
|
358 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
359 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
360 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
361 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
362 |
+
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
|
363 |
+
new_side = int((x.shape[1] - 1) ** 0.5)
|
364 |
+
|
365 |
+
# update the position embedding during inference for varied input size
|
366 |
+
if side != new_side:
|
367 |
+
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
|
368 |
+
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
|
369 |
+
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
|
370 |
+
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
|
371 |
+
|
372 |
+
pos = self.positional_embedding.to(x.dtype)
|
373 |
+
x = x + pos
|
374 |
+
x = self.ln_pre(x)
|
375 |
+
|
376 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
377 |
+
[x, x_ori], patch_tokens = self.transformer(x, features_list, DPAM_layer = DPAM_layer, ffn = ffn)
|
378 |
+
|
379 |
+
|
380 |
+
if True:
|
381 |
+
patch_token_list = []
|
382 |
+
for patch_token in patch_tokens:
|
383 |
+
patch_token = self.ln_post(patch_token.permute(1, 0, 2)) @ self.proj # LND -> NLD
|
384 |
+
patch_token_list.append(patch_token)
|
385 |
+
patch_tokens = patch_token_list
|
386 |
+
|
387 |
+
return x_ori[0, :, :] @ self.proj, patch_tokens
|
388 |
+
|
389 |
+
|
390 |
+
return x
|
391 |
+
|
392 |
+
|
393 |
+
from thop import profile
|
394 |
+
class AnomalyCLIP(nn.Module):
|
395 |
+
def __init__(self,
|
396 |
+
embed_dim: int,
|
397 |
+
# vision
|
398 |
+
image_resolution: int,
|
399 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
400 |
+
vision_width: int,
|
401 |
+
vision_patch_size: int,
|
402 |
+
# text
|
403 |
+
context_length: int,
|
404 |
+
vocab_size: int,
|
405 |
+
transformer_width: int,
|
406 |
+
transformer_heads: int,
|
407 |
+
transformer_layers: int,
|
408 |
+
design_details = None
|
409 |
+
):
|
410 |
+
super().__init__()
|
411 |
+
|
412 |
+
self.context_length = context_length
|
413 |
+
|
414 |
+
if isinstance(vision_layers, (tuple, list)):
|
415 |
+
vision_heads = vision_width * 32 // 64
|
416 |
+
self.visual = ModifiedResNet(
|
417 |
+
layers=vision_layers,
|
418 |
+
output_dim=embed_dim,
|
419 |
+
heads=vision_heads,
|
420 |
+
input_resolution=image_resolution,
|
421 |
+
width=vision_width
|
422 |
+
)
|
423 |
+
else:
|
424 |
+
vision_heads = vision_width // 64
|
425 |
+
self.visual = VisionTransformer(
|
426 |
+
input_resolution=image_resolution,
|
427 |
+
patch_size=vision_patch_size,
|
428 |
+
width=vision_width,
|
429 |
+
layers=vision_layers,
|
430 |
+
heads=vision_heads,
|
431 |
+
output_dim=embed_dim
|
432 |
+
)
|
433 |
+
|
434 |
+
self.transformer = Transformer(
|
435 |
+
width=transformer_width,
|
436 |
+
layers=transformer_layers,
|
437 |
+
heads=transformer_heads,
|
438 |
+
attn_mask=self.build_attention_mask(), text_layer=True, design_details=design_details
|
439 |
+
)
|
440 |
+
|
441 |
+
self.vocab_size = vocab_size
|
442 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
443 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
444 |
+
self.ln_final = LayerNorm(transformer_width)
|
445 |
+
|
446 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
447 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
448 |
+
|
449 |
+
self.initialize_parameters()
|
450 |
+
|
451 |
+
def initialize_parameters(self):
|
452 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
453 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
454 |
+
|
455 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
456 |
+
attn_std = self.transformer.width ** -0.5
|
457 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
458 |
+
for block in self.transformer.resblocks:
|
459 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
460 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
461 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
462 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
463 |
+
|
464 |
+
if self.text_projection is not None:
|
465 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
466 |
+
def build_attention_mask(self):
|
467 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
468 |
+
# pytorch uses additive attention mask; fill with -inf
|
469 |
+
mask = torch.empty(self.context_length, self.context_length)
|
470 |
+
mask.fill_(float("-inf"))
|
471 |
+
mask.triu_(1) # zero out the lower diagonal
|
472 |
+
return mask
|
473 |
+
|
474 |
+
@property
|
475 |
+
def dtype(self):
|
476 |
+
return self.visual.conv1.weight.dtype
|
477 |
+
|
478 |
+
def encode_image(self, image, feature_list = [], ori_patch = False, proj_use = True, DPAM_layer = None, ffn = False):
|
479 |
+
return self.visual(image.type(self.dtype), feature_list, ori_patch = ori_patch, proj_use = proj_use, DPAM_layer = DPAM_layer, ffn = ffn)
|
480 |
+
|
481 |
+
|
482 |
+
def encode_text(self, text):
|
483 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
484 |
+
|
485 |
+
x = x + self.positional_embedding.type(self.dtype)
|
486 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
487 |
+
x = self.transformer(x)
|
488 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
489 |
+
x = self.ln_final(x).type(self.dtype)
|
490 |
+
|
491 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
492 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
493 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
494 |
+
|
495 |
+
return x
|
496 |
+
|
497 |
+
def encode_text_learn(self, prompts, tokenized_prompts, deep_compound_prompts_text = None, normalize: bool = False):
|
498 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
499 |
+
|
500 |
+
# x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
501 |
+
|
502 |
+
# x = x + self.positional_embedding.to(cast_dtype)
|
503 |
+
|
504 |
+
x = prompts + self.positional_embedding.to(cast_dtype)
|
505 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
506 |
+
# print("test", x.shape, len(deep_compound_prompts_text))
|
507 |
+
if deep_compound_prompts_text is None:
|
508 |
+
x = self.transformer(x)
|
509 |
+
else:
|
510 |
+
x = self.transformer([x, deep_compound_prompts_text, 0])
|
511 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
512 |
+
x = self.ln_final(x).type(self.dtype) # [batch_size, n_ctx, transformer.width]
|
513 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
514 |
+
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
|
515 |
+
return x
|
516 |
+
|
517 |
+
def forward(self, image, text):
|
518 |
+
image_features = self.encode_image(image)
|
519 |
+
text_features = self.encode_text(text)
|
520 |
+
|
521 |
+
# normalized features
|
522 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
523 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
524 |
+
|
525 |
+
# cosine similarity as logits
|
526 |
+
logit_scale = self.logit_scale.exp()
|
527 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
528 |
+
logits_per_text = logits_per_image.t()
|
529 |
+
|
530 |
+
# shape = [global_batch_size, global_batch_size]
|
531 |
+
return logits_per_image, logits_per_text
|
AnomalyCLIP_lib/CLIP.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class Bottleneck(nn.Module):
|
11 |
+
expansion = 4
|
12 |
+
|
13 |
+
def __init__(self, inplanes, planes, stride=1):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
self.relu1 = nn.ReLU(inplace=True)
|
20 |
+
|
21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
23 |
+
self.relu2 = nn.ReLU(inplace=True)
|
24 |
+
|
25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
26 |
+
|
27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
29 |
+
self.relu3 = nn.ReLU(inplace=True)
|
30 |
+
|
31 |
+
self.downsample = None
|
32 |
+
self.stride = stride
|
33 |
+
|
34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
36 |
+
self.downsample = nn.Sequential(OrderedDict([
|
37 |
+
("-1", nn.AvgPool2d(stride)),
|
38 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
39 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
40 |
+
]))
|
41 |
+
|
42 |
+
def forward(self, x: torch.Tensor):
|
43 |
+
identity = x
|
44 |
+
|
45 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
46 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
47 |
+
out = self.avgpool(out)
|
48 |
+
out = self.bn3(self.conv3(out))
|
49 |
+
|
50 |
+
if self.downsample is not None:
|
51 |
+
identity = self.downsample(x)
|
52 |
+
|
53 |
+
out += identity
|
54 |
+
out = self.relu3(out)
|
55 |
+
return out
|
56 |
+
|
57 |
+
|
58 |
+
class AttentionPool2d(nn.Module):
|
59 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
60 |
+
super().__init__()
|
61 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
62 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
63 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
64 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
65 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
66 |
+
self.num_heads = num_heads
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
70 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
71 |
+
|
72 |
+
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
|
73 |
+
new_side = int((x.shape[0] - 1) ** 0.5)
|
74 |
+
|
75 |
+
# update the position embedding during inference for varied input size
|
76 |
+
if side != new_side:
|
77 |
+
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
|
78 |
+
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
|
79 |
+
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
|
80 |
+
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
|
81 |
+
|
82 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
83 |
+
x, _ = F.multi_head_attention_forward(
|
84 |
+
query=x, key=x, value=x,
|
85 |
+
embed_dim_to_check=x.shape[-1],
|
86 |
+
num_heads=self.num_heads,
|
87 |
+
q_proj_weight=self.q_proj.weight,
|
88 |
+
k_proj_weight=self.k_proj.weight,
|
89 |
+
v_proj_weight=self.v_proj.weight,
|
90 |
+
in_proj_weight=None,
|
91 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
92 |
+
bias_k=None,
|
93 |
+
bias_v=None,
|
94 |
+
add_zero_attn=False,
|
95 |
+
dropout_p=0,
|
96 |
+
out_proj_weight=self.c_proj.weight,
|
97 |
+
out_proj_bias=self.c_proj.bias,
|
98 |
+
use_separate_proj_weight=True,
|
99 |
+
training=self.training,
|
100 |
+
need_weights=False
|
101 |
+
)
|
102 |
+
|
103 |
+
#return x[0]
|
104 |
+
return x.transpose(0, 1) # return both cls token and image tokens, B,N,C
|
105 |
+
|
106 |
+
|
107 |
+
class ModifiedResNet(nn.Module):
|
108 |
+
"""
|
109 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
110 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
111 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
112 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
116 |
+
super().__init__()
|
117 |
+
self.output_dim = output_dim
|
118 |
+
self.input_resolution = input_resolution
|
119 |
+
|
120 |
+
# the 3-layer stem
|
121 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
122 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
123 |
+
self.relu1 = nn.ReLU(inplace=True)
|
124 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
125 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
126 |
+
self.relu2 = nn.ReLU(inplace=True)
|
127 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
128 |
+
self.bn3 = nn.BatchNorm2d(width)
|
129 |
+
self.relu3 = nn.ReLU(inplace=True)
|
130 |
+
self.avgpool = nn.AvgPool2d(2)
|
131 |
+
|
132 |
+
# residual layers
|
133 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
134 |
+
self.layer1 = self._make_layer(width, layers[0])
|
135 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
136 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
137 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
138 |
+
|
139 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
140 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
141 |
+
|
142 |
+
def _make_layer(self, planes, blocks, stride=1):
|
143 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
144 |
+
|
145 |
+
self._inplanes = planes * Bottleneck.expansion
|
146 |
+
for _ in range(1, blocks):
|
147 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
148 |
+
|
149 |
+
return nn.Sequential(*layers)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
def stem(x):
|
153 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
154 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
155 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
156 |
+
x = self.avgpool(x)
|
157 |
+
return x
|
158 |
+
|
159 |
+
x = x.type(self.conv1.weight.dtype)
|
160 |
+
x = stem(x)
|
161 |
+
x = self.layer1(x)
|
162 |
+
x = self.layer2(x)
|
163 |
+
x = self.layer3(x)
|
164 |
+
x = self.layer4(x)
|
165 |
+
x = self.attnpool(x)
|
166 |
+
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class LayerNorm(nn.LayerNorm):
|
171 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
172 |
+
|
173 |
+
def forward(self, x: torch.Tensor):
|
174 |
+
orig_type = x.dtype
|
175 |
+
ret = super().forward(x.type(torch.float32))
|
176 |
+
return ret.type(orig_type)
|
177 |
+
|
178 |
+
|
179 |
+
class QuickGELU(nn.Module):
|
180 |
+
def forward(self, x: torch.Tensor):
|
181 |
+
return x * torch.sigmoid(1.702 * x)
|
182 |
+
|
183 |
+
|
184 |
+
class ResidualAttentionBlock(nn.Module):
|
185 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
|
186 |
+
super().__init__()
|
187 |
+
|
188 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
189 |
+
self.ln_1 = LayerNorm(d_model)
|
190 |
+
self.mlp = nn.Sequential(OrderedDict([
|
191 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
192 |
+
("gelu", QuickGELU()),
|
193 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
194 |
+
]))
|
195 |
+
self.ln_2 = LayerNorm(d_model)
|
196 |
+
self.attn_mask = attn_mask
|
197 |
+
self.need_weights = need_weights
|
198 |
+
|
199 |
+
def attention(self, x: torch.Tensor):
|
200 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
201 |
+
if self.need_weights == False:
|
202 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
203 |
+
else:
|
204 |
+
return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)
|
205 |
+
|
206 |
+
def forward(self, x: torch.Tensor):
|
207 |
+
if self.need_weights == False:
|
208 |
+
x = x + self.attention(self.ln_1(x))
|
209 |
+
x = x + self.mlp(self.ln_2(x))
|
210 |
+
return x
|
211 |
+
else:
|
212 |
+
y, attn = self.attention(self.ln_1(x))
|
213 |
+
x = x + y
|
214 |
+
x = x + self.mlp(self.ln_2(x))
|
215 |
+
return x
|
216 |
+
|
217 |
+
|
218 |
+
class Transformer(nn.Module):
|
219 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
|
220 |
+
super().__init__()
|
221 |
+
self.width = width
|
222 |
+
self.layers = layers
|
223 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, need_weights if i == layers - 1 else False) for i in range(layers)])
|
224 |
+
|
225 |
+
def forward(self, x: torch.Tensor):
|
226 |
+
return self.resblocks(x)
|
227 |
+
|
228 |
+
def get_cast_dtype(self) -> torch.dtype:
|
229 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
class VisionTransformer(nn.Module):
|
234 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
235 |
+
super().__init__()
|
236 |
+
self.input_resolution = input_resolution
|
237 |
+
self.output_dim = output_dim
|
238 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
239 |
+
|
240 |
+
scale = width ** -0.5
|
241 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
242 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
243 |
+
self.ln_pre = LayerNorm(width)
|
244 |
+
|
245 |
+
self.transformer = Transformer(width, layers, heads, need_weights=True)
|
246 |
+
|
247 |
+
self.ln_post = LayerNorm(width)
|
248 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
249 |
+
|
250 |
+
def forward(self, x: torch.Tensor):
|
251 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
252 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
253 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
254 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
255 |
+
|
256 |
+
#####################################################################################
|
257 |
+
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
|
258 |
+
new_side = int((x.shape[1] - 1) ** 0.5)
|
259 |
+
|
260 |
+
# update the position embedding during inference for varied input size
|
261 |
+
if side != new_side:
|
262 |
+
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
|
263 |
+
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
|
264 |
+
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
|
265 |
+
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
|
266 |
+
#####################################################################################
|
267 |
+
|
268 |
+
|
269 |
+
x = x + self.positional_embedding.to(x.dtype)
|
270 |
+
x = self.ln_pre(x)
|
271 |
+
|
272 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
273 |
+
x = self.transformer(x)
|
274 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
275 |
+
|
276 |
+
#x = self.ln_post(x[:, 0, :])
|
277 |
+
x = self.ln_post(x) # return both cls token and image tokens
|
278 |
+
|
279 |
+
if self.proj is not None:
|
280 |
+
x = x @ self.proj
|
281 |
+
|
282 |
+
return x
|
283 |
+
|
284 |
+
|
285 |
+
class CLIP(nn.Module):
|
286 |
+
def __init__(self,
|
287 |
+
embed_dim: int,
|
288 |
+
# vision
|
289 |
+
image_resolution: int,
|
290 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
291 |
+
vision_width: int,
|
292 |
+
vision_patch_size: int,
|
293 |
+
# text
|
294 |
+
context_length: int,
|
295 |
+
vocab_size: int,
|
296 |
+
transformer_width: int,
|
297 |
+
transformer_heads: int,
|
298 |
+
transformer_layers: int
|
299 |
+
):
|
300 |
+
super().__init__()
|
301 |
+
|
302 |
+
self.context_length = context_length
|
303 |
+
|
304 |
+
if isinstance(vision_layers, (tuple, list)):
|
305 |
+
vision_heads = vision_width * 32 // 64
|
306 |
+
self.visual = ModifiedResNet(
|
307 |
+
layers=vision_layers,
|
308 |
+
output_dim=embed_dim,
|
309 |
+
heads=vision_heads,
|
310 |
+
input_resolution=image_resolution,
|
311 |
+
width=vision_width
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
vision_heads = vision_width // 64
|
315 |
+
self.visual = VisionTransformer(
|
316 |
+
input_resolution=image_resolution,
|
317 |
+
patch_size=vision_patch_size,
|
318 |
+
width=vision_width,
|
319 |
+
layers=vision_layers,
|
320 |
+
heads=vision_heads,
|
321 |
+
output_dim=embed_dim
|
322 |
+
)
|
323 |
+
|
324 |
+
self.transformer = Transformer(
|
325 |
+
width=transformer_width,
|
326 |
+
layers=transformer_layers,
|
327 |
+
heads=transformer_heads,
|
328 |
+
attn_mask=self.build_attention_mask()
|
329 |
+
)
|
330 |
+
|
331 |
+
self.vocab_size = vocab_size
|
332 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
333 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
334 |
+
self.ln_final = LayerNorm(transformer_width)
|
335 |
+
|
336 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
337 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
338 |
+
|
339 |
+
self.initialize_parameters()
|
340 |
+
|
341 |
+
def initialize_parameters(self):
|
342 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
343 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
344 |
+
|
345 |
+
if isinstance(self.visual, ModifiedResNet):
|
346 |
+
if self.visual.attnpool is not None:
|
347 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
348 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
349 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
350 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
351 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
352 |
+
|
353 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
354 |
+
for name, param in resnet_block.named_parameters():
|
355 |
+
if name.endswith("bn3.weight"):
|
356 |
+
nn.init.zeros_(param)
|
357 |
+
|
358 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
359 |
+
attn_std = self.transformer.width ** -0.5
|
360 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
361 |
+
for block in self.transformer.resblocks:
|
362 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
363 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
364 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
365 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
366 |
+
|
367 |
+
if self.text_projection is not None:
|
368 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
369 |
+
|
370 |
+
def build_attention_mask(self):
|
371 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
372 |
+
# pytorch uses additive attention mask; fill with -inf
|
373 |
+
mask = torch.empty(self.context_length, self.context_length)
|
374 |
+
mask.fill_(float("-inf"))
|
375 |
+
mask.triu_(1) # zero out the lower diagonal
|
376 |
+
return mask
|
377 |
+
|
378 |
+
@property
|
379 |
+
def dtype(self):
|
380 |
+
return self.visual.conv1.weight.dtype
|
381 |
+
|
382 |
+
def encode_image(self, image):
|
383 |
+
return self.visual(image.type(self.dtype))
|
384 |
+
|
385 |
+
def encode_text(self, text):
|
386 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
387 |
+
|
388 |
+
x = x + self.positional_embedding.type(self.dtype)
|
389 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
390 |
+
x = self.transformer(x)
|
391 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
392 |
+
x = self.ln_final(x).type(self.dtype)
|
393 |
+
|
394 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
395 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
396 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
397 |
+
|
398 |
+
return x
|
399 |
+
|
400 |
+
def encode_text_learn(self, prompts, tokenized_prompts, deep_compound_prompts_text = None, normalize: bool = False):
|
401 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
402 |
+
|
403 |
+
# x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
404 |
+
|
405 |
+
# x = x + self.positional_embedding.to(cast_dtype)
|
406 |
+
|
407 |
+
x = prompts + self.positional_embedding.to(cast_dtype)
|
408 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
409 |
+
# print("test", x.shape, len(deep_compound_prompts_text))
|
410 |
+
if deep_compound_prompts_text is None:
|
411 |
+
x = self.transformer(x)
|
412 |
+
else:
|
413 |
+
x = self.transformer([x, deep_compound_prompts_text, 0])
|
414 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
415 |
+
x = self.ln_final(x).type(self.dtype) # [batch_size, n_ctx, transformer.width]
|
416 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
417 |
+
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
|
418 |
+
return x
|
419 |
+
|
420 |
+
|
421 |
+
|
422 |
+
def forward(self, image, text):
|
423 |
+
image_features = self.encode_image(image)
|
424 |
+
text_features = self.encode_text(text)
|
425 |
+
|
426 |
+
# normalized features
|
427 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
428 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
429 |
+
|
430 |
+
# cosine similarity as logits
|
431 |
+
logit_scale = self.logit_scale.exp()
|
432 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
433 |
+
logits_per_text = logits_per_image.t()
|
434 |
+
|
435 |
+
# shape = [global_batch_size, global_batch_size]
|
436 |
+
return logits_per_image, logits_per_text
|
AnomalyCLIP_lib/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model_load import *
|
AnomalyCLIP_lib/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
AnomalyCLIP_lib/build_model.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from .CLIP import CLIP
|
3 |
+
from .AnomalyCLIP import AnomalyCLIP
|
4 |
+
|
5 |
+
def build_model(name: str, state_dict: dict, design_details = None):
|
6 |
+
vit = "visual.proj" in state_dict
|
7 |
+
|
8 |
+
if vit:
|
9 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
10 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
11 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
12 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
13 |
+
image_resolution = vision_patch_size * grid_size
|
14 |
+
else:
|
15 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
16 |
+
vision_layers = tuple(counts)
|
17 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
18 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
19 |
+
vision_patch_size = None
|
20 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
21 |
+
image_resolution = output_width * 32
|
22 |
+
|
23 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
24 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
25 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
26 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
27 |
+
transformer_heads = transformer_width // 64
|
28 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
29 |
+
# print('name', name)
|
30 |
+
# if 'CS-' in name:
|
31 |
+
if design_details is not None:
|
32 |
+
model = AnomalyCLIP(
|
33 |
+
embed_dim,
|
34 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
35 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, design_details = design_details
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
model = CLIP(
|
39 |
+
embed_dim,
|
40 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
41 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
42 |
+
)
|
43 |
+
|
44 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
45 |
+
if key in state_dict:
|
46 |
+
del state_dict[key]
|
47 |
+
|
48 |
+
#convert_weights(model)
|
49 |
+
model.load_state_dict(state_dict)
|
50 |
+
return model.eval()
|
AnomalyCLIP_lib/constants.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
2 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
AnomalyCLIP_lib/model_load.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from typing import Union, List
|
6 |
+
from pkg_resources import packaging
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
11 |
+
from tqdm import tqdm
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from .build_model import build_model
|
15 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
16 |
+
from torchvision.transforms import InterpolationMode
|
17 |
+
|
18 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
19 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
20 |
+
|
21 |
+
|
22 |
+
__all__ = ["available_models", "load",
|
23 |
+
"get_similarity_map", "compute_similarity"]
|
24 |
+
_tokenizer = _Tokenizer()
|
25 |
+
|
26 |
+
_MODELS = {
|
27 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def _download(
|
32 |
+
url: str,
|
33 |
+
cache_dir: Union[str, None] = None,
|
34 |
+
):
|
35 |
+
|
36 |
+
if not cache_dir:
|
37 |
+
# cache_dir = os.path.expanduser("~/.cache/clip")
|
38 |
+
cache_dir = os.path.expanduser("/remote-home/iot_zhouqihang/root/.cache/clip")
|
39 |
+
os.makedirs(cache_dir, exist_ok=True)
|
40 |
+
filename = os.path.basename(url)
|
41 |
+
|
42 |
+
if 'openaipublic' in url:
|
43 |
+
expected_sha256 = url.split("/")[-2]
|
44 |
+
elif 'mlfoundations' in url:
|
45 |
+
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
|
46 |
+
else:
|
47 |
+
expected_sha256 = ''
|
48 |
+
|
49 |
+
download_target = os.path.join(cache_dir, filename)
|
50 |
+
|
51 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
52 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
53 |
+
|
54 |
+
if os.path.isfile(download_target):
|
55 |
+
if expected_sha256:
|
56 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
57 |
+
return download_target
|
58 |
+
else:
|
59 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
60 |
+
else:
|
61 |
+
return download_target
|
62 |
+
|
63 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
64 |
+
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
65 |
+
while True:
|
66 |
+
buffer = source.read(8192)
|
67 |
+
if not buffer:
|
68 |
+
break
|
69 |
+
|
70 |
+
output.write(buffer)
|
71 |
+
loop.update(len(buffer))
|
72 |
+
|
73 |
+
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
74 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
75 |
+
|
76 |
+
return download_target
|
77 |
+
|
78 |
+
|
79 |
+
def _convert_image_to_rgb(image):
|
80 |
+
return image.convert("RGB")
|
81 |
+
|
82 |
+
|
83 |
+
def _transform(n_px):
|
84 |
+
return Compose([
|
85 |
+
Resize((n_px, n_px), interpolation=InterpolationMode.BICUBIC),
|
86 |
+
#CenterCrop(n_px), # rm center crop to explain whole image
|
87 |
+
_convert_image_to_rgb,
|
88 |
+
ToTensor(),
|
89 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
90 |
+
])
|
91 |
+
|
92 |
+
|
93 |
+
def available_models() -> List[str]:
|
94 |
+
"""Returns the names of available CLIP models"""
|
95 |
+
return list(_MODELS.keys())
|
96 |
+
|
97 |
+
|
98 |
+
def load_state_dict(checkpoint_path: str, map_location='cpu'):
|
99 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
100 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
101 |
+
state_dict = checkpoint['state_dict']
|
102 |
+
else:
|
103 |
+
state_dict = checkpoint
|
104 |
+
if next(iter(state_dict.items()))[0].startswith('module'):
|
105 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
106 |
+
return state_dict
|
107 |
+
|
108 |
+
def load_checkpoint(model, checkpoint_path, strict=True):
|
109 |
+
state_dict = load_state_dict(checkpoint_path)
|
110 |
+
# detect old format and make compatible with new format
|
111 |
+
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
112 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
113 |
+
resize_pos_embed(state_dict, model)
|
114 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
115 |
+
return incompatible_keys
|
116 |
+
|
117 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", design_details = None, jit: bool = False, download_root: str = None):
|
118 |
+
"""Load a CLIP model
|
119 |
+
|
120 |
+
Parameters
|
121 |
+
----------
|
122 |
+
name : str
|
123 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
124 |
+
|
125 |
+
device : Union[str, torch.device]
|
126 |
+
The device to put the loaded model
|
127 |
+
|
128 |
+
jit : bool
|
129 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
130 |
+
|
131 |
+
download_root: str
|
132 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
133 |
+
|
134 |
+
Returns
|
135 |
+
-------
|
136 |
+
model : torch.nn.Module
|
137 |
+
The CLIP model
|
138 |
+
|
139 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
140 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
141 |
+
"""
|
142 |
+
print("name", name)
|
143 |
+
if name in _MODELS:
|
144 |
+
# model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
145 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("/remote-home/iot_zhouqihang/root/.cache/clip"))
|
146 |
+
elif os.path.isfile(name):
|
147 |
+
model_path = name
|
148 |
+
else:
|
149 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
150 |
+
|
151 |
+
with open(model_path, 'rb') as opened_file:
|
152 |
+
try:
|
153 |
+
# loading JIT archive
|
154 |
+
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
155 |
+
state_dict = None
|
156 |
+
except RuntimeError:
|
157 |
+
# loading saved state dict
|
158 |
+
if jit:
|
159 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
160 |
+
jit = False
|
161 |
+
state_dict = torch.load(opened_file, map_location="cpu")
|
162 |
+
|
163 |
+
if not jit:
|
164 |
+
model = build_model(name, state_dict or model.state_dict(), design_details).to(device)
|
165 |
+
if str(device) == "cpu":
|
166 |
+
model.float()
|
167 |
+
return model, _transform(model.visual.input_resolution)
|
168 |
+
|
169 |
+
# patch the device names
|
170 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
171 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
172 |
+
|
173 |
+
def patch_device(module):
|
174 |
+
try:
|
175 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
176 |
+
except RuntimeError:
|
177 |
+
graphs = []
|
178 |
+
|
179 |
+
if hasattr(module, "forward1"):
|
180 |
+
graphs.append(module.forward1.graph)
|
181 |
+
|
182 |
+
for graph in graphs:
|
183 |
+
for node in graph.findAllNodes("prim::Constant"):
|
184 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
185 |
+
node.copyAttributes(device_node)
|
186 |
+
|
187 |
+
model.apply(patch_device)
|
188 |
+
patch_device(model.encode_image)
|
189 |
+
patch_device(model.encode_text)
|
190 |
+
|
191 |
+
# patch dtype to float32 on CPU
|
192 |
+
if str(device) == "cpu":
|
193 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
194 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
195 |
+
float_node = float_input.node()
|
196 |
+
|
197 |
+
def patch_float(module):
|
198 |
+
try:
|
199 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
200 |
+
except RuntimeError:
|
201 |
+
graphs = []
|
202 |
+
|
203 |
+
if hasattr(module, "forward1"):
|
204 |
+
graphs.append(module.forward1.graph)
|
205 |
+
|
206 |
+
for graph in graphs:
|
207 |
+
for node in graph.findAllNodes("aten::to"):
|
208 |
+
inputs = list(node.inputs())
|
209 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
210 |
+
if inputs[i].node()["value"] == 5:
|
211 |
+
inputs[i].node().copyAttributes(float_node)
|
212 |
+
|
213 |
+
model.apply(patch_float)
|
214 |
+
patch_float(model.encode_image)
|
215 |
+
patch_float(model.encode_text)
|
216 |
+
|
217 |
+
model.float()
|
218 |
+
|
219 |
+
return model, _transform(model.input_resolution.item())
|
220 |
+
|
221 |
+
|
222 |
+
def get_similarity_map(sm, shape):
|
223 |
+
side = int(sm.shape[1] ** 0.5)
|
224 |
+
sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2)
|
225 |
+
sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear')
|
226 |
+
sm = sm.permute(0, 2, 3, 1)
|
227 |
+
return sm
|
228 |
+
|
229 |
+
|
230 |
+
def compute_similarity(image_features, text_features, t=2):
|
231 |
+
prob_1 = image_features[:, :1, :] @ text_features.t()
|
232 |
+
b, n_t, n_i, c = image_features.shape[0], text_features.shape[0], image_features.shape[1], image_features.shape[2]
|
233 |
+
feats = image_features.reshape(b, n_i, 1, c) * text_features.reshape(1, 1, n_t, c)
|
234 |
+
similarity = feats.sum(-1)
|
235 |
+
return (similarity/0.07).softmax(-1), prob_1
|
AnomalyCLIP_lib/simple_tokenizer.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
import ftfy
|
7 |
+
import regex as re
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def default_bpe():
|
12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
+
|
14 |
+
|
15 |
+
@lru_cache()
|
16 |
+
def bytes_to_unicode():
|
17 |
+
"""
|
18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
+
The reversible bpe codes work on unicode strings.
|
20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
+
"""
|
26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
27 |
+
cs = bs[:]
|
28 |
+
n = 0
|
29 |
+
for b in range(2**8):
|
30 |
+
if b not in bs:
|
31 |
+
bs.append(b)
|
32 |
+
cs.append(2**8+n)
|
33 |
+
n += 1
|
34 |
+
cs = [chr(n) for n in cs]
|
35 |
+
return dict(zip(bs, cs))
|
36 |
+
|
37 |
+
|
38 |
+
def get_pairs(word):
|
39 |
+
"""Return set of symbol pairs in a word.
|
40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
+
"""
|
42 |
+
pairs = set()
|
43 |
+
prev_char = word[0]
|
44 |
+
for char in word[1:]:
|
45 |
+
pairs.add((prev_char, char))
|
46 |
+
prev_char = char
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def basic_clean(text):
|
51 |
+
text = ftfy.fix_text(text)
|
52 |
+
text = html.unescape(html.unescape(text))
|
53 |
+
return text.strip()
|
54 |
+
|
55 |
+
|
56 |
+
def whitespace_clean(text):
|
57 |
+
text = re.sub(r'\s+', ' ', text)
|
58 |
+
text = text.strip()
|
59 |
+
return text
|
60 |
+
|
61 |
+
|
62 |
+
class SimpleTokenizer(object):
|
63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
+
self.byte_encoder = bytes_to_unicode()
|
65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
+
merges = merges[1:49152-256-2+1]
|
68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
69 |
+
vocab = list(bytes_to_unicode().values())
|
70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
+
for merge in merges:
|
72 |
+
vocab.append(''.join(merge))
|
73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
+
|
80 |
+
def bpe(self, token):
|
81 |
+
if token in self.cache:
|
82 |
+
return self.cache[token]
|
83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
84 |
+
pairs = get_pairs(word)
|
85 |
+
|
86 |
+
if not pairs:
|
87 |
+
return token+'</w>'
|
88 |
+
|
89 |
+
while True:
|
90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
91 |
+
if bigram not in self.bpe_ranks:
|
92 |
+
break
|
93 |
+
first, second = bigram
|
94 |
+
new_word = []
|
95 |
+
i = 0
|
96 |
+
while i < len(word):
|
97 |
+
try:
|
98 |
+
j = word.index(first, i)
|
99 |
+
new_word.extend(word[i:j])
|
100 |
+
i = j
|
101 |
+
except:
|
102 |
+
new_word.extend(word[i:])
|
103 |
+
break
|
104 |
+
|
105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
106 |
+
new_word.append(first+second)
|
107 |
+
i += 2
|
108 |
+
else:
|
109 |
+
new_word.append(word[i])
|
110 |
+
i += 1
|
111 |
+
new_word = tuple(new_word)
|
112 |
+
word = new_word
|
113 |
+
if len(word) == 1:
|
114 |
+
break
|
115 |
+
else:
|
116 |
+
pairs = get_pairs(word)
|
117 |
+
word = ' '.join(word)
|
118 |
+
self.cache[token] = word
|
119 |
+
return word
|
120 |
+
|
121 |
+
def encode(self, text):
|
122 |
+
bpe_tokens = []
|
123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
124 |
+
for token in re.findall(self.pat, text):
|
125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
127 |
+
return bpe_tokens
|
128 |
+
|
129 |
+
def decode(self, tokens):
|
130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
132 |
+
return text
|
AnomalyCLIP_lib/transform.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from dataclasses import dataclass, asdict
|
3 |
+
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torchvision.transforms.functional as F
|
8 |
+
|
9 |
+
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
|
10 |
+
CenterCrop
|
11 |
+
|
12 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class AugmentationCfg:
|
17 |
+
scale: Tuple[float, float] = (0.9, 1.0)
|
18 |
+
ratio: Optional[Tuple[float, float]] = None
|
19 |
+
color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
|
20 |
+
interpolation: Optional[str] = None
|
21 |
+
re_prob: Optional[float] = None
|
22 |
+
re_count: Optional[int] = None
|
23 |
+
use_timm: bool = False
|
24 |
+
|
25 |
+
|
26 |
+
class ResizeMaxSize(nn.Module):
|
27 |
+
|
28 |
+
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
|
29 |
+
super().__init__()
|
30 |
+
if not isinstance(max_size, int):
|
31 |
+
raise TypeError(f"Size should be int. Got {type(max_size)}")
|
32 |
+
self.max_size = max_size
|
33 |
+
self.interpolation = interpolation
|
34 |
+
self.fn = min if fn == 'min' else min
|
35 |
+
self.fill = fill
|
36 |
+
|
37 |
+
def forward(self, img):
|
38 |
+
if isinstance(img, torch.Tensor):
|
39 |
+
height, width = img.shape[:2]
|
40 |
+
else:
|
41 |
+
width, height = img.size
|
42 |
+
scale = self.max_size / float(max(height, width))
|
43 |
+
if scale != 1.0:
|
44 |
+
new_size = tuple(round(dim * scale) for dim in (height, width))
|
45 |
+
img = F.resize(img, new_size, self.interpolation)
|
46 |
+
pad_h = self.max_size - new_size[0]
|
47 |
+
pad_w = self.max_size - new_size[1]
|
48 |
+
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
|
49 |
+
return img
|
50 |
+
|
51 |
+
|
52 |
+
def _convert_to_rgb(image):
|
53 |
+
return image.convert('RGB')
|
54 |
+
|
55 |
+
|
56 |
+
def image_transform(
|
57 |
+
image_size: int,
|
58 |
+
is_train: bool,
|
59 |
+
mean: Optional[Tuple[float, ...]] = None,
|
60 |
+
std: Optional[Tuple[float, ...]] = None,
|
61 |
+
resize_longest_max: bool = False,
|
62 |
+
fill_color: int = 0,
|
63 |
+
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
64 |
+
):
|
65 |
+
mean = mean or OPENAI_DATASET_MEAN
|
66 |
+
if not isinstance(mean, (list, tuple)):
|
67 |
+
mean = (mean,) * 3
|
68 |
+
|
69 |
+
std = std or OPENAI_DATASET_STD
|
70 |
+
if not isinstance(std, (list, tuple)):
|
71 |
+
std = (std,) * 3
|
72 |
+
|
73 |
+
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
74 |
+
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
75 |
+
image_size = image_size[0]
|
76 |
+
|
77 |
+
if isinstance(aug_cfg, dict):
|
78 |
+
aug_cfg = AugmentationCfg(**aug_cfg)
|
79 |
+
else:
|
80 |
+
aug_cfg = aug_cfg or AugmentationCfg()
|
81 |
+
normalize = Normalize(mean=mean, std=std)
|
82 |
+
if is_train:
|
83 |
+
aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
|
84 |
+
use_timm = aug_cfg_dict.pop('use_timm', False)
|
85 |
+
if use_timm:
|
86 |
+
from timm.data import create_transform # timm can still be optional
|
87 |
+
if isinstance(image_size, (tuple, list)):
|
88 |
+
assert len(image_size) >= 2
|
89 |
+
input_size = (3,) + image_size[-2:]
|
90 |
+
else:
|
91 |
+
input_size = (3, image_size, image_size)
|
92 |
+
# by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
|
93 |
+
aug_cfg_dict.setdefault('interpolation', 'random')
|
94 |
+
aug_cfg_dict.setdefault('color_jitter', None) # disable by default
|
95 |
+
train_transform = create_transform(
|
96 |
+
input_size=input_size,
|
97 |
+
is_training=True,
|
98 |
+
hflip=0.,
|
99 |
+
mean=mean,
|
100 |
+
std=std,
|
101 |
+
re_mode='pixel',
|
102 |
+
**aug_cfg_dict,
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
train_transform = Compose([
|
106 |
+
RandomResizedCrop(
|
107 |
+
image_size,
|
108 |
+
scale=aug_cfg_dict.pop('scale'),
|
109 |
+
interpolation=InterpolationMode.BICUBIC,
|
110 |
+
),
|
111 |
+
_convert_to_rgb,
|
112 |
+
ToTensor(),
|
113 |
+
normalize,
|
114 |
+
])
|
115 |
+
if aug_cfg_dict:
|
116 |
+
warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
|
117 |
+
return train_transform
|
118 |
+
else:
|
119 |
+
if resize_longest_max:
|
120 |
+
transforms = [
|
121 |
+
ResizeMaxSize(image_size, fill=fill_color)
|
122 |
+
]
|
123 |
+
else:
|
124 |
+
transforms = [
|
125 |
+
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
126 |
+
CenterCrop(image_size),
|
127 |
+
]
|
128 |
+
transforms.extend([
|
129 |
+
_convert_to_rgb,
|
130 |
+
ToTensor(),
|
131 |
+
normalize,
|
132 |
+
])
|
133 |
+
return Compose(transforms)
|
Dockerfile
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# A sample Dockerfile to help you replicate our test environment
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
|
5 |
+
FROM pytorch/pytorch:2.4.1-cuda12.4-cudnn9-runtime
|
6 |
+
WORKDIR /app
|
7 |
+
COPY . .
|
8 |
+
|
9 |
+
# Install your python and apt requirements
|
10 |
+
RUN pip install -r requirements.txt
|
11 |
+
RUN apt-get update && apt-get install $(cat apt_requirements.txt) -y
|
12 |
+
RUN chmod +x run.sh
|
13 |
+
|
14 |
+
CMD ["python3", "runner.py"]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Qihang Zhou
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AnomalyCLIP (Train once and test other)
|
2 |
+
> [**ICLR 24**] [**AnomalyCLIP: Object-agnostic Prompt Learning for Zero-shot Anomaly Detection**](https://arxiv.org/pdf/2310.18961.pdf)
|
3 |
+
>
|
4 |
+
> by [Qihang Zhou*](), [Guansong Pang*](https://www.guansongpang.com/), [Yu Tian](https://yutianyt.com/), [Shibo He](https://scholar.google.com/citations?hl=zh-CN&user=5GOcb4gAAAAJ&view_op=list_works&sortby=pubdate), [Jiming Chen](https://scholar.google.com/citations?user=zK9tvo8AAAAJ&hl=zh-CN).
|
5 |
+
|
6 |
+
|
7 |
+
## Updates
|
8 |
+
|
9 |
+
- **03.19.2024**: Code has been released !!!
|
10 |
+
- **08.08.2024**: Update the code for testing one image.
|
11 |
+
|
12 |
+
## Introduction
|
13 |
+
Zero-shot anomaly detection (ZSAD) requires detection models trained using auxiliary data to detect anomalies without any training sample in a target dataset. It is a crucial task when training data is not accessible due to various concerns, e.g., data privacy, yet it is challenging since the models need to generalize to anomalies across different domains where the appearance of foreground objects, abnormal regions, and background features, such as defects/tumors on different products/organs, can vary significantly. Recently large pre-trained vision-language models (VLMs), such as CLIP,
|
14 |
+
have demonstrated strong zero-shot recognition ability in various vision tasks, including anomaly detection. However, their ZSAD performance is weak since the VLMs focus more on modeling the class semantics of the foreground objects rather than the abnormality/normality in the images.
|
15 |
+
In this paper we introduce a novel approach, namely AnomalyCLIP, to adapt CLIP for accurate ZSAD across different domains. The key insight of AnomalyCLIP is to learn object-agnostic text prompts that capture generic normality and abnormality in an image regardless of its foreground objects. This allows our model to focus on the abnormal image regions rather than the object semantics, enabling generalized normality and abnormality recognition on diverse types of objects. Large-scale experiments on 17 real-world anomaly detection datasets show that AnomalyCLIP achieves superior zero-shot performance of detecting and segmenting anomalies in datasets of highly diverse class semantics from various defect inspection and medical imaging domains. All experiments are conducted in PyTorch-2.0.0 with a single NVIDIA RTX 3090 24GB.
|
16 |
+
|
17 |
+
## Overview of AnomalyCLIP
|
18 |
+

|
19 |
+
|
20 |
+
|
21 |
+
## Analysis of different text prompt templates
|
22 |
+

|
23 |
+
|
24 |
+
|
25 |
+
## How to Run
|
26 |
+
### Prepare your dataset
|
27 |
+
Download the dataset below:
|
28 |
+
|
29 |
+
* Industrial Domain:
|
30 |
+
[MVTec](https://www.mvtec.com/company/research/datasets/mvtec-ad), [VisA](https://github.com/amazon-science/spot-diff), [MPDD](https://github.com/stepanje/MPDD), [BTAD](http://avires.dimi.uniud.it/papers/btad/btad.zip), [SDD](https://www.vicos.si/resources/kolektorsdd/), [DAGM](https://www.kaggle.com/datasets/mhskjelvareid/dagm-2007-competition-dataset-optical-inspection), [DTD-Synthetic](https://drive.google.com/drive/folders/10OyPzvI3H6llCZBxKxFlKWt1Pw1tkMK1)
|
31 |
+
|
32 |
+
* Medical Domain:
|
33 |
+
[HeadCT](https://www.kaggle.com/datasets/felipekitamura/head-ct-hemorrhage), [BrainMRI](https://www.kaggle.com/datasets/navoneel/brain-mri-images-for-brain-tumor-detection), [Br35H](https://www.kaggle.com/datasets/ahmedhamada0/brain-tumor-detection), [COVID-19](https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database), [ISIC](https://isic-challenge-data.s3.amazonaws.com/2016/ISBI2016_ISIC_Part1_Test_Data.zip), [CVC-ColonDB](https://figshare.com/articles/figure/Polyp_DataSet_zip/21221579), [CVC-ClinicDB](https://figshare.com/articles/figure/Polyp_DataSet_zip/21221579), [Kvasir](https://figshare.com/articles/figure/Polyp_DataSet_zip/21221579), [Endo](https://drive.google.com/file/d/1LNpLkv5ZlEUzr_RPN5rdOHaqk0SkZa3m/view), [TN3K](https://github.com/haifangong/TRFE-Net-for-thyroid-nodule-segmentation?tab=readme-ov-file).
|
34 |
+
|
35 |
+
* Google Drive link (frequently requested dataset): [SDD](https://drive.google.com/drive/folders/1oqaxUZYi44jlLT4WtT6D5T6onPTNZXsu?usp=drive_link), [Br35H](https://drive.google.com/file/d/1l9XODMBm4X23K70LtpxAxgoaBbNzr4Nc/view?usp=drive_link), [COVID-19](https://drive.google.com/file/d/1ECwI8DJmhEtcVHatxCAdFqnSmXs35WFL/view?usp=drive_link)
|
36 |
+
### Generate the dataset JSON
|
37 |
+
Take MVTec AD for example (With multiple anomaly categories)
|
38 |
+
|
39 |
+
Structure of MVTec Folder:
|
40 |
+
```
|
41 |
+
mvtec/
|
42 |
+
│
|
43 |
+
├── meta.json
|
44 |
+
│
|
45 |
+
├── bottle/
|
46 |
+
│ ├── ground_truth/
|
47 |
+
│ │ ├── broken_large/
|
48 |
+
│ │ │ └── 000_mask.png
|
49 |
+
| | | └── ...
|
50 |
+
│ │ └── ...
|
51 |
+
│ └── test/
|
52 |
+
│ ├── broken_large/
|
53 |
+
│ │ └── 000.png
|
54 |
+
| | └── ...
|
55 |
+
│ └── ...
|
56 |
+
│
|
57 |
+
└── ...
|
58 |
+
```
|
59 |
+
|
60 |
+
```bash
|
61 |
+
cd generate_dataset_json
|
62 |
+
python mvtec.py
|
63 |
+
```
|
64 |
+
|
65 |
+
Take SDD for example (With single anomaly category)
|
66 |
+
|
67 |
+
Structure of SDD Folder:
|
68 |
+
```
|
69 |
+
SDD/
|
70 |
+
│
|
71 |
+
├── electrical_commutators/
|
72 |
+
│ └── test/
|
73 |
+
│ ├─��� defect/
|
74 |
+
│ │ └── kos01_Part5_0.png
|
75 |
+
| | └── ...
|
76 |
+
│ └── good/
|
77 |
+
│ └── kos01_Part0_0.png
|
78 |
+
│ └── ...
|
79 |
+
│
|
80 |
+
└── meta.json
|
81 |
+
```
|
82 |
+
|
83 |
+
```bash
|
84 |
+
cd generate_dataset_json
|
85 |
+
python SDD.py
|
86 |
+
```
|
87 |
+
Select the corresponding script and run it (we provide all scripts for datasets that AnomalyCLIP reported). The generated JSON stores all the information that AnomalyCLIP needs.
|
88 |
+
|
89 |
+
### Custom dataset (optional)
|
90 |
+
1. Create a new JSON script in fold [generate_dataset_json](https://github.com/zqhang/AnomalyCLIP/tree/main/generate_dataset_json) according to the fold structure of your own datasets.
|
91 |
+
2. Add the related info of your dataset (i.e., dataset name and class names) in script [dataset\.py](https://github.com/zqhang/AnomalyCLIP/blob/main/dataset.py)
|
92 |
+
|
93 |
+
### Run AnomalyCLIP
|
94 |
+
* Quick start (use the pre-trained weights)
|
95 |
+
```bash
|
96 |
+
bash test.sh
|
97 |
+
```
|
98 |
+
|
99 |
+
* Train your own weights
|
100 |
+
```bash
|
101 |
+
bash train.sh
|
102 |
+
```
|
103 |
+
|
104 |
+
|
105 |
+
## Main results (We test all datasets by training once on MVTec AD. For MVTec AD, AnomalyCLIP is trained on VisA.)
|
106 |
+
|
107 |
+
### Industrial dataset
|
108 |
+

|
109 |
+
|
110 |
+
|
111 |
+
### Medical dataset
|
112 |
+

|
113 |
+
|
114 |
+
|
115 |
+
## Visualization
|
116 |
+
|
117 |
+

|
118 |
+
|
119 |
+

|
120 |
+
|
121 |
+

|
122 |
+
|
123 |
+

|
124 |
+
|
125 |
+
|
126 |
+
## We provide the reproduction of WinCLIP [here](https://github.com/zqhang/WinCLIP-pytorch)
|
127 |
+
|
128 |
+
|
129 |
+
* We thank for the code repository: [open_clip](https://github.com/mlfoundations/open_clip), [DualCoOp](https://github.com/sunxm2357/DualCoOp), [CLIP_Surgery](https://github.com/xmed-lab/CLIP_Surgery), and [VAND](https://github.com/ByChelsea/VAND-APRIL-GAN/tree/master).
|
130 |
+
|
131 |
+
## BibTex Citation
|
132 |
+
|
133 |
+
If you find this paper and repository useful, please cite our paper.
|
134 |
+
|
135 |
+
```
|
136 |
+
@inproceedings{zhou2023anomalyclip,
|
137 |
+
title={AnomalyCLIP: Object-agnostic Prompt Learning for Zero-shot Anomaly Detection},
|
138 |
+
author={Zhou, Qihang and Pang, Guansong and Tian, Yu and He, Shibo and Chen, Jiming},
|
139 |
+
booktitle={The Twelfth International Conference on Learning Representations},
|
140 |
+
year={2023}
|
141 |
+
}
|
142 |
+
```
|
checkpoints/9_12_4_multiscale/epoch_1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a89d1ffe49d86995e936c8e91515efa878d4e1777c73888622091e89a8df9e5b
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale/epoch_10.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7205c05df3319984b349686cbfd8cc01d3ac241a82f33943e9217cbb85604b0b
|
3 |
+
size 22631975
|
checkpoints/9_12_4_multiscale/epoch_11.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:40017b0588b3e41aea4cf3902b388bbee494201b4406583f0a9c96f90818a986
|
3 |
+
size 22631975
|
checkpoints/9_12_4_multiscale/epoch_12.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef4bdfad5689797d48296eeceb57343aabba5ae5a2c7e57d4b9e225d2d254252
|
3 |
+
size 22631975
|
checkpoints/9_12_4_multiscale/epoch_13.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4381596b44bbaa33e7b04b4a19a46582980f1ee8742414d71147c8be95ef90d7
|
3 |
+
size 22631975
|
checkpoints/9_12_4_multiscale/epoch_14.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fd2a3865c4cf1363b80f301da7dc181a54787e3c218cc1f3464650a5f749cb26
|
3 |
+
size 22631975
|
checkpoints/9_12_4_multiscale/epoch_15.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:94ce202da3e6486a864b904fdfed5057de75846c5834e446fd1d2fe7f97acb44
|
3 |
+
size 22631975
|
checkpoints/9_12_4_multiscale/epoch_2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f6bfcd2ed1725b3d58dd06d5d38f7ef6d3b9c49d817bb4714a16f3153c3d7450
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale/epoch_3.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5af4c383158732845ac2ef195e5036e8528f187ed80173c8d993830a0abed64c
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale/epoch_4.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ab9a9909711c89cac5f02f0c46c7baac82b09bfaca59a83271a50b195cad89f
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale/epoch_5.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:317837a0ef5b46d2476c234d3fa77e8cfab7bbfa85711f5fe7eb7f50ea7151a0
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale/epoch_6.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:04379155c0df8d4e1194335427091e626df512a9747e47c1bbb7ee3a55708164
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale/epoch_7.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:41c5a77a355c27266d6a9c7b6da4b3ee2c193596873d889822e68a797a2688b2
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale/epoch_8.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c92bfa088eccb2efb71b27c9703c0f21158903581efd7292f42938ad96940c82
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale/epoch_9.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:43f0eca2d506b88370a06c94a6cd557360c7bcb179a4f3f24981230349a9581a
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale/log.txt
ADDED
File without changes
|
checkpoints/9_12_4_multiscale_visa/epoch_1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de5df7fc2ec18acb5709e65b1889d586974d365c39d1aa4df728336633e4ee70
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale_visa/epoch_10.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:397255934bd313beeab2b610fa901f113e12342974687147cad78f502e5ae7e5
|
3 |
+
size 22631975
|
checkpoints/9_12_4_multiscale_visa/epoch_11.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:843fb9df1c46da89f6976a42d10d5fe34675ad48eccb365e3f43785f925c2ae9
|
3 |
+
size 22631975
|
checkpoints/9_12_4_multiscale_visa/epoch_12.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:17f69ad9ae4bcc5823fdd9ad56b51ec57cc641270280a1776c1014ea1969f282
|
3 |
+
size 22631975
|
checkpoints/9_12_4_multiscale_visa/epoch_13.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5bf5fd9c269e3f68e81134f4361c3239ba14d5f2cd4e3564f93f5b59f616cd19
|
3 |
+
size 22631975
|
checkpoints/9_12_4_multiscale_visa/epoch_14.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:969dbaaa1a986f17d79dfb81d2ce90443d0e9dd9f19db7fd9a9190f97cc8e3d4
|
3 |
+
size 22631975
|
checkpoints/9_12_4_multiscale_visa/epoch_15.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:415c5dcb52668b8c33fb9c1a351c686d632b919df5b384d63fa9ce7a2338ced4
|
3 |
+
size 22631975
|
checkpoints/9_12_4_multiscale_visa/epoch_2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c98c722977ac0fc42c1067a8038656c10466728f6e9d448aad9e3f6b3d5368b6
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale_visa/epoch_3.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d3e7a65d6b9ff057b5fa53bfc59bfa57a25619b5a5d9cd40ed37579e312ab4aa
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale_visa/epoch_4.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f56b0ed7bd9da05f77780a3c4318e038c258b99a02ad1455652cad146b3dded5
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale_visa/epoch_5.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f2c44c082a19abde2993e80044466c1e45a620cc24aad39e85bd65ed60d3572d
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale_visa/epoch_6.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:402d63bca2150631fb09d8d1c7529712a4ee8eea29bd7746412eae99b4ec6dc5
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale_visa/epoch_7.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:081526236212ebc011ec53babaf8f0da7e25fbe92300aa7cc68eb41ca29b054f
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale_visa/epoch_8.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f2587be72657ab30fc26bc5957e130ba7359ff53c32beb7984be517a818427c
|
3 |
+
size 22631493
|
checkpoints/9_12_4_multiscale_visa/epoch_9.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4850f209b34912c33718b86c13d2a01c340907d182236a8ef8903f35c80daec0
|
3 |
+
size 22631493
|
dataset.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
|
9 |
+
class Dataset(data.Dataset):
|
10 |
+
def __init__(self, root, transform, target_transform, dataset_name, mode='test'):
|
11 |
+
self.root = root
|
12 |
+
self.transform = transform
|
13 |
+
self.target_transform = target_transform
|
14 |
+
self.data_all = []
|
15 |
+
meta_info = json.load(open(f'{self.root}/meta.json', 'r'))
|
16 |
+
name = self.root.split('/')[-1]
|
17 |
+
meta_info = meta_info[mode]
|
18 |
+
|
19 |
+
self.cls_names = list(meta_info.keys())
|
20 |
+
for cls_name in self.cls_names:
|
21 |
+
self.data_all.extend(meta_info[cls_name])
|
22 |
+
self.length = len(self.data_all)
|
23 |
+
|
24 |
+
self.obj_list = [folder for folder in os.listdir(root) if os.path.isdir(os.path.join(root, folder)) and not folder.startswith('.')]
|
25 |
+
self.class_name_map_class_id = {o: i for i, o in enumerate(self.obj_list)}
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return self.length
|
29 |
+
|
30 |
+
def __getitem__(self, index):
|
31 |
+
data = self.data_all[index]
|
32 |
+
img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
|
33 |
+
data['specie_name'], data['anomaly']
|
34 |
+
img = Image.open(os.path.join(self.root, img_path))
|
35 |
+
if anomaly == 0:
|
36 |
+
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
|
37 |
+
else:
|
38 |
+
if os.path.isdir(os.path.join(self.root, mask_path)):
|
39 |
+
# just for classification not report error
|
40 |
+
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
|
41 |
+
else:
|
42 |
+
img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
|
43 |
+
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
|
44 |
+
# transforms
|
45 |
+
img = self.transform(img) if self.transform is not None else img
|
46 |
+
img_mask = self.target_transform(
|
47 |
+
img_mask) if self.target_transform is not None and img_mask is not None else img_mask
|
48 |
+
img_mask = [] if img_mask is None else img_mask
|
49 |
+
return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
|
50 |
+
'img_path': os.path.join(self.root, img_path), "cls_id": self.class_name_map_class_id[cls_name]}
|
datasets/rayan_dataset.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# Do Not Alter This File!
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
# The following code is part of the logic used for loading and evaluating your
|
5 |
+
# output scores. Please DO NOT modify this section, as upon your submission,
|
6 |
+
# the whole evaluation logic will be overwritten by the original code.
|
7 |
+
# -----------------------------------------------------------------------------
|
8 |
+
# If you'd like to make modifications, you can create a completely new Dataset
|
9 |
+
# class or a child class that inherits from this one and use that with your
|
10 |
+
# data loader.
|
11 |
+
# -----------------------------------------------------------------------------
|
12 |
+
|
13 |
+
import os
|
14 |
+
from enum import Enum
|
15 |
+
|
16 |
+
import PIL
|
17 |
+
import torch
|
18 |
+
from torchvision import transforms
|
19 |
+
|
20 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
21 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
22 |
+
|
23 |
+
|
24 |
+
class DatasetSplit(Enum):
|
25 |
+
TRAIN = "train"
|
26 |
+
VAL = "val"
|
27 |
+
TEST = "test"
|
28 |
+
|
29 |
+
|
30 |
+
class RayanDataset(torch.utils.data.Dataset):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
source,
|
34 |
+
classname,
|
35 |
+
input_size=518,
|
36 |
+
output_size=224,
|
37 |
+
split=DatasetSplit.TEST,
|
38 |
+
external_transform=None,
|
39 |
+
**kwargs,
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
self.source = source
|
43 |
+
self.split = split
|
44 |
+
self.classnames_to_use = [classname]
|
45 |
+
self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()
|
46 |
+
|
47 |
+
if external_transform is None:
|
48 |
+
self.transform_img = [
|
49 |
+
transforms.Resize((input_size, input_size)),
|
50 |
+
transforms.CenterCrop(input_size),
|
51 |
+
transforms.ToTensor(),
|
52 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
53 |
+
]
|
54 |
+
self.transform_img = transforms.Compose(self.transform_img)
|
55 |
+
else:
|
56 |
+
self.transform_img = external_transform
|
57 |
+
|
58 |
+
# Output size of the mask has to be of shape: 1×224×224
|
59 |
+
self.transform_mask = [
|
60 |
+
transforms.Resize((output_size, output_size)),
|
61 |
+
transforms.CenterCrop(output_size),
|
62 |
+
transforms.ToTensor(),
|
63 |
+
]
|
64 |
+
self.transform_mask = transforms.Compose(self.transform_mask)
|
65 |
+
self.output_shape = (1, output_size, output_size)
|
66 |
+
|
67 |
+
def __getitem__(self, idx):
|
68 |
+
classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]
|
69 |
+
image = PIL.Image.open(image_path).convert("RGB")
|
70 |
+
image = self.transform_img(image)
|
71 |
+
|
72 |
+
if self.split == DatasetSplit.TEST and mask_path is not None:
|
73 |
+
mask = PIL.Image.open(mask_path).convert("L")
|
74 |
+
mask = self.transform_mask(mask) > 0
|
75 |
+
else:
|
76 |
+
mask = torch.zeros([*self.output_shape])
|
77 |
+
|
78 |
+
return {
|
79 |
+
"image": image,
|
80 |
+
"mask": mask,
|
81 |
+
"is_anomaly": int(anomaly != "good"),
|
82 |
+
"image_path": image_path,
|
83 |
+
}
|
84 |
+
|
85 |
+
def __len__(self):
|
86 |
+
return len(self.data_to_iterate)
|
87 |
+
|
88 |
+
def get_image_data(self):
|
89 |
+
imgpaths_per_class = {}
|
90 |
+
maskpaths_per_class = {}
|
91 |
+
|
92 |
+
for classname in self.classnames_to_use:
|
93 |
+
classpath = os.path.join(self.source, classname, self.split.value)
|
94 |
+
maskpath = os.path.join(self.source, classname, "ground_truth")
|
95 |
+
anomaly_types = os.listdir(classpath)
|
96 |
+
|
97 |
+
imgpaths_per_class[classname] = {}
|
98 |
+
maskpaths_per_class[classname] = {}
|
99 |
+
|
100 |
+
for anomaly in anomaly_types:
|
101 |
+
anomaly_path = os.path.join(classpath, anomaly)
|
102 |
+
anomaly_files = sorted(os.listdir(anomaly_path))
|
103 |
+
imgpaths_per_class[classname][anomaly] = [
|
104 |
+
os.path.join(anomaly_path, x) for x in anomaly_files
|
105 |
+
]
|
106 |
+
|
107 |
+
if self.split == DatasetSplit.TEST and anomaly != "good":
|
108 |
+
anomaly_mask_path = os.path.join(maskpath, anomaly)
|
109 |
+
anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
|
110 |
+
maskpaths_per_class[classname][anomaly] = [
|
111 |
+
os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files
|
112 |
+
]
|
113 |
+
else:
|
114 |
+
maskpaths_per_class[classname]["good"] = None
|
115 |
+
|
116 |
+
data_to_iterate = []
|
117 |
+
for classname in sorted(imgpaths_per_class.keys()):
|
118 |
+
for anomaly in sorted(imgpaths_per_class[classname].keys()):
|
119 |
+
for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
|
120 |
+
data_tuple = [classname, anomaly, image_path]
|
121 |
+
if self.split == DatasetSplit.TEST and anomaly != "good":
|
122 |
+
data_tuple.append(maskpaths_per_class[classname][anomaly][i])
|
123 |
+
else:
|
124 |
+
data_tuple.append(None)
|
125 |
+
data_to_iterate.append(data_tuple)
|
126 |
+
|
127 |
+
return imgpaths_per_class, data_to_iterate
|
docker-compose.yml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# A sample Docker Compose file to help you replicate our test environment
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
|
5 |
+
services:
|
6 |
+
zsad-service:
|
7 |
+
image: zsad-image:1
|
8 |
+
build:
|
9 |
+
context: .
|
10 |
+
container_name: zsad-container
|
11 |
+
volumes:
|
12 |
+
- ./shared_folder:/app/output
|
13 |
+
deploy:
|
14 |
+
resources:
|
15 |
+
reservations:
|
16 |
+
devices:
|
17 |
+
- driver: nvidia
|
18 |
+
count: all
|
19 |
+
capabilities: [gpu]
|
20 |
+
|
21 |
+
command: [ "python3", "runner.py" ]
|
evaluation/base_eval.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# Do Not Alter This File!
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
# The following code is part of the logic used for loading and evaluating your
|
5 |
+
# output scores. Please DO NOT modify this section, as upon your submission,
|
6 |
+
# the whole evaluation logic will be overwritten by the original code.
|
7 |
+
# -----------------------------------------------------------------------------
|
8 |
+
|
9 |
+
import warnings
|
10 |
+
import os
|
11 |
+
from pathlib import Path
|
12 |
+
import csv
|
13 |
+
import json
|
14 |
+
import torch
|
15 |
+
|
16 |
+
import datasets.rayan_dataset as rayan_dataset
|
17 |
+
from evaluation.utils.metrics import compute_metrics
|
18 |
+
|
19 |
+
warnings.filterwarnings("ignore")
|
20 |
+
|
21 |
+
|
22 |
+
class BaseEval:
|
23 |
+
def __init__(self, cfg):
|
24 |
+
self.cfg = cfg
|
25 |
+
self.device = torch.device(
|
26 |
+
"cuda:{}".format(cfg["device"]) if torch.cuda.is_available() else "cpu"
|
27 |
+
)
|
28 |
+
|
29 |
+
self.path = cfg["datasets"]["data_path"]
|
30 |
+
self.dataset = cfg["datasets"]["dataset_name"]
|
31 |
+
self.save_csv = cfg["testing"]["save_csv"]
|
32 |
+
self.save_json = cfg["testing"]["save_json"]
|
33 |
+
self.categories = cfg["datasets"]["class_name"]
|
34 |
+
if isinstance(self.categories, str):
|
35 |
+
if self.categories.lower() == "all":
|
36 |
+
if self.dataset == "rayan_dataset":
|
37 |
+
self.categories = self.get_available_class_names(self.path)
|
38 |
+
else:
|
39 |
+
self.categories = [self.categories]
|
40 |
+
self.output_dir = cfg["testing"]["output_dir"]
|
41 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
42 |
+
self.scores_dir = cfg["testing"]["output_scores_dir"]
|
43 |
+
self.class_name_mapping_dir = cfg["testing"]["class_name_mapping_dir"]
|
44 |
+
|
45 |
+
self.leaderboard_metric_weights = {
|
46 |
+
"image_auroc": 1.2,
|
47 |
+
"image_ap": 1.1,
|
48 |
+
"image_f1": 1.1,
|
49 |
+
"pixel_auroc": 1.0,
|
50 |
+
"pixel_aupro": 1.4,
|
51 |
+
"pixel_ap": 1.3,
|
52 |
+
"pixel_f1": 1.3,
|
53 |
+
}
|
54 |
+
|
55 |
+
def get_available_class_names(self, root_data_path):
|
56 |
+
all_items = os.listdir(root_data_path)
|
57 |
+
folder_names = [
|
58 |
+
item
|
59 |
+
for item in all_items
|
60 |
+
if os.path.isdir(os.path.join(root_data_path, item))
|
61 |
+
]
|
62 |
+
|
63 |
+
return folder_names
|
64 |
+
|
65 |
+
def load_datasets(self, category):
|
66 |
+
dataset_classes = {
|
67 |
+
"rayan_dataset": rayan_dataset.RayanDataset,
|
68 |
+
}
|
69 |
+
|
70 |
+
dataset_splits = {
|
71 |
+
"rayan_dataset": rayan_dataset.DatasetSplit.TEST,
|
72 |
+
}
|
73 |
+
|
74 |
+
test_dataset = dataset_classes[self.dataset](
|
75 |
+
source=self.path,
|
76 |
+
split=dataset_splits[self.dataset],
|
77 |
+
classname=category,
|
78 |
+
)
|
79 |
+
return test_dataset
|
80 |
+
|
81 |
+
def get_category_metrics(self, category):
|
82 |
+
print(f"Loading scores of '{category}'")
|
83 |
+
gt_sp, pr_sp, gt_px, pr_px, _ = self.load_category_scores(category)
|
84 |
+
|
85 |
+
print(f"Computing metrics for '{category}'")
|
86 |
+
image_metric, pixel_metric = compute_metrics(gt_sp, pr_sp, gt_px, pr_px)
|
87 |
+
|
88 |
+
return image_metric, pixel_metric
|
89 |
+
|
90 |
+
def load_category_scores(self, category):
|
91 |
+
raise NotImplementedError()
|
92 |
+
|
93 |
+
def get_scores_path_for_image(self, image_path):
|
94 |
+
"""example image_path: './data/photovoltaic_module/test/good/037.png'"""
|
95 |
+
path = Path(image_path)
|
96 |
+
|
97 |
+
category, split, anomaly_type = path.parts[-4:-1]
|
98 |
+
image_name = path.stem
|
99 |
+
|
100 |
+
return os.path.join(
|
101 |
+
self.scores_dir, category, split, anomaly_type, f"{image_name}_scores.json"
|
102 |
+
)
|
103 |
+
|
104 |
+
def calc_leaderboard_score(self, **metrics):
|
105 |
+
weighted_sum = 0
|
106 |
+
total_weight = 0
|
107 |
+
for key, weight in self.leaderboard_metric_weights.items():
|
108 |
+
metric = metrics.get(key)
|
109 |
+
weighted_sum += metric * weight
|
110 |
+
total_weight += weight
|
111 |
+
|
112 |
+
if total_weight == 0:
|
113 |
+
return 0
|
114 |
+
|
115 |
+
return weighted_sum / total_weight
|
116 |
+
|
117 |
+
def main(self):
|
118 |
+
image_auroc_list = []
|
119 |
+
image_f1_list = []
|
120 |
+
image_ap_list = []
|
121 |
+
pixel_auroc_list = []
|
122 |
+
pixel_f1_list = []
|
123 |
+
pixel_ap_list = []
|
124 |
+
pixel_aupro_list = []
|
125 |
+
leaderboard_score_list = []
|
126 |
+
for category in self.categories:
|
127 |
+
image_metric, pixel_metric = self.get_category_metrics(
|
128 |
+
category=category,
|
129 |
+
)
|
130 |
+
image_auroc, image_f1, image_ap = image_metric
|
131 |
+
pixel_auroc, pixel_f1, pixel_ap, pixel_aupro = pixel_metric
|
132 |
+
leaderboard_score = self.calc_leaderboard_score(
|
133 |
+
image_auroc=image_auroc,
|
134 |
+
image_f1=image_f1,
|
135 |
+
image_ap=image_ap,
|
136 |
+
pixel_auroc=pixel_auroc,
|
137 |
+
pixel_aupro=pixel_aupro,
|
138 |
+
pixel_f1=pixel_f1,
|
139 |
+
pixel_ap=pixel_ap,
|
140 |
+
)
|
141 |
+
|
142 |
+
image_auroc_list.append(image_auroc)
|
143 |
+
image_f1_list.append(image_f1)
|
144 |
+
image_ap_list.append(image_ap)
|
145 |
+
pixel_auroc_list.append(pixel_auroc)
|
146 |
+
pixel_f1_list.append(pixel_f1)
|
147 |
+
pixel_ap_list.append(pixel_ap)
|
148 |
+
pixel_aupro_list.append(pixel_aupro)
|
149 |
+
leaderboard_score_list.append(leaderboard_score)
|
150 |
+
|
151 |
+
print(category)
|
152 |
+
print(
|
153 |
+
"[image level] auroc:{}, f1:{}, ap:{}".format(
|
154 |
+
image_auroc * 100,
|
155 |
+
image_f1 * 100,
|
156 |
+
image_ap * 100,
|
157 |
+
)
|
158 |
+
)
|
159 |
+
print(
|
160 |
+
"[pixel level] auroc:{}, f1:{}, ap:{}, aupro:{}".format(
|
161 |
+
pixel_auroc * 100,
|
162 |
+
pixel_f1 * 100,
|
163 |
+
pixel_ap * 100,
|
164 |
+
pixel_aupro * 100,
|
165 |
+
)
|
166 |
+
)
|
167 |
+
print(
|
168 |
+
"leaderboard score:{}".format(
|
169 |
+
leaderboard_score * 100,
|
170 |
+
)
|
171 |
+
)
|
172 |
+
|
173 |
+
image_auroc_mean = sum(image_auroc_list) / len(image_auroc_list)
|
174 |
+
image_f1_mean = sum(image_f1_list) / len(image_f1_list)
|
175 |
+
image_ap_mean = sum(image_ap_list) / len(image_ap_list)
|
176 |
+
pixel_auroc_mean = sum(pixel_auroc_list) / len(pixel_auroc_list)
|
177 |
+
pixel_f1_mean = sum(pixel_f1_list) / len(pixel_f1_list)
|
178 |
+
pixel_ap_mean = sum(pixel_ap_list) / len(pixel_ap_list)
|
179 |
+
pixel_aupro_mean = sum(pixel_aupro_list) / len(pixel_aupro_list)
|
180 |
+
leaderboard_score_mean = sum(leaderboard_score_list) / len(
|
181 |
+
leaderboard_score_list
|
182 |
+
)
|
183 |
+
|
184 |
+
print("mean")
|
185 |
+
print(
|
186 |
+
"[image level] auroc:{}, f1:{}, ap:{}".format(
|
187 |
+
image_auroc_mean * 100, image_f1_mean * 100, image_ap_mean * 100
|
188 |
+
)
|
189 |
+
)
|
190 |
+
print(
|
191 |
+
"[pixel level] auroc:{}, f1:{}, ap:{}, aupro:{}".format(
|
192 |
+
pixel_auroc_mean * 100,
|
193 |
+
pixel_f1_mean * 100,
|
194 |
+
pixel_ap_mean * 100,
|
195 |
+
pixel_aupro_mean * 100,
|
196 |
+
)
|
197 |
+
)
|
198 |
+
print(
|
199 |
+
"leaderboard score:{}".format(
|
200 |
+
leaderboard_score_mean * 100,
|
201 |
+
)
|
202 |
+
)
|
203 |
+
|
204 |
+
# Save the final results as a csv file
|
205 |
+
if self.save_csv:
|
206 |
+
with open(self.class_name_mapping_dir, "r") as f:
|
207 |
+
class_name_mapping_dict = json.load(f)
|
208 |
+
csv_data = [
|
209 |
+
[
|
210 |
+
"Category",
|
211 |
+
"pixel_auroc",
|
212 |
+
"pixel_f1",
|
213 |
+
"pixel_ap",
|
214 |
+
"pixel_aupro",
|
215 |
+
"image_auroc",
|
216 |
+
"image_f1",
|
217 |
+
"image_ap",
|
218 |
+
"leaderboard_score",
|
219 |
+
]
|
220 |
+
]
|
221 |
+
for i, category in enumerate(self.categories):
|
222 |
+
csv_data.append(
|
223 |
+
[
|
224 |
+
class_name_mapping_dict[category],
|
225 |
+
pixel_auroc_list[i] * 100,
|
226 |
+
pixel_f1_list[i] * 100,
|
227 |
+
pixel_ap_list[i] * 100,
|
228 |
+
pixel_aupro_list[i] * 100,
|
229 |
+
image_auroc_list[i] * 100,
|
230 |
+
image_f1_list[i] * 100,
|
231 |
+
image_ap_list[i] * 100,
|
232 |
+
leaderboard_score_list[i] * 100,
|
233 |
+
]
|
234 |
+
)
|
235 |
+
csv_data.append(
|
236 |
+
[
|
237 |
+
"mean",
|
238 |
+
pixel_auroc_mean * 100,
|
239 |
+
pixel_f1_mean * 100,
|
240 |
+
pixel_ap_mean * 100,
|
241 |
+
pixel_aupro_mean * 100,
|
242 |
+
image_auroc_mean * 100,
|
243 |
+
image_f1_mean * 100,
|
244 |
+
image_ap_mean * 100,
|
245 |
+
leaderboard_score_mean * 100,
|
246 |
+
]
|
247 |
+
)
|
248 |
+
|
249 |
+
csv_file_path = os.path.join(self.output_dir, "results.csv")
|
250 |
+
with open(csv_file_path, mode="w", newline="") as file:
|
251 |
+
writer = csv.writer(file)
|
252 |
+
writer.writerows(csv_data)
|
253 |
+
|
254 |
+
# Save the final results as a json file
|
255 |
+
if self.save_json:
|
256 |
+
json_data = []
|
257 |
+
with open(self.class_name_mapping_dir, "r") as f:
|
258 |
+
class_name_mapping_dict = json.load(f)
|
259 |
+
for i, category in enumerate(self.categories):
|
260 |
+
json_data.append(
|
261 |
+
{
|
262 |
+
"Category": class_name_mapping_dict[category],
|
263 |
+
"pixel_auroc": pixel_auroc_list[i] * 100,
|
264 |
+
"pixel_f1": pixel_f1_list[i] * 100,
|
265 |
+
"pixel_ap": pixel_ap_list[i] * 100,
|
266 |
+
"pixel_aupro": pixel_aupro_list[i] * 100,
|
267 |
+
"image_auroc": image_auroc_list[i] * 100,
|
268 |
+
"image_f1": image_f1_list[i] * 100,
|
269 |
+
"image_ap": image_ap_list[i] * 100,
|
270 |
+
"leaderboard_score": leaderboard_score_list[i] * 100,
|
271 |
+
}
|
272 |
+
)
|
273 |
+
json_data.append(
|
274 |
+
{
|
275 |
+
"Category": "mean",
|
276 |
+
"pixel_auroc": pixel_auroc_mean * 100,
|
277 |
+
"pixel_f1": pixel_f1_mean * 100,
|
278 |
+
"pixel_ap": pixel_ap_mean * 100,
|
279 |
+
"pixel_aupro": pixel_aupro_mean * 100,
|
280 |
+
"image_auroc": image_auroc_mean * 100,
|
281 |
+
"image_f1": image_f1_mean * 100,
|
282 |
+
"image_ap": image_ap_mean * 100,
|
283 |
+
"leaderboard_score": leaderboard_score_mean * 100,
|
284 |
+
}
|
285 |
+
)
|
286 |
+
|
287 |
+
json_file_path = os.path.join(self.output_dir, "results.json")
|
288 |
+
with open(json_file_path, mode="w") as file:
|
289 |
+
final_json = {
|
290 |
+
"result": leaderboard_score_mean * 100,
|
291 |
+
"metadata": json_data,
|
292 |
+
}
|
293 |
+
json.dump(final_json, file, indent=4)
|
evaluation/class_name_mapping.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"pill": "industrial_01",
|
3 |
+
"photovoltaic_module": "industrial_02",
|
4 |
+
"capsules": "industrial_03"
|
5 |
+
}
|
evaluation/eval_main.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# Do Not Alter This File!
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
# The following code is part of the logic used for loading and evaluating your
|
5 |
+
# output scores. Please DO NOT modify this section, as upon your submission,
|
6 |
+
# the whole evaluation logic will be overwritten by the original code.
|
7 |
+
# -----------------------------------------------------------------------------
|
8 |
+
|
9 |
+
import warnings
|
10 |
+
import argparse
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
|
14 |
+
sys.path.append(os.getcwd())
|
15 |
+
from evaluation.json_score import JsonScoreEvaluator
|
16 |
+
|
17 |
+
warnings.filterwarnings("ignore")
|
18 |
+
|
19 |
+
|
20 |
+
def get_args():
|
21 |
+
parser = argparse.ArgumentParser(description="Rayan ZSAD Evaluation Code")
|
22 |
+
parser.add_argument("--data_path", type=str, default=None, help="dataset path")
|
23 |
+
parser.add_argument("--dataset_name", type=str, default=None, help="dataset name")
|
24 |
+
parser.add_argument("--class_name", type=str, default=None, help="category")
|
25 |
+
parser.add_argument("--device", type=int, default=None, help="gpu id")
|
26 |
+
parser.add_argument(
|
27 |
+
"--output_dir", type=str, default=None, help="save results path"
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--output_scores_dir", type=str, default=None, help="save scores path"
|
31 |
+
)
|
32 |
+
parser.add_argument("--save_csv", type=str, default=None, help="save csv")
|
33 |
+
parser.add_argument("--save_json", type=str, default=None, help="save json")
|
34 |
+
|
35 |
+
parser.add_argument(
|
36 |
+
"--class_name_mapping_dir",
|
37 |
+
type=str,
|
38 |
+
default=None,
|
39 |
+
help="mapping from actual class names to class numbers",
|
40 |
+
)
|
41 |
+
args = parser.parse_args()
|
42 |
+
return args
|
43 |
+
|
44 |
+
|
45 |
+
def load_args(cfg, args):
|
46 |
+
cfg["datasets"]["data_path"] = args.data_path
|
47 |
+
assert os.path.exists(
|
48 |
+
cfg["datasets"]["data_path"]
|
49 |
+
), f"The dataset path {cfg['datasets']['data_path']} does not exist."
|
50 |
+
cfg["datasets"]["dataset_name"] = args.dataset_name
|
51 |
+
cfg["datasets"]["class_name"] = args.class_name
|
52 |
+
cfg["device"] = args.device
|
53 |
+
if isinstance(cfg["device"], int):
|
54 |
+
cfg["device"] = str(cfg["device"])
|
55 |
+
cfg["testing"]["output_dir"] = args.output_dir
|
56 |
+
cfg["testing"]["output_scores_dir"] = args.output_scores_dir
|
57 |
+
os.makedirs(cfg["testing"]["output_scores_dir"], exist_ok=True)
|
58 |
+
|
59 |
+
cfg["testing"]["class_name_mapping_dir"] = args.class_name_mapping_dir
|
60 |
+
if args.save_csv.lower() == "true":
|
61 |
+
cfg["testing"]["save_csv"] = True
|
62 |
+
else:
|
63 |
+
cfg["testing"]["save_csv"] = False
|
64 |
+
|
65 |
+
if args.save_json.lower() == "true":
|
66 |
+
cfg["testing"]["save_json"] = True
|
67 |
+
else:
|
68 |
+
cfg["testing"]["save_json"] = False
|
69 |
+
|
70 |
+
return cfg
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
args = get_args()
|
75 |
+
cfg = load_args(cfg={"datasets": {}, "testing": {}, "models": {}}, args=args)
|
76 |
+
print(cfg)
|
77 |
+
model = JsonScoreEvaluator(cfg=cfg)
|
78 |
+
model.main()
|