PedroMartelleto commited on
Commit
1b87171
·
1 Parent(s): 73a6c1b

Deploying to HF

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -32,6 +32,7 @@ class Explainer:
32
  )
33
 
34
  self.transformed_img = transform(img)
 
35
  self.input = transform_normalize(self.transformed_img)
36
  self.input = self.input.unsqueeze(0)
37
 
@@ -46,7 +47,10 @@ class Explainer:
46
  self.fig_title = 'Predicted: ' + self.pred_label + ' (' + str(round(self.pred_score.squeeze().item(), 2)) + ')'
47
 
48
  def convert_fig_to_pil(self, fig):
49
- return PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
 
 
 
50
 
51
  def shap(self):
52
  gradient_shap = GradientShap(self.model)
@@ -75,7 +79,7 @@ labels = [ "benign", "malignant", "normal" ]
75
  def predict(img):
76
  explainer = Explainer(model, img, labels)
77
  shap_img = explainer.shap()
78
- return [explainer.confidences, shap_img]
79
 
80
  ui = gr.Interface(fn=predict,
81
  inputs=gr.Image(type="pil"),
 
32
  )
33
 
34
  self.transformed_img = transform(img)
35
+
36
  self.input = transform_normalize(self.transformed_img)
37
  self.input = self.input.unsqueeze(0)
38
 
 
47
  self.fig_title = 'Predicted: ' + self.pred_label + ' (' + str(round(self.pred_score.squeeze().item(), 2)) + ')'
48
 
49
  def convert_fig_to_pil(self, fig):
50
+ fig.canvas.draw()
51
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
52
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
53
+ return PIL.Image.fromarray(data)
54
 
55
  def shap(self):
56
  gradient_shap = GradientShap(self.model)
 
79
  def predict(img):
80
  explainer = Explainer(model, img, labels)
81
  shap_img = explainer.shap()
82
+ return [explainer.confidences, shap_img]
83
 
84
  ui = gr.Interface(fn=predict,
85
  inputs=gr.Image(type="pil"),