anindya-hf-2002 commited on
Commit
84393df
1 Parent(s): 149b827

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -54
app.py CHANGED
@@ -1,55 +1,55 @@
1
- import gradio as gr
2
-
3
-
4
- from src.inference import load_classifier, load_model, generate_images, convert_into_image, classify_image
5
- from src.models import ResUNetGenerator
6
- from src.explainer import GradCAM, preprocess_image
7
-
8
- # Loading Models
9
- classifier_path = 'models\\efficientnet_b1-epoch16-val_loss0.46_ft.ckpt'
10
- g_NP_checkpoint = 'models\\g_NP_best.ckpt'
11
- g_PN_checkpoint = 'models\\g_PN_best.ckpt'
12
- g_NP = load_model(g_NP_checkpoint, ResUNetGenerator(gf=32, channels=1))
13
- g_PN = load_model(g_PN_checkpoint, ResUNetGenerator(gf=32, channels=1))
14
- classifier = load_classifier(classifier_path)
15
- target_layer = classifier.model.features[-1]
16
- grad_cam = GradCAM(classifier, target_layer)
17
-
18
-
19
- def counterfactual_generation(input_image):
20
-
21
- translated_images, recon_images = generate_images(input_image, classifier, g_PN, g_NP)
22
- translated_images = convert_into_image(translated_images)
23
- recon_images = convert_into_image(recon_images)
24
- return translated_images, recon_images
25
-
26
- def image_classification(input_image):
27
-
28
- result, target_class = classify_image(input_image, classifier=classifier)
29
- input_tensor = preprocess_image(input_image)
30
- cam = grad_cam.generate_cam(input_tensor, target_class)
31
- cam_image = grad_cam.visualize_cam(cam, input_tensor)
32
-
33
- return result, cam_image
34
-
35
- # Defining the components
36
- inputs1 = gr.Image(type="pil", format="png")
37
- inputs2 = gr.Image(type="pil", format="png")
38
- outputs1 = [gr.Image(type="pil", label="Translated Images", format="png"),
39
- gr.Image(type="pil", label="Reconstructed Images", format="png")]
40
-
41
- outputs2 = [gr.Label(label="Classification Result"), gr.Image(label="Grad-CAM", format="png")]
42
-
43
- with gr.Blocks() as demo:
44
- with gr.Tab("Counterfactual Generation"):
45
- app1 = gr.Interface(fn=counterfactual_generation, inputs=inputs1, outputs=outputs1,
46
- title="Counterfactual Image Generation", allow_flagging="never",
47
- description="Generate counterfactual images to explain the classifier's decisions.")
48
-
49
- with gr.Tab("Classification"):
50
- app2 = gr.Interface(fn=image_classification, inputs=inputs2, outputs=outputs2,
51
- title="Image Classification", allow_flagging="never",
52
- description="Classify the input medical image and visualize Grad-CAM.")
53
-
54
- # Launch the app
55
  demo.launch(share=True)
 
1
+ import gradio as gr
2
+
3
+
4
+ from src.inference import load_classifier, load_model, generate_images, convert_into_image, classify_image
5
+ from src.models import ResUNetGenerator
6
+ from src.explainer import GradCAM, preprocess_image
7
+
8
+ # Loading Models
9
+ classifier_path = 'models\efficientnet_b1-epoch16-val_loss0.46_ft.ckpt'
10
+ g_NP_checkpoint = 'models\g_NP_best.ckpt'
11
+ g_PN_checkpoint = 'models\g_PN_best.ckpt'
12
+ g_NP = load_model(g_NP_checkpoint, ResUNetGenerator(gf=32, channels=1))
13
+ g_PN = load_model(g_PN_checkpoint, ResUNetGenerator(gf=32, channels=1))
14
+ classifier = load_classifier(classifier_path)
15
+ target_layer = classifier.model.features[-1]
16
+ grad_cam = GradCAM(classifier, target_layer)
17
+
18
+
19
+ def counterfactual_generation(input_image):
20
+
21
+ translated_images, recon_images = generate_images(input_image, classifier, g_PN, g_NP)
22
+ translated_images = convert_into_image(translated_images)
23
+ recon_images = convert_into_image(recon_images)
24
+ return translated_images, recon_images
25
+
26
+ def image_classification(input_image):
27
+
28
+ result, target_class = classify_image(input_image, classifier=classifier)
29
+ input_tensor = preprocess_image(input_image)
30
+ cam = grad_cam.generate_cam(input_tensor, target_class)
31
+ cam_image = grad_cam.visualize_cam(cam, input_tensor)
32
+
33
+ return result, cam_image
34
+
35
+ # Defining the components
36
+ inputs1 = gr.Image(type="pil", format="png")
37
+ inputs2 = gr.Image(type="pil", format="png")
38
+ outputs1 = [gr.Image(type="pil", label="Translated Images", format="png"),
39
+ gr.Image(type="pil", label="Reconstructed Images", format="png")]
40
+
41
+ outputs2 = [gr.Label(label="Classification Result"), gr.Image(label="Grad-CAM", format="png")]
42
+
43
+ with gr.Blocks() as demo:
44
+ with gr.Tab("Counterfactual Generation"):
45
+ app1 = gr.Interface(fn=counterfactual_generation, inputs=inputs1, outputs=outputs1,
46
+ title="Counterfactual Image Generation", allow_flagging="never",
47
+ description="Generate counterfactual images to explain the classifier's decisions.")
48
+
49
+ with gr.Tab("Classification"):
50
+ app2 = gr.Interface(fn=image_classification, inputs=inputs2, outputs=outputs2,
51
+ title="Image Classification", allow_flagging="never",
52
+ description="Classify the input medical image and visualize Grad-CAM.")
53
+
54
+ # Launch the app
55
  demo.launch(share=True)