diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..63a56a8fc48ba74b36f39e2a9b3f1242c490235f --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +*.pyc +*.pyo +__pycache__/ +*.tar.gz +*.tar.xz +ZSAD-dataset +data/ diff --git a/AnomalyCLIP_lib/AnomalyCLIP.py b/AnomalyCLIP_lib/AnomalyCLIP.py new file mode 100644 index 0000000000000000000000000000000000000000..bc011cd675b84a93fd65118be2e8dee1423ef5cb --- /dev/null +++ b/AnomalyCLIP_lib/AnomalyCLIP.py @@ -0,0 +1,531 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +# implement attention module for v-v self-attention +class Attention(nn.Module): + def __init__(self, out_dim, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., settings=''): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(out_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.settings = settings + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + # original self-attention for the original path + attn_ori = (q @ k.transpose(-2, -1)) * self.scale + attn_ori = attn_ori.softmax(dim=-1) + attn_ori = self.attn_drop(attn_ori) + + # replace k & q by v + k = v + q = k + + # self-attention, higher temperate for resnets performs better + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = (attn).softmax(dim=-1) + attn = self.attn_drop(attn) + + x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj_drop(self.proj(x)) + x_ori = self.proj_drop(self.proj(x_ori)) + return [x, x_ori] + + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, design_details = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + if isinstance(self.attn, Attention): + x = x.transpose(0, 1) + x, x_ori = self.attn(x) + return [x.transpose(0, 1), x_ori.transpose(0, 1)] + else: + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x, whole = False, ffn = False): + # print("xxxxx",x.shape) + # dual paths for blocks deeper than "d" + + if isinstance(self.attn, Attention): + if isinstance(x, list): + if not ffn: + x, x_ori = x + x_res = self.attention(self.ln_1(x_ori)) + x_res, x_ori_res = x_res + x_ori += x_ori_res + x_ori = x_ori + self.mlp(self.ln_2(x_ori)) + x += x_res # skip ffn for the new path + # print('hellloooo') + return [x, x_ori] + else: + x, x_ori_1 = x + x_res = self.attention(self.ln_1(x_ori_1)) + x_res, x_ori_res = x_res + x_ori = x_ori_1 + x_ori_res + x_ori = x_ori + self.mlp(self.ln_2(x_ori)) + x += x_res # skip ffn for the new path + x = x_res + x_ori_1 + x = x + self.mlp(self.ln_2(x)) + return [x, x_ori] + # start of dual path + else: + x_res = self.attention(self.ln_1(x)) + if isinstance(x_res, list): + x_res, x_ori_res = x_res + x_ori = x + x_ori_res + x_ori = x_ori + self.mlp(self.ln_2(x_ori)) + x += x_res + return [x, x_ori] + + # singl path before "d" + else: + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + +class ResidualAttentionBlock_learnable_token(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, design_details=None, + text_layer=False, i = 0): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + self.i = i + self.compound_prompt_nctx = design_details['learnabel_text_embedding_length'] + self.text_layer = text_layer + if i == 0: + self.first_layer = True + else: + self.first_layer = False + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + if isinstance(self.attn, Attention): + x = x.transpose(0, 1) + x, x_ori = self.attn(x) + return [x.transpose(0, 1), x_ori.transpose(0, 1)] + else: + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, inputs): + + # dual paths for blocks deeper than "d" + if isinstance(self.attn, Attention): + x = inputs[0] + if isinstance(x, list): + x, x_ori = x + x_res = self.attention(self.ln_1(x_ori)) + x_res, x_ori_res = x_res + x_ori += x_ori_res + x_ori = x_ori + self.mlp(self.ln_2(x_ori)) + x += x_res # skip ffn for the new path + return [x, x_ori] + + # start of dual path + else: + x_res = self.attention(self.ln_1(x)) + if isinstance(x_res, list): + x_res, x_ori_res = x_res + x_ori = x + x_ori_res + x_ori = x_ori + self.mlp(self.ln_2(x_ori)) + x += x_res + return [x, x_ori] + + # singl path before "d" + else: + x = inputs[0] + compound_prompts_deeper = inputs[1] + counter = inputs[2] + if not self.first_layer: + # First check if the ith layer needs compound prompts or not + if not (counter > len(compound_prompts_deeper) - 1): + # Appending the learnable tokens in different way + # x -> [77, NCLS, DIM] + # First remove the learnable tokens from previous layer + prefix = x[:1, :, :] + suffix = x[1 + self.compound_prompt_nctx:, :, :] + textual_context = compound_prompts_deeper[counter] + textual_context = textual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half() + # Add the learnable tokens of this layer with the input, replaced by previous + # layer learnable tokens + x = torch.cat([prefix, textual_context, suffix], dim=0) + # Once done, update the counter, so that the next time, it does not use same learnable tokens + counter += 1 + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return [x, compound_prompts_deeper, counter] + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False, design_details = None ,text_layer = False): + super().__init__() + self.width = width + self.layers = layers + self.text_layer = text_layer + self.design_deatails = design_details + print("text_layer", self.text_layer) + if self.text_layer and (design_details is not None): + self.resblocks = nn.ModuleList([ResidualAttentionBlock_learnable_token(width, heads, attn_mask, design_details, text_layer, i=i) for i in range(layers)]) + else: + self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, attn_mask,) for i in range(layers)]) + + def ori_CLIP_with_patch_forward(self, x, out_layers): + idx = 0 + out_tokens = [] + for r in self.resblocks: + idx += 1 + x = r(x) + if idx in out_layers: + if isinstance(x, list): + out_tokens.append(x[1]) + else: + out_tokens.append(x) + + return [x, x], out_tokens + + def AnomalyCLIP_forward(self, x, out_layers, ffn): + idx = 0 + out_tokens = [] + for r in self.resblocks: + idx += 1 + x = r(x, ffn = ffn) + # print("out_layers", out_layers, idx) + if idx in out_layers: + if isinstance(x, list): + out_tokens.append(x[0]) + else: + out_tokens.append(x) + return x, out_tokens + + def forward(self, x: torch.Tensor, out_layers = [6, 12, 18, 24], DPAM_layer = None, ffn = False): + # visual encoder forward + if not self.text_layer: + out_tokens = [] + + if DPAM_layer is None: + [x, x], out_tokens = self.ori_CLIP_with_patch_forward(x, out_layers) + return [x, x], out_tokens + else: + x, out_tokens = self.AnomalyCLIP_forward(x, out_layers, ffn) + return x, out_tokens + # text encoder forward + # ori text embedding + elif self.design_deatails is None: + for idx, r in enumerate(self.resblocks): + x = r(x) + return x + # insert learnable text embedding + elif self.design_deatails is not None: + for idx, r in enumerate(self.resblocks): + x = r(x) + return x[0] + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads, need_weights=True) + self.attn = None + self.embed_dim = width + self.num_heads = heads + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + + @torch.no_grad() + def DAPM_replace(self, DPAM_layer): + if DPAM_layer is not None: + for i in range(1, DPAM_layer): + self.attn = Attention(self.embed_dim, self.embed_dim, self.num_heads, True) + self.attn.qkv.weight.data = self.transformer.resblocks[-i].attn.in_proj_weight.clone() + self.attn.qkv.bias.data = self.transformer.resblocks[-i].attn.in_proj_bias.clone() + self.attn.proj.weight.data = self.transformer.resblocks[-i].attn.out_proj.weight.clone() + self.attn.proj.bias.data = self.transformer.resblocks[-i].attn.out_proj.bias.clone() + self.transformer.resblocks[-i].attn = self.attn + + @torch.no_grad() + def forward(self, x: torch.Tensor, features_list, ori_patch = False, proj_use = True, DPAM_layer = None, ffn = False): + + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + side = int((self.positional_embedding.shape[0] - 1) ** 0.5) + new_side = int((x.shape[1] - 1) ** 0.5) + + # update the position embedding during inference for varied input size + if side != new_side: + new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2) + new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear') + new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2) + self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0) + + pos = self.positional_embedding.to(x.dtype) + x = x + pos + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + [x, x_ori], patch_tokens = self.transformer(x, features_list, DPAM_layer = DPAM_layer, ffn = ffn) + + + if True: + patch_token_list = [] + for patch_token in patch_tokens: + patch_token = self.ln_post(patch_token.permute(1, 0, 2)) @ self.proj # LND -> NLD + patch_token_list.append(patch_token) + patch_tokens = patch_token_list + + return x_ori[0, :, :] @ self.proj, patch_tokens + + + return x + + +from thop import profile +class AnomalyCLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + design_details = None + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), text_layer=True, design_details=design_details + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image, feature_list = [], ori_patch = False, proj_use = True, DPAM_layer = None, ffn = False): + return self.visual(image.type(self.dtype), feature_list, ori_patch = ori_patch, proj_use = proj_use, DPAM_layer = DPAM_layer, ffn = ffn) + + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def encode_text_learn(self, prompts, tokenized_prompts, deep_compound_prompts_text = None, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + # x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + # x = x + self.positional_embedding.to(cast_dtype) + + x = prompts + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + # print("test", x.shape, len(deep_compound_prompts_text)) + if deep_compound_prompts_text is None: + x = self.transformer(x) + else: + x = self.transformer([x, deep_compound_prompts_text, 0]) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text diff --git a/AnomalyCLIP_lib/CLIP.py b/AnomalyCLIP_lib/CLIP.py new file mode 100644 index 0000000000000000000000000000000000000000..6846ebab15ca509b347796fc36efa01f49be1cf2 --- /dev/null +++ b/AnomalyCLIP_lib/CLIP.py @@ -0,0 +1,436 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + + side = int((self.positional_embedding.shape[0] - 1) ** 0.5) + new_side = int((x.shape[0] - 1) ** 0.5) + + # update the position embedding during inference for varied input size + if side != new_side: + new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2) + new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear') + new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2) + self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0) + + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + #return x[0] + return x.transpose(0, 1) # return both cls token and image tokens, B,N,C + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, need_weights: bool = False): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + self.need_weights = need_weights + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + if self.need_weights == False: + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + else: + return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask) + + def forward(self, x: torch.Tensor): + if self.need_weights == False: + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + else: + y, attn = self.attention(self.ln_1(x)) + x = x + y + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, need_weights if i == layers - 1 else False) for i in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads, need_weights=True) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + + ##################################################################################### + side = int((self.positional_embedding.shape[0] - 1) ** 0.5) + new_side = int((x.shape[1] - 1) ** 0.5) + + # update the position embedding during inference for varied input size + if side != new_side: + new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2) + new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear') + new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2) + self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0) + ##################################################################################### + + + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + #x = self.ln_post(x[:, 0, :]) + x = self.ln_post(x) # return both cls token and image tokens + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def encode_text_learn(self, prompts, tokenized_prompts, deep_compound_prompts_text = None, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + # x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + # x = x + self.positional_embedding.to(cast_dtype) + + x = prompts + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + # print("test", x.shape, len(deep_compound_prompts_text)) + if deep_compound_prompts_text is None: + x = self.transformer(x) + else: + x = self.transformer([x, deep_compound_prompts_text, 0]) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection + return x + + + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text diff --git a/AnomalyCLIP_lib/__init__.py b/AnomalyCLIP_lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..274e95764dc422743a4a9599dd99348eb94d9458 --- /dev/null +++ b/AnomalyCLIP_lib/__init__.py @@ -0,0 +1 @@ +from .model_load import * diff --git a/AnomalyCLIP_lib/bpe_simple_vocab_16e6.txt.gz b/AnomalyCLIP_lib/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/AnomalyCLIP_lib/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/AnomalyCLIP_lib/build_model.py b/AnomalyCLIP_lib/build_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5abf74ee9159e5847557f8e645adf35565ab673e --- /dev/null +++ b/AnomalyCLIP_lib/build_model.py @@ -0,0 +1,50 @@ +from torch import nn +from .CLIP import CLIP +from .AnomalyCLIP import AnomalyCLIP + +def build_model(name: str, state_dict: dict, design_details = None): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + # print('name', name) + # if 'CS-' in name: + if design_details is not None: + model = AnomalyCLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, design_details = design_details + ) + else: + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + #convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/AnomalyCLIP_lib/constants.py b/AnomalyCLIP_lib/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a670bb3fab442baeb9af53b91c312e6982af57ee --- /dev/null +++ b/AnomalyCLIP_lib/constants.py @@ -0,0 +1,2 @@ +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) diff --git a/AnomalyCLIP_lib/model_load.py b/AnomalyCLIP_lib/model_load.py new file mode 100644 index 0000000000000000000000000000000000000000..c91d4ad619b8503854e0e88800ae3c7fd13146c8 --- /dev/null +++ b/AnomalyCLIP_lib/model_load.py @@ -0,0 +1,235 @@ +import hashlib +import os +import urllib +import warnings +from typing import Union, List +from pkg_resources import packaging + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, ToTensor, Normalize +from tqdm import tqdm +import numpy as np + +from .build_model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer +from torchvision.transforms import InterpolationMode + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", + "get_similarity_map", "compute_similarity"] +_tokenizer = _Tokenizer() + +_MODELS = { + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download( + url: str, + cache_dir: Union[str, None] = None, +): + + if not cache_dir: + # cache_dir = os.path.expanduser("~/.cache/clip") + cache_dir = os.path.expanduser("/remote-home/iot_zhouqihang/root/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + elif 'mlfoundations' in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = '' + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize((n_px, n_px), interpolation=InterpolationMode.BICUBIC), + #CenterCrop(n_px), # rm center crop to explain whole image + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load_state_dict(checkpoint_path: str, map_location='cpu'): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith('module'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + return state_dict + +def load_checkpoint(model, checkpoint_path, strict=True): + state_dict = load_state_dict(checkpoint_path) + # detect old format and make compatible with new format + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + state_dict = convert_to_custom_text_state_dict(state_dict) + resize_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", design_details = None, jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + print("name", name) + if name in _MODELS: + # model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + model_path = _download(_MODELS[name], download_root or os.path.expanduser("/remote-home/iot_zhouqihang/root/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(name, state_dict or model.state_dict(), design_details).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def get_similarity_map(sm, shape): + side = int(sm.shape[1] ** 0.5) + sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2) + sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear') + sm = sm.permute(0, 2, 3, 1) + return sm + + +def compute_similarity(image_features, text_features, t=2): + prob_1 = image_features[:, :1, :] @ text_features.t() + b, n_t, n_i, c = image_features.shape[0], text_features.shape[0], image_features.shape[1], image_features.shape[2] + feats = image_features.reshape(b, n_i, 1, c) * text_features.reshape(1, 1, n_t, c) + similarity = feats.sum(-1) + return (similarity/0.07).softmax(-1), prob_1 diff --git a/AnomalyCLIP_lib/simple_tokenizer.py b/AnomalyCLIP_lib/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/AnomalyCLIP_lib/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/AnomalyCLIP_lib/transform.py b/AnomalyCLIP_lib/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..748884a3c7cb7ece1ca521ca1dbf40bb74855007 --- /dev/null +++ b/AnomalyCLIP_lib/transform.py @@ -0,0 +1,133 @@ +import warnings +from dataclasses import dataclass, asdict +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F + +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ + CenterCrop + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + +@dataclass +class AugmentationCfg: + scale: Tuple[float, float] = (0.9, 1.0) + ratio: Optional[Tuple[float, float]] = None + color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None + interpolation: Optional[str] = None + re_prob: Optional[float] = None + re_count: Optional[int] = None + use_timm: bool = False + + +class ResizeMaxSize(nn.Module): + + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == 'min' else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[:2] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) + return img + + +def _convert_to_rgb(image): + return image.convert('RGB') + + +def image_transform( + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + + if isinstance(aug_cfg, dict): + aug_cfg = AugmentationCfg(**aug_cfg) + else: + aug_cfg = aug_cfg or AugmentationCfg() + normalize = Normalize(mean=mean, std=std) + if is_train: + aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} + use_timm = aug_cfg_dict.pop('use_timm', False) + if use_timm: + from timm.data import create_transform # timm can still be optional + if isinstance(image_size, (tuple, list)): + assert len(image_size) >= 2 + input_size = (3,) + image_size[-2:] + else: + input_size = (3, image_size, image_size) + # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time + aug_cfg_dict.setdefault('interpolation', 'random') + aug_cfg_dict.setdefault('color_jitter', None) # disable by default + train_transform = create_transform( + input_size=input_size, + is_training=True, + hflip=0., + mean=mean, + std=std, + re_mode='pixel', + **aug_cfg_dict, + ) + else: + train_transform = Compose([ + RandomResizedCrop( + image_size, + scale=aug_cfg_dict.pop('scale'), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ]) + if aug_cfg_dict: + warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') + return train_transform + else: + if resize_longest_max: + transforms = [ + ResizeMaxSize(image_size, fill=fill_color) + ] + else: + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ] + transforms.extend([ + _convert_to_rgb, + ToTensor(), + normalize, + ]) + return Compose(transforms) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..b7fbac0b908b35ce07c2716271aa1e76bd1c012c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,14 @@ +# ----------------------------------------------------------------------------- +# A sample Dockerfile to help you replicate our test environment +# ----------------------------------------------------------------------------- + +FROM pytorch/pytorch:2.4.1-cuda12.4-cudnn9-runtime +WORKDIR /app +COPY . . + +# Install your python and apt requirements +RUN pip install -r requirements.txt +RUN apt-get update && apt-get install $(cat apt_requirements.txt) -y +RUN chmod +x run.sh + +CMD ["python3", "runner.py"] \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..87a3e418a78680b6483591b1575fc62b037442de --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Qihang Zhou + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0e636ffbcbce06729349bfd3a61d9cd3ada7a59d --- /dev/null +++ b/README.md @@ -0,0 +1,142 @@ +# AnomalyCLIP (Train once and test other) +> [**ICLR 24**] [**AnomalyCLIP: Object-agnostic Prompt Learning for Zero-shot Anomaly Detection**](https://arxiv.org/pdf/2310.18961.pdf) +> +> by [Qihang Zhou*](), [Guansong Pang*](https://www.guansongpang.com/), [Yu Tian](https://yutianyt.com/), [Shibo He](https://scholar.google.com/citations?hl=zh-CN&user=5GOcb4gAAAAJ&view_op=list_works&sortby=pubdate), [Jiming Chen](https://scholar.google.com/citations?user=zK9tvo8AAAAJ&hl=zh-CN). + + +## Updates + +- **03.19.2024**: Code has been released !!! +- **08.08.2024**: Update the code for testing one image. + +## Introduction +Zero-shot anomaly detection (ZSAD) requires detection models trained using auxiliary data to detect anomalies without any training sample in a target dataset. It is a crucial task when training data is not accessible due to various concerns, e.g., data privacy, yet it is challenging since the models need to generalize to anomalies across different domains where the appearance of foreground objects, abnormal regions, and background features, such as defects/tumors on different products/organs, can vary significantly. Recently large pre-trained vision-language models (VLMs), such as CLIP, +have demonstrated strong zero-shot recognition ability in various vision tasks, including anomaly detection. However, their ZSAD performance is weak since the VLMs focus more on modeling the class semantics of the foreground objects rather than the abnormality/normality in the images. +In this paper we introduce a novel approach, namely AnomalyCLIP, to adapt CLIP for accurate ZSAD across different domains. The key insight of AnomalyCLIP is to learn object-agnostic text prompts that capture generic normality and abnormality in an image regardless of its foreground objects. This allows our model to focus on the abnormal image regions rather than the object semantics, enabling generalized normality and abnormality recognition on diverse types of objects. Large-scale experiments on 17 real-world anomaly detection datasets show that AnomalyCLIP achieves superior zero-shot performance of detecting and segmenting anomalies in datasets of highly diverse class semantics from various defect inspection and medical imaging domains. All experiments are conducted in PyTorch-2.0.0 with a single NVIDIA RTX 3090 24GB. + +## Overview of AnomalyCLIP +![overview](https://github.com/zqhang/AnomalyCLIP/assets/19222962/4ec3e5fc-9570-41f7-8067-6e7a515841be) + + +## Analysis of different text prompt templates +![analysis](./assets/analysis.png) + + +## How to Run +### Prepare your dataset +Download the dataset below: + +* Industrial Domain: +[MVTec](https://www.mvtec.com/company/research/datasets/mvtec-ad), [VisA](https://github.com/amazon-science/spot-diff), [MPDD](https://github.com/stepanje/MPDD), [BTAD](http://avires.dimi.uniud.it/papers/btad/btad.zip), [SDD](https://www.vicos.si/resources/kolektorsdd/), [DAGM](https://www.kaggle.com/datasets/mhskjelvareid/dagm-2007-competition-dataset-optical-inspection), [DTD-Synthetic](https://drive.google.com/drive/folders/10OyPzvI3H6llCZBxKxFlKWt1Pw1tkMK1) + +* Medical Domain: +[HeadCT](https://www.kaggle.com/datasets/felipekitamura/head-ct-hemorrhage), [BrainMRI](https://www.kaggle.com/datasets/navoneel/brain-mri-images-for-brain-tumor-detection), [Br35H](https://www.kaggle.com/datasets/ahmedhamada0/brain-tumor-detection), [COVID-19](https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database), [ISIC](https://isic-challenge-data.s3.amazonaws.com/2016/ISBI2016_ISIC_Part1_Test_Data.zip), [CVC-ColonDB](https://figshare.com/articles/figure/Polyp_DataSet_zip/21221579), [CVC-ClinicDB](https://figshare.com/articles/figure/Polyp_DataSet_zip/21221579), [Kvasir](https://figshare.com/articles/figure/Polyp_DataSet_zip/21221579), [Endo](https://drive.google.com/file/d/1LNpLkv5ZlEUzr_RPN5rdOHaqk0SkZa3m/view), [TN3K](https://github.com/haifangong/TRFE-Net-for-thyroid-nodule-segmentation?tab=readme-ov-file). + +* Google Drive link (frequently requested dataset): [SDD](https://drive.google.com/drive/folders/1oqaxUZYi44jlLT4WtT6D5T6onPTNZXsu?usp=drive_link), [Br35H](https://drive.google.com/file/d/1l9XODMBm4X23K70LtpxAxgoaBbNzr4Nc/view?usp=drive_link), [COVID-19](https://drive.google.com/file/d/1ECwI8DJmhEtcVHatxCAdFqnSmXs35WFL/view?usp=drive_link) +### Generate the dataset JSON +Take MVTec AD for example (With multiple anomaly categories) + +Structure of MVTec Folder: +``` +mvtec/ +│ +├── meta.json +│ +├── bottle/ +│ ├── ground_truth/ +│ │ ├── broken_large/ +│ │ │ └── 000_mask.png +| | | └── ... +│ │ └── ... +│ └── test/ +│ ├── broken_large/ +│ │ └── 000.png +| | └── ... +│ └── ... +│ +└── ... +``` + +```bash +cd generate_dataset_json +python mvtec.py +``` + +Take SDD for example (With single anomaly category) + +Structure of SDD Folder: +``` +SDD/ +│ +├── electrical_commutators/ +│ └── test/ +│ ├── defect/ +│ │ └── kos01_Part5_0.png +| | └── ... +│ └── good/ +│ └── kos01_Part0_0.png +│ └── ... +│ +└── meta.json +``` + +```bash +cd generate_dataset_json +python SDD.py +``` +Select the corresponding script and run it (we provide all scripts for datasets that AnomalyCLIP reported). The generated JSON stores all the information that AnomalyCLIP needs. + +### Custom dataset (optional) +1. Create a new JSON script in fold [generate_dataset_json](https://github.com/zqhang/AnomalyCLIP/tree/main/generate_dataset_json) according to the fold structure of your own datasets. +2. Add the related info of your dataset (i.e., dataset name and class names) in script [dataset\.py](https://github.com/zqhang/AnomalyCLIP/blob/main/dataset.py) + +### Run AnomalyCLIP +* Quick start (use the pre-trained weights) +```bash +bash test.sh +``` + +* Train your own weights +```bash +bash train.sh +``` + + +## Main results (We test all datasets by training once on MVTec AD. For MVTec AD, AnomalyCLIP is trained on VisA.) + +### Industrial dataset +![industrial](./assets/Industrial.png) + + +### Medical dataset +![medical](./assets/medical.png) + + +## Visualization + +![hazelnut](./assets/hazelnut.png) + +![capusle](./assets/capusle.png) + +![skin](./assets/skin.png) + +![brain](./assets/brain.png) + + +## We provide the reproduction of WinCLIP [here](https://github.com/zqhang/WinCLIP-pytorch) + + +* We thank for the code repository: [open_clip](https://github.com/mlfoundations/open_clip), [DualCoOp](https://github.com/sunxm2357/DualCoOp), [CLIP_Surgery](https://github.com/xmed-lab/CLIP_Surgery), and [VAND](https://github.com/ByChelsea/VAND-APRIL-GAN/tree/master). + +## BibTex Citation + +If you find this paper and repository useful, please cite our paper. + +``` +@inproceedings{zhou2023anomalyclip, + title={AnomalyCLIP: Object-agnostic Prompt Learning for Zero-shot Anomaly Detection}, + author={Zhou, Qihang and Pang, Guansong and Tian, Yu and He, Shibo and Chen, Jiming}, + booktitle={The Twelfth International Conference on Learning Representations}, + year={2023} +} +``` diff --git a/checkpoints/9_12_4_multiscale/epoch_1.pth b/checkpoints/9_12_4_multiscale/epoch_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..9de907be7641aee663990403416b681b8971321f --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a89d1ffe49d86995e936c8e91515efa878d4e1777c73888622091e89a8df9e5b +size 22631493 diff --git a/checkpoints/9_12_4_multiscale/epoch_10.pth b/checkpoints/9_12_4_multiscale/epoch_10.pth new file mode 100644 index 0000000000000000000000000000000000000000..76005cd8f1038b654fc079183208a716f76249bb --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_10.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7205c05df3319984b349686cbfd8cc01d3ac241a82f33943e9217cbb85604b0b +size 22631975 diff --git a/checkpoints/9_12_4_multiscale/epoch_11.pth b/checkpoints/9_12_4_multiscale/epoch_11.pth new file mode 100644 index 0000000000000000000000000000000000000000..1733d13d226ed327c66e61d363819167b4f1c487 --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_11.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40017b0588b3e41aea4cf3902b388bbee494201b4406583f0a9c96f90818a986 +size 22631975 diff --git a/checkpoints/9_12_4_multiscale/epoch_12.pth b/checkpoints/9_12_4_multiscale/epoch_12.pth new file mode 100644 index 0000000000000000000000000000000000000000..8b9dd47c75c84053ac83aed6819e04beec37ebf5 --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_12.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef4bdfad5689797d48296eeceb57343aabba5ae5a2c7e57d4b9e225d2d254252 +size 22631975 diff --git a/checkpoints/9_12_4_multiscale/epoch_13.pth b/checkpoints/9_12_4_multiscale/epoch_13.pth new file mode 100644 index 0000000000000000000000000000000000000000..0ad22500a4db53d88b3a97bf74aad2b0963dcf8f --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_13.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4381596b44bbaa33e7b04b4a19a46582980f1ee8742414d71147c8be95ef90d7 +size 22631975 diff --git a/checkpoints/9_12_4_multiscale/epoch_14.pth b/checkpoints/9_12_4_multiscale/epoch_14.pth new file mode 100644 index 0000000000000000000000000000000000000000..7ffbe4128dc0e6ad5dd0bc1dd634defa50662eec --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_14.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd2a3865c4cf1363b80f301da7dc181a54787e3c218cc1f3464650a5f749cb26 +size 22631975 diff --git a/checkpoints/9_12_4_multiscale/epoch_15.pth b/checkpoints/9_12_4_multiscale/epoch_15.pth new file mode 100644 index 0000000000000000000000000000000000000000..804b02d75965ac2c70d06e8c01939047c1307bbe --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_15.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:94ce202da3e6486a864b904fdfed5057de75846c5834e446fd1d2fe7f97acb44 +size 22631975 diff --git a/checkpoints/9_12_4_multiscale/epoch_2.pth b/checkpoints/9_12_4_multiscale/epoch_2.pth new file mode 100644 index 0000000000000000000000000000000000000000..1ff196af7cea6cd4e807a0b0725a44aa9362ff39 --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_2.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f6bfcd2ed1725b3d58dd06d5d38f7ef6d3b9c49d817bb4714a16f3153c3d7450 +size 22631493 diff --git a/checkpoints/9_12_4_multiscale/epoch_3.pth b/checkpoints/9_12_4_multiscale/epoch_3.pth new file mode 100644 index 0000000000000000000000000000000000000000..06bc40dd259e2bfda4ecc1992720bf63b3a2f1d5 --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_3.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5af4c383158732845ac2ef195e5036e8528f187ed80173c8d993830a0abed64c +size 22631493 diff --git a/checkpoints/9_12_4_multiscale/epoch_4.pth b/checkpoints/9_12_4_multiscale/epoch_4.pth new file mode 100644 index 0000000000000000000000000000000000000000..b36f3a2a8d79e7ad3239d381bab089c2b0640e50 --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ab9a9909711c89cac5f02f0c46c7baac82b09bfaca59a83271a50b195cad89f +size 22631493 diff --git a/checkpoints/9_12_4_multiscale/epoch_5.pth b/checkpoints/9_12_4_multiscale/epoch_5.pth new file mode 100644 index 0000000000000000000000000000000000000000..bd343f47d3e8feb9dce9937bccd23cebc4ca5057 --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_5.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:317837a0ef5b46d2476c234d3fa77e8cfab7bbfa85711f5fe7eb7f50ea7151a0 +size 22631493 diff --git a/checkpoints/9_12_4_multiscale/epoch_6.pth b/checkpoints/9_12_4_multiscale/epoch_6.pth new file mode 100644 index 0000000000000000000000000000000000000000..f233177e95df66b960f79a745776fb9d201b4193 --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_6.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04379155c0df8d4e1194335427091e626df512a9747e47c1bbb7ee3a55708164 +size 22631493 diff --git a/checkpoints/9_12_4_multiscale/epoch_7.pth b/checkpoints/9_12_4_multiscale/epoch_7.pth new file mode 100644 index 0000000000000000000000000000000000000000..6e5ad06bb822ae887cbbcb23e7fd443aa91f4f5e --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_7.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:41c5a77a355c27266d6a9c7b6da4b3ee2c193596873d889822e68a797a2688b2 +size 22631493 diff --git a/checkpoints/9_12_4_multiscale/epoch_8.pth b/checkpoints/9_12_4_multiscale/epoch_8.pth new file mode 100644 index 0000000000000000000000000000000000000000..b4555671b56f7ae48760e1896e38841775b5c75e --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_8.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c92bfa088eccb2efb71b27c9703c0f21158903581efd7292f42938ad96940c82 +size 22631493 diff --git a/checkpoints/9_12_4_multiscale/epoch_9.pth b/checkpoints/9_12_4_multiscale/epoch_9.pth new file mode 100644 index 0000000000000000000000000000000000000000..02111f2d66924c7b21be86502cd2fb538be257bd --- /dev/null +++ b/checkpoints/9_12_4_multiscale/epoch_9.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43f0eca2d506b88370a06c94a6cd557360c7bcb179a4f3f24981230349a9581a +size 22631493 diff --git a/checkpoints/9_12_4_multiscale/log.txt b/checkpoints/9_12_4_multiscale/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_1.pth b/checkpoints/9_12_4_multiscale_visa/epoch_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..6ca517c272cfbb075af9d218e1b1bd08f8ad34c2 --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de5df7fc2ec18acb5709e65b1889d586974d365c39d1aa4df728336633e4ee70 +size 22631493 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_10.pth b/checkpoints/9_12_4_multiscale_visa/epoch_10.pth new file mode 100644 index 0000000000000000000000000000000000000000..9b58f49df4bd6b294e28b91c41284d149f59fb80 --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_10.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:397255934bd313beeab2b610fa901f113e12342974687147cad78f502e5ae7e5 +size 22631975 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_11.pth b/checkpoints/9_12_4_multiscale_visa/epoch_11.pth new file mode 100644 index 0000000000000000000000000000000000000000..5694092e3aa6ab4761ab42915536ad069f6ee77a --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_11.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:843fb9df1c46da89f6976a42d10d5fe34675ad48eccb365e3f43785f925c2ae9 +size 22631975 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_12.pth b/checkpoints/9_12_4_multiscale_visa/epoch_12.pth new file mode 100644 index 0000000000000000000000000000000000000000..e3774b7d82442ff2c1f715550c8372dd3a023793 --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_12.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:17f69ad9ae4bcc5823fdd9ad56b51ec57cc641270280a1776c1014ea1969f282 +size 22631975 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_13.pth b/checkpoints/9_12_4_multiscale_visa/epoch_13.pth new file mode 100644 index 0000000000000000000000000000000000000000..79809c4eeb736faeb889fdedc8646191bbcfe919 --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_13.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5bf5fd9c269e3f68e81134f4361c3239ba14d5f2cd4e3564f93f5b59f616cd19 +size 22631975 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_14.pth b/checkpoints/9_12_4_multiscale_visa/epoch_14.pth new file mode 100644 index 0000000000000000000000000000000000000000..8850b5937a7f767f1e2cd88f76b3b78a9b68e690 --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_14.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:969dbaaa1a986f17d79dfb81d2ce90443d0e9dd9f19db7fd9a9190f97cc8e3d4 +size 22631975 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_15.pth b/checkpoints/9_12_4_multiscale_visa/epoch_15.pth new file mode 100644 index 0000000000000000000000000000000000000000..29dfd88895087f41f2a73ec163f94e6c274b0f0a --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_15.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:415c5dcb52668b8c33fb9c1a351c686d632b919df5b384d63fa9ce7a2338ced4 +size 22631975 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_2.pth b/checkpoints/9_12_4_multiscale_visa/epoch_2.pth new file mode 100644 index 0000000000000000000000000000000000000000..0013bc03d261bd9657ca992782a3d49cb5db71f5 --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_2.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c98c722977ac0fc42c1067a8038656c10466728f6e9d448aad9e3f6b3d5368b6 +size 22631493 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_3.pth b/checkpoints/9_12_4_multiscale_visa/epoch_3.pth new file mode 100644 index 0000000000000000000000000000000000000000..d645b4d01bcc96c392d939f8be4ffbde6f96753b --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_3.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3e7a65d6b9ff057b5fa53bfc59bfa57a25619b5a5d9cd40ed37579e312ab4aa +size 22631493 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_4.pth b/checkpoints/9_12_4_multiscale_visa/epoch_4.pth new file mode 100644 index 0000000000000000000000000000000000000000..7f536754d050e89db80bab5e233ccacff2d532f0 --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f56b0ed7bd9da05f77780a3c4318e038c258b99a02ad1455652cad146b3dded5 +size 22631493 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_5.pth b/checkpoints/9_12_4_multiscale_visa/epoch_5.pth new file mode 100644 index 0000000000000000000000000000000000000000..8d6a71d147bed502a2a4b1774457e7c42724b026 --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_5.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2c44c082a19abde2993e80044466c1e45a620cc24aad39e85bd65ed60d3572d +size 22631493 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_6.pth b/checkpoints/9_12_4_multiscale_visa/epoch_6.pth new file mode 100644 index 0000000000000000000000000000000000000000..bd67d30d1c1f3ec650b4d4bcd16e95dfca67e68f --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_6.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:402d63bca2150631fb09d8d1c7529712a4ee8eea29bd7746412eae99b4ec6dc5 +size 22631493 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_7.pth b/checkpoints/9_12_4_multiscale_visa/epoch_7.pth new file mode 100644 index 0000000000000000000000000000000000000000..39c7004b005f92a91dbbd6799b567f5699d70e7e --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_7.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:081526236212ebc011ec53babaf8f0da7e25fbe92300aa7cc68eb41ca29b054f +size 22631493 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_8.pth b/checkpoints/9_12_4_multiscale_visa/epoch_8.pth new file mode 100644 index 0000000000000000000000000000000000000000..bf7a3a25532099d133340b84ee6da107fdf5ade5 --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_8.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f2587be72657ab30fc26bc5957e130ba7359ff53c32beb7984be517a818427c +size 22631493 diff --git a/checkpoints/9_12_4_multiscale_visa/epoch_9.pth b/checkpoints/9_12_4_multiscale_visa/epoch_9.pth new file mode 100644 index 0000000000000000000000000000000000000000..bdb5f0db4d41bb39065e2bc82ca8865545d753b9 --- /dev/null +++ b/checkpoints/9_12_4_multiscale_visa/epoch_9.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4850f209b34912c33718b86c13d2a01c340907d182236a8ef8903f35c80daec0 +size 22631493 diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dcaf4c2a0b435af8c7b5fa6a3f890421558617b2 --- /dev/null +++ b/dataset.py @@ -0,0 +1,50 @@ +import torch.utils.data as data +import json +import random +from PIL import Image +import numpy as np +import torch +import os + +class Dataset(data.Dataset): + def __init__(self, root, transform, target_transform, dataset_name, mode='test'): + self.root = root + self.transform = transform + self.target_transform = target_transform + self.data_all = [] + meta_info = json.load(open(f'{self.root}/meta.json', 'r')) + name = self.root.split('/')[-1] + meta_info = meta_info[mode] + + self.cls_names = list(meta_info.keys()) + for cls_name in self.cls_names: + self.data_all.extend(meta_info[cls_name]) + self.length = len(self.data_all) + + self.obj_list = [folder for folder in os.listdir(root) if os.path.isdir(os.path.join(root, folder)) and not folder.startswith('.')] + self.class_name_map_class_id = {o: i for i, o in enumerate(self.obj_list)} + + def __len__(self): + return self.length + + def __getitem__(self, index): + data = self.data_all[index] + img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \ + data['specie_name'], data['anomaly'] + img = Image.open(os.path.join(self.root, img_path)) + if anomaly == 0: + img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') + else: + if os.path.isdir(os.path.join(self.root, mask_path)): + # just for classification not report error + img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') + else: + img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0 + img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L') + # transforms + img = self.transform(img) if self.transform is not None else img + img_mask = self.target_transform( + img_mask) if self.target_transform is not None and img_mask is not None else img_mask + img_mask = [] if img_mask is None else img_mask + return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly, + 'img_path': os.path.join(self.root, img_path), "cls_id": self.class_name_map_class_id[cls_name]} diff --git a/datasets/rayan_dataset.py b/datasets/rayan_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..de12e9dd57e57ca8b923c2eee315d6456d95b757 --- /dev/null +++ b/datasets/rayan_dataset.py @@ -0,0 +1,127 @@ +# ----------------------------------------------------------------------------- +# Do Not Alter This File! +# ----------------------------------------------------------------------------- +# The following code is part of the logic used for loading and evaluating your +# output scores. Please DO NOT modify this section, as upon your submission, +# the whole evaluation logic will be overwritten by the original code. +# ----------------------------------------------------------------------------- +# If you'd like to make modifications, you can create a completely new Dataset +# class or a child class that inherits from this one and use that with your +# data loader. +# ----------------------------------------------------------------------------- + +import os +from enum import Enum + +import PIL +import torch +from torchvision import transforms + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + + +class DatasetSplit(Enum): + TRAIN = "train" + VAL = "val" + TEST = "test" + + +class RayanDataset(torch.utils.data.Dataset): + def __init__( + self, + source, + classname, + input_size=518, + output_size=224, + split=DatasetSplit.TEST, + external_transform=None, + **kwargs, + ): + super().__init__() + self.source = source + self.split = split + self.classnames_to_use = [classname] + self.imgpaths_per_class, self.data_to_iterate = self.get_image_data() + + if external_transform is None: + self.transform_img = [ + transforms.Resize((input_size, input_size)), + transforms.CenterCrop(input_size), + transforms.ToTensor(), + transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ] + self.transform_img = transforms.Compose(self.transform_img) + else: + self.transform_img = external_transform + + # Output size of the mask has to be of shape: 1×224×224 + self.transform_mask = [ + transforms.Resize((output_size, output_size)), + transforms.CenterCrop(output_size), + transforms.ToTensor(), + ] + self.transform_mask = transforms.Compose(self.transform_mask) + self.output_shape = (1, output_size, output_size) + + def __getitem__(self, idx): + classname, anomaly, image_path, mask_path = self.data_to_iterate[idx] + image = PIL.Image.open(image_path).convert("RGB") + image = self.transform_img(image) + + if self.split == DatasetSplit.TEST and mask_path is not None: + mask = PIL.Image.open(mask_path).convert("L") + mask = self.transform_mask(mask) > 0 + else: + mask = torch.zeros([*self.output_shape]) + + return { + "image": image, + "mask": mask, + "is_anomaly": int(anomaly != "good"), + "image_path": image_path, + } + + def __len__(self): + return len(self.data_to_iterate) + + def get_image_data(self): + imgpaths_per_class = {} + maskpaths_per_class = {} + + for classname in self.classnames_to_use: + classpath = os.path.join(self.source, classname, self.split.value) + maskpath = os.path.join(self.source, classname, "ground_truth") + anomaly_types = os.listdir(classpath) + + imgpaths_per_class[classname] = {} + maskpaths_per_class[classname] = {} + + for anomaly in anomaly_types: + anomaly_path = os.path.join(classpath, anomaly) + anomaly_files = sorted(os.listdir(anomaly_path)) + imgpaths_per_class[classname][anomaly] = [ + os.path.join(anomaly_path, x) for x in anomaly_files + ] + + if self.split == DatasetSplit.TEST and anomaly != "good": + anomaly_mask_path = os.path.join(maskpath, anomaly) + anomaly_mask_files = sorted(os.listdir(anomaly_mask_path)) + maskpaths_per_class[classname][anomaly] = [ + os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files + ] + else: + maskpaths_per_class[classname]["good"] = None + + data_to_iterate = [] + for classname in sorted(imgpaths_per_class.keys()): + for anomaly in sorted(imgpaths_per_class[classname].keys()): + for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]): + data_tuple = [classname, anomaly, image_path] + if self.split == DatasetSplit.TEST and anomaly != "good": + data_tuple.append(maskpaths_per_class[classname][anomaly][i]) + else: + data_tuple.append(None) + data_to_iterate.append(data_tuple) + + return imgpaths_per_class, data_to_iterate diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..d6ea4778c91ff3d4e97f220a1db0e166099a857e --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,21 @@ +# ----------------------------------------------------------------------------- +# A sample Docker Compose file to help you replicate our test environment +# ----------------------------------------------------------------------------- + +services: + zsad-service: + image: zsad-image:1 + build: + context: . + container_name: zsad-container + volumes: + - ./shared_folder:/app/output + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + + command: [ "python3", "runner.py" ] diff --git a/evaluation/base_eval.py b/evaluation/base_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..4cca27668ae97cac64394a0d687ca7e707a343f7 --- /dev/null +++ b/evaluation/base_eval.py @@ -0,0 +1,293 @@ +# ----------------------------------------------------------------------------- +# Do Not Alter This File! +# ----------------------------------------------------------------------------- +# The following code is part of the logic used for loading and evaluating your +# output scores. Please DO NOT modify this section, as upon your submission, +# the whole evaluation logic will be overwritten by the original code. +# ----------------------------------------------------------------------------- + +import warnings +import os +from pathlib import Path +import csv +import json +import torch + +import datasets.rayan_dataset as rayan_dataset +from evaluation.utils.metrics import compute_metrics + +warnings.filterwarnings("ignore") + + +class BaseEval: + def __init__(self, cfg): + self.cfg = cfg + self.device = torch.device( + "cuda:{}".format(cfg["device"]) if torch.cuda.is_available() else "cpu" + ) + + self.path = cfg["datasets"]["data_path"] + self.dataset = cfg["datasets"]["dataset_name"] + self.save_csv = cfg["testing"]["save_csv"] + self.save_json = cfg["testing"]["save_json"] + self.categories = cfg["datasets"]["class_name"] + if isinstance(self.categories, str): + if self.categories.lower() == "all": + if self.dataset == "rayan_dataset": + self.categories = self.get_available_class_names(self.path) + else: + self.categories = [self.categories] + self.output_dir = cfg["testing"]["output_dir"] + os.makedirs(self.output_dir, exist_ok=True) + self.scores_dir = cfg["testing"]["output_scores_dir"] + self.class_name_mapping_dir = cfg["testing"]["class_name_mapping_dir"] + + self.leaderboard_metric_weights = { + "image_auroc": 1.2, + "image_ap": 1.1, + "image_f1": 1.1, + "pixel_auroc": 1.0, + "pixel_aupro": 1.4, + "pixel_ap": 1.3, + "pixel_f1": 1.3, + } + + def get_available_class_names(self, root_data_path): + all_items = os.listdir(root_data_path) + folder_names = [ + item + for item in all_items + if os.path.isdir(os.path.join(root_data_path, item)) + ] + + return folder_names + + def load_datasets(self, category): + dataset_classes = { + "rayan_dataset": rayan_dataset.RayanDataset, + } + + dataset_splits = { + "rayan_dataset": rayan_dataset.DatasetSplit.TEST, + } + + test_dataset = dataset_classes[self.dataset]( + source=self.path, + split=dataset_splits[self.dataset], + classname=category, + ) + return test_dataset + + def get_category_metrics(self, category): + print(f"Loading scores of '{category}'") + gt_sp, pr_sp, gt_px, pr_px, _ = self.load_category_scores(category) + + print(f"Computing metrics for '{category}'") + image_metric, pixel_metric = compute_metrics(gt_sp, pr_sp, gt_px, pr_px) + + return image_metric, pixel_metric + + def load_category_scores(self, category): + raise NotImplementedError() + + def get_scores_path_for_image(self, image_path): + """example image_path: './data/photovoltaic_module/test/good/037.png'""" + path = Path(image_path) + + category, split, anomaly_type = path.parts[-4:-1] + image_name = path.stem + + return os.path.join( + self.scores_dir, category, split, anomaly_type, f"{image_name}_scores.json" + ) + + def calc_leaderboard_score(self, **metrics): + weighted_sum = 0 + total_weight = 0 + for key, weight in self.leaderboard_metric_weights.items(): + metric = metrics.get(key) + weighted_sum += metric * weight + total_weight += weight + + if total_weight == 0: + return 0 + + return weighted_sum / total_weight + + def main(self): + image_auroc_list = [] + image_f1_list = [] + image_ap_list = [] + pixel_auroc_list = [] + pixel_f1_list = [] + pixel_ap_list = [] + pixel_aupro_list = [] + leaderboard_score_list = [] + for category in self.categories: + image_metric, pixel_metric = self.get_category_metrics( + category=category, + ) + image_auroc, image_f1, image_ap = image_metric + pixel_auroc, pixel_f1, pixel_ap, pixel_aupro = pixel_metric + leaderboard_score = self.calc_leaderboard_score( + image_auroc=image_auroc, + image_f1=image_f1, + image_ap=image_ap, + pixel_auroc=pixel_auroc, + pixel_aupro=pixel_aupro, + pixel_f1=pixel_f1, + pixel_ap=pixel_ap, + ) + + image_auroc_list.append(image_auroc) + image_f1_list.append(image_f1) + image_ap_list.append(image_ap) + pixel_auroc_list.append(pixel_auroc) + pixel_f1_list.append(pixel_f1) + pixel_ap_list.append(pixel_ap) + pixel_aupro_list.append(pixel_aupro) + leaderboard_score_list.append(leaderboard_score) + + print(category) + print( + "[image level] auroc:{}, f1:{}, ap:{}".format( + image_auroc * 100, + image_f1 * 100, + image_ap * 100, + ) + ) + print( + "[pixel level] auroc:{}, f1:{}, ap:{}, aupro:{}".format( + pixel_auroc * 100, + pixel_f1 * 100, + pixel_ap * 100, + pixel_aupro * 100, + ) + ) + print( + "leaderboard score:{}".format( + leaderboard_score * 100, + ) + ) + + image_auroc_mean = sum(image_auroc_list) / len(image_auroc_list) + image_f1_mean = sum(image_f1_list) / len(image_f1_list) + image_ap_mean = sum(image_ap_list) / len(image_ap_list) + pixel_auroc_mean = sum(pixel_auroc_list) / len(pixel_auroc_list) + pixel_f1_mean = sum(pixel_f1_list) / len(pixel_f1_list) + pixel_ap_mean = sum(pixel_ap_list) / len(pixel_ap_list) + pixel_aupro_mean = sum(pixel_aupro_list) / len(pixel_aupro_list) + leaderboard_score_mean = sum(leaderboard_score_list) / len( + leaderboard_score_list + ) + + print("mean") + print( + "[image level] auroc:{}, f1:{}, ap:{}".format( + image_auroc_mean * 100, image_f1_mean * 100, image_ap_mean * 100 + ) + ) + print( + "[pixel level] auroc:{}, f1:{}, ap:{}, aupro:{}".format( + pixel_auroc_mean * 100, + pixel_f1_mean * 100, + pixel_ap_mean * 100, + pixel_aupro_mean * 100, + ) + ) + print( + "leaderboard score:{}".format( + leaderboard_score_mean * 100, + ) + ) + + # Save the final results as a csv file + if self.save_csv: + with open(self.class_name_mapping_dir, "r") as f: + class_name_mapping_dict = json.load(f) + csv_data = [ + [ + "Category", + "pixel_auroc", + "pixel_f1", + "pixel_ap", + "pixel_aupro", + "image_auroc", + "image_f1", + "image_ap", + "leaderboard_score", + ] + ] + for i, category in enumerate(self.categories): + csv_data.append( + [ + class_name_mapping_dict[category], + pixel_auroc_list[i] * 100, + pixel_f1_list[i] * 100, + pixel_ap_list[i] * 100, + pixel_aupro_list[i] * 100, + image_auroc_list[i] * 100, + image_f1_list[i] * 100, + image_ap_list[i] * 100, + leaderboard_score_list[i] * 100, + ] + ) + csv_data.append( + [ + "mean", + pixel_auroc_mean * 100, + pixel_f1_mean * 100, + pixel_ap_mean * 100, + pixel_aupro_mean * 100, + image_auroc_mean * 100, + image_f1_mean * 100, + image_ap_mean * 100, + leaderboard_score_mean * 100, + ] + ) + + csv_file_path = os.path.join(self.output_dir, "results.csv") + with open(csv_file_path, mode="w", newline="") as file: + writer = csv.writer(file) + writer.writerows(csv_data) + + # Save the final results as a json file + if self.save_json: + json_data = [] + with open(self.class_name_mapping_dir, "r") as f: + class_name_mapping_dict = json.load(f) + for i, category in enumerate(self.categories): + json_data.append( + { + "Category": class_name_mapping_dict[category], + "pixel_auroc": pixel_auroc_list[i] * 100, + "pixel_f1": pixel_f1_list[i] * 100, + "pixel_ap": pixel_ap_list[i] * 100, + "pixel_aupro": pixel_aupro_list[i] * 100, + "image_auroc": image_auroc_list[i] * 100, + "image_f1": image_f1_list[i] * 100, + "image_ap": image_ap_list[i] * 100, + "leaderboard_score": leaderboard_score_list[i] * 100, + } + ) + json_data.append( + { + "Category": "mean", + "pixel_auroc": pixel_auroc_mean * 100, + "pixel_f1": pixel_f1_mean * 100, + "pixel_ap": pixel_ap_mean * 100, + "pixel_aupro": pixel_aupro_mean * 100, + "image_auroc": image_auroc_mean * 100, + "image_f1": image_f1_mean * 100, + "image_ap": image_ap_mean * 100, + "leaderboard_score": leaderboard_score_mean * 100, + } + ) + + json_file_path = os.path.join(self.output_dir, "results.json") + with open(json_file_path, mode="w") as file: + final_json = { + "result": leaderboard_score_mean * 100, + "metadata": json_data, + } + json.dump(final_json, file, indent=4) diff --git a/evaluation/class_name_mapping.json b/evaluation/class_name_mapping.json new file mode 100644 index 0000000000000000000000000000000000000000..f692cdfd79da985e6a4f2d71685ecb4c6e94fc62 --- /dev/null +++ b/evaluation/class_name_mapping.json @@ -0,0 +1,5 @@ +{ + "pill": "industrial_01", + "photovoltaic_module": "industrial_02", + "capsules": "industrial_03" +} diff --git a/evaluation/eval_main.py b/evaluation/eval_main.py new file mode 100644 index 0000000000000000000000000000000000000000..9c48160ad28cc624118c206342af3b75d051487f --- /dev/null +++ b/evaluation/eval_main.py @@ -0,0 +1,78 @@ +# ----------------------------------------------------------------------------- +# Do Not Alter This File! +# ----------------------------------------------------------------------------- +# The following code is part of the logic used for loading and evaluating your +# output scores. Please DO NOT modify this section, as upon your submission, +# the whole evaluation logic will be overwritten by the original code. +# ----------------------------------------------------------------------------- + +import warnings +import argparse +import os +import sys + +sys.path.append(os.getcwd()) +from evaluation.json_score import JsonScoreEvaluator + +warnings.filterwarnings("ignore") + + +def get_args(): + parser = argparse.ArgumentParser(description="Rayan ZSAD Evaluation Code") + parser.add_argument("--data_path", type=str, default=None, help="dataset path") + parser.add_argument("--dataset_name", type=str, default=None, help="dataset name") + parser.add_argument("--class_name", type=str, default=None, help="category") + parser.add_argument("--device", type=int, default=None, help="gpu id") + parser.add_argument( + "--output_dir", type=str, default=None, help="save results path" + ) + parser.add_argument( + "--output_scores_dir", type=str, default=None, help="save scores path" + ) + parser.add_argument("--save_csv", type=str, default=None, help="save csv") + parser.add_argument("--save_json", type=str, default=None, help="save json") + + parser.add_argument( + "--class_name_mapping_dir", + type=str, + default=None, + help="mapping from actual class names to class numbers", + ) + args = parser.parse_args() + return args + + +def load_args(cfg, args): + cfg["datasets"]["data_path"] = args.data_path + assert os.path.exists( + cfg["datasets"]["data_path"] + ), f"The dataset path {cfg['datasets']['data_path']} does not exist." + cfg["datasets"]["dataset_name"] = args.dataset_name + cfg["datasets"]["class_name"] = args.class_name + cfg["device"] = args.device + if isinstance(cfg["device"], int): + cfg["device"] = str(cfg["device"]) + cfg["testing"]["output_dir"] = args.output_dir + cfg["testing"]["output_scores_dir"] = args.output_scores_dir + os.makedirs(cfg["testing"]["output_scores_dir"], exist_ok=True) + + cfg["testing"]["class_name_mapping_dir"] = args.class_name_mapping_dir + if args.save_csv.lower() == "true": + cfg["testing"]["save_csv"] = True + else: + cfg["testing"]["save_csv"] = False + + if args.save_json.lower() == "true": + cfg["testing"]["save_json"] = True + else: + cfg["testing"]["save_json"] = False + + return cfg + + +if __name__ == "__main__": + args = get_args() + cfg = load_args(cfg={"datasets": {}, "testing": {}, "models": {}}, args=args) + print(cfg) + model = JsonScoreEvaluator(cfg=cfg) + model.main() diff --git a/evaluation/json_score.py b/evaluation/json_score.py new file mode 100644 index 0000000000000000000000000000000000000000..03f5a06e928e94dbf237ddf8e4bd37f1d53b6444 --- /dev/null +++ b/evaluation/json_score.py @@ -0,0 +1,98 @@ +# ----------------------------------------------------------------------------- +# Do Not Alter This File! +# ----------------------------------------------------------------------------- +# The following code is part of the logic used for loading and evaluating your +# output scores. Please DO NOT modify this section, as upon your submission, +# the whole evaluation logic will be overwritten by the original code. +# ----------------------------------------------------------------------------- + +import warnings +import numpy as np +import torch +from tqdm import tqdm + +from evaluation.base_eval import BaseEval +from evaluation.utils.json_helpers import json_to_dict + +warnings.filterwarnings("ignore") + + +class JsonScoreEvaluator(BaseEval): + """ + Evaluates anomaly detection performance based on pre-computed scores stored in JSON files. + + This class extends the BaseEval class and specializes in reading scores from JSON files, + computing evaluation metrics, and optionally saving results to CSV or JSON format. + + Notes: + - Score files are expected to follow the exact dataset structure. + `{category}/{split}/{anomaly_type}/{image_name}_scores.json` + e.g., `photovoltaic_module/test/good/037_scores.json` + - Score files are expected to be at `self.scores_dir`. + + Example usage: + >>> evaluator = JsonScoreEvaluator(cfg) + >>> results = evaluator.main() + """ + + def __init__(self, cfg): + super().__init__(cfg) + + def get_scores_for_image(self, image_path): + image_scores_path = self.get_scores_path_for_image(image_path) + image_scores = json_to_dict(image_scores_path) + + return image_scores + + def load_category_scores(self, category): + cls_scores_list = [] # image level prediction + anomaly_maps = [] # pixel level prediction + gt_list = [] # image level ground truth + img_masks = [] # pixel level ground truth + + image_path_list = [] + test_dataset = self.load_datasets(category) + test_dataloader = torch.utils.data.DataLoader( + test_dataset, + batch_size=1, + shuffle=False, + num_workers=0, + pin_memory=True, + ) + + for image_info in tqdm(test_dataloader): + if not isinstance(image_info, dict): + raise ValueError("Encountered non-dict image in dataloader") + + del image_info["image"] + + image_path = image_info["image_path"][0] + image_path_list.extend(image_path) + + img_masks.append(image_info["mask"]) + gt_list.extend(list(image_info["is_anomaly"].numpy())) + + image_scores = self.get_scores_for_image(image_path) + cls_scores = image_scores["img_level_score"] + anomaly_maps_iter = image_scores["pix_level_score"] + + cls_scores_list.append(cls_scores) + anomaly_maps.append(anomaly_maps_iter) + + pr_sp = np.array(cls_scores_list) + gt_sp = np.array(gt_list) + pr_px = np.array(anomaly_maps) + gt_px = torch.cat(img_masks, dim=0).numpy().astype(np.int32) + + assert pr_px.shape[1:] == ( + 1, + 224, + 224, + ), "Predicted output scores do not meet the expected shape!" + assert gt_px.shape[1:] == ( + 1, + 224, + 224, + ), "Loaded ground truth maps do not meet the expected shape!" + + return gt_sp, pr_sp, gt_px, pr_px, image_path_list diff --git a/evaluation/utils/json_helpers.py b/evaluation/utils/json_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..544666cd00e455f814aca7dda1da919816ae1aba --- /dev/null +++ b/evaluation/utils/json_helpers.py @@ -0,0 +1,46 @@ +# ----------------------------------------------------------------------------- +# Do Not Alter This File! +# ----------------------------------------------------------------------------- +# The following code is part of the logic used for loading and evaluating your +# output scores. Please DO NOT modify this section, as upon your submission, +# the whole evaluation logic will be overwritten by the original code. +# ----------------------------------------------------------------------------- + +import json +import numpy as np + + +class NumpyEncoder(json.JSONEncoder): + """Special json encoder for numpy types""" + + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return { + "__ndarray__": obj.tolist(), + "dtype": str(obj.dtype), + "shape": obj.shape, + } + else: + return super(NumpyEncoder, self).default(obj) + + +def dict_to_json(dct, filename): + """Save a dictionary to a JSON file""" + with open(filename, "w") as f: + json.dump(dct, f, cls=NumpyEncoder) + + +def json_to_dict(filename): + """Load a JSON file and convert it back to a dictionary of NumPy arrays""" + with open(filename, "r") as f: + dct = json.load(f) + + for k, v in dct.items(): + if isinstance(v, dict) and "__ndarray__" in v: + dct[k] = np.array(v["__ndarray__"], dtype=v["dtype"]).reshape(v["shape"]) + + return dct diff --git a/evaluation/utils/metrics.py b/evaluation/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..758964453215028b5c6af283d32b220dc052f75a --- /dev/null +++ b/evaluation/utils/metrics.py @@ -0,0 +1,78 @@ +# ----------------------------------------------------------------------------- +# Do Not Alter This File! +# ----------------------------------------------------------------------------- +# The following code is part of the logic used for loading and evaluating your +# output scores. Please DO NOT modify this section, as upon your submission, +# the whole evaluation logic will be overwritten by the original code. +# ----------------------------------------------------------------------------- + +import numpy as np +from sklearn.metrics import ( + auc, + roc_auc_score, + average_precision_score, + precision_recall_curve, +) +from skimage import measure + + +# ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py +def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3): + binary_amaps = np.zeros_like(amaps, dtype=bool) + min_th, max_th = amaps.min(), amaps.max() + delta = (max_th - min_th) / max_step + pros, fprs, ths = [], [], [] + for th in np.arange(min_th, max_th, delta): + binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1 + pro = [] + for binary_amap, mask in zip(binary_amaps, masks): + for region in measure.regionprops(measure.label(mask)): + tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum() + pro.append(tp_pixels / region.area) + inverse_masks = 1 - masks + fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum() + fpr = fp_pixels / inverse_masks.sum() + pros.append(np.array(pro).mean()) + fprs.append(fpr) + ths.append(th) + pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths) + idxes = fprs < expect_fpr + fprs = fprs[idxes] + fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min()) + pro_auc = auc(fprs, pros[idxes]) + return pro_auc + + +def compute_metrics(gt_sp=None, pr_sp=None, gt_px=None, pr_px=None): + # classification + if ( + gt_sp is None + or pr_sp is None + or gt_sp.sum() == 0 + or gt_sp.sum() == gt_sp.shape[0] + ): + auroc_sp, f1_sp, ap_sp = 0, 0, 0 + else: + auroc_sp = roc_auc_score(gt_sp, pr_sp) + ap_sp = average_precision_score(gt_sp, pr_sp) + precisions, recalls, thresholds = precision_recall_curve(gt_sp, pr_sp) + f1_scores = (2 * precisions * recalls) / (precisions + recalls) + f1_sp = np.max(f1_scores[np.isfinite(f1_scores)]) + + # segmentation + if gt_px is None or pr_px is None or gt_px.sum() == 0: + auroc_px, f1_px, ap_px, aupro = 0, 0, 0, 0 + else: + auroc_px = roc_auc_score(gt_px.ravel(), pr_px.ravel()) + ap_px = average_precision_score(gt_px.ravel(), pr_px.ravel()) + precisions, recalls, thresholds = precision_recall_curve( + gt_px.ravel(), pr_px.ravel() + ) + f1_scores = (2 * precisions * recalls) / (precisions + recalls) + f1_px = np.max(f1_scores[np.isfinite(f1_scores)]) + aupro = cal_pro_score(gt_px.squeeze(), pr_px.squeeze()) + + image_metric = [auroc_sp, f1_sp, ap_sp] + pixel_metric = [auroc_px, f1_px, ap_px, aupro] + + return image_metric, pixel_metric diff --git a/generate_dataset_json/DAGM.py b/generate_dataset_json/DAGM.py new file mode 100644 index 0000000000000000000000000000000000000000..f17e2113f6da80602af69fea767cc73333535b49 --- /dev/null +++ b/generate_dataset_json/DAGM.py @@ -0,0 +1,68 @@ +import os +import json +import pandas as pd + + +class DAGMSolver(object): + CLSNAMES = [ + 'Class1','Class2','Class3','Class4','Class5','Class6','Class7','Class8','Class9','Class10' + ] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(Train={}, Test={}) + anomaly_samples = 0 + normal_samples = 0 + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['Train', 'Test']: + cls_info = [] + x, y, mask_names_none= [], [], [] + img_dir = os.listdir(f'{cls_dir}/{phase}') + + mask_names = os.listdir(f'{cls_dir}/{phase}/Label') + + img_fpath_list = sorted([f + for f in img_dir + if f.endswith('.PNG')]) + gt_fpath_list = sorted([f + for f in mask_names + if f.endswith('.PNG')]) + + img_exclude_list = [f.split("_")[0] + ".PNG" for f in gt_fpath_list] + + img_normal_fpath_list = list(set(img_fpath_list) - set(img_exclude_list)) + + x.extend(img_normal_fpath_list + img_exclude_list) + + y.extend([0] * len(img_normal_fpath_list) + [1]* len(img_exclude_list)) + + mask_names_none.extend([None] * len(img_normal_fpath_list) + gt_fpath_list) + + for idx, img_name in enumerate(x): + info_img = dict( + img_path=f'{cls_name}/{phase}/{img_name}', + mask_path=f'{cls_name}/{phase}/Label/{mask_names_none[idx]}' if mask_names_none[idx] != None else '', + cls_name=cls_name, + specie_name='', + anomaly=1 if y[idx] == 1 else 0, + ) + cls_info.append(info_img) + if phase == 'Test': + if y[idx] == 1: + anomaly_samples = anomaly_samples + 1 + else: + normal_samples = normal_samples + 1 + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) + + + +if __name__ == '__main__': + runner = DAGMSolver(root='/remote-home/iot_zhouqihang/data/DAGM_KaggleUpload') + runner.run() diff --git a/generate_dataset_json/DTD.py b/generate_dataset_json/DTD.py new file mode 100644 index 0000000000000000000000000000000000000000..495be0f72ec2005e719ceb33efd248ccaf9a9f7a --- /dev/null +++ b/generate_dataset_json/DTD.py @@ -0,0 +1,43 @@ + + +import os +import json + + +class DTDSolver(object): + CLSNAMES = ['Woven_001', 'Woven_127', 'Woven_104', 'Stratified_154', 'Blotchy_099', 'Woven_068', 'Woven_125', 'Marbled_078', 'Perforated_037', 'Mesh_114', 'Fibrous_183', 'Matted_069'] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = DTDSolver(root='/remote-home/iot_zhouqihang/data/DTD-Synthetic') + runner.run() diff --git a/generate_dataset_json/SDD.py b/generate_dataset_json/SDD.py new file mode 100644 index 0000000000000000000000000000000000000000..d293d828aef7d3af1b3d64a33d8fecf5dc23e18b --- /dev/null +++ b/generate_dataset_json/SDD.py @@ -0,0 +1,51 @@ + + +import os +import json +import sys + +class SDDSolver(object): + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + self.CLSNAMES = [folder for folder in os.listdir(root) if os.path.isdir(os.path.join(root, folder)) and not folder.startswith('.')] + + def run(self): + info = dict(train={}, test={}) + anomaly_samples = 0 + normal_samples = 0 + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + if phase == 'test': + if is_abnormal: + anomaly_samples = anomaly_samples + 1 + else: + normal_samples = normal_samples + 1 + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) + + +if __name__ == '__main__': + runner = SDDSolver(root=sys.argv[1]) + runner.run() diff --git a/generate_dataset_json/br35.py b/generate_dataset_json/br35.py new file mode 100644 index 0000000000000000000000000000000000000000..ed82b703c71025418dd1b21883630e2dfd9d248c --- /dev/null +++ b/generate_dataset_json/br35.py @@ -0,0 +1,38 @@ +import os +import json + + +class Br35Solver(object): + CLSNAMES = ['brain'] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/' + for phase in ['test']: + cls_info = [] + species = os.listdir(f'{cls_dir}') + for specie in species: + is_abnormal = True if specie not in ['no'] else False + img_names = os.listdir(f'{cls_dir}/{specie}') + img_names.sort() + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_dir}/{specie}/{img_name}', + cls_name=cls_name, + mask_path="", + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + +if __name__ == '__main__': + runner = Br35Solver(root='/remote-home/iot_zhouqihang/data/br35') + runner.run() diff --git a/generate_dataset_json/brainmri.py b/generate_dataset_json/brainmri.py new file mode 100644 index 0000000000000000000000000000000000000000..28d1d785454150403e0dcfb2893da038e081bcf3 --- /dev/null +++ b/generate_dataset_json/brainmri.py @@ -0,0 +1,38 @@ +import os +import json + + +class IsbiSolver(object): + CLSNAMES = ['brain'] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/brain_tumor_dataset' + for phase in ['test']: + cls_info = [] + species = os.listdir(f'{cls_dir}') + for specie in species: + is_abnormal = True if specie not in ['no'] else False + img_names = os.listdir(f'{cls_dir}/{specie}') + img_names.sort() + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_dir}/{specie}/{img_name}', + cls_name=cls_name, + mask_path="", + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + +if __name__ == '__main__': + runner = IsbiSolver(root='/remote-home/iot_zhouqihang/data/BrainMRI') + runner.run() diff --git a/generate_dataset_json/btad.py b/generate_dataset_json/btad.py new file mode 100644 index 0000000000000000000000000000000000000000..6c278fb81283a5d51fa765c3c920b8f1ffa362cc --- /dev/null +++ b/generate_dataset_json/btad.py @@ -0,0 +1,48 @@ +import os +import json + + +class BtadSolver(object): + CLSNAMES = ['01', '02', '03'] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(train={}, test={}) + anomaly_samples = 0 + normal_samples = 0 + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['ok'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + if phase == 'test': + if is_abnormal: + anomaly_samples = anomaly_samples + 1 + else: + normal_samples = normal_samples + 1 + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) + +if __name__ == '__main__': + runner = BtadSolver(root='/remote-home/iot_zhouqihang/data/BTech_Dataset_transformed') + runner.run() diff --git a/generate_dataset_json/clinicDB.py b/generate_dataset_json/clinicDB.py new file mode 100644 index 0000000000000000000000000000000000000000..ade868cf456d1da59ba767e3736b5fd3c56f5b1b --- /dev/null +++ b/generate_dataset_json/clinicDB.py @@ -0,0 +1,50 @@ +import os +import json +import pandas as pd + + +class ClinicDBSolver(object): + CLSNAMES = [ + 'colon', + ] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(train={}, test={}) + anomaly_samples = 0 + normal_samples = 0 + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}' + for phase in ['test']: + cls_info = [] + # is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/images') + mask_names = os.listdir(f'{cls_dir}/masks') + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_dir}/images/{img_name}', + mask_path=f'{cls_dir}/masks/{mask_names[idx]}', + cls_name=cls_name, + specie_name='', + anomaly=1 + ) + cls_info.append(info_img) + if phase == 'test': + if True: + anomaly_samples = anomaly_samples + 1 + else: + normal_samples = normal_samples + 1 + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) + + +if __name__ == '__main__': + runner = ClinicDBSolver(root='/remote-home/iot_zhouqihang/data/medical/CVC-ClinicDB') + runner.run() diff --git a/generate_dataset_json/colonDB.py b/generate_dataset_json/colonDB.py new file mode 100644 index 0000000000000000000000000000000000000000..1864b92c1627e8c7644274b50ef7708773982f4d --- /dev/null +++ b/generate_dataset_json/colonDB.py @@ -0,0 +1,50 @@ +import os +import json +import pandas as pd + + +class ClinicDBSolver(object): + CLSNAMES = [ + 'colon', + ] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(train={}, test={}) + anomaly_samples = 0 + normal_samples = 0 + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}' + for phase in ['test']: + cls_info = [] + # is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/images') + mask_names = os.listdir(f'{cls_dir}/masks') + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_dir}/images/{img_name}', + mask_path=f'{cls_dir}/masks/{mask_names[idx]}', + cls_name=cls_name, + specie_name='', + anomaly=1 + ) + cls_info.append(info_img) + if phase == 'test': + if True: + anomaly_samples = anomaly_samples + 1 + else: + normal_samples = normal_samples + 1 + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) + + +if __name__ == '__main__': + runner = ClinicDBSolver(root='/remote-home/iot_zhouqihang/data/medical/CVC-ColonDB') + runner.run() diff --git a/generate_dataset_json/covid.py b/generate_dataset_json/covid.py new file mode 100644 index 0000000000000000000000000000000000000000..4837e047b83e2eb470c6c3c9ce6fc44aba21d423 --- /dev/null +++ b/generate_dataset_json/covid.py @@ -0,0 +1,51 @@ +import os +import json + +import os +import json + + +class MpddSolver(object): + CLSNAMES = ['chest'] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(train={}, test={}) + anomaly_samples = 0 + normal_samples = 0 + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}' + for phase in ['test']: + cls_info = [] + species = os.listdir(f'{cls_dir}') + for specie in species: + is_abnormal = True if specie not in ['NORMAL'] else False + img_names = os.listdir(f'{cls_dir}/{specie}/') + + img_names.sort() + + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{specie}/{img_name}', + mask_path="", + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + if phase == 'test': + if is_abnormal: + anomaly_samples = anomaly_samples + 1 + else: + normal_samples = normal_samples + 1 + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) + +if __name__ == '__main__': + runner = MpddSolver(root='/remote-home/iot_zhouqihang/data/COVID-19_Radiography_Dataset') + runner.run() diff --git a/generate_dataset_json/endoTect.py b/generate_dataset_json/endoTect.py new file mode 100644 index 0000000000000000000000000000000000000000..e554e52538afd4ac9eabe6801b98ae88fc21809e --- /dev/null +++ b/generate_dataset_json/endoTect.py @@ -0,0 +1,55 @@ +import os +import json +import pandas as pd + + +class HyperSolver(object): + CLSNAMES = [ + 'colon', + ] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(train={}, test={}) + anomaly_samples = 0 + normal_samples = 0 + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}' + for phase in ['test']: + cls_info = [] + species = set(os.listdir(f'{cls_dir}'))-set(['masks']) + print("species", species) + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{specie}') + mask_names = os.listdir(f'{cls_dir}/masks/') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + assert len(img_names) == len(mask_names) if mask_names is not None else True + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_dir}/{specie}/{img_name}', + mask_path=f'{cls_dir}/masks/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + if phase == 'test': + if is_abnormal: + anomaly_samples = anomaly_samples + 1 + else: + normal_samples = normal_samples + 1 + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) + + + +if __name__ == '__main__': + runner = HyperSolver(root='/remote-home/iot_zhouqihang/data/medical/EndoTect_2020_Segmentation_Test_Dataset') + runner.run() diff --git a/generate_dataset_json/head_ct.py b/generate_dataset_json/head_ct.py new file mode 100644 index 0000000000000000000000000000000000000000..646b265bd246621ee6583a20d5eae09bcb0937c8 --- /dev/null +++ b/generate_dataset_json/head_ct.py @@ -0,0 +1,48 @@ +import os +import json + + +class MpddSolver(object): + CLSNAMES = ['brain'] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(train={}, test={}) + anomaly_samples = 0 + normal_samples = 0 + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + + img_names.sort() + + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path="", + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + if phase == 'test': + if is_abnormal: + anomaly_samples = anomaly_samples + 1 + else: + normal_samples = normal_samples + 1 + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) + +if __name__ == '__main__': + runner = MpddSolver(root='/remote-home/iot_zhouqihang/data/HeadCT_anomaly_detection') + runner.run() diff --git a/generate_dataset_json/isbi.py b/generate_dataset_json/isbi.py new file mode 100644 index 0000000000000000000000000000000000000000..3b70d8ddfff7638478ba191d24e1f28c785c0b2d --- /dev/null +++ b/generate_dataset_json/isbi.py @@ -0,0 +1,48 @@ +import os +import json + + +class IsbiSolver(object): + CLSNAMES = ['skin'] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(train={}, test={}) + anomaly_samples = 0 + normal_samples = 0 + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ISBI2016_ISIC_Part1_Test_GroundTruth/') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ISBI2016_ISIC_Part1_Test_GroundTruth/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + if phase == 'test': + if is_abnormal: + anomaly_samples = anomaly_samples + 1 + else: + normal_samples = normal_samples + 1 + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) + +if __name__ == '__main__': + runner = IsbiSolver(root='/remote-home/iot_zhouqihang/data/ISBI') + runner.run() diff --git a/generate_dataset_json/kvasir.py b/generate_dataset_json/kvasir.py new file mode 100644 index 0000000000000000000000000000000000000000..20ac90e7585bd815cbeb02cbac3b59f651d08ab9 --- /dev/null +++ b/generate_dataset_json/kvasir.py @@ -0,0 +1,51 @@ +import os +import json +import pandas as pd + + +class ClinicDBSolver(object): + CLSNAMES = [ + 'colon', + ] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(train={}, test={}) + anomaly_samples = 0 + normal_samples = 0 + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}' + for phase in ['test']: + cls_info = [] + # is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/images') + mask_names = os.listdir(f'{cls_dir}/masks') + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_dir}/images/{img_name}', + mask_path=f'{cls_dir}/masks/{mask_names[idx]}', + cls_name=cls_name, + specie_name='', + anomaly=1 + ) + cls_info.append(info_img) + if phase == 'test': + if True: + anomaly_samples = anomaly_samples + 1 + else: + normal_samples = normal_samples + 1 + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) + + + +if __name__ == '__main__': + runner = ClinicDBSolver(root='/remote-home/iot_zhouqihang/data/medical/Kvasir') + runner.run() diff --git a/generate_dataset_json/mpdd.py b/generate_dataset_json/mpdd.py new file mode 100644 index 0000000000000000000000000000000000000000..4e0c825aba940bdcbc24f68c8008a8467c0939b1 --- /dev/null +++ b/generate_dataset_json/mpdd.py @@ -0,0 +1,48 @@ +import os +import json + + +class MpddSolver(object): + CLSNAMES = ['bracket_black', 'bracket_brown', 'bracket_white', 'connector', 'metal_plate', 'tubes'] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(train={}, test={}) + anomaly_samples = 0 + normal_samples = 0 + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + if phase == 'test': + if is_abnormal: + anomaly_samples = anomaly_samples + 1 + else: + normal_samples = normal_samples + 1 + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) + +if __name__ == '__main__': + runner = MpddSolver(root='/remote-home/iot_zhouqihang/data/mpdd') + runner.run() diff --git a/generate_dataset_json/mvtec.py b/generate_dataset_json/mvtec.py new file mode 100644 index 0000000000000000000000000000000000000000..fcbfe63b79e3ae95f71a0482f99d70a93c697956 --- /dev/null +++ b/generate_dataset_json/mvtec.py @@ -0,0 +1,51 @@ +import os +import json +import sys + +class MVTecSolver(object): + CLSNAMES = [ + 'bottle', 'cable', 'capsule', 'carpet', 'grid', + 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', + 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', + ] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(train={}, test={}) + anomaly_samples = 0 + normal_samples = 0 + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + if phase == 'test': + if is_abnormal: + anomaly_samples = anomaly_samples + 1 + else: + normal_samples = normal_samples + 1 + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) +if __name__ == '__main__': + runner = MVTecSolver(sys.argv[1]) + runner.run() diff --git a/generate_dataset_json/tn3k.py b/generate_dataset_json/tn3k.py new file mode 100644 index 0000000000000000000000000000000000000000..e300bf548f62eaea01a2a665c321d47bbe159279 --- /dev/null +++ b/generate_dataset_json/tn3k.py @@ -0,0 +1,51 @@ +import os +import json +import pandas as pd + + +class ClinicDBSolver(object): + CLSNAMES = [ + 'thyroid', + ] + + def __init__(self, root='data/mvtec'): + self.root = root + self.meta_path = f'{root}/meta.json' + + def run(self): + info = dict(train={}, test={}) + anomaly_samples = 0 + normal_samples = 0 + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}' + for phase in ['test']: + cls_info = [] + # is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/test-image') + mask_names = os.listdir(f'{cls_dir}/test-mask') + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_dir}/test-image/{img_name}', + mask_path=f'{cls_dir}/test-mask/{mask_names[idx]}', + cls_name=cls_name, + specie_name='', + anomaly=1 + ) + cls_info.append(info_img) + if phase == 'test': + if True: + anomaly_samples = anomaly_samples + 1 + else: + normal_samples = normal_samples + 1 + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) + + + +if __name__ == '__main__': + runner = ClinicDBSolver(root='/remote-home/iot_zhouqihang/data/tn3k') + runner.run() diff --git a/generate_dataset_json/visa.py b/generate_dataset_json/visa.py new file mode 100644 index 0000000000000000000000000000000000000000..0340505d69cca68c1c8c93e1854102fb12375ebe --- /dev/null +++ b/generate_dataset_json/visa.py @@ -0,0 +1,54 @@ +import os +import json +import pandas as pd + + +class VisASolver(object): + CLSNAMES = [ + 'candle', 'capsules', 'cashew', 'chewinggum', 'fryum', + 'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3', + 'pcb4', 'pipe_fryum', + ] + + def __init__(self, root='data/visa'): + self.root = root + self.meta_path = f'{root}/meta.json' + self.phases = ['train', 'test'] + self.csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0) + + def run(self): + columns = self.csv_data.columns # [object, split, label, image, mask] + info = {phase: {} for phase in self.phases} + anomaly_samples = 0 + normal_samples = 0 + for cls_name in self.CLSNAMES: + cls_data = self.csv_data[self.csv_data[columns[0]] == cls_name] + for phase in self.phases: + cls_info = [] + cls_data_phase = cls_data[cls_data[columns[1]] == phase] + cls_data_phase.index = list(range(len(cls_data_phase))) + for idx in range(cls_data_phase.shape[0]): + data = cls_data_phase.loc[idx] + is_abnormal = True if data[2] == 'anomaly' else False + info_img = dict( + img_path=data[3], + mask_path=data[4] if is_abnormal else '', + cls_name=cls_name, + specie_name='', + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + if phase == 'test': + if is_abnormal: + anomaly_samples = anomaly_samples + 1 + else: + normal_samples = normal_samples + 1 + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) + + +if __name__ == '__main__': + runner = VisASolver(root='/remote-home/iot_zhouqihang/data/Visa') + runner.run() diff --git a/logger.py b/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..6f992d0a1fd3206d0c1071490bd2b98fd152f9a8 --- /dev/null +++ b/logger.py @@ -0,0 +1,25 @@ + +import logging +import os + +def get_logger(save_path): + if not os.path.exists(save_path): + os.makedirs(save_path) + + txt_path = os.path.join(save_path, 'log.txt') + # logger + root_logger = logging.getLogger() + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + root_logger.setLevel(logging.WARNING) + logger = logging.getLogger('test') + formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', + datefmt='%y-%m-%d %H:%M:%S') + logger.setLevel(logging.INFO) + file_handler = logging.FileHandler(txt_path, mode='a') + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + return logger \ No newline at end of file diff --git a/loss.py b/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..26131610c079e355e1b26941d053df8da19db542 --- /dev/null +++ b/loss.py @@ -0,0 +1,125 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from math import exp + +class FocalLoss(nn.Module): + """ + copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py + This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in + 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' + Focal_Loss= -1*alpha*(1-pt)*log(pt) + :param alpha: (tensor) 3D or 4D the scalar factor for this criterion + :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more + focus on hard misclassified example + :param smooth: (float,double) smooth value when cross entropy + :param balance_index: (int) balance class index, should be specific when alpha is float + :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. + """ + + def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): + super(FocalLoss, self).__init__() + self.apply_nonlin = apply_nonlin + self.alpha = alpha + self.gamma = gamma + self.balance_index = balance_index + self.smooth = smooth + self.size_average = size_average + + if self.smooth is not None: + if self.smooth < 0 or self.smooth > 1.0: + raise ValueError('smooth value should be in [0,1]') + + def forward(self, logit, target): + if self.apply_nonlin is not None: + logit = self.apply_nonlin(logit) + num_class = logit.shape[1] + + if logit.dim() > 2: + # N,C,d1,d2 -> N,C,m (m=d1*d2*...) + logit = logit.view(logit.size(0), logit.size(1), -1) + logit = logit.permute(0, 2, 1).contiguous() + logit = logit.view(-1, logit.size(-1)) + target = torch.squeeze(target, 1) + target = target.view(-1, 1) + alpha = self.alpha + + if alpha is None: + alpha = torch.ones(num_class, 1) + elif isinstance(alpha, (list, np.ndarray)): + assert len(alpha) == num_class + alpha = torch.FloatTensor(alpha).view(num_class, 1) + alpha = alpha / alpha.sum() + elif isinstance(alpha, float): + alpha = torch.ones(num_class, 1) + alpha = alpha * (1 - self.alpha) + alpha[self.balance_index] = self.alpha + + else: + raise TypeError('Not support alpha type') + + if alpha.device != logit.device: + alpha = alpha.to(logit.device) + + idx = target.cpu().long() + + one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() + one_hot_key = one_hot_key.scatter_(1, idx, 1) + if one_hot_key.device != logit.device: + one_hot_key = one_hot_key.to(logit.device) + + if self.smooth: + one_hot_key = torch.clamp( + one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) + pt = (one_hot_key * logit).sum(1) + self.smooth + logpt = pt.log() + + gamma = self.gamma + + alpha = alpha[idx] + alpha = torch.squeeze(alpha) + loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt + + if self.size_average: + loss = loss.mean() + return loss + + +class BinaryDiceLoss(nn.Module): + def __init__(self): + super(BinaryDiceLoss, self).__init__() + + def forward(self, input, targets): + # 获取每个批次的大小 N + N = targets.size()[0] + # 平滑变量 + smooth = 1 + # 将宽高 reshape 到同一纬度 + input_flat = input.view(N, -1) + targets_flat = targets.view(N, -1) + + intersection = input_flat * targets_flat + N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth) + # 计算一个批次中平均每张图的损失 + loss = 1 - N_dice_eff.sum() / N + return loss + +def smooth(arr, lamda1): + new_array = arr + arr2 = torch.zeros_like(arr) + arr2[:, :-1, :] = arr[:, 1:, :] + arr2[:, -1, :] = arr[:, -1, :] + + new_array2 = torch.zeros_like(new_array) + new_array2[:, :, :-1] = new_array[:, :, 1:] + new_array2[:, :, -1] = new_array[:, :, -1] + loss = (torch.sum((arr2 - arr) ** 2) + torch.sum((new_array2 - new_array) ** 2)) / 2 + return lamda1 * loss + +def sparsity(arr, target, lamda2): + if target == 0: + loss = torch.mean(torch.norm(arr, dim=0)) + else: + loss = torch.mean(torch.norm(1-arr, dim=0)) + return lamda2 * loss \ No newline at end of file diff --git a/metrics.py b/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a77c74c3da6cdd01352ab967930825923d1c3b --- /dev/null +++ b/metrics.py @@ -0,0 +1,60 @@ +from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise +import numpy as np +from skimage import measure + +def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3): + # ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py + binary_amaps = np.zeros_like(amaps, dtype=bool) + min_th, max_th = amaps.min(), amaps.max() + delta = (max_th - min_th) / max_step + pros, fprs, ths = [], [], [] + for th in np.arange(min_th, max_th, delta): + binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1 + pro = [] + for binary_amap, mask in zip(binary_amaps, masks): + for region in measure.regionprops(measure.label(mask)): + tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum() + pro.append(tp_pixels / region.area) + inverse_masks = 1 - masks + fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum() + fpr = fp_pixels / inverse_masks.sum() + pros.append(np.array(pro).mean()) + fprs.append(fpr) + ths.append(th) + pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths) + idxes = fprs < expect_fpr + fprs = fprs[idxes] + fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min()) + pro_auc = auc(fprs, pros[idxes]) + return pro_auc + + +def image_level_metrics(results, obj, metric): + gt = results[obj]['gt_sp'] + pr = results[obj]['pr_sp'] + gt = np.array(gt) + pr = np.array(pr) + if metric == 'image-auroc': + performance = roc_auc_score(gt, pr) + elif metric == 'image-ap': + performance = average_precision_score(gt, pr) + + return performance + # table.append(str(np.round(performance * 100, decimals=1))) + + +def pixel_level_metrics(results, obj, metric): + gt = results[obj]['imgs_masks'] + pr = results[obj]['anomaly_maps'] + gt = np.array(gt) + pr = np.array(pr) + if metric == 'pixel-auroc': + performance = roc_auc_score(gt.ravel(), pr.ravel()) + elif metric == 'pixel-aupro': + if len(gt.shape) == 4: + gt = gt.squeeze(1) + if len(pr.shape) == 4: + pr = pr.squeeze(1) + performance = cal_pro_score(gt, pr) + return performance + \ No newline at end of file diff --git a/prompt_ensemble.py b/prompt_ensemble.py new file mode 100644 index 0000000000000000000000000000000000000000..564772c3ec87975e41dec91c09626b3195f52f76 --- /dev/null +++ b/prompt_ensemble.py @@ -0,0 +1,264 @@ +import os +from typing import Union, List +from pkg_resources import packaging +import torch +import numpy as np +from AnomalyCLIP_lib.simple_tokenizer import SimpleTokenizer as _Tokenizer +# from open_clip import tokenizer +# simple_tokenizer = tokenizer.SimpleTokenizer() +from copy import deepcopy +import torch.nn as nn + +_tokenizer = _Tokenizer() + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + +def encode_text_with_prompt_ensemble(model, texts, device): + prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage'] + prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage'] + prompt_state = [prompt_normal, prompt_abnormal] + prompt_templates = ['a bad photo of a {}.', 'a low resolution photo of the {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a bright photo of a {}.', 'a dark photo of the {}.', 'a photo of my {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a photo of the small {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'a dark photo of a {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.'] + + text_features = [] + for i in range(len(prompt_state)): + prompted_state = [state.format(texts[0]) for state in prompt_state[i]] + prompted_sentence = [] + for s in prompted_state: + for template in prompt_templates: + prompted_sentence.append(template.format(s)) + prompted_sentence = tokenize(prompted_sentence) + class_embeddings = model.encode_text(prompted_sentence.to(device)) + class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) + class_embedding = class_embeddings.mean(dim=0) + class_embedding /= class_embedding.norm() + text_features.append(class_embedding) + + text_features = torch.stack(text_features, dim=1).to(device).t() + + return text_features + + + +def _get_clones(module, N): + return nn.ModuleList([deepcopy(module) for i in range(N)]) +class AnomalyCLIP_PromptLearner(nn.Module): + def __init__(self, clip_model, design_details): + super().__init__() + classnames = ["object"] + self.n_cls = len(classnames) + self.n_ctx = design_details["Prompt_length"] + n_ctx_pos = self.n_ctx + n_ctx_neg = self.n_ctx + self.text_encoder_n_ctx = design_details["learnabel_text_embedding_length"] + ctx_init_pos = "" + ctx_init_neg = "" + dtype = clip_model.transformer.get_cast_dtype() + + ctx_dim = clip_model.ln_final.weight.shape[0] + + + self.classnames = classnames + + self.state_normal_list = [ + "{}", + ] + + self.state_anomaly_list = [ + "damaged {}", + ] + + normal_num = len(self.state_normal_list) + anormaly_num = len(self.state_anomaly_list) + self.normal_num = normal_num + self.anormaly_num = anormaly_num + + if ctx_init_pos and ctx_init_neg: + # use given words to initialize context vectors + ctx_init_pos = ctx_init_pos.replace("_", " ") + ctx_init_neg = ctx_init_neg.replace("_", " ") + n_ctx_pos = len(ctx_init_pos.split(" ")) + n_ctx_neg = len(ctx_init_neg.split(" ")) + #初始化text成bpd编码 + prompt_pos = tokenize(ctx_init_pos) + prompt_neg = tokenize(ctx_init_neg) + with torch.no_grad(): + #生成相应的text embedding + embedding_pos = clip_model.token_embedding(prompt_pos).type(dtype) + embedding_neg = clip_model.token_embedding(prompt_neg).type(dtype) + #这些是去除出来EOS 和 # CLS, EOS, 获得可学习的textual prompt + ctx_vectors_pos = embedding_pos[0, 1: 1 + n_ctx_pos, :] + ctx_vectors_neg = embedding_neg[0, 1: 1 + n_ctx_neg, :] + prompt_prefix_pos = ctx_init_pos + prompt_prefix_neg = ctx_init_neg + if True: + ctx_vectors_pos_ = [] + ctx_vectors_neg_ = [] + for _ in range(self.n_cls): + ctx_vectors_pos_.append(deepcopy(ctx_vectors_pos)) + ctx_vectors_neg_.append(deepcopy(ctx_vectors_neg)) + ctx_vectors_pos = torch.stack(ctx_vectors_pos_, dim=0) + ctx_vectors_neg = torch.stack(ctx_vectors_neg_, dim=0) + + else: + # Random Initialization + if True: + print("Initializing class-specific contexts") + #这里是cls是类的个数,n_ctx_pos代表learnable token的长度,ctx_dim表示prompt的dimension + ctx_vectors_pos = torch.empty(self.n_cls, self.normal_num, n_ctx_pos, ctx_dim, dtype=dtype) + ctx_vectors_neg = torch.empty(self.n_cls, self.anormaly_num, n_ctx_neg, ctx_dim, dtype=dtype) + else: + print("Initializing a generic context") + ctx_vectors_pos = torch.empty(n_ctx_pos, ctx_dim, dtype=dtype) + ctx_vectors_neg = torch.empty(n_ctx_neg, ctx_dim, dtype=dtype) + nn.init.normal_(ctx_vectors_pos, std=0.02) + nn.init.normal_(ctx_vectors_neg, std=0.02) + prompt_prefix_pos = " ".join(["X"] * n_ctx_pos) + prompt_prefix_neg = " ".join(["X"] * n_ctx_neg) + self.compound_prompts_depth = design_details["learnabel_text_embedding_depth"] + self.compound_prompts_text = nn.ParameterList([nn.Parameter(torch.empty(self.text_encoder_n_ctx, ctx_dim)) + for _ in range(self.compound_prompts_depth - 1)]) + for single_para in self.compound_prompts_text: + print("single_para", single_para.shape) + nn.init.normal_(single_para, std=0.02) + + single_layer = nn.Linear(ctx_dim, 896) + self.compound_prompt_projections = _get_clones(single_layer, self.compound_prompts_depth - 1) + + + self.ctx_pos = nn.Parameter(ctx_vectors_pos) # to be optimized + self.ctx_neg = nn.Parameter(ctx_vectors_neg) # to be optimized + + classnames = [name.replace("_", " ") for name in classnames] + name_lens = [len(_tokenizer.encode(name)) for name in classnames] + + + prompts_pos = [prompt_prefix_pos + " " + template.format(name)+ "." for template in self.state_normal_list for name in classnames] + prompts_neg = [prompt_prefix_neg + " " + template.format(name)+ "." for template in self.state_anomaly_list for name in classnames] + + tokenized_prompts_pos = [] + tokenized_prompts_neg = [] + + for p_pos in prompts_pos: + tokenized_prompts_pos.append(tokenize(p_pos)) + for p_neg in prompts_neg: + tokenized_prompts_neg.append(tokenize(p_neg)) + tokenized_prompts_pos = torch.cat(tokenized_prompts_pos) + tokenized_prompts_neg = torch.cat(tokenized_prompts_neg) + #生成相应的text embedding + with torch.no_grad(): + embedding_pos = clip_model.token_embedding(tokenized_prompts_pos).type(dtype) + embedding_neg = clip_model.token_embedding(tokenized_prompts_neg).type(dtype) + n, l, d = embedding_pos.shape + print("embedding_pos", embedding_pos.shape) + embedding_pos = embedding_pos.reshape(normal_num, self.n_cls, l, d).permute(1, 0, 2, 3) + embedding_neg = embedding_neg.reshape(anormaly_num, self.n_cls, l, d).permute(1, 0, 2, 3) + + + self.register_buffer("token_prefix_pos", embedding_pos[:, :, :1, :] ) + self.register_buffer("token_suffix_pos", embedding_pos[:, :,1 + n_ctx_pos:, :]) + self.register_buffer("token_prefix_neg", embedding_neg[:,:, :1, :]) + self.register_buffer("token_suffix_neg", embedding_neg[:, :, 1 + n_ctx_neg:, :]) + + n, d = tokenized_prompts_pos.shape + tokenized_prompts_pos = tokenized_prompts_pos.reshape(normal_num, self.n_cls, d).permute(1, 0, 2) + + n, d = tokenized_prompts_neg.shape + tokenized_prompts_neg = tokenized_prompts_neg.reshape(anormaly_num, self.n_cls, d).permute(1, 0, 2) + + self.n_ctx_pos = n_ctx_pos + self.n_ctx_neg = n_ctx_neg + # tokenized_prompts = torch.cat([tokenized_prompts_pos, tokenized_prompts_neg], dim=0) # torch.Tensor + self.register_buffer("tokenized_prompts_pos", tokenized_prompts_pos) + self.register_buffer("tokenized_prompts_neg", tokenized_prompts_neg) + print("tokenized_prompts shape", self.tokenized_prompts_pos.shape, self.tokenized_prompts_neg.shape) + + + + def forward(self, cls_id =None): + + ctx_pos = self.ctx_pos + ctx_neg = self.ctx_neg + ctx_pos = self.ctx_pos + ctx_neg = self.ctx_neg + # print("shape", self.ctx_pos[0:1].shape, ctx_pos.shape) + prefix_pos = self.token_prefix_pos + prefix_neg = self.token_prefix_neg + suffix_pos = self.token_suffix_pos + suffix_neg = self.token_suffix_neg + + # print(prefix_pos.shape, prefix_neg.shape) + + prompts_pos = torch.cat( + [ + # N(the number of template), 1, dim + prefix_pos, # (n_cls, 1, dim) + ctx_pos, # (n_cls, n_ctx, dim) + suffix_pos, # (n_cls, *, dim) + ], + dim=2, + ) + + prompts_neg = torch.cat( + [ + prefix_neg, # (n_cls, 1, dim) + ctx_neg, # (n_cls, n_ctx, dim) + suffix_neg, # (n_cls, *, dim) + ], + dim=2, + ) + _, _, l, d = prompts_pos.shape + prompts_pos = prompts_pos.reshape(-1, l, d) + _, _, l, d = prompts_neg.shape + prompts_neg = prompts_neg.reshape(-1, l, d) + prompts = torch.cat([prompts_pos, prompts_neg], dim=0) + + + _, l, d = self.tokenized_prompts_pos.shape + tokenized_prompts_pos = self.tokenized_prompts_pos.reshape(-1, d) + _, l, d = self.tokenized_prompts_neg.shape + tokenized_prompts_neg = self.tokenized_prompts_neg.reshape(-1, d) + tokenized_prompts = torch.cat((tokenized_prompts_pos, tokenized_prompts_neg), dim = 0) + + + return prompts, tokenized_prompts, self.compound_prompts_text \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..18dffcaaec792523a85f18a91755036fcea58235 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +scikit-image==0.20.0 +scikit-learn==1.2.2 +scipy==1.9.1 +seaborn==0.12.2 +timm==0.6.13 +torch>=2.0.0 +torchsummary==1.5.1 +torchvision>=0.15.1 +tqdm==4.65.0 +dash-table==5.0.0 +thop==0.1.1.post2209072238 +ftfy==6.3.1 diff --git a/results/9_12_4_multiscale/zero_shot/log.txt b/results/9_12_4_multiscale/zero_shot/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..08cbbea84ad7143e31b27ea0787546d35c4fa32a --- /dev/null +++ b/results/9_12_4_multiscale/zero_shot/log.txt @@ -0,0 +1,19 @@ +24-12-18 12:08:53.440 - INFO: +| objects | pixel_auroc | pixel_aupro | image_auroc | image_ap | +|:-----------|--------------:|--------------:|--------------:|-----------:| +| carpet | 98.8 | 90 | 100 | 100 | +| bottle | 90.4 | 80.8 | 88.7 | 96.8 | +| hazelnut | 97.2 | 92.5 | 97.2 | 98.5 | +| leather | 98.6 | 92.2 | 99.8 | 99.9 | +| cable | 78.9 | 64 | 70.3 | 81.7 | +| capsule | 95.8 | 87.6 | 89.5 | 97.8 | +| grid | 97.3 | 75.4 | 97.8 | 99.3 | +| pill | 91.8 | 88.1 | 81.1 | 95.3 | +| transistor | 70.8 | 58.2 | 93.9 | 92.1 | +| metal_nut | 74.6 | 71.1 | 92.4 | 98.2 | +| screw | 97.5 | 88 | 82.1 | 92.9 | +| toothbrush | 91.9 | 88.5 | 85.3 | 93.9 | +| zipper | 91.3 | 65.4 | 98.4 | 99.5 | +| tile | 94.7 | 87.4 | 100 | 100 | +| wood | 96.4 | 91.5 | 96.9 | 99.2 | +| mean | 91.1 | 81.4 | 91.6 | 96.4 | diff --git a/results/9_12_4_multiscale_visa/zero_shot/log.txt b/results/9_12_4_multiscale_visa/zero_shot/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..19463a2a6c912894d88f6116f93aa46079e5c878 --- /dev/null +++ b/results/9_12_4_multiscale_visa/zero_shot/log.txt @@ -0,0 +1,16 @@ +24-12-18 12:19:01.770 - INFO: +| objects | pixel_auroc | pixel_aupro | image_auroc | image_ap | +|:-----------|--------------:|--------------:|--------------:|-----------:| +| candle | 98.8 | 96.5 | 80.9 | 82.6 | +| capsules | 95 | 78.9 | 82.8 | 89.4 | +| cashew | 93.8 | 91.9 | 76 | 89.3 | +| chewinggum | 99.3 | 90.9 | 97.2 | 98.8 | +| fryum | 94.6 | 86.9 | 92.7 | 96.6 | +| macaroni1 | 98.3 | 89.8 | 86.7 | 85.5 | +| macaroni2 | 97.6 | 84 | 72.2 | 70.8 | +| pcb1 | 94 | 80.7 | 85.2 | 86.7 | +| pcb2 | 92.4 | 78.9 | 62 | 64.4 | +| pcb3 | 88.4 | 76.8 | 61.7 | 69.4 | +| pcb4 | 95.7 | 89.4 | 93.9 | 94.3 | +| pipe_fryum | 98.2 | 96.2 | 92.3 | 96.3 | +| mean | 95.5 | 86.7 | 82 | 85.3 | diff --git a/run.sh b/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..9d19777257799c47ea430b4a4e99792299c9a3d3 --- /dev/null +++ b/run.sh @@ -0,0 +1,5 @@ +# You must complete this file for your proposed method. + +python3 generate_dataset_json/SDD.py data + +bash test_data.sh diff --git a/runner.py b/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..139e9bf515e9d55ddcfd2b426471d2b967e7027f --- /dev/null +++ b/runner.py @@ -0,0 +1,37 @@ +# ----------------------------------------------------------------------------- +# This Python script is the primary entry point called by our judge. It runs +# your code to generate anomaly scores, then evaluates those scores to produce +# the final results. +# ----------------------------------------------------------------------------- + +import subprocess + +# Step 1: Generate anomaly scores +subprocess.run(["./run.sh"], check=True) + +# Step 2: Evaluate the generated scores +subprocess.run( + [ + "python3", + "evaluation/eval_main.py", + "--device", + "0", + "--data_path", + "./data/", + "--dataset_name", + "rayan_dataset", + "--class_name", + "all", + "--output_dir", + "./output", + "--output_scores_dir", + "./output_scores", + "--save_csv", + "True", + "--save_json", + "True", + "--class_name_mapping_dir", + "./evaluation/class_name_mapping.json", + ], + check=True, +) diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..8c37a5cbddcb3a57c15c1622a4389f2f9c54e92c --- /dev/null +++ b/test.py @@ -0,0 +1,146 @@ +import AnomalyCLIP_lib +import torch +import argparse +import torch.nn.functional as F +from prompt_ensemble import AnomalyCLIP_PromptLearner +from loss import FocalLoss, BinaryDiceLoss +from utils import normalize +from dataset import Dataset +from logger import get_logger +from tqdm import tqdm + +import os +import random +import numpy as np +from tabulate import tabulate +from utils import get_transform +from evaluation.utils.json_helpers import dict_to_json + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +from visualization import visualizer + +from metrics import image_level_metrics, pixel_level_metrics +from tqdm import tqdm +from scipy.ndimage import gaussian_filter +def test(args): + img_size = args.image_size + features_list = args.features_list + dataset_dir = args.data_path + save_path = args.save_path + dataset_name = args.dataset + + logger = get_logger(args.save_path) + device = "cuda" if torch.cuda.is_available() else "cpu" + + AnomalyCLIP_parameters = {"Prompt_length": args.n_ctx, "learnabel_text_embedding_depth": args.depth, "learnabel_text_embedding_length": args.t_n_ctx} + + model, _ = AnomalyCLIP_lib.load("ViT-L-14-336px.pt", device=device, design_details = AnomalyCLIP_parameters) + model.eval() + + preprocess, target_transform = get_transform(args) + test_data = Dataset(root=args.data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset) + test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) + obj_list = test_data.obj_list + + + results = {} + metrics = {} + for obj in obj_list: + results[obj] = {} + results[obj]['gt_sp'] = [] + results[obj]['pr_sp'] = [] + results[obj]['imgs_masks'] = [] + results[obj]['anomaly_maps'] = [] + metrics[obj] = {} + metrics[obj]['pixel-auroc'] = 0 + metrics[obj]['pixel-aupro'] = 0 + metrics[obj]['image-auroc'] = 0 + metrics[obj]['image-ap'] = 0 + + prompt_learner = AnomalyCLIP_PromptLearner(model.to("cpu"), AnomalyCLIP_parameters) + checkpoint = torch.load(args.checkpoint_path) + prompt_learner.load_state_dict(checkpoint["prompt_learner"]) + prompt_learner.to(device) + model.to(device) + model.visual.DAPM_replace(DPAM_layer = 20) + + prompts, tokenized_prompts, compound_prompts_text = prompt_learner(cls_id = None) + text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).float() + text_features = torch.stack(torch.chunk(text_features, dim = 0, chunks = 2), dim = 1) + text_features = text_features/text_features.norm(dim=-1, keepdim=True) + + + model.to(device) + for idx, items in enumerate(tqdm(test_dataloader)): + image = items['img'].to(device) + cls_name = items['cls_name'] + cls_id = items['cls_id'] + gt_mask = items['img_mask'] + gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0 + results[cls_name[0]]['imgs_masks'].append(gt_mask) # px + results[cls_name[0]]['gt_sp'].extend(items['anomaly'].detach().cpu()) + + with torch.no_grad(): + image_features, patch_features = model.encode_image(image, features_list, DPAM_layer = 20) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + + text_probs = image_features @ text_features.permute(0, 2, 1) + text_probs = (text_probs/0.07).softmax(-1) + text_probs = text_probs[:, 0, 1] + anomaly_map_list = [] + for idx, patch_feature in enumerate(patch_features): + if idx >= args.feature_map_layer[0]: + patch_feature = patch_feature/ patch_feature.norm(dim = -1, keepdim = True) + similarity, _ = AnomalyCLIP_lib.compute_similarity(patch_feature, text_features[0]) + similarity_map = AnomalyCLIP_lib.get_similarity_map(similarity[:, 1:, :], args.image_size) + anomaly_map = (similarity_map[...,1] + 1 - similarity_map[...,0])/2.0 + anomaly_map_list.append(anomaly_map) + + anomaly_map = torch.stack(anomaly_map_list) + + anomaly_map = anomaly_map.sum(dim = 0) + results[cls_name[0]]['pr_sp'].extend(text_probs.detach().cpu()) + anomaly_map = torch.stack([torch.from_numpy(gaussian_filter(i, sigma = args.sigma)) for i in anomaly_map.detach().cpu()], dim = 0 ) + results[cls_name[0]]['anomaly_maps'].append(anomaly_map) + + # save + new_path = items['img_path'][0].replace(dataset_dir, './output_scores').replace('.png', '_scores.json') + interp_anomaly_map = torch.nn.functional.interpolate(anomaly_map[None], size=(224, 224), mode='bicubic') + dic = { + "img_level_score": text_probs.item(), + "pixel_level_score": interp_anomaly_map.squeeze(0).cpu().detach().numpy() + } + os.makedirs(os.path.dirname(new_path), exist_ok=True) + dict_to_json(dic, new_path) + # visualizer(items['img_path'], anomaly_map.detach().cpu().numpy(), args.image_size, args.save_path, cls_name) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("AnomalyCLIP", add_help=True) + # paths + parser.add_argument("--data_path", type=str, default="./data/visa", help="path to test dataset") + parser.add_argument("--save_path", type=str, default='./results/', help='path to save results') + parser.add_argument("--checkpoint_path", type=str, default='./checkpoint/', help='path to checkpoint') + # model + parser.add_argument("--dataset", type=str, default='mvtec') + parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used") + parser.add_argument("--image_size", type=int, default=518, help="image size") + parser.add_argument("--depth", type=int, default=9, help="image size") + parser.add_argument("--n_ctx", type=int, default=12, help="zero shot") + parser.add_argument("--t_n_ctx", type=int, default=4, help="zero shot") + parser.add_argument("--feature_map_layer", type=int, nargs="+", default=[0, 1, 2, 3], help="zero shot") + parser.add_argument("--metrics", type=str, default='image-pixel-level') + parser.add_argument("--seed", type=int, default=111, help="random seed") + parser.add_argument("--sigma", type=int, default=4, help="zero shot") + + args = parser.parse_args() + print(args) + setup_seed(args.seed) + test(args) diff --git a/test.sh b/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..36184014f24a6d3c9cbdd97defc23e39c398724b --- /dev/null +++ b/test.sh @@ -0,0 +1,20 @@ + +device=0 + +LOG=${save_dir}"res.log" +echo ${LOG} +depth=(9) +n_ctx=(12) +t_n_ctx=(4) +for i in "${!depth[@]}";do + for j in "${!n_ctx[@]}";do + ## train on the VisA dataset + base_dir=${depth[i]}_${n_ctx[j]}_${t_n_ctx[0]}_multiscale + save_dir=./checkpoints/${base_dir}/ + CUDA_VISIBLE_DEVICES=${device} python test.py --dataset mvtec \ + --data_path /remote-home/iot_zhouqihang/data/mvdataset --save_path ./results/${base_dir}/zero_shot \ + --checkpoint_path ${save_dir}epoch_15.pth \ + --features_list 6 12 18 24 --image_size 518 --depth ${depth[i]} --n_ctx ${n_ctx[j]} --t_n_ctx ${t_n_ctx[0]} + wait + done +done diff --git a/test_ZSAD.sh b/test_ZSAD.sh new file mode 100644 index 0000000000000000000000000000000000000000..b157c704e8d0c6cdabb3c716279f452557ae6a29 --- /dev/null +++ b/test_ZSAD.sh @@ -0,0 +1,20 @@ + +device=0 + +LOG=${save_dir}"res.log" +echo ${LOG} +depth=(9) +n_ctx=(12) +t_n_ctx=(4) +for i in "${!depth[@]}";do + for j in "${!n_ctx[@]}";do + ## train on the VisA dataset + base_dir=${depth[i]}_${n_ctx[j]}_${t_n_ctx[0]}_multiscale + save_dir=./checkpoints/${base_dir}/ + CUDA_VISIBLE_DEVICES=${device} python test.py --dataset SDD \ + --data_path ./ZSAD-dataset --save_path ./results/${base_dir}/zero_shot \ + --checkpoint_path ${save_dir}epoch_15.pth \ + --features_list 6 12 18 24 --image_size 518 --depth ${depth[i]} --n_ctx ${n_ctx[j]} --t_n_ctx ${t_n_ctx[0]} + wait + done +done diff --git a/test_data.sh b/test_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..c9888ec97665517e7ef2abfb17f0332140fc6431 --- /dev/null +++ b/test_data.sh @@ -0,0 +1,19 @@ +device=0 + +LOG=${save_dir}"res.log" +echo ${LOG} +depth=(9) +n_ctx=(12) +t_n_ctx=(4) +for i in "${!depth[@]}";do + for j in "${!n_ctx[@]}";do + ## train on the VisA dataset + base_dir=${depth[i]}_${n_ctx[j]}_${t_n_ctx[0]}_multiscale + save_dir=./checkpoints/${base_dir}/ + CUDA_VISIBLE_DEVICES=${device} python test.py --dataset SDD \ + --data_path ./data --save_path ./results/${base_dir}/zero_shot \ + --checkpoint_path ${save_dir}epoch_15.pth \ + --features_list 6 12 18 24 --image_size 518 --depth ${depth[i]} --n_ctx ${n_ctx[j]} --t_n_ctx ${t_n_ctx[0]} + wait + done +done \ No newline at end of file diff --git a/test_one_example.py b/test_one_example.py new file mode 100644 index 0000000000000000000000000000000000000000..6f537f13207d68dd37f2bc7c3d6a532211bcc550 --- /dev/null +++ b/test_one_example.py @@ -0,0 +1,120 @@ +import AnomalyCLIP_lib +import torch +import argparse +import torch.nn.functional as F +from prompt_ensemble import AnomalyCLIP_PromptLearner +from PIL import Image + +import os +import random +import numpy as np +from utils import get_transform, normalize + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +# from visualization import visualizer +import cv2 + + +def apply_ad_scoremap(image, scoremap, alpha=0.5): + np_image = np.asarray(image, dtype=float) + scoremap = (scoremap * 255).astype(np.uint8) + scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET) + scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB) + return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8) + +def visualizer(path, anomaly_map, img_size): + filename = os.path.basename(path) + dirname = os.path.dirname(path) + vis = cv2.cvtColor(cv2.resize(cv2.imread(path), (img_size, img_size)), cv2.COLOR_BGR2RGB) # RGB + mask = normalize(anomaly_map[0]) + vis = apply_ad_scoremap(vis, mask) + vis = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR) # BGR + save_vis = os.path.join(dirname, f'anomaly_map_{filename}') + print(save_vis) + cv2.imwrite(save_vis, vis) + +from scipy.ndimage import gaussian_filter +def test(args): + img_size = args.image_size + features_list = args.features_list + image_path = args.image_path + + device = "cuda" if torch.cuda.is_available() else "cpu" + + AnomalyCLIP_parameters = {"Prompt_length": args.n_ctx, "learnabel_text_embedding_depth": args.depth, "learnabel_text_embedding_length": args.t_n_ctx} + + model, _ = AnomalyCLIP_lib.load("ViT-L/14@336px", device=device, design_details = AnomalyCLIP_parameters) + model.eval() + + preprocess, target_transform = get_transform(args) + + + prompt_learner = AnomalyCLIP_PromptLearner(model.to("cpu"), AnomalyCLIP_parameters) + checkpoint = torch.load(args.checkpoint_path) + prompt_learner.load_state_dict(checkpoint["prompt_learner"]) + prompt_learner.to(device) + model.to(device) + model.visual.DAPM_replace(DPAM_layer = 20) + + prompts, tokenized_prompts, compound_prompts_text = prompt_learner(cls_id = None) + text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).float() + text_features = torch.stack(torch.chunk(text_features, dim = 0, chunks = 2), dim = 1) + text_features = text_features/text_features.norm(dim=-1, keepdim=True) + + img = Image.open(image_path) + img = preprocess(img) + + print("img", img.shape) + image = img.reshape(1, 3, img_size, img_size).to(device) + + with torch.no_grad(): + image_features, patch_features = model.encode_image(image, features_list, DPAM_layer = 20) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + + text_probs = image_features @ text_features.permute(0, 2, 1) + text_probs = (text_probs/0.07).softmax(-1) + text_probs = text_probs[:, 0, 1] + anomaly_map_list = [] + for idx, patch_feature in enumerate(patch_features): + if idx >= args.feature_map_layer[0]: + patch_feature = patch_feature/ patch_feature.norm(dim = -1, keepdim = True) + similarity, _ = AnomalyCLIP_lib.compute_similarity(patch_feature, text_features[0]) + similarity_map = AnomalyCLIP_lib.get_similarity_map(similarity[:, 1:, :], args.image_size) + anomaly_map = (similarity_map[...,1] + 1 - similarity_map[...,0])/2.0 + anomaly_map_list.append(anomaly_map) + + anomaly_map = torch.stack(anomaly_map_list) + + anomaly_map = anomaly_map.sum(dim = 0) + + anomaly_map = torch.stack([torch.from_numpy(gaussian_filter(i, sigma = args.sigma)) for i in anomaly_map.detach().cpu()], dim = 0 ) + + visualizer(image_path, anomaly_map.detach().cpu().numpy(), args.image_size) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("AnomalyCLIP", add_help=True) + # paths + parser.add_argument("--image_path", type=str, default="./data/visa", help="path to test dataset") + parser.add_argument("--checkpoint_path", type=str, default='./checkpoint/', help='path to checkpoint') + # model + parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used") + parser.add_argument("--image_size", type=int, default=518, help="image size") + parser.add_argument("--depth", type=int, default=9, help="image size") + parser.add_argument("--n_ctx", type=int, default=12, help="zero shot") + parser.add_argument("--t_n_ctx", type=int, default=4, help="zero shot") + parser.add_argument("--feature_map_layer", type=int, nargs="+", default=[0, 1, 2, 3], help="zero shot") + parser.add_argument("--seed", type=int, default=111, help="random seed") + parser.add_argument("--sigma", type=int, default=4, help="zero shot") + + args = parser.parse_args() + print(args) + setup_seed(args.seed) + test(args) diff --git a/test_one_example.sh b/test_one_example.sh new file mode 100644 index 0000000000000000000000000000000000000000..c72759c3b951125a8d22505afa879a78c241a512 --- /dev/null +++ b/test_one_example.sh @@ -0,0 +1,59 @@ + +device=0 + +# LOG=${save_dir}"res.log" +# echo ${LOG} +# depth=(9) +# n_ctx=(12) +# t_n_ctx=(4) +# for i in "${!depth[@]}";do +# for j in "${!n_ctx[@]}";do +# ## train on the VisA dataset +# base_dir=${depth[i]}_${n_ctx[j]}_${t_n_ctx[0]}_multiscale +# save_dir=./checkpoints_memory/${base_dir}/ +# CUDA_VISIBLE_DEVICES=${device} python test.py --dataset mvtec \ +# --data_path /root/data/mvdataset --save_path ./results/${base_dir}/zero_shot \ +# --checkpoint_path ${save_dir}epoch_15.pth \ +# --features_list 6 12 18 24 --image_size 518 --depth ${depth[i]} --n_ctx ${n_ctx[j]} --t_n_ctx ${t_n_ctx[0]} +# wait +# done +# done + + +# LOG=${save_dir}"res.log" +# echo ${LOG} +# depth=(9) +# n_ctx=(12) +# t_n_ctx=(4) +# for i in "${!depth[@]}";do +# for j in "${!n_ctx[@]}";do +# ## train on the VisA dataset +# base_dir=${depth[i]}_${n_ctx[j]}_${t_n_ctx[0]}_multiscale_visa +# save_dir=./checkpoints/${base_dir}/ +# CUDA_VISIBLE_DEVICES=${device} python test.py --dataset visa \ +# --data_path /remote-home/iot_zhouqihang/data/Visa --save_path ./results/${base_dir}/zero_shot \ +# --checkpoint_path ${save_dir}epoch_15.pth \ +# --features_list 6 12 18 24 --image_size 518 --depth ${depth[i]} --n_ctx ${n_ctx[j]} --t_n_ctx ${t_n_ctx[0]} +# wait +# done +# done + + +LOG=${save_dir}"res.log" +echo ${LOG} +depth=(9) +n_ctx=(12) +t_n_ctx=(4) +for i in "${!depth[@]}";do + for j in "${!n_ctx[@]}";do + ## train on the VisA dataset + base_dir=${depth[i]}_${n_ctx[j]}_${t_n_ctx[0]}_multiscale + save_dir=./checkpoints_mul/${base_dir}/ + CUDA_VISIBLE_DEVICES=${device} python test_one_example.py \ + --image_path /remote-home/iot_zhouqihang/old_root/data/mvdataset/carpet/test/hole/004.png \ + --checkpoint_path ${save_dir}epoch_15.pth \ + --features_list 6 12 18 24 --image_size 518 --depth ${depth[i]} --n_ctx ${n_ctx[j]} --t_n_ctx ${t_n_ctx[0]} + wait + done +done + diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7669852d1a507345508fe312bb8c9ac702a031ca --- /dev/null +++ b/train.py @@ -0,0 +1,137 @@ +import AnomalyCLIP_lib +import torch +import argparse +import torch.nn.functional as F +from prompt_ensemble import AnomalyCLIP_PromptLearner +from loss import FocalLoss, BinaryDiceLoss +from utils import normalize +from dataset import Dataset +from logger import get_logger +from tqdm import tqdm +import numpy as np +import os +import random +from utils import get_transform +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def train(args): + + logger = get_logger(args.save_path) + + preprocess, target_transform = get_transform(args) + device = "cuda" if torch.cuda.is_available() else "cpu" + + AnomalyCLIP_parameters = {"Prompt_length": args.n_ctx, "learnabel_text_embedding_depth": args.depth, "learnabel_text_embedding_length": args.t_n_ctx} + + model, _ = AnomalyCLIP_lib.load("ViT-L/14@336px", device=device, design_details = AnomalyCLIP_parameters) + model.eval() + + train_data = Dataset(root=args.train_data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset) + train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True) + + ########################################################################################## + prompt_learner = AnomalyCLIP_PromptLearner(model.to("cpu"), AnomalyCLIP_parameters) + prompt_learner.to(device) + model.to(device) + model.visual.DAPM_replace(DPAM_layer = 20) + ########################################################################################## + optimizer = torch.optim.Adam(list(prompt_learner.parameters()), lr=args.learning_rate, betas=(0.5, 0.999)) + + # losses + loss_focal = FocalLoss() + loss_dice = BinaryDiceLoss() + + + model.eval() + prompt_learner.train() + for epoch in tqdm(range(args.epoch)): + model.eval() + prompt_learner.train() + loss_list = [] + image_loss_list = [] + + for items in tqdm(train_dataloader): + image = items['img'].to(device) + label = items['anomaly'] + + gt = items['img_mask'].squeeze().to(device) + gt[gt > 0.5] = 1 + gt[gt <= 0.5] = 0 + + with torch.no_grad(): + # Apply DPAM to the layer from 6 to 24 + # DPAM_layer represents the number of layer refined by DPAM from top to bottom + # DPAM_layer = 1, no DPAM is used + # DPAM_layer = 20 as default + image_features, patch_features = model.encode_image(image, args.features_list, DPAM_layer = 20) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + + #################################### + prompts, tokenized_prompts, compound_prompts_text = prompt_learner(cls_id = None) + text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).float() + text_features = torch.stack(torch.chunk(text_features, dim = 0, chunks = 2), dim = 1) + text_features = text_features/text_features.norm(dim=-1, keepdim=True) + # Apply DPAM surgery + text_probs = image_features.unsqueeze(1) @ text_features.permute(0, 2, 1) + text_probs = text_probs[:, 0, ...]/0.07 + image_loss = F.cross_entropy(text_probs.squeeze(), label.long().cuda()) + image_loss_list.append(image_loss.item()) + ######################################################################### + similarity_map_list = [] + # similarity_map_list.append(similarity_map) + for idx, patch_feature in enumerate(patch_features): + if idx >= args.feature_map_layer[0]: + patch_feature = patch_feature/ patch_feature.norm(dim = -1, keepdim = True) + similarity, _ = AnomalyCLIP_lib.compute_similarity(patch_feature, text_features[0]) + similarity_map = AnomalyCLIP_lib.get_similarity_map(similarity[:, 1:, :], args.image_size).permute(0, 3, 1, 2) + similarity_map_list.append(similarity_map) + + loss = 0 + for i in range(len(similarity_map_list)): + loss += loss_focal(similarity_map_list[i], gt) + loss += loss_dice(similarity_map_list[i][:, 1, :, :], gt) + loss += loss_dice(similarity_map_list[i][:, 0, :, :], 1-gt) + + optimizer.zero_grad() + (loss+image_loss).backward() + optimizer.step() + loss_list.append(loss.item()) + # logs + if (epoch + 1) % args.print_freq == 0: + logger.info('epoch [{}/{}], loss:{:.4f}, image_loss:{:.4f}'.format(epoch + 1, args.epoch, np.mean(loss_list), np.mean(image_loss_list))) + + # save model + if (epoch + 1) % args.save_freq == 0: + ckp_path = os.path.join(args.save_path, 'epoch_' + str(epoch + 1) + '.pth') + torch.save({"prompt_learner": prompt_learner.state_dict()}, ckp_path) + +if __name__ == '__main__': + parser = argparse.ArgumentParser("AnomalyCLIP", add_help=True) + parser.add_argument("--train_data_path", type=str, default="./data/visa", help="train dataset path") + parser.add_argument("--save_path", type=str, default='./checkpoint', help='path to save results') + + + parser.add_argument("--dataset", type=str, default='mvtec', help="train dataset name") + + parser.add_argument("--depth", type=int, default=9, help="image size") + parser.add_argument("--n_ctx", type=int, default=12, help="zero shot") + parser.add_argument("--t_n_ctx", type=int, default=4, help="zero shot") + parser.add_argument("--feature_map_layer", type=int, nargs="+", default=[0, 1, 2, 3], help="zero shot") + parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used") + + parser.add_argument("--epoch", type=int, default=15, help="epochs") + parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate") + parser.add_argument("--batch_size", type=int, default=8, help="batch size") + parser.add_argument("--image_size", type=int, default=518, help="image size") + parser.add_argument("--print_freq", type=int, default=1, help="print frequency") + parser.add_argument("--save_freq", type=int, default=1, help="save frequency") + parser.add_argument("--seed", type=int, default=111, help="random seed") + args = parser.parse_args() + setup_seed(args.seed) + train(args) diff --git a/train.sh b/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..fb0a7e4c472f306c50f757a4f34dd5188f5ef773 --- /dev/null +++ b/train.sh @@ -0,0 +1,39 @@ + +device=1 + +LOG=${save_dir}"res.log" +echo ${LOG} +depth=(9) +n_ctx=(12) +t_n_ctx=(4) +for i in "${!depth[@]}";do + for j in "${!n_ctx[@]}";do + ## train on the VisA dataset + base_dir=${depth[i]}_${n_ctx[j]}_${t_n_ctx[0]}_multiscale + save_dir=./checkpoints/${base_dir}/ + CUDA_VISIBLE_DEVICES=${device} python train.py --dataset visa --train_data_path /remote-home/iot_zhouqihang/data/Visa \ + --save_path ${save_dir} \ + --features_list 6 12 18 24 --image_size 518 --batch_size 8 --print_freq 1 \ + --epoch 15 --save_freq 1 --depth ${depth[i]} --n_ctx ${n_ctx[j]} --t_n_ctx ${t_n_ctx[0]} + wait + done +done + + +LOG=${save_dir}"res.log" +echo ${LOG} +depth=(9) +n_ctx=(12) +t_n_ctx=(4) +for i in "${!depth[@]}";do + for j in "${!n_ctx[@]}";do + ## train on the VisA dataset + base_dir=${depth[i]}_${n_ctx[j]}_${t_n_ctx[0]}_multiscale_visa + save_dir=./checkpoints/${base_dir}/ + CUDA_VISIBLE_DEVICES=${device} python train.py --dataset mvtec --train_data_path /remote-home/iot_zhouqihang/data/mvdataset \ + --save_path ${save_dir} \ + --features_list 6 12 18 24 --image_size 518 --batch_size 8 --print_freq 1 \ + --epoch 15 --save_freq 1 --depth ${depth[i]} --n_ctx ${n_ctx[j]} --t_n_ctx ${t_n_ctx[0]} + wait + done +done \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f28f094a9c689d2af0111f03c71aaf8d1ceb3ba1 --- /dev/null +++ b/utils.py @@ -0,0 +1,24 @@ +import torchvision.transforms as transforms +# from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode +from AnomalyCLIP_lib.transform import image_transform +from AnomalyCLIP_lib.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + + +def normalize(pred, max_value=None, min_value=None): + if max_value is None or min_value is None: + return (pred - pred.min()) / (pred.max() - pred.min()) + else: + return (pred - min_value) / (max_value - min_value) + +def get_transform(args): + preprocess = image_transform(args.image_size, is_train=False, mean = OPENAI_DATASET_MEAN, std = OPENAI_DATASET_STD) + target_transform = transforms.Compose([ + transforms.Resize((args.image_size, args.image_size)), + transforms.CenterCrop(args.image_size), + transforms.ToTensor() + ]) + preprocess.transforms[0] = transforms.Resize(size=(args.image_size, args.image_size), interpolation=transforms.InterpolationMode.BICUBIC, + max_size=None, antialias=None) + preprocess.transforms[1] = transforms.CenterCrop(size=(args.image_size, args.image_size)) + return preprocess, target_transform diff --git a/utils/dump_scores.py b/utils/dump_scores.py new file mode 100644 index 0000000000000000000000000000000000000000..52ffa91875b82bf01d83364afcd442e541f52633 --- /dev/null +++ b/utils/dump_scores.py @@ -0,0 +1,52 @@ +from tqdm import tqdm +import os +from pathlib import Path +import numpy as np + +from evaluation.utils.json_helpers import dict_to_json + + +# ----------------------------------------------------------------------------- +# This class is a sample code provided to help you with saving your predicted +# scores as JSON files. We strongly suggest that you use the provided methods, +# but you are NOT required to follow this structure. Feel free to adapt, +# modify, or extend this template in favor of your own workflow. +# ----------------------------------------------------------------------------- +class DumpScores: + def __init__(self): + self.scores_dir = "./output_scores" + self.save_scores_precision = 4 + + def save_scores(self, image_path_list, pred_img_level, pred_pix_level): + print( + f"Saving scores at '{self.scores_dir}' with precision: '{self.save_scores_precision}'" + ) + for i in tqdm(range(len(image_path_list)), desc=f"Saving scores"): + image_path = image_path_list[i] + image_score_path = self.get_scores_path_for_image(image_path) + os.makedirs(os.path.dirname(image_score_path), exist_ok=True) + + vectorized_enforce_precision = np.vectorize(self.enforce_precision) + d = { + "img_level_score": vectorized_enforce_precision( + pred_img_level[i], self.save_scores_precision + ), + "pix_level_score": vectorized_enforce_precision( + pred_pix_level[i], self.save_scores_precision + ), + } + dict_to_json(d, image_score_path) + + def get_scores_path_for_image(self, image_path): + """example image_path: './data/photovoltaic_module/test/good/037.png'""" + path = Path(image_path) + + category, split, anomaly_type = path.parts[-4:-1] + image_name = path.stem + + return os.path.join( + self.scores_dir, category, split, anomaly_type, f"{image_name}_scores.json" + ) + + def enforce_precision(self, x, precision): + return float(f"{x:.{precision}f}") diff --git a/visualization.py b/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..9cdb2dde1a49b4b8ebf0f0514e0d3d03dec477a4 --- /dev/null +++ b/visualization.py @@ -0,0 +1,24 @@ +import cv2 +import os +from utils import normalize +import numpy as np + +def visualizer(pathes, anomaly_map, img_size, save_path, cls_name): + for idx, path in enumerate(pathes): + cls = path.split('/')[-2] + filename = path.split('/')[-1] + vis = cv2.cvtColor(cv2.resize(cv2.imread(path), (img_size, img_size)), cv2.COLOR_BGR2RGB) # RGB + mask = normalize(anomaly_map[idx]) + vis = apply_ad_scoremap(vis, mask) + vis = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR) # BGR + save_vis = os.path.join(save_path, 'imgs', cls_name[idx], cls) + if not os.path.exists(save_vis): + os.makedirs(save_vis) + cv2.imwrite(os.path.join(save_vis, filename), vis) + +def apply_ad_scoremap(image, scoremap, alpha=0.5): + np_image = np.asarray(image, dtype=float) + scoremap = (scoremap * 255).astype(np.uint8) + scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET) + scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB) + return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8)