import os import sys import torch import nibabel import numpy as np from tqdm import tqdm import matplotlib.pyplot as plt from matplotlib import animation from monai.data import MetaTensor from multiprocessing import Process, Pool from sklearn.preprocessing import MinMaxScaler import nibabel as nib import gdown import io from monai.transforms import ( Orientation, EnsureType, ConvertToMultiChannelBasedOnBratsClasses, ) from Segformer3D import SegFormer3D def predict_from_folder(model, zip_ref, device, D, H, W): """ Dự đoán kết quả segmentation từ một thư mục chứa các file MRI: flair, t1, t1ce, t2. Args: model: Mô hình segmentation đã được load. zip_ref: File zip chứa các file MRI. device: Thiết bị chạy mô hình ("cuda" hoặc "cpu"). D, H, W: Kích thước của đầu vào sau khi crop. Returns: prediction: Mặt nạ segmentation dự đoán (numpy array). inputs_rgb: Dữ liệu đầu vào đã chuẩn hóa về khoảng [0, 255] cho hiển thị màu. """ MRI_TYPE = ["flair", "t1", "t1ce", "t2"] def load_nii_from_bytes(data_bytes): """Load file NIfTI từ bytes.""" file_like = io.BytesIO(data_bytes) return nib.Nifti1Image.from_file_map({'header': nib.FileHolder(fileobj=file_like), 'image': nib.FileHolder(fileobj=file_like)}) def normalize(x): """Chuẩn hóa dữ liệu về khoảng [0, 1], đồng thời lưu min và max.""" min_val = np.min(x) max_val = np.max(x) scaler = MinMaxScaler(feature_range=(0, 1)) normalized_1D_array = scaler.fit_transform(x.reshape(-1, x.shape[-1])) return normalized_1D_array.reshape(x.shape), min_val, max_val def denormalize_to_rgb(x, min_val, max_val): """Chuyển dữ liệu từ [0, 1] về [0, 255].""" return ((x * (max_val - min_val)) + min_val).clip(0, 255).astype(np.uint8) def orient(x, affine): """Chuyển hệ tọa độ về chuẩn RAS.""" meta_tensor = MetaTensor(x=x, affine=affine) oriented_tensor = Orientation(axcodes="RAS")(meta_tensor) return EnsureType(data_type="numpy", track_meta=False)(oriented_tensor) def crop_brats2021_zero_pixels(x): """Cắt giảm kích thước về (D, H, W).""" H_start = (x.shape[1] - H) // 2 W_start = (x.shape[2] - W) // 2 D_start = (x.shape[3] - D) // 2 return x[:, H_start:H_start + H, W_start:W_start + W, D_start:D_start + D] def preprocess_modality(zip_ref, mri_type): """Tiền xử lý cho từng modality.""" extracted_files = zip_ref.namelist() nii_files = [f for f in extracted_files if f.lower().endswith(f'{mri_type}.nii')] if not nii_files: raise FileNotFoundError(f"No files ending with {mri_type}.nii found.") nii_file = nii_files[0] data_bytes = zip_ref.read(nii_file) nii_image = load_nii_from_bytes(data_bytes) data = nii_image.get_fdata() affine = nii_image.affine data, min_val, max_val = normalize(data) data = data[np.newaxis, ...] data = orient(data, affine) data = crop_brats2021_zero_pixels(data) return data, min_val, max_val # Tiền xử lý cho các modality modalities = [] min_max_values = [] # Lưu min và max cho mỗi modality for mri_type in MRI_TYPE: modality, min_val, max_val = preprocess_modality(zip_ref, mri_type) modalities.append(modality) min_max_values.append((min_val, max_val)) inputs = np.concatenate(modalities, axis=0) # (4, D, H, W) inputs = torch.tensor(inputs).unsqueeze(0).to(device).float() # Dự đoán với mô hình model.eval() with torch.no_grad(): logits = model(inputs) probabilities = torch.sigmoid(logits) prediction = (probabilities > 0.5).int() inputs_rgb = (inputs.squeeze(0).cpu().numpy()*255).astype(np.int32) return prediction.squeeze(0).cpu().numpy(),inputs_rgb def load_model(checkpoint_path, device): model = SegFormer3D() model = model.to(device) # model = torch.nn.DataParallel(model) checkpoint = torch.load(checkpoint_path,weights_only=True, map_location=device) model.load_state_dict(checkpoint['model_state_dict'],strict=False) model.eval() return model def overlay_mask(modalities, prediction): # Giả sử prediction có kích thước (D, H, W, 3) và modalities có kích thước (D, H, W, C) D, H, W = modalities.shape[:3] # Khởi tạo một mảng để lưu ảnh overlay cuối cùng overlay_all_slices = [] final_masks = [] flair_slice_colors = [] for slice_idx in range(D): # Lấy modality flair và dự đoán cho slice này flair_slice = modalities[slice_idx, :, :, 0] # (H, W) - Chọn flair modality prediction_slice = prediction[slice_idx, :, :, :] # (H, W, 3) # Tách các mask WT, TC, ET wt_mask = prediction_slice[:, :, 1] # Kênh 2: WT tc_mask = prediction_slice[:, :, 0] # Kênh 1: TC et_mask = prediction_slice[:, :, 2] # Kênh 3: ET # Chồng các kênh theo thứ tự ET > TC > WT final_mask = np.zeros_like(wt_mask) final_mask[et_mask > 0] = 3 # U tăng cường (ET) final_mask[(tc_mask > 0) & (final_mask == 0)] = 2 # Lõi u (TC) final_mask[(wt_mask > 0) & (final_mask == 0)] = 1 # Toàn bộ khối u (WT) final_masks.append(final_mask) # Chuyển flair_slice thành ảnh màu với 3 kênh flair_slice_color = np.stack((flair_slice,) * 3, axis=-1) # (H, W, 3) flair_slice_colors.append(np.copy(flair_slice_color)) # Overlay các vùng khác nhau bằng màu RGB flair_slice_color[final_mask == 1] = [255, 255, 0] # WT - Đỏ flair_slice_color[final_mask == 2] = [0, 255, 255] # TC - Xanh lá flair_slice_color[final_mask == 3] = [255, 0, 255] # ET - Xanh dương # Lưu ảnh overlay màu vào mảng kết quả overlay_all_slices.append(flair_slice_color) return np.stack(overlay_all_slices) def __call__(zip_ref): device = "cuda" if torch.cuda.is_available() else "cpu" url = "https://drive.google.com/uc?id=1qtWBuwE8PVb-_dzLbl_ySEPX6fNtEGBS" checkpoint_path = "Segformer3D_Brats2021_epoch_50_model.pth" if not os.path.exists(checkpoint_path): gdown.download(url, checkpoint_path, quiet=False) model = load_model(checkpoint_path,device) prediction,modalities = predict_from_folder(model, zip_ref, device, D=128, H=128, W=128) modalities = np.transpose(modalities,(3,2,1,0)) prediction = np.transpose(prediction,(3,2,1,0)) overlay = overlay_mask(modalities,prediction) return overlay.astype(np.uint8)