Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import torch | |
import nibabel | |
import numpy as np | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
from matplotlib import animation | |
from monai.data import MetaTensor | |
from multiprocessing import Process, Pool | |
from sklearn.preprocessing import MinMaxScaler | |
import nibabel as nib | |
import gdown | |
import io | |
from monai.transforms import ( | |
Orientation, | |
EnsureType, | |
ConvertToMultiChannelBasedOnBratsClasses, | |
) | |
from Segformer3D import SegFormer3D | |
def predict_from_folder(model, zip_ref, device, D, H, W): | |
""" | |
Dự đoán kết quả segmentation từ một thư mục chứa các file MRI: flair, t1, t1ce, t2. | |
Args: | |
model: Mô hình segmentation đã được load. | |
zip_ref: File zip chứa các file MRI. | |
device: Thiết bị chạy mô hình ("cuda" hoặc "cpu"). | |
D, H, W: Kích thước của đầu vào sau khi crop. | |
Returns: | |
prediction: Mặt nạ segmentation dự đoán (numpy array). | |
inputs_rgb: Dữ liệu đầu vào đã chuẩn hóa về khoảng [0, 255] cho hiển thị màu. | |
""" | |
MRI_TYPE = ["flair", "t1", "t1ce", "t2"] | |
def load_nii_from_bytes(data_bytes): | |
"""Load file NIfTI từ bytes.""" | |
file_like = io.BytesIO(data_bytes) | |
return nib.Nifti1Image.from_file_map({'header': nib.FileHolder(fileobj=file_like), | |
'image': nib.FileHolder(fileobj=file_like)}) | |
def normalize(x): | |
"""Chuẩn hóa dữ liệu về khoảng [0, 1], đồng thời lưu min và max.""" | |
min_val = np.min(x) | |
max_val = np.max(x) | |
scaler = MinMaxScaler(feature_range=(0, 1)) | |
normalized_1D_array = scaler.fit_transform(x.reshape(-1, x.shape[-1])) | |
return normalized_1D_array.reshape(x.shape), min_val, max_val | |
def denormalize_to_rgb(x, min_val, max_val): | |
"""Chuyển dữ liệu từ [0, 1] về [0, 255].""" | |
return ((x * (max_val - min_val)) + min_val).clip(0, 255).astype(np.uint8) | |
def orient(x, affine): | |
"""Chuyển hệ tọa độ về chuẩn RAS.""" | |
meta_tensor = MetaTensor(x=x, affine=affine) | |
oriented_tensor = Orientation(axcodes="RAS")(meta_tensor) | |
return EnsureType(data_type="numpy", track_meta=False)(oriented_tensor) | |
def crop_brats2021_zero_pixels(x): | |
"""Cắt giảm kích thước về (D, H, W).""" | |
H_start = (x.shape[1] - H) // 2 | |
W_start = (x.shape[2] - W) // 2 | |
D_start = (x.shape[3] - D) // 2 | |
return x[:, H_start:H_start + H, W_start:W_start + W, D_start:D_start + D] | |
def preprocess_modality(zip_ref, mri_type): | |
"""Tiền xử lý cho từng modality.""" | |
extracted_files = zip_ref.namelist() | |
nii_files = [f for f in extracted_files if f.lower().endswith(f'{mri_type}.nii')] | |
if not nii_files: | |
raise FileNotFoundError(f"No files ending with {mri_type}.nii found.") | |
nii_file = nii_files[0] | |
data_bytes = zip_ref.read(nii_file) | |
nii_image = load_nii_from_bytes(data_bytes) | |
data = nii_image.get_fdata() | |
affine = nii_image.affine | |
data, min_val, max_val = normalize(data) | |
data = data[np.newaxis, ...] | |
data = orient(data, affine) | |
data = crop_brats2021_zero_pixels(data) | |
return data, min_val, max_val | |
# Tiền xử lý cho các modality | |
modalities = [] | |
min_max_values = [] # Lưu min và max cho mỗi modality | |
for mri_type in MRI_TYPE: | |
modality, min_val, max_val = preprocess_modality(zip_ref, mri_type) | |
modalities.append(modality) | |
min_max_values.append((min_val, max_val)) | |
inputs = np.concatenate(modalities, axis=0) # (4, D, H, W) | |
inputs = torch.tensor(inputs).unsqueeze(0).to(device).float() | |
# Dự đoán với mô hình | |
model.eval() | |
with torch.no_grad(): | |
logits = model(inputs) | |
probabilities = torch.sigmoid(logits) | |
prediction = (probabilities > 0.5).int() | |
inputs_rgb = (inputs.squeeze(0).cpu().numpy()*255).astype(np.int32) | |
return prediction.squeeze(0).cpu().numpy(),inputs_rgb | |
def load_model(checkpoint_path, device): | |
model = SegFormer3D() | |
model = model.to(device) | |
# model = torch.nn.DataParallel(model) | |
checkpoint = torch.load(checkpoint_path,weights_only=True, map_location=device) | |
model.load_state_dict(checkpoint['model_state_dict'],strict=False) | |
model.eval() | |
return model | |
def overlay_mask(modalities, prediction): | |
# Giả sử prediction có kích thước (D, H, W, 3) và modalities có kích thước (D, H, W, C) | |
D, H, W = modalities.shape[:3] | |
# Khởi tạo một mảng để lưu ảnh overlay cuối cùng | |
overlay_all_slices = [] | |
final_masks = [] | |
flair_slice_colors = [] | |
for slice_idx in range(D): | |
# Lấy modality flair và dự đoán cho slice này | |
flair_slice = modalities[slice_idx, :, :, 0] # (H, W) - Chọn flair modality | |
prediction_slice = prediction[slice_idx, :, :, :] # (H, W, 3) | |
# Tách các mask WT, TC, ET | |
wt_mask = prediction_slice[:, :, 1] # Kênh 2: WT | |
tc_mask = prediction_slice[:, :, 0] # Kênh 1: TC | |
et_mask = prediction_slice[:, :, 2] # Kênh 3: ET | |
# Chồng các kênh theo thứ tự ET > TC > WT | |
final_mask = np.zeros_like(wt_mask) | |
final_mask[et_mask > 0] = 3 # U tăng cường (ET) | |
final_mask[(tc_mask > 0) & (final_mask == 0)] = 2 # Lõi u (TC) | |
final_mask[(wt_mask > 0) & (final_mask == 0)] = 1 # Toàn bộ khối u (WT) | |
final_masks.append(final_mask) | |
# Chuyển flair_slice thành ảnh màu với 3 kênh | |
flair_slice_color = np.stack((flair_slice,) * 3, axis=-1) # (H, W, 3) | |
flair_slice_colors.append(np.copy(flair_slice_color)) | |
# Overlay các vùng khác nhau bằng màu RGB | |
flair_slice_color[final_mask == 1] = [255, 255, 0] # WT - Đỏ | |
flair_slice_color[final_mask == 2] = [0, 255, 255] # TC - Xanh lá | |
flair_slice_color[final_mask == 3] = [255, 0, 255] # ET - Xanh dương | |
# Lưu ảnh overlay màu vào mảng kết quả | |
overlay_all_slices.append(flair_slice_color) | |
return np.stack(overlay_all_slices) | |
def __call__(zip_ref): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
url = "https://drive.google.com/uc?id=1qtWBuwE8PVb-_dzLbl_ySEPX6fNtEGBS" | |
checkpoint_path = "Segformer3D_Brats2021_epoch_50_model.pth" | |
if not os.path.exists(checkpoint_path): | |
gdown.download(url, checkpoint_path, quiet=False) | |
model = load_model(checkpoint_path,device) | |
prediction,modalities = predict_from_folder(model, zip_ref, device, D=128, H=128, W=128) | |
modalities = np.transpose(modalities,(3,2,1,0)) | |
prediction = np.transpose(prediction,(3,2,1,0)) | |
overlay = overlay_mask(modalities,prediction) | |
return overlay.astype(np.uint8) | |