Spaces:
Runtime error
Runtime error
PedroMartelleto
commited on
Commit
•
032c7aa
1
Parent(s):
1a23377
Fix import error
Browse files
app.py
CHANGED
@@ -1,9 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
from torchvision.models import resnet50
|
3 |
import torch.nn as nn
|
4 |
import torch
|
5 |
import numpy as np
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
@staticmethod
|
9 |
def create_model_from_checkpoint():
|
|
|
1 |
+
import PIL
|
2 |
+
from captum.attr import GradientShap
|
3 |
+
from captum.attr import visualization as viz
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
from matplotlib.colors import LinearSegmentedColormap
|
7 |
+
import torch.nn.functional as F
|
8 |
import gradio as gr
|
9 |
+
from torchvision.models import resnet50
|
10 |
import torch.nn as nn
|
11 |
import torch
|
12 |
import numpy as np
|
13 |
+
|
14 |
+
class Explainer:
|
15 |
+
def __init__(self, model):
|
16 |
+
self.model = model
|
17 |
+
self.default_cmap = LinearSegmentedColormap.from_list('custom blue',
|
18 |
+
[(0, '#ffffff'),
|
19 |
+
(0.25, '#000000'),
|
20 |
+
(1, '#000000')], N=256)
|
21 |
+
|
22 |
+
def __init__(self, model, img, class_names):
|
23 |
+
self.model = model
|
24 |
+
self.class_names = class_names
|
25 |
+
|
26 |
+
transform = transforms.Compose([
|
27 |
+
transforms.Resize(256),
|
28 |
+
transforms.CenterCrop(224),
|
29 |
+
transforms.ToTensor()
|
30 |
+
])
|
31 |
+
|
32 |
+
transform_normalize = transforms.Normalize(
|
33 |
+
mean=[0.485, 0.456, 0.406],
|
34 |
+
std=[0.229, 0.224, 0.225]
|
35 |
+
)
|
36 |
+
|
37 |
+
self.transformed_img = transform(img)
|
38 |
+
self.input = transform_normalize(self.transformed_img)
|
39 |
+
self.input = input.unsqueeze(0)
|
40 |
+
|
41 |
+
with torch.no_grad():
|
42 |
+
self.output = self.model(input)
|
43 |
+
self.output = F.softmax(self.output, dim=1)
|
44 |
+
print(self.output.shape)
|
45 |
+
self.confidences = {class_names[i]: float(self.output[0, i]) for i in range(3)}
|
46 |
+
|
47 |
+
self.pred_score, self.pred_label_idx = torch.topk(self.output, 1)
|
48 |
+
self.pred_label = self.class_names[self.pred_label_idx]
|
49 |
+
self.fig_title = 'Predicted: ' + self.pred_label + ' (' + str(round(self.pred_score.squeeze().item(), 2)) + ')'
|
50 |
+
|
51 |
+
def convert_fig_to_pil(self, fig):
|
52 |
+
return PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
53 |
+
|
54 |
+
def shap(self):
|
55 |
+
gradient_shap = GradientShap(self.model)
|
56 |
+
rand_img_dist = torch.cat([self.input * 0, self.input * 1])
|
57 |
+
attributions_gs = gradient_shap.attribute(self.input, n_samples=50, stdevs=0.0001, baselines=rand_img_dist, target=self.pred_label_idx)
|
58 |
+
fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)),
|
59 |
+
np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
|
60 |
+
["original_image", "heat_map"],
|
61 |
+
["all", "absolute_value"],
|
62 |
+
cmap=self.default_cmap,
|
63 |
+
show_colorbar=True)
|
64 |
+
fig.suptitle(self.fig_title, fontsize=12)
|
65 |
+
return self.convert_fig_to_pil(fig)
|
66 |
|
67 |
@staticmethod
|
68 |
def create_model_from_checkpoint():
|