import gradio as gr from PIL import Image from torchvision import transforms from gradcam import do_gradcam from lrp import do_lrp, do_partial_lrp from rollout import do_rollout from tiba import do_tiba normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) TRANSFORM = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ] ) METHOD_MAP = { "tiba": do_tiba, "gradcam": do_gradcam, "lrp": do_lrp, "partial_lrp": do_partial_lrp, "rollout": do_rollout, } def generate_viz(image, method, class_index=None): if class_index is not None: class_index = int(class_index) print(f"Image: {image.size}") print(f"Method: {method}") print(f"Class: {class_index}") viz_method = METHOD_MAP[method] viz = viz_method(TRANSFORM, image, class_index=class_index) viz.savefig("visualization.png") return Image.open("visualization.png").convert("RGB") title = "Compare different methods of explaining ViTs 🤖" article = "Different methods for explaining Vision Transformers as explored by Chefer et al. in [Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks](https://arxiv.org/abs/2012.09838)." iface = gr.Interface( generate_viz, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Dropdown( list(METHOD_MAP.keys()), label="Method", info="Explainability method to investigate.", ), gr.Number(label="Class Index", info="Class index to inspect"), ], outputs=gr.Image(), title=title, article=article, allow_flagging="never", cache_examples=True, examples=[ ["Transformer-Explainability/samples/catdog.png", "tiba", None], ["Transformer-Explainability/samples/catdog.png", "rollout", 243], ["Transformer-Explainability/samples/el2.png", "tiba", None], ["Transformer-Explainability/samples/el2.png", "gradcam", 340], ["Transformer-Explainability/samples/dogbird.png", "lrp", 161], ], ) iface.launch(debug=True)