|
import gradio as gr |
|
from PIL import Image |
|
from torchvision import transforms |
|
|
|
from gradcam import do_gradcam |
|
from lrp import do_lrp, do_partial_lrp |
|
from rollout import do_rollout |
|
from tiba import do_tiba |
|
|
|
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
TRANSFORM = transforms.Compose( |
|
[ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
normalize, |
|
] |
|
) |
|
METHOD_MAP = { |
|
"tiba": do_tiba, |
|
"gradcam": do_gradcam, |
|
"lrp": do_lrp, |
|
"partial_lrp": do_partial_lrp, |
|
"rollout": do_rollout, |
|
} |
|
|
|
|
|
def generate_viz(image, method, class_index=None): |
|
viz_method = METHOD_MAP[method] |
|
viz = viz_method(TRANSFORM, image, class_index) |
|
viz.savefig("visualization.png") |
|
return Image.open("visualization.png").convert("RGB") |
|
|
|
|
|
title = "Compare different methods of explaining ViTs π€" |
|
article = "Different methods for explaining Vision Transformers as explored by Chefer et al. in [Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks](https://arxiv.org/abs/2012.09838)." |
|
|
|
iface = gr.Interface( |
|
generate_viz, |
|
inputs=[ |
|
gr.Image(type="pil", label="Input Image"), |
|
gr.Dropdown( |
|
list(METHOD_MAP.keys()), |
|
label="Method", |
|
info="Explainability method to investigate.", |
|
), |
|
gr.Number(label="Class Index", info="Class index to inspect"), |
|
], |
|
outputs=gr.Image(), |
|
title=title, |
|
article=article, |
|
allow_flagging="never", |
|
cache_examples=True, |
|
examples=[ |
|
["Transformer-Explainability/samples/catdog.png", "tiba", None], |
|
["Transformer-Explainability/samples/catdog.png", "rollout", 243], |
|
["Transformer-Explainability/samples/el2.png", "tiba", None], |
|
|
|
["Transformer-Explainability/samples/dogbird.png", "lrp", 161], |
|
], |
|
) |
|
iface.launch(debug=True) |
|
|