import AnomalyCLIP_lib import torch import argparse import torch.nn.functional as F from prompt_ensemble import AnomalyCLIP_PromptLearner from loss import FocalLoss, BinaryDiceLoss from utils import normalize from dataset import Dataset from logger import get_logger from tqdm import tqdm import torchvision import os import random import numpy as np from tabulate import tabulate from utils import get_transform from evaluation.utils.json_helpers import dict_to_json def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False from visualization import visualizer from metrics import image_level_metrics, pixel_level_metrics from tqdm import tqdm from scipy.ndimage import gaussian_filter import torchvision.transforms as transforms transform224 = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224) ]) to_tensor = transforms.ToTensor() def test(args): img_size = args.image_size features_list = args.features_list dataset_dir = args.data_path save_path = args.save_path dataset_name = args.dataset logger = get_logger(args.save_path) device = "cuda" if torch.cuda.is_available() else "cpu" AnomalyCLIP_parameters = {"Prompt_length": args.n_ctx, "learnabel_text_embedding_depth": args.depth, "learnabel_text_embedding_length": args.t_n_ctx} model, _ = AnomalyCLIP_lib.load("ViT-L-14-336px.pt", device=device, design_details = AnomalyCLIP_parameters) model.eval() preprocess, target_transform = get_transform(args) test_data = Dataset(root=args.data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset) test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) obj_list = test_data.obj_list results = {} metrics = {} for obj in obj_list: results[obj] = {} results[obj]['gt_sp'] = [] results[obj]['pr_sp'] = [] results[obj]['imgs_masks'] = [] results[obj]['anomaly_maps'] = [] metrics[obj] = {} metrics[obj]['pixel-auroc'] = 0 metrics[obj]['pixel-aupro'] = 0 metrics[obj]['image-auroc'] = 0 metrics[obj]['image-ap'] = 0 prompt_learner = AnomalyCLIP_PromptLearner(model.to("cpu"), AnomalyCLIP_parameters) checkpoint = torch.load(args.checkpoint_path) prompt_learner.load_state_dict(checkpoint["prompt_learner"]) prompt_learner.to(device) model.to(device) model.visual.DAPM_replace(DPAM_layer = 20) prompts, tokenized_prompts, compound_prompts_text = prompt_learner(cls_id = None) text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).float() text_features = torch.stack(torch.chunk(text_features, dim = 0, chunks = 2), dim = 1) text_features = text_features/text_features.norm(dim=-1, keepdim=True) model.to(device) for idx, items in enumerate(tqdm(test_dataloader)): image = items['img'].to(device) cls_name = items['cls_name'] cls_id = items['cls_id'] gt_mask = items['img_mask'] gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0 results[cls_name[0]]['imgs_masks'].append(gt_mask) # px results[cls_name[0]]['gt_sp'].extend(items['anomaly'].detach().cpu()) with torch.no_grad(): image_features, patch_features = model.encode_image(image, features_list, DPAM_layer = 20) image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_probs = image_features @ text_features.permute(0, 2, 1) text_probs = (text_probs/0.07).softmax(-1) text_probs = text_probs[:, 0, 1] anomaly_map_list = [] for idx, patch_feature in enumerate(patch_features): if idx >= args.feature_map_layer[0]: patch_feature = patch_feature/ patch_feature.norm(dim = -1, keepdim = True) similarity, _ = AnomalyCLIP_lib.compute_similarity(patch_feature, text_features[0]) similarity_map = AnomalyCLIP_lib.get_similarity_map(similarity[:, 1:, :], args.image_size) anomaly_map = (similarity_map[...,1] + 1 - similarity_map[...,0])/2.0 anomaly_map_list.append(anomaly_map) anomaly_map = torch.stack(anomaly_map_list) anomaly_map = anomaly_map.sum(dim = 0) results[cls_name[0]]['pr_sp'].extend(text_probs.detach().cpu()) anomaly_map = torch.stack([torch.from_numpy(gaussian_filter(i, sigma = args.sigma)) for i in anomaly_map.detach().cpu()], dim = 0 ) results[cls_name[0]]['anomaly_maps'].append(anomaly_map) # save new_path = items['img_path'][0].replace(dataset_dir, './output_scores').replace('.png', '_scores.json') # interp_anomaly_map = torch.nn.functional.interpolate(anomaly_map[None], size=(224, 224), mode='bicubic') tensor = anomaly_map.cpu() # Convert to PIL Image pil_image = torchvision.transforms.functional.to_pil_image(tensor) # Apply transforms interp_anomaly_map = to_tensor(transform224(pil_image)) print(interp_anomaly_map.min(), interp_anomaly_map.max()) dic = { "img_level_score": text_probs.item(), "pix_level_score": interp_anomaly_map.cpu().detach().numpy() } os.makedirs(os.path.dirname(new_path), exist_ok=True) dict_to_json(dic, new_path) # visualizer(items['img_path'], anomaly_map.detach().cpu().numpy(), args.image_size, args.save_path, cls_name) if __name__ == '__main__': parser = argparse.ArgumentParser("AnomalyCLIP", add_help=True) # paths parser.add_argument("--data_path", type=str, default="./data/visa", help="path to test dataset") parser.add_argument("--save_path", type=str, default='./results/', help='path to save results') parser.add_argument("--checkpoint_path", type=str, default='./checkpoint/', help='path to checkpoint') # model parser.add_argument("--dataset", type=str, default='mvtec') parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used") parser.add_argument("--image_size", type=int, default=518, help="image size") parser.add_argument("--depth", type=int, default=9, help="image size") parser.add_argument("--n_ctx", type=int, default=12, help="zero shot") parser.add_argument("--t_n_ctx", type=int, default=4, help="zero shot") parser.add_argument("--feature_map_layer", type=int, nargs="+", default=[0, 1, 2, 3], help="zero shot") parser.add_argument("--metrics", type=str, default='image-pixel-level') parser.add_argument("--seed", type=int, default=111, help="random seed") parser.add_argument("--sigma", type=int, default=4, help="zero shot") args = parser.parse_args() print(args) setup_seed(args.seed) test(args)