File size: 4,092 Bytes
a25563f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import warnings
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
import json
import os
import torch
from scipy.ndimage import gaussian_filter
import cv2
from method import AdaCLIP_Trainer
import numpy as np

############ Init Model
ckt_path1 = 'weights/pretrained_mvtec_colondb.pth'
ckt_path2 = "weights/pretrained_visa_clinicdb.pth"
ckt_path3 = 'weights/pretrained_all.pth'

# Configurations
image_size = 518
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
model = "ViT-L-14-336"
prompting_depth = 4
prompting_length = 5
prompting_type = 'SD'
prompting_branch = 'VL'
use_hsf = True
k_clusters = 20

config_path = os.path.join('./model_configs', f'{model}.json')

# Prepare model
with open(config_path, 'r') as f:
    model_configs = json.load(f)

# Set up the feature hierarchy
n_layers = model_configs['vision_cfg']['layers']
substage = n_layers // 4
features_list = [substage, substage * 2, substage * 3, substage * 4]

model = AdaCLIP_Trainer(
    backbone=model,
    feat_list=features_list,
    input_dim=model_configs['vision_cfg']['width'],
    output_dim=model_configs['embed_dim'],
    learning_rate=0.,
    device=device,
    image_size=image_size,
    prompting_depth=prompting_depth,
    prompting_length=prompting_length,
    prompting_branch=prompting_branch,
    prompting_type=prompting_type,
    use_hsf=use_hsf,
    k_clusters=k_clusters
).to(device)


def process_image(image, text, options):
    # Load the model based on selected options
    if 'MVTec AD+Colondb' in options:
        model.load(ckt_path1)
    elif 'VisA+Clinicdb' in options:
        model.load(ckt_path2)
    elif 'All' in options:
        model.load(ckt_path3)
    else:
        # Default to 'All' if no valid option is provided
        model.load(ckt_path3)
        print('Invalid option. Defaulting to All.')

    # Ensure image is in RGB mode
    image = image.convert('RGB')

    # Convert PIL image to NumPy array
    np_image = np.array(image)

    # Convert RGB to BGR for OpenCV
    np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
    np_image = cv2.resize(np_image, (image_size, image_size))
    # Preprocess the image and run the model
    img_input = model.preprocess(image).unsqueeze(0)
    img_input = img_input.to(model.device)

    with torch.no_grad():
        anomaly_map, anomaly_score = model.clip_model(img_input, [text], aggregation=True)

    # Process anomaly map
    anomaly_map = anomaly_map[0, :, :].cpu().numpy()
    anomaly_score = anomaly_score[0].cpu().numpy()
    anomaly_map = gaussian_filter(anomaly_map, sigma=4)
    anomaly_map = (anomaly_map * 255).astype(np.uint8)

    # Apply color map and blend with original image
    heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET)
    vis_map = cv2.addWeighted(heat_map, 0.5, np_image, 0.5, 0)

    # Convert OpenCV image back to PIL image for Gradio
    vis_map_pil = Image.fromarray(cv2.cvtColor(vis_map, cv2.COLOR_BGR2RGB))

    return vis_map_pil, f'{anomaly_score:.3f}'

# Define examples
examples = [
    ["asset/img.png", "candle", "MVTec AD+Colondb"],
    ["asset/img2.png", "bottle", "VisA+Clinicdb"],
    ["asset/img3.png", "button", "All"],
]

# Gradio interface layout
demo = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(label="Class Name"),
        gr.Radio(["MVTec AD+Colondb",
                  "VisA+Clinicdb",
                  "All"],
        label="Pre-trained Datasets")
    ],
    outputs=[
        gr.Image(type="pil", label="Output Image"),
        gr.Textbox(label="Anomaly Score"),
    ],
    examples=examples,
    title="AdaCLIP -- Zero-shot Anomaly Detection",
    description="Upload an image, enter class name, and select pre-trained datasets to do zero-shot anomaly detection"
)

# Launch the demo
demo.launch()
# demo.launch(server_name="0.0.0.0", server_port=10002)