Spaces:
Runtime error
Runtime error
File size: 3,643 Bytes
032c7aa ad937a3 032c7aa ad937a3 1a23377 032c7aa ad937a3 1a23377 ad937a3 1a23377 ad937a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import PIL
from captum.attr import GradientShap
from captum.attr import visualization as viz
import torch
from torchvision import transforms
from matplotlib.colors import LinearSegmentedColormap
import torch.nn.functional as F
import gradio as gr
from torchvision.models import resnet50
import torch.nn as nn
import torch
import numpy as np
class Explainer:
def __init__(self, model):
self.model = model
self.default_cmap = LinearSegmentedColormap.from_list('custom blue',
[(0, '#ffffff'),
(0.25, '#000000'),
(1, '#000000')], N=256)
def __init__(self, model, img, class_names):
self.model = model
self.class_names = class_names
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
])
transform_normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
self.transformed_img = transform(img)
self.input = transform_normalize(self.transformed_img)
self.input = input.unsqueeze(0)
with torch.no_grad():
self.output = self.model(input)
self.output = F.softmax(self.output, dim=1)
print(self.output.shape)
self.confidences = {class_names[i]: float(self.output[0, i]) for i in range(3)}
self.pred_score, self.pred_label_idx = torch.topk(self.output, 1)
self.pred_label = self.class_names[self.pred_label_idx]
self.fig_title = 'Predicted: ' + self.pred_label + ' (' + str(round(self.pred_score.squeeze().item(), 2)) + ')'
def convert_fig_to_pil(self, fig):
return PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
def shap(self):
gradient_shap = GradientShap(self.model)
rand_img_dist = torch.cat([self.input * 0, self.input * 1])
attributions_gs = gradient_shap.attribute(self.input, n_samples=50, stdevs=0.0001, baselines=rand_img_dist, target=self.pred_label_idx)
fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)),
np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
["original_image", "heat_map"],
["all", "absolute_value"],
cmap=self.default_cmap,
show_colorbar=True)
fig.suptitle(self.fig_title, fontsize=12)
return self.convert_fig_to_pil(fig)
@staticmethod
def create_model_from_checkpoint():
# Loads a model from a checkpoint
model = resnet50()
model.fc = nn.Linear(model.fc.in_features, 3)
model.load_state_dict(torch.load("best_model"))
model.eval()
return model
model = create_model_from_checkpoint()
labels = [ "benign", "malignant", "normal" ]
def predict(img):
explainer = Explainer(model, img, labels)
shap_img = explainer.shap()
return [explainer.confidences, shap_img]
ui = gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil")],
examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch()
ui.launch(share=True) |