Spaces:
Sleeping
Sleeping
import numpy as np | |
import cv2 | |
import torchvision.transforms as transforms | |
class GradCAM: | |
def __init__(self, model, target_layer): | |
self.model = model | |
self.target_layer = target_layer | |
self.gradients = None | |
self.activations = None | |
self._register_hooks() | |
def _register_hooks(self): | |
def forward_hook(module, input, output): | |
self.activations = output | |
def backward_hook(module, grad_in, grad_out): | |
self.gradients = grad_out[0] | |
self.target_layer.register_forward_hook(forward_hook) | |
self.target_layer.register_backward_hook(backward_hook) | |
def generate_cam(self, input_image, target_class): | |
self.model.zero_grad() | |
output = self.model(input_image) | |
loss = output[:, target_class].sum() | |
loss.backward() | |
weights = self.gradients.mean(dim=(2, 3), keepdim=True) | |
cam = (weights * self.activations).sum(dim=1, keepdim=True) | |
cam = cam.detach().cpu().numpy() | |
cam = np.maximum(cam, 0) | |
cam = cam / cam.max() | |
cam = cam.squeeze() | |
return cam | |
def visualize_cam(self, cam, input_image): | |
cam = cv2.resize(cam, (input_image.shape[2], input_image.shape[3])) | |
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET) | |
heatmap = np.float32(heatmap) / 255 | |
input_image = np.moveaxis(input_image.cpu().numpy()[0], 0, -1) | |
input_image = np.float32(input_image) | |
cam_image = heatmap + input_image | |
cam_image = cam_image / cam_image.max() | |
return cam_image | |
def preprocess_image(image): | |
preprocess = transforms.Compose([ | |
transforms.Grayscale(num_output_channels=1), | |
transforms.Resize((512, 512)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) | |
]) | |
image = preprocess(image) | |
image = image.unsqueeze(0) | |
return image | |