Anthony-Ml commited on
Commit
1c72a49
1 Parent(s): c0d91ef

Added init file

Browse files
Files changed (1) hide show
  1. utils.py +100 -0
utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import matplotlib.cm
6
+ from PIL import Image
7
+
8
+ # Adapted from: https://colab.research.google.com/github/kevinzakka/clip_playground/blob/main/CLIP_GradCAM_Visualization.ipynb
9
+ class Hook:
10
+ """Attaches to a module and records its activations and gradients."""
11
+
12
+ def __init__(self, module: nn.Module):
13
+ self.data = None
14
+ self.hook = module.register_forward_hook(self.save_grad)
15
+
16
+ def save_grad(self, module, input, output):
17
+ self.data = output
18
+ output.requires_grad_(True)
19
+ output.retain_grad()
20
+
21
+ def __enter__(self):
22
+ return self
23
+
24
+ def __exit__(self, exc_type, exc_value, exc_traceback):
25
+ self.hook.remove()
26
+
27
+ @property
28
+ def activation(self) -> torch.Tensor:
29
+ return self.data
30
+
31
+ @property
32
+ def gradient(self) -> torch.Tensor:
33
+ return self.data.grad
34
+
35
+
36
+ # Reference: https://arxiv.org/abs/1610.02391
37
+ def gradCAM(
38
+ model: nn.Module,
39
+ input: torch.Tensor,
40
+ target: torch.Tensor,
41
+ layer: nn.Module
42
+ ) -> torch.Tensor:
43
+ # Zero out any gradients at the input.
44
+ if input.grad is not None:
45
+ input.grad.data.zero_()
46
+
47
+ # Disable gradient settings.
48
+ requires_grad = {}
49
+ for name, param in model.named_parameters():
50
+ requires_grad[name] = param.requires_grad
51
+ param.requires_grad_(False)
52
+
53
+ # Attach a hook to the model at the desired layer.
54
+ assert isinstance(layer, nn.Module)
55
+ with Hook(layer) as hook:
56
+ # Do a forward and backward pass.
57
+ output = model(input)
58
+ output.backward(target)
59
+
60
+ grad = hook.gradient.float()
61
+ act = hook.activation.float()
62
+
63
+ # Global average pool gradient across spatial dimension
64
+ # to obtain importance weights.
65
+ alpha = grad.mean(dim=(2, 3), keepdim=True)
66
+ # Weighted combination of activation maps over channel
67
+ # dimension.
68
+ gradcam = torch.sum(act * alpha, dim=1, keepdim=True)
69
+ # We only want neurons with positive influence so we
70
+ # clamp any negative ones.
71
+ gradcam = torch.clamp(gradcam, min=0)
72
+
73
+ # Resize gradcam to input resolution.
74
+ gradcam = F.interpolate(
75
+ gradcam,
76
+ input.shape[2:],
77
+ mode='bicubic',
78
+ align_corners=False)
79
+
80
+ # Restore gradient settings.
81
+ for name, param in model.named_parameters():
82
+ param.requires_grad_(requires_grad[name])
83
+
84
+ return gradcam
85
+
86
+
87
+ # Modified from: https://github.com/salesforce/ALBEF/blob/main/visualization.ipynb
88
+ def getAttMap(img, attn_map):
89
+ # Normalize attention map
90
+ attn_map = attn_map - attn_map.min()
91
+ if attn_map.max() > 0:
92
+ attn_map = attn_map / attn_map.max()
93
+
94
+ H = matplotlib.cm.jet(attn_map)
95
+ H = (H * 255).astype(np.uint8)[:, :, :3]
96
+ img_heatmap = Image.fromarray(H)
97
+ img_heatmap = img_heatmap.resize((256, 256))
98
+
99
+ return Image.blend(
100
+ img.resize((256, 256)), img_heatmap, 0.4)