mawady's picture
Create app.py
f740d84
raw
history blame
3.63 kB
import tensorflow as tf
from tensorflow.keras.applications.resnet_v2 import ResNet50V2
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet_v2 import preprocess_input, decode_predictions
import matplotlib.pyplot as plt
from alibi.explainers import IntegratedGradients
from alibi.datasets import load_cats
from alibi.utils.visualization import visualize_image_attr
import numpy as np
from PIL import Image
import io
import time
import os
import copy
import pickle
import datetime
import urllib.request
import gradio as gr
url = "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
path_input = "/content/cat.jpg"
urllib.request.urlretrieve(url, filename=path_input)
url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
path_input = "/content/dog.jpg"
urllib.request.urlretrieve(url, filename=path_input)
model = ResNet50V2(weights='imagenet')
n_steps = 50
method = "gausslegendre"
internal_batch_size = 50
ig = IntegratedGradients(model,
n_steps=n_steps,
method=method,
internal_batch_size=internal_batch_size)
# refs:
# - fig2pil: https://stackoverflow.com/questions/57316491/how-to-convert-matplotlib-figure-to-pil-image-object-without-saving-image
def do_process(img, baseline):
instance = image.img_to_array(img)
instance = np.expand_dims(instance, axis=0)
instance = preprocess_input(instance)
preds = model.predict(instance)
lstPreds = decode_predictions(preds, top=3)[0]
dctPreds = {lstPreds[i][1]: round(float(lstPreds[i][2]),2) for i in range(len(lstPreds))}
predictions = preds.argmax(axis=1)
if baseline is 'white':
baselines = bls = np.ones(instance.shape).astype(instance.dtype)
elif baseline is 'black':
baselines = bls = np.zeros(instance.shape).astype(instance.dtype)
else:
baselines = np.random.random_sample(instance.shape).astype(instance.dtype)
explanation = ig.explain(instance,
baselines=baselines,
target=predictions)
attrs = explanation.attributions[0]
fig, ax = visualize_image_attr(attr=attrs.squeeze(), original_image=img, method='blended_heat_map',
sign='all', show_colorbar=True, title='Overlaid Attributions',
plt_fig_axis=None, use_pyplot=False)
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img_res = Image.open(buf)
return img_res, dctPreds
input_im = gr.inputs.Image(shape=(224, 224), image_mode='RGB',
invert_colors=False, source="upload",
type="pil")
input_drop = gr.inputs.Dropdown(label='Baseline (default: random)',
choices=sorted(list(['black', 'white', 'random'])), default='random', type='value')
output_img = gr.outputs.Image(label='Output image', type='pil')
output_label = gr.outputs.Label(num_top_classes=3)
title = "XAI - Integrated gradients"
description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio."
examples = [['./cat.jpg'],['./dog.jpg']]
article="<p style='text-align: center'><a href='https://github.com/mawady/colab-recipes-cv' target='_blank'>Colab recipes for computer vision - Dr. Mohamed Elawady</a></p>"
iface = gr.Interface(
fn=do_process,
inputs=[input_im, input_drop],
outputs=[output_img,output_label],
live=False,
interpretation=None,
title=title,
description=description,
article=article,
examples=examples
)
iface.test_launch()
iface.launch(share=True, debug=True)