improved argparse stuff
Browse files- .gitignore +4 -0
- livermask/livermask.py +96 -78
- setup.py +2 -2
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
venv/
|
2 |
+
build/
|
3 |
+
dist/
|
4 |
+
livermask.egg-info/
|
livermask/livermask.py
CHANGED
@@ -10,120 +10,138 @@ import gdown
|
|
10 |
from skimage.morphology import remove_small_holes, binary_dilation, binary_erosion, ball
|
11 |
from skimage.measure import label, regionprops
|
12 |
import warnings
|
|
|
|
|
|
|
|
|
|
|
13 |
warnings.filterwarnings('ignore', '.*output shape of zoom.*')
|
14 |
|
15 |
|
16 |
def intensity_normalization(volume, intensity_clipping_range):
|
17 |
-
|
18 |
|
19 |
-
|
20 |
-
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
|
27 |
-
|
28 |
|
29 |
def post_process(pred):
|
30 |
-
|
31 |
|
32 |
def get_model():
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
|
38 |
def func(path, output):
|
39 |
|
40 |
-
|
41 |
|
42 |
-
|
43 |
|
44 |
-
|
45 |
-
|
46 |
|
47 |
-
|
48 |
-
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
|
77 |
-
|
78 |
-
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
|
106 |
-
|
107 |
-
|
108 |
|
109 |
-
|
110 |
-
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
|
118 |
|
119 |
def main():
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
-
|
123 |
-
|
|
|
124 |
|
125 |
-
|
126 |
|
127 |
|
128 |
if __name__ == "__main__":
|
129 |
-
|
|
|
10 |
from skimage.morphology import remove_small_holes, binary_dilation, binary_erosion, ball
|
11 |
from skimage.measure import label, regionprops
|
12 |
import warnings
|
13 |
+
import argparse
|
14 |
+
import pkg_resources
|
15 |
+
|
16 |
+
|
17 |
+
# mute some warnings
|
18 |
warnings.filterwarnings('ignore', '.*output shape of zoom.*')
|
19 |
|
20 |
|
21 |
def intensity_normalization(volume, intensity_clipping_range):
|
22 |
+
result = np.copy(volume)
|
23 |
|
24 |
+
result[volume < intensity_clipping_range[0]] = intensity_clipping_range[0]
|
25 |
+
result[volume > intensity_clipping_range[1]] = intensity_clipping_range[1]
|
26 |
|
27 |
+
min_val = np.amin(result)
|
28 |
+
max_val = np.amax(result)
|
29 |
+
if (max_val - min_val) != 0:
|
30 |
+
result = (result - min_val) / (max_val - min_val)
|
31 |
|
32 |
+
return result
|
33 |
|
34 |
def post_process(pred):
|
35 |
+
return pred
|
36 |
|
37 |
def get_model():
|
38 |
+
url = "https://drive.google.com/uc?id=12or5Q79at2BtLgQ7IaglNGPFGRlEgEHc"
|
39 |
+
output = "model.h5"
|
40 |
+
md5 = "ef5a6dfb794b39bea03f5496a9a49d4d"
|
41 |
+
gdown.cached_download(url, output, md5=md5) #, postprocess=gdown.extractall)
|
42 |
|
43 |
def func(path, output):
|
44 |
|
45 |
+
cwd = "/".join(os.path.realpath(__file__).replace("\\", "/").split("/")[:-1]) + "/"
|
46 |
|
47 |
+
name = cwd + "model.h5"
|
48 |
|
49 |
+
# get model
|
50 |
+
get_model()
|
51 |
|
52 |
+
# load model
|
53 |
+
model = load_model(name, compile=False)
|
54 |
|
55 |
+
print("preprocessing...")
|
56 |
+
nib_volume = nib.load(path)
|
57 |
+
new_spacing = [1., 1., 1.]
|
58 |
+
resampled_volume = resample_to_output(nib_volume, new_spacing, order=1)
|
59 |
+
data = resampled_volume.get_data().astype('float32')
|
60 |
|
61 |
+
curr_shape = data.shape
|
62 |
|
63 |
+
# resize to get (512, 512) output images
|
64 |
+
img_size = 512
|
65 |
+
data = zoom(data, [img_size / data.shape[0], img_size / data.shape[1], 1.0], order=1)
|
66 |
|
67 |
+
# intensity normalization
|
68 |
+
intensity_clipping_range = [-150, 250] # HU clipping limits (Pravdaray's configs)
|
69 |
+
data = intensity_normalization(volume=data, intensity_clipping_range=intensity_clipping_range)
|
70 |
|
71 |
+
# fix orientation
|
72 |
+
data = np.rot90(data, k=1, axes=(0, 1))
|
73 |
+
data = np.flip(data, axis=0)
|
74 |
|
75 |
+
print("predicting...")
|
76 |
+
# predict on data
|
77 |
+
pred = np.zeros_like(data).astype(np.float32)
|
78 |
+
for i in tqdm(range(data.shape[-1]), "pred: "):
|
79 |
+
pred[..., i] = model.predict(np.expand_dims(np.expand_dims(np.expand_dims(data[..., i], axis=0), axis=-1), axis=0))[0, ..., 1]
|
80 |
+
del data
|
81 |
|
82 |
+
# threshold
|
83 |
+
pred = (pred >= 0.4).astype(int)
|
84 |
|
85 |
+
# fix orientation back
|
86 |
+
pred = np.flip(pred, axis=0)
|
87 |
+
pred = np.rot90(pred, k=-1, axes=(0, 1))
|
88 |
|
89 |
+
print("resize back...")
|
90 |
+
# resize back from 512x512
|
91 |
+
pred = zoom(pred, [curr_shape[0] / img_size, curr_shape[1] / img_size, 1.0], order=1)
|
92 |
+
pred = (pred >= 0.5).astype(np.float32)
|
93 |
|
94 |
+
print("morphological post-processing...")
|
95 |
+
# morpological post-processing
|
96 |
+
# 1) first erode
|
97 |
+
pred = binary_erosion(pred.astype(bool), ball(3)).astype(np.float32)
|
98 |
|
99 |
+
# 2) keep only largest connected component
|
100 |
+
labels = label(pred)
|
101 |
+
regions = regionprops(labels)
|
102 |
+
area_sizes = []
|
103 |
+
for region in regions:
|
104 |
+
area_sizes.append([region.label, region.area])
|
105 |
+
area_sizes = np.array(area_sizes)
|
106 |
+
tmp = np.zeros_like(pred)
|
107 |
+
tmp[labels == area_sizes[np.argmax(area_sizes[:, 1]), 0]] = 1
|
108 |
+
pred = tmp.copy()
|
109 |
+
del tmp, labels, regions, area_sizes
|
110 |
|
111 |
+
# 3) dilate
|
112 |
+
pred = binary_dilation(pred.astype(bool), ball(3))
|
113 |
|
114 |
+
# 4) remove small holes
|
115 |
+
pred = remove_small_holes(pred.astype(bool), area_threshold=0.001*np.prod(pred.shape)).astype(np.float32)
|
116 |
|
117 |
+
print("saving...")
|
118 |
+
pred = pred.astype(np.uint8)
|
119 |
+
img = nib.Nifti1Image(pred, affine=resampled_volume.affine)
|
120 |
+
resampled_lab = resample_from_to(img, nib_volume, order=0)
|
121 |
+
nib.save(resampled_lab, output)
|
122 |
|
123 |
|
124 |
def main():
|
125 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # disable GPU
|
126 |
+
version = pkg_resources.require("livermask")[0].version
|
127 |
+
|
128 |
+
parser = argparse.ArgumentParser()
|
129 |
+
parser.add_argument('--input', metavar='--i', type=str, nargs='?',
|
130 |
+
help="set path of which image to use.")
|
131 |
+
parser.add_argument('--output', metavar='--o', type=str, nargs='?',
|
132 |
+
help="set path to store the output.")
|
133 |
+
parser.add_argument('--cpu', metavar='--o', action='store_true',
|
134 |
+
help="force using the CPU even if a GPU is available.")
|
135 |
+
parser.add_argument('--version', metavar='--v',
|
136 |
+
help='shows the current version of livermask.', version=version)
|
137 |
+
ret = parser.parse_args(sys.argv[1:]); print(ret)
|
138 |
|
139 |
+
# fix paths
|
140 |
+
ret.input = ret.input.replace("\\", "/")
|
141 |
+
ret.output = ret.output.replace("\\", "/")
|
142 |
|
143 |
+
func(*vars(ret).values())
|
144 |
|
145 |
|
146 |
if __name__ == "__main__":
|
147 |
+
main()
|
setup.py
CHANGED
@@ -20,8 +20,8 @@ setuptools.setup(
|
|
20 |
]
|
21 |
},
|
22 |
install_requires=[
|
23 |
-
'
|
24 |
-
'
|
25 |
'scipy',
|
26 |
'tqdm',
|
27 |
'nibabel',
|
|
|
20 |
]
|
21 |
},
|
22 |
install_requires=[
|
23 |
+
'numpy'
|
24 |
+
'tensorflow==2.6',
|
25 |
'scipy',
|
26 |
'tqdm',
|
27 |
'nibabel',
|