TotalClassifier: Slice-Level Organ Classification for CT Examinations
TotalClassifier is a classification model which predicts the presence of various organs on a 2D slice from a CT volume.
It supports axial, sagittal, and coronal images, and a variety of windowing parameters.
This model uses a tf_efficientnetv2_b0
backbone with a gated recurrent unit (GRU) head which performs sequence modeling across extracted slice-level features.
The model also works with single 2D images.
The model is trained on the publicly available TotalSegmentator dataset, version 2.0.1. It predicts 117 labels corresponding to the available labels from TotalSegmentator. The classification labels were generated from the provided segmentation labels.
Note that the model expects one channel. If you create a multi-channel image using multiple CT windows, simply take the mean across channels. The model also expects 8-bit input (converted to float). Thus if your CT volume is in Hounsfield units, you can apply a standard window, such as soft tissue (level=50, width=400), before inputting it into the model.
Example Usage
import torch
from transformers import AutoModel
device = "cuda"
organ_model = AutoModel.from_pretrained("ianpan/total-classifier", trust_remote_code=True).eval().to(device)
# can use model to load CT from folder with DICOM files, if pydicom is installed
# here we apply soft tissue window
ct_volume = organ_model.load_stack_from_dicom_folder("/path/to/dicom/folder", windows=[[50, 400]], dicom_extension=".dcm")
# ct_volume.shape is (num_slices, height, width, num_channels) if applying windows
# otherwise is (num_slices, height, width) if using original Hounsfield units
# preprocess
x = model.preprocess(ct_volume, mode="3d", torchify=True, add_batch_dim=True, device=device)
# here, ct_volume is a numpy array
# if you are loading volumes as torch.Tensors, then you can skip the preprocess function
# and just resize the volume to height and width of 256 x 256
# x is now torch.Tensor with shape (1, num_slices, num_channels, height, width)
# note that these are the expected dims for the model's forward method
with torch.inference_mode():
out = organ_model(x)
out_df = organ_model(x, return_as_df=True)
# out is a torch.Tensor of shape (1, num_slices, 117) containing scores [0-1] for each organ label
# out_df is a list of pandas DataFrames with shape (num_slices, 117), where column names are the organ names
# each element of the list corresponds to each sample in the batch
# however if using batch sizes >1, then all samples need to be padded to the same number of slices
# you can use out_df to only get slices with predicted organ labels greater than a certain threshold
out_df = out_df[0]
threshold = 0.5
liver_indices = np.where(out_df["liver"].values >= threshold)[0]
# or slices where at least one of the specified organ labels is greater than threshold
organs_of_interest = ["liver", "spleen", "pancreas"]
threshold = 0.5
slice_indices = np.where((out_df[organs_of_interest].values >= threshold).max(1))[0]
# organ_model.label2index can be used to convert organ label names to the indices 0-116
# organ_model.index2label is the inverse
If you have a large number of slices and limited GPU memory, you can either process the volume in chunks, or downsample the volume along the slice dimension and interpolate the predictions back to the original number of slices.
- Downloads last month
- 40
Model tree for ianpan/total-classifier
Base model
timm/tf_efficientnetv2_b0.in1k