X-ray_Classifier / Utils /Covid19_Utils.py
Emms's picture
X-RAY BASE
49106b8
import cv2
from PIL import Image
import torch
import matplotlib.pyplot as plt
import torch.functional as F
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
mean_nums = [0.485, 0.456, 0.406]
std_nums = [0.229, 0.224, 0.225]
val_transform = transforms.Compose([
transforms.Resize((150,150)),
transforms.CenterCrop(150), #Performs Crop at Center and resizes it to 150x150
transforms.ToTensor(),
transforms.Normalize(mean=mean_nums, std = std_nums)
])
def transform_image(image, transforms):
# img = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
img = transforms(image)
img = img.unsqueeze(0)
return img
class DenseNet(nn.Module):
def __init__(self):
super(DenseNet, self).__init__()
self.base_model = torchvision.models.densenet121(weights="DEFAULT").features
self.pool = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(1024, 1000)
self.classify = nn.Linear(1000, 1)
self.classifier = nn.Sigmoid()
def forward(self, x):
x = self.base_model(x)
x = self.pool(x)
x = x.view(-1, 1024)
x = self.fc(x)
x = self.classify(x)
x = self.classifier(x)
return x
class ModelGradCam(nn.Module):
def __init__(self, base_model):
super(ModelGradCam, self).__init__()
self.features_conv = base_model.base_model
self.pool = base_model.pool
self.fc = base_model.fc
self.classify = base_model.classify
self.classifier = base_model.classifier
self.gradients = None
def activations_hook(self, grad):
self.gradients = grad
def forward(self, x):
x = self.features_conv(x)
h = x.register_hook(self.activations_hook)
x = self.pool(x)
x = x.view(-1, 1024)
x = self.fc(x)
x = self.classify(x)
x = self.classifier(x)
return x
def get_activations_gradient(self):
return self.gradients
def get_activations(self, x):
return self.features_conv(x)
def plot_grad_cam(model, x_ray_image, class_names, threshold:int=0.5, normalized=True):
model.eval()
# fig, axs = plt.subplots(1, 2, figsize=(15, 10))
image = x_ray_image
outputs = model(image).view(-1)
conf = [1-outputs.item(), outputs.item()]
# conf = 1 - outputs if outputs < threshold else outputs
pred = torch.where(outputs > threshold, torch.tensor(1, device=device), torch.tensor(0, device=device))
outputs[0].backward()
gradients = model.get_activations_gradient()
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
activations = model.get_activations(image).detach()
activations *= pooled_gradients.unsqueeze(-1).unsqueeze(-1)
heatmap = torch.mean(activations, dim=1).squeeze()
heatmap = np.maximum(heatmap.cpu(), 0)
heatmap /= torch.max(heatmap)
img = image.squeeze().permute(1, 2, 0).cpu().numpy()
img = img if normalized else img/255.0
heatmap = cv2.resize(heatmap.numpy(), (img.shape[1], img.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = heatmap * 0.0045 + img
output_dict = dict(zip(class_names, np.round(conf,3)))
return superimposed_img, class_names[pred.item()], output_dict
# axs[0].imshow(img)
# axs[1].imshow(superimposed_img)
# axs[0].set_title(f'Predicted: {class_names[pred.item()]}\n Confidence: {conf.item():.3f}')
# axs[0].axis('off')
# axs[1].set_title(f'Predicted: {class_names[pred.item()]}\n Confidence: {conf.item():.3f}')
# axs[1].axis('off')
# plt.show()