File size: 3,643 Bytes
032c7aa
 
 
 
 
 
 
ad937a3
032c7aa
ad937a3
 
1a23377
032c7aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad937a3
 
 
 
 
 
 
 
 
 
 
 
 
 
1a23377
 
 
ad937a3
 
 
1a23377
ad937a3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import PIL
from captum.attr import GradientShap
from captum.attr import visualization as viz
import torch
from torchvision import transforms
from matplotlib.colors import LinearSegmentedColormap
import torch.nn.functional as F
import gradio as gr
from torchvision.models import resnet50
import torch.nn as nn
import torch
import numpy as np

class Explainer:
    def __init__(self, model):
        self.model = model
        self.default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                    [(0, '#ffffff'),
                                                    (0.25, '#000000'),
                                                    (1, '#000000')], N=256)

    def __init__(self, model, img, class_names):
        self.model = model
        self.class_names = class_names

        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])

        transform_normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )

        self.transformed_img = transform(img)
        self.input = transform_normalize(self.transformed_img)
        self.input = input.unsqueeze(0)

        with torch.no_grad():
            self.output = self.model(input)
            self.output = F.softmax(self.output, dim=1)
        print(self.output.shape)
        self.confidences = {class_names[i]: float(self.output[0, i]) for i in range(3)}

        self.pred_score, self.pred_label_idx = torch.topk(self.output, 1)
        self.pred_label = self.class_names[self.pred_label_idx]
        self.fig_title = 'Predicted: ' + self.pred_label + ' (' + str(round(self.pred_score.squeeze().item(), 2)) + ')'

    def convert_fig_to_pil(self, fig):
        return PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())

    def shap(self):
        gradient_shap = GradientShap(self.model)
        rand_img_dist = torch.cat([self.input * 0, self.input * 1])
        attributions_gs = gradient_shap.attribute(self.input, n_samples=50, stdevs=0.0001, baselines=rand_img_dist, target=self.pred_label_idx)
        fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)),
                                            np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                            ["original_image", "heat_map"],
                                            ["all", "absolute_value"],
                                            cmap=self.default_cmap,
                                            show_colorbar=True)
        fig.suptitle(self.fig_title, fontsize=12)
        return self.convert_fig_to_pil(fig)

@staticmethod
def create_model_from_checkpoint():
    # Loads a model from a checkpoint
    model = resnet50()
    model.fc = nn.Linear(model.fc.in_features, 3)
    model.load_state_dict(torch.load("best_model"))
    model.eval()
    return model

model = create_model_from_checkpoint()
labels = [ "benign", "malignant", "normal" ]

def predict(img):
    explainer = Explainer(model, img, labels)
    shap_img = explainer.shap()
    return [explainer.confidences, shap_img]

ui = gr.Interface(fn=predict, 
                inputs=gr.Image(type="pil"),
                outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil")],
                examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch()
ui.launch(share=True)