Spaces:
Sleeping
Sleeping
import gradio as gr | |
from src.inference import load_classifier, load_model, generate_images, convert_into_image, classify_image | |
from src.models import ResUNetGenerator | |
from src.explainer import GradCAM, preprocess_image | |
# Loading Models | |
classifier_path = 'models\efficientnet_b1-epoch16-val_loss0.46_ft.ckpt' | |
g_NP_checkpoint = 'models\g_NP_best.ckpt' | |
g_PN_checkpoint = 'models\g_PN_best.ckpt' | |
g_NP = load_model(g_NP_checkpoint, ResUNetGenerator(gf=32, channels=1)) | |
g_PN = load_model(g_PN_checkpoint, ResUNetGenerator(gf=32, channels=1)) | |
classifier = load_classifier(classifier_path) | |
target_layer = classifier.model.features[-1] | |
grad_cam = GradCAM(classifier, target_layer) | |
def counterfactual_generation(input_image): | |
translated_images, recon_images = generate_images(input_image, classifier, g_PN, g_NP) | |
translated_images = convert_into_image(translated_images) | |
recon_images = convert_into_image(recon_images) | |
return translated_images, recon_images | |
def image_classification(input_image): | |
result, target_class = classify_image(input_image, classifier=classifier) | |
input_tensor = preprocess_image(input_image) | |
cam = grad_cam.generate_cam(input_tensor, target_class) | |
cam_image = grad_cam.visualize_cam(cam, input_tensor) | |
return result, cam_image | |
# Defining the components | |
inputs1 = gr.Image(type="pil", format="png") | |
inputs2 = gr.Image(type="pil", format="png") | |
outputs1 = [gr.Image(type="pil", label="Translated Images", format="png"), | |
gr.Image(type="pil", label="Reconstructed Images", format="png")] | |
outputs2 = [gr.Label(label="Classification Result"), gr.Image(label="Grad-CAM", format="png")] | |
with gr.Blocks() as demo: | |
with gr.Tab("Counterfactual Generation"): | |
app1 = gr.Interface(fn=counterfactual_generation, inputs=inputs1, outputs=outputs1, | |
title="Counterfactual Image Generation", allow_flagging="never", | |
description="Generate counterfactual images to explain the classifier's decisions.") | |
with gr.Tab("Classification"): | |
app2 = gr.Interface(fn=image_classification, inputs=inputs2, outputs=outputs2, | |
title="Image Classification", allow_flagging="never", | |
description="Classify the input medical image and visualize Grad-CAM.") | |
# Launch the app | |
demo.launch(share=True) |