File size: 2,425 Bytes
65eeb0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr


from src.inference import load_classifier, load_model, generate_images, convert_into_image, classify_image
from src.models import ResUNetGenerator
from src.explainer import GradCAM, preprocess_image

# Loading Models
classifier_path = 'models\\efficientnet_b1-epoch16-val_loss0.46_ft.ckpt'
g_NP_checkpoint = 'models\\g_NP_best.ckpt'
g_PN_checkpoint = 'models\\g_PN_best.ckpt'
g_NP = load_model(g_NP_checkpoint, ResUNetGenerator(gf=32, channels=1))
g_PN = load_model(g_PN_checkpoint, ResUNetGenerator(gf=32, channels=1))
classifier = load_classifier(classifier_path)
target_layer = classifier.model.features[-1]
grad_cam = GradCAM(classifier, target_layer)


def counterfactual_generation(input_image):

    translated_images, recon_images = generate_images(input_image, classifier, g_PN, g_NP)
    translated_images = convert_into_image(translated_images)
    recon_images = convert_into_image(recon_images)
    return translated_images, recon_images

def image_classification(input_image):

    result, target_class = classify_image(input_image, classifier=classifier)
    input_tensor = preprocess_image(input_image)
    cam = grad_cam.generate_cam(input_tensor, target_class)
    cam_image = grad_cam.visualize_cam(cam, input_tensor)

    return result, cam_image

# Defining the components
inputs1 = gr.Image(type="pil", format="png")
inputs2 = gr.Image(type="pil", format="png")
outputs1 = [gr.Image(type="pil", label="Translated Images", format="png"),
           gr.Image(type="pil", label="Reconstructed Images", format="png")]

outputs2 = [gr.Label(label="Classification Result"), gr.Image(label="Grad-CAM", format="png")]

with gr.Blocks() as demo:
    with gr.Tab("Counterfactual Generation"):
        app1 = gr.Interface(fn=counterfactual_generation, inputs=inputs1, outputs=outputs1,
                            title="Counterfactual Image Generation", allow_flagging="never",
                            description="Generate counterfactual images to explain the classifier's decisions.")

    with gr.Tab("Classification"):
        app2 = gr.Interface(fn=image_classification, inputs=inputs2, outputs=outputs2,
                            title="Image Classification", allow_flagging="never",
                            description="Classify the input medical image and visualize Grad-CAM.")

# Launch the app
demo.launch(share=True)