|
|
|
|
|
""" |
|
inference_brain2vec_PCA.py |
|
|
|
Loads a pre-trained PCA-based Brain2Vec model (saved with joblib) and performs |
|
inference on one or more input images. Produces embeddings (and optional |
|
reconstructions) for each image. |
|
|
|
Example usage: |
|
|
|
python inference_brain2vec_PCA.py \ |
|
--pca_model pca_model.joblib \ |
|
--input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \ |
|
--output_dir pca_output \ |
|
--embeddings_filename pca_embeddings_2 \ |
|
--save_recons |
|
|
|
Or, if you have a CSV with image paths: |
|
|
|
python inference_brain2vec_PCA.py \ |
|
--pca_model pca_model.joblib \ |
|
--csv_input inputs.csv \ |
|
--output_dir pca_output \ |
|
--embeddings_filename pca_embeddings_all |
|
""" |
|
|
|
import os |
|
import argparse |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from joblib import load |
|
import pandas as pd |
|
|
|
from monai.transforms import ( |
|
Compose, |
|
CopyItemsD, |
|
LoadImageD, |
|
EnsureChannelFirstD, |
|
SpacingD, |
|
ResizeWithPadOrCropD, |
|
ScaleIntensityD, |
|
) |
|
|
|
|
|
RESOLUTION = 2 |
|
INPUT_SHAPE_AE = (80, 96, 80) |
|
FLATTENED_DIM = INPUT_SHAPE_AE[0] * INPUT_SHAPE_AE[1] * INPUT_SHAPE_AE[2] |
|
|
|
|
|
transforms_fn = Compose([ |
|
CopyItemsD(keys={'image_path'}, names=['image']), |
|
LoadImageD(image_only=True, keys=['image']), |
|
EnsureChannelFirstD(keys=['image']), |
|
SpacingD(pixdim=RESOLUTION, keys=['image']), |
|
ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']), |
|
ScaleIntensityD(minv=0, maxv=1, keys=['image']), |
|
]) |
|
|
|
|
|
def preprocess_mri(image_path: str) -> torch.Tensor: |
|
""" |
|
Preprocess an MRI using MONAI transforms to produce |
|
a 5D Torch tensor: (batch=1, channel=1, D, H, W). |
|
|
|
Args: |
|
image_path (str): Path to the MRI (e.g., .nii.gz file). |
|
|
|
Returns: |
|
torch.Tensor: Preprocessed 5D tensor of shape (1, 1, D, H, W). |
|
""" |
|
data_dict = {"image_path": image_path} |
|
output_dict = transforms_fn(data_dict) |
|
|
|
image_tensor = output_dict["image"].unsqueeze(0) |
|
return image_tensor.float() |
|
|
|
|
|
class PCABrain2vec(nn.Module): |
|
""" |
|
A PCA-based 'autoencoder' that mimics a typical VAE interface: |
|
- from_pretrained(...) to load a PCA model from disk |
|
- forward(...) returns (reconstruction, embedding, None) |
|
|
|
Steps: |
|
1. Flatten the input volume (N, 1, D, H, W) => (N, 614400). |
|
2. Transform -> embeddings => shape (N, n_components). |
|
3. Inverse transform -> recon => shape (N, 614400). |
|
4. Reshape => (N, 1, D, H, W). |
|
""" |
|
|
|
def __init__(self, pca_model=None): |
|
super().__init__() |
|
self.pca_model = pca_model |
|
|
|
def forward(self, x: torch.Tensor): |
|
""" |
|
Perform a forward pass of the PCA-based "autoencoder". |
|
|
|
Args: |
|
x (torch.Tensor): Input of shape (N, 1, D, H, W). |
|
|
|
Returns: |
|
tuple(torch.Tensor, torch.Tensor, None): |
|
- reconstruction: (N, 1, D, H, W) |
|
- embedding: (N, n_components) |
|
- None (to align with the typical VAE interface). |
|
""" |
|
n_samples = x.shape[0] |
|
x_cpu = x.detach().cpu().numpy() |
|
x_flat = x_cpu.reshape(n_samples, -1) |
|
|
|
|
|
embedding_np = self.pca_model.transform(x_flat) |
|
|
|
|
|
recon_np = self.pca_model.inverse_transform(embedding_np) |
|
recon_np = recon_np.reshape(n_samples, 1, *INPUT_SHAPE_AE) |
|
|
|
|
|
reconstruction_torch = torch.from_numpy(recon_np).float() |
|
embedding_torch = torch.from_numpy(embedding_np).float() |
|
return reconstruction_torch, embedding_torch, None |
|
|
|
@staticmethod |
|
def from_pretrained(pca_path: str) -> "PCABrain2vec": |
|
""" |
|
Load a pre-trained PCA model (pickled or joblib) from disk. |
|
|
|
Args: |
|
pca_path (str): File path to the PCA model. |
|
|
|
Returns: |
|
PCABrain2vec: An instance wrapping the loaded PCA model. |
|
""" |
|
if not os.path.exists(pca_path): |
|
raise FileNotFoundError(f"Could not find PCA model at {pca_path}") |
|
|
|
pca_model = load(pca_path) |
|
return PCABrain2vec(pca_model=pca_model) |
|
|
|
|
|
def main() -> None: |
|
""" |
|
Main function to parse command-line arguments and run inference |
|
with a pre-trained PCA Brain2Vec model. |
|
""" |
|
parser = argparse.ArgumentParser( |
|
description="PCA-based Brain2Vec Inference Script" |
|
) |
|
parser.add_argument( |
|
"--pca_model", type=str, required=True, |
|
help="Path to the saved PCA model (.joblib)." |
|
) |
|
parser.add_argument( |
|
"--output_dir", type=str, default="./pca_inference_outputs", |
|
help="Directory to save embeddings/reconstructions." |
|
) |
|
|
|
parser.add_argument( |
|
"--input_images", type=str, nargs="*", |
|
help="One or more image paths for inference." |
|
) |
|
parser.add_argument( |
|
"--csv_input", type=str, default=None, |
|
help="Path to a CSV containing column 'image_path'." |
|
) |
|
parser.add_argument( |
|
"--embeddings_filename", |
|
type=str, |
|
required=True, |
|
help="Filename (without path) to save the stacked embeddings (e.g., 'pca_embeddings.npy')." |
|
) |
|
parser.add_argument( |
|
"--save_recons", |
|
action="store_true", |
|
help="If set, save each reconstruction as .npy. Default is not to save." |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
pca_brain2vec = PCABrain2vec.from_pretrained(args.pca_model) |
|
pca_brain2vec.eval() |
|
|
|
|
|
if args.csv_input: |
|
df = pd.read_csv(args.csv_input) |
|
if "image_path" not in df.columns: |
|
raise ValueError("CSV must contain a column named 'image_path'.") |
|
image_paths = df["image_path"].tolist() |
|
else: |
|
if not args.input_images: |
|
raise ValueError( |
|
"Must provide either --csv_input or --input_images." |
|
) |
|
image_paths = args.input_images |
|
|
|
|
|
all_embeddings = [] |
|
for i, img_path in enumerate(image_paths): |
|
if not os.path.exists(img_path): |
|
raise FileNotFoundError(f"Image not found: {img_path}") |
|
|
|
|
|
img_tensor = preprocess_mri(img_path) |
|
|
|
|
|
with torch.no_grad(): |
|
recon, embedding, _ = pca_brain2vec(img_tensor) |
|
|
|
|
|
embedding_np = embedding.detach().cpu().numpy() |
|
recon_np = recon.detach().cpu().numpy() |
|
|
|
|
|
all_embeddings.append(embedding_np) |
|
|
|
|
|
if args.save_recons: |
|
out_recon_path = os.path.join(args.output_dir, f"reconstruction_{i}.npy") |
|
np.save(out_recon_path, recon_np) |
|
print(f"[INFO] Saved reconstruction to: {out_recon_path}") |
|
|
|
|
|
stacked_embeddings = np.vstack(all_embeddings) |
|
filename = args.embeddings_filename |
|
if not filename.lower().endswith(".npy"): |
|
filename += ".npy" |
|
|
|
out_embed_path = os.path.join(args.output_dir, filename) |
|
np.save(out_embed_path, stacked_embeddings) |
|
print(f"[INFO] Saved embeddings of shape {stacked_embeddings.shape} to: {out_embed_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |