Spaces:
Runtime error
Runtime error
File size: 7,276 Bytes
032c7aa c396e65 032c7aa ad937a3 032c7aa ad937a3 1a23377 032c7aa 7736fa8 032c7aa 1b87171 032c7aa 41433b6 032c7aa 41433b6 032c7aa ffd57e9 032c7aa 1b87171 3a35229 1b87171 032c7aa d4a3403 032c7aa d4a3403 032c7aa 35677f0 032c7aa ad937a3 d4a3403 c396e65 d4a3403 c396e65 35677f0 c396e65 35677f0 c396e65 ad937a3 87fbe80 ad937a3 19010f0 1a23377 d4a3403 ad937a3 d4a3403 7d550ac d4a3403 35677f0 19010f0 ad937a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import PIL
from captum.attr import GradientShap, Occlusion, LayerGradCam, LayerAttribution, IntegratedGradients
from captum.attr import visualization as viz
import torch
from torchvision import transforms
from matplotlib.colors import LinearSegmentedColormap
import torch.nn.functional as F
import gradio as gr
from torchvision.models import resnet50
import torch.nn as nn
import torch
import numpy as np
class Explainer:
def __init__(self, model, img, class_names):
self.model = model
self.default_cmap = LinearSegmentedColormap.from_list('custom blue',
[(0, '#ffffff'),
(0.25, '#000000'),
(1, '#000000')], N=256)
self.class_names = class_names
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
])
transform_normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
self.transformed_img = transform(img)
self.input = transform_normalize(self.transformed_img)
self.input = self.input.unsqueeze(0)
with torch.no_grad():
self.output = self.model(self.input)
self.output = F.softmax(self.output, dim=1)
self.confidences = {class_names[i]: float(self.output[0, i]) for i in range(3)}
self.pred_score, self.pred_label_idx = torch.topk(self.output, 1)
self.pred_label = self.class_names[self.pred_label_idx]
self.fig_title = 'Predicted: ' + self.pred_label + ' (' + str(round(self.pred_score.squeeze().item(), 2)) + ')'
def convert_fig_to_pil(self, fig):
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return PIL.Image.fromarray(data)
def shap(self, n_samples, stdevs):
gradient_shap = GradientShap(self.model)
rand_img_dist = torch.cat([self.input * 0, self.input * 1])
attributions_gs = gradient_shap.attribute(self.input, n_samples=int(n_samples), stdevs=stdevs, baselines=rand_img_dist, target=self.pred_label_idx)
fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)),
np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
["original_image", "heat_map"],
["all", "absolute_value"],
cmap=self.default_cmap,
show_colorbar=True)
fig.suptitle("SHAP | " + self.fig_title, fontsize=12)
return self.convert_fig_to_pil(fig)
def occlusion(self, stride, sliding_window):
occlusion = Occlusion(model)
attributions_occ = occlusion.attribute(self.input,
target=self.pred_label_idx,
strides=(3, int(stride), int(stride)),
sliding_window_shapes=(3, int(sliding_window), int(sliding_window)),
baselines=0)
fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
["original_image", "heat_map", "heat_map", "masked_image"],
["all", "positive", "negative", "positive"],
show_colorbar=True,
titles=["Original", "Positive Attribution", "Negative Attribution", "Masked"],
fig_size=(18, 6)
)
fig.suptitle("Occlusion | " + self.fig_title, fontsize=12)
return self.convert_fig_to_pil(fig)
def gradcam(self):
layer_gradcam = LayerGradCam(self.model, self.model.layer3[1].conv2)
attributions_lgc = layer_gradcam.attribute(self.input, target=self.pred_label_idx)
#_ = viz.visualize_image_attr(attributions_lgc[0].cpu().permute(1,2,0).detach().numpy(),
# sign="all",
# title="Layer 3 Block 1 Conv 2")
upsamp_attr_lgc = LayerAttribution.interpolate(attributions_lgc, self.input.shape[2:])
fig, _ = viz.visualize_image_attr_multiple(upsamp_attr_lgc[0].cpu().permute(1,2,0).detach().numpy(),
self.transformed_img.permute(1,2,0).numpy(),
["original_image","blended_heat_map","masked_image"],
["all","positive","positive"],
show_colorbar=True,
titles=["Original", "Positive Attribution", "Masked"],
fig_size=(18, 6))
fig.suptitle("GradCAM layer3[1].conv2 | " + self.fig_title, fontsize=12)
return self.convert_fig_to_pil(fig)
def create_model_from_checkpoint():
# Loads a model from a checkpoint
model = resnet50()
model.fc = nn.Linear(model.fc.in_features, 3)
model.load_state_dict(torch.load("best_model", map_location=torch.device('cpu')))
model.eval()
return model
model = create_model_from_checkpoint()
labels = [ "benign", "malignant", "normal" ]
def predict(img, shap_samples, shap_stdevs, occlusion_stride, occlusion_window):
explainer = Explainer(model, img, labels)
return [explainer.confidences,
explainer.shap(shap_samples, shap_stdevs),
explainer.occlusion(occlusion_stride, occlusion_window),
explainer.gradcam()]
ui = gr.Interface(fn=predict,
inputs=[
gr.Image(type="pil"),
gr.Slider(minimum=10, maximum=100, default=50, label="SHAP Samples", step=1),
gr.Slider(minimum=0.0001, maximum=0.01, default=0.0001, label="SHAP Stdevs", step=0.0001),
gr.Slider(minimum=4, maximum=80, default=8, label="Occlusion Stride", step=1),
gr.Slider(minimum=4, maximum=80, default=15, label="Occlusion Window", step=1)
],
outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil"), gr.Image(type="pil"), gr.Image(type="pil")],
examples=[["benign (52).png", 50, 0.0001, 8, 15],
["benign (243).png", 50, 0.0001, 8, 15],
["malignant (127).png", 50, 0.0001, 8, 15],
["malignant (201).png", 50, 0.0001, 8, 15],
["normal (81).png", 50, 0.0001, 8, 15],
["normal (101).png", 50, 0.0001, 8, 15]]).launch()
ui.launch(share=True) |