|
import copy |
|
import json |
|
import logging |
|
import os |
|
import pathlib |
|
from typing import Sequence |
|
|
|
import numpy as np |
|
import torch |
|
from monai.apps.vista3d.transforms import VistaPostTransformd, VistaPreTransformd |
|
from monai.data.utils import decollate_batch, list_data_collate |
|
from monai.networks.utils import eval_mode, train_mode |
|
from monai.transforms import ( |
|
CastToTyped, |
|
Compose, |
|
CropForegroundd, |
|
EnsureChannelFirstd, |
|
EnsureTyped, |
|
Invertd, |
|
Lambdad, |
|
LoadImaged, |
|
Orientationd, |
|
SaveImaged, |
|
ScaleIntensityRanged, |
|
Spacingd, |
|
reset_ops_id, |
|
) |
|
from monai.utils import ForwardMode, optional_import, set_determinism |
|
from monai.utils.enums import CommonKeys as Keys |
|
from monai.utils.module import look_up_option |
|
from scripts.inferer import Vista3dInferer |
|
from transformers import AutoModel, Pipeline |
|
from transformers.pipelines import PIPELINE_REGISTRY |
|
|
|
rearrange, _ = optional_import("einops", name="rearrange") |
|
|
|
FILE_PATH = os.path.dirname(__file__) |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class VISTA3DPipeline(Pipeline): |
|
"""Define the VISTA3D pipeline.""" |
|
|
|
PREPROCESSING_EXTRA_ARGS = [ |
|
"image_key", |
|
"resample_spacing", |
|
"metadata_path", |
|
] |
|
INFERENCE_EXTRA_ARGS = [ |
|
"mode", |
|
"amp", |
|
"hyper_kwargs", |
|
"roi_size", |
|
"overlap", |
|
"sw_batch_size", |
|
"use_point_window", |
|
] |
|
POSTPROCESSING_EXTRA_ARGS = [ |
|
"pred_key", |
|
"image_key", |
|
"output_dir", |
|
"output_ext", |
|
"output_postfix", |
|
"separate_folder", |
|
"save_output", |
|
] |
|
EVERYTHING_LABEL = list( |
|
set([i + 1 for i in range(132)]) |
|
- set([2, 16, 18, 20, 21, 23, 24, 25, 26, 27, 128, 129, 130, 131, 132]) |
|
) |
|
|
|
def __init__(self, model, **kwargs): |
|
super().__init__(model, **kwargs) |
|
self.preprocessing_transforms = self._init_preprocessing_transforms( |
|
**self._preprocess_params |
|
) |
|
self.inferer = self._init_inferer(**self._forward_params) |
|
self.postprocessing_transforms = self._init_postprocessing_transforms( |
|
**self._postprocess_params |
|
) |
|
|
|
def _init_inferer( |
|
self, |
|
roi_size: Sequence = (128, 128, 128), |
|
overlap: float = 0.3, |
|
sw_batch_size: int = 1, |
|
use_point_window: bool = True, |
|
): |
|
return Vista3dInferer( |
|
roi_size=roi_size, |
|
overlap=overlap, |
|
use_point_window=use_point_window, |
|
sw_batch_size=sw_batch_size, |
|
) |
|
|
|
def _init_preprocessing_transforms( |
|
self, |
|
image_key: str = "image", |
|
resample_spacing: Sequence = (1.5, 1.5, 1.5), |
|
metadata_path: str = os.path.join(FILE_PATH, "metadata.json"), |
|
): |
|
device = self.device |
|
subclass = { |
|
"2": [14, 5], |
|
"20": [28, 29, 30, 31, 32], |
|
"21": list(range(33, 57)) + list(range(63, 98)) + [114, 120, 122], |
|
} |
|
metadata = json.loads(pathlib.Path(metadata_path).read_text()) |
|
labels_dict = metadata["network_data_format"]["outputs"]["pred"]["channel_def"] |
|
preprocessing_transforms = Compose( |
|
[ |
|
LoadImaged(keys=image_key, image_only=True), |
|
EnsureChannelFirstd(keys=image_key), |
|
EnsureTyped(keys=image_key, device=device, track_meta=True), |
|
Spacingd(keys=image_key, pixdim=resample_spacing, mode="bilinear"), |
|
CropForegroundd( |
|
keys=image_key, allow_smaller=True, margin=10, source_key=image_key |
|
), |
|
VistaPreTransformd( |
|
keys=image_key, subclass=subclass, labels_dict=labels_dict |
|
), |
|
ScaleIntensityRanged( |
|
keys=image_key, |
|
a_min=-963.8247715525971, |
|
a_max=1053.678477684517, |
|
b_min=0, |
|
b_max=1, |
|
clip=True, |
|
), |
|
Orientationd(keys=image_key, axcodes="RAS"), |
|
CastToTyped(keys=image_key, dtype=torch.float32), |
|
] |
|
) |
|
return preprocessing_transforms |
|
|
|
def _init_postprocessing_transforms( |
|
self, |
|
pred_key: str = "pred", |
|
image_key: str = "image", |
|
output_dir: str = "output_directory", |
|
output_ext: str = ".nii.gz", |
|
output_dtype: torch.dtype = torch.float32, |
|
output_postfix: str = "seg", |
|
separate_folder: bool = True, |
|
save_output: bool = True, |
|
): |
|
transforms = [ |
|
VistaPostTransformd(keys=pred_key), |
|
Invertd( |
|
keys=pred_key, |
|
transform=copy.deepcopy(self.preprocessing_transforms), |
|
orig_keys=image_key, |
|
nearest_interp=True, |
|
to_tensor=True, |
|
), |
|
Lambdad(keys=pred_key, func=lambda x: torch.nan_to_num(x, nan=255)), |
|
] |
|
if save_output: |
|
transforms.append( |
|
SaveImaged( |
|
keys=pred_key, |
|
resample=False, |
|
output_dir=output_dir, |
|
output_ext=output_ext, |
|
output_dtype=output_dtype, |
|
output_postfix=output_postfix, |
|
separate_folder=separate_folder, |
|
), |
|
) |
|
postprocessing_transforms = Compose(transforms=transforms) |
|
return postprocessing_transforms |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
""" |
|
_sanitize_parameters exists to allow users to pass any parameters whenever they wish, |
|
be it at initialization time pipeline(...., maybe_arg=4) or at call time pipe = pipeline(...); output = pipe(...., maybe_arg=4). |
|
The returns of _sanitize_parameters are the 3 dicts of kwargs that will be passed directly to preprocess, _forward and postprocess. |
|
Don't fill anything if the caller didn't call with any extra parameter. That allows to keep the default arguments in the function |
|
definition which is always more “natural”.""" |
|
|
|
vista3d_preprocessing_kwargs = {} |
|
vista3d_infer_kwargs = {} |
|
vista3d_postprocessing_kwargs = {} |
|
for key in self.INFERENCE_EXTRA_ARGS: |
|
if key in kwargs: |
|
vista3d_infer_kwargs[key] = kwargs[key] |
|
|
|
for key in self.PREPROCESSING_EXTRA_ARGS: |
|
if key in kwargs: |
|
vista3d_preprocessing_kwargs[key] = kwargs[key] |
|
|
|
for key in self.POSTPROCESSING_EXTRA_ARGS: |
|
if key in kwargs: |
|
vista3d_postprocessing_kwargs[key] = kwargs[key] |
|
|
|
return ( |
|
vista3d_preprocessing_kwargs, |
|
vista3d_infer_kwargs, |
|
vista3d_postprocessing_kwargs, |
|
) |
|
|
|
def check_prompts_format(self, label_prompt, points, point_labels): |
|
"""check the format of user prompts |
|
label_prompt: [1,2,3,4,...,B] List of tensors |
|
points: [[[x,y,z], [x,y,z], ...]] List of coordinates of a single object |
|
point_labels: [[1,1,0,...]] List of scalar that matches number of points |
|
""" |
|
|
|
if label_prompt is None and points is None: |
|
everything_labels = self.hyper_kwargs.get("everything_labels", None) |
|
if everything_labels is not None: |
|
label_prompt = [torch.tensor(_) for _ in everything_labels] |
|
return label_prompt, points, point_labels |
|
else: |
|
raise ValueError("Prompt must be given for inference.") |
|
|
|
if label_prompt is not None: |
|
if isinstance(label_prompt, list): |
|
if not np.all([len(_) == 1 for _ in label_prompt]): |
|
raise ValueError( |
|
"Label prompt must be a list of single scalar, [1,2,3,4,...,]." |
|
) |
|
if isinstance(label_prompt[0], list): |
|
for prompt in label_prompt: |
|
if not np.all([(x < 255).item() for x in prompt]): |
|
raise ValueError( |
|
"Current bundle only supports label prompt smaller than 255." |
|
) |
|
else: |
|
if not np.all([(x < 255).item() for x in label_prompt]): |
|
raise ValueError( |
|
"Current bundle only supports label prompt smaller than 255." |
|
) |
|
if points is None: |
|
supported_list = list( |
|
{i + 1 for i in range(132)} - {16, 18, 129, 130, 131} |
|
) |
|
if isinstance(label_prompt[0], list): |
|
for prompt in label_prompt: |
|
if not np.all([(x < 255).item() for x in prompt]): |
|
raise ValueError( |
|
"Current bundle only supports label prompt smaller than 255." |
|
) |
|
else: |
|
if not np.all([x in supported_list for x in label_prompt]): |
|
raise ValueError( |
|
"Undefined label prompt detected. Provide point prompts for zero-shot." |
|
) |
|
else: |
|
raise ValueError("Label prompt must be a list, [1,2,3,4,...,].") |
|
|
|
if points is not None: |
|
if point_labels is None: |
|
raise ValueError("Point labels must be given if points are given.") |
|
if not np.all([len(_) == 3 for _ in points]): |
|
raise ValueError( |
|
"Points must be three dimensional (x,y,z) in the shape of [[x,y,z],...,[x,y,z]]." |
|
) |
|
if len(points) != len(point_labels): |
|
raise ValueError("Points must match point labels.") |
|
if not np.all([_ in [-1, 0, 1, 2, 3] for _ in point_labels]): |
|
raise ValueError( |
|
"Point labels can only be -1,0,1 and 2,3 for special flags." |
|
) |
|
if label_prompt is not None and points is not None: |
|
if len(label_prompt) != 1: |
|
raise ValueError( |
|
"Label prompt can only be a single object if provided with point prompts." |
|
) |
|
|
|
if point_labels is not None: |
|
if points is None: |
|
raise ValueError("Points must be given if point labels are given.") |
|
return label_prompt, points, point_labels |
|
|
|
def transform_points(self, point, affine): |
|
"""transform point to the coordinates of the transformed image |
|
point: numpy array [bs, N, 3] |
|
""" |
|
bs, n = point.shape[:2] |
|
point = np.concatenate((point, np.ones((bs, n, 1))), axis=-1) |
|
point = rearrange(point, "b n d -> d (b n)") |
|
point = affine @ point |
|
point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3] |
|
return point |
|
|
|
def preprocess( |
|
self, |
|
inputs, |
|
**kwargs, |
|
): |
|
for key, value in kwargs.items(): |
|
if key in self._preprocess_params and value != self._preprocess_params[key]: |
|
logging.warning( |
|
f"Please set the parameter {key} during initialization." |
|
) |
|
|
|
if key not in self.PREPROCESSING_EXTRA_ARGS: |
|
logging.warning(f"Cannot set parameter {key} for preprocessing.") |
|
inputs = self.preprocessing_transforms(inputs) |
|
inputs = list_data_collate([inputs]) |
|
return inputs |
|
|
|
def _forward( |
|
self, |
|
inputs, |
|
mode: str = ForwardMode.EVAL, |
|
amp: bool = True, |
|
hyper_kwargs: dict = {"user_prompt": 1, "everything_labels": 1}, |
|
): |
|
set_determinism(seed=123) |
|
|
|
if inputs is None: |
|
raise ValueError("Must provide input data for inference.") |
|
self.hyper_kwargs = hyper_kwargs |
|
|
|
label_set = hyper_kwargs.get("label_set", None) |
|
|
|
val_label_set = hyper_kwargs.get("val_label_set", label_set) |
|
|
|
|
|
if hyper_kwargs["user_prompt"]: |
|
inputs, label_prompt, points, point_labels = ( |
|
inputs["image"], |
|
inputs.get("label_prompt", None), |
|
inputs.get("points", None), |
|
inputs.get("point_labels", None), |
|
) |
|
labels = None |
|
label_prompt, points, point_labels = self.check_prompts_format( |
|
label_prompt, points, point_labels |
|
) |
|
inputs = inputs.to(self.device) |
|
|
|
label_prompt = ( |
|
torch.as_tensor([label_prompt]).to(inputs.device)[0].unsqueeze(-1) |
|
if label_prompt is not None |
|
else None |
|
) |
|
|
|
if points is not None: |
|
points = torch.as_tensor([points]) |
|
points = self.transform_points( |
|
points, |
|
np.linalg.inv(inputs.affine[0]) |
|
@ inputs.meta["original_affine"][0].numpy(), |
|
) |
|
points = torch.from_numpy(points).to(inputs.device) |
|
point_labels = ( |
|
torch.as_tensor([point_labels]).to(inputs.device) |
|
if point_labels is not None |
|
else None |
|
) |
|
|
|
|
|
else: |
|
|
|
inputs, labels = inputs["image"], inputs["label"] |
|
|
|
if label_set is None: |
|
output_classes = hyper_kwargs.get("output_classes", None) |
|
label_set = np.arange(output_classes).tolist() |
|
label_prompt = torch.tensor(label_set).to(self.device).unsqueeze(-1) |
|
|
|
points = torch.zeros(label_prompt.shape[0], 1, 3).to(inputs.device) |
|
point_labels = -1 + torch.zeros(label_prompt.shape[0], 1).to(inputs.device) |
|
|
|
if hyper_kwargs.get("val_head", "auto") == "auto": |
|
|
|
|
|
val_label_set = None |
|
else: |
|
|
|
label_prompt = None |
|
|
|
|
|
outputs = {Keys.IMAGE: inputs, Keys.LABEL: labels} |
|
mode = look_up_option(mode, ForwardMode) |
|
if mode == ForwardMode.EVAL: |
|
mode = eval_mode |
|
elif mode == ForwardMode.TRAIN: |
|
mode = train_mode |
|
else: |
|
raise ValueError(f"unsupported mode: {mode}, should be 'eval' or 'train'.") |
|
|
|
|
|
self.model.network.to(self.device) |
|
with mode(self.model): |
|
if amp: |
|
with torch.autocast("cuda"): |
|
outputs[Keys.PRED] = self.inferer( |
|
inputs=inputs, |
|
network=self.model.network, |
|
point_coords=points, |
|
point_labels=point_labels, |
|
class_vector=label_prompt, |
|
labels=labels, |
|
label_set=val_label_set, |
|
) |
|
else: |
|
outputs[Keys.PRED] = self.inferer( |
|
inputs=inputs, |
|
network=self.model.network, |
|
point_coords=points, |
|
point_labels=point_labels, |
|
class_vector=label_prompt, |
|
labels=labels, |
|
label_set=val_label_set, |
|
) |
|
inputs = reset_ops_id(inputs) |
|
|
|
outputs["label_prompt"] = ( |
|
label_prompt.unsqueeze(0) if label_prompt is not None else None |
|
) |
|
outputs["points"] = points.unsqueeze(0) if points is not None else None |
|
outputs["point_labels"] = ( |
|
point_labels.unsqueeze(0) if point_labels is not None else None |
|
) |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
return outputs |
|
|
|
def postprocess(self, outputs, **kwargs): |
|
for key, value in kwargs.items(): |
|
if ( |
|
key in self._postprocess_params |
|
and value != self._postprocess_params[key] |
|
): |
|
logging.warning( |
|
f"Please set the parameter {key} during initialization." |
|
) |
|
|
|
if key not in self.POSTPROCESSING_EXTRA_ARGS: |
|
logging.warning(f"Cannot set parameter {key} for postprocessing.") |
|
outputs = self.postprocessing_transforms(decollate_batch(outputs)) |
|
return outputs |
|
|
|
|
|
def register_simple_pipeline(): |
|
PIPELINE_REGISTRY.register_pipeline( |
|
"vista3d", |
|
pipeline_class=VISTA3DPipeline, |
|
pt_model=AutoModel, |
|
default={"pt": (os.path.join(FILE_PATH, "vista3d_pretrained_model"), "")}, |
|
type="image", |
|
) |
|
|