Spaces:
Runtime error
Runtime error
PedroMartelleto
commited on
Commit
•
c396e65
1
Parent(s):
1b87171
Deploying to HF
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import PIL
|
2 |
-
from captum.attr import GradientShap
|
3 |
from captum.attr import visualization as viz
|
4 |
import torch
|
5 |
from torchvision import transforms
|
@@ -65,6 +65,58 @@ class Explainer:
|
|
65 |
fig.suptitle(self.fig_title, fontsize=12)
|
66 |
return self.convert_fig_to_pil(fig)
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
def create_model_from_checkpoint():
|
69 |
# Loads a model from a checkpoint
|
70 |
model = resnet50()
|
@@ -78,11 +130,10 @@ labels = [ "benign", "malignant", "normal" ]
|
|
78 |
|
79 |
def predict(img):
|
80 |
explainer = Explainer(model, img, labels)
|
81 |
-
|
82 |
-
return [explainer.confidences, shap_img]
|
83 |
|
84 |
ui = gr.Interface(fn=predict,
|
85 |
inputs=gr.Image(type="pil"),
|
86 |
-
outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil")],
|
87 |
examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch()
|
88 |
ui.launch(share=True)
|
|
|
1 |
import PIL
|
2 |
+
from captum.attr import GradientShap, Occlusion, LayerGradCam, LayerAttribution, IntegratedGradients
|
3 |
from captum.attr import visualization as viz
|
4 |
import torch
|
5 |
from torchvision import transforms
|
|
|
65 |
fig.suptitle(self.fig_title, fontsize=12)
|
66 |
return self.convert_fig_to_pil(fig)
|
67 |
|
68 |
+
def occlusion(self):
|
69 |
+
occlusion = Occlusion(model)
|
70 |
+
|
71 |
+
attributions_occ = occlusion.attribute(self.input,
|
72 |
+
target=self.pred_label_idx,
|
73 |
+
strides=(3, 8, 8),
|
74 |
+
sliding_window_shapes=(3,15, 15),
|
75 |
+
baselines=0)
|
76 |
+
|
77 |
+
fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
|
78 |
+
np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
|
79 |
+
["original_image", "heat_map", "heat_map", "masked_image"],
|
80 |
+
["all", "positive", "negative", "positive"],
|
81 |
+
show_colorbar=True,
|
82 |
+
titles=["Original", "Positive Attribution", "Negative Attribution", "Masked"],
|
83 |
+
fig_size=(18, 6)
|
84 |
+
)
|
85 |
+
fig.suptitle(self.fig_title, fontsize=12)
|
86 |
+
return self.convert_fig_to_pil(fig)
|
87 |
+
|
88 |
+
def gradcam(self):
|
89 |
+
layer_gradcam = LayerGradCam(self.model, self.model.layer3[1].conv2)
|
90 |
+
attributions_lgc = layer_gradcam.attribute(self.input, target=self.pred_label_idx)
|
91 |
+
|
92 |
+
#_ = viz.visualize_image_attr(attributions_lgc[0].cpu().permute(1,2,0).detach().numpy(),
|
93 |
+
# sign="all",
|
94 |
+
# title="Layer 3 Block 1 Conv 2")
|
95 |
+
upsamp_attr_lgc = LayerAttribution.interpolate(attributions_lgc, self.input.shape[2:])
|
96 |
+
|
97 |
+
fig, _ = viz.visualize_image_attr_multiple(upsamp_attr_lgc[0].cpu().permute(1,2,0).detach().numpy(),
|
98 |
+
self.transformed_img.permute(1,2,0).numpy(),
|
99 |
+
["original_image","blended_heat_map","masked_image"],
|
100 |
+
["all","positive","positive"],
|
101 |
+
show_colorbar=True,
|
102 |
+
titles=["Original", "Positive Attribution", "Masked"],
|
103 |
+
fig_size=(18, 6))
|
104 |
+
return self.convert_fig_to_pil(fig)
|
105 |
+
|
106 |
+
def integrated_gradients(self):
|
107 |
+
integrated_gradients = IntegratedGradients(self.model)
|
108 |
+
attributions_ig = integrated_gradients.attribute(self.input, target=self.pred_label_idx, n_steps=50)
|
109 |
+
|
110 |
+
fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
|
111 |
+
np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
|
112 |
+
["original_image", "heat_map", "masked_image"],
|
113 |
+
["all", "positive", "positive"],
|
114 |
+
show_colorbar=True,
|
115 |
+
titles=["Original", "Attribution", "Masked"],
|
116 |
+
fig_size=(18, 6))
|
117 |
+
fig.suptitle(self.fig_title, fontsize=12)
|
118 |
+
return self.convert_fig_to_pil(fig)
|
119 |
+
|
120 |
def create_model_from_checkpoint():
|
121 |
# Loads a model from a checkpoint
|
122 |
model = resnet50()
|
|
|
130 |
|
131 |
def predict(img):
|
132 |
explainer = Explainer(model, img, labels)
|
133 |
+
return [explainer.confidences, explainer.shap(), explainer.occlusion(), explainer.gradcam(), explainer.integrated_gradients()]
|
|
|
134 |
|
135 |
ui = gr.Interface(fn=predict,
|
136 |
inputs=gr.Image(type="pil"),
|
137 |
+
outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil"), gr.Image(type="pil"), gr.Image(type="pil"), gr.Image(type="pil")],
|
138 |
examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch()
|
139 |
ui.launch(share=True)
|