import zipfile import nibabel as nib import numpy as np import gradio as gr import Segformer3DBRATS2021 # Giả sử bạn đã định nghĩa mô hình này ở đâu đó. import torch import torch.nn.functional as F from io import BytesIO import tempfile def predict_segmentation(zip_file): """ Hàm xử lý file zip chứa dữ liệu MRI, gọi mô hình Segformer3D để dự đoán và trả về file .nii kết quả. """ try: # Giải nén file zip with zipfile.ZipFile(zip_file) as zip_ref: overlay = Segformer3DBRATS2021.__call__(zip_ref) overlay_all_slices = np.transpose(overlay,(3,2,1,0)) overlay_tensor = torch.tensor(overlay_all_slices, dtype=torch.float32).unsqueeze(0) target_shape = (240, 240, 155) # Tính toán padding (thêm padding để đạt được kích thước mong muốn) z_pad_before = (target_shape[0] - overlay_tensor.shape[2]) // 2 z_pad_after = target_shape[0] - overlay_tensor.shape[2] - z_pad_before y_pad_before = (target_shape[1] - overlay_tensor.shape[3]) // 2 y_pad_after = target_shape[1] - overlay_tensor.shape[3] - y_pad_before x_pad_before = (target_shape[2] - overlay_tensor.shape[4]) // 2 x_pad_after = target_shape[2] - overlay_tensor.shape[4] - x_pad_before # Tạo padding (đệm đen) padded_tensor = F.pad(overlay_tensor, (x_pad_before, x_pad_after, y_pad_before, y_pad_after, z_pad_before, z_pad_after), value=0) assert padded_tensor.shape[2:] == target_shape, f"Expected shape {target_shape}, got {padded_tensor.shape[2:]}" padded_tensor = padded_tensor.permute(0,2,3,4,1) padded_slices = padded_tensor.squeeze(0).numpy() for i in range(padded_slices.shape[2]): padded_slices[:, :, i, :] = np.flipud(np.fliplr(padded_slices[:, :, i, :])) padded_slices = padded_slices.astype(np.uint8) affine = np.eye(4) nii_image = nib.Nifti1Image(padded_slices, affine) with tempfile.NamedTemporaryFile(delete=False, suffix='.nii') as temp_file: nii_file_path = temp_file.name nib.save(nii_image, nii_file_path) # Trả về đường dẫn đến file NIfTI đã lưu return nii_file_path except Exception as e: return str(e) def main(): # Định nghĩa giao diện Gradio inputs = gr.File(label="Upload a ZIP file containing MRI modalities (flair, t1, t1ce, t2)") outputs = gr.File(label="Segmentation Result (.nii)") gr.Interface( fn=predict_segmentation, inputs=inputs, outputs=outputs, title="3D Brain Tumor Segmentation", description="Upload a ZIP file containing MRI modalities (flair, t1, t1ce, t2).", allow_flagging="never", ).launch(show_error=True) if __name__ == '__main__': main()