import monai import torch import pandas as pd import nibabel as nib import numpy as np from monai.data import DataLoader from monai.utils.enums import CommonKeys from scipy import ndimage from monai.data import Dataset from monai.inferers import sliding_window_inference from monai.metrics import DiceMetric from monai.transforms import ( Activationsd, AsDiscreted, Compose, ConcatItemsd, KeepLargestConnectedComponentd, LoadImaged, EnsureChannelFirstd, EnsureTyped, SaveImaged, ScaleIntensityd, NormalizeIntensityd, Spacingd, Orientationd, ) # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print("Using device:", device) # model = monai.networks.nets.UNet( # in_channels=1, # out_channels=3, # spatial_dims=3, # channels=[16, 32, 64, 128, 256, 512], # strides=[2, 2, 2, 2, 2], # num_res_units=4, # act="PRELU", # norm="BATCH", # dropout=0.15, # ) # model.load_state_dict(torch.load("anatomy.pt", map_location=device)) # keys = ("t2", "t2_anatomy_reader1") # transforms = Compose( # [ # LoadImaged(keys=keys, image_only=False), # EnsureChannelFirstd(keys=keys), # Spacingd(keys=keys, pixdim=[0.5, 0.5, 0.5], mode=("bilinear", "nearest")), # Orientationd(keys=keys, axcodes="RAS"), # ScaleIntensityd(keys=keys, minv=0, maxv=1), # NormalizeIntensityd(keys=keys), # EnsureTyped(keys=keys), # ConcatItemsd(keys=("t2"), name=CommonKeys.IMAGE, dim=0), # ConcatItemsd(keys=("t2_anatomy_reader1"), name=CommonKeys.LABEL, dim=0), # ], # ) # postprocessing = Compose( # [ # EnsureTyped(keys=[CommonKeys.PRED, CommonKeys.LABEL]), # KeepLargestConnectedComponentd( # keys=CommonKeys.PRED, # applied_labels=list(range(1, 3)) # ), # ], # ) keys = ("t2") transforms = Compose( [ LoadImaged(keys=keys, image_only=False), EnsureChannelFirstd(keys=keys), Spacingd(keys=keys, pixdim=[0.5, 0.5, 0.5], mode=("bilinear")), Orientationd(keys=keys, axcodes="RAS"), ScaleIntensityd(keys=keys, minv=0, maxv=1), NormalizeIntensityd(keys=keys), EnsureTyped(keys=keys), ConcatItemsd(keys=("t2"), name=CommonKeys.IMAGE, dim=0), ], ) postprocessing = Compose( [ EnsureTyped(keys=[CommonKeys.PRED]), KeepLargestConnectedComponentd( keys=CommonKeys.PRED, applied_labels=list(range(1, 3)) ), ], ) inferer = monai.inferers.SlidingWindowInferer( roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5, ) def resize_image(image: np.array, target_shape: tuple): depth_factor = target_shape[0] / image.shape[0] width_factor = target_shape[1] / image.shape[1] height_factor = target_shape[2] / image.shape[2] return ndimage.zoom(image, (depth_factor, width_factor, height_factor), order=1) # model.eval() # with torch.no_grad(): # for i in range(len(test_ds)): # example = test_ds[i] # label = example["t2_anatomy_reader1"] # input_tensor = example["t2"].unsqueeze(0) # input_tensor = input_tensor.to(device) # output_tensor = inferer(input_tensor, model) # output_tensor = output_tensor.argmax(dim=1, keepdim=False) # output_tensor = output_tensor.squeeze(0).to(torch.device("cpu")) # output_tensor = postprocessing({"pred": output_tensor, "label": label})["pred"] # output_tensor = output_tensor.numpy().astype(np.uint8) # target_shape = example["t2_meta_dict"]["spatial_shape"] # output_tensor = resize_image(output_tensor, target_shape) # # flip first two dimensions # output_tensor = np.flip(output_tensor, axis=0) # output_tensor = np.flip(output_tensor, axis=1) # new_image = nib.Nifti1Image(output_tensor, affine=example["t2_meta_dict"]["affine"]) # nib.save(new_image, f"test/{i+1:03}/predicted.nii.gz") # print("Saved", i+1) def make_inference(data_dict:list) -> str: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) model = monai.networks.nets.UNet( in_channels=1, out_channels=3, spatial_dims=3, channels=[16, 32, 64, 128, 256, 512], strides=[2, 2, 2, 2, 2], num_res_units=4, act="PRELU", norm="BATCH", dropout=0.15, ) model.load_state_dict(torch.load("anatomy.pt", map_location=device)) test_ds = Dataset( data=data_dict, transform=transforms, ) model.eval() with torch.no_grad(): example = test_ds[0] # label = example["t2_anatomy_reader1"] input_tensor = example["t2"].unsqueeze(0) input_tensor = input_tensor.to(device) output_tensor = inferer(input_tensor, model) output_tensor = output_tensor.argmax(dim=1, keepdim=False) output_tensor = output_tensor.squeeze(0).to(torch.device("cpu")) # output_tensor = postprocessing({"pred": output_tensor, "label": label})["pred"] output_tensor = postprocessing({"pred": output_tensor})["pred"] output_tensor = output_tensor.numpy().astype(np.uint8) target_shape = example["t2_meta_dict"]["spatial_shape"] output_tensor = resize_image(output_tensor, target_shape) # flip first two dimensions output_tensor = np.flip(output_tensor, axis=0) output_tensor = np.flip(output_tensor, axis=1) new_image = nib.Nifti1Image(output_tensor, affine=example["t2_meta_dict"]["affine"]) nib.save(new_image, "predicted.nii.gz") return "predicted.nii.gz"