import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import matplotlib.cm from PIL import Image class Hook: """Attaches to a module and records its activations and gradients.""" def __init__(self, module: nn.Module): self.data = None self.hook = module.register_forward_hook(self.save_grad) def save_grad(self, module, input, output): self.data = output output.requires_grad_(True) output.retain_grad() def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_traceback): self.hook.remove() @property def activation(self) -> torch.Tensor: return self.data @property def gradient(self) -> torch.Tensor: return self.data.grad # Reference: https://arxiv.org/abs/1610.02391 def gradCAM( model: nn.Module, input: torch.Tensor, target: torch.Tensor, layer: nn.Module ) -> torch.Tensor: # Zero out any gradients at the input. if input.grad is not None: input.grad.data.zero_() # Disable gradient settings. requires_grad = {} for name, param in model.named_parameters(): requires_grad[name] = param.requires_grad param.requires_grad_(False) # Attach a hook to the model at the desired layer. assert isinstance(layer, nn.Module) with Hook(layer) as hook: # Do a forward and backward pass. output = model(input) output.backward(target) grad = hook.gradient.float() act = hook.activation.float() # Global average pool gradient across spatial dimension # to obtain importance weights. alpha = grad.mean(dim=(2, 3), keepdim=True) # Weighted combination of activation maps over channel # dimension. gradcam = torch.sum(act * alpha, dim=1, keepdim=True) # We only want neurons with positive influence so we # clamp any negative ones. gradcam = torch.clamp(gradcam, min=0) # Resize gradcam to input resolution. gradcam = F.interpolate( gradcam, input.shape[2:], mode='bicubic', align_corners=False) # Restore gradient settings. for name, param in model.named_parameters(): param.requires_grad_(requires_grad[name]) return gradcam # Modified from: https://github.com/salesforce/ALBEF/blob/main/visualization.ipynb def getAttMap(img, attn_map): # Normalize attention map attn_map = attn_map - attn_map.min() if attn_map.max() > 0: attn_map = attn_map / attn_map.max() H = matplotlib.cm.jet(attn_map) H = (H * 255).astype(np.uint8)[:, :, :3] img_heatmap = Image.fromarray(H) img_heatmap = img_heatmap.resize((256, 256)) return Image.blend( img.resize((256, 256)), img_heatmap, 0.4)