PedroMartelleto's picture
Deploying to HF
3a35229
raw
history blame
7.28 kB
import PIL
from captum.attr import GradientShap, Occlusion, LayerGradCam, LayerAttribution, IntegratedGradients
from captum.attr import visualization as viz
import torch
from torchvision import transforms
from matplotlib.colors import LinearSegmentedColormap
import torch.nn.functional as F
import gradio as gr
from torchvision.models import resnet50
import torch.nn as nn
import torch
import numpy as np
class Explainer:
def __init__(self, model, img, class_names):
self.model = model
self.default_cmap = LinearSegmentedColormap.from_list('custom blue',
[(0, '#ffffff'),
(0.25, '#000000'),
(1, '#000000')], N=256)
self.class_names = class_names
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
])
transform_normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
self.transformed_img = transform(img)
self.input = transform_normalize(self.transformed_img)
self.input = self.input.unsqueeze(0)
with torch.no_grad():
self.output = self.model(self.input)
self.output = F.softmax(self.output, dim=1)
self.confidences = {class_names[i]: float(self.output[0, i]) for i in range(3)}
self.pred_score, self.pred_label_idx = torch.topk(self.output, 1)
self.pred_label = self.class_names[self.pred_label_idx]
self.fig_title = 'Predicted: ' + self.pred_label + ' (' + str(round(self.pred_score.squeeze().item(), 2)) + ')'
def convert_fig_to_pil(self, fig):
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return PIL.Image.fromarray(data)
def shap(self, n_samples, stdevs):
gradient_shap = GradientShap(self.model)
rand_img_dist = torch.cat([self.input * 0, self.input * 1])
attributions_gs = gradient_shap.attribute(self.input, n_samples=int(n_samples), stdevs=stdevs, baselines=rand_img_dist, target=self.pred_label_idx)
fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)),
np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
["original_image", "heat_map"],
["all", "absolute_value"],
cmap=self.default_cmap,
show_colorbar=True)
fig.suptitle("SHAP | " + self.fig_title, fontsize=12)
return self.convert_fig_to_pil(fig)
def occlusion(self, stride, sliding_window):
occlusion = Occlusion(model)
attributions_occ = occlusion.attribute(self.input,
target=self.pred_label_idx,
strides=(3, int(stride), int(stride)),
sliding_window_shapes=(3, int(sliding_window), int(sliding_window)),
baselines=0)
fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
["original_image", "heat_map", "heat_map", "masked_image"],
["all", "positive", "negative", "positive"],
show_colorbar=True,
titles=["Original", "Positive Attribution", "Negative Attribution", "Masked"],
fig_size=(18, 6)
)
fig.suptitle("Occlusion | " + self.fig_title, fontsize=12)
return self.convert_fig_to_pil(fig)
def gradcam(self):
layer_gradcam = LayerGradCam(self.model, self.model.layer3[1].conv2)
attributions_lgc = layer_gradcam.attribute(self.input, target=self.pred_label_idx)
#_ = viz.visualize_image_attr(attributions_lgc[0].cpu().permute(1,2,0).detach().numpy(),
# sign="all",
# title="Layer 3 Block 1 Conv 2")
upsamp_attr_lgc = LayerAttribution.interpolate(attributions_lgc, self.input.shape[2:])
fig, _ = viz.visualize_image_attr_multiple(upsamp_attr_lgc[0].cpu().permute(1,2,0).detach().numpy(),
self.transformed_img.permute(1,2,0).numpy(),
["original_image","blended_heat_map","masked_image"],
["all","positive","positive"],
show_colorbar=True,
titles=["Original", "Positive Attribution", "Masked"],
fig_size=(18, 6))
fig.suptitle("GradCAM layer3[1].conv2 | " + self.fig_title, fontsize=12)
return self.convert_fig_to_pil(fig)
def create_model_from_checkpoint():
# Loads a model from a checkpoint
model = resnet50()
model.fc = nn.Linear(model.fc.in_features, 3)
model.load_state_dict(torch.load("best_model", map_location=torch.device('cpu')))
model.eval()
return model
model = create_model_from_checkpoint()
labels = [ "benign", "malignant", "normal" ]
def predict(img, shap_samples, shap_stdevs, occlusion_stride, occlusion_window):
explainer = Explainer(model, img, labels)
return [explainer.confidences,
explainer.shap(shap_samples, shap_stdevs),
explainer.occlusion(occlusion_stride, occlusion_window),
explainer.gradcam()]
ui = gr.Interface(fn=predict,
inputs=[
gr.Image(type="pil"),
gr.Slider(minimum=10, maximum=100, default=50, label="SHAP Samples", step=1),
gr.Slider(minimum=0.0001, maximum=0.01, default=0.0001, label="SHAP Stdevs", step=0.0001),
gr.Slider(minimum=4, maximum=80, default=8, label="Occlusion Stride", step=1),
gr.Slider(minimum=4, maximum=80, default=15, label="Occlusion Window", step=1)
],
outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil"), gr.Image(type="pil"), gr.Image(type="pil")],
examples=[["benign (52).png", 50, 0.0001, 8, 15],
["benign (243).png", 50, 0.0001, 8, 15],
["malignant (127).png", 50, 0.0001, 8, 15],
["malignant (201).png", 50, 0.0001, 8, 15],
["normal (81).png", 50, 0.0001, 8, 15],
["normal (101).png", 50, 0.0001, 8, 15]]).launch()
ui.launch(share=True)