DuyTa commited on
Commit
67e51f5
Β·
verified Β·
1 Parent(s): 16e5ee4

6185cf2abcdb1f523dbfb37144730057281a2f2df5aad865ab54ef3f83045f8b

Browse files
Files changed (28) hide show
  1. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01661/BraTS2021_01661_t2.nii.gz +3 -0
  2. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01662/BraTS2021_01662_flair.nii.gz +3 -0
  3. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01662/BraTS2021_01662_seg.nii.gz +0 -0
  4. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01662/BraTS2021_01662_t1.nii.gz +3 -0
  5. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01662/BraTS2021_01662_t1ce.nii.gz +3 -0
  6. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01662/BraTS2021_01662_t2.nii.gz +3 -0
  7. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01663/BraTS2021_01663_flair.nii.gz +3 -0
  8. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01663/BraTS2021_01663_seg.nii.gz +0 -0
  9. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01663/BraTS2021_01663_t1.nii.gz +3 -0
  10. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01663/BraTS2021_01663_t1ce.nii.gz +3 -0
  11. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01663/BraTS2021_01663_t2.nii.gz +3 -0
  12. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01664/BraTS2021_01664_flair.nii.gz +3 -0
  13. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01664/BraTS2021_01664_seg.nii.gz +0 -0
  14. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01664/BraTS2021_01664_t1.nii.gz +3 -0
  15. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01664/BraTS2021_01664_t1ce.nii.gz +3 -0
  16. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01664/BraTS2021_01664_t2.nii.gz +3 -0
  17. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01665/BraTS2021_01665_flair.nii.gz +3 -0
  18. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01665/BraTS2021_01665_seg.nii.gz +0 -0
  19. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01665/BraTS2021_01665_t1.nii.gz +3 -0
  20. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01665/BraTS2021_01665_t1ce.nii.gz +3 -0
  21. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01665/BraTS2021_01665_t2.nii.gz +3 -0
  22. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01666/BraTS2021_01666_flair.nii.gz +3 -0
  23. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01666/BraTS2021_01666_seg.nii.gz +0 -0
  24. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01666/BraTS2021_01666_t1.nii.gz +3 -0
  25. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01666/BraTS2021_01666_t1ce.nii.gz +3 -0
  26. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01666/BraTS2021_01666_t2.nii.gz +3 -0
  27. info.json +0 -0
  28. main.ipynb +1478 -0
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01661/BraTS2021_01661_t2.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c505210bdc5a038f0feec158af86cbd51fe9044c2e322116b8dedf0d4cfa0ae8
3
+ size 2142198
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01662/BraTS2021_01662_flair.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82601944e7b5ed25b7d5260546bcf9fc05735380b1121c808a62a9444728710d
3
+ size 2245968
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01662/BraTS2021_01662_seg.nii.gz ADDED
Binary file (36.3 kB). View file
 
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01662/BraTS2021_01662_t1.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33191e411b78a8ccfabb50385530526178b7c8fe71af12c3b05284d845b4a809
3
+ size 2350622
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01662/BraTS2021_01662_t1ce.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ef3e7436f5911064d241bd785872de5953ea58b596d35069bad3daf7e915118
3
+ size 2363017
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01662/BraTS2021_01662_t2.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a42815be77caed00d9c499717b2b098dcf982d8a2066f9050bb975c18a0bd590
3
+ size 2058223
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01663/BraTS2021_01663_flair.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9bab8b5bfa4cdec3e9e3503a0573c320d9bc8ad7c6329124284982686506553
3
+ size 2563621
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01663/BraTS2021_01663_seg.nii.gz ADDED
Binary file (17 kB). View file
 
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01663/BraTS2021_01663_t1.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b839418ba30c21cac53504dbde89a91e3c831c97a0418db52dc60b6453c9859
3
+ size 2654382
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01663/BraTS2021_01663_t1ce.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c858f60833fd410ea37f9e769537096ce765a4222bd69f1e58e0cdc266c274a0
3
+ size 3038088
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01663/BraTS2021_01663_t2.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea430675569e79890f0452cfdd729c48a5548470186dd21fa5a7e4d80e602d60
3
+ size 3031822
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01664/BraTS2021_01664_flair.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a52c35931fa86a7454bbc8d1309eac4779b03f4a93b4fd0d553ae600c749c53
3
+ size 2414151
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01664/BraTS2021_01664_seg.nii.gz ADDED
Binary file (31.4 kB). View file
 
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01664/BraTS2021_01664_t1.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6dba2effcf0b5a9c500fab9841e438dd1bc187284ab4c69a98a32b7587afa051
3
+ size 2461810
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01664/BraTS2021_01664_t1ce.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e583078d125e3eb90d258eaab57cfc6163b3028671f0582cb8fed25c4b471e3
3
+ size 2815663
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01664/BraTS2021_01664_t2.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f2dd092f11b2e3e87e63c2d19f497c3051b3bf991b23ae22037a278514f6afd
3
+ size 2845901
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01665/BraTS2021_01665_flair.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:080b33accc7e6a96af2083cfd3db3d8b118caa4be48d5f34bca489d9e4fc2651
3
+ size 2297237
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01665/BraTS2021_01665_seg.nii.gz ADDED
Binary file (27.3 kB). View file
 
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01665/BraTS2021_01665_t1.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2da4767e4fc6070d179cc4353847d92f135e9e875acd3fe764fec45f4c7a14cc
3
+ size 2327765
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01665/BraTS2021_01665_t1ce.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c09a49c57ba59e96eb705f616e49588b128a42c7f78ed83e0355e4ef240ebf9
3
+ size 2693477
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01665/BraTS2021_01665_t2.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf5cb3c237e4533b456b0d10917adcd60398563458f4948b678807cdafe308c7
3
+ size 2639817
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01666/BraTS2021_01666_flair.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f89012c90e6f4a8b2226fcdb26874f9f5ef49c7ef145dbc3b69cc00816510a3
3
+ size 2029939
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01666/BraTS2021_01666_seg.nii.gz ADDED
Binary file (68.5 kB). View file
 
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01666/BraTS2021_01666_t1.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a23f6ac4f2feb8d5f91b920e3dc822c820403d7c698a3834943aba980d5a26b4
3
+ size 2244910
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01666/BraTS2021_01666_t1ce.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ef132d2943864bc055c694fa9bdf5dec9d4332f404c0dc5fad4b997584fdf6a
3
+ size 2268398
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_01666/BraTS2021_01666_t2.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b4a9a57fcd9a525d18c84b0f42fb608c20c1eeb2cbad690870f5d8f2d83719c
3
+ size 2181088
info.json ADDED
The diff for this file is too large to render. See raw diff
 
