File size: 2,275 Bytes
020dd6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File    :   demo_metauas.py
@Time    :   2025/03/26 23:49:14
@Author  :   Bin-Bin Gao
@Email   :   [email protected]
@Homepage:   https://csgaobb.github.io/
@Version :   1.0
@Desc    :   MetaUAS Demo
'''


import os
import cv2
import torch
import json
import shutil
import kornia as K
import numpy as np

from easydict import EasyDict
from argparse import ArgumentParser
from metauas import MetaUAS, set_random_seed, normalize, apply_ad_scoremap, read_image_as_tensor, safely_load_state_dict

if __name__ == "__main__":
    random_seed = 1
    
    set_random_seed(random_seed)

    ckt_path = 'weights/metauas-256.ckpt'
    img_size = 256
    #ckt_path = "weights/metauas-512.ckpt"
    #img_size = 512

    # load model
    encoder = 'efficientnet-b4' 
    decoder = 'unet' 
    encoder_depth = 5
    decoder_depth = 5
    num_crossfa_layers = 3
    alignment_type =  'sa' 
    fusion_policy = 'cat'

    model = MetaUAS(encoder, 
                decoder, 
                encoder_depth, 
                decoder_depth, 
                num_crossfa_layers, 
                alignment_type, 
                fusion_policy
                )       


    model = safely_load_state_dict(model, ckt_path)
    model.cuda()
    model.eval()
    

    # load test images
    path_root = "./images/"
    path_to_prompt = path_root + "036.png" 
    path_to_query = path_root + "024.png"

    query = read_image_as_tensor(path_to_query)
    prompt = read_image_as_tensor(path_to_prompt)

    if query.shape[1] != img_size:
        resize_trans = K.augmentation.Resize([img_size, img_size], return_transform=True)
        query = resize_trans(query)[0]
        prompt = resize_trans(prompt)[0]
    

    test_data = {
            "query_image": query.cuda(),
            "prompt_image": prompt.cuda(),
        }
     
    # forward
    predicted_masks = model(test_data)

    # visualization
    query_img = test_data["query_image"][0] * 255
    query_img = query_img.permute(1,2,0)

    pred = (1-predicted_masks.squeeze().detach())[:, :, None].cpu().numpy().repeat(3, 2)
    # normalize just for analysis
    scoremap_self = apply_ad_scoremap(query_img.cpu(), normalize(pred))
    cv2.imwrite('./anomaly_map.jpg', scoremap_self)