File size: 1,722 Bytes
be419be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
import clip
import torch

import utils

clip_model = "RN50x4"
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(clip_model, device=device, jit=False)
model.eval()


def grad_cam_fn(text, img, saliency_layer):
    resize = model.visual.input_resolution
    img = img.resize((resize, resize))

    text_input = clip.tokenize([text]).to(device)
    text_feature = model.encode_text(text_input).float()
    image_input = preprocess(img).unsqueeze(0).to(device)

    attn_map = utils.gradCAM(
        model.visual,
        image_input,
        text_feature,
        getattr(model.visual, saliency_layer)
    )
    attn_map = attn_map.squeeze().detach().cpu().numpy()
    attn_map = utils.getAttMap(img, attn_map)

    return attn_map


if __name__ == '__main__':
    interface = gr.Interface(
        fn=grad_cam_fn,
        inputs=[
            gr.inputs.Textbox(
                label="Target Text",
                lines=1),
            gr.inputs.Image(
                label='Input Image',
                image_mode="RGB",
                type='pil',
                shape=(512, 512)),
            gr.inputs.Dropdown(
                ["layer4", "layer3", "layer2", "layer1"],
                default="layer4",
                label="Saliency Layer")
        ],
        outputs=gr.outputs.Image(
            type="pil", 
            label="Attention Map"),
        examples=[
            ['a cat lying on the floor', 'assets/cat_dog.jpg', 'layer4'],
            ['a dog sitting', 'assets/cat_dog.jpg', 'layer4']
        ],
        description="OpenAI CLIP Grad CAM")
    interface.launch(
        server_name='0.0.0.0',
        server_port=7861,
        share=False)