added batch mode. Can now provide directory of CTs/.nii files to be predicted on. Also refactored build workflow
Browse files- .github/workflows/build.yaml +1 -4
- livermask/livermask.py +99 -74
.github/workflows/build.yaml
CHANGED
@@ -34,10 +34,7 @@ jobs:
|
|
34 |
run: pip install wheel setuptools
|
35 |
|
36 |
- name: Build wheel
|
37 |
-
run:
|
38 |
-
python setup.py bdist_wheel --universal
|
39 |
-
cd dist
|
40 |
-
ls
|
41 |
|
42 |
- name: Install program for ${{matrix.TARGET}}
|
43 |
run: ${{matrix.CMD_BUILD}}
|
|
|
34 |
run: pip install wheel setuptools
|
35 |
|
36 |
- name: Build wheel
|
37 |
+
run: python setup.py bdist_wheel --universal
|
|
|
|
|
|
|
38 |
|
39 |
- name: Install program for ${{matrix.TARGET}}
|
40 |
run: ${{matrix.CMD_BUILD}}
|
livermask/livermask.py
CHANGED
@@ -12,6 +12,7 @@ import warnings
|
|
12 |
import argparse
|
13 |
import pkg_resources
|
14 |
import tensorflow as tf
|
|
|
15 |
|
16 |
|
17 |
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # due to this: https://github.com/tensorflow/tensorflow/issues/35029
|
@@ -38,7 +39,17 @@ def get_model(output):
|
|
38 |
md5 = "ef5a6dfb794b39bea03f5496a9a49d4d"
|
39 |
gdown.cached_download(url, output, md5=md5) #, postprocess=gdown.extractall)
|
40 |
|
41 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
cwd = "/".join(os.path.realpath(__file__).replace("\\", "/").split("/")[:-1]) + "/"
|
43 |
name = cwd + "model.h5"
|
44 |
|
@@ -48,83 +59,97 @@ def func(path, output, cpu):
|
|
48 |
# load model
|
49 |
model = load_model(name, compile=False)
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
def main():
|
121 |
parser = argparse.ArgumentParser()
|
122 |
parser.add_argument('--input', metavar='--i', type=str, nargs='?',
|
123 |
-
help="set path of which image to use.")
|
124 |
parser.add_argument('--output', metavar='--o', type=str, nargs='?',
|
125 |
help="set path to store the output.")
|
126 |
parser.add_argument('--cpu', action='store_true',
|
127 |
help="force using the CPU even if a GPU is available.")
|
|
|
|
|
128 |
ret = parser.parse_args(sys.argv[1:]); print(ret)
|
129 |
|
130 |
if ret.cpu:
|
@@ -148,10 +173,10 @@ def main():
|
|
148 |
raise ValueError("Please, provide an input.")
|
149 |
if ret.output is None:
|
150 |
raise ValueError("Please, provide an output.")
|
151 |
-
if not ret.input.endswith(".nii"):
|
152 |
-
raise ValueError("
|
153 |
-
if
|
154 |
-
raise ValueError("Output
|
155 |
|
156 |
# fix paths
|
157 |
ret.input = ret.input.replace("\\", "/")
|
|
|
12 |
import argparse
|
13 |
import pkg_resources
|
14 |
import tensorflow as tf
|
15 |
+
import logging as log
|
16 |
|
17 |
|
18 |
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # due to this: https://github.com/tensorflow/tensorflow/issues/35029
|
|
|
39 |
md5 = "ef5a6dfb794b39bea03f5496a9a49d4d"
|
40 |
gdown.cached_download(url, output, md5=md5) #, postprocess=gdown.extractall)
|
41 |
|
42 |
+
def verboseHandler(verbose):
|
43 |
+
if verbose:
|
44 |
+
log.basicConfig(format="%(levelname)s: %(message)s", level=log.DEBUG)
|
45 |
+
log.info("Verbose output.")
|
46 |
+
else:
|
47 |
+
log.basicConfig(format="%(levelname)s: %(message)s")
|
48 |
+
|
49 |
+
def func(path, output, cpu, verbose):
|
50 |
+
# enable verbose or not
|
51 |
+
verboseHandler(verbose)
|
52 |
+
|
53 |
cwd = "/".join(os.path.realpath(__file__).replace("\\", "/").split("/")[:-1]) + "/"
|
54 |
name = cwd + "model.h5"
|
55 |
|
|
|
59 |
# load model
|
60 |
model = load_model(name, compile=False)
|
61 |
|
62 |
+
if not os.path.isdir(path):
|
63 |
+
paths = [path]
|
64 |
+
else:
|
65 |
+
paths = [path + "/" + p for p in os.listdir(path)]
|
66 |
+
|
67 |
+
multiple_flag = len(paths) > 1
|
68 |
+
if multiple_flag:
|
69 |
+
os.makedirs(output + "/", exist_ok=True)
|
70 |
+
|
71 |
+
for curr in tqdm(paths, "CT:"):
|
72 |
+
log.info("preprocessing...")
|
73 |
+
nib_volume = nib.load(curr)
|
74 |
+
new_spacing = [1., 1., 1.]
|
75 |
+
resampled_volume = resample_to_output(nib_volume, new_spacing, order=1)
|
76 |
+
data = resampled_volume.get_data().astype('float32')
|
77 |
+
|
78 |
+
curr_shape = data.shape
|
79 |
+
|
80 |
+
# resize to get (512, 512) output images
|
81 |
+
img_size = 512
|
82 |
+
data = zoom(data, [img_size / data.shape[0], img_size / data.shape[1], 1.0], order=1)
|
83 |
+
|
84 |
+
# intensity normalization
|
85 |
+
intensity_clipping_range = [-150, 250] # HU clipping limits (Pravdaray's configs)
|
86 |
+
data = intensity_normalization(volume=data, intensity_clipping_range=intensity_clipping_range)
|
87 |
+
|
88 |
+
# fix orientation
|
89 |
+
data = np.rot90(data, k=1, axes=(0, 1))
|
90 |
+
data = np.flip(data, axis=0)
|
91 |
+
|
92 |
+
log.info("predicting...")
|
93 |
+
# predict on data
|
94 |
+
pred = np.zeros_like(data).astype(np.float32)
|
95 |
+
for i in tqdm(range(data.shape[-1]), "pred: ", disable=not verbose):
|
96 |
+
pred[..., i] = model.predict(np.expand_dims(np.expand_dims(np.expand_dims(data[..., i], axis=0), axis=-1), axis=0))[0, ..., 1]
|
97 |
+
del data
|
98 |
+
|
99 |
+
# threshold
|
100 |
+
pred = (pred >= 0.4).astype(int)
|
101 |
+
|
102 |
+
# fix orientation back
|
103 |
+
pred = np.flip(pred, axis=0)
|
104 |
+
pred = np.rot90(pred, k=-1, axes=(0, 1))
|
105 |
+
|
106 |
+
log.info("resize back...")
|
107 |
+
# resize back from 512x512
|
108 |
+
pred = zoom(pred, [curr_shape[0] / img_size, curr_shape[1] / img_size, 1.0], order=1)
|
109 |
+
pred = (pred >= 0.5).astype(np.float32)
|
110 |
+
|
111 |
+
log.info("morphological post-processing...")
|
112 |
+
# morpological post-processing
|
113 |
+
# 1) first erode
|
114 |
+
pred = binary_erosion(pred.astype(bool), ball(3)).astype(np.float32)
|
115 |
+
|
116 |
+
# 2) keep only largest connected component
|
117 |
+
labels = label(pred)
|
118 |
+
regions = regionprops(labels)
|
119 |
+
area_sizes = []
|
120 |
+
for region in regions:
|
121 |
+
area_sizes.append([region.label, region.area])
|
122 |
+
area_sizes = np.array(area_sizes)
|
123 |
+
tmp = np.zeros_like(pred)
|
124 |
+
tmp[labels == area_sizes[np.argmax(area_sizes[:, 1]), 0]] = 1
|
125 |
+
pred = tmp.copy()
|
126 |
+
del tmp, labels, regions, area_sizes
|
127 |
+
|
128 |
+
# 3) dilate
|
129 |
+
pred = binary_dilation(pred.astype(bool), ball(3))
|
130 |
+
|
131 |
+
# 4) remove small holes
|
132 |
+
pred = remove_small_holes(pred.astype(bool), area_threshold=0.001*np.prod(pred.shape)).astype(np.float32)
|
133 |
+
|
134 |
+
log.info("saving...")
|
135 |
+
pred = pred.astype(np.uint8)
|
136 |
+
img = nib.Nifti1Image(pred, affine=resampled_volume.affine)
|
137 |
+
resampled_lab = resample_from_to(img, nib_volume, order=0)
|
138 |
+
if multiple_flag:
|
139 |
+
nib.save(resampled_lab, output + "/" + curr.split("/")[-1].split(".")[0] + "-livermask" + ".nii")
|
140 |
+
else:
|
141 |
+
nib.save(resampled_lab, output + ".nii")
|
142 |
|
143 |
def main():
|
144 |
parser = argparse.ArgumentParser()
|
145 |
parser.add_argument('--input', metavar='--i', type=str, nargs='?',
|
146 |
+
help="set path of which image(s) to use.")
|
147 |
parser.add_argument('--output', metavar='--o', type=str, nargs='?',
|
148 |
help="set path to store the output.")
|
149 |
parser.add_argument('--cpu', action='store_true',
|
150 |
help="force using the CPU even if a GPU is available.")
|
151 |
+
parser.add_argument('--verbose', action='store_true',
|
152 |
+
help="enable verbose.")
|
153 |
ret = parser.parse_args(sys.argv[1:]); print(ret)
|
154 |
|
155 |
if ret.cpu:
|
|
|
173 |
raise ValueError("Please, provide an input.")
|
174 |
if ret.output is None:
|
175 |
raise ValueError("Please, provide an output.")
|
176 |
+
if not os.path.isdir(ret.input) and not ret.input.endswith(".nii"):
|
177 |
+
raise ValueError("Input path provided is not in the supported '.nii' format or a directory.")
|
178 |
+
if ret.output.endswith(".nii") or not os.path.isdir(ret.output) or "." in ret.output.split("/")[-1]:
|
179 |
+
raise ValueError("Output path provided is not a directory or a name (remove *.nii format from name).")
|
180 |
|
181 |
# fix paths
|
182 |
ret.input = ret.input.replace("\\", "/")
|