Spaces:
Runtime error
Runtime error
File size: 6,049 Bytes
ebb41db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import matplotlib.pyplot as plt
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
def convert_back_image(image):
"""Using mean and std deviation convert image back to normal"""
cifar10_mean = (0.4914, 0.4822, 0.4471)
cifar10_std = (0.2469, 0.2433, 0.2615)
image = image.numpy().astype(dtype=np.float32)
for i in range(image.shape[0]):
image[i] = (image[i] * cifar10_std[i]) + cifar10_mean[i]
# To stop throwing a warning that image pixels exceeds bounds
image = image.clip(0, 1)
return np.transpose(image, (1, 2, 0))
def plot_sample_training_images(batch_data, batch_label, class_label, num_images=30):
"""Function to plot sample images from the training data."""
images, labels = batch_data, batch_label
# Calculate the number of images to plot
num_images = min(num_images, len(images))
# calculate the number of rows and columns to plot
num_cols = 5
num_rows = int(np.ceil(num_images / num_cols))
# Initialize a subplot with the required number of rows and columns
fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, 10))
# Iterate through the images and plot them in the grid along with class labels
for img_index in range(1, num_images + 1):
plt.subplot(num_rows, num_cols, img_index)
plt.tight_layout()
plt.axis("off")
plt.imshow(convert_back_image(images[img_index - 1]))
plt.title(class_label[labels[img_index - 1].item()])
plt.xticks([])
plt.yticks([])
return fig, axs
def plot_train_test_metrics(results):
"""
Function to plot the training and test metrics.
"""
# Extract train_losses, train_acc, test_losses, test_acc from results
train_losses = results["train_loss"]
train_acc = results["train_acc"]
test_losses = results["test_loss"]
test_acc = results["test_acc"]
# Plot the graphs in a 1x2 grid showing the training and test metrics
fig, axs = plt.subplots(1, 2, figsize=(16, 8))
# Loss plot
axs[0].plot(train_losses, label="Train")
axs[0].plot(test_losses, label="Test")
axs[0].set_title("Loss")
axs[0].legend(loc="upper right")
# Accuracy plot
axs[1].plot(train_acc, label="Train")
axs[1].plot(test_acc, label="Test")
axs[1].set_title("Accuracy")
axs[1].legend(loc="upper right")
return fig, axs
def plot_misclassified_images(data, class_label, num_images=10):
"""Plot the misclassified images from the test dataset."""
# Calculate the number of images to plot
num_images = min(num_images, len(data["ground_truths"]))
# calculate the number of rows and columns to plot
num_cols = 5
num_rows = int(np.ceil(num_images / num_cols))
# Initialize a subplot with the required number of rows and columns
fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2))
# Iterate through the images and plot them in the grid along with class labels
for img_index in range(1, num_images + 1):
# Get the ground truth and predicted labels for the image
label = data["ground_truths"][img_index - 1].cpu().item()
pred = data["predicted_vals"][img_index - 1].cpu().item()
# Get the image
image = data["images"][img_index - 1].cpu()
# Plot the image
plt.subplot(num_rows, num_cols, img_index)
plt.tight_layout()
plt.axis("off")
plt.imshow(convert_back_image(image))
plt.title(f"""ACT: {class_label[label]} \nPRED: {class_label[pred]}""")
plt.xticks([])
plt.yticks([])
return fig, axs
# Function to plot gradcam for misclassified images using pytorch_grad_cam
def plot_gradcam_images(
model,
data,
class_label,
target_layers,
targets=None,
num_images=10,
image_weight=0.25,
):
"""Show gradcam for misclassified images"""
# Calculate the number of images to plot
num_images = min(num_images, len(data["ground_truths"]))
# calculate the number of rows and columns to plot
num_cols = 5
num_rows = int(np.ceil(num_images / num_cols))
# Initialize a subplot with the required number of rows and columns
fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2))
# Initialize the GradCAM object
# https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/grad_cam.py
# https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/base_cam.py
# Alert: Change the device to cpu for gradio app
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
# Iterate through the images and plot them in the grid along with class labels
for img_index in range(1, num_images + 1):
# Extract elements from the data dictionary
# Get the ground truth and predicted labels for the image
label = data["ground_truths"][img_index - 1].cpu().item()
pred = data["predicted_vals"][img_index - 1].cpu().item()
# Get the image
image = data["images"][img_index - 1].cpu()
# Get the GradCAM output
# https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/model_targets.py
grad_cam_output = cam(
input_tensor=image.unsqueeze(0),
targets=targets,
aug_smooth=True,
eigen_smooth=True,
)
grad_cam_output = grad_cam_output[0, :]
# Overlay gradcam on top of numpy image
overlayed_image = show_cam_on_image(
convert_back_image(image),
grad_cam_output,
use_rgb=True,
image_weight=image_weight,
)
# Plot the image
plt.subplot(num_rows, num_cols, img_index)
plt.tight_layout()
plt.axis("off")
plt.imshow(overlayed_image)
plt.title(f"""ACT: {class_label[label]} \nPRED: {class_label[pred]}""")
plt.xticks([])
plt.yticks([])
return fig, axs
|