File size: 6,946 Bytes
3b00cde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import sys
import torch
import nibabel
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib import animation
from monai.data import MetaTensor
from multiprocessing import Process, Pool
from sklearn.preprocessing import MinMaxScaler 
import nibabel as nib
import gdown

import io
from monai.transforms import (
    Orientation, 
    EnsureType,
    ConvertToMultiChannelBasedOnBratsClasses,
)
from Segformer3D import SegFormer3D
def predict_from_folder(model, zip_ref, device, D, H, W):
    """
    Dự đoán kết quả segmentation từ một thư mục chứa các file MRI: flair, t1, t1ce, t2.
    
    Args:
        model: Mô hình segmentation đã được load.
        zip_ref: File zip chứa các file MRI.
        device: Thiết bị chạy mô hình ("cuda" hoặc "cpu").
        D, H, W: Kích thước của đầu vào sau khi crop.

    Returns:
        prediction: Mặt nạ segmentation dự đoán (numpy array).
        inputs_rgb: Dữ liệu đầu vào đã chuẩn hóa về khoảng [0, 255] cho hiển thị màu.
    """
    MRI_TYPE = ["flair", "t1", "t1ce", "t2"]

    def load_nii_from_bytes(data_bytes):
        """Load file NIfTI từ bytes."""
        file_like = io.BytesIO(data_bytes)
        return nib.Nifti1Image.from_file_map({'header': nib.FileHolder(fileobj=file_like),
                                              'image': nib.FileHolder(fileobj=file_like)})

    def normalize(x):
        """Chuẩn hóa dữ liệu về khoảng [0, 1], đồng thời lưu min và max."""
        min_val = np.min(x)
        max_val = np.max(x)
        scaler = MinMaxScaler(feature_range=(0, 1))
        normalized_1D_array = scaler.fit_transform(x.reshape(-1, x.shape[-1]))
        return normalized_1D_array.reshape(x.shape), min_val, max_val

    def denormalize_to_rgb(x, min_val, max_val):
        """Chuyển dữ liệu từ [0, 1] về [0, 255]."""
        return ((x * (max_val - min_val)) + min_val).clip(0, 255).astype(np.uint8)

    def orient(x, affine):
        """Chuyển hệ tọa độ về chuẩn RAS."""
        meta_tensor = MetaTensor(x=x, affine=affine)
        oriented_tensor = Orientation(axcodes="RAS")(meta_tensor)
        return EnsureType(data_type="numpy", track_meta=False)(oriented_tensor)

    def crop_brats2021_zero_pixels(x):
        """Cắt giảm kích thước về (D, H, W)."""
        H_start = (x.shape[1] - H) // 2
        W_start = (x.shape[2] - W) // 2
        D_start = (x.shape[3] - D) // 2
        return x[:, H_start:H_start + H, W_start:W_start + W, D_start:D_start + D]

    def preprocess_modality(zip_ref, mri_type):
        """Tiền xử lý cho từng modality."""
        extracted_files = zip_ref.namelist()
        nii_files = [f for f in extracted_files if f.lower().endswith(f'{mri_type}.nii')]
        if not nii_files:
            raise FileNotFoundError(f"No files ending with {mri_type}.nii found.")

        nii_file = nii_files[0]
        data_bytes = zip_ref.read(nii_file)
        nii_image = load_nii_from_bytes(data_bytes)

        data = nii_image.get_fdata()
        affine = nii_image.affine
        data, min_val, max_val = normalize(data)
        data = data[np.newaxis, ...]
        data = orient(data, affine)
        data = crop_brats2021_zero_pixels(data)
        return data, min_val, max_val

    # Tiền xử lý cho các modality
    modalities = []
    min_max_values = []  # Lưu min và max cho mỗi modality
    for mri_type in MRI_TYPE:
        modality, min_val, max_val = preprocess_modality(zip_ref, mri_type)
        modalities.append(modality)
        min_max_values.append((min_val, max_val))

    inputs = np.concatenate(modalities, axis=0)  # (4, D, H, W)
    inputs = torch.tensor(inputs).unsqueeze(0).to(device).float()

    # Dự đoán với mô hình
    model.eval()
    with torch.no_grad():
        logits = model(inputs)
        probabilities = torch.sigmoid(logits)
        prediction = (probabilities > 0.5).int()
    inputs_rgb = (inputs.squeeze(0).cpu().numpy()*255).astype(np.int32)
    return prediction.squeeze(0).cpu().numpy(),inputs_rgb

def load_model(checkpoint_path, device):
    model = SegFormer3D()
    model = model.to(device)
    # model = torch.nn.DataParallel(model)
    checkpoint = torch.load(checkpoint_path,weights_only=True, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'],strict=False)
    model.eval()
    return model
def overlay_mask(modalities, prediction):
    # Giả sử prediction có kích thước (D, H, W, 3) và modalities có kích thước (D, H, W, C)
    D, H, W = modalities.shape[:3]

    # Khởi tạo một mảng để lưu ảnh overlay cuối cùng
    overlay_all_slices = []
    final_masks = []
    flair_slice_colors = []
    for slice_idx in range(D):
        # Lấy modality flair và dự đoán cho slice này
        flair_slice = modalities[slice_idx, :, :, 0]  # (H, W) - Chọn flair modality
        prediction_slice = prediction[slice_idx, :, :, :]  # (H, W, 3)

        # Tách các mask WT, TC, ET
        wt_mask = prediction_slice[:, :, 1]  # Kênh 2: WT
        tc_mask = prediction_slice[:, :, 0]  # Kênh 1: TC
        et_mask = prediction_slice[:, :, 2]  # Kênh 3: ET

        # Chồng các kênh theo thứ tự ET > TC > WT
        final_mask = np.zeros_like(wt_mask)

        final_mask[et_mask > 0] = 3  # U tăng cường (ET)
        final_mask[(tc_mask > 0) & (final_mask == 0)] = 2  # Lõi u (TC)
        final_mask[(wt_mask > 0) & (final_mask == 0)] = 1  # Toàn bộ khối u (WT)
        final_masks.append(final_mask)
        # Chuyển flair_slice thành ảnh màu với 3 kênh
        flair_slice_color = np.stack((flair_slice,) * 3, axis=-1)  # (H, W, 3)
        flair_slice_colors.append(np.copy(flair_slice_color))
        # Overlay các vùng khác nhau bằng màu RGB
        flair_slice_color[final_mask == 1] = [255, 255, 0]    # WT - Đỏ
        flair_slice_color[final_mask == 2] = [0, 255, 255]    # TC - Xanh lá
        flair_slice_color[final_mask == 3] = [255, 0, 255]    # ET - Xanh dương
        
        # Lưu ảnh overlay màu vào mảng kết quả
        overlay_all_slices.append(flair_slice_color)
    return np.stack(overlay_all_slices)
def __call__(zip_ref):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    url = "https://drive.google.com/uc?id=1qtWBuwE8PVb-_dzLbl_ySEPX6fNtEGBS"
    checkpoint_path = "Segformer3D_Brats2021_epoch_50_model.pth"
    if not os.path.exists(checkpoint_path):
        gdown.download(url, checkpoint_path, quiet=False)
    model = load_model(checkpoint_path,device)
    prediction,modalities = predict_from_folder(model, zip_ref, device, D=128, H=128, W=128)
    modalities = np.transpose(modalities,(3,2,1,0))
    prediction = np.transpose(prediction,(3,2,1,0))
    overlay = overlay_mask(modalities,prediction)
    return overlay.astype(np.uint8)