File size: 4,216 Bytes
58c979f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
import torch

from Utilities.model import Net
from Utilities import config
from Utilities.utils import generate_confidences, generate_gradcam, generate_missclassified_imgs

inputs = [

    gr.Image(shape=(32, 32), label="Input Image"),
    gr.Slider(minimum=1, maximum=10, step=1, label="Number of Top Prediction to Display"),
    gr.Checkbox(default=False, label="Show GradCAM"),
    gr.Slider(minimum=-2, maximum=-1, step=1, value=-1, label="GradCAM Layer (from the end)"),
    gr.Slider(minimum=0, maximum=1, value=0.5, label="GradCAM Heatmap Opacity"),
    gr.Checkbox(label="Show Incorrect Predictions"),
    gr.Slider(minimum=5, maximum=50, step=5, label="Number of Incorrect Predictions to Display"),

]

model = Net(
    num_classes=config.NUM_CLASSES,
    dropout_percentage = config.DROPOUT_PERCENTAGE,
    norm = config.LAYER_NORM,
    criterion = config.CRITERION,
    learning_rate = config.LEARNING_RATE,
    weight_decay = config.WEIGHT_DECAY
)

model.load_state_dict(
    torch.load(
        config.MODEL_PATH,
        map_location=torch.device(config.ACCELERATOR)
    )
)

model.pred_store = torch.load(config.PRED_STORE_PATH, map_location=torch.device(config.ACCELERATOR))

def generate_gradio_output(
        input_img,
        num_top_preds,
        show_gradcam,
        gradcam_layer,
        gradcam_opacity,
        show_misclassified,
        num_misclassified,
):
    processed_img, confidences = generate_confidences(
        model=model,
        input_img=input_img,
        num_top_preds=num_top_preds 
    )

    visulization = generate_gradcam(
        model=model,
        org_img=input_img,
        input_img=processed_img,
        show_gradcam=show_gradcam,
        gradcam_layer=gradcam_layer,
        gradcam_opacity=gradcam_opacity,
    )

    plot = generate_missclassified_imgs(
        model=model,
        show_misclassified=show_misclassified,
        num_misclassified=num_misclassified,
    )

    return confidences, visulization, plot

outputs = [
    gr.Label(visible=True, scale=0.5, label="Classification Confidences"),
    gr.Image(shape=(32, 32), label="GradCAM Visualization").style(
        width=256, height=256, visible=True
    ),
    gr.Plot(visible=True, label="Misclassified Images")
]

examples = [
    [config.EXAMPLE_IMG_PATH + "cat.jpeg", 3, True, -2, 0.68, True, 40],
    [config.EXAMPLE_IMG_PATH + "horse.jpg", 3, True, -2, 0.59, True, 25],
    [config.EXAMPLE_IMG_PATH + "bird.webp", 10, True, -1, 0.55, True, 20],
    [config.EXAMPLE_IMG_PATH + "dog1.jpg", 10, True, -1, 0.33, True, 45],
    [config.EXAMPLE_IMG_PATH + "frog1.webp", 5, True, -1, 0.64, True, 40],
    [config.EXAMPLE_IMG_PATH + "deer.webp", 1, True, -2, 0.45, True, 20],
    [config.EXAMPLE_IMG_PATH + "airplane.png", 3, True, -2, 0.43, True, 40],
    [config.EXAMPLE_IMG_PATH + "shipp.jpg", 7, True, -1, 0.6, True, 30],
    [config.EXAMPLE_IMG_PATH + "car.jpg", 2, True, -1, 0.68, True, 30],
    [config.EXAMPLE_IMG_PATH + "truck1.jpg", 5, True, -2, 0.51, True, 35],
]

title = "Image Classification (CIFAR10 - 10 Classes) with GradCAM"
description = """A simple Gradio interface to visualize the output of a CNN trained on CIFAR10 dataset with GradCAM and Misclassified images. 
The architecture is inspired from David Page's (myrtle.ai) DAWNBench winning model archiecture.
Please input the image and select the number of top predictions to display - you will see the top predictions and their corresponding confidence scores.
You can also select whether to show GradCAM for the particular image (utilizes the gradients of the classification score with respect to the final convolutional feature map, to identify the parts of an input image that most impact the classification score).
You need to select the model layer where the gradients need to be plugged from - this affects how much of the image is used to compute the GradCAM.
You can also select whether to show misclassified images - these are the images that the model misclassified.
Some examples are provided in the examples tab.
"""

gr.Interface(
    fn=generate_gradio_output,
    inputs=inputs,
    outputs=outputs,
    title=title,
    description=description,
    examples=examples
).launch()