import cv2 |
import glob |
import numpy as np |
import os |
import torch |
import torch.nn as nn |
import torch.nn.functional as F |
from transformers import PreTrainedModel |
from timm import create_model |
from .configuration import CTCropConfig |
try: |
from pydicom import dcmread |
except ModuleNotFoundError: |
pass |
class CTCropModel(PreTrainedModel): |
config_class = CTCropConfig |
def __init__(self, config): |
super().__init__(config) |
self.backbone = create_model( |
model_name=config.backbone, |
pretrained=False, |
num_classes=0, |
global_pool="", |
features_only=False, |
in_chans=config.in_chans, |
) |
self.dropout = nn.Dropout(p=config.dropout) |
self.linear = nn.Linear(config.feature_dim, config.num_classes) |
def normalize(self, x: torch.Tensor) -> torch.Tensor: |
mini, maxi = 0.0, 255.0 |
x = (x - mini) / (maxi - mini) |
x = (x - 0.5) * 2.0 |
return x |
@staticmethod |
def window(x: np.ndarray, WL: int, WW: int) -> np.ndarray[np.uint8]: |
lower, upper = WL - WW // 2, WL + WW // 2 |
x = np.clip(x, lower, upper) |
x = (x - lower) / (upper - lower) |
return (x * 255.0).astype("uint8") |
@staticmethod |
def validate_windows_type(windows): |
assert isinstance(windows, tuple) or isinstance(windows, list) |
if isinstance(windows, tuple): |
assert len(windows) == 2 |
assert [isinstance(_, int) for _ in windows] |
elif isinstance(windows, list): |
assert all([isinstance(_, tuple) for _ in windows]) |
assert all([len(_) == 2 for _ in windows]) |
assert all([isinstance(__, int) for _ in windows for __ in _]) |
@staticmethod |
def determine_dicom_orientation(ds) -> int: |
iop = ds.ImageOrientationPatient |
normal_vector = np.cross(iop[:3], iop[3:]) |
abs_normal = np.abs(normal_vector) |
if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]: |
return 0 |
elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]: |
return 1 |
else: |
return 2 |
def load_image_from_dicom( |
self, path: str, windows: tuple[int, int] | list[tuple[int, int]] | None = None |
) -> np.ndarray: |
raise Exception("`pydicom` is not installed") |
dicom = dcmread(path) |
array = dicom.pixel_array.astype("float32") |
m, b = float(dicom.RescaleSlope), float(dicom.RescaleIntercept) |
array = array * m + b |
if windows is None: |
return array |
self.validate_windows_type(windows) |
if isinstance(windows, tuple): |
windows = [windows] |
arr_list = [] |
for WL, WW in windows: |
arr_list.append(self.window(array.copy(), WL, WW)) |
array = np.stack(arr_list, axis=-1) |
if array.shape[-1] == 1: |
array = np.squeeze(array, axis=-1) |
return array |
@staticmethod |
def is_valid_dicom( |
ds, |
fname: str = "", |
sort_by_instance_number: bool = False, |
exclude_invalid_dicoms: bool = False, |
) -> bool: |
attributes = [ |
"pixel_array", |
"RescaleSlope", |
"RescaleIntercept", |
] |
if sort_by_instance_number: |
attributes.append("InstanceNumber") |
else: |
attributes.append("ImagePositionPatient") |
attributes.append("ImageOrientationPatient") |
attributes_present = [hasattr(ds, attr) for attr in attributes] |
valid = all(attributes_present) |
if not valid and not exclude_invalid_dicoms: |
raise Exception( |
f"invalid DICOM file [{fname}]: missing attributes: {list(np.array(attributes)[~np.array(attributes_present)])}" |
) |
return valid |
@staticmethod |
def most_common_element(lst): |
return max(set(lst), key=lst.count) |
@staticmethod |
def center_crop_or_pad_borders(image, size): |
height, width = image.shape[:2] |
new_height, new_width = size |
if new_height < height: |
crop_top = (height - new_height) // 2 |
crop_bottom = height - new_height - crop_top |
image = image[crop_top:-crop_bottom] |
elif new_height > height: |
pad_top = (new_height - height) // 2 |
pad_bottom = new_height - height - pad_top |
image = np.pad( |
image, |
((pad_top, pad_bottom), (0, 0)), |
mode="constant", |
constant_values=0, |
) |
if new_width < width: |
crop_left = (width - new_width) // 2 |
crop_right = width - new_width - crop_left |
image = image[:, crop_left:-crop_right] |
elif new_width > width: |
pad_left = (new_width - width) // 2 |
pad_right = new_width - width - pad_left |
image = np.pad( |
image, |
((0, 0), (pad_left, pad_right)), |
mode="constant", |
constant_values=0, |
) |
return image |
def load_stack_from_dicom_folder( |
self, |
path: str, |
windows: tuple[int, int] | list[tuple[int, int]] | None = None, |
dicom_extension: str = ".dcm", |
sort_by_instance_number: bool = False, |
exclude_invalid_dicoms: bool = False, |
fix_unequal_shapes: str = "crop_pad", |
return_sorted_dicom_files: bool = False, |
) -> np.ndarray | tuple[np.ndarray, list[str]]: |
raise Exception("`pydicom` is not installed") |
dicom_files = glob.glob(os.path.join(path, f"*{dicom_extension}")) |
if len(dicom_files) == 0: |
raise Exception( |
f"No DICOM files found in `{path}` using `dicom_extension={dicom_extension}`" |
) |
dicoms = [dcmread(f) for f in dicom_files] |
dicoms = [ |
(d, dicom_files[idx]) |
for idx, d in enumerate(dicoms) |
if self.is_valid_dicom( |
d, dicom_files[idx], sort_by_instance_number, exclude_invalid_dicoms |
) |
] |
dicom_files = [_[1] for _ in dicoms] |
dicoms = [_[0] for _ in dicoms] |
slices = [dcm.pixel_array.astype("float32") for dcm in dicoms] |
shapes = np.stack([s.shape for s in slices], axis=0) |
if not np.all(shapes == shapes[0]): |
unique_shapes, counts = np.unique(shapes, axis=0, return_counts=True) |
standard_shape = tuple(unique_shapes[np.argmax(counts)]) |
print( |
f"warning: different array shapes present, using {fix_unequal_shapes} -> {standard_shape}" |
) |
if fix_unequal_shapes == "crop_pad": |
slices = [ |
self.center_crop_or_pad_borders(s, standard_shape) |
if s.shape != standard_shape |
else s |
for s in slices |
] |
elif fix_unequal_shapes == "resize": |
slices = [ |
cv2.resize(s, standard_shape) if s.shape != standard_shape else s |
for s in slices |
] |
slices = np.stack(slices, axis=0) |
orientation = [self.determine_dicom_orientation(dcm) for dcm in dicoms] |
orientation = self.most_common_element(orientation) |
if sort_by_instance_number: |
positions = [float(d.InstanceNumber) for d in dicoms] |
else: |
positions = [float(d.ImagePositionPatient[orientation]) for d in dicoms] |
indices = np.argsort(positions) |
slices = slices[indices] |
m, b = ( |
[float(d.RescaleSlope) for d in dicoms], |
[float(d.RescaleIntercept) for d in dicoms], |
) |
m, b = self.most_common_element(m), self.most_common_element(b) |
slices = slices * m + b |
if windows is not None: |
self.validate_windows_type(windows) |
if isinstance(windows, tuple): |
windows = [windows] |
arr_list = [] |
for WL, WW in windows: |
arr_list.append(self.window(slices.copy(), WL, WW)) |
slices = np.stack(arr_list, axis=-1) |
if slices.shape[-1] == 1: |
slices = np.squeeze(slices, axis=-1) |
if return_sorted_dicom_files: |
return slices, [dicom_files[idx] for idx in indices] |
return slices |
@staticmethod |
def preprocess(x: np.ndarray, mode="2d") -> np.ndarray: |
mode = mode.lower() |
if mode == "2d": |
x = cv2.resize(x, (256, 256)) |
if x.ndim == 2: |
x = x[:, :, np.newaxis] |
elif mode == "3d": |
x = np.stack([cv2.resize(s, (256, 256)) for s in x], axis=0) |
if x.ndim == 3: |
x = x[:, :, :, np.newaxis] |
return x |
@staticmethod |
def add_buffer_to_coords( |
coords: torch.Tensor, |
buffer: float | tuple[float, float] = 0.05, |
empty_threshold: float = 1e-4, |
) -> torch.Tensor: |
coords = coords.clone() |
empty = (coords < empty_threshold).all(dim=1) |
assert len(coords.shape) == 2 |
assert coords.shape[1] == 4 |
if isinstance(buffer, float): |
buffer = buffer, buffer |
assert buffer[0] >= 0 and buffer[1] >= 0 |
assert coords.min() >= 0 and coords.max() <= 1 |
if buffer == 0 or empty.sum() == coords.shape[0]: |
return coords |
x1, y1, w, h = coords.unbind(1) |
x2, y2 = x1 + w, y1 + h |
w_buf, h_buf = buffer |
x1, y1, x2, y2 = x1 - w_buf, y1 - h_buf, x2 + w_buf, y2 + h_buf |
x1, y1 = torch.clamp_min(x1, 0), torch.clamp_min(y1, 0) |
x2, y2 = torch.clamp_max(x2, 1), torch.clamp_max(y2, 1) |
w, h = x2 - x1, y2 - y1 |
coords = torch.stack([x1, y1, w, h], dim=1) |
coords[empty] = 0 |
assert coords.min() >= 0 and coords.max() <= 1 |
return coords |
def forward( |
self, |
x: torch.Tensor, |
img_shape: torch.Tensor | None = None, |
add_buffer: float | tuple[float, float] | None = None, |
) -> torch.Tensor: |
if img_shape is not None: |
assert ( |
x.size(0) == img_shape.size(0) |
), f"x.size(0) [{x.size(0)}] must equal img_shape.size(0) [{img_shape.size(0)}]" |
x = self.normalize(x) |
features = F.adaptive_avg_pool2d(self.backbone(x), 1).flatten(1) |
coords = self.linear(self.dropout(features)).sigmoid() |
if add_buffer is not None: |
coords = self.add_buffer_to_coords(coords, add_buffer) |
if img_shape is None: |
return coords |
rescaled_coords = coords.clone() |
rescaled_coords[:, 0] = rescaled_coords[:, 0] * img_shape[:, 1] |
rescaled_coords[:, 1] = rescaled_coords[:, 1] * img_shape[:, 0] |
rescaled_coords[:, 2] = rescaled_coords[:, 2] * img_shape[:, 1] |
rescaled_coords[:, 3] = rescaled_coords[:, 3] * img_shape[:, 0] |
return rescaled_coords.int() |
@torch.no_grad() |
def crop( |
self, |
x: np.ndarray, |
mode: str, |
device: str | None = None, |
raw_hu: bool = False, |
remove_empty_slices: bool = False, |
add_buffer: float | tuple[float, float] | None = None, |
return_coords: bool = False, |
) -> ( |
np.ndarray |
| tuple[np.ndarray, list[int]] |
| tuple[np.ndarray, list[int], list[int]] |
): |
assert mode in ["2d", "3d"] |
if device is None: |
device = "cuda" if torch.cuda.is_available() else "cpu" |
assert isinstance(x, np.ndarray) |
assert ( |
x.ndim <= 4 and x.ndim >= 2 |
), f"# of dimensions should be 2, 3, or 4, got {x.ndim}" |
x0 = x |
if mode == "2d": |
x = np.expand_dims(x, axis=0) |
img_shapes = torch.tensor([_.shape[:2] for _ in x]).to(device) |
x = self.preprocess(x, mode="3d") |
if raw_hu: |
x = self.window(x, WL=50, WW=400) |
x = torch.from_numpy(x) |
x = x.permute(0, 3, 1, 2).float().to(device) |
if x.size(1) > 1: |
x = x.mean(1, keepdim=True) |
coords = self.forward(x, img_shape=img_shapes, add_buffer=add_buffer) |
empty = coords.sum(dim=1) == 0 |
coords = coords[~empty] |
if coords.shape[0] == 0: |
print("no foreground detected, returning original input ...") |
return x0 |
x, y, w, h = coords.unbind(1) |
x1, y1, x2, y2 = x, y, x + w, y + h |
x1, y1 = x1.min().item(), y1.min().item() |
x2, y2 = x2.max().item(), y2.max().item() |
cropped = x0[:, y1:y2, x1:x2] if mode == "3d" else x0[y1:y2, x1:x2] |
if remove_empty_slices and empty.sum() > 0: |
empty_indices = list(torch.where(empty)[0].cpu().numpy()) |
print(f"removing {empty.sum()} empty slices ...") |
cropped = cropped[~empty.cpu().numpy()] |
if not isinstance(cropped, tuple): |
cropped = (cropped,) |
cropped = cropped + (empty_indices,) |
if return_coords: |
if not isinstance(cropped, tuple): |
cropped = (cropped,) |
cropped = cropped + ([x1, y1, x2, y2],) |
return cropped |