Segformer3DBraTS2021 / Segformer3DBRATS2021.py
NhatNam214's picture
api brats
3b00cde
import os
import sys
import torch
import nibabel
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib import animation
from monai.data import MetaTensor
from multiprocessing import Process, Pool
from sklearn.preprocessing import MinMaxScaler
import nibabel as nib
import gdown
import io
from monai.transforms import (
Orientation,
EnsureType,
ConvertToMultiChannelBasedOnBratsClasses,
)
from Segformer3D import SegFormer3D
def predict_from_folder(model, zip_ref, device, D, H, W):
"""
Dự đoán kết quả segmentation từ một thư mục chứa các file MRI: flair, t1, t1ce, t2.
Args:
model: Mô hình segmentation đã được load.
zip_ref: File zip chứa các file MRI.
device: Thiết bị chạy mô hình ("cuda" hoặc "cpu").
D, H, W: Kích thước của đầu vào sau khi crop.
Returns:
prediction: Mặt nạ segmentation dự đoán (numpy array).
inputs_rgb: Dữ liệu đầu vào đã chuẩn hóa về khoảng [0, 255] cho hiển thị màu.
"""
MRI_TYPE = ["flair", "t1", "t1ce", "t2"]
def load_nii_from_bytes(data_bytes):
"""Load file NIfTI từ bytes."""
file_like = io.BytesIO(data_bytes)
return nib.Nifti1Image.from_file_map({'header': nib.FileHolder(fileobj=file_like),
'image': nib.FileHolder(fileobj=file_like)})
def normalize(x):
"""Chuẩn hóa dữ liệu về khoảng [0, 1], đồng thời lưu min và max."""
min_val = np.min(x)
max_val = np.max(x)
scaler = MinMaxScaler(feature_range=(0, 1))
normalized_1D_array = scaler.fit_transform(x.reshape(-1, x.shape[-1]))
return normalized_1D_array.reshape(x.shape), min_val, max_val
def denormalize_to_rgb(x, min_val, max_val):
"""Chuyển dữ liệu từ [0, 1] về [0, 255]."""
return ((x * (max_val - min_val)) + min_val).clip(0, 255).astype(np.uint8)
def orient(x, affine):
"""Chuyển hệ tọa độ về chuẩn RAS."""
meta_tensor = MetaTensor(x=x, affine=affine)
oriented_tensor = Orientation(axcodes="RAS")(meta_tensor)
return EnsureType(data_type="numpy", track_meta=False)(oriented_tensor)
def crop_brats2021_zero_pixels(x):
"""Cắt giảm kích thước về (D, H, W)."""
H_start = (x.shape[1] - H) // 2
W_start = (x.shape[2] - W) // 2
D_start = (x.shape[3] - D) // 2
return x[:, H_start:H_start + H, W_start:W_start + W, D_start:D_start + D]
def preprocess_modality(zip_ref, mri_type):
"""Tiền xử lý cho từng modality."""
extracted_files = zip_ref.namelist()
nii_files = [f for f in extracted_files if f.lower().endswith(f'{mri_type}.nii')]
if not nii_files:
raise FileNotFoundError(f"No files ending with {mri_type}.nii found.")
nii_file = nii_files[0]
data_bytes = zip_ref.read(nii_file)
nii_image = load_nii_from_bytes(data_bytes)
data = nii_image.get_fdata()
affine = nii_image.affine
data, min_val, max_val = normalize(data)
data = data[np.newaxis, ...]
data = orient(data, affine)
data = crop_brats2021_zero_pixels(data)
return data, min_val, max_val
# Tiền xử lý cho các modality
modalities = []
min_max_values = [] # Lưu min và max cho mỗi modality
for mri_type in MRI_TYPE:
modality, min_val, max_val = preprocess_modality(zip_ref, mri_type)
modalities.append(modality)
min_max_values.append((min_val, max_val))
inputs = np.concatenate(modalities, axis=0) # (4, D, H, W)
inputs = torch.tensor(inputs).unsqueeze(0).to(device).float()
# Dự đoán với mô hình
model.eval()
with torch.no_grad():
logits = model(inputs)
probabilities = torch.sigmoid(logits)
prediction = (probabilities > 0.5).int()
inputs_rgb = (inputs.squeeze(0).cpu().numpy()*255).astype(np.int32)
return prediction.squeeze(0).cpu().numpy(),inputs_rgb
def load_model(checkpoint_path, device):
model = SegFormer3D()
model = model.to(device)
# model = torch.nn.DataParallel(model)
checkpoint = torch.load(checkpoint_path,weights_only=True, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'],strict=False)
model.eval()
return model
def overlay_mask(modalities, prediction):
# Giả sử prediction có kích thước (D, H, W, 3) và modalities có kích thước (D, H, W, C)
D, H, W = modalities.shape[:3]
# Khởi tạo một mảng để lưu ảnh overlay cuối cùng
overlay_all_slices = []
final_masks = []
flair_slice_colors = []
for slice_idx in range(D):
# Lấy modality flair và dự đoán cho slice này
flair_slice = modalities[slice_idx, :, :, 0] # (H, W) - Chọn flair modality
prediction_slice = prediction[slice_idx, :, :, :] # (H, W, 3)
# Tách các mask WT, TC, ET
wt_mask = prediction_slice[:, :, 1] # Kênh 2: WT
tc_mask = prediction_slice[:, :, 0] # Kênh 1: TC
et_mask = prediction_slice[:, :, 2] # Kênh 3: ET
# Chồng các kênh theo thứ tự ET > TC > WT
final_mask = np.zeros_like(wt_mask)
final_mask[et_mask > 0] = 3 # U tăng cường (ET)
final_mask[(tc_mask > 0) & (final_mask == 0)] = 2 # Lõi u (TC)
final_mask[(wt_mask > 0) & (final_mask == 0)] = 1 # Toàn bộ khối u (WT)
final_masks.append(final_mask)
# Chuyển flair_slice thành ảnh màu với 3 kênh
flair_slice_color = np.stack((flair_slice,) * 3, axis=-1) # (H, W, 3)
flair_slice_colors.append(np.copy(flair_slice_color))
# Overlay các vùng khác nhau bằng màu RGB
flair_slice_color[final_mask == 1] = [255, 255, 0] # WT - Đỏ
flair_slice_color[final_mask == 2] = [0, 255, 255] # TC - Xanh lá
flair_slice_color[final_mask == 3] = [255, 0, 255] # ET - Xanh dương
# Lưu ảnh overlay màu vào mảng kết quả
overlay_all_slices.append(flair_slice_color)
return np.stack(overlay_all_slices)
def __call__(zip_ref):
device = "cuda" if torch.cuda.is_available() else "cpu"
url = "https://drive.google.com/uc?id=1qtWBuwE8PVb-_dzLbl_ySEPX6fNtEGBS"
checkpoint_path = "Segformer3D_Brats2021_epoch_50_model.pth"
if not os.path.exists(checkpoint_path):
gdown.download(url, checkpoint_path, quiet=False)
model = load_model(checkpoint_path,device)
prediction,modalities = predict_from_folder(model, zip_ref, device, D=128, H=128, W=128)
modalities = np.transpose(modalities,(3,2,1,0))
prediction = np.transpose(prediction,(3,2,1,0))
overlay = overlay_mask(modalities,prediction)
return overlay.astype(np.uint8)