File size: 3,854 Bytes
92aa748
 
a2a5b20
92aa748
311d8c3
 
 
 
 
92aa748
d444637
5950f80
8b7fd62
5950f80
e8bfb0c
5950f80
6d3de44
5950f80
c0f4c8a
6d3de44
c0f4c8a
 
 
 
 
415fd74
25c6dcd
94d2d6e
dc87457
 
 
 
 
 
 
 
 
 
 
 
 
a2d2814
dc87457
8b7fd62
d444637
3b82920
42e38b7
741c74f
f825554
 
 
 
96f2bb3
f825554
1c07639
 
 
 
 
 
 
 
 
6088cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19a972e
6088cd1
 
 
 
 
 
142aa6a
 
6088cd1
142aa6a
6088cd1
 
ed2b731
311d8c3
 
92aa748
 
19a972e
b8f132d
187b87b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
from fastai.vision.all import *
from efficientnet_pytorch import EfficientNet 

import torch, torchvision
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image

title = "COVID_19 Infection Detectation App!"
head = (  
    "<body>"
        "<center>"
            "<img src='/file=Gradcam.png' width=200>"
            "<h2>"
                "This Space demonstrates a model based on efficientnetB7 base model. The Model was trained to classify chest xray image. To test it, "
            "</h2>"
            "<h3>"
                "Use the Example Images provided below the up or Upload your own xray images with the App." 
            "</h3>"
            "<h3>"
                "!!!PLEASE NOTE MODEL WAS TRAINED and VALIDATED USING PNG FILES!!!" 
            "</h3>"
    "</center>"
            "<p>"
                "<b>""<a href='https://www.kaggle.com/datasets/anasmohammedtahir/covidqu'>The model is trained using COVID-QU-Ex dataset</a>""</b>"
                "  that the researchers from Qatar University compiled,that consists of 33,920 chest X-ray (CXR) images including:"
            "</p>"
            "<ul>"
              "<li>" 
                    "11,956 COVID-19"
              "</li>"    
              "<li>" 
                    "11,263 Non-COVID infections (Viral or Bacterial Pneumonia)"
              "</li>"     
              "<li>" 
                    "10,701 Normal"
              "</li>"     
            "</ul>"
            "<p>"
                "Thanks to Kaggle & KaggleX, this is the largest ever created lung mask dataset, that I am aware of publicly available as of October 2023."
            "</p>"
    "</body>"
)
description = head

examples = [
    ['covid/covid_1038.png'], ['covid/covid_1034.png'], 
    ['covid/cd.png'], ['covid/covid_1021.png'], 
    ['covid/covid_1027.png'], ['covid/covid_1042.png'], 
    ['covid/covid_1031.png']
]

#learn = load_learner('model/predictcovidfastaifinal18102023.pkl')
learn = load_learner('model/final_20102023_eb7_model.pkl')

categories = learn.dls.vocab

def predict_image(get_image):
   pred, idx, probs = learn.predict(get_image)
   return dict(zip(categories, map(float, probs)))

def interpretation_function(image_path, model, target_layer, target_category=None):
    # Load and preprocess the image
    image = Image.open(image_path)
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    input_image = preprocess(image).unsqueeze(0)  # Add a batch dimension
    input_image = input_image.to('cuda' if torch.cuda.is_available() else 'cpu')

    # Create an instance of GradCAM
    cam = GradCAM(model=model, target_layer=target_layer)

    # Compute the CAM
    cam_image = cam(input_tensor=input_image, target_category=target_category)

    # Show the CAM on the original image
    visualization = show_cam_on_image(input_image, cam_image)
    #visualization.show()


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            input_img = gr.Image(label="Input Image", shape=(224, 224))
            with gr.Row():
                interpret = gr.Button("Interpret")
        with gr.Column():
            label = gr.Label(label="Predicted Class")
        with gr.Column():
            interpretation = gr.components.Interpretation(input_img)
    interpret.click(interpretation_function(input_img,learn, learn.layer4[-1],target_category=None), input_img, interpretation)

#interpretation="default"
enable_queue=True

gr.Interface(fn=predict_image, inputs=gr.Image(shape=(224,224)),
             outputs = gr.Label(num_top_classes=3),title=title,description=description,examples=examples, interpretation=interpretation,enable_queue=enable_queue).launch(share=False)