File size: 2,024 Bytes
c4b2b37 8e6f824 c4b2b37 26b53b0 c4b2b37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import gradio as gr
from PIL import Image
from torchvision import transforms
from gradcam import do_gradcam
from lrp import do_lrp, do_partial_lrp
from rollout import do_rollout
from tiba import do_tiba
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
TRANSFORM = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]
)
METHOD_MAP = {
"tiba": do_tiba,
"gradcam": do_gradcam,
"lrp": do_lrp,
"partial_lrp": do_partial_lrp,
"rollout": do_rollout,
}
def generate_viz(image, method, class_index=None):
viz_method = METHOD_MAP[method]
viz = viz_method(TRANSFORM, image, class_index)
viz.savefig("visualization.png")
return Image.open("visualization.png").convert("RGB")
title = "Compare different methods of explaining ViTs 🤖"
article = "Different methods for explaining Vision Transformers as explored by Chefer et al. in [Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks](https://arxiv.org/abs/2012.09838)."
iface = gr.Interface(
generate_viz,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Dropdown(
list(METHOD_MAP.keys()),
label="Method",
info="Explainability method to investigate.",
),
gr.Number(label="Class Index", info="Class index to inspect"),
],
outputs=gr.Image(),
title=title,
article=article,
allow_flagging="never",
cache_examples=True,
examples=[
["Transformer-Explainability/samples/catdog.png", "tiba", None],
["Transformer-Explainability/samples/catdog.png", "rollout", 243],
["Transformer-Explainability/samples/el2.png", "tiba", None],
# ["Transformer-Explainability/samples/el2.png", "gradcam", 340],
["Transformer-Explainability/samples/dogbird.png", "lrp", 161],
],
)
iface.launch(debug=True)
|