NhatNam214 commited on
Commit
3b00cde
·
1 Parent(s): 045b86f
Files changed (4) hide show
  1. Segformer3D.py +632 -0
  2. Segformer3DBRATS2021.py +163 -0
  3. app.py +69 -0
  4. requirements.txt +9 -0
Segformer3D.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import copy
4
+ from torch import nn
5
+ from einops import rearrange
6
+ from functools import partial
7
+
8
+ def build_segformer3d_model(config=None):
9
+ model = SegFormer3D(
10
+ in_channels=config["model_parameters"]["in_channels"],
11
+ sr_ratios=config["model_parameters"]["sr_ratios"],
12
+ embed_dims=config["model_parameters"]["embed_dims"],
13
+ patch_kernel_size=config["model_parameters"]["patch_kernel_size"],
14
+ patch_stride=config["model_parameters"]["patch_stride"],
15
+ patch_padding=config["model_parameters"]["patch_padding"],
16
+ mlp_ratios=config["model_parameters"]["mlp_ratios"],
17
+ num_heads=config["model_parameters"]["num_heads"],
18
+ depths=config["model_parameters"]["depths"],
19
+ decoder_head_embedding_dim=config["model_parameters"][
20
+ "decoder_head_embedding_dim"
21
+ ],
22
+ num_classes=config["model_parameters"]["num_classes"],
23
+ decoder_dropout=config["model_parameters"]["decoder_dropout"],
24
+ )
25
+ return model
26
+
27
+
28
+ class SegFormer3D(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_channels: int = 4,
32
+ sr_ratios: list = [4, 2, 1, 1],
33
+ embed_dims: list = [32, 64, 160, 256],
34
+ patch_kernel_size: list = [7, 3, 3, 3],
35
+ patch_stride: list = [4, 2, 2, 2],
36
+ patch_padding: list = [3, 1, 1, 1],
37
+ mlp_ratios: list = [4, 4, 4, 4],
38
+ num_heads: list = [1, 2, 5, 8],
39
+ depths: list = [2, 2, 2, 2],
40
+ decoder_head_embedding_dim: int = 256,
41
+ num_classes: int = 3,
42
+ decoder_dropout: float = 0.0,
43
+ ):
44
+ """
45
+ in_channels: number of the input channels
46
+ img_volume_dim: spatial resolution of the image volume (Depth, Width, Height)
47
+ sr_ratios: the rates at which to down sample the sequence length of the embedded patch
48
+ embed_dims: hidden size of the PatchEmbedded input
49
+ patch_kernel_size: kernel size for the convolution in the patch embedding module
50
+ patch_stride: stride for the convolution in the patch embedding module
51
+ patch_padding: padding for the convolution in the patch embedding module
52
+ mlp_ratios: at which rate increases the projection dim of the hidden_state in the mlp
53
+ num_heads: number of attention heads
54
+ depths: number of attention layers
55
+ decoder_head_embedding_dim: projection dimension of the mlp layer in the all-mlp-decoder module
56
+ num_classes: number of the output channel of the network
57
+ decoder_dropout: dropout rate of the concatenated feature maps
58
+
59
+ """
60
+ super().__init__()
61
+ self.segformer_encoder = MixVisionTransformer(
62
+ in_channels=in_channels,
63
+ sr_ratios=sr_ratios,
64
+ embed_dims=embed_dims,
65
+ patch_kernel_size=patch_kernel_size,
66
+ patch_stride=patch_stride,
67
+ patch_padding=patch_padding,
68
+ mlp_ratios=mlp_ratios,
69
+ num_heads=num_heads,
70
+ depths=depths,
71
+ )
72
+ # decoder takes in the feature maps in the reversed order
73
+ reversed_embed_dims = embed_dims[::-1]
74
+ self.segformer_decoder = SegFormerDecoderHead(
75
+ input_feature_dims=reversed_embed_dims,
76
+ decoder_head_embedding_dim=decoder_head_embedding_dim,
77
+ num_classes=num_classes,
78
+ dropout=decoder_dropout,
79
+ )
80
+ self.apply(self._init_weights)
81
+
82
+ def _init_weights(self, m):
83
+ if isinstance(m, nn.Linear):
84
+ nn.init.trunc_normal_(m.weight, std=0.02)
85
+ if isinstance(m, nn.Linear) and m.bias is not None:
86
+ nn.init.constant_(m.bias, 0)
87
+ elif isinstance(m, nn.LayerNorm):
88
+ nn.init.constant_(m.bias, 0)
89
+ nn.init.constant_(m.weight, 1.0)
90
+ elif isinstance(m, nn.BatchNorm2d):
91
+ nn.init.constant_(m.bias, 0)
92
+ nn.init.constant_(m.weight, 1.0)
93
+ elif isinstance(m, nn.BatchNorm3d):
94
+ nn.init.constant_(m.bias, 0)
95
+ nn.init.constant_(m.weight, 1.0)
96
+ elif isinstance(m, nn.Conv2d):
97
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
98
+ fan_out //= m.groups
99
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
100
+ if m.bias is not None:
101
+ m.bias.data.zero_()
102
+ elif isinstance(m, nn.Conv3d):
103
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
104
+ fan_out //= m.groups
105
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
106
+ if m.bias is not None:
107
+ m.bias.data.zero_()
108
+
109
+
110
+ def forward(self, x):
111
+ # embedding the input
112
+ x = self.segformer_encoder(x)
113
+ # # unpacking the embedded features generated by the transformer
114
+ c1 = x[0]
115
+ c2 = x[1]
116
+ c3 = x[2]
117
+ c4 = x[3]
118
+ # decoding the embedded features
119
+ x = self.segformer_decoder(c1, c2, c3, c4)
120
+ return x
121
+
122
+ # ----------------------------------------------------- encoder -----------------------------------------------------
123
+ class PatchEmbedding(nn.Module):
124
+ def __init__(
125
+ self,
126
+ in_channel: int = 4,
127
+ embed_dim: int = 768,
128
+ kernel_size: int = 7,
129
+ stride: int = 4,
130
+ padding: int = 3,
131
+ ):
132
+ """
133
+ in_channels: number of the channels in the input volume
134
+ embed_dim: embedding dimmesion of the patch
135
+ """
136
+ super().__init__()
137
+ self.patch_embeddings = nn.Conv3d(
138
+ in_channel,
139
+ embed_dim,
140
+ kernel_size=kernel_size,
141
+ stride=stride,
142
+ padding=padding,
143
+ )
144
+ self.norm = nn.LayerNorm(embed_dim)
145
+
146
+ def forward(self, x):
147
+ # standard embedding patch
148
+ patches = self.patch_embeddings(x)
149
+ patches = patches.flatten(2).transpose(1, 2)
150
+ patches = self.norm(patches)
151
+ return patches
152
+
153
+
154
+ class SelfAttention(nn.Module):
155
+ def __init__(
156
+ self,
157
+ embed_dim: int = 768,
158
+ num_heads: int = 8,
159
+ sr_ratio: int = 2,
160
+ qkv_bias: bool = False,
161
+ attn_dropout: float = 0.0,
162
+ proj_dropout: float = 0.0,
163
+ ):
164
+ """
165
+ embed_dim : hidden size of the PatchEmbedded input
166
+ num_heads: number of attention heads
167
+ sr_ratio: the rate at which to down sample the sequence length of the embedded patch
168
+ qkv_bias: whether or not the linear projection has bias
169
+ attn_dropout: the dropout rate of the attention component
170
+ proj_dropout: the dropout rate of the final linear projection
171
+ """
172
+ super().__init__()
173
+ assert (
174
+ embed_dim % num_heads == 0
175
+ ), "Embedding dim should be divisible by number of heads!"
176
+
177
+ self.num_heads = num_heads
178
+ # embedding dimesion of each attention head
179
+ self.attention_head_dim = embed_dim // num_heads
180
+
181
+ # The same input is used to generate the query, key, and value,
182
+ # (batch_size, num_patches, hidden_size) -> (batch_size, num_patches, attention_head_size)
183
+ self.query = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
184
+ self.key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=qkv_bias)
185
+ self.attn_dropout = nn.Dropout(attn_dropout)
186
+ self.proj = nn.Linear(embed_dim, embed_dim)
187
+ self.proj_dropout = nn.Dropout(proj_dropout)
188
+
189
+ self.sr_ratio = sr_ratio
190
+ if sr_ratio > 1:
191
+ self.sr = nn.Conv3d(
192
+ embed_dim, embed_dim, kernel_size=sr_ratio, stride=sr_ratio
193
+ )
194
+ self.sr_norm = nn.LayerNorm(embed_dim)
195
+
196
+ def forward(self, x):
197
+ # (batch_size, num_patches, hidden_size)
198
+ B, N, C = x.shape
199
+
200
+ # (batch_size, num_head, sequence_length, embed_dim)
201
+ q = (
202
+ self.query(x)
203
+ .reshape(B, N, self.num_heads, self.attention_head_dim)
204
+ .permute(0, 2, 1, 3)
205
+ )
206
+
207
+ if self.sr_ratio > 1:
208
+ n = cube_root(N)
209
+ # (batch_size, sequence_length, embed_dim) -> (batch_size, embed_dim, patch_D, patch_H, patch_W)
210
+ x_ = x.permute(0, 2, 1).reshape(B, C, n, n, n)
211
+ # (batch_size, embed_dim, patch_D, patch_H, patch_W) -> (batch_size, embed_dim, patch_D/sr_ratio, patch_H/sr_ratio, patch_W/sr_ratio)
212
+ x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
213
+ # (batch_size, embed_dim, patch_D/sr_ratio, patch_H/sr_ratio, patch_W/sr_ratio) -> (batch_size, sequence_length, embed_dim)
214
+ # normalizing the layer
215
+ x_ = self.sr_norm(x_)
216
+ # (batch_size, num_patches, hidden_size)
217
+ kv = (
218
+ self.key_value(x_)
219
+ .reshape(B, -1, 2, self.num_heads, self.attention_head_dim)
220
+ .permute(2, 0, 3, 1, 4)
221
+ )
222
+ # (2, batch_size, num_heads, num_sequence, attention_head_dim)
223
+ else:
224
+ # (batch_size, num_patches, hidden_size)
225
+ kv = (
226
+ self.key_value(x)
227
+ .reshape(B, -1, 2, self.num_heads, self.attention_head_dim)
228
+ .permute(2, 0, 3, 1, 4)
229
+ )
230
+ # (2, batch_size, num_heads, num_sequence, attention_head_dim)
231
+
232
+ k, v = kv[0], kv[1]
233
+
234
+ attention_score = (q @ k.transpose(-2, -1)) / math.sqrt(self.num_heads)
235
+ attnention_prob = attention_score.softmax(dim=-1)
236
+ attnention_prob = self.attn_dropout(attnention_prob)
237
+ out = (attnention_prob @ v).transpose(1, 2).reshape(B, N, C)
238
+ out = self.proj(out)
239
+ out = self.proj_dropout(out)
240
+ return out
241
+
242
+
243
+ class TransformerBlock(nn.Module):
244
+ def __init__(
245
+ self,
246
+ embed_dim: int = 768,
247
+ mlp_ratio: int = 2,
248
+ num_heads: int = 8,
249
+ sr_ratio: int = 2,
250
+ qkv_bias: bool = False,
251
+ attn_dropout: float = 0.0,
252
+ proj_dropout: float = 0.0,
253
+ ):
254
+ """
255
+ embed_dim : hidden size of the PatchEmbedded input
256
+ mlp_ratio: at which rate increasse the projection dim of the embedded patch in the _MLP component
257
+ num_heads: number of attention heads
258
+ sr_ratio: the rate at which to down sample the sequence length of the embedded patch
259
+ qkv_bias: whether or not the linear projection has bias
260
+ attn_dropout: the dropout rate of the attention component
261
+ proj_dropout: the dropout rate of the final linear projection
262
+ """
263
+ super().__init__()
264
+ self.norm1 = nn.LayerNorm(embed_dim)
265
+ self.attention = SelfAttention(
266
+ embed_dim=embed_dim,
267
+ num_heads=num_heads,
268
+ sr_ratio=sr_ratio,
269
+ qkv_bias=qkv_bias,
270
+ attn_dropout=attn_dropout,
271
+ proj_dropout=proj_dropout,
272
+ )
273
+ self.norm2 = nn.LayerNorm(embed_dim)
274
+ self.mlp = _MLP(in_feature=embed_dim, mlp_ratio=mlp_ratio, dropout=0.0)
275
+
276
+ def forward(self, x):
277
+ x = x + self.attention(self.norm1(x))
278
+ x = x + self.mlp(self.norm2(x))
279
+ return x
280
+
281
+
282
+ class MixVisionTransformer(nn.Module):
283
+ def __init__(
284
+ self,
285
+ in_channels: int = 4,
286
+ sr_ratios: list = [8, 4, 2, 1],
287
+ embed_dims: list = [64, 128, 320, 512],
288
+ patch_kernel_size: list = [7, 3, 3, 3],
289
+ patch_stride: list = [4, 2, 2, 2],
290
+ patch_padding: list = [3, 1, 1, 1],
291
+ mlp_ratios: list = [2, 2, 2, 2],
292
+ num_heads: list = [1, 2, 5, 8],
293
+ depths: list = [2, 2, 2, 2],
294
+ ):
295
+ """
296
+ in_channels: number of the input channels
297
+ img_volume_dim: spatial resolution of the image volume (Depth, Width, Height)
298
+ sr_ratios: the rates at which to down sample the sequence length of the embedded patch
299
+ embed_dims: hidden size of the PatchEmbedded input
300
+ patch_kernel_size: kernel size for the convolution in the patch embedding module
301
+ patch_stride: stride for the convolution in the patch embedding module
302
+ patch_padding: padding for the convolution in the patch embedding module
303
+ mlp_ratio: at which rate increasse the projection dim of the hidden_state in the mlp
304
+ num_heads: number of attenion heads
305
+ depth: number of attention layers
306
+ """
307
+ super().__init__()
308
+
309
+ # patch embedding at different Pyramid level
310
+ self.embed_1 = PatchEmbedding(
311
+ in_channel=in_channels,
312
+ embed_dim=embed_dims[0],
313
+ kernel_size=patch_kernel_size[0],
314
+ stride=patch_stride[0],
315
+ padding=patch_padding[0],
316
+ )
317
+ self.embed_2 = PatchEmbedding(
318
+ in_channel=embed_dims[0],
319
+ embed_dim=embed_dims[1],
320
+ kernel_size=patch_kernel_size[1],
321
+ stride=patch_stride[1],
322
+ padding=patch_padding[1],
323
+ )
324
+ self.embed_3 = PatchEmbedding(
325
+ in_channel=embed_dims[1],
326
+ embed_dim=embed_dims[2],
327
+ kernel_size=patch_kernel_size[2],
328
+ stride=patch_stride[2],
329
+ padding=patch_padding[2],
330
+ )
331
+ self.embed_4 = PatchEmbedding(
332
+ in_channel=embed_dims[2],
333
+ embed_dim=embed_dims[3],
334
+ kernel_size=patch_kernel_size[3],
335
+ stride=patch_stride[3],
336
+ padding=patch_padding[3],
337
+ )
338
+
339
+ # block 1
340
+ self.tf_block1 = nn.ModuleList(
341
+ [
342
+ TransformerBlock(
343
+ embed_dim=embed_dims[0],
344
+ num_heads=num_heads[0],
345
+ mlp_ratio=mlp_ratios[0],
346
+ sr_ratio=sr_ratios[0],
347
+ qkv_bias=True,
348
+ )
349
+ for _ in range(depths[0])
350
+ ]
351
+ )
352
+ self.norm1 = nn.LayerNorm(embed_dims[0])
353
+
354
+ # block 2
355
+ self.tf_block2 = nn.ModuleList(
356
+ [
357
+ TransformerBlock(
358
+ embed_dim=embed_dims[1],
359
+ num_heads=num_heads[1],
360
+ mlp_ratio=mlp_ratios[1],
361
+ sr_ratio=sr_ratios[1],
362
+ qkv_bias=True,
363
+ )
364
+ for _ in range(depths[1])
365
+ ]
366
+ )
367
+ self.norm2 = nn.LayerNorm(embed_dims[1])
368
+
369
+ # block 3
370
+ self.tf_block3 = nn.ModuleList(
371
+ [
372
+ TransformerBlock(
373
+ embed_dim=embed_dims[2],
374
+ num_heads=num_heads[2],
375
+ mlp_ratio=mlp_ratios[2],
376
+ sr_ratio=sr_ratios[2],
377
+ qkv_bias=True,
378
+ )
379
+ for _ in range(depths[2])
380
+ ]
381
+ )
382
+ self.norm3 = nn.LayerNorm(embed_dims[2])
383
+
384
+ # block 4
385
+ self.tf_block4 = nn.ModuleList(
386
+ [
387
+ TransformerBlock(
388
+ embed_dim=embed_dims[3],
389
+ num_heads=num_heads[3],
390
+ mlp_ratio=mlp_ratios[3],
391
+ sr_ratio=sr_ratios[3],
392
+ qkv_bias=True,
393
+ )
394
+ for _ in range(depths[3])
395
+ ]
396
+ )
397
+ self.norm4 = nn.LayerNorm(embed_dims[3])
398
+
399
+ def forward(self, x):
400
+ out = []
401
+ # at each stage these are the following mappings:
402
+ # (batch_size, num_patches, hidden_state)
403
+ # (num_patches,) -> (D, H, W)
404
+ # (batch_size, num_patches, hidden_state) -> (batch_size, hidden_state, D, H, W)
405
+
406
+ # stage 1
407
+ x = self.embed_1(x)
408
+ B, N, C = x.shape
409
+ n = cube_root(N)
410
+ for i, blk in enumerate(self.tf_block1):
411
+ x = blk(x)
412
+ x = self.norm1(x)
413
+ # (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W)
414
+ x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous()
415
+ out.append(x)
416
+
417
+ # stage 2
418
+ x = self.embed_2(x)
419
+ B, N, C = x.shape
420
+ n = cube_root(N)
421
+ for i, blk in enumerate(self.tf_block2):
422
+ x = blk(x)
423
+ x = self.norm2(x)
424
+ # (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W)
425
+ x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous()
426
+ out.append(x)
427
+
428
+ # stage 3
429
+ x = self.embed_3(x)
430
+ B, N, C = x.shape
431
+ n = cube_root(N)
432
+ for i, blk in enumerate(self.tf_block3):
433
+ x = blk(x)
434
+ x = self.norm3(x)
435
+ # (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W)
436
+ x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous()
437
+ out.append(x)
438
+
439
+ # stage 4
440
+ x = self.embed_4(x)
441
+ B, N, C = x.shape
442
+ n = cube_root(N)
443
+ for i, blk in enumerate(self.tf_block4):
444
+ x = blk(x)
445
+ x = self.norm4(x)
446
+ # (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W)
447
+ x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous()
448
+ out.append(x)
449
+
450
+ return out
451
+
452
+
453
+ class _MLP(nn.Module):
454
+ def __init__(self, in_feature, mlp_ratio=2, dropout=0.0):
455
+ super().__init__()
456
+ out_feature = mlp_ratio * in_feature
457
+ self.fc1 = nn.Linear(in_feature, out_feature)
458
+ self.dwconv = DWConv(dim=out_feature)
459
+ self.fc2 = nn.Linear(out_feature, in_feature)
460
+ self.act_fn = nn.GELU()
461
+ self.dropout = nn.Dropout(dropout)
462
+
463
+ def forward(self, x):
464
+ x = self.fc1(x)
465
+ x = self.dwconv(x)
466
+ x = self.act_fn(x)
467
+ x = self.dropout(x)
468
+ x = self.fc2(x)
469
+ x = self.dropout(x)
470
+ return x
471
+
472
+
473
+ class DWConv(nn.Module):
474
+ def __init__(self, dim=768):
475
+ super().__init__()
476
+ self.dwconv = nn.Conv3d(dim, dim, 3, 1, 1, bias=True, groups=dim)
477
+ # added batchnorm (remove it ?)
478
+ self.bn = nn.BatchNorm3d(dim)
479
+
480
+ def forward(self, x):
481
+ B, N, C = x.shape
482
+ # (batch, patch_cube, hidden_size) -> (batch, hidden_size, D, H, W)
483
+ # assuming D = H = W, i.e. cube root of the patch is an integer number!
484
+ n = cube_root(N)
485
+ x = x.transpose(1, 2).view(B, C, n, n, n)
486
+ x = self.dwconv(x)
487
+ # added batchnorm (remove it ?)
488
+ x = self.bn(x)
489
+ x = x.flatten(2).transpose(1, 2)
490
+ return x
491
+
492
+ ###################################################################################
493
+ def cube_root(n):
494
+ return round(math.pow(n, (1 / 3)))
495
+
496
+
497
+ ###################################################################################
498
+ # ----------------------------------------------------- decoder -------------------
499
+ class MLP_(nn.Module):
500
+ """
501
+ Linear Embedding
502
+ """
503
+
504
+ def __init__(self, input_dim=2048, embed_dim=768):
505
+ super().__init__()
506
+ self.proj = nn.Linear(input_dim, embed_dim)
507
+ self.bn = nn.LayerNorm(embed_dim)
508
+
509
+ def forward(self, x):
510
+ x = x.flatten(2).transpose(1, 2).contiguous()
511
+ x = self.proj(x)
512
+ # added batchnorm (remove it ?)
513
+ x = self.bn(x)
514
+ return x
515
+
516
+
517
+ ###################################################################################
518
+ class SegFormerDecoderHead(nn.Module):
519
+ """
520
+ SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
521
+ """
522
+
523
+ def __init__(
524
+ self,
525
+ input_feature_dims: list = [512, 320, 128, 64],
526
+ decoder_head_embedding_dim: int = 256,
527
+ num_classes: int = 3,
528
+ dropout: float = 0.0,
529
+ ):
530
+ """
531
+ input_feature_dims: list of the output features channels generated by the transformer encoder
532
+ decoder_head_embedding_dim: projection dimension of the mlp layer in the all-mlp-decoder module
533
+ num_classes: number of the output channels
534
+ dropout: dropout rate of the concatenated feature maps
535
+ """
536
+ super().__init__()
537
+ self.linear_c4 = MLP_(
538
+ input_dim=input_feature_dims[0],
539
+ embed_dim=decoder_head_embedding_dim,
540
+ )
541
+ self.linear_c3 = MLP_(
542
+ input_dim=input_feature_dims[1],
543
+ embed_dim=decoder_head_embedding_dim,
544
+ )
545
+ self.linear_c2 = MLP_(
546
+ input_dim=input_feature_dims[2],
547
+ embed_dim=decoder_head_embedding_dim,
548
+ )
549
+ self.linear_c1 = MLP_(
550
+ input_dim=input_feature_dims[3],
551
+ embed_dim=decoder_head_embedding_dim,
552
+ )
553
+ # convolution module to combine feature maps generated by the mlps
554
+ self.linear_fuse = nn.Sequential(
555
+ nn.Conv3d(
556
+ in_channels=4 * decoder_head_embedding_dim,
557
+ out_channels=decoder_head_embedding_dim,
558
+ kernel_size=1,
559
+ stride=1,
560
+ bias=False,
561
+ ),
562
+ nn.BatchNorm3d(decoder_head_embedding_dim),
563
+ nn.ReLU(),
564
+ )
565
+ self.dropout = nn.Dropout(dropout)
566
+
567
+ # final linear projection layer
568
+ self.linear_pred = nn.Conv3d(
569
+ decoder_head_embedding_dim, num_classes, kernel_size=1
570
+ )
571
+
572
+ # segformer decoder generates the final decoded feature map size at 1/4 of the original input volume size
573
+ self.upsample_volume = nn.Upsample(
574
+ scale_factor=4.0, mode="trilinear", align_corners=False
575
+ )
576
+
577
+ def forward(self, c1, c2, c3, c4):
578
+ ############## _MLP decoder on C1-C4 ###########
579
+ n, _, _, _, _ = c4.shape
580
+
581
+ _c4 = (
582
+ self.linear_c4(c4)
583
+ .permute(0, 2, 1)
584
+ .reshape(n, -1, c4.shape[2], c4.shape[3], c4.shape[4])
585
+ .contiguous()
586
+ )
587
+ _c4 = torch.nn.functional.interpolate(
588
+ _c4,
589
+ size=c1.size()[2:],
590
+ mode="trilinear",
591
+ align_corners=False,
592
+ )
593
+
594
+ _c3 = (
595
+ self.linear_c3(c3)
596
+ .permute(0, 2, 1)
597
+ .reshape(n, -1, c3.shape[2], c3.shape[3], c3.shape[4])
598
+ .contiguous()
599
+ )
600
+ _c3 = torch.nn.functional.interpolate(
601
+ _c3,
602
+ size=c1.size()[2:],
603
+ mode="trilinear",
604
+ align_corners=False,
605
+ )
606
+
607
+ _c2 = (
608
+ self.linear_c2(c2)
609
+ .permute(0, 2, 1)
610
+ .reshape(n, -1, c2.shape[2], c2.shape[3], c2.shape[4])
611
+ .contiguous()
612
+ )
613
+ _c2 = torch.nn.functional.interpolate(
614
+ _c2,
615
+ size=c1.size()[2:],
616
+ mode="trilinear",
617
+ align_corners=False,
618
+ )
619
+
620
+ _c1 = (
621
+ self.linear_c1(c1)
622
+ .permute(0, 2, 1)
623
+ .reshape(n, -1, c1.shape[2], c1.shape[3], c1.shape[4])
624
+ .contiguous()
625
+ )
626
+
627
+ _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
628
+
629
+ x = self.dropout(_c)
630
+ x = self.linear_pred(x)
631
+ x = self.upsample_volume(x)
632
+ return x
Segformer3DBRATS2021.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import nibabel
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib import animation
9
+ from monai.data import MetaTensor
10
+ from multiprocessing import Process, Pool
11
+ from sklearn.preprocessing import MinMaxScaler
12
+ import nibabel as nib
13
+ import gdown
14
+
15
+ import io
16
+ from monai.transforms import (
17
+ Orientation,
18
+ EnsureType,
19
+ ConvertToMultiChannelBasedOnBratsClasses,
20
+ )
21
+ from Segformer3D import SegFormer3D
22
+ def predict_from_folder(model, zip_ref, device, D, H, W):
23
+ """
24
+ Dự đoán kết quả segmentation từ một thư mục chứa các file MRI: flair, t1, t1ce, t2.
25
+
26
+ Args:
27
+ model: Mô hình segmentation đã được load.
28
+ zip_ref: File zip chứa các file MRI.
29
+ device: Thiết bị chạy mô hình ("cuda" hoặc "cpu").
30
+ D, H, W: Kích thước của đầu vào sau khi crop.
31
+
32
+ Returns:
33
+ prediction: Mặt nạ segmentation dự đoán (numpy array).
34
+ inputs_rgb: Dữ liệu đầu vào đã chuẩn hóa về khoảng [0, 255] cho hiển thị màu.
35
+ """
36
+ MRI_TYPE = ["flair", "t1", "t1ce", "t2"]
37
+
38
+ def load_nii_from_bytes(data_bytes):
39
+ """Load file NIfTI từ bytes."""
40
+ file_like = io.BytesIO(data_bytes)
41
+ return nib.Nifti1Image.from_file_map({'header': nib.FileHolder(fileobj=file_like),
42
+ 'image': nib.FileHolder(fileobj=file_like)})
43
+
44
+ def normalize(x):
45
+ """Chuẩn hóa dữ liệu về khoảng [0, 1], đồng thời lưu min và max."""
46
+ min_val = np.min(x)
47
+ max_val = np.max(x)
48
+ scaler = MinMaxScaler(feature_range=(0, 1))
49
+ normalized_1D_array = scaler.fit_transform(x.reshape(-1, x.shape[-1]))
50
+ return normalized_1D_array.reshape(x.shape), min_val, max_val
51
+
52
+ def denormalize_to_rgb(x, min_val, max_val):
53
+ """Chuyển dữ liệu từ [0, 1] về [0, 255]."""
54
+ return ((x * (max_val - min_val)) + min_val).clip(0, 255).astype(np.uint8)
55
+
56
+ def orient(x, affine):
57
+ """Chuyển hệ tọa độ về chuẩn RAS."""
58
+ meta_tensor = MetaTensor(x=x, affine=affine)
59
+ oriented_tensor = Orientation(axcodes="RAS")(meta_tensor)
60
+ return EnsureType(data_type="numpy", track_meta=False)(oriented_tensor)
61
+
62
+ def crop_brats2021_zero_pixels(x):
63
+ """Cắt giảm kích thước về (D, H, W)."""
64
+ H_start = (x.shape[1] - H) // 2
65
+ W_start = (x.shape[2] - W) // 2
66
+ D_start = (x.shape[3] - D) // 2
67
+ return x[:, H_start:H_start + H, W_start:W_start + W, D_start:D_start + D]
68
+
69
+ def preprocess_modality(zip_ref, mri_type):
70
+ """Tiền xử lý cho từng modality."""
71
+ extracted_files = zip_ref.namelist()
72
+ nii_files = [f for f in extracted_files if f.lower().endswith(f'{mri_type}.nii')]
73
+ if not nii_files:
74
+ raise FileNotFoundError(f"No files ending with {mri_type}.nii found.")
75
+
76
+ nii_file = nii_files[0]
77
+ data_bytes = zip_ref.read(nii_file)
78
+ nii_image = load_nii_from_bytes(data_bytes)
79
+
80
+ data = nii_image.get_fdata()
81
+ affine = nii_image.affine
82
+ data, min_val, max_val = normalize(data)
83
+ data = data[np.newaxis, ...]
84
+ data = orient(data, affine)
85
+ data = crop_brats2021_zero_pixels(data)
86
+ return data, min_val, max_val
87
+
88
+ # Tiền xử lý cho các modality
89
+ modalities = []
90
+ min_max_values = [] # Lưu min và max cho mỗi modality
91
+ for mri_type in MRI_TYPE:
92
+ modality, min_val, max_val = preprocess_modality(zip_ref, mri_type)
93
+ modalities.append(modality)
94
+ min_max_values.append((min_val, max_val))
95
+
96
+ inputs = np.concatenate(modalities, axis=0) # (4, D, H, W)
97
+ inputs = torch.tensor(inputs).unsqueeze(0).to(device).float()
98
+
99
+ # Dự đoán với mô hình
100
+ model.eval()
101
+ with torch.no_grad():
102
+ logits = model(inputs)
103
+ probabilities = torch.sigmoid(logits)
104
+ prediction = (probabilities > 0.5).int()
105
+ inputs_rgb = (inputs.squeeze(0).cpu().numpy()*255).astype(np.int32)
106
+ return prediction.squeeze(0).cpu().numpy(),inputs_rgb
107
+
108
+ def load_model(checkpoint_path, device):
109
+ model = SegFormer3D()
110
+ model = model.to(device)
111
+ # model = torch.nn.DataParallel(model)
112
+ checkpoint = torch.load(checkpoint_path,weights_only=True, map_location=device)
113
+ model.load_state_dict(checkpoint['model_state_dict'],strict=False)
114
+ model.eval()
115
+ return model
116
+ def overlay_mask(modalities, prediction):
117
+ # Giả sử prediction có kích thước (D, H, W, 3) và modalities có kích thước (D, H, W, C)
118
+ D, H, W = modalities.shape[:3]
119
+
120
+ # Khởi tạo một mảng để lưu ảnh overlay cuối cùng
121
+ overlay_all_slices = []
122
+ final_masks = []
123
+ flair_slice_colors = []
124
+ for slice_idx in range(D):
125
+ # Lấy modality flair và dự đoán cho slice này
126
+ flair_slice = modalities[slice_idx, :, :, 0] # (H, W) - Chọn flair modality
127
+ prediction_slice = prediction[slice_idx, :, :, :] # (H, W, 3)
128
+
129
+ # Tách các mask WT, TC, ET
130
+ wt_mask = prediction_slice[:, :, 1] # Kênh 2: WT
131
+ tc_mask = prediction_slice[:, :, 0] # Kênh 1: TC
132
+ et_mask = prediction_slice[:, :, 2] # Kênh 3: ET
133
+
134
+ # Chồng các kênh theo thứ tự ET > TC > WT
135
+ final_mask = np.zeros_like(wt_mask)
136
+
137
+ final_mask[et_mask > 0] = 3 # U tăng cường (ET)
138
+ final_mask[(tc_mask > 0) & (final_mask == 0)] = 2 # Lõi u (TC)
139
+ final_mask[(wt_mask > 0) & (final_mask == 0)] = 1 # Toàn bộ khối u (WT)
140
+ final_masks.append(final_mask)
141
+ # Chuyển flair_slice thành ảnh màu với 3 kênh
142
+ flair_slice_color = np.stack((flair_slice,) * 3, axis=-1) # (H, W, 3)
143
+ flair_slice_colors.append(np.copy(flair_slice_color))
144
+ # Overlay các vùng khác nhau bằng màu RGB
145
+ flair_slice_color[final_mask == 1] = [255, 255, 0] # WT - Đỏ
146
+ flair_slice_color[final_mask == 2] = [0, 255, 255] # TC - Xanh lá
147
+ flair_slice_color[final_mask == 3] = [255, 0, 255] # ET - Xanh dương
148
+
149
+ # Lưu ảnh overlay màu vào mảng kết quả
150
+ overlay_all_slices.append(flair_slice_color)
151
+ return np.stack(overlay_all_slices)
152
+ def __call__(zip_ref):
153
+ device = "cuda" if torch.cuda.is_available() else "cpu"
154
+ url = "https://drive.google.com/uc?id=1qtWBuwE8PVb-_dzLbl_ySEPX6fNtEGBS"
155
+ checkpoint_path = "Segformer3D_Brats2021_epoch_50_model.pth"
156
+ if not os.path.exists(checkpoint_path):
157
+ gdown.download(url, checkpoint_path, quiet=False)
158
+ model = load_model(checkpoint_path,device)
159
+ prediction,modalities = predict_from_folder(model, zip_ref, device, D=128, H=128, W=128)
160
+ modalities = np.transpose(modalities,(3,2,1,0))
161
+ prediction = np.transpose(prediction,(3,2,1,0))
162
+ overlay = overlay_mask(modalities,prediction)
163
+ return overlay.astype(np.uint8)
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import zipfile
2
+ import nibabel as nib
3
+ import numpy as np
4
+ import gradio as gr
5
+ import Segformer3DBRATS2021 # Giả sử bạn đã định nghĩa mô hình này ở đâu đó.
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from io import BytesIO
9
+ import tempfile
10
+
11
+ def predict_segmentation(zip_file):
12
+ """
13
+ Hàm xử lý file zip chứa dữ liệu MRI, gọi mô hình Segformer3D để dự đoán và trả về file .nii kết quả.
14
+ """
15
+ try:
16
+ # Giải nén file zip
17
+ with zipfile.ZipFile(zip_file) as zip_ref:
18
+
19
+ overlay = Segformer3DBRATS2021.__call__(zip_ref)
20
+ overlay_all_slices = np.transpose(overlay,(3,2,1,0))
21
+ overlay_tensor = torch.tensor(overlay_all_slices, dtype=torch.float32).unsqueeze(0)
22
+ target_shape = (240, 240, 155)
23
+ # Tính toán padding (thêm padding để đạt được kích thước mong muốn)
24
+ z_pad_before = (target_shape[0] - overlay_tensor.shape[2]) // 2
25
+ z_pad_after = target_shape[0] - overlay_tensor.shape[2] - z_pad_before
26
+
27
+ y_pad_before = (target_shape[1] - overlay_tensor.shape[3]) // 2
28
+ y_pad_after = target_shape[1] - overlay_tensor.shape[3] - y_pad_before
29
+
30
+ x_pad_before = (target_shape[2] - overlay_tensor.shape[4]) // 2
31
+ x_pad_after = target_shape[2] - overlay_tensor.shape[4] - x_pad_before
32
+
33
+ # Tạo padding (đệm đen)
34
+ padded_tensor = F.pad(overlay_tensor, (x_pad_before, x_pad_after, y_pad_before, y_pad_after, z_pad_before, z_pad_after), value=0)
35
+ assert padded_tensor.shape[2:] == target_shape, f"Expected shape {target_shape}, got {padded_tensor.shape[2:]}"
36
+ padded_tensor = padded_tensor.permute(0,2,3,4,1)
37
+ padded_slices = padded_tensor.squeeze(0).numpy()
38
+
39
+ for i in range(padded_slices.shape[2]):
40
+ padded_slices[:, :, i, :] = np.flipud(np.fliplr(padded_slices[:, :, i, :]))
41
+ padded_slices = padded_slices.astype(np.uint8)
42
+ affine = np.eye(4)
43
+ nii_image = nib.Nifti1Image(padded_slices, affine)
44
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.nii') as temp_file:
45
+ nii_file_path = temp_file.name
46
+ nib.save(nii_image, nii_file_path)
47
+
48
+ # Trả về đường dẫn đến file NIfTI đã lưu
49
+ return nii_file_path
50
+
51
+ except Exception as e:
52
+ return str(e)
53
+
54
+ def main():
55
+ # Định nghĩa giao diện Gradio
56
+ inputs = gr.File(label="Upload a ZIP file containing MRI modalities (flair, t1, t1ce, t2)")
57
+ outputs = gr.File(label="Segmentation Result (.nii)")
58
+
59
+ gr.Interface(
60
+ fn=predict_segmentation,
61
+ inputs=inputs,
62
+ outputs=outputs,
63
+ title="3D Brain Tumor Segmentation",
64
+ description="Upload a ZIP file containing MRI modalities (flair, t1, t1ce, t2).",
65
+ allow_flagging="never",
66
+ ).launch(show_error=True)
67
+
68
+ if __name__ == '__main__':
69
+ main()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ nibabel
4
+ tqdm
5
+ matplotlib
6
+ monai
7
+ scikit-learn
8
+ gdown
9
+ gradio==5.8.0