--- library_name: transformers tags: - ct - computed_tomography - crop - dicom - radiology license: apache-2.0 base_model: - timm/mobilenetv3_small_100.lamb_in1k pipeline_tag: object-detection --- This model crops the foreground from the background in CT slices. It is a lightweight `mobilenetv3_small_100` model trained on CT examinations from the [public TotalSegmentator dataset](https://zenodo.org/records/10047292), version.2.0.1. The following function was used to generate masks for each CT: ``` import nibabel as nib import numpy as np from scipy.ndimage import binary_closing, binary_fill_holes, minimum_filter from skimage.measure import label def generate_mask(array): mask = (array > 0).astype("uint8") mask_label = label(mask) labels, counts = np.unique(mask_label, return_counts=True) labels, counts = labels[1:], counts[1:] max_label = labels[np.argmax(counts)] mask = mask_label == max_label mask = np.stack([ binary_fill_holes(binary_closing(mask[:, :, i])) for i in range(mask.shape[2]) ], axis=2).astype("uint8") mask = np.stack([ minimum_filter(mask[:, :, i], size=3) for i in range(mask.shape[2]) ], axis=2) return mask array = nib.load("ct.nii.gz").get_fdata() # apply soft tissue window array = apply_ct_window(array, window_level=50, window_width=400) mask = generate_mask(array) ``` Bounding box coordinates were generated from the masks for individual slices. The model was then trained to predict normalized (0-1) `xwyh` coordinates, given an individual CT slice. If the mask was empty, the coordinates were set to all zero. Images were converted from Hounsfield units (HU) to 4 CT windows: 1. Soft tissue (level=50, width=400) 2. Brain (level=40, width=80) 3. Lung (level=-600, width=1500) 4. Bone (level=400, width=1800) During training, random combinations of channels were selected. If more than 1 channel was selected, the images were averaged channel-wise to create a single-channel output. Strong data augmentation was also applied. Thus, this model should be robust to different CT windows and combinations thereof. Example usage below: ``` import cv2 import torch from transformers import AutoModel device = "cuda" if torch.cuda.is_available() else "cpu" cropper = AutoModel.from_pretrained("ianpan/ct-crop", trust_remote_code=True).eval().to(device) # single image img = cv2.imread("ct_slice.png", cv2.IMREAD_GRAYSCALE) cropped_img = cropper.crop(img, mode="2d", device=device, raw_hu=False, add_buffer=None) # expand all 4 sides by 2.5% each cropped_img = cropper.crop(img, mode="2d", device=device, raw_hu=False, add_buffer=0.025) # expand box height by 2.5% in each direction # and box width by 5% in each direction buffer = (0.05, 0.025) cropped_img = cropper.crop(img, mode="2d", device=device, raw_hu=False, add_buffer=buffer) # stack of images img_list = ["ct_slice_1.png", "ct_slice_2.png", ...] stack = np.stack([cv2.imread(img, cv2.IMREAD_GRAYSCALE) for img in img_list], axis=0) cropped_stack = cropper.crop(img, mode="3d", device=device, raw_hu=False, add_buffer=None) ``` You can also get the coordinates directly and do the cropping yourself. You must separately preprocess the input. Example below: ``` # single image img0 = cv2.imread("ct_slice.png", cv2.IMREAD_GRAYSCALE) img_shapes = torch.tensor([_.shape[:2] for _ in [img0]]).to(device) img = cropper.preprocess(img0, mode="2d") # if multi-channel, need to convert from channels-last -> channels-first img = torch.from_numpy(img).expand(1, 1, -1, -1).float().to(device) with torch.inference_mode(): coords = cropper(img, img_shape=img_shapes, add_buffer=None) # if you do not provide img_shapes, output will be normalized (0-1) coordinates # otherwise will be scaled to img_shape ``` The model also contains methods to load DICOM images, if you have `pydicom` installed: ``` img = cropper.load_image_from_dicom(path_to_dicom_file, windows=None) # note: RescaleSlope and RescaleIntercept already applied in the method # apply CT window brain_window = (40, 80) img = cropper.load_image_from_dicom(path_to_dicom_file, windows=brain_window) # or multiple windows soft_tissue_window = (50, 400) img = cropper.load_image_from_dicom(path_to_dicom_file, windows=[brain_window, soft_tissue_window]) # each window is a separate channel, img will be channels-last format ``` You can also load a stack of DICOM images from a folder: ``` dicom_folder = "/path/to/ct/head/images/" # dicom_extension is used to filter files, default is ".dcm" # can pass "" if you do not want to filter files # default sort is by ImagePositionPatient using automatically determined # orientation, can also sort by InstanceNumber # can also apply CT windows, as above stack = cropper.load_stack_from_dicom_folder(dicom_folder, windows=None, dicom_extension=".dcm", sort_by_instance_number=False) # can input raw Hounsfield units into cropper cropped_stack = cropper.crop(stack, mode="3d", device=device, raw_hu=True) ``` By default, the cropper will not remove slices in a stack, even if they are predicted to be empty. You can enable this by specifying `remove_empty_slices=True`, which will also return the indices in the original input of the removed empty slices. ``` cropped_stack, empty_slice_indices = cropper.crop(stack, mode="3d", remove_empty_slices=True) ```