Spaces:
Build error
Build error
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import matplotlib.cm | |
from PIL import Image | |
# Adapted from: https://colab.research.google.com/github/kevinzakka/clip_playground/blob/main/CLIP_GradCAM_Visualization.ipynb | |
class Hook: | |
"""Attaches to a module and records its activations and gradients.""" | |
def __init__(self, module: nn.Module): | |
self.data = None | |
self.hook = module.register_forward_hook(self.save_grad) | |
def save_grad(self, module, input, output): | |
self.data = output | |
output.requires_grad_(True) | |
output.retain_grad() | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_value, exc_traceback): | |
self.hook.remove() | |
def activation(self) -> torch.Tensor: | |
return self.data | |
def gradient(self) -> torch.Tensor: | |
return self.data.grad | |
# Reference: https://arxiv.org/abs/1610.02391 | |
def gradCAM( | |
model: nn.Module, | |
input: torch.Tensor, | |
target: torch.Tensor, | |
layer: nn.Module | |
) -> torch.Tensor: | |
# Zero out any gradients at the input. | |
if input.grad is not None: | |
input.grad.data.zero_() | |
# Disable gradient settings. | |
requires_grad = {} | |
for name, param in model.named_parameters(): | |
requires_grad[name] = param.requires_grad | |
param.requires_grad_(False) | |
# Attach a hook to the model at the desired layer. | |
assert isinstance(layer, nn.Module) | |
with Hook(layer) as hook: | |
# Do a forward and backward pass. | |
output = model(input) | |
output.backward(target) | |
grad = hook.gradient.float() | |
act = hook.activation.float() | |
# Global average pool gradient across spatial dimension | |
# to obtain importance weights. | |
alpha = grad.mean(dim=(2, 3), keepdim=True) | |
# Weighted combination of activation maps over channel | |
# dimension. | |
gradcam = torch.sum(act * alpha, dim=1, keepdim=True) | |
# We only want neurons with positive influence so we | |
# clamp any negative ones. | |
gradcam = torch.clamp(gradcam, min=0) | |
# Resize gradcam to input resolution. | |
gradcam = F.interpolate( | |
gradcam, | |
input.shape[2:], | |
mode='bicubic', | |
align_corners=False) | |
# Restore gradient settings. | |
for name, param in model.named_parameters(): | |
param.requires_grad_(requires_grad[name]) | |
return gradcam | |
# Modified from: https://github.com/salesforce/ALBEF/blob/main/visualization.ipynb | |
def getAttMap(img, attn_map): | |
# Normalize attention map | |
attn_map = attn_map - attn_map.min() | |
if attn_map.max() > 0: | |
attn_map = attn_map / attn_map.max() | |
H = matplotlib.cm.jet(attn_map) | |
H = (H * 255).astype(np.uint8)[:, :, :3] | |
img_heatmap = Image.fromarray(H) | |
img_heatmap = img_heatmap.resize((256, 256)) | |
return Image.blend( | |
img.resize((256, 256)), img_heatmap, 0.4) | |