import cv2 from PIL import Image import torch import matplotlib.pyplot as plt import torch.functional as F import torch.nn as nn import numpy as np import torchvision import torchvision.transforms as transform # !pip install efficientnet_pytorch -q from efficientnet_pytorch import EfficientNet if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") val_transform = transform.Compose([transform.Resize(255), transform.CenterCrop(224), transform.ToTensor(), ]) def transform_image(image, transforms): # img = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) img = transforms(image) img = img.unsqueeze(0) return img DenseNet = torchvision.models.densenet161(weights="DEFAULT") for param in DenseNet.parameters(): param.requires_grad = True in_features = DenseNet.classifier.in_features DenseNet.classifier = nn.Linear(in_features, 2) class ModelGradCam(nn.Module): def __init__(self, base_model): super(ModelGradCam, self).__init__() self.base_model = base_model self.features_conv = self.base_model.features self.pool = nn.AdaptiveAvgPool2d((1,1)) self.classifier = self.base_model.classifier self.gradients = None def activations_hook(self, grad): self.gradients = grad def forward(self, x): x = self.features_conv(x) h = x.register_hook(self.activations_hook) x = self.pool(x) x = x.view(-1, 2208) x = self.classifier(x) return x def get_activations_gradient(self): return self.gradients def get_activations(self, x): return self.features_conv(x) def plot_grad_cam(model, x_ray_image, class_names, normalized=True): model.eval() # fig, axs = plt.subplots(1, 2, figsize=(15, 10)) image = x_ray_image outputs = torch.nn.functional.softmax(model(image), dim=1) _, pred = torch.max(outputs, 1) outputs[0][pred.detach().cpu().numpy()[0]].backward() gradients = model.get_activations_gradient() pooled_gradients = torch.mean(gradients, dim=[0, 2, 3]) activations = model.get_activations(image).detach() activations *= pooled_gradients.unsqueeze(-1).unsqueeze(-1) heatmap = torch.mean(activations, dim=1).squeeze() heatmap = np.maximum(heatmap.cpu(), 0) heatmap /= torch.max(heatmap) img = image.squeeze().permute(1, 2, 0).cpu().numpy() img = img if normalized else img/255.0 heatmap = cv2.resize(heatmap.numpy(), (img.shape[1], img.shape[0])) heatmap = np.uint8(255 * heatmap) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) superimposed_img = heatmap * 0.0025 + img outputs = outputs.tolist()[0] output_dict = dict(zip(class_names, np.round(outputs,3))) return superimposed_img, class_names[pred.item()], output_dict # axs[0].imshow(img) # axs[1].imshow(superimposed_img) # axs[0].set_title(f'Predicted: {class_names[pred.item()]}\n Confidence: {conf.item():.2f}') # axs[0].axis('off') # axs[1].set_title(f'Predicted: {class_names[pred.item()]}\n Confidence: {conf.item():.2f}') # axs[1].axis('off') # plt.show()