TotalClassifier: Slice-Level Organ Classification for CT Examinations

TotalClassifier is a classification model which predicts the presence of various organs on a 2D slice from a CT volume. It supports axial, sagittal, and coronal images, and a variety of windowing parameters. This model uses a tf_efficientnetv2_b0 backbone with a gated recurrent unit (GRU) head which performs sequence modeling across extracted slice-level features. The model also works with single 2D images.

The model is trained on the publicly available TotalSegmentator dataset, version 2.0.1. It predicts 117 labels corresponding to the available labels from TotalSegmentator. The classification labels were generated from the provided segmentation labels.

Note that the model expects one channel. If you create a multi-channel image using multiple CT windows, simply take the mean across channels. The model also expects 8-bit input (converted to float). Thus if your CT volume is in Hounsfield units, you can apply a standard window, such as soft tissue (level=50, width=400), before inputting it into the model.

Example Usage

import torch
from transformers import AutoModel

device = "cuda"
organ_model = AutoModel.from_pretrained("ianpan/total-classifier", trust_remote_code=True).eval().to(device)

# can use model to load CT from folder with DICOM files, if pydicom is installed
# here we apply soft tissue window 
ct_volume = organ_model.load_stack_from_dicom_folder("/path/to/dicom/folder", windows=[[50, 400]], dicom_extension=".dcm")

# ct_volume.shape is (num_slices, height, width, num_channels) if applying windows
# otherwise is (num_slices, height, width) if using original Hounsfield units

# preprocess
x = model.preprocess(ct_volume, mode="3d", torchify=True, add_batch_dim=True, device=device)

# here, ct_volume is a numpy array
# if you are loading volumes as torch.Tensors, then you can skip the preprocess function
# and just resize the volume to height and width of 256 x 256

# x is now torch.Tensor with shape (1, num_slices, num_channels, height, width)
# note that these are the expected dims for the model's forward method

with torch.inference_mode():
  out = organ_model(x)
  out_df = organ_model(x, return_as_df=True)

# out is a torch.Tensor of shape (1, num_slices, 117) containing scores [0-1] for each organ label
# out_df is a list of pandas DataFrames with shape (num_slices, 117), where column names are the organ names
# each element of the list corresponds to each sample in the batch
# however if using batch sizes >1, then all samples need to be padded to the same number of slices

# you can use out_df to only get slices with predicted organ labels greater than a certain threshold
out_df = out_df[0]
threshold = 0.5
liver_indices = np.where(out_df["liver"].values >= threshold)[0]

# or slices where at least one of the specified organ labels is greater than threshold
organs_of_interest = ["liver", "spleen", "pancreas"]
threshold = 0.5
slice_indices = np.where((out_df[organs_of_interest].values >= threshold).max(1))[0]

# organ_model.label2index can be used to convert organ label names to the indices 0-116
# organ_model.index2label is the inverse

If you have a large number of slices and limited GPU memory, you can either process the volume in chunks, or downsample the volume along the slice dimension and interpolate the predictions back to the original number of slices.

Downloads last month
40
Safetensors
Model size
5.86M params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The HF Inference API does not support model that require custom code execution.

Model tree for ianpan/total-classifier

Finetuned
(1)
this model