PedroMartelleto commited on
Commit
032c7aa
1 Parent(s): 1a23377

Fix import error

Browse files
Files changed (1) hide show
  1. app.py +61 -2
app.py CHANGED
@@ -1,9 +1,68 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from torchvision.models import resnet50, ResNet50_Weights
3
  import torch.nn as nn
4
  import torch
5
  import numpy as np
6
- from explain import Explainer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  @staticmethod
9
  def create_model_from_checkpoint():
 
1
+ import PIL
2
+ from captum.attr import GradientShap
3
+ from captum.attr import visualization as viz
4
+ import torch
5
+ from torchvision import transforms
6
+ from matplotlib.colors import LinearSegmentedColormap
7
+ import torch.nn.functional as F
8
  import gradio as gr
9
+ from torchvision.models import resnet50
10
  import torch.nn as nn
11
  import torch
12
  import numpy as np
13
+
14
+ class Explainer:
15
+ def __init__(self, model):
16
+ self.model = model
17
+ self.default_cmap = LinearSegmentedColormap.from_list('custom blue',
18
+ [(0, '#ffffff'),
19
+ (0.25, '#000000'),
20
+ (1, '#000000')], N=256)
21
+
22
+ def __init__(self, model, img, class_names):
23
+ self.model = model
24
+ self.class_names = class_names
25
+
26
+ transform = transforms.Compose([
27
+ transforms.Resize(256),
28
+ transforms.CenterCrop(224),
29
+ transforms.ToTensor()
30
+ ])
31
+
32
+ transform_normalize = transforms.Normalize(
33
+ mean=[0.485, 0.456, 0.406],
34
+ std=[0.229, 0.224, 0.225]
35
+ )
36
+
37
+ self.transformed_img = transform(img)
38
+ self.input = transform_normalize(self.transformed_img)
39
+ self.input = input.unsqueeze(0)
40
+
41
+ with torch.no_grad():
42
+ self.output = self.model(input)
43
+ self.output = F.softmax(self.output, dim=1)
44
+ print(self.output.shape)
45
+ self.confidences = {class_names[i]: float(self.output[0, i]) for i in range(3)}
46
+
47
+ self.pred_score, self.pred_label_idx = torch.topk(self.output, 1)
48
+ self.pred_label = self.class_names[self.pred_label_idx]
49
+ self.fig_title = 'Predicted: ' + self.pred_label + ' (' + str(round(self.pred_score.squeeze().item(), 2)) + ')'
50
+
51
+ def convert_fig_to_pil(self, fig):
52
+ return PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
53
+
54
+ def shap(self):
55
+ gradient_shap = GradientShap(self.model)
56
+ rand_img_dist = torch.cat([self.input * 0, self.input * 1])
57
+ attributions_gs = gradient_shap.attribute(self.input, n_samples=50, stdevs=0.0001, baselines=rand_img_dist, target=self.pred_label_idx)
58
+ fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)),
59
+ np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
60
+ ["original_image", "heat_map"],
61
+ ["all", "absolute_value"],
62
+ cmap=self.default_cmap,
63
+ show_colorbar=True)
64
+ fig.suptitle(self.fig_title, fontsize=12)
65
+ return self.convert_fig_to_pil(fig)
66
 
67
  @staticmethod
68
  def create_model_from_checkpoint():