main.ipynb ADDED
@@ -0,0 +1,1478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "MONAI version: 1.4.dev2409\n",
13
+ "Numpy version: 1.26.2\n",
14
+ "Pytorch version: 1.13.0+cu116\n",
15
+ "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n",
16
+ "MONAI rev id: 46c1b228091283fba829280a5d747f4237f76ed0\n",
17
+ "MONAI __file__: /usr/local/lib/python3.9/site-packages/monai/__init__.py\n",
18
+ "\n",
19
+ "Optional dependencies:\n",
20
+ "Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.\n",
21
+ "ITK version: NOT INSTALLED or UNKNOWN VERSION.\n",
22
+ "Nibabel version: 5.2.1\n",
23
+ "scikit-image version: NOT INSTALLED or UNKNOWN VERSION.\n",
24
+ "scipy version: 1.11.4\n",
25
+ "Pillow version: 10.1.0\n",
26
+ "Tensorboard version: 2.16.2\n",
27
+ "gdown version: NOT INSTALLED or UNKNOWN VERSION.\n",
28
+ "TorchVision version: 0.14.0+cu116\n",
29
+ "tqdm version: 4.66.1\n",
30
+ "lmdb version: NOT INSTALLED or UNKNOWN VERSION.\n",
31
+ "psutil version: 5.9.8\n",
32
+ "pandas version: 2.2.1\n",
33
+ "einops version: 0.7.0\n",
34
+ "transformers version: 4.35.2\n",
35
+ "mlflow version: NOT INSTALLED or UNKNOWN VERSION.\n",
36
+ "pynrrd version: NOT INSTALLED or UNKNOWN VERSION.\n",
37
+ "clearml version: NOT INSTALLED or UNKNOWN VERSION.\n",
38
+ "\n",
39
+ "For details about installing the optional dependencies, please visit:\n",
40
+ " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
41
+ "\n"
42
+ ]
43
+ }
44
+ ],
45
+ "source": [
46
+ "\n",
47
+ "import matplotlib.pyplot as plt\n",
48
+ "import numpy as np\n",
49
+ "from monai.config import print_config\n",
50
+ "from monai.losses import DiceLoss\n",
51
+ "from monai.inferers import sliding_window_inference\n",
52
+ "from monai.transforms import MapTransform\n",
53
+ "from monai.data import DataLoader, Dataset\n",
54
+ "from monai.utils import set_determinism\n",
55
+ "from monai import transforms\n",
56
+ "import torch\n",
57
+ "\n",
58
+ "print_config()"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 3,
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "set_determinism(seed=0)"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 9,
73
+ "metadata": {},
74
+ "outputs": [
75
+ {
76
+ "name": "stdout",
77
+ "output_type": "stream",
78
+ "text": [
79
+ "Sα»‘ lượng mαΊ«u trong '/app/brats_2021_task1/BraTS2021_Training_Data' lΓ : 1251\n"
80
+ ]
81
+ }
82
+ ],
83
+ "source": [
84
+ "import os\n",
85
+ "\n",
86
+ "parent_folder_path = '/app/brats_2021_task1/BraTS2021_Training_Data'\n",
87
+ "subfolders = [f for f in os.listdir(parent_folder_path) if os.path.isdir(os.path.join(parent_folder_path, f))]\n",
88
+ "num_folders = len(subfolders)\n",
89
+ "print(f\"Sα»‘ lượng mαΊ«u trong '{parent_folder_path}' lΓ : {num_folders}\")"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "import os\n",
99
+ "import json\n",
100
+ "\n",
101
+ "folder_data = []\n",
102
+ "\n",
103
+ "for fold_number in os.listdir(parent_folder_path):\n",
104
+ " fold_path = os.path.join(parent_folder_path, fold_number)\n",
105
+ "\n",
106
+ " if os.path.isdir(fold_path):\n",
107
+ " entry = {\"fold\": 0, \"image\": [], \"label\": \"\"}\n",
108
+ "\n",
109
+ " for file_type in ['flair', 't1ce', 't1', 't2']:\n",
110
+ " file_name = f\"{fold_number}_{file_type}.nii.gz\"\n",
111
+ " file_path = os.path.join(fold_path, file_name)\n",
112
+ "\n",
113
+ " if os.path.exists(file_path):\n",
114
+ "\n",
115
+ " entry[\"image\"].append(os.path.abspath(file_path))\n",
116
+ "\n",
117
+ " label_name = f\"{fold_number}_seg.nii.gz\"\n",
118
+ " label_path = os.path.join(fold_path, label_name)\n",
119
+ " if os.path.exists(label_path):\n",
120
+ " entry[\"label\"] = os.path.abspath(label_path)\n",
121
+ "\n",
122
+ " folder_data.append(entry)\n",
123
+ "\n",
124
+ "\n",
125
+ "json_data = {\"training\": folder_data}\n",
126
+ "\n",
127
+ "json_file_path = '/app/info.json'\n",
128
+ "with open(json_file_path, 'w') as json_file:\n",
129
+ " json.dump(json_data, json_file, indent=2)\n",
130
+ "\n",
131
+ "print(f\"ThΓ΄ng tin Δ‘Γ£ được ghi vΓ o {json_file_path}\")\n"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": 5,
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):\n",
141
+ " \"\"\"\n",
142
+ " Convert labels to multi channels based on brats classes:\n",
143
+ " label 1 is the necrotic and non-enhancing tumor core\n",
144
+ " label 2 is the peritumoral edema\n",
145
+ " label 4 is the GD-enhancing tumor\n",
146
+ " The possible classes are TC (Tumor core), WT (Whole tumor)\n",
147
+ " and ET (Enhancing tumor).\n",
148
+ "\n",
149
+ " \"\"\"\n",
150
+ "\n",
151
+ " def __call__(self, data):\n",
152
+ " d = dict(data)\n",
153
+ " for key in self.keys:\n",
154
+ " result = []\n",
155
+ " # merge label 1 and label 4 to construct TC\n",
156
+ " result.append(np.logical_or(d[key] == 1, d[key] == 4))\n",
157
+ " # merge labels 1, 2 and 4 to construct WT\n",
158
+ " result.append(\n",
159
+ " np.logical_or(\n",
160
+ " np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2\n",
161
+ " )\n",
162
+ " )\n",
163
+ " # label 4 is ET\n",
164
+ " result.append(d[key] == 4)\n",
165
+ " d[key] = np.stack(result, axis=0).astype(np.float32)\n",
166
+ " return d"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": 6,
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "def datafold_read(datalist, basedir, fold=0, key=\"training\"):\n",
176
+ " with open(datalist) as f:\n",
177
+ " json_data = json.load(f)\n",
178
+ "\n",
179
+ " json_data = json_data[key]\n",
180
+ "\n",
181
+ " for d in json_data:\n",
182
+ " for k in d:\n",
183
+ " if isinstance(d[k], list):\n",
184
+ " d[k] = [os.path.join(basedir, iv) for iv in d[k]]\n",
185
+ " elif isinstance(d[k], str):\n",
186
+ " d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]\n",
187
+ "\n",
188
+ " tr = []\n",
189
+ " val = []\n",
190
+ " for d in json_data:\n",
191
+ " if \"fold\" in d and d[\"fold\"] == fold:\n",
192
+ " val.append(d)\n",
193
+ " else:\n",
194
+ " tr.append(d)\n",
195
+ "\n",
196
+ " return tr, val"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": 7,
202
+ "metadata": {},
203
+ "outputs": [],
204
+ "source": [
205
+ "def split_train_test(datalist, basedir, fold,test_size = 0.2, volume : float = None) :\n",
206
+ " train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=fold)\n",
207
+ " from sklearn.model_selection import train_test_split\n",
208
+ " if volume != None :\n",
209
+ " train_files, _ = train_test_split(train_files,test_size=volume,random_state=42)\n",
210
+ " \n",
211
+ " train_files,validation_files = train_test_split(train_files,test_size=test_size, random_state=42)\n",
212
+ " \n",
213
+ " validation_files,test_files = train_test_split(validation_files,test_size=test_size, random_state=42)\n",
214
+ " return train_files, validation_files, test_files"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": 8,
220
+ "metadata": {},
221
+ "outputs": [],
222
+ "source": [
223
+ "def get_loader(batch_size, data_dir, json_list, fold, roi,volume :float = None,test_size = 0.2):\n",
224
+ " train_files,validation_files,test_files = split_train_test(datalist = json_list,basedir = data_dir,test_size=test_size,fold = fold,volume= volume)\n",
225
+ " \n",
226
+ " train_transform = transforms.Compose(\n",
227
+ " [\n",
228
+ " transforms.LoadImaged(keys=[\"image\", \"label\"]),\n",
229
+ " transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n",
230
+ " transforms.CropForegroundd(\n",
231
+ " keys=[\"image\", \"label\"],\n",
232
+ " source_key=\"image\",\n",
233
+ " k_divisible=[roi[0], roi[1], roi[2]],\n",
234
+ " ),\n",
235
+ " transforms.RandSpatialCropd(\n",
236
+ " keys=[\"image\", \"label\"],\n",
237
+ " roi_size=[roi[0], roi[1], roi[2]],\n",
238
+ " random_size=False,\n",
239
+ " ),\n",
240
+ " transforms.RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=0),\n",
241
+ " transforms.RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=1),\n",
242
+ " transforms.RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=2),\n",
243
+ " transforms.NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n",
244
+ " transforms.RandScaleIntensityd(keys=\"image\", factors=0.1, prob=1.0),\n",
245
+ " transforms.RandShiftIntensityd(keys=\"image\", offsets=0.1, prob=1.0),\n",
246
+ " ]\n",
247
+ " )\n",
248
+ " val_transform = transforms.Compose(\n",
249
+ " [\n",
250
+ " transforms.LoadImaged(keys=[\"image\", \"label\"]),\n",
251
+ " transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n",
252
+ " transforms.NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n",
253
+ " ]\n",
254
+ " )\n",
255
+ "\n",
256
+ " train_ds = Dataset(data=train_files, transform=train_transform)\n",
257
+ " train_loader = DataLoader(\n",
258
+ " train_ds,\n",
259
+ " batch_size=batch_size,\n",
260
+ " shuffle=True,\n",
261
+ " num_workers=2,\n",
262
+ " pin_memory=True,\n",
263
+ " )\n",
264
+ " val_ds = Dataset(data=validation_files, transform=val_transform)\n",
265
+ " val_loader = DataLoader(\n",
266
+ " val_ds,\n",
267
+ " batch_size=1,\n",
268
+ " shuffle=False,\n",
269
+ " num_workers=2,\n",
270
+ " pin_memory=True,\n",
271
+ " )\n",
272
+ " test_ds = Dataset(data=test_files, transform=val_transform)\n",
273
+ " test_loader = DataLoader(\n",
274
+ " test_ds,\n",
275
+ " batch_size=1,\n",
276
+ " shuffle=False,\n",
277
+ " num_workers=2,\n",
278
+ " pin_memory=True,\n",
279
+ " )\n",
280
+ " return train_loader, val_loader,test_loader"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": 9,
286
+ "metadata": {},
287
+ "outputs": [
288
+ {
289
+ "name": "stderr",
290
+ "output_type": "stream",
291
+ "text": [
292
+ "/usr/local/lib/python3.9/site-packages/monai/utils/deprecate_utils.py:321: FutureWarning: monai.transforms.croppad.dictionary CropForegroundd.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.\n",
293
+ " warn_deprecated(argname, msg, warning_category)\n"
294
+ ]
295
+ }
296
+ ],
297
+ "source": [
298
+ "import json\n",
299
+ "data_dir = \"/app/brats_2021_task1\"\n",
300
+ "json_list = \"/app/info.json\"\n",
301
+ "roi = (128, 128, 128)\n",
302
+ "batch_size = 1\n",
303
+ "sw_batch_size = 2\n",
304
+ "fold = 1\n",
305
+ "infer_overlap = 0.5\n",
306
+ "max_epochs = 100\n",
307
+ "val_every = 10\n",
308
+ "train_loader, val_loader,test_loader = get_loader(batch_size, data_dir, json_list, fold, roi, volume=0.5, test_size=0.2)"
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "code",
313
+ "execution_count": 45,
314
+ "metadata": {},
315
+ "outputs": [
316
+ {
317
+ "data": {
318
+ "text/plain": [
319
+ "100"
320
+ ]
321
+ },
322
+ "execution_count": 45,
323
+ "metadata": {},
324
+ "output_type": "execute_result"
325
+ }
326
+ ],
327
+ "source": [
328
+ "len(val_loader)"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": 10,
334
+ "metadata": {},
335
+ "outputs": [],
336
+ "source": [
337
+ "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
338
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "markdown",
343
+ "metadata": {},
344
+ "source": [
345
+ "#### Model design, base on SegResNet, VAE and TransBTS"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "code",
350
+ "execution_count": 11,
351
+ "metadata": {},
352
+ "outputs": [],
353
+ "source": [
354
+ "import torch\n",
355
+ "import torch.nn as nn\n",
356
+ "\n",
357
+ "#Re-use from encoder block\n",
358
+ "def normalization(planes, norm = 'instance'):\n",
359
+ " if norm == 'bn':\n",
360
+ " m = nn.BatchNorm3d(planes)\n",
361
+ " elif norm == 'gn':\n",
362
+ " m = nn.GroupNorm(8, planes)\n",
363
+ " elif norm == 'instance':\n",
364
+ " m = nn.InstanceNorm3d(planes)\n",
365
+ " else:\n",
366
+ " raise ValueError(\"Does not support this kind of norm.\")\n",
367
+ " return m\n",
368
+ "class ResNetBlock(nn.Module):\n",
369
+ " def __init__(self, in_channels, norm = 'instance'):\n",
370
+ " super().__init__()\n",
371
+ " self.resnetblock = nn.Sequential(\n",
372
+ " normalization(in_channels, norm = norm),\n",
373
+ " nn.LeakyReLU(0.2, inplace=True),\n",
374
+ " nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1),\n",
375
+ " normalization(in_channels, norm = norm),\n",
376
+ " nn.LeakyReLU(0.2, inplace=True),\n",
377
+ " nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1)\n",
378
+ " )\n",
379
+ "\n",
380
+ " def forward(self, x):\n",
381
+ " y = self.resnetblock(x)\n",
382
+ " return y + x"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": 12,
388
+ "metadata": {},
389
+ "outputs": [],
390
+ "source": [
391
+ "\n",
392
+ "\n",
393
+ "from torch.nn import functional as F\n",
394
+ "\n",
395
+ "def calculate_total_dimension(a):\n",
396
+ " res = 1\n",
397
+ " for x in a:\n",
398
+ " res *= x\n",
399
+ " return res\n",
400
+ "\n",
401
+ "class VAE(nn.Module):\n",
402
+ " def __init__(self, input_shape, latent_dim, num_channels):\n",
403
+ " super().__init__()\n",
404
+ " self.input_shape = input_shape\n",
405
+ " self.in_channels = input_shape[1] #input_shape[0] is batch size\n",
406
+ " self.latent_dim = latent_dim\n",
407
+ " self.encoder_channels = self.in_channels // 16\n",
408
+ "\n",
409
+ " #Encoder\n",
410
+ " self.VAE_reshape = nn.Conv3d(self.in_channels, self.encoder_channels,\n",
411
+ " kernel_size = 3, stride = 2, padding=1)\n",
412
+ " # self.VAE_reshape = nn.Sequential(\n",
413
+ " # nn.GroupNorm(8, self.in_channels),\n",
414
+ " # nn.ReLU(),\n",
415
+ " # nn.Conv3d(self.in_channels, self.encoder_channels,\n",
416
+ " # kernel_size = 3, stride = 2, padding=1),\n",
417
+ " # )\n",
418
+ "\n",
419
+ " flatten_input_shape = calculate_total_dimension(input_shape)\n",
420
+ " flatten_input_shape_after_vae_reshape = \\\n",
421
+ " flatten_input_shape * self.encoder_channels // (8 * self.in_channels)\n",
422
+ "\n",
423
+ " #Convert from total dimension to latent space\n",
424
+ " self.to_latent_space = nn.Linear(\n",
425
+ " flatten_input_shape_after_vae_reshape // self.in_channels, 1)\n",
426
+ "\n",
427
+ " self.mean = nn.Linear(self.in_channels, self.latent_dim)\n",
428
+ " self.logvar = nn.Linear(self.in_channels, self.latent_dim)\n",
429
+ "# self.epsilon = nn.Parameter(torch.randn(1, latent_dim))\n",
430
+ "\n",
431
+ " #Decoder\n",
432
+ " self.to_original_dimension = nn.Linear(self.latent_dim, flatten_input_shape_after_vae_reshape)\n",
433
+ " self.Reconstruct = nn.Sequential(\n",
434
+ " nn.LeakyReLU(0.2, inplace=True),\n",
435
+ " nn.Conv3d(\n",
436
+ " self.encoder_channels, self.in_channels,\n",
437
+ " stride = 1, kernel_size = 1),\n",
438
+ " nn.Upsample(scale_factor=2, mode = 'nearest'),\n",
439
+ "\n",
440
+ " nn.Conv3d(\n",
441
+ " self.in_channels, self.in_channels // 2,\n",
442
+ " stride = 1, kernel_size = 1),\n",
443
+ " nn.Upsample(scale_factor=2, mode = 'nearest'),\n",
444
+ " ResNetBlock(self.in_channels // 2),\n",
445
+ "\n",
446
+ " nn.Conv3d(\n",
447
+ " self.in_channels // 2, self.in_channels // 4,\n",
448
+ " stride = 1, kernel_size = 1),\n",
449
+ " nn.Upsample(scale_factor=2, mode = 'nearest'),\n",
450
+ " ResNetBlock(self.in_channels // 4),\n",
451
+ "\n",
452
+ " nn.Conv3d(\n",
453
+ " self.in_channels // 4, self.in_channels // 8,\n",
454
+ " stride = 1, kernel_size = 1),\n",
455
+ " nn.Upsample(scale_factor=2, mode = 'nearest'),\n",
456
+ " ResNetBlock(self.in_channels // 8),\n",
457
+ "\n",
458
+ " nn.InstanceNorm3d(self.in_channels // 8),\n",
459
+ " nn.LeakyReLU(0.2, inplace=True),\n",
460
+ " nn.Conv3d(\n",
461
+ " self.in_channels // 8, num_channels,\n",
462
+ " kernel_size = 3, padding = 1),\n",
463
+ "# nn.Sigmoid()\n",
464
+ " )\n",
465
+ "\n",
466
+ "\n",
467
+ " def forward(self, x): #x has shape = input_shape\n",
468
+ " #Encoder\n",
469
+ " # print(x.shape)\n",
470
+ " x = self.VAE_reshape(x)\n",
471
+ " shape = x.shape\n",
472
+ "\n",
473
+ " x = x.view(self.in_channels, -1)\n",
474
+ " x = self.to_latent_space(x)\n",
475
+ " x = x.view(1, self.in_channels)\n",
476
+ "\n",
477
+ " mean = self.mean(x)\n",
478
+ " logvar = self.logvar(x)\n",
479
+ "# sigma = torch.exp(0.5 * logvar)\n",
480
+ " # Reparameter\n",
481
+ " epsilon = torch.randn_like(logvar)\n",
482
+ " sample = mean + epsilon * torch.exp(0.5*logvar)\n",
483
+ "\n",
484
+ " #Decoder\n",
485
+ " y = self.to_original_dimension(sample)\n",
486
+ " y = y.view(*shape)\n",
487
+ " return self.Reconstruct(y), mean, logvar\n",
488
+ " def total_params(self):\n",
489
+ " total = sum(p.numel() for p in self.parameters())\n",
490
+ " return format(total, ',')\n",
491
+ "\n",
492
+ " def total_trainable_params(self):\n",
493
+ " total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)\n",
494
+ " return format(total_trainable, ',')\n",
495
+ "\n",
496
+ "\n",
497
+ "# x = torch.rand((1, 256, 16, 16, 16))\n",
498
+ "# vae = VAE(input_shape = x.shape, latent_dim = 256, num_channels = 4)\n",
499
+ "# y = vae(x)\n",
500
+ "# print(y[0].shape, y[1].shape, y[2].shape)\n",
501
+ "# print(vae.total_trainable_params())\n"
502
+ ]
503
+ },
504
+ {
505
+ "cell_type": "code",
506
+ "execution_count": 13,
507
+ "metadata": {},
508
+ "outputs": [],
509
+ "source": [
510
+ "import torch\n",
511
+ "from torch import nn\n",
512
+ "\n",
513
+ "from einops import rearrange\n",
514
+ "from einops.layers.torch import Rearrange\n",
515
+ "\n",
516
+ "def pair(t):\n",
517
+ " return t if isinstance(t, tuple) else (t, t)\n",
518
+ "\n",
519
+ "\n",
520
+ "class PreNorm(nn.Module):\n",
521
+ " def __init__(self, dim, function):\n",
522
+ " super().__init__()\n",
523
+ " self.norm = nn.LayerNorm(dim)\n",
524
+ " self.function = function\n",
525
+ "\n",
526
+ " def forward(self, x):\n",
527
+ " return self.function(self.norm(x))\n",
528
+ "\n",
529
+ "\n",
530
+ "class FeedForward(nn.Module):\n",
531
+ " def __init__(self, dim, hidden_dim, dropout = 0.0):\n",
532
+ " super().__init__()\n",
533
+ " self.net = nn.Sequential(\n",
534
+ " nn.Linear(dim, hidden_dim),\n",
535
+ " nn.GELU(),\n",
536
+ " nn.Dropout(dropout),\n",
537
+ " nn.Linear(hidden_dim, dim),\n",
538
+ " nn.Dropout(dropout)\n",
539
+ " )\n",
540
+ "\n",
541
+ " def forward(self, x):\n",
542
+ " return self.net(x)\n",
543
+ "\n",
544
+ "class Attention(nn.Module):\n",
545
+ " def __init__(self, dim, heads, dim_head, dropout = 0.0):\n",
546
+ " super().__init__()\n",
547
+ " all_head_size = heads * dim_head\n",
548
+ " project_out = not (heads == 1 and dim_head == dim)\n",
549
+ "\n",
550
+ " self.heads = heads\n",
551
+ " self.scale = dim_head ** -0.5\n",
552
+ "\n",
553
+ " self.softmax = nn.Softmax(dim = -1)\n",
554
+ " self.to_qkv = nn.Linear(dim, all_head_size * 3, bias = False)\n",
555
+ "\n",
556
+ " self.to_out = nn.Sequential(\n",
557
+ " nn.Linear(all_head_size, dim),\n",
558
+ " nn.Dropout(dropout)\n",
559
+ " ) if project_out else nn.Identity()\n",
560
+ "\n",
561
+ " def forward(self, x):\n",
562
+ " qkv = self.to_qkv(x).chunk(3, dim = -1)\n",
563
+ " #(batch, heads * dim_head) -> (batch, all_head_size)\n",
564
+ " q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)\n",
565
+ "\n",
566
+ " dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale\n",
567
+ "\n",
568
+ " atten = self.softmax(dots)\n",
569
+ "\n",
570
+ " out = torch.matmul(atten, v)\n",
571
+ " out = rearrange(out, 'b h n d -> b n (h d)')\n",
572
+ " return self.to_out(out)\n",
573
+ "\n",
574
+ "class Transformer(nn.Module):\n",
575
+ " def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.0):\n",
576
+ " super().__init__()\n",
577
+ " self.layers = nn.ModuleList([])\n",
578
+ " for _ in range(depth):\n",
579
+ " self.layers.append(nn.ModuleList([\n",
580
+ " PreNorm(dim, Attention(dim, heads, dim_head, dropout)),\n",
581
+ " PreNorm(dim, FeedForward(dim, mlp_dim, dropout))\n",
582
+ " ]))\n",
583
+ " def forward(self, x):\n",
584
+ " for attention, feedforward in self.layers:\n",
585
+ " x = attention(x) + x\n",
586
+ " x = feedforward(x) + x\n",
587
+ " return x\n",
588
+ "\n",
589
+ "class FixedPositionalEncoding(nn.Module):\n",
590
+ " def __init__(self, embedding_dim, max_length=768):\n",
591
+ " super(FixedPositionalEncoding, self).__init__()\n",
592
+ "\n",
593
+ " pe = torch.zeros(max_length, embedding_dim)\n",
594
+ " position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)\n",
595
+ " div_term = torch.exp(\n",
596
+ " torch.arange(0, embedding_dim, 2).float()\n",
597
+ " * (-torch.log(torch.tensor(10000.0)) / embedding_dim)\n",
598
+ " )\n",
599
+ " pe[:, 0::2] = torch.sin(position * div_term)\n",
600
+ " pe[:, 1::2] = torch.cos(position * div_term)\n",
601
+ " pe = pe.unsqueeze(0).transpose(0, 1)\n",
602
+ " self.register_buffer('pe', pe)\n",
603
+ "\n",
604
+ " def forward(self, x):\n",
605
+ " x = x + self.pe[: x.size(0), :]\n",
606
+ " return x\n",
607
+ "\n",
608
+ "\n",
609
+ "class LearnedPositionalEncoding(nn.Module):\n",
610
+ " def __init__(self, embedding_dim, seq_length):\n",
611
+ " super(LearnedPositionalEncoding, self).__init__()\n",
612
+ " self.seq_length = seq_length\n",
613
+ " self.position_embeddings = nn.Parameter(torch.zeros(1, seq_length, embedding_dim)) #8x\n",
614
+ "\n",
615
+ " def forward(self, x, position_ids=None):\n",
616
+ " position_embeddings = self.position_embeddings\n",
617
+ "# print(x.shape, self.position_embeddings.shape)\n",
618
+ " return x + position_embeddings"
619
+ ]
620
+ },
621
+ {
622
+ "cell_type": "code",
623
+ "execution_count": 14,
624
+ "metadata": {},
625
+ "outputs": [],
626
+ "source": [
627
+ "### Encoder ####\n",
628
+ "import torch.nn as nn\n",
629
+ "import torch.nn.functional as F\n",
630
+ "\n",
631
+ "class InitConv(nn.Module):\n",
632
+ " def __init__(self, in_channels = 4, out_channels = 16, dropout = 0.2):\n",
633
+ " super().__init__()\n",
634
+ " self.layer = nn.Sequential(\n",
635
+ " nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1),\n",
636
+ " nn.Dropout3d(dropout)\n",
637
+ " )\n",
638
+ " def forward(self, x):\n",
639
+ " y = self.layer(x)\n",
640
+ " return y\n",
641
+ "\n",
642
+ "\n",
643
+ "class DownSample(nn.Module):\n",
644
+ " def __init__(self, in_channels, out_channels):\n",
645
+ " super().__init__()\n",
646
+ " self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1)\n",
647
+ " def forward(self, x):\n",
648
+ " return self.conv(x)\n",
649
+ "\n",
650
+ "class Encoder(nn.Module):\n",
651
+ " def __init__(self, in_channels, base_channels, dropout = 0.2):\n",
652
+ " super().__init__()\n",
653
+ "\n",
654
+ " self.init_conv = InitConv(in_channels, base_channels, dropout = dropout)\n",
655
+ " self.encoder_block1 = ResNetBlock(in_channels = base_channels)\n",
656
+ " self.encoder_down1 = DownSample(base_channels, base_channels * 2)\n",
657
+ "\n",
658
+ " self.encoder_block2_1 = ResNetBlock(base_channels * 2)\n",
659
+ " self.encoder_block2_2 = ResNetBlock(base_channels * 2)\n",
660
+ " self.encoder_down2 = DownSample(base_channels * 2, base_channels * 4)\n",
661
+ "\n",
662
+ " self.encoder_block3_1 = ResNetBlock(base_channels * 4)\n",
663
+ " self.encoder_block3_2 = ResNetBlock(base_channels * 4)\n",
664
+ " self.encoder_down3 = DownSample(base_channels * 4, base_channels * 8)\n",
665
+ "\n",
666
+ " self.encoder_block4_1 = ResNetBlock(base_channels * 8)\n",
667
+ " self.encoder_block4_2 = ResNetBlock(base_channels * 8)\n",
668
+ " self.encoder_block4_3 = ResNetBlock(base_channels * 8)\n",
669
+ " self.encoder_block4_4 = ResNetBlock(base_channels * 8)\n",
670
+ " # self.encoder_down3 = EncoderDown(base_channels * 8, base_channels * 16)\n",
671
+ " def forward(self, x):\n",
672
+ " x = self.init_conv(x) #(1, 16, 128, 128, 128)\n",
673
+ "\n",
674
+ " x1 = self.encoder_block1(x)\n",
675
+ " x1_down = self.encoder_down1(x1) #(1, 32, 64, 64, 64)\n",
676
+ "\n",
677
+ " x2 = self.encoder_block2_2(self.encoder_block2_1(x1_down))\n",
678
+ " x2_down = self.encoder_down2(x2) #(1, 64, 32, 32, 32)\n",
679
+ "\n",
680
+ " x3 = self.encoder_block3_2(self.encoder_block3_1(x2_down))\n",
681
+ " x3_down = self.encoder_down3(x3) #(1, 128, 16, 16, 16)\n",
682
+ "\n",
683
+ " output = self.encoder_block4_4(\n",
684
+ " self.encoder_block4_3(\n",
685
+ " self.encoder_block4_2(\n",
686
+ " self.encoder_block4_1(x3_down)))) #(1, 256, 16, 16, 16)\n",
687
+ " return x1, x2, x3, output\n",
688
+ "\n",
689
+ "# x = torch.rand((1, 4, 128, 128, 128))\n",
690
+ "# Enc = Encoder(4, 32)\n",
691
+ "# _, _, _, y = Enc(x)\n",
692
+ "# print(y.shape) (1,256,16,16)"
693
+ ]
694
+ },
695
+ {
696
+ "cell_type": "code",
697
+ "execution_count": 15,
698
+ "metadata": {},
699
+ "outputs": [],
700
+ "source": [
701
+ "### Decoder ####\n",
702
+ "\n",
703
+ "import torch\n",
704
+ "import torch.nn as nn\n",
705
+ "\n",
706
+ "\n",
707
+ "class Upsample(nn.Module):\n",
708
+ " def __init__(self, in_channel, out_channel):\n",
709
+ " super().__init__()\n",
710
+ " self.conv1 = nn.Conv3d(in_channel, out_channel, kernel_size = 1)\n",
711
+ " self.deconv = nn.ConvTranspose3d(out_channel, out_channel, kernel_size = 2, stride = 2)\n",
712
+ " self.conv2 = nn.Conv3d(out_channel * 2, out_channel, kernel_size = 1)\n",
713
+ "\n",
714
+ " def forward(self, prev, x):\n",
715
+ " x = self.deconv(self.conv1(x))\n",
716
+ " y = torch.cat((prev, x), dim = 1)\n",
717
+ " return self.conv2(y)\n",
718
+ "\n",
719
+ "class FinalConv(nn.Module): # Input channels are equal to output channels\n",
720
+ " def __init__(self, in_channels, out_channels=32, norm=\"instance\"):\n",
721
+ " super(FinalConv, self).__init__()\n",
722
+ " if norm == \"batch\":\n",
723
+ " norm_layer = nn.BatchNorm3d(num_features=in_channels)\n",
724
+ " elif norm == \"group\":\n",
725
+ " norm_layer = nn.GroupNorm(num_groups=8, num_channels=in_channels)\n",
726
+ " elif norm == 'instance':\n",
727
+ " norm_layer = nn.InstanceNorm3d(in_channels)\n",
728
+ "\n",
729
+ " self.layer = nn.Sequential(\n",
730
+ " norm_layer,\n",
731
+ " nn.LeakyReLU(0.2, inplace=True),\n",
732
+ " nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)\n",
733
+ " )\n",
734
+ " def forward(self, x):\n",
735
+ " return self.layer(x)\n",
736
+ "\n",
737
+ "class Decoder(nn.Module):\n",
738
+ " def __init__(self, img_dim, patch_dim, embedding_dim, num_classes = 3):\n",
739
+ " super().__init__()\n",
740
+ " self.img_dim = img_dim\n",
741
+ " self.patch_dim = patch_dim\n",
742
+ " self.embedding_dim = embedding_dim\n",
743
+ "\n",
744
+ " self.decoder_upsample_1 = Upsample(128, 64)\n",
745
+ " self.decoder_block_1 = ResNetBlock(64)\n",
746
+ "\n",
747
+ " self.decoder_upsample_2 = Upsample(64, 32)\n",
748
+ " self.decoder_block_2 = ResNetBlock(32)\n",
749
+ "\n",
750
+ " self.decoder_upsample_3 = Upsample(32, 16)\n",
751
+ " self.decoder_block_3 = ResNetBlock(16)\n",
752
+ "\n",
753
+ " self.endconv = FinalConv(16, num_classes)\n",
754
+ " # self.normalize = nn.Sigmoid()\n",
755
+ "\n",
756
+ " def forward(self, x1, x2, x3, x):\n",
757
+ " x = self.decoder_upsample_1(x3, x)\n",
758
+ " x = self.decoder_block_1(x)\n",
759
+ "\n",
760
+ " x = self.decoder_upsample_2(x2, x)\n",
761
+ " x = self.decoder_block_2(x)\n",
762
+ "\n",
763
+ " x = self.decoder_upsample_3(x1, x)\n",
764
+ " x = self.decoder_block_3(x)\n",
765
+ "\n",
766
+ " y = self.endconv(x)\n",
767
+ " return y"
768
+ ]
769
+ },
770
+ {
771
+ "cell_type": "code",
772
+ "execution_count": 16,
773
+ "metadata": {},
774
+ "outputs": [],
775
+ "source": [
776
+ "class FeatureMapping(nn.Module):\n",
777
+ " def __init__(self, in_channel, out_channel, norm = 'instance'):\n",
778
+ " super().__init__()\n",
779
+ " if norm == 'bn':\n",
780
+ " norm_layer_1 = nn.BatchNorm3d(out_channel)\n",
781
+ " norm_layer_2 = nn.BatchNorm3d(out_channel)\n",
782
+ " elif norm == 'gn':\n",
783
+ " norm_layer_1 = nn.GroupNorm(8, out_channel)\n",
784
+ " norm_layer_2 = nn.GroupNorm(8, out_channel)\n",
785
+ " elif norm == 'instance':\n",
786
+ " norm_layer_1 = nn.InstanceNorm3d(out_channel)\n",
787
+ " norm_layer_2 = nn.InstanceNorm3d(out_channel)\n",
788
+ " self.feature_mapping = nn.Sequential(\n",
789
+ " nn.Conv3d(in_channel, out_channel, kernel_size = 3, padding = 1),\n",
790
+ " norm_layer_1,\n",
791
+ " nn.LeakyReLU(0.2, inplace=True),\n",
792
+ " nn.Conv3d(out_channel, out_channel, kernel_size = 3, padding = 1),\n",
793
+ " norm_layer_2,\n",
794
+ " nn.LeakyReLU(0.2, inplace=True)\n",
795
+ " )\n",
796
+ "\n",
797
+ " def forward(self, x):\n",
798
+ " return self.feature_mapping(x)\n",
799
+ "\n",
800
+ "\n",
801
+ "class FeatureMapping1(nn.Module):\n",
802
+ " def __init__(self, in_channel, norm = 'instance'):\n",
803
+ " super().__init__()\n",
804
+ " if norm == 'bn':\n",
805
+ " norm_layer_1 = nn.BatchNorm3d(in_channel)\n",
806
+ " norm_layer_2 = nn.BatchNorm3d(in_channel)\n",
807
+ " elif norm == 'gn':\n",
808
+ " norm_layer_1 = nn.GroupNorm(8, in_channel)\n",
809
+ " norm_layer_2 = nn.GroupNorm(8, in_channel)\n",
810
+ " elif norm == 'instance':\n",
811
+ " norm_layer_1 = nn.InstanceNorm3d(in_channel)\n",
812
+ " norm_layer_2 = nn.InstanceNorm3d(in_channel)\n",
813
+ " self.feature_mapping1 = nn.Sequential(\n",
814
+ " nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),\n",
815
+ " norm_layer_1,\n",
816
+ " nn.LeakyReLU(0.2, inplace=True),\n",
817
+ " nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),\n",
818
+ " norm_layer_2,\n",
819
+ " nn.LeakyReLU(0.2, inplace=True)\n",
820
+ " )\n",
821
+ " def forward(self, x):\n",
822
+ " y = self.feature_mapping1(x)\n",
823
+ " return x + y #Resnet Like"
824
+ ]
825
+ },
826
+ {
827
+ "cell_type": "code",
828
+ "execution_count": 17,
829
+ "metadata": {},
830
+ "outputs": [],
831
+ "source": [
832
+ "\n",
833
+ "class SegTransVAE(nn.Module):\n",
834
+ " def __init__(self, img_dim, patch_dim, num_channels, num_classes,\n",
835
+ " embedding_dim, num_heads, num_layers, hidden_dim, in_channels_vae,\n",
836
+ " dropout = 0.0, attention_dropout = 0.0,\n",
837
+ " conv_patch_representation = True, positional_encoding = 'learned',\n",
838
+ " use_VAE = False):\n",
839
+ "\n",
840
+ " super().__init__()\n",
841
+ " assert embedding_dim % num_heads == 0\n",
842
+ " assert img_dim[0] % patch_dim == 0 and img_dim[1] % patch_dim == 0 and img_dim[2] % patch_dim == 0\n",
843
+ "\n",
844
+ " self.img_dim = img_dim\n",
845
+ " self.embedding_dim = embedding_dim\n",
846
+ " self.num_heads = num_heads\n",
847
+ " self.num_classes = num_classes\n",
848
+ " self.patch_dim = patch_dim\n",
849
+ " self.num_channels = num_channels\n",
850
+ " self.in_channels_vae = in_channels_vae\n",
851
+ " self.dropout = dropout\n",
852
+ " self.attention_dropout = attention_dropout\n",
853
+ " self.conv_patch_representation = conv_patch_representation\n",
854
+ " self.use_VAE = use_VAE\n",
855
+ "\n",
856
+ " self.num_patches = int((img_dim[0] // patch_dim) * (img_dim[1] // patch_dim) * (img_dim[2] // patch_dim))\n",
857
+ " self.seq_length = self.num_patches\n",
858
+ " self.flatten_dim = 128 * num_channels\n",
859
+ "\n",
860
+ " self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)\n",
861
+ " if positional_encoding == \"learned\":\n",
862
+ " self.position_encoding = LearnedPositionalEncoding(\n",
863
+ " self.embedding_dim, self.seq_length\n",
864
+ " )\n",
865
+ " elif positional_encoding == \"fixed\":\n",
866
+ " self.position_encoding = FixedPositionalEncoding(\n",
867
+ " self.embedding_dim,\n",
868
+ " )\n",
869
+ " self.pe_dropout = nn.Dropout(self.dropout)\n",
870
+ "\n",
871
+ " self.transformer = Transformer(\n",
872
+ " embedding_dim, num_layers, num_heads, embedding_dim // num_heads, hidden_dim, dropout\n",
873
+ " )\n",
874
+ " self.pre_head_ln = nn.LayerNorm(embedding_dim)\n",
875
+ "\n",
876
+ " if self.conv_patch_representation:\n",
877
+ " self.conv_x = nn.Conv3d(128, self.embedding_dim, kernel_size=3, stride=1, padding=1)\n",
878
+ " self.encoder = Encoder(self.num_channels, 16)\n",
879
+ " self.bn = nn.InstanceNorm3d(128)\n",
880
+ " self.relu = nn.LeakyReLU(0.2, inplace=True)\n",
881
+ " self.FeatureMapping = FeatureMapping(in_channel = self.embedding_dim, out_channel= self.in_channels_vae)\n",
882
+ " self.FeatureMapping1 = FeatureMapping1(in_channel = self.in_channels_vae)\n",
883
+ " self.decoder = Decoder(self.img_dim, self.patch_dim, self.embedding_dim, num_classes)\n",
884
+ "\n",
885
+ " self.vae_input = (1, self.in_channels_vae, img_dim[0] // 8, img_dim[1] // 8, img_dim[2] // 8)\n",
886
+ " if use_VAE:\n",
887
+ " self.vae = VAE(input_shape = self.vae_input , latent_dim= 256, num_channels= self.num_channels)\n",
888
+ " def encode(self, x):\n",
889
+ " if self.conv_patch_representation:\n",
890
+ " x1, x2, x3, x = self.encoder(x)\n",
891
+ " x = self.bn(x)\n",
892
+ " x = self.relu(x)\n",
893
+ " x = self.conv_x(x)\n",
894
+ " x = x.permute(0, 2, 3, 4, 1).contiguous()\n",
895
+ " x = x.view(x.size(0), -1, self.embedding_dim)\n",
896
+ " x = self.position_encoding(x)\n",
897
+ " x = self.pe_dropout(x)\n",
898
+ " x = self.transformer(x)\n",
899
+ " x = self.pre_head_ln(x)\n",
900
+ "\n",
901
+ " return x1, x2, x3, x\n",
902
+ "\n",
903
+ " def decode(self, x1, x2, x3, x):\n",
904
+ " #x: (1, 4096, 512) -> (1, 16, 16, 16, 512)\n",
905
+ "# print(\"In decode...\")\n",
906
+ "# print(\" x1: {} \\n x2: {} \\n x3: {} \\n x: {}\".format( x1.shape, x2.shape, x3.shape, x.shape))\n",
907
+ "# break\n",
908
+ " return self.decoder(x1, x2, x3, x)\n",
909
+ "\n",
910
+ " def forward(self, x, is_validation = True):\n",
911
+ " x1, x2, x3, x = self.encode(x)\n",
912
+ " x = x.view( x.size(0),\n",
913
+ " self.img_dim[0] // self.patch_dim,\n",
914
+ " self.img_dim[1] // self.patch_dim,\n",
915
+ " self.img_dim[2] // self.patch_dim,\n",
916
+ " self.embedding_dim)\n",
917
+ " x = x.permute(0, 4, 1, 2, 3).contiguous()\n",
918
+ " x = self.FeatureMapping(x)\n",
919
+ " x = self.FeatureMapping1(x)\n",
920
+ " if self.use_VAE and not is_validation:\n",
921
+ " vae_out, mu, sigma = self.vae(x)\n",
922
+ " y = self.decode(x1, x2, x3, x)\n",
923
+ " if self.use_VAE and not is_validation:\n",
924
+ " return y, vae_out, mu, sigma\n",
925
+ " else:\n",
926
+ " return y\n",
927
+ "\n",
928
+ "\n"
929
+ ]
930
+ },
931
+ {
932
+ "cell_type": "code",
933
+ "execution_count": 18,
934
+ "metadata": {},
935
+ "outputs": [
936
+ {
937
+ "name": "stdout",
938
+ "output_type": "stream",
939
+ "text": [
940
+ "CUDA (GPU) is available. Using GPU.\n"
941
+ ]
942
+ }
943
+ ],
944
+ "source": [
945
+ "import torch\n",
946
+ "\n",
947
+ "# Check if CUDA (GPU support) is available\n",
948
+ "if torch.cuda.is_available():\n",
949
+ " device = torch.device(\"cuda:0\")\n",
950
+ " print(\"CUDA (GPU) is available. Using GPU.\")\n",
951
+ "else:\n",
952
+ " device = torch.device(\"cpu\")\n",
953
+ " print(\"CUDA (GPU) is not available. Using CPU.\")"
954
+ ]
955
+ },
956
+ {
957
+ "cell_type": "code",
958
+ "execution_count": 18,
959
+ "metadata": {},
960
+ "outputs": [],
961
+ "source": [
962
+ "model = SegTransVAE(img_dim = (128, 128, 128),patch_dim= 8,num_channels =4,num_classes= 3,embedding_dim= 768,num_heads= 8,num_layers= 4, hidden_dim= 3072,in_channels_vae=128 , use_VAE = True)"
963
+ ]
964
+ },
965
+ {
966
+ "cell_type": "code",
967
+ "execution_count": 28,
968
+ "metadata": {},
969
+ "outputs": [
970
+ {
971
+ "name": "stdout",
972
+ "output_type": "stream",
973
+ "text": [
974
+ "Tα»•ng sα»‘ tham sα»‘ của mΓ΄ hΓ¬nh lΓ : 44727120\n",
975
+ "Tα»•ng sα»‘ tham sα»‘ cαΊ§n tΓ­nh gradient của mΓ΄ hΓ¬nh lΓ : 44727120\n"
976
+ ]
977
+ }
978
+ ],
979
+ "source": [
980
+ "total_params = sum(p.numel() for p in model.parameters())\n",
981
+ "print(f'Tα»•ng sα»‘ tham sα»‘ của mΓ΄ hΓ¬nh lΓ : {total_params}')\n",
982
+ "\n",
983
+ "total_params_requires_grad = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
984
+ "print(f'Tα»•ng sα»‘ tham sα»‘ cαΊ§n tΓ­nh gradient của mΓ΄ hΓ¬nh lΓ : {total_params_requires_grad}')\n"
985
+ ]
986
+ },
987
+ {
988
+ "cell_type": "code",
989
+ "execution_count": 19,
990
+ "metadata": {},
991
+ "outputs": [],
992
+ "source": [
993
+ "class Loss_VAE(nn.Module):\n",
994
+ " def __init__(self):\n",
995
+ " super().__init__()\n",
996
+ " self.mse = nn.MSELoss(reduction='sum')\n",
997
+ "\n",
998
+ " def forward(self, recon_x, x, mu, log_var):\n",
999
+ " mse = self.mse(recon_x, x)\n",
1000
+ " kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())\n",
1001
+ " loss = mse + kld\n",
1002
+ " return loss"
1003
+ ]
1004
+ },
1005
+ {
1006
+ "cell_type": "code",
1007
+ "execution_count": 20,
1008
+ "metadata": {},
1009
+ "outputs": [],
1010
+ "source": [
1011
+ "def DiceScore(\n",
1012
+ " y_pred: torch.Tensor,\n",
1013
+ " y: torch.Tensor,\n",
1014
+ " include_background: bool = True,\n",
1015
+ ") -> torch.Tensor:\n",
1016
+ " \"\"\"Computes Dice score metric from full size Tensor and collects average.\n",
1017
+ " Args:\n",
1018
+ " y_pred: input data to compute, typical segmentation model output.\n",
1019
+ " It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values\n",
1020
+ " should be binarized.\n",
1021
+ " y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.\n",
1022
+ " The values should be binarized.\n",
1023
+ " include_background: whether to skip Dice computation on the first channel of\n",
1024
+ " the predicted output. Defaults to True.\n",
1025
+ " Returns:\n",
1026
+ " Dice scores per batch and per class, (shape [batch_size, num_classes]).\n",
1027
+ " Raises:\n",
1028
+ " ValueError: when `y_pred` and `y` have different shapes.\n",
1029
+ " \"\"\"\n",
1030
+ "\n",
1031
+ " y = y.float()\n",
1032
+ " y_pred = y_pred.float()\n",
1033
+ "\n",
1034
+ " if y.shape != y_pred.shape:\n",
1035
+ " raise ValueError(\"y_pred and y should have same shapes.\")\n",
1036
+ "\n",
1037
+ " # reducing only spatial dimensions (not batch nor channels)\n",
1038
+ " n_len = len(y_pred.shape)\n",
1039
+ " reduce_axis = list(range(2, n_len))\n",
1040
+ " intersection = torch.sum(y * y_pred, dim=reduce_axis)\n",
1041
+ "\n",
1042
+ " y_o = torch.sum(y, reduce_axis)\n",
1043
+ " y_pred_o = torch.sum(y_pred, dim=reduce_axis)\n",
1044
+ " denominator = y_o + y_pred_o\n",
1045
+ "\n",
1046
+ " return torch.where(\n",
1047
+ " denominator > 0,\n",
1048
+ " (2.0 * intersection) / denominator,\n",
1049
+ " torch.tensor(float(\"1\"), device=y_o.device),\n",
1050
+ " )\n"
1051
+ ]
1052
+ },
1053
+ {
1054
+ "cell_type": "code",
1055
+ "execution_count": 21,
1056
+ "metadata": {},
1057
+ "outputs": [],
1058
+ "source": [
1059
+ "# Pytorch Lightning\n",
1060
+ "import pytorch_lightning as pl\n",
1061
+ "import matplotlib.pyplot as plt\n",
1062
+ "import csv\n",
1063
+ "from monai.transforms import AsDiscrete, Activations, Compose, EnsureType"
1064
+ ]
1065
+ },
1066
+ {
1067
+ "cell_type": "code",
1068
+ "execution_count": 24,
1069
+ "metadata": {},
1070
+ "outputs": [],
1071
+ "source": [
1072
+ "class BRATS(pl.LightningModule):\n",
1073
+ " def __init__(self, use_VAE = True, lr = 1e-4, ):\n",
1074
+ " super().__init__()\n",
1075
+ " \n",
1076
+ " self.use_vae = use_VAE\n",
1077
+ " self.lr = lr\n",
1078
+ " self.model = SegTransVAE((128, 128, 128), 8, 4, 3, 768, 8, 4, 3072, in_channels_vae=128, use_VAE = use_VAE)\n",
1079
+ "\n",
1080
+ " self.loss_vae = Loss_VAE()\n",
1081
+ " self.dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True)\n",
1082
+ " self.post_trans_images = Compose(\n",
1083
+ " [EnsureType(),\n",
1084
+ " Activations(sigmoid=True), \n",
1085
+ " AsDiscrete(threshold_values=True), \n",
1086
+ " ]\n",
1087
+ " )\n",
1088
+ "\n",
1089
+ " self.best_val_dice = 0\n",
1090
+ " \n",
1091
+ " self.training_step_outputs = [] \n",
1092
+ " self.val_step_loss = [] \n",
1093
+ " self.val_step_dice = []\n",
1094
+ " self.val_step_dice_tc = [] \n",
1095
+ " self.val_step_dice_wt = []\n",
1096
+ " self.val_step_dice_et = [] \n",
1097
+ " self.test_step_loss = [] \n",
1098
+ " self.test_step_dice = []\n",
1099
+ " self.test_step_dice_tc = [] \n",
1100
+ " self.test_step_dice_wt = []\n",
1101
+ " self.test_step_dice_et = [] \n",
1102
+ "\n",
1103
+ " def forward(self, x, is_validation = True):\n",
1104
+ " return self.model(x, is_validation) \n",
1105
+ " def training_step(self, batch, batch_index):\n",
1106
+ " inputs, labels = (batch['image'], batch['label'])\n",
1107
+ " \n",
1108
+ " if not self.use_vae:\n",
1109
+ " outputs = self.forward(inputs, is_validation=False)\n",
1110
+ " loss = self.dice_loss(outputs, labels)\n",
1111
+ " else:\n",
1112
+ " outputs, recon_batch, mu, sigma = self.forward(inputs, is_validation=False)\n",
1113
+ " \n",
1114
+ " vae_loss = self.loss_vae(recon_batch, inputs, mu, sigma)\n",
1115
+ " dice_loss = self.dice_loss(outputs, labels)\n",
1116
+ " loss = dice_loss + 1/(4 * 128 * 128 * 128) * vae_loss\n",
1117
+ " self.training_step_outputs.append(loss)\n",
1118
+ " self.log('train/vae_loss', vae_loss)\n",
1119
+ " self.log('train/dice_loss', dice_loss)\n",
1120
+ " if batch_index == 10:\n",
1121
+ "\n",
1122
+ " tensorboard = self.logger.experiment \n",
1123
+ " fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(10, 5))\n",
1124
+ " \n",
1125
+ "\n",
1126
+ " ax[0].imshow(inputs.detach().cpu()[0][0][:, :, 80], cmap='gray')\n",
1127
+ " ax[0].set_title(\"Input\")\n",
1128
+ "\n",
1129
+ " ax[1].imshow(recon_batch.detach().cpu().float()[0][0][:,:, 80], cmap='gray')\n",
1130
+ " ax[1].set_title(\"Reconstruction\")\n",
1131
+ " \n",
1132
+ " ax[2].imshow(labels.detach().cpu().float()[0][0][:,:, 80], cmap='gray')\n",
1133
+ " ax[2].set_title(\"Labels TC\")\n",
1134
+ " \n",
1135
+ " ax[3].imshow(outputs.sigmoid().detach().cpu().float()[0][0][:,:, 80], cmap='gray')\n",
1136
+ " ax[3].set_title(\"TC\")\n",
1137
+ " \n",
1138
+ " ax[4].imshow(labels.detach().cpu().float()[0][2][:,:, 80], cmap='gray')\n",
1139
+ " ax[4].set_title(\"Labels ET\")\n",
1140
+ " \n",
1141
+ " ax[5].imshow(outputs.sigmoid().detach().cpu().float()[0][2][:,:, 80], cmap='gray')\n",
1142
+ " ax[5].set_title(\"ET\")\n",
1143
+ "\n",
1144
+ " \n",
1145
+ " tensorboard.add_figure('train_visualize', fig, self.current_epoch)\n",
1146
+ "\n",
1147
+ " self.log('train/loss', loss)\n",
1148
+ " \n",
1149
+ " return loss\n",
1150
+ " \n",
1151
+ " def on_train_epoch_end(self):\n",
1152
+ " ## F1 Macro all epoch saving outputs and target per batch\n",
1153
+ "\n",
1154
+ " # free up the memory\n",
1155
+ " # --> HERE STEP 3 <--\n",
1156
+ " epoch_average = torch.stack(self.training_step_outputs).mean()\n",
1157
+ " self.log(\"training_epoch_average\", epoch_average)\n",
1158
+ " self.training_step_outputs.clear() # free memory\n",
1159
+ "\n",
1160
+ " def validation_step(self, batch, batch_index):\n",
1161
+ " inputs, labels = (batch['image'], batch['label'])\n",
1162
+ " roi_size = (128, 128, 128)\n",
1163
+ " sw_batch_size = 1\n",
1164
+ " outputs = sliding_window_inference(\n",
1165
+ " inputs, roi_size, sw_batch_size, self.model, overlap = 0.5)\n",
1166
+ " loss = self.dice_loss(outputs, labels)\n",
1167
+ " \n",
1168
+ " \n",
1169
+ " val_outputs = self.post_trans_images(outputs)\n",
1170
+ " \n",
1171
+ " \n",
1172
+ " metric_tc = DiceScore(y_pred=val_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)\n",
1173
+ " metric_wt = DiceScore(y_pred=val_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)\n",
1174
+ " metric_et = DiceScore(y_pred=val_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)\n",
1175
+ " mean_val_dice = (metric_tc + metric_wt + metric_et)/3\n",
1176
+ " self.val_step_loss.append(loss) \n",
1177
+ " self.val_step_dice.append(mean_val_dice)\n",
1178
+ " self.val_step_dice_tc.append(metric_tc) \n",
1179
+ " self.val_step_dice_wt.append(metric_wt)\n",
1180
+ " self.val_step_dice_et.append(metric_et) \n",
1181
+ " return {'val_loss': loss, 'val_mean_dice': mean_val_dice, 'val_dice_tc': metric_tc,\n",
1182
+ " 'val_dice_wt': metric_wt, 'val_dice_et': metric_et}\n",
1183
+ " \n",
1184
+ " def on_validation_epoch_end(self):\n",
1185
+ "\n",
1186
+ " loss = torch.stack(self.val_step_loss).mean()\n",
1187
+ " mean_val_dice = torch.stack(self.val_step_dice).mean()\n",
1188
+ " metric_tc = torch.stack(self.val_step_dice_tc).mean()\n",
1189
+ " metric_wt = torch.stack(self.val_step_dice_wt).mean()\n",
1190
+ " metric_et = torch.stack(self.val_step_dice_et).mean()\n",
1191
+ " self.log('val/Loss', loss)\n",
1192
+ " self.log('val/MeanDiceScore', mean_val_dice)\n",
1193
+ " self.log('val/DiceTC', metric_tc)\n",
1194
+ " self.log('val/DiceWT', metric_wt)\n",
1195
+ " self.log('val/DiceET', metric_et)\n",
1196
+ " os.makedirs(self.logger.log_dir, exist_ok=True)\n",
1197
+ " if self.current_epoch == 0:\n",
1198
+ " with open('{}/metric_log.csv'.format(self.logger.log_dir), 'w') as f:\n",
1199
+ " writer = csv.writer(f)\n",
1200
+ " writer.writerow(['Epoch', 'Mean Dice Score', 'Dice TC', 'Dice WT', 'Dice ET'])\n",
1201
+ " with open('{}/metric_log.csv'.format(self.logger.log_dir), 'a') as f:\n",
1202
+ " writer = csv.writer(f)\n",
1203
+ " writer.writerow([self.current_epoch, mean_val_dice.item(), metric_tc.item(), metric_wt.item(), metric_et.item()])\n",
1204
+ "\n",
1205
+ " if mean_val_dice > self.best_val_dice:\n",
1206
+ " self.best_val_dice = mean_val_dice\n",
1207
+ " self.best_val_epoch = self.current_epoch\n",
1208
+ " print(\n",
1209
+ " f\"\\n Current epoch: {self.current_epoch} Current mean dice: {mean_val_dice:.4f}\"\n",
1210
+ " f\" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}\"\n",
1211
+ " f\"\\n Best mean dice: {self.best_val_dice}\"\n",
1212
+ " f\" at epoch: {self.best_val_epoch}\"\n",
1213
+ " )\n",
1214
+ " \n",
1215
+ " self.val_step_loss.clear() \n",
1216
+ " self.val_step_dice.clear()\n",
1217
+ " self.val_step_dice_tc.clear() \n",
1218
+ " self.val_step_dice_wt.clear()\n",
1219
+ " self.val_step_dice_et.clear()\n",
1220
+ " return {'val_MeanDiceScore': mean_val_dice}\n",
1221
+ " def test_step(self, batch, batch_index):\n",
1222
+ " inputs, labels = (batch['image'], batch['label'])\n",
1223
+ " \n",
1224
+ " roi_size = (128, 128, 128)\n",
1225
+ " sw_batch_size = 1\n",
1226
+ " test_outputs = sliding_window_inference(\n",
1227
+ " inputs, roi_size, sw_batch_size, self.forward, overlap = 0.5)\n",
1228
+ " loss = self.dice_loss(test_outputs, labels)\n",
1229
+ " test_outputs = self.post_trans_images(test_outputs)\n",
1230
+ " metric_tc = DiceScore(y_pred=test_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)\n",
1231
+ " metric_wt = DiceScore(y_pred=test_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)\n",
1232
+ " metric_et = DiceScore(y_pred=test_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)\n",
1233
+ " mean_test_dice = (metric_tc + metric_wt + metric_et)/3\n",
1234
+ " \n",
1235
+ " self.test_step_loss.append(loss) \n",
1236
+ " self.test_step_dice.append(mean_test_dice)\n",
1237
+ " self.test_step_dice_tc.append(metric_tc) \n",
1238
+ " self.test_step_dice_wt.append(metric_wt)\n",
1239
+ " self.test_step_dice_et.append(metric_et) \n",
1240
+ " \n",
1241
+ " return {'test_loss': loss, 'test_mean_dice': mean_test_dice, 'test_dice_tc': metric_tc,\n",
1242
+ " 'test_dice_wt': metric_wt, 'test_dice_et': metric_et}\n",
1243
+ " \n",
1244
+ " def test_epoch_end(self):\n",
1245
+ " loss = torch.stack(self.test_step_loss).mean()\n",
1246
+ " mean_test_dice = torch.stack(self.test_step_dice).mean()\n",
1247
+ " metric_tc = torch.stack(self.test_step_dice_tc).mean()\n",
1248
+ " metric_wt = torch.stack(self.test_step_dice_wt).mean()\n",
1249
+ " metric_et = torch.stack(self.test_step_dice_et).mean()\n",
1250
+ " self.log('test/Loss', loss)\n",
1251
+ " self.log('test/MeanDiceScore', mean_test_dice)\n",
1252
+ " self.log('test/DiceTC', metric_tc)\n",
1253
+ " self.log('test/DiceWT', metric_wt)\n",
1254
+ " self.log('test/DiceET', metric_et)\n",
1255
+ "\n",
1256
+ " with open('{}/test_log.csv'.format(self.logger.log_dir), 'w') as f:\n",
1257
+ " writer = csv.writer(f)\n",
1258
+ " writer.writerow([\"Mean Test Dice\", \"Dice TC\", \"Dice WT\", \"Dice ET\"])\n",
1259
+ " writer.writerow([mean_test_dice, metric_tc, metric_wt, metric_et])\n",
1260
+ "\n",
1261
+ " self.test_step_loss.clear() \n",
1262
+ " self.test_step_dice.clear()\n",
1263
+ " self.test_step_dice_tc.clear() \n",
1264
+ " self.test_step_dice_wt.clear()\n",
1265
+ " self.test_step_dice_et.clear()\n",
1266
+ " return {'test_MeanDiceScore': mean_test_dice}\n",
1267
+ " \n",
1268
+ " \n",
1269
+ " def configure_optimizers(self):\n",
1270
+ " optimizer = torch.optim.Adam(\n",
1271
+ " self.model.parameters(), self.lr, weight_decay=1e-5, amsgrad=True\n",
1272
+ " )\n",
1273
+ "# optimizer = AdaBelief(self.model.parameters(), \n",
1274
+ "# lr=self.lr, eps=1e-16, \n",
1275
+ "# betas=(0.9,0.999), weight_decouple = True, \n",
1276
+ "# rectify = False)\n",
1277
+ " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200)\n",
1278
+ " return [optimizer], [scheduler]\n",
1279
+ " \n",
1280
+ " def train_dataloader(self):\n",
1281
+ " return train_loader\n",
1282
+ " def val_dataloader(self):\n",
1283
+ " return val_loader\n",
1284
+ " \n",
1285
+ " def test_dataloader(self):\n",
1286
+ " return test_loader"
1287
+ ]
1288
+ },
1289
+ {
1290
+ "cell_type": "code",
1291
+ "execution_count": 1,
1292
+ "metadata": {},
1293
+ "outputs": [
1294
+ {
1295
+ "name": "stderr",
1296
+ "output_type": "stream",
1297
+ "text": [
1298
+ "/usr/local/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
1299
+ " from .autonotebook import tqdm as notebook_tqdm\n"
1300
+ ]
1301
+ }
1302
+ ],
1303
+ "source": [
1304
+ "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
1305
+ "import os \n",
1306
+ "from pytorch_lightning.loggers import TensorBoardLogger"
1307
+ ]
1308
+ },
1309
+ {
1310
+ "cell_type": "code",
1311
+ "execution_count": 25,
1312
+ "metadata": {},
1313
+ "outputs": [
1314
+ {
1315
+ "name": "stderr",
1316
+ "output_type": "stream",
1317
+ "text": [
1318
+ "sh: 1: cls: not found\n"
1319
+ ]
1320
+ },
1321
+ {
1322
+ "name": "stdout",
1323
+ "output_type": "stream",
1324
+ "text": [
1325
+ "\u001b[H\u001b[2JTraining ...\n"
1326
+ ]
1327
+ },
1328
+ {
1329
+ "name": "stderr",
1330
+ "output_type": "stream",
1331
+ "text": [
1332
+ "/usr/local/lib/python3.9/site-packages/lightning_fabric/connector.py:563: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!\n",
1333
+ "Using 16bit Automatic Mixed Precision (AMP)\n",
1334
+ "GPU available: True (cuda), used: True\n",
1335
+ "TPU available: False, using: 0 TPU cores\n",
1336
+ "IPU available: False, using: 0 IPUs\n",
1337
+ "HPU available: False, using: 0 HPUs\n",
1338
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
1339
+ "\n",
1340
+ " | Name | Type | Params\n",
1341
+ "------------------------------------------\n",
1342
+ "0 | model | SegTransVAE | 44.7 M\n",
1343
+ "1 | loss_vae | Loss_VAE | 0 \n",
1344
+ "2 | dice_loss | DiceLoss | 0 \n",
1345
+ "------------------------------------------\n",
1346
+ "44.7 M Trainable params\n",
1347
+ "0 Non-trainable params\n",
1348
+ "44.7 M Total params\n",
1349
+ "178.908 Total estimated model params size (MB)\n"
1350
+ ]
1351
+ },
1352
+ {
1353
+ "name": "stdout",
1354
+ "output_type": "stream",
1355
+ "text": [
1356
+ "Sanity Checking DataLoader 0: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:05<00:00, 0.37it/s]\n",
1357
+ " Current epoch: 0 Current mean dice: 0.0097 tc: 0.0029 wt: 0.0234 et: 0.0028\n",
1358
+ " Best mean dice: 0.009687595069408417 at epoch: 0\n",
1359
+ "Epoch 0: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 500/500 [05:38<00:00, 1.48it/s, v_num=6] \n",
1360
+ " Current epoch: 0 Current mean dice: 0.1927 tc: 0.1647 wt: 0.2843 et: 0.1290\n",
1361
+ " Best mean dice: 0.1926589012145996 at epoch: 0\n",
1362
+ "Epoch 1: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 500/500 [07:35<00:00, 1.10it/s, v_num=6]\n",
1363
+ " Current epoch: 1 Current mean dice: 0.3212 tc: 0.2691 wt: 0.4253 et: 0.2692\n",
1364
+ " Best mean dice: 0.32120221853256226 at epoch: 1\n",
1365
+ "Epoch 2: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 500/500 [08:11<00:00, 1.02it/s, v_num=6]\n",
1366
+ " Current epoch: 2 Current mean dice: 0.3912 tc: 0.3510 wt: 0.5087 et: 0.3137\n",
1367
+ " Best mean dice: 0.39115065336227417 at epoch: 2\n",
1368
+ "Epoch 3: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 500/500 [08:58<00:00, 0.93it/s, v_num=6]\n",
1369
+ " Current epoch: 3 Current mean dice: 0.4268 tc: 0.3828 wt: 0.5424 et: 0.3553\n",
1370
+ " Best mean dice: 0.42682838439941406 at epoch: 3\n",
1371
+ "Epoch 4: 41%|β–ˆβ–ˆβ–ˆβ–ˆβ– | 207/500 [02:51<04:03, 1.21it/s, v_num=6]"
1372
+ ]
1373
+ },
1374
+ {
1375
+ "ename": "",
1376
+ "evalue": "",
1377
+ "output_type": "error",
1378
+ "traceback": [
1379
+ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
1380
+ "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
1381
+ "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
1382
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
1383
+ ]
1384
+ }
1385
+ ],
1386
+ "source": [
1387
+ "os.system('cls||clear')\n",
1388
+ "print(\"Training ...\")\n",
1389
+ "model = BRATS(use_VAE = True)\n",
1390
+ "checkpoint_callback = ModelCheckpoint(\n",
1391
+ " monitor='val/MeanDiceScore',\n",
1392
+ " dirpath='./app/checkpoints/{}'.format(1),\n",
1393
+ " filename='Epoch{epoch:3d}-MeanDiceScore{val/MeanDiceScore:.4f}',\n",
1394
+ " save_top_k=3,\n",
1395
+ " mode='max',\n",
1396
+ " save_last= True,\n",
1397
+ " auto_insert_metric_name=False\n",
1398
+ ")\n",
1399
+ "early_stop_callback = EarlyStopping(\n",
1400
+ " monitor='val/MeanDiceScore',\n",
1401
+ " min_delta=0.0001,\n",
1402
+ " patience=15,\n",
1403
+ " verbose=False,\n",
1404
+ " mode='max'\n",
1405
+ ")\n",
1406
+ "tensorboardlogger = TensorBoardLogger(\n",
1407
+ " 'logs', \n",
1408
+ " name = \"1\", \n",
1409
+ " default_hp_metric = None \n",
1410
+ ")\n",
1411
+ "trainer = pl.Trainer(#fast_dev_run = 10, \n",
1412
+ "# accelerator='ddp',\n",
1413
+ " #overfit_batches=5,\n",
1414
+ " devices = [0], \n",
1415
+ " precision=16,\n",
1416
+ " max_epochs = 200, \n",
1417
+ " enable_progress_bar=True, \n",
1418
+ " callbacks=[checkpoint_callback, early_stop_callback], \n",
1419
+ "# auto_lr_find=True,\n",
1420
+ " num_sanity_val_steps=2,\n",
1421
+ " logger = tensorboardlogger,\n",
1422
+ "# limit_train_batches=0.01, \n",
1423
+ "# limit_val_batches=0.01\n",
1424
+ " )\n",
1425
+ "# trainer.tune(model)\n",
1426
+ "trainer.fit(model)\n",
1427
+ "\n",
1428
+ "\n",
1429
+ "\n"
1430
+ ]
1431
+ },
1432
+ {
1433
+ "cell_type": "code",
1434
+ "execution_count": null,
1435
+ "metadata": {},
1436
+ "outputs": [],
1437
+ "source": [
1438
+ "import pytorch_lightning as pl\n",
1439
+ "from trainer import BRATS\n",
1440
+ "import os \n",
1441
+ "import torch\n",
1442
+ "os.system('cls||clear')\n",
1443
+ "print(\"Testing ...\")\n",
1444
+ "\n",
1445
+ "CKPT = ''\n",
1446
+ "model = BRATS(use_VAE=True).load_from_checkpoint(CKPT).eval()\n",
1447
+ "val_dataloader = get_val_dataloader()\n",
1448
+ "test_dataloader = get_test_dataloader()\n",
1449
+ "trainer = pl.Trainer(gpus = [0], precision=32, progress_bar_refresh_rate=10)\n",
1450
+ "\n",
1451
+ "trainer.test(model, dataloaders = val_dataloader)\n",
1452
+ "trainer.test(model, dataloaders = test_dataloader)\n",
1453
+ "\n"
1454
+ ]
1455
+ }
1456
+ ],
1457
+ "metadata": {
1458
+ "kernelspec": {
1459
+ "display_name": "Python 3 (ipykernel)",
1460
+ "language": "python",
1461
+ "name": "python3"
1462
+ },
1463
+ "language_info": {
1464
+ "codemirror_mode": {
1465
+ "name": "ipython",
1466
+ "version": 3
1467
+ },
1468
+ "file_extension": ".py",
1469
+ "mimetype": "text/x-python",
1470
+ "name": "python",
1471
+ "nbconvert_exporter": "python",
1472
+ "pygments_lexer": "ipython3",
1473
+ "version": "3.9.18"
1474
+ }
1475
+ },
1476
+ "nbformat": 4,
1477
+ "nbformat_minor": 2
1478
+ }