|
import AnomalyCLIP_lib |
|
import torch |
|
import argparse |
|
import torch.nn.functional as F |
|
from prompt_ensemble import AnomalyCLIP_PromptLearner |
|
from loss import FocalLoss, BinaryDiceLoss |
|
from utils import normalize |
|
from dataset import Dataset |
|
from logger import get_logger |
|
from tqdm import tqdm |
|
|
|
import os |
|
import random |
|
import numpy as np |
|
from tabulate import tabulate |
|
from utils import get_transform |
|
from evaluation.utils.json_helpers import dict_to_json |
|
|
|
def setup_seed(seed): |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
from visualization import visualizer |
|
|
|
from metrics import image_level_metrics, pixel_level_metrics |
|
from tqdm import tqdm |
|
from scipy.ndimage import gaussian_filter |
|
def test(args): |
|
img_size = args.image_size |
|
features_list = args.features_list |
|
dataset_dir = args.data_path |
|
save_path = args.save_path |
|
dataset_name = args.dataset |
|
|
|
logger = get_logger(args.save_path) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
AnomalyCLIP_parameters = {"Prompt_length": args.n_ctx, "learnabel_text_embedding_depth": args.depth, "learnabel_text_embedding_length": args.t_n_ctx} |
|
|
|
model, _ = AnomalyCLIP_lib.load("ViT-L-14-336px.pt", device=device, design_details = AnomalyCLIP_parameters) |
|
model.eval() |
|
|
|
preprocess, target_transform = get_transform(args) |
|
test_data = Dataset(root=args.data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset) |
|
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) |
|
obj_list = test_data.obj_list |
|
|
|
|
|
results = {} |
|
metrics = {} |
|
for obj in obj_list: |
|
results[obj] = {} |
|
results[obj]['gt_sp'] = [] |
|
results[obj]['pr_sp'] = [] |
|
results[obj]['imgs_masks'] = [] |
|
results[obj]['anomaly_maps'] = [] |
|
metrics[obj] = {} |
|
metrics[obj]['pixel-auroc'] = 0 |
|
metrics[obj]['pixel-aupro'] = 0 |
|
metrics[obj]['image-auroc'] = 0 |
|
metrics[obj]['image-ap'] = 0 |
|
|
|
prompt_learner = AnomalyCLIP_PromptLearner(model.to("cpu"), AnomalyCLIP_parameters) |
|
checkpoint = torch.load(args.checkpoint_path) |
|
prompt_learner.load_state_dict(checkpoint["prompt_learner"]) |
|
prompt_learner.to(device) |
|
model.to(device) |
|
model.visual.DAPM_replace(DPAM_layer = 20) |
|
|
|
prompts, tokenized_prompts, compound_prompts_text = prompt_learner(cls_id = None) |
|
text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).float() |
|
text_features = torch.stack(torch.chunk(text_features, dim = 0, chunks = 2), dim = 1) |
|
text_features = text_features/text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
model.to(device) |
|
for idx, items in enumerate(tqdm(test_dataloader)): |
|
image = items['img'].to(device) |
|
cls_name = items['cls_name'] |
|
cls_id = items['cls_id'] |
|
gt_mask = items['img_mask'] |
|
gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0 |
|
results[cls_name[0]]['imgs_masks'].append(gt_mask) |
|
results[cls_name[0]]['gt_sp'].extend(items['anomaly'].detach().cpu()) |
|
|
|
with torch.no_grad(): |
|
image_features, patch_features = model.encode_image(image, features_list, DPAM_layer = 20) |
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
|
|
text_probs = image_features @ text_features.permute(0, 2, 1) |
|
text_probs = (text_probs/0.07).softmax(-1) |
|
text_probs = text_probs[:, 0, 1] |
|
anomaly_map_list = [] |
|
for idx, patch_feature in enumerate(patch_features): |
|
if idx >= args.feature_map_layer[0]: |
|
patch_feature = patch_feature/ patch_feature.norm(dim = -1, keepdim = True) |
|
similarity, _ = AnomalyCLIP_lib.compute_similarity(patch_feature, text_features[0]) |
|
similarity_map = AnomalyCLIP_lib.get_similarity_map(similarity[:, 1:, :], args.image_size) |
|
anomaly_map = (similarity_map[...,1] + 1 - similarity_map[...,0])/2.0 |
|
anomaly_map_list.append(anomaly_map) |
|
|
|
anomaly_map = torch.stack(anomaly_map_list) |
|
|
|
anomaly_map = anomaly_map.sum(dim = 0) |
|
results[cls_name[0]]['pr_sp'].extend(text_probs.detach().cpu()) |
|
anomaly_map = torch.stack([torch.from_numpy(gaussian_filter(i, sigma = args.sigma)) for i in anomaly_map.detach().cpu()], dim = 0 ) |
|
results[cls_name[0]]['anomaly_maps'].append(anomaly_map) |
|
|
|
|
|
new_path = items['img_path'][0].replace(dataset_dir, './output_scores').replace('.png', '_scores.json') |
|
interp_anomaly_map = torch.nn.functional.interpolate(anomaly_map[None], size=(224, 224), mode='bicubic') |
|
dic = { |
|
"img_level_score": text_probs.item(), |
|
"pixel_level_score": interp_anomaly_map.squeeze(0).cpu().detach().numpy() |
|
} |
|
os.makedirs(os.path.dirname(new_path), exist_ok=True) |
|
dict_to_json(dic, new_path) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser("AnomalyCLIP", add_help=True) |
|
|
|
parser.add_argument("--data_path", type=str, default="./data/visa", help="path to test dataset") |
|
parser.add_argument("--save_path", type=str, default='./results/', help='path to save results') |
|
parser.add_argument("--checkpoint_path", type=str, default='./checkpoint/', help='path to checkpoint') |
|
|
|
parser.add_argument("--dataset", type=str, default='mvtec') |
|
parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used") |
|
parser.add_argument("--image_size", type=int, default=518, help="image size") |
|
parser.add_argument("--depth", type=int, default=9, help="image size") |
|
parser.add_argument("--n_ctx", type=int, default=12, help="zero shot") |
|
parser.add_argument("--t_n_ctx", type=int, default=4, help="zero shot") |
|
parser.add_argument("--feature_map_layer", type=int, nargs="+", default=[0, 1, 2, 3], help="zero shot") |
|
parser.add_argument("--metrics", type=str, default='image-pixel-level') |
|
parser.add_argument("--seed", type=int, default=111, help="random seed") |
|
parser.add_argument("--sigma", type=int, default=4, help="zero shot") |
|
|
|
args = parser.parse_args() |
|
print(args) |
|
setup_seed(args.seed) |
|
test(args) |
|
|