{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "TdEse3Kwq3JD" }, "source": [ "# Import Necessary Libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WRKzuv_5owuz" }, "outputs": [], "source": [ "import numpy as np\n", "import nibabel as nib\n", "import glob\n", "from tensorflow.keras.utils import to_categorical # multiclass semantic segmentation, therefore the volumes to categorical\n", "import matplotlib.pyplot as plt\n", "from tifffile import imsave\n", "from sklearn.preprocessing import MinMaxScaler #scale values\n", "import tensorflow as tf\n", "import random\n", "import os.path\n", "!pip install split-folders\n", "!pip3 install -U segmentation-models-3D\n", "import splitfolders\n", "!pip install -q -U keras-tuner" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vEtRg2vutWru" }, "outputs": [], "source": [ "# To always ensure that the GPU is available\n", "import tensorflow as tf\n", "device_name = tf.test.gpu_device_name()\n", "if device_name != '/device:GPU:0':\n", " raise SystemError('GPU device not found')\n", "print('Found GPU at: {}'.format(device_name))" ] }, { "cell_type": "markdown", "metadata": { "id": "L5yBxROtvDAI" }, "source": [ "# Define the MinMax Scaler + Mount Drive to access Dataset\n", "\n", "* The MinMax scaler is necessary for transforming the scans' features to a range between 0 and 1" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sqMRiba8q-30" }, "outputs": [], "source": [ "scaler = MinMaxScaler()\n", "\n", "from google.colab import drive\n", "drive.mount('/content/drive')" ] }, { "cell_type": "markdown", "metadata": { "id": "XH4_Z5f2sfxZ" }, "source": [ "# Load sample images and visualize\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SvfI9iTrrZuN" }, "outputs": [], "source": [ "DATASET_PATH = ''\n", "\n", "test_image_flair = nib.load(DATASET_PATH + 'flair.nii').get_fdata()\n", "print(test_image_flair[156][98][78])\n", "test_image_flair = scaler.fit_transform(test_image_flair.reshape(-1, test_image_flair.shape[-1])).reshape(test_image_flair.shape)\n", "print(test_image_flair[156][98][78])\n", "\n", "test_image_t1 = nib.load(DATASET_PATH + 't1.nii').get_fdata()\n", "test_image_t1 = scaler.fit_transform(test_image_t1.reshape(-1, test_image_t1.shape[-1])).reshape(test_image_t1.shape)\n", "\n", "test_image_t1ce = nib.load(DATASET_PATH + 't1ce.nii').get_fdata()\n", "test_image_t1ce = scaler.fit_transform(test_image_t1ce.reshape(-1, test_image_t1ce.shape[-1])).reshape(test_image_t1ce.shape)\n", "\n", "test_image_t2 = nib.load(DATASET_PATH + 't2.nii').get_fdata()\n", "test_image_t2 = scaler.fit_transform(test_image_t2.reshape(-1, test_image_t2.shape[-1])).reshape(test_image_t2.shape)\n", "\n", "test_mask = nib.load(DATASET_PATH + 'seg.nii').get_fdata()\n", "test_mask = test_mask.astype(np.uint8)\n", "\n", "print(np.unique(test_mask))\n", "# Reassign label value 4 to 3\n", "test_mask[test_mask==4] = 3\n", "print(np.unique(test_mask))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aTkjA-mgwecE" }, "outputs": [], "source": [ "n_slice = random.randint(0, test_mask.shape[2])\n", "\n", "plt.figure(figsize=(12,8))\n", "plt.subplot(231)\n", "plt.imshow(test_image_flair[:, :, n_slice], cmap='gray')\n", "plt.title('Flair Scan')\n", "\n", "plt.subplot(232)\n", "plt.imshow(test_image_t1[:, :, n_slice], cmap='gray')\n", "plt.title('T1 Scan')\n", "\n", "plt.subplot(233)\n", "plt.imshow(test_image_t1ce[:, :, n_slice], cmap='gray')\n", "plt.title('T1ce Scan')\n", "\n", "plt.subplot(234)\n", "plt.imshow(test_image_t2[:, :, n_slice], cmap='gray')\n", "plt.title('T2 Scan')\n", "\n", "plt.subplot(235)\n", "plt.imshow(test_mask[:, :, n_slice])\n", "plt.title('Mask')\n", "\n", "plt.show()\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "EORoZoj7yPfW" }, "source": [ "# Data Processing: Combining the volumes of scans to one + Cropping the scans and masks\n", "\n", "* The numpy array is reshaped to 2D, the dimensions the scaler can take as input, the array is transformed and then reshaped back to 3D\n", "* Result: the feature at position [156][98][78] of the loaded FLAIR scan numpy array is transformed from 1920.0 to 0.7683...\n", "* The three scans to be used are stacked together to forme a combined scan.\n", "* Result: A FLAIR scan, a T1CE scan and a T2 scan, all of dimensions 255 x 255 x 155 are stacked to form a combined scan of dimensions 255 x 255 x 155 x 3\n", "* The combined scan is cropped to 128 x 128 x 128 x 3\n", "* Label 4 in the dataset is reassigned to label 3 resulting to a continuous list of labels: 0, 1, 2, 3" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-3u91yIqybn-" }, "outputs": [], "source": [ "combined_x = np.stack([test_image_flair, test_image_t1ce, test_image_t2], axis=3)\n", "combined_x = combined_x[56:184, 56:184, 13:141] #crop to 128 x 128 x 128 X 3\n", "\n", "test_mask = test_mask[56:184, 56:184, 13:141]\n", "n_slice = random.randint(0, test_mask.shape[1])\n", "plt.figure(figsize=(12, 8))\n", "\n", "plt.subplot(231)\n", "plt.imshow(combined_x[:, :, n_slice, 0], cmap='gray')\n", "plt.title('Flair Scan')\n", "\n", "plt.subplot(232)\n", "plt.imshow(combined_x[:, :, n_slice, 1], cmap='gray')\n", "plt.title('T1ce Scan')\n", "\n", "plt.subplot(233)\n", "plt.imshow(combined_x[:, :, n_slice, 2], cmap='gray')\n", "plt.title('T2 Scan')\n", "\n", "plt.subplot(234)\n", "plt.imshow(test_mask[:, :, n_slice])\n", "plt.title('Mask')\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "T8r7sy4QND41" }, "outputs": [], "source": [ "from tensorflow.keras import backend as K\n", "\n", "print(K.int_shape(test_image_flair))\n", "\n", "print(K.int_shape(combined_x))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WeD_PqCv6Vww" }, "outputs": [], "source": [ "flair_list = sorted(glob.glob(DATASET_PATH + '*/flair.nii'))\n", "t1_list = sorted(glob.glob(DATASET_PATH + '*/t1.nii'))\n", "t1ce_list = sorted(glob.glob(DATASET_PATH + '*/t1ce.nii'))\n", "t2_list = sorted(glob.glob(DATASET_PATH + '*/t2.nii'))\n", "mask_list = sorted(glob.glob(DATASET_PATH + '*/seg.nii'))\n", "\n", "\n", "for img in range(len(flair_list)):\n", " print('Now processing image and masks no: ', img)\n", "\n", " temp_image_flair = nib.load(flair_list[img]).get_fdata()\n", " temp_image_flair = scaler.fit_transform(temp_image_flair.reshape(-1, temp_image_flair.shape[-1])).reshape(temp_image_flair.shape)\n", "\n", " temp_image_t1 = nib.load(t1_list[img]).get_fdata()\n", " temp_image_t1 = scaler.fit_transform(temp_image_t1.reshape(-1, temp_image_t1.shape[-1])).reshape(temp_image_t1.shape)\n", "\n", " temp_image_t1ce = nib.load(t1ce_list[img]).get_fdata()\n", " temp_image_t1ce = scaler.fit_transform(temp_image_t1ce.reshape(-1, temp_image_t1ce.shape[-1])).reshape(temp_image_t1ce.shape)\n", "\n", " temp_image_t2 = nib.load(t2_list[img]).get_fdata()\n", " temp_image_t2 = scaler.fit_transform(temp_image_t2.reshape(-1, temp_image_t2.shape[-1])).reshape(temp_image_t2.shape)\n", "\n", " temp_mask = nib.load(mask_list[img]).get_fdata()\n", " temp_mask = temp_mask.astype(np.uint8)\n", " temp_mask[temp_mask == 4] = 3\n", "\n", " temp_combined_images = np.stack([temp_image_flair, temp_image_t1, temp_image_t1ce, temp_image_t2], axis = 3)\n", " temp_combined_images = temp_combined_images[56:184, 56:184, 13:141]\n", " temp_mask = temp_mask[56:184, 56:184, 13:141]\n", "\n", " val, counts = np.unique(temp_mask, return_counts=True)\n", "\n", " if(1 - (counts[0]/counts.sum())) > 0.01:\n", " temp_mask = to_categorical(temp_mask, num_classes=4)\n", " np.save(DATASET_PATH + 'final_dataset/scans/image_' + str(img) + '.npy', temp_combined_images)\n", " np.save(DATASET_PATH + 'final_dataset/masks/image_' + str(img) + '.npy', temp_mask)\n", " print(\"Saved\")\n", " else:\n", " print(\"Not saved\")" ] }, { "cell_type": "markdown", "metadata": { "id": "-wICUx56ugDz" }, "source": [ "# Dataset Splitting: 60:20:20 for train, val and test" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Oi_g5D01HSnq" }, "outputs": [], "source": [ "input_folder = DATASET_PATH + 'final_dataset/'\n", "output_folder = DATASET_PATH + 'split_dataset/'\n", "splitfolders.ratio(input_folder, output=output_folder, seed=42, ratio=(.6, .2, .2), group_prefix=None)" ] }, { "cell_type": "markdown", "metadata": { "id": "RtaRf0B4kPkM" }, "source": [ "# Data Generator\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UMfHysy2ixc8" }, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "\n", "def load_img(img_dir, img_list):\n", " images=[]\n", " for i, image_name in enumerate(img_list):\n", " if(image_name.split('.')[1] == 'npy'):\n", " image = np.load(img_dir + image_name)\n", " images.append(image)\n", " images = np.array(images)\n", " return images\n", "\n", "def imageLoader(img_dir, img_list, mask_dir, mask_list, batch_size):\n", " L = len(img_list)\n", " # keras needs the generator infinite, so use while True\n", " while True:\n", " batch_start = 0\n", " batch_end = batch_size\n", "\n", " while batch_start < L:\n", " limit = min(batch_end, L)\n", " X = load_img(img_dir, img_list[batch_start:limit])\n", " Y = load_img(mask_dir, mask_list[batch_start:limit])\n", "\n", " yield(X, Y) # a tuple with two numpy arrays with batch_size samples\n", "\n", " batch_start += batch_size\n", " batch_end += batch_size\n", "\n", "\n", "# Test the generator\n", "TRAIN_DATASET_PATH = ''\n", "train_img_dir = TRAIN_DATASET_PATH + 'scans/'\n", "train_mask_dir = TRAIN_DATASET_PATH + 'masks/'\n", "\n", "train_img_list = os.listdir(train_img_dir)\n", "train_mask_list = os.listdir(train_mask_dir)\n", "\n", "batch_size = 2\n", "\n", "train_img_datagen = imageLoader(train_img_dir, train_img_list,\n", " train_mask_dir, train_mask_list, batch_size)\n", "\n", "# Verify generator - In python 3 next() is renamed as __next__()\n", "img, msk = train_img_datagen.__next__()\n", "\n", "img_num = random.randint(0, img.shape[0]-1)\n", "\n", "test_img = img[img_num]\n", "test_mask = msk[img_num]\n", "test_mask = np.argmax(test_mask, axis=3)\n", "\n", "n_slice = random.randint(0, test_mask.shape[2])\n", "plt.figure(figsize=(12,8))\n", "\n", "plt.subplot(221)\n", "plt.imshow(test_img[:, :, n_slice, 0], cmap='gray')\n", "plt.title('Flair Scan')\n", "\n", "plt.subplot(222)\n", "plt.imshow(test_img[:, :, n_slice, 1], cmap='gray')\n", "plt.title('T1ce Scan')\n", "\n", "plt.subplot(223)\n", "plt.imshow(test_img[:, :, n_slice, 2], cmap='gray')\n", "plt.title('T2 Scan')\n", "\n", "plt.subplot(224)\n", "plt.imshow(test_mask[:, :, n_slice])\n", "plt.title('Mask')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "ReTmFPr0QV17" }, "source": [ "# Define image generators for training, validation and testing" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HS9Dihs_QbqU" }, "outputs": [], "source": [ "DATASET_PATH = ''\n", "train_img_dir = DATASET_PATH + 'train/scans/'\n", "train_mask_dir = DATASET_PATH + 'train/masks/'\n", "\n", "val_img_dir = DATASET_PATH + 'val/scans/'\n", "val_mask_dir = DATASET_PATH + 'val/masks/'\n", "\n", "test_img_dir = DATASET_PATH + 'test/scans/'\n", "test_mask_dir = DATASET_PATH + 'test/masks/'\n", "\n", "train_img_list = os.listdir(train_img_dir)\n", "train_mask_list = os.listdir(train_mask_dir)\n", "\n", "val_img_list = os.listdir(val_img_dir)\n", "val_mask_list = os.listdir(val_mask_dir)\n", "\n", "test_img_list = os.listdir(test_img_dir)\n", "test_mask_list = os.listdir(test_mask_dir)\n", "\n", "batch_size = 2\n", "train_img_datagen = imageLoader(train_img_dir, train_img_list,\n", " train_mask_dir, train_mask_list, batch_size)\n", "\n", "val_img_datagen = imageLoader(val_img_dir, val_img_list,\n", " val_mask_dir, val_mask_list, batch_size)\n", "\n", "test_img_datagen = imageLoader(test_img_dir, test_img_list,\n", " test_mask_dir, test_mask_list, batch_size)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "dBKMHMn96Z3c" }, "source": [ "# Losses and metrics\n", "* These losses and metrics best handle the problem of class imbalance\n", "* Used: dice_coef as a metric, tversky_loss as a loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pshixCsr6eyt" }, "outputs": [], "source": [ "import tensorflow.keras.backend as K\n", "\n", "\n", "def dice_coef(y_true, y_pred, smooth=1):\n", " y_true_f = K.flatten(y_true)\n", " y_pred_f = K.flatten(y_pred)\n", " intersection = K.sum(y_true_f * y_pred_f)\n", " return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) +\n", " smooth)\n", "\n", "\n", "def dice_coef_loss(y_true, y_pred):\n", " return 1 - dice_coef(y_true, y_pred)\n", "\n", "\n", "def tversky(y_true, y_pred, smooth=1, alpha=0.7):\n", " y_true_pos = K.flatten(y_true)\n", " y_pred_pos = K.flatten(y_pred)\n", " true_pos = K.sum(y_true_pos * y_pred_pos)\n", " false_neg = K.sum(y_true_pos * (1 - y_pred_pos))\n", " false_pos = K.sum((1 - y_true_pos) * y_pred_pos)\n", " return (true_pos + smooth) / (true_pos + alpha * false_neg +\n", " (1 - alpha) * false_pos + smooth)\n", "\n", "\n", "def tversky_loss(y_true, y_pred):\n", " return 1 - tversky(y_true, y_pred)\n", "\n", "\n", "def focal_tversky_loss(y_true, y_pred, gamma=0.75):\n", " tv = tversky(y_true, y_pred)\n", " return K.pow((1 - tv), gamma)" ] }, { "cell_type": "markdown", "metadata": { "id": "2o2WuIhaW5ff" }, "source": [ "# Define loss, metrics and optimizer to be used for training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WxiJ1eUQXJ4I" }, "outputs": [], "source": [ "from keras.models import Model\n", "from keras.layers import Input, Conv3D, MaxPooling3D, Activation, add, concatenate, Conv3DTranspose, BatchNormalization, Dropout, UpSampling3D, multiply\n", "from tensorflow.keras.optimizers import Adam\n", "from keras import layers\n", "\n", "kernel_initializer = 'he_uniform'\n", "\n", "import segmentation_models_3D as sm\n", "\n", "metrics = [dice_coef]\n", "\n", "LR = 0.0001\n", "optim = Adam(LR)\n", "\n", "steps_per_epoch = len(train_img_list) // batch_size\n", "val_steps_per_epoch = len(val_img_list) // batch_size" ] }, { "cell_type": "markdown", "metadata": { "id": "PR2Ugre0YP-v" }, "source": [ "# 3D UNet Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N0VyhdjCYVuZ" }, "outputs": [], "source": [ "def UNet(IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS, num_classes):\n", " inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS))\n", "\n", " # Downsampling\n", " c1 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(inputs)\n", " c1 = Dropout(0.1)(c1)\n", " c1 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c1)\n", " p1 = MaxPooling3D((2, 2, 2))(c1)\n", "\n", " c2 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(p1)\n", " c2 = Dropout(0.1)(c2)\n", " c2 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c2)\n", " p2 = MaxPooling3D((2, 2, 2))(c2)\n", "\n", " c3 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(p2)\n", " c3 = Dropout(0.2)(c3)\n", " c3 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c3)\n", " p3 = MaxPooling3D((2, 2, 2))(c3)\n", "\n", " c4 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(p3)\n", " c4 = Dropout(0.2)(c4)\n", " c4 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c4)\n", " p4 = MaxPooling3D((2, 2, 2))(c4)\n", "\n", " c5 = Conv3D(filters = 256, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(p4)\n", " c5 = Dropout(0.3)(c5)\n", " c5 = Conv3D(filters = 256, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c5)\n", " \n", " # Upsampling part\n", " u6 = Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding='same')(c5)\n", " u6 = concatenate([u6, c4])\n", " c6 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(u6)\n", " c6 = Dropout(0.2)(c6)\n", " c6 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c6) \n", " \n", " u7 = Conv3DTranspose(64, (2, 2, 2), strides=(2, 2, 2), padding='same')(c6)\n", " u7 = concatenate([u7, c3])\n", " c7 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(u7)\n", " c7 = Dropout(0.2)(c7)\n", " c7 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c7) \n", " \n", " u8 = Conv3DTranspose(32, (2, 2, 2), strides=(2, 2, 2), padding='same')(c7)\n", " u8 = concatenate([u8, c2])\n", " c8 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(u8)\n", " c8 = Dropout(0.1)(c8)\n", " c8 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c8) \n", "\n", " u9 = Conv3DTranspose(16, (2, 2, 2), strides=(2, 2, 2), padding='same')(c8)\n", " u9 = concatenate([u9, c1])\n", " c9 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(u9)\n", " c9 = Dropout(0.1)(c9)\n", " c9 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c9) \n", "\n", " outputs = Conv3D(num_classes, (1, 1, 1), activation='softmax')(c9)\n", "\n", " model = Model(inputs=[inputs], outputs=[outputs])\n", " model.summary()\n", "\n", " return model" ] }, { "cell_type": "markdown", "metadata": { "id": "-Aw_Peb9iJYb" }, "source": [ "# Test the working of the 3D UNet model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fjdzCTisiMLI" }, "outputs": [], "source": [ "steps_per_epoch = len(train_img_list)//batch_size\n", "val_steps_per_epoch = len(val_img_list)//batch_size\n", "\n", "model = UNet(IMG_HEIGHT = 128,\n", " IMG_WIDTH = 128,\n", " IMG_DEPTH = 128,\n", " IMG_CHANNELS = 3,\n", " num_classes = 4)\n", "\n", "model.compile(optimizer = optim, loss = tversky_loss, metrics = metrics)\n", "\n", "print(model.summary)\n", "\n", "print(model.input_shape)\n", "print(model.output_shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "e6Cvn6hWvars" }, "source": [ "# 3D Attention UNet Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JBcFdz80v2mL" }, "outputs": [], "source": [ "from keras.layers.core.activation import Activation\n", "from tensorflow.keras import backend as K\n", "from keras.layers import LeakyReLU\n", "\n", "def repeat_elem(tensor, rep):\n", " # lambda function to repeat Repeats the elements of a tensor along an axis\n", " #by a factor of rep.\n", " # If tensor has shape (None, 128,128,3), lambda will return a tensor of shape \n", " #(None, 128,128,6), if specified axis=3 and rep=2.\n", "\n", " return layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=4),\n", " arguments={'repnum': rep})(tensor)\n", "\n", "def attention_block(x, gating, inter_shape):\n", " shape_x = K.int_shape(x)\n", " shape_g = K.int_shape(gating)\n", "\n", " # Getting the gating signal to the same number of filters as the inter_shape\n", " phi_g = Conv3D(filters=inter_shape, kernel_size=1, strides=1, padding='same')(gating)\n", "\n", " # Geting the x signal to the same shape as the gating signal\n", " theta_x = Conv3D(filters=inter_shape, kernel_size=3, strides=(\n", " shape_x[1] // shape_g[1],\n", " shape_x[2] // shape_g[2],\n", " shape_x[3] // shape_g[3]\n", " ), padding='same')(x)\n", " shape_theta_x = K.int_shape(theta_x)\n", "\n", " print(shape_theta_x, shape_g)\n", "\n", " # Elemet-wise addition of the gating and x signals\n", " xg_sum = add([phi_g, theta_x])\n", " xg_sum = Activation('relu')(xg_sum)\n", "\n", " # 1x1x1 convolution\n", " psi = Conv3D(filters=1, kernel_size=1, padding='same')(xg_sum)\n", " sigmoid_psi = Activation('sigmoid')(psi)\n", " shape_sigmoid = K.int_shape(sigmoid_psi)\n", "\n", " # Upsampling psi back to the original dimensions of x signal to enable \n", " # element-wise multiplication with the signal\n", "\n", " upsampled_sigmoid_psi = UpSampling3D(size=(\n", " shape_x[1] // shape_sigmoid[1], \n", " shape_x[2] // shape_sigmoid[2],\n", " shape_x[3] // shape_sigmoid[3]\n", " ))(sigmoid_psi)\n", "\n", " # Expand the filter axis to the number of filters in the original x signal\n", " upsampled_sigmoid_psi = repeat_elem(upsampled_sigmoid_psi, shape_x[4])\n", "\n", " # Element-wise multiplication of attention coefficients back onto original x signal\n", " attention_coeffs = multiply([upsampled_sigmoid_psi, x])\n", "\n", " # Final 1x1x1 convolution to consolidate attention signal to original x dimensions\n", " output = Conv3D(filters=shape_x[3], kernel_size=1, strides=1, padding='same')(attention_coeffs)\n", " output = BatchNormalization()(output)\n", " return output\n", "\n", "\n", "# Gating signal\n", "def gating_signal(input, output_size, batch_norm=False):\n", " # Resize the down layer feature map into the same dimensions as the up layer feature map using 1x1 conv\n", " # Return: the gating feature map with the same dimension of the up layer feature map\n", " x = Conv3D(output_size, (1, 1, 1), padding='same')(input)\n", " if batch_norm:\n", " x = BatchNormalization()(x)\n", " x = Activation('relu')(x)\n", " return x\n", "\n", "\n", "# Attention UNet\n", "def attention_unet(IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS, num_classes, batch_norm = True):\n", " inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS))\n", " FILTER_NUM = 64 #\n", " FILTER_SIZE = 3 #\n", " UP_SAMPLING_SIZE = 2 # \n", "\n", " c1 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(inputs)\n", " c1 = Dropout(0.1)(c1)\n", " c1 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c1)\n", " p1 = MaxPooling3D((2, 2, 2))(c1)\n", "\n", " c2 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(p1)\n", " c2 = Dropout(0.1)(c2)\n", " c2 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c2)\n", " p2 = MaxPooling3D((2, 2, 2))(c2)\n", "\n", " c3 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(p2)\n", " c3 = Dropout(0.2)(c3)\n", " c3 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c3)\n", " p3 = MaxPooling3D((2, 2, 2))(c3)\n", "\n", " c4 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(p3)\n", " c4 = Dropout(0.2)(c4)\n", " c4 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c4)\n", " p4 = MaxPooling3D((2, 2, 2))(c4)\n", "\n", " c5 = Conv3D(filters = 256, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(p4)\n", " c5 = Dropout(0.3)(c5)\n", " c5 = Conv3D(filters = 256, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c5)\n", " \n", "\n", " gating_6 = gating_signal(c5, 128, batch_norm)\n", " att_6 = attention_block(c4, gating_6, 128)\n", " u6 = UpSampling3D((2, 2, 2), data_format='channels_last')(c5)\n", " u6 = concatenate([u6, att_6])\n", " c6 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(u6)\n", " c6 = Dropout(0.2)(c6)\n", " c6 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c6) \n", " \n", " gating_7 = gating_signal(c6, 64, batch_norm)\n", " att_7 = attention_block(c3, gating_6, 64)\n", " u7 = UpSampling3D((2, 2, 2), data_format='channels_last')(c6)\n", " u7 = concatenate([u7, att_7])\n", " c7 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(u7)\n", " c7 = Dropout(0.2)(c7)\n", " c7 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c7) \n", " \n", " gating_8 = gating_signal(c7, 64, batch_norm)\n", " att_8 = attention_block(c2, gating_6, 64)\n", " u8 = UpSampling3D((2, 2, 2), data_format='channels_last')(c7)\n", " u8 = concatenate([u8, att_8])\n", " c8 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(u8)\n", " c8 = Dropout(0.1)(c8)\n", " c8 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c8) \n", "\n", " gating_9 = gating_signal(c8, 64, batch_norm)\n", " att_9 = attention_block(c1, gating_6, 64)\n", " u9 = UpSampling3D((2, 2, 2), data_format='channels_last')(c8)\n", " u9 = concatenate([u9, att_9])\n", " c9 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(u9)\n", " c9 = Dropout(0.1)(c9)\n", " c9 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c9) \n", "\n", " outputs = Conv3D(num_classes, (1, 1, 1))(c9)\n", " outputs = BatchNormalization()(outputs)\n", " outputs = Activation('softmax')(outputs)\n", "\n", " model = Model(inputs=[inputs], outputs=[outputs], name=\"Attention_UNet\")\n", " model.summary()\n", "\n", " return model" ] }, { "cell_type": "markdown", "metadata": { "id": "xndmsEwjVhn7" }, "source": [ "# Test the working of a 3D Attention UNet Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pBNjxGbjVn9U" }, "outputs": [], "source": [ "steps_per_epoch = len(train_img_list)//batch_size\n", "val_steps_per_epoch = len(val_img_list)//batch_size\n", "\n", "model = attention_unet(IMG_HEIGHT = 128,\n", " IMG_WIDTH = 128,\n", " IMG_DEPTH = 128,\n", " IMG_CHANNELS = 3,\n", " num_classes = 4)\n", "\n", "model.compile(optimizer = optim, loss = tversky_loss, metrics = metrics)\n", "\n", "print(model.summary)\n", "\n", "print(model.input_shape)\n", "print(model.output_shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "8qnlrlr1YXu4" }, "source": [ "# Fit the Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UXmCjFvjYaSG" }, "outputs": [], "source": [ "import tensorflow.keras as keras\n", "from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger, TerminateOnNaN\n", "\n", "checkpoint_path = ''\n", "log_path = ''\n", "\n", "callbacks = [\n", " EarlyStopping(monitor='val_loss', patience=4, verbose=1),\n", " ReduceLROnPlateau(factor=0.1,\n", " monitor='val_loss',\n", " patience=4,\n", " min_lr=0.0001,\n", " verbose=1,\n", " mode='min'),\n", " ModelCheckpoint(checkpoint_path,\n", " monitor='val_loss',\n", " mode='min',\n", " verbose=0,\n", " save_best_only=True),\n", " CSVLogger(log_path, separator=',', append=True),\n", " TerminateOnNaN()\n", "]\n", "\n", "history = model.fit(train_img_datagen,\n", " steps_per_epoch=steps_per_epoch,\n", " epochs=100,\n", " verbose=1,\n", " validation_data=val_img_datagen,\n", " validation_steps=val_steps_per_epoch,\n", " callbacks=callbacks\n", " )\n", "\n", "history_callback = np.save('', history.history)" ] }, { "cell_type": "markdown", "metadata": { "id": "pfcKmJv4jP2J" }, "source": [ "# Load Model for more training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7RXukeY_jiad" }, "outputs": [], "source": [ "import tensorflow.keras.models as load\n", "import keras\n", "model = load.load_model('', custom_objects={\n", " 'tversky_loss': tversky_loss,\n", " 'dice_coef': dice_coef\n", "})\n", "\n", "checkpoint_path = ''\n", "log_path = ''\n", "\n", "callbacks = [\n", " EarlyStopping(monitor='val_loss', patience=4, verbose=1),\n", " ReduceLROnPlateau(factor=0.1,\n", " monitor='val_loss',\n", " patience=4,\n", " min_lr=0.0001,\n", " verbose=1,\n", " mode='min'),\n", " ModelCheckpoint(checkpoint_path,\n", " monitor='val_loss',\n", " mode='min',\n", " verbose=0,\n", " save_best_only=True),\n", " CSVLogger(log_path, separator=',', append=True),\n", " TerminateOnNaN()\n", "]\n", "\n", "history = model.fit(train_img_datagen,\n", " steps_per_epoch=steps_per_epoch,\n", " epochs=100,\n", " verbose=1,\n", " validation_data=val_img_datagen,\n", " validation_steps=val_steps_per_epoch,\n", " callbacks=callbacks\n", " )\n", "\n", "history_callback = np.save('', history.history)" ] }, { "cell_type": "markdown", "metadata": { "id": "SPBUC1HIfqDt" }, "source": [ "# Plot the training and validation loss (tversky) and dice coefficient (metric) at each epoch" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "I7e4YkM5f1Jg" }, "outputs": [], "source": [ "history = np.load('',allow_pickle='TRUE').item()\n", "\n", "print(history)\n", "loss = history['loss']\n", "val_loss = history['val_loss']\n", "epochs = range(1, len(loss) + 1)\n", "plt.plot(epochs, loss, 'y', label='Training loss')\n", "plt.plot(epochs, val_loss, 'r', label='Validation loss')\n", "plt.title('Training and Validation Loss')\n", "plt.xlabel('Epochs')\n", "plt.ylabel('Loss')\n", "plt.legend()\n", "plt.show()\n", "\n", "acc = history['dice_coef']\n", "val_acc = history['val_dice_coef']\n", "\n", "plt.plot(epochs, acc, 'y', label='Training accuracy')\n", "plt.plot(epochs, val_acc, 'r', label='Validation accuracy')\n", "plt.title('Trainign and Validation Accuracy')\n", "plt.xlabel('Epochs')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "XV8kjMkemQ-W" }, "source": [ "# Model Evaluation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ChhYHB8PmTnK" }, "outputs": [], "source": [ "from tensorflow.keras.models import load_model\n", "my_model = load_model('', custom_objects={\n", " 'tversky_loss': tversky_loss,\n", " 'dice_coef': dice_coef},\n", " compile = True)\n", "\n", "# Verify IoU on a batch of images from the test dataset\n", "batch_size = 8\n", "test_img_datagen = imageLoader(val_img_dir, val_img_list,\n", " val_mask_dir, val_mask_list, batch_size)\n", "\n", "test_image_batch, test_mask_batch = test_img_datagen.__next__()\n", "\n", "test_mask_batch_argmax = np.argmax(test_mask_batch, axis=4)\n", "\n", "results = my_model.evaluate(test_image_batch, test_mask_batch, batch_size=batch_size)\n", "print(\"test acc, test loss:\", results)" ] }, { "cell_type": "markdown", "metadata": { "id": "xvEqiU6SqY2y" }, "source": [ "# Predict on a test scan" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8-MUQpCiqcxd" }, "outputs": [], "source": [ "from tensorflow.keras.models import load_model\n", "my_model = load_model('', compile=False)\n", "\n", "img_num = 53\n", "test_scan = np.load('' + str(img_num) + '.npy')\n", "\n", "test_mask = np.load('' + str(img_num) + '.npy')\n", "test_mask_argmax = np.argmax(test_mask, axis = 3)\n", "\n", "test_scan_input = np.expand_dims(test_scan, axis = 0)\n", "test_prediction = my_model.predict(test_scan_input)\n", "test_prediction_argmax = np.argmax(test_prediction, axis = 4)[0, :, :, :]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "background_save": true }, "id": "65FmAMNhmX8E" }, "outputs": [], "source": [ "# n_slice = 55\n", "n_slice = random.randint(0, test_mask_argmax.shape[2])\n", "\n", "plt.figure(figsize=(12,8))\n", "plt.subplot(231)\n", "plt.imshow(test_scan[:, :, n_slice, 1], cmap='gray')\n", "plt.title('Testing Scan')\n", "\n", "plt.subplot(232)\n", "plt.imshow(test_mask_argmax[:, :, n_slice])\n", "plt.title('Testing Label')\n", "\n", "plt.subplot(235)\n", "plt.imshow(test_prediction_argmax[:, :, n_slice])\n", "plt.title('Prediction on test image')\n", "\n", "plt.show()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "TdEse3Kwq3JD", "L5yBxROtvDAI", "EORoZoj7yPfW", "-wICUx56ugDz", "nq3p80zN2ew2", "dBKMHMn96Z3c", "PR2Ugre0YP-v", "-Aw_Peb9iJYb", "e6Cvn6hWvars", "xndmsEwjVhn7", "pfcKmJv4jP2J" ], "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }