NhatNam214's picture
api brats
3b00cde
import zipfile
import nibabel as nib
import numpy as np
import gradio as gr
import Segformer3DBRATS2021 # Giả sử bạn đã định nghĩa mô hình này ở đâu đó.
import torch
import torch.nn.functional as F
from io import BytesIO
import tempfile
def predict_segmentation(zip_file):
"""
Hàm xử lý file zip chứa dữ liệu MRI, gọi mô hình Segformer3D để dự đoán và trả về file .nii kết quả.
"""
try:
# Giải nén file zip
with zipfile.ZipFile(zip_file) as zip_ref:
overlay = Segformer3DBRATS2021.__call__(zip_ref)
overlay_all_slices = np.transpose(overlay,(3,2,1,0))
overlay_tensor = torch.tensor(overlay_all_slices, dtype=torch.float32).unsqueeze(0)
target_shape = (240, 240, 155)
# Tính toán padding (thêm padding để đạt được kích thước mong muốn)
z_pad_before = (target_shape[0] - overlay_tensor.shape[2]) // 2
z_pad_after = target_shape[0] - overlay_tensor.shape[2] - z_pad_before
y_pad_before = (target_shape[1] - overlay_tensor.shape[3]) // 2
y_pad_after = target_shape[1] - overlay_tensor.shape[3] - y_pad_before
x_pad_before = (target_shape[2] - overlay_tensor.shape[4]) // 2
x_pad_after = target_shape[2] - overlay_tensor.shape[4] - x_pad_before
# Tạo padding (đệm đen)
padded_tensor = F.pad(overlay_tensor, (x_pad_before, x_pad_after, y_pad_before, y_pad_after, z_pad_before, z_pad_after), value=0)
assert padded_tensor.shape[2:] == target_shape, f"Expected shape {target_shape}, got {padded_tensor.shape[2:]}"
padded_tensor = padded_tensor.permute(0,2,3,4,1)
padded_slices = padded_tensor.squeeze(0).numpy()
for i in range(padded_slices.shape[2]):
padded_slices[:, :, i, :] = np.flipud(np.fliplr(padded_slices[:, :, i, :]))
padded_slices = padded_slices.astype(np.uint8)
affine = np.eye(4)
nii_image = nib.Nifti1Image(padded_slices, affine)
with tempfile.NamedTemporaryFile(delete=False, suffix='.nii') as temp_file:
nii_file_path = temp_file.name
nib.save(nii_image, nii_file_path)
# Trả về đường dẫn đến file NIfTI đã lưu
return nii_file_path
except Exception as e:
return str(e)
def main():
# Định nghĩa giao diện Gradio
inputs = gr.File(label="Upload a ZIP file containing MRI modalities (flair, t1, t1ce, t2)")
outputs = gr.File(label="Segmentation Result (.nii)")
gr.Interface(
fn=predict_segmentation,
inputs=inputs,
outputs=outputs,
title="3D Brain Tumor Segmentation",
description="Upload a ZIP file containing MRI modalities (flair, t1, t1ce, t2).",
allow_flagging="never",
).launch(show_error=True)
if __name__ == '__main__':
main()