Spaces:
Sleeping
Sleeping
Commit
·
3b00cde
1
Parent(s):
045b86f
api brats
Browse files- Segformer3D.py +632 -0
- Segformer3DBRATS2021.py +163 -0
- app.py +69 -0
- requirements.txt +9 -0
Segformer3D.py
ADDED
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import copy
|
4 |
+
from torch import nn
|
5 |
+
from einops import rearrange
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
def build_segformer3d_model(config=None):
|
9 |
+
model = SegFormer3D(
|
10 |
+
in_channels=config["model_parameters"]["in_channels"],
|
11 |
+
sr_ratios=config["model_parameters"]["sr_ratios"],
|
12 |
+
embed_dims=config["model_parameters"]["embed_dims"],
|
13 |
+
patch_kernel_size=config["model_parameters"]["patch_kernel_size"],
|
14 |
+
patch_stride=config["model_parameters"]["patch_stride"],
|
15 |
+
patch_padding=config["model_parameters"]["patch_padding"],
|
16 |
+
mlp_ratios=config["model_parameters"]["mlp_ratios"],
|
17 |
+
num_heads=config["model_parameters"]["num_heads"],
|
18 |
+
depths=config["model_parameters"]["depths"],
|
19 |
+
decoder_head_embedding_dim=config["model_parameters"][
|
20 |
+
"decoder_head_embedding_dim"
|
21 |
+
],
|
22 |
+
num_classes=config["model_parameters"]["num_classes"],
|
23 |
+
decoder_dropout=config["model_parameters"]["decoder_dropout"],
|
24 |
+
)
|
25 |
+
return model
|
26 |
+
|
27 |
+
|
28 |
+
class SegFormer3D(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
in_channels: int = 4,
|
32 |
+
sr_ratios: list = [4, 2, 1, 1],
|
33 |
+
embed_dims: list = [32, 64, 160, 256],
|
34 |
+
patch_kernel_size: list = [7, 3, 3, 3],
|
35 |
+
patch_stride: list = [4, 2, 2, 2],
|
36 |
+
patch_padding: list = [3, 1, 1, 1],
|
37 |
+
mlp_ratios: list = [4, 4, 4, 4],
|
38 |
+
num_heads: list = [1, 2, 5, 8],
|
39 |
+
depths: list = [2, 2, 2, 2],
|
40 |
+
decoder_head_embedding_dim: int = 256,
|
41 |
+
num_classes: int = 3,
|
42 |
+
decoder_dropout: float = 0.0,
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
in_channels: number of the input channels
|
46 |
+
img_volume_dim: spatial resolution of the image volume (Depth, Width, Height)
|
47 |
+
sr_ratios: the rates at which to down sample the sequence length of the embedded patch
|
48 |
+
embed_dims: hidden size of the PatchEmbedded input
|
49 |
+
patch_kernel_size: kernel size for the convolution in the patch embedding module
|
50 |
+
patch_stride: stride for the convolution in the patch embedding module
|
51 |
+
patch_padding: padding for the convolution in the patch embedding module
|
52 |
+
mlp_ratios: at which rate increases the projection dim of the hidden_state in the mlp
|
53 |
+
num_heads: number of attention heads
|
54 |
+
depths: number of attention layers
|
55 |
+
decoder_head_embedding_dim: projection dimension of the mlp layer in the all-mlp-decoder module
|
56 |
+
num_classes: number of the output channel of the network
|
57 |
+
decoder_dropout: dropout rate of the concatenated feature maps
|
58 |
+
|
59 |
+
"""
|
60 |
+
super().__init__()
|
61 |
+
self.segformer_encoder = MixVisionTransformer(
|
62 |
+
in_channels=in_channels,
|
63 |
+
sr_ratios=sr_ratios,
|
64 |
+
embed_dims=embed_dims,
|
65 |
+
patch_kernel_size=patch_kernel_size,
|
66 |
+
patch_stride=patch_stride,
|
67 |
+
patch_padding=patch_padding,
|
68 |
+
mlp_ratios=mlp_ratios,
|
69 |
+
num_heads=num_heads,
|
70 |
+
depths=depths,
|
71 |
+
)
|
72 |
+
# decoder takes in the feature maps in the reversed order
|
73 |
+
reversed_embed_dims = embed_dims[::-1]
|
74 |
+
self.segformer_decoder = SegFormerDecoderHead(
|
75 |
+
input_feature_dims=reversed_embed_dims,
|
76 |
+
decoder_head_embedding_dim=decoder_head_embedding_dim,
|
77 |
+
num_classes=num_classes,
|
78 |
+
dropout=decoder_dropout,
|
79 |
+
)
|
80 |
+
self.apply(self._init_weights)
|
81 |
+
|
82 |
+
def _init_weights(self, m):
|
83 |
+
if isinstance(m, nn.Linear):
|
84 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
85 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
86 |
+
nn.init.constant_(m.bias, 0)
|
87 |
+
elif isinstance(m, nn.LayerNorm):
|
88 |
+
nn.init.constant_(m.bias, 0)
|
89 |
+
nn.init.constant_(m.weight, 1.0)
|
90 |
+
elif isinstance(m, nn.BatchNorm2d):
|
91 |
+
nn.init.constant_(m.bias, 0)
|
92 |
+
nn.init.constant_(m.weight, 1.0)
|
93 |
+
elif isinstance(m, nn.BatchNorm3d):
|
94 |
+
nn.init.constant_(m.bias, 0)
|
95 |
+
nn.init.constant_(m.weight, 1.0)
|
96 |
+
elif isinstance(m, nn.Conv2d):
|
97 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
98 |
+
fan_out //= m.groups
|
99 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
100 |
+
if m.bias is not None:
|
101 |
+
m.bias.data.zero_()
|
102 |
+
elif isinstance(m, nn.Conv3d):
|
103 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
|
104 |
+
fan_out //= m.groups
|
105 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
106 |
+
if m.bias is not None:
|
107 |
+
m.bias.data.zero_()
|
108 |
+
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
# embedding the input
|
112 |
+
x = self.segformer_encoder(x)
|
113 |
+
# # unpacking the embedded features generated by the transformer
|
114 |
+
c1 = x[0]
|
115 |
+
c2 = x[1]
|
116 |
+
c3 = x[2]
|
117 |
+
c4 = x[3]
|
118 |
+
# decoding the embedded features
|
119 |
+
x = self.segformer_decoder(c1, c2, c3, c4)
|
120 |
+
return x
|
121 |
+
|
122 |
+
# ----------------------------------------------------- encoder -----------------------------------------------------
|
123 |
+
class PatchEmbedding(nn.Module):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
in_channel: int = 4,
|
127 |
+
embed_dim: int = 768,
|
128 |
+
kernel_size: int = 7,
|
129 |
+
stride: int = 4,
|
130 |
+
padding: int = 3,
|
131 |
+
):
|
132 |
+
"""
|
133 |
+
in_channels: number of the channels in the input volume
|
134 |
+
embed_dim: embedding dimmesion of the patch
|
135 |
+
"""
|
136 |
+
super().__init__()
|
137 |
+
self.patch_embeddings = nn.Conv3d(
|
138 |
+
in_channel,
|
139 |
+
embed_dim,
|
140 |
+
kernel_size=kernel_size,
|
141 |
+
stride=stride,
|
142 |
+
padding=padding,
|
143 |
+
)
|
144 |
+
self.norm = nn.LayerNorm(embed_dim)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
# standard embedding patch
|
148 |
+
patches = self.patch_embeddings(x)
|
149 |
+
patches = patches.flatten(2).transpose(1, 2)
|
150 |
+
patches = self.norm(patches)
|
151 |
+
return patches
|
152 |
+
|
153 |
+
|
154 |
+
class SelfAttention(nn.Module):
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
embed_dim: int = 768,
|
158 |
+
num_heads: int = 8,
|
159 |
+
sr_ratio: int = 2,
|
160 |
+
qkv_bias: bool = False,
|
161 |
+
attn_dropout: float = 0.0,
|
162 |
+
proj_dropout: float = 0.0,
|
163 |
+
):
|
164 |
+
"""
|
165 |
+
embed_dim : hidden size of the PatchEmbedded input
|
166 |
+
num_heads: number of attention heads
|
167 |
+
sr_ratio: the rate at which to down sample the sequence length of the embedded patch
|
168 |
+
qkv_bias: whether or not the linear projection has bias
|
169 |
+
attn_dropout: the dropout rate of the attention component
|
170 |
+
proj_dropout: the dropout rate of the final linear projection
|
171 |
+
"""
|
172 |
+
super().__init__()
|
173 |
+
assert (
|
174 |
+
embed_dim % num_heads == 0
|
175 |
+
), "Embedding dim should be divisible by number of heads!"
|
176 |
+
|
177 |
+
self.num_heads = num_heads
|
178 |
+
# embedding dimesion of each attention head
|
179 |
+
self.attention_head_dim = embed_dim // num_heads
|
180 |
+
|
181 |
+
# The same input is used to generate the query, key, and value,
|
182 |
+
# (batch_size, num_patches, hidden_size) -> (batch_size, num_patches, attention_head_size)
|
183 |
+
self.query = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
184 |
+
self.key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=qkv_bias)
|
185 |
+
self.attn_dropout = nn.Dropout(attn_dropout)
|
186 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
187 |
+
self.proj_dropout = nn.Dropout(proj_dropout)
|
188 |
+
|
189 |
+
self.sr_ratio = sr_ratio
|
190 |
+
if sr_ratio > 1:
|
191 |
+
self.sr = nn.Conv3d(
|
192 |
+
embed_dim, embed_dim, kernel_size=sr_ratio, stride=sr_ratio
|
193 |
+
)
|
194 |
+
self.sr_norm = nn.LayerNorm(embed_dim)
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
# (batch_size, num_patches, hidden_size)
|
198 |
+
B, N, C = x.shape
|
199 |
+
|
200 |
+
# (batch_size, num_head, sequence_length, embed_dim)
|
201 |
+
q = (
|
202 |
+
self.query(x)
|
203 |
+
.reshape(B, N, self.num_heads, self.attention_head_dim)
|
204 |
+
.permute(0, 2, 1, 3)
|
205 |
+
)
|
206 |
+
|
207 |
+
if self.sr_ratio > 1:
|
208 |
+
n = cube_root(N)
|
209 |
+
# (batch_size, sequence_length, embed_dim) -> (batch_size, embed_dim, patch_D, patch_H, patch_W)
|
210 |
+
x_ = x.permute(0, 2, 1).reshape(B, C, n, n, n)
|
211 |
+
# (batch_size, embed_dim, patch_D, patch_H, patch_W) -> (batch_size, embed_dim, patch_D/sr_ratio, patch_H/sr_ratio, patch_W/sr_ratio)
|
212 |
+
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
213 |
+
# (batch_size, embed_dim, patch_D/sr_ratio, patch_H/sr_ratio, patch_W/sr_ratio) -> (batch_size, sequence_length, embed_dim)
|
214 |
+
# normalizing the layer
|
215 |
+
x_ = self.sr_norm(x_)
|
216 |
+
# (batch_size, num_patches, hidden_size)
|
217 |
+
kv = (
|
218 |
+
self.key_value(x_)
|
219 |
+
.reshape(B, -1, 2, self.num_heads, self.attention_head_dim)
|
220 |
+
.permute(2, 0, 3, 1, 4)
|
221 |
+
)
|
222 |
+
# (2, batch_size, num_heads, num_sequence, attention_head_dim)
|
223 |
+
else:
|
224 |
+
# (batch_size, num_patches, hidden_size)
|
225 |
+
kv = (
|
226 |
+
self.key_value(x)
|
227 |
+
.reshape(B, -1, 2, self.num_heads, self.attention_head_dim)
|
228 |
+
.permute(2, 0, 3, 1, 4)
|
229 |
+
)
|
230 |
+
# (2, batch_size, num_heads, num_sequence, attention_head_dim)
|
231 |
+
|
232 |
+
k, v = kv[0], kv[1]
|
233 |
+
|
234 |
+
attention_score = (q @ k.transpose(-2, -1)) / math.sqrt(self.num_heads)
|
235 |
+
attnention_prob = attention_score.softmax(dim=-1)
|
236 |
+
attnention_prob = self.attn_dropout(attnention_prob)
|
237 |
+
out = (attnention_prob @ v).transpose(1, 2).reshape(B, N, C)
|
238 |
+
out = self.proj(out)
|
239 |
+
out = self.proj_dropout(out)
|
240 |
+
return out
|
241 |
+
|
242 |
+
|
243 |
+
class TransformerBlock(nn.Module):
|
244 |
+
def __init__(
|
245 |
+
self,
|
246 |
+
embed_dim: int = 768,
|
247 |
+
mlp_ratio: int = 2,
|
248 |
+
num_heads: int = 8,
|
249 |
+
sr_ratio: int = 2,
|
250 |
+
qkv_bias: bool = False,
|
251 |
+
attn_dropout: float = 0.0,
|
252 |
+
proj_dropout: float = 0.0,
|
253 |
+
):
|
254 |
+
"""
|
255 |
+
embed_dim : hidden size of the PatchEmbedded input
|
256 |
+
mlp_ratio: at which rate increasse the projection dim of the embedded patch in the _MLP component
|
257 |
+
num_heads: number of attention heads
|
258 |
+
sr_ratio: the rate at which to down sample the sequence length of the embedded patch
|
259 |
+
qkv_bias: whether or not the linear projection has bias
|
260 |
+
attn_dropout: the dropout rate of the attention component
|
261 |
+
proj_dropout: the dropout rate of the final linear projection
|
262 |
+
"""
|
263 |
+
super().__init__()
|
264 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
265 |
+
self.attention = SelfAttention(
|
266 |
+
embed_dim=embed_dim,
|
267 |
+
num_heads=num_heads,
|
268 |
+
sr_ratio=sr_ratio,
|
269 |
+
qkv_bias=qkv_bias,
|
270 |
+
attn_dropout=attn_dropout,
|
271 |
+
proj_dropout=proj_dropout,
|
272 |
+
)
|
273 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
274 |
+
self.mlp = _MLP(in_feature=embed_dim, mlp_ratio=mlp_ratio, dropout=0.0)
|
275 |
+
|
276 |
+
def forward(self, x):
|
277 |
+
x = x + self.attention(self.norm1(x))
|
278 |
+
x = x + self.mlp(self.norm2(x))
|
279 |
+
return x
|
280 |
+
|
281 |
+
|
282 |
+
class MixVisionTransformer(nn.Module):
|
283 |
+
def __init__(
|
284 |
+
self,
|
285 |
+
in_channels: int = 4,
|
286 |
+
sr_ratios: list = [8, 4, 2, 1],
|
287 |
+
embed_dims: list = [64, 128, 320, 512],
|
288 |
+
patch_kernel_size: list = [7, 3, 3, 3],
|
289 |
+
patch_stride: list = [4, 2, 2, 2],
|
290 |
+
patch_padding: list = [3, 1, 1, 1],
|
291 |
+
mlp_ratios: list = [2, 2, 2, 2],
|
292 |
+
num_heads: list = [1, 2, 5, 8],
|
293 |
+
depths: list = [2, 2, 2, 2],
|
294 |
+
):
|
295 |
+
"""
|
296 |
+
in_channels: number of the input channels
|
297 |
+
img_volume_dim: spatial resolution of the image volume (Depth, Width, Height)
|
298 |
+
sr_ratios: the rates at which to down sample the sequence length of the embedded patch
|
299 |
+
embed_dims: hidden size of the PatchEmbedded input
|
300 |
+
patch_kernel_size: kernel size for the convolution in the patch embedding module
|
301 |
+
patch_stride: stride for the convolution in the patch embedding module
|
302 |
+
patch_padding: padding for the convolution in the patch embedding module
|
303 |
+
mlp_ratio: at which rate increasse the projection dim of the hidden_state in the mlp
|
304 |
+
num_heads: number of attenion heads
|
305 |
+
depth: number of attention layers
|
306 |
+
"""
|
307 |
+
super().__init__()
|
308 |
+
|
309 |
+
# patch embedding at different Pyramid level
|
310 |
+
self.embed_1 = PatchEmbedding(
|
311 |
+
in_channel=in_channels,
|
312 |
+
embed_dim=embed_dims[0],
|
313 |
+
kernel_size=patch_kernel_size[0],
|
314 |
+
stride=patch_stride[0],
|
315 |
+
padding=patch_padding[0],
|
316 |
+
)
|
317 |
+
self.embed_2 = PatchEmbedding(
|
318 |
+
in_channel=embed_dims[0],
|
319 |
+
embed_dim=embed_dims[1],
|
320 |
+
kernel_size=patch_kernel_size[1],
|
321 |
+
stride=patch_stride[1],
|
322 |
+
padding=patch_padding[1],
|
323 |
+
)
|
324 |
+
self.embed_3 = PatchEmbedding(
|
325 |
+
in_channel=embed_dims[1],
|
326 |
+
embed_dim=embed_dims[2],
|
327 |
+
kernel_size=patch_kernel_size[2],
|
328 |
+
stride=patch_stride[2],
|
329 |
+
padding=patch_padding[2],
|
330 |
+
)
|
331 |
+
self.embed_4 = PatchEmbedding(
|
332 |
+
in_channel=embed_dims[2],
|
333 |
+
embed_dim=embed_dims[3],
|
334 |
+
kernel_size=patch_kernel_size[3],
|
335 |
+
stride=patch_stride[3],
|
336 |
+
padding=patch_padding[3],
|
337 |
+
)
|
338 |
+
|
339 |
+
# block 1
|
340 |
+
self.tf_block1 = nn.ModuleList(
|
341 |
+
[
|
342 |
+
TransformerBlock(
|
343 |
+
embed_dim=embed_dims[0],
|
344 |
+
num_heads=num_heads[0],
|
345 |
+
mlp_ratio=mlp_ratios[0],
|
346 |
+
sr_ratio=sr_ratios[0],
|
347 |
+
qkv_bias=True,
|
348 |
+
)
|
349 |
+
for _ in range(depths[0])
|
350 |
+
]
|
351 |
+
)
|
352 |
+
self.norm1 = nn.LayerNorm(embed_dims[0])
|
353 |
+
|
354 |
+
# block 2
|
355 |
+
self.tf_block2 = nn.ModuleList(
|
356 |
+
[
|
357 |
+
TransformerBlock(
|
358 |
+
embed_dim=embed_dims[1],
|
359 |
+
num_heads=num_heads[1],
|
360 |
+
mlp_ratio=mlp_ratios[1],
|
361 |
+
sr_ratio=sr_ratios[1],
|
362 |
+
qkv_bias=True,
|
363 |
+
)
|
364 |
+
for _ in range(depths[1])
|
365 |
+
]
|
366 |
+
)
|
367 |
+
self.norm2 = nn.LayerNorm(embed_dims[1])
|
368 |
+
|
369 |
+
# block 3
|
370 |
+
self.tf_block3 = nn.ModuleList(
|
371 |
+
[
|
372 |
+
TransformerBlock(
|
373 |
+
embed_dim=embed_dims[2],
|
374 |
+
num_heads=num_heads[2],
|
375 |
+
mlp_ratio=mlp_ratios[2],
|
376 |
+
sr_ratio=sr_ratios[2],
|
377 |
+
qkv_bias=True,
|
378 |
+
)
|
379 |
+
for _ in range(depths[2])
|
380 |
+
]
|
381 |
+
)
|
382 |
+
self.norm3 = nn.LayerNorm(embed_dims[2])
|
383 |
+
|
384 |
+
# block 4
|
385 |
+
self.tf_block4 = nn.ModuleList(
|
386 |
+
[
|
387 |
+
TransformerBlock(
|
388 |
+
embed_dim=embed_dims[3],
|
389 |
+
num_heads=num_heads[3],
|
390 |
+
mlp_ratio=mlp_ratios[3],
|
391 |
+
sr_ratio=sr_ratios[3],
|
392 |
+
qkv_bias=True,
|
393 |
+
)
|
394 |
+
for _ in range(depths[3])
|
395 |
+
]
|
396 |
+
)
|
397 |
+
self.norm4 = nn.LayerNorm(embed_dims[3])
|
398 |
+
|
399 |
+
def forward(self, x):
|
400 |
+
out = []
|
401 |
+
# at each stage these are the following mappings:
|
402 |
+
# (batch_size, num_patches, hidden_state)
|
403 |
+
# (num_patches,) -> (D, H, W)
|
404 |
+
# (batch_size, num_patches, hidden_state) -> (batch_size, hidden_state, D, H, W)
|
405 |
+
|
406 |
+
# stage 1
|
407 |
+
x = self.embed_1(x)
|
408 |
+
B, N, C = x.shape
|
409 |
+
n = cube_root(N)
|
410 |
+
for i, blk in enumerate(self.tf_block1):
|
411 |
+
x = blk(x)
|
412 |
+
x = self.norm1(x)
|
413 |
+
# (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W)
|
414 |
+
x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous()
|
415 |
+
out.append(x)
|
416 |
+
|
417 |
+
# stage 2
|
418 |
+
x = self.embed_2(x)
|
419 |
+
B, N, C = x.shape
|
420 |
+
n = cube_root(N)
|
421 |
+
for i, blk in enumerate(self.tf_block2):
|
422 |
+
x = blk(x)
|
423 |
+
x = self.norm2(x)
|
424 |
+
# (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W)
|
425 |
+
x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous()
|
426 |
+
out.append(x)
|
427 |
+
|
428 |
+
# stage 3
|
429 |
+
x = self.embed_3(x)
|
430 |
+
B, N, C = x.shape
|
431 |
+
n = cube_root(N)
|
432 |
+
for i, blk in enumerate(self.tf_block3):
|
433 |
+
x = blk(x)
|
434 |
+
x = self.norm3(x)
|
435 |
+
# (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W)
|
436 |
+
x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous()
|
437 |
+
out.append(x)
|
438 |
+
|
439 |
+
# stage 4
|
440 |
+
x = self.embed_4(x)
|
441 |
+
B, N, C = x.shape
|
442 |
+
n = cube_root(N)
|
443 |
+
for i, blk in enumerate(self.tf_block4):
|
444 |
+
x = blk(x)
|
445 |
+
x = self.norm4(x)
|
446 |
+
# (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W)
|
447 |
+
x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous()
|
448 |
+
out.append(x)
|
449 |
+
|
450 |
+
return out
|
451 |
+
|
452 |
+
|
453 |
+
class _MLP(nn.Module):
|
454 |
+
def __init__(self, in_feature, mlp_ratio=2, dropout=0.0):
|
455 |
+
super().__init__()
|
456 |
+
out_feature = mlp_ratio * in_feature
|
457 |
+
self.fc1 = nn.Linear(in_feature, out_feature)
|
458 |
+
self.dwconv = DWConv(dim=out_feature)
|
459 |
+
self.fc2 = nn.Linear(out_feature, in_feature)
|
460 |
+
self.act_fn = nn.GELU()
|
461 |
+
self.dropout = nn.Dropout(dropout)
|
462 |
+
|
463 |
+
def forward(self, x):
|
464 |
+
x = self.fc1(x)
|
465 |
+
x = self.dwconv(x)
|
466 |
+
x = self.act_fn(x)
|
467 |
+
x = self.dropout(x)
|
468 |
+
x = self.fc2(x)
|
469 |
+
x = self.dropout(x)
|
470 |
+
return x
|
471 |
+
|
472 |
+
|
473 |
+
class DWConv(nn.Module):
|
474 |
+
def __init__(self, dim=768):
|
475 |
+
super().__init__()
|
476 |
+
self.dwconv = nn.Conv3d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
477 |
+
# added batchnorm (remove it ?)
|
478 |
+
self.bn = nn.BatchNorm3d(dim)
|
479 |
+
|
480 |
+
def forward(self, x):
|
481 |
+
B, N, C = x.shape
|
482 |
+
# (batch, patch_cube, hidden_size) -> (batch, hidden_size, D, H, W)
|
483 |
+
# assuming D = H = W, i.e. cube root of the patch is an integer number!
|
484 |
+
n = cube_root(N)
|
485 |
+
x = x.transpose(1, 2).view(B, C, n, n, n)
|
486 |
+
x = self.dwconv(x)
|
487 |
+
# added batchnorm (remove it ?)
|
488 |
+
x = self.bn(x)
|
489 |
+
x = x.flatten(2).transpose(1, 2)
|
490 |
+
return x
|
491 |
+
|
492 |
+
###################################################################################
|
493 |
+
def cube_root(n):
|
494 |
+
return round(math.pow(n, (1 / 3)))
|
495 |
+
|
496 |
+
|
497 |
+
###################################################################################
|
498 |
+
# ----------------------------------------------------- decoder -------------------
|
499 |
+
class MLP_(nn.Module):
|
500 |
+
"""
|
501 |
+
Linear Embedding
|
502 |
+
"""
|
503 |
+
|
504 |
+
def __init__(self, input_dim=2048, embed_dim=768):
|
505 |
+
super().__init__()
|
506 |
+
self.proj = nn.Linear(input_dim, embed_dim)
|
507 |
+
self.bn = nn.LayerNorm(embed_dim)
|
508 |
+
|
509 |
+
def forward(self, x):
|
510 |
+
x = x.flatten(2).transpose(1, 2).contiguous()
|
511 |
+
x = self.proj(x)
|
512 |
+
# added batchnorm (remove it ?)
|
513 |
+
x = self.bn(x)
|
514 |
+
return x
|
515 |
+
|
516 |
+
|
517 |
+
###################################################################################
|
518 |
+
class SegFormerDecoderHead(nn.Module):
|
519 |
+
"""
|
520 |
+
SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
|
521 |
+
"""
|
522 |
+
|
523 |
+
def __init__(
|
524 |
+
self,
|
525 |
+
input_feature_dims: list = [512, 320, 128, 64],
|
526 |
+
decoder_head_embedding_dim: int = 256,
|
527 |
+
num_classes: int = 3,
|
528 |
+
dropout: float = 0.0,
|
529 |
+
):
|
530 |
+
"""
|
531 |
+
input_feature_dims: list of the output features channels generated by the transformer encoder
|
532 |
+
decoder_head_embedding_dim: projection dimension of the mlp layer in the all-mlp-decoder module
|
533 |
+
num_classes: number of the output channels
|
534 |
+
dropout: dropout rate of the concatenated feature maps
|
535 |
+
"""
|
536 |
+
super().__init__()
|
537 |
+
self.linear_c4 = MLP_(
|
538 |
+
input_dim=input_feature_dims[0],
|
539 |
+
embed_dim=decoder_head_embedding_dim,
|
540 |
+
)
|
541 |
+
self.linear_c3 = MLP_(
|
542 |
+
input_dim=input_feature_dims[1],
|
543 |
+
embed_dim=decoder_head_embedding_dim,
|
544 |
+
)
|
545 |
+
self.linear_c2 = MLP_(
|
546 |
+
input_dim=input_feature_dims[2],
|
547 |
+
embed_dim=decoder_head_embedding_dim,
|
548 |
+
)
|
549 |
+
self.linear_c1 = MLP_(
|
550 |
+
input_dim=input_feature_dims[3],
|
551 |
+
embed_dim=decoder_head_embedding_dim,
|
552 |
+
)
|
553 |
+
# convolution module to combine feature maps generated by the mlps
|
554 |
+
self.linear_fuse = nn.Sequential(
|
555 |
+
nn.Conv3d(
|
556 |
+
in_channels=4 * decoder_head_embedding_dim,
|
557 |
+
out_channels=decoder_head_embedding_dim,
|
558 |
+
kernel_size=1,
|
559 |
+
stride=1,
|
560 |
+
bias=False,
|
561 |
+
),
|
562 |
+
nn.BatchNorm3d(decoder_head_embedding_dim),
|
563 |
+
nn.ReLU(),
|
564 |
+
)
|
565 |
+
self.dropout = nn.Dropout(dropout)
|
566 |
+
|
567 |
+
# final linear projection layer
|
568 |
+
self.linear_pred = nn.Conv3d(
|
569 |
+
decoder_head_embedding_dim, num_classes, kernel_size=1
|
570 |
+
)
|
571 |
+
|
572 |
+
# segformer decoder generates the final decoded feature map size at 1/4 of the original input volume size
|
573 |
+
self.upsample_volume = nn.Upsample(
|
574 |
+
scale_factor=4.0, mode="trilinear", align_corners=False
|
575 |
+
)
|
576 |
+
|
577 |
+
def forward(self, c1, c2, c3, c4):
|
578 |
+
############## _MLP decoder on C1-C4 ###########
|
579 |
+
n, _, _, _, _ = c4.shape
|
580 |
+
|
581 |
+
_c4 = (
|
582 |
+
self.linear_c4(c4)
|
583 |
+
.permute(0, 2, 1)
|
584 |
+
.reshape(n, -1, c4.shape[2], c4.shape[3], c4.shape[4])
|
585 |
+
.contiguous()
|
586 |
+
)
|
587 |
+
_c4 = torch.nn.functional.interpolate(
|
588 |
+
_c4,
|
589 |
+
size=c1.size()[2:],
|
590 |
+
mode="trilinear",
|
591 |
+
align_corners=False,
|
592 |
+
)
|
593 |
+
|
594 |
+
_c3 = (
|
595 |
+
self.linear_c3(c3)
|
596 |
+
.permute(0, 2, 1)
|
597 |
+
.reshape(n, -1, c3.shape[2], c3.shape[3], c3.shape[4])
|
598 |
+
.contiguous()
|
599 |
+
)
|
600 |
+
_c3 = torch.nn.functional.interpolate(
|
601 |
+
_c3,
|
602 |
+
size=c1.size()[2:],
|
603 |
+
mode="trilinear",
|
604 |
+
align_corners=False,
|
605 |
+
)
|
606 |
+
|
607 |
+
_c2 = (
|
608 |
+
self.linear_c2(c2)
|
609 |
+
.permute(0, 2, 1)
|
610 |
+
.reshape(n, -1, c2.shape[2], c2.shape[3], c2.shape[4])
|
611 |
+
.contiguous()
|
612 |
+
)
|
613 |
+
_c2 = torch.nn.functional.interpolate(
|
614 |
+
_c2,
|
615 |
+
size=c1.size()[2:],
|
616 |
+
mode="trilinear",
|
617 |
+
align_corners=False,
|
618 |
+
)
|
619 |
+
|
620 |
+
_c1 = (
|
621 |
+
self.linear_c1(c1)
|
622 |
+
.permute(0, 2, 1)
|
623 |
+
.reshape(n, -1, c1.shape[2], c1.shape[3], c1.shape[4])
|
624 |
+
.contiguous()
|
625 |
+
)
|
626 |
+
|
627 |
+
_c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
|
628 |
+
|
629 |
+
x = self.dropout(_c)
|
630 |
+
x = self.linear_pred(x)
|
631 |
+
x = self.upsample_volume(x)
|
632 |
+
return x
|
Segformer3DBRATS2021.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import nibabel
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from matplotlib import animation
|
9 |
+
from monai.data import MetaTensor
|
10 |
+
from multiprocessing import Process, Pool
|
11 |
+
from sklearn.preprocessing import MinMaxScaler
|
12 |
+
import nibabel as nib
|
13 |
+
import gdown
|
14 |
+
|
15 |
+
import io
|
16 |
+
from monai.transforms import (
|
17 |
+
Orientation,
|
18 |
+
EnsureType,
|
19 |
+
ConvertToMultiChannelBasedOnBratsClasses,
|
20 |
+
)
|
21 |
+
from Segformer3D import SegFormer3D
|
22 |
+
def predict_from_folder(model, zip_ref, device, D, H, W):
|
23 |
+
"""
|
24 |
+
Dự đoán kết quả segmentation từ một thư mục chứa các file MRI: flair, t1, t1ce, t2.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
model: Mô hình segmentation đã được load.
|
28 |
+
zip_ref: File zip chứa các file MRI.
|
29 |
+
device: Thiết bị chạy mô hình ("cuda" hoặc "cpu").
|
30 |
+
D, H, W: Kích thước của đầu vào sau khi crop.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
prediction: Mặt nạ segmentation dự đoán (numpy array).
|
34 |
+
inputs_rgb: Dữ liệu đầu vào đã chuẩn hóa về khoảng [0, 255] cho hiển thị màu.
|
35 |
+
"""
|
36 |
+
MRI_TYPE = ["flair", "t1", "t1ce", "t2"]
|
37 |
+
|
38 |
+
def load_nii_from_bytes(data_bytes):
|
39 |
+
"""Load file NIfTI từ bytes."""
|
40 |
+
file_like = io.BytesIO(data_bytes)
|
41 |
+
return nib.Nifti1Image.from_file_map({'header': nib.FileHolder(fileobj=file_like),
|
42 |
+
'image': nib.FileHolder(fileobj=file_like)})
|
43 |
+
|
44 |
+
def normalize(x):
|
45 |
+
"""Chuẩn hóa dữ liệu về khoảng [0, 1], đồng thời lưu min và max."""
|
46 |
+
min_val = np.min(x)
|
47 |
+
max_val = np.max(x)
|
48 |
+
scaler = MinMaxScaler(feature_range=(0, 1))
|
49 |
+
normalized_1D_array = scaler.fit_transform(x.reshape(-1, x.shape[-1]))
|
50 |
+
return normalized_1D_array.reshape(x.shape), min_val, max_val
|
51 |
+
|
52 |
+
def denormalize_to_rgb(x, min_val, max_val):
|
53 |
+
"""Chuyển dữ liệu từ [0, 1] về [0, 255]."""
|
54 |
+
return ((x * (max_val - min_val)) + min_val).clip(0, 255).astype(np.uint8)
|
55 |
+
|
56 |
+
def orient(x, affine):
|
57 |
+
"""Chuyển hệ tọa độ về chuẩn RAS."""
|
58 |
+
meta_tensor = MetaTensor(x=x, affine=affine)
|
59 |
+
oriented_tensor = Orientation(axcodes="RAS")(meta_tensor)
|
60 |
+
return EnsureType(data_type="numpy", track_meta=False)(oriented_tensor)
|
61 |
+
|
62 |
+
def crop_brats2021_zero_pixels(x):
|
63 |
+
"""Cắt giảm kích thước về (D, H, W)."""
|
64 |
+
H_start = (x.shape[1] - H) // 2
|
65 |
+
W_start = (x.shape[2] - W) // 2
|
66 |
+
D_start = (x.shape[3] - D) // 2
|
67 |
+
return x[:, H_start:H_start + H, W_start:W_start + W, D_start:D_start + D]
|
68 |
+
|
69 |
+
def preprocess_modality(zip_ref, mri_type):
|
70 |
+
"""Tiền xử lý cho từng modality."""
|
71 |
+
extracted_files = zip_ref.namelist()
|
72 |
+
nii_files = [f for f in extracted_files if f.lower().endswith(f'{mri_type}.nii')]
|
73 |
+
if not nii_files:
|
74 |
+
raise FileNotFoundError(f"No files ending with {mri_type}.nii found.")
|
75 |
+
|
76 |
+
nii_file = nii_files[0]
|
77 |
+
data_bytes = zip_ref.read(nii_file)
|
78 |
+
nii_image = load_nii_from_bytes(data_bytes)
|
79 |
+
|
80 |
+
data = nii_image.get_fdata()
|
81 |
+
affine = nii_image.affine
|
82 |
+
data, min_val, max_val = normalize(data)
|
83 |
+
data = data[np.newaxis, ...]
|
84 |
+
data = orient(data, affine)
|
85 |
+
data = crop_brats2021_zero_pixels(data)
|
86 |
+
return data, min_val, max_val
|
87 |
+
|
88 |
+
# Tiền xử lý cho các modality
|
89 |
+
modalities = []
|
90 |
+
min_max_values = [] # Lưu min và max cho mỗi modality
|
91 |
+
for mri_type in MRI_TYPE:
|
92 |
+
modality, min_val, max_val = preprocess_modality(zip_ref, mri_type)
|
93 |
+
modalities.append(modality)
|
94 |
+
min_max_values.append((min_val, max_val))
|
95 |
+
|
96 |
+
inputs = np.concatenate(modalities, axis=0) # (4, D, H, W)
|
97 |
+
inputs = torch.tensor(inputs).unsqueeze(0).to(device).float()
|
98 |
+
|
99 |
+
# Dự đoán với mô hình
|
100 |
+
model.eval()
|
101 |
+
with torch.no_grad():
|
102 |
+
logits = model(inputs)
|
103 |
+
probabilities = torch.sigmoid(logits)
|
104 |
+
prediction = (probabilities > 0.5).int()
|
105 |
+
inputs_rgb = (inputs.squeeze(0).cpu().numpy()*255).astype(np.int32)
|
106 |
+
return prediction.squeeze(0).cpu().numpy(),inputs_rgb
|
107 |
+
|
108 |
+
def load_model(checkpoint_path, device):
|
109 |
+
model = SegFormer3D()
|
110 |
+
model = model.to(device)
|
111 |
+
# model = torch.nn.DataParallel(model)
|
112 |
+
checkpoint = torch.load(checkpoint_path,weights_only=True, map_location=device)
|
113 |
+
model.load_state_dict(checkpoint['model_state_dict'],strict=False)
|
114 |
+
model.eval()
|
115 |
+
return model
|
116 |
+
def overlay_mask(modalities, prediction):
|
117 |
+
# Giả sử prediction có kích thước (D, H, W, 3) và modalities có kích thước (D, H, W, C)
|
118 |
+
D, H, W = modalities.shape[:3]
|
119 |
+
|
120 |
+
# Khởi tạo một mảng để lưu ảnh overlay cuối cùng
|
121 |
+
overlay_all_slices = []
|
122 |
+
final_masks = []
|
123 |
+
flair_slice_colors = []
|
124 |
+
for slice_idx in range(D):
|
125 |
+
# Lấy modality flair và dự đoán cho slice này
|
126 |
+
flair_slice = modalities[slice_idx, :, :, 0] # (H, W) - Chọn flair modality
|
127 |
+
prediction_slice = prediction[slice_idx, :, :, :] # (H, W, 3)
|
128 |
+
|
129 |
+
# Tách các mask WT, TC, ET
|
130 |
+
wt_mask = prediction_slice[:, :, 1] # Kênh 2: WT
|
131 |
+
tc_mask = prediction_slice[:, :, 0] # Kênh 1: TC
|
132 |
+
et_mask = prediction_slice[:, :, 2] # Kênh 3: ET
|
133 |
+
|
134 |
+
# Chồng các kênh theo thứ tự ET > TC > WT
|
135 |
+
final_mask = np.zeros_like(wt_mask)
|
136 |
+
|
137 |
+
final_mask[et_mask > 0] = 3 # U tăng cường (ET)
|
138 |
+
final_mask[(tc_mask > 0) & (final_mask == 0)] = 2 # Lõi u (TC)
|
139 |
+
final_mask[(wt_mask > 0) & (final_mask == 0)] = 1 # Toàn bộ khối u (WT)
|
140 |
+
final_masks.append(final_mask)
|
141 |
+
# Chuyển flair_slice thành ảnh màu với 3 kênh
|
142 |
+
flair_slice_color = np.stack((flair_slice,) * 3, axis=-1) # (H, W, 3)
|
143 |
+
flair_slice_colors.append(np.copy(flair_slice_color))
|
144 |
+
# Overlay các vùng khác nhau bằng màu RGB
|
145 |
+
flair_slice_color[final_mask == 1] = [255, 255, 0] # WT - Đỏ
|
146 |
+
flair_slice_color[final_mask == 2] = [0, 255, 255] # TC - Xanh lá
|
147 |
+
flair_slice_color[final_mask == 3] = [255, 0, 255] # ET - Xanh dương
|
148 |
+
|
149 |
+
# Lưu ảnh overlay màu vào mảng kết quả
|
150 |
+
overlay_all_slices.append(flair_slice_color)
|
151 |
+
return np.stack(overlay_all_slices)
|
152 |
+
def __call__(zip_ref):
|
153 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
154 |
+
url = "https://drive.google.com/uc?id=1qtWBuwE8PVb-_dzLbl_ySEPX6fNtEGBS"
|
155 |
+
checkpoint_path = "Segformer3D_Brats2021_epoch_50_model.pth"
|
156 |
+
if not os.path.exists(checkpoint_path):
|
157 |
+
gdown.download(url, checkpoint_path, quiet=False)
|
158 |
+
model = load_model(checkpoint_path,device)
|
159 |
+
prediction,modalities = predict_from_folder(model, zip_ref, device, D=128, H=128, W=128)
|
160 |
+
modalities = np.transpose(modalities,(3,2,1,0))
|
161 |
+
prediction = np.transpose(prediction,(3,2,1,0))
|
162 |
+
overlay = overlay_mask(modalities,prediction)
|
163 |
+
return overlay.astype(np.uint8)
|
app.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import zipfile
|
2 |
+
import nibabel as nib
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
import Segformer3DBRATS2021 # Giả sử bạn đã định nghĩa mô hình này ở đâu đó.
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from io import BytesIO
|
9 |
+
import tempfile
|
10 |
+
|
11 |
+
def predict_segmentation(zip_file):
|
12 |
+
"""
|
13 |
+
Hàm xử lý file zip chứa dữ liệu MRI, gọi mô hình Segformer3D để dự đoán và trả về file .nii kết quả.
|
14 |
+
"""
|
15 |
+
try:
|
16 |
+
# Giải nén file zip
|
17 |
+
with zipfile.ZipFile(zip_file) as zip_ref:
|
18 |
+
|
19 |
+
overlay = Segformer3DBRATS2021.__call__(zip_ref)
|
20 |
+
overlay_all_slices = np.transpose(overlay,(3,2,1,0))
|
21 |
+
overlay_tensor = torch.tensor(overlay_all_slices, dtype=torch.float32).unsqueeze(0)
|
22 |
+
target_shape = (240, 240, 155)
|
23 |
+
# Tính toán padding (thêm padding để đạt được kích thước mong muốn)
|
24 |
+
z_pad_before = (target_shape[0] - overlay_tensor.shape[2]) // 2
|
25 |
+
z_pad_after = target_shape[0] - overlay_tensor.shape[2] - z_pad_before
|
26 |
+
|
27 |
+
y_pad_before = (target_shape[1] - overlay_tensor.shape[3]) // 2
|
28 |
+
y_pad_after = target_shape[1] - overlay_tensor.shape[3] - y_pad_before
|
29 |
+
|
30 |
+
x_pad_before = (target_shape[2] - overlay_tensor.shape[4]) // 2
|
31 |
+
x_pad_after = target_shape[2] - overlay_tensor.shape[4] - x_pad_before
|
32 |
+
|
33 |
+
# Tạo padding (đệm đen)
|
34 |
+
padded_tensor = F.pad(overlay_tensor, (x_pad_before, x_pad_after, y_pad_before, y_pad_after, z_pad_before, z_pad_after), value=0)
|
35 |
+
assert padded_tensor.shape[2:] == target_shape, f"Expected shape {target_shape}, got {padded_tensor.shape[2:]}"
|
36 |
+
padded_tensor = padded_tensor.permute(0,2,3,4,1)
|
37 |
+
padded_slices = padded_tensor.squeeze(0).numpy()
|
38 |
+
|
39 |
+
for i in range(padded_slices.shape[2]):
|
40 |
+
padded_slices[:, :, i, :] = np.flipud(np.fliplr(padded_slices[:, :, i, :]))
|
41 |
+
padded_slices = padded_slices.astype(np.uint8)
|
42 |
+
affine = np.eye(4)
|
43 |
+
nii_image = nib.Nifti1Image(padded_slices, affine)
|
44 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.nii') as temp_file:
|
45 |
+
nii_file_path = temp_file.name
|
46 |
+
nib.save(nii_image, nii_file_path)
|
47 |
+
|
48 |
+
# Trả về đường dẫn đến file NIfTI đã lưu
|
49 |
+
return nii_file_path
|
50 |
+
|
51 |
+
except Exception as e:
|
52 |
+
return str(e)
|
53 |
+
|
54 |
+
def main():
|
55 |
+
# Định nghĩa giao diện Gradio
|
56 |
+
inputs = gr.File(label="Upload a ZIP file containing MRI modalities (flair, t1, t1ce, t2)")
|
57 |
+
outputs = gr.File(label="Segmentation Result (.nii)")
|
58 |
+
|
59 |
+
gr.Interface(
|
60 |
+
fn=predict_segmentation,
|
61 |
+
inputs=inputs,
|
62 |
+
outputs=outputs,
|
63 |
+
title="3D Brain Tumor Segmentation",
|
64 |
+
description="Upload a ZIP file containing MRI modalities (flair, t1, t1ce, t2).",
|
65 |
+
allow_flagging="never",
|
66 |
+
).launch(show_error=True)
|
67 |
+
|
68 |
+
if __name__ == '__main__':
|
69 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
numpy
|
3 |
+
nibabel
|
4 |
+
tqdm
|
5 |
+
matplotlib
|
6 |
+
monai
|
7 |
+
scikit-learn
|
8 |
+
gdown
|
9 |
+
gradio==5.8.0
|