import cv2 import glob import numpy as np import os import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from timm import create_model from .configuration import CTCropConfig _PYDICOM_AVAILABLE = False try: from pydicom import dcmread _PYDICOM_AVAILABLE = True except ModuleNotFoundError: pass class CTCropModel(PreTrainedModel): config_class = CTCropConfig def __init__(self, config): super().__init__(config) self.backbone = create_model( model_name=config.backbone, pretrained=False, num_classes=0, global_pool="", features_only=False, in_chans=config.in_chans, ) self.dropout = nn.Dropout(p=config.dropout) self.linear = nn.Linear(config.feature_dim, config.num_classes) def normalize(self, x: torch.Tensor) -> torch.Tensor: # [0, 255] -> [-1, 1] mini, maxi = 0.0, 255.0 x = (x - mini) / (maxi - mini) x = (x - 0.5) * 2.0 return x @staticmethod def window(x: np.ndarray, WL: int, WW: int) -> np.ndarray[np.uint8]: # applying windowing to CT lower, upper = WL - WW // 2, WL + WW // 2 x = np.clip(x, lower, upper) x = (x - lower) / (upper - lower) return (x * 255.0).astype("uint8") @staticmethod def validate_windows_type(windows): assert isinstance(windows, tuple) or isinstance(windows, list) if isinstance(windows, tuple): assert len(windows) == 2 assert [isinstance(_, int) for _ in windows] elif isinstance(windows, list): assert all([isinstance(_, tuple) for _ in windows]) assert all([len(_) == 2 for _ in windows]) assert all([isinstance(__, int) for _ in windows for __ in _]) @staticmethod def determine_dicom_orientation(ds) -> int: iop = ds.ImageOrientationPatient # Calculate the direction cosine for the normal vector of the plane normal_vector = np.cross(iop[:3], iop[3:]) # Determine the plane based on the largest component of the normal vector abs_normal = np.abs(normal_vector) if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]: return 0 # sagittal elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]: return 1 # coronal else: return 2 # axial def load_image_from_dicom( self, path: str, windows: tuple[int, int] | list[tuple[int, int]] | None = None ) -> np.ndarray: # windows can be tuple of (WINDOW_LEVEL, WINDOW_WIDTH) # or list of tuples if wishing to generate multi-channel image using # > 1 window if not _PYDICOM_AVAILABLE: raise Exception("`pydicom` is not installed") dicom = dcmread(path) array = dicom.pixel_array.astype("float32") m, b = float(dicom.RescaleSlope), float(dicom.RescaleIntercept) array = array * m + b if windows is None: return array self.validate_windows_type(windows) if isinstance(windows, tuple): windows = [windows] arr_list = [] for WL, WW in windows: arr_list.append(self.window(array.copy(), WL, WW)) array = np.stack(arr_list, axis=-1) if array.shape[-1] == 1: array = np.squeeze(array, axis=-1) return array @staticmethod def is_valid_dicom( ds, fname: str = "", sort_by_instance_number: bool = False, exclude_invalid_dicoms: bool = False, ): attributes = [ "pixel_array", "RescaleSlope", "RescaleIntercept", ] if sort_by_instance_number: attributes.append("InstanceNumber") else: attributes.append("ImagePositionPatient") attributes.append("ImageOrientationPatient") attributes_present = [hasattr(ds, attr) for attr in attributes] valid = all(attributes_present) if not valid and not exclude_invalid_dicoms: raise Exception( f"invalid DICOM file [{fname}]: missing attributes: {list(np.array(attributes)[~np.array(attributes_present)])}" ) return valid @staticmethod def most_common_element(lst): return max(set(lst), key=lst.count) @staticmethod def center_crop_or_pad_borders(image, size): height, width = image.shape[:2] new_height, new_width = size if new_height < height: # crop top and bottom crop_top = (height - new_height) // 2 crop_bottom = height - new_height - crop_top image = image[crop_top:-crop_bottom] elif new_height > height: # pad top and bottom pad_top = (new_height - height) // 2 pad_bottom = new_height - height - pad_top image = np.pad( image, ((pad_top, pad_bottom), (0, 0)), mode="constant", constant_values=0, ) if new_width < width: # crop left and right crop_left = (width - new_width) // 2 crop_right = width - new_width - crop_left image = image[:, crop_left:-crop_right] elif new_width > width: # pad left and right pad_left = (new_width - width) // 2 pad_right = new_width - width - pad_left image = np.pad( image, ((0, 0), (pad_left, pad_right)), mode="constant", constant_values=0, ) return image def load_stack_from_dicom_folder( self, path: str, windows: tuple[int, int] | list[tuple[int, int]] | None = None, dicom_extension: str = ".dcm", sort_by_instance_number: bool = False, exclude_invalid_dicoms: bool = False, fix_unequal_shapes: str = "crop_pad", return_sorted_dicom_files: bool = False, ) -> np.ndarray | tuple[np.ndarray, list[str]]: if not _PYDICOM_AVAILABLE: raise Exception("`pydicom` is not installed") dicom_files = glob.glob(os.path.join(path, f"*{dicom_extension}")) if len(dicom_files) == 0: raise Exception( f"No DICOM files found in `{path}` using `dicom_extension={dicom_extension}`" ) dicoms = [dcmread(f) for f in dicom_files] dicoms = [ (d, dicom_files[idx]) for idx, d in enumerate(dicoms) if self.is_valid_dicom( d, dicom_files[idx], sort_by_instance_number, exclude_invalid_dicoms ) ] # handles exclude_invalid_dicoms=True and return_sorted_dicom_files=True # by only including valid DICOM filenames dicom_files = [_[1] for _ in dicoms] dicoms = [_[0] for _ in dicoms] slices = [dcm.pixel_array.astype("float32") for dcm in dicoms] shapes = np.stack([s.shape for s in slices], axis=0) if not np.all(shapes == shapes[0]): unique_shapes, counts = np.unique(shapes, axis=0, return_counts=True) standard_shape = tuple(unique_shapes[np.argmax(counts)]) print( f"warning: different array shapes present, using {fix_unequal_shapes} -> {standard_shape}" ) if fix_unequal_shapes == "crop_pad": slices = [ self.center_crop_or_pad_borders(s, standard_shape) if s.shape != standard_shape else s for s in slices ] elif fix_unequal_shapes == "resize": slices = [ cv2.resize(s, standard_shape) if s.shape != standard_shape else s for s in slices ] slices = np.stack(slices, axis=0) # find orientation orientation = [self.determine_dicom_orientation(dcm) for dcm in dicoms] # use most common orientation = self.most_common_element(orientation) # sort using ImagePositionPatient # orientation is index to use for sorting if sort_by_instance_number: positions = [float(d.InstanceNumber) for d in dicoms] else: positions = [float(d.ImagePositionPatient[orientation]) for d in dicoms] indices = np.argsort(positions) slices = slices[indices] # rescale m, b = ( [float(d.RescaleSlope) for d in dicoms], [float(d.RescaleIntercept) for d in dicoms], ) m, b = self.most_common_element(m), self.most_common_element(b) slices = slices * m + b if windows is not None: self.validate_windows_type(windows) if isinstance(windows, tuple): windows = [windows] arr_list = [] for WL, WW in windows: arr_list.append(self.window(slices.copy(), WL, WW)) slices = np.stack(arr_list, axis=-1) if slices.shape[-1] == 1: slices = np.squeeze(slices, axis=-1) if return_sorted_dicom_files: return slices, [dicom_files[idx] for idx in indices] return slices @staticmethod def preprocess(x: np.ndarray, mode="2d") -> np.ndarray: mode = mode.lower() if mode == "2d": x = cv2.resize(x, (256, 256)) if x.ndim == 2: x = x[:, :, np.newaxis] elif mode == "3d": x = np.stack([cv2.resize(s, (256, 256)) for s in x], axis=0) if x.ndim == 3: x = x[:, :, :, np.newaxis] return x @staticmethod def add_buffer_to_coords( coords: torch.Tensor, buffer: float | tuple[float, float] = 0.05, empty_threshold: float = 1e-4, ): coords = coords.clone() empty = (coords < empty_threshold).all(dim=1) # assumes coords is a torch.Tensor of shape (N, 4) containing # normalized x, y, w, h coordinates # buffer is for EACH SIDE (i.e., 0.05 will add total of 0.1) assert len(coords.shape) == 2 assert coords.shape[1] == 4 if isinstance(buffer, float): buffer = buffer, buffer assert buffer[0] >= 0 and buffer[1] >= 0 assert coords.min() >= 0 and coords.max() <= 1 if buffer == 0 or empty.sum() == coords.shape[0]: return coords # convert xywh->xyxy x1, y1, w, h = coords.unbind(1) x2, y2 = x1 + w, y1 + h # since coords are normalized, can use buffer value directly w_buf, h_buf = buffer x1, y1, x2, y2 = x1 - w_buf, y1 - h_buf, x2 + w_buf, y2 + h_buf x1, y1 = torch.clamp_min(x1, 0), torch.clamp_min(y1, 0) x2, y2 = torch.clamp_max(x2, 1), torch.clamp_max(y2, 1) w, h = x2 - x1, y2 - y1 coords = torch.stack([x1, y1, w, h], dim=1) coords[empty] = 0 assert coords.min() >= 0 and coords.max() <= 1 return coords def forward( self, x: torch.Tensor, img_shape: torch.Tensor | None = None, add_buffer: float | tuple[float, float] | None = None, ) -> torch.Tensor: # if img_shape is provided, will provide rescaled coordinates # otherwise, provide normalized [0, 1] coordinates # coords format is xywh if img_shape is not None: assert ( x.size(0) == img_shape.size(0) ), f"x.size(0) [{x.size(0)}] must equal img_shape.size(0) [{img_shape.size(0)}]" # img_shape = (batch_dim, 2) # img_shape[:, 0] = height, img_shape[:, 1] = width x = self.normalize(x) # avg pooling features = F.adaptive_avg_pool2d(self.backbone(x), 1).flatten(1) coords = self.linear(features).sigmoid() if add_buffer is not None: coords = self.add_buffer_to_coords(coords, add_buffer) if img_shape is None: return coords rescaled_coords = coords.clone() rescaled_coords[:, 0] = rescaled_coords[:, 0] * img_shape[:, 1] rescaled_coords[:, 1] = rescaled_coords[:, 1] * img_shape[:, 0] rescaled_coords[:, 2] = rescaled_coords[:, 2] * img_shape[:, 1] rescaled_coords[:, 3] = rescaled_coords[:, 3] * img_shape[:, 0] return rescaled_coords.int() def crop( self, x: np.ndarray, mode: str, device: str | None = None, raw_hu: bool = False, remove_empty_slices: bool = False, add_buffer: float | tuple[float, float] | None = None, ) -> np.ndarray: assert mode in ["2d", "3d"] if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" assert isinstance(x, np.ndarray) assert ( x.ndim <= 4 and x.ndim >= 2 ), f"# of dimensions should be 2, 3, or 4, got {x.ndim}" x0 = x if mode == "2d": x = np.expand_dims(x, axis=0) img_shapes = torch.tensor([_.shape[:2] for _ in x]).to(device) x = self.preprocess(x, mode="3d") if raw_hu: # if input is in Hounsfield units, apply soft tissue window x = self.window(x, WL=50, WW=400) # torchify x = torch.from_numpy(x) x = x.permute(0, 3, 1, 2).float().to(device) if x.size(1) > 1: # if multi-channel, take mean x = x.mean(1, keepdim=True) coords = self.forward(x, img_shape=img_shapes, add_buffer=add_buffer) # get the union of all slice-wise bounding boxes # exclude empty boxes empty = coords.sum(dim=1) == 0 coords = coords[~empty] # if all empty, return original input if coords.shape[0] == 0: print("no foreground detected, returning original input ...") return x0 x, y, w, h = coords.unbind(1) # xywh -> xyxy x1, y1, x2, y2 = x, y, x + w, y + h x1, y1 = x1.min().item(), y1.min().item() x2, y2 = x2.max().item(), y2.max().item() cropped = x0[:, y1:y2, x1:x2] if mode == "3d" else x0[y1:y2, x1:x2] if remove_empty_slices and empty.sum() > 0: empty_indices = list(torch.where(empty)[0].cpu().numpy()) print(f"removing {empty.sum()} empty slices ...") cropped = cropped[~empty.cpu().numpy()] return cropped, empty_indices return cropped