farrell236 commited on
Commit
2ef318a
·
1 Parent(s): 273da2c

Upload app.py

Browse files
Files changed (2) hide show
  1. app.py +151 -0
  2. utils.py +137 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ import numpy as np
4
+ import streamlit as st
5
+ import tensorflow as tf
6
+
7
+ from utils import _get_retina_bb, _pad_to_square
8
+
9
+
10
+ @st.cache(allow_output_mutation=True)
11
+ def load_model(model_file):
12
+ model = tf.keras.models.load_model(model_file, compile=False)
13
+ print(f'Model {model_file} Loaded!')
14
+ return model
15
+
16
+
17
+ @st.cache(allow_output_mutation=True)
18
+ def load_gatekeeper():
19
+ validator_model = tf.keras.models.load_model('checkpoints/ResNetV2-EyeQ-QA.tf')
20
+ print('Gatekeeper Model Loaded!')
21
+ return validator_model
22
+
23
+
24
+ def parse_function(image):
25
+ image = tf.image.resize(image, [512, 512])
26
+ image = tf.image.convert_image_dtype(image, tf.float32)
27
+ return image
28
+
29
+
30
+ def main():
31
+
32
+ st.title('Retina Segmentation')
33
+
34
+ st.sidebar.title('Segmentation Model')
35
+ options = st.sidebar.selectbox('Select Option:', ('Vessels', 'Lesions (BETA)'))
36
+ gatekeeper = st.sidebar.radio("Gatekeeper:", ('Enabled', 'Disabled'))
37
+
38
+ gatekeeper_model = load_gatekeeper()
39
+
40
+ if options == 'Vessels':
41
+
42
+ st.set_option('deprecation.showfileUploaderEncoding', False)
43
+ uploaded_file = st.file_uploader('Choose an image...', type=('png', 'jpg', 'jpeg'))
44
+
45
+ model = load_model('checkpoints/DeeplabV3Plus_DRIVE.tf')
46
+
47
+ if uploaded_file:
48
+ col1, col2 = st.columns(2)
49
+
50
+ # Load Image
51
+ file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
52
+ image = cv2.imdecode(file_bytes, 1)
53
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
54
+
55
+ # Check image
56
+ valid = np.argmax(gatekeeper_model(parse_function(image[None, ...])))
57
+ if valid == 2 and gatekeeper == 'Enabled':
58
+ st.image(image)
59
+ st.info('Image is of poor quality')
60
+ return
61
+
62
+ # Localise and center retina image
63
+ x, y, w, h, _ = _get_retina_bb(image)
64
+ image = image[y:y + h, x:x + w, :]
65
+ image = _pad_to_square(image, border=0)
66
+ image = cv2.resize(image, (1024, 1024))
67
+
68
+ with col1:
69
+ st.subheader("Uploaded Image")
70
+ st.image(image)
71
+
72
+ # Apply CLAHE pre-processing
73
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(16, 16))
74
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
75
+ image[:, :, 0] = clahe.apply(image[:, :, 0])
76
+ image = cv2.cvtColor(image, cv2.COLOR_LAB2RGB)
77
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
78
+ image = tf.image.convert_image_dtype(image, tf.float32)
79
+
80
+ # Run model on input
81
+ y_pred = model(image[None, ..., None])[0].numpy()
82
+
83
+ with col2:
84
+ st.subheader("Predicted Vessel")
85
+ st.image(y_pred)
86
+
87
+ elif options == 'Lesions (BETA)':
88
+
89
+ st.write('```--- WARNING: This model is highly experimental ---```')
90
+
91
+ st.set_option('deprecation.showfileUploaderEncoding', False)
92
+ uploaded_file = st.file_uploader('Choose an image...', type=('png', 'jpg', 'jpeg'))
93
+
94
+ model = load_model('checkpoints/DeeplabV3Plus_FGADR.tf')
95
+
96
+ if uploaded_file:
97
+ col1, col2, col3, = st.columns(3)
98
+
99
+ # Load Image
100
+ file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
101
+ image = cv2.imdecode(file_bytes, 1)
102
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
103
+
104
+ # Check image
105
+ valid = np.argmax(gatekeeper_model(parse_function(image[None, ...])))
106
+ if valid == 2 and gatekeeper == 'Enabled':
107
+ st.image(image)
108
+ st.info('Image is of poor quality')
109
+ return
110
+
111
+ # Localise and center retina image
112
+ x, y, w, h, _ = _get_retina_bb(image)
113
+ image = image[y:y + h, x:x + w, :]
114
+ image = _pad_to_square(image, border=0)
115
+ image = cv2.resize(image, (1024, 1024))
116
+
117
+ with col1:
118
+ st.subheader("Uploaded Image")
119
+ st.image(image)
120
+
121
+ # Apply CLAHE pre-processing
122
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(16, 16))
123
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
124
+ image[:, :, 0] = clahe.apply(image[:, :, 0])
125
+ image = cv2.cvtColor(image, cv2.COLOR_LAB2RGB)
126
+ image = tf.image.convert_image_dtype(image, tf.float32)
127
+
128
+ # Run model on input
129
+ y_pred = model(image[None, ..., None])[0].numpy()
130
+
131
+ with col2:
132
+ st.subheader(f'MA')
133
+ st.image(y_pred[..., 1])
134
+ with col3:
135
+ st.subheader(f'HE')
136
+ st.image(y_pred[..., 2])
137
+ with col1:
138
+ st.subheader(f'EX')
139
+ st.image(y_pred[..., 3])
140
+ with col2:
141
+ st.subheader(f'SE')
142
+ st.image(y_pred[..., 4])
143
+ with col3:
144
+ st.subheader(f'OD')
145
+ st.image(y_pred[..., 5])
146
+
147
+ if __name__ == '__main__':
148
+
149
+ tf.config.set_visible_devices([], 'GPU')
150
+
151
+ main()
utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+
6
+
7
+ def _pad_to_square(image, long_side=None, border=0):
8
+ h, w, _ = image.shape
9
+
10
+ if long_side == None: long_side = max(h, w)
11
+
12
+ l_pad = (long_side - w) // 2 + border
13
+ r_pad = (long_side - w) // 2 + border
14
+ t_pad = (long_side - h) // 2 + border
15
+ b_pad = (long_side - h) // 2 + border
16
+ if w % 2 != 0: r_pad = r_pad + 1
17
+ if h % 2 != 0: b_pad = b_pad + 1
18
+
19
+ image = np.pad(
20
+ image,
21
+ ((t_pad, b_pad),
22
+ (l_pad, r_pad),
23
+ (0, 0)),
24
+ 'constant')
25
+
26
+ return image
27
+
28
+
29
+ def _get_retina_bb(image):
30
+
31
+ # make image greyscale and normalise
32
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
33
+ image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX)
34
+
35
+ # calculate threshold perform threshold
36
+ threshold = np.mean(image)/3-7
37
+ ret, thresh = cv2.threshold(image, max(0, threshold), 255, cv2.THRESH_BINARY)
38
+
39
+ # median filter, erode and dilate to remove noise and holes
40
+ thresh = cv2.medianBlur(thresh, 25)
41
+ thresh = cv2.erode(thresh, None, iterations=2)
42
+ thresh = cv2.dilate(thresh, None, iterations=2)
43
+
44
+ # find mask contour
45
+ cnts = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
46
+ cnts = cnts[0] if len(cnts) == 2 else cnts[1]
47
+ c = max(cnts, key=cv2.contourArea)
48
+
49
+ # Get bounding box from mask contour
50
+ x, y, w, h = cv2.boundingRect(c)
51
+
52
+ # Get mask from contour
53
+ mask = np.zeros_like(image)
54
+ cv2.drawContours(mask, [c], -1, (255, 255, 255), -1)
55
+
56
+ return x, y, w, h, mask
57
+
58
+
59
+ def _get_retina_bb2(image, skips=4):
60
+ '''
61
+ Experimental Retina Bounding Box detector based on Convexity Defect Points
62
+ '''
63
+
64
+ # make image greyscale and normalise
65
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
66
+ image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX)
67
+
68
+ # calculate threshold perform threshold
69
+ threshold = np.mean(image)/3-7
70
+ ret, thresh = cv2.threshold(image, max(0, threshold), 255, cv2.THRESH_BINARY)
71
+
72
+ # median filter, erode and dilate to remove noise and holes
73
+ thresh = cv2.medianBlur(thresh, 25)
74
+ thresh = cv2.erode(thresh, None, iterations=2)
75
+ thresh = cv2.dilate(thresh, None, iterations=2)
76
+
77
+ # Find contours
78
+ contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
79
+
80
+ # Find the index of the largest contour
81
+ areas = [cv2.contourArea(c) for c in contours]
82
+ max_index = np.argmax(areas)
83
+ cnt = contours[max_index]
84
+
85
+ # Get convexity defect points
86
+ hull = cv2.convexHull(cnt, returnPoints=False)
87
+ hull[::-1].sort(axis=0)
88
+ defects = cv2.convexityDefects(cnt, hull)
89
+
90
+ ConvexityDefectPoint = []
91
+ for i in range(0, defects.shape[0], skips):
92
+ s, e, f, d = defects[i, 0]
93
+ ConvexityDefectPoint.append(tuple(cnt[f][0]))
94
+
95
+ # Get minimum enclosing circle as retina estimate
96
+ (x, y), radius = cv2.minEnclosingCircle(np.array(ConvexityDefectPoint))
97
+
98
+ # Get mask from contour
99
+ mask = np.zeros_like(image)
100
+ cv2.circle(mask, (x, y), radius, (255, 255, 255), -1)
101
+
102
+ # return (x, y, w, h) bounding box
103
+ return int(x - radius), int(y - radius), int(2 * radius - 1), int(2 * radius - 1), mask
104
+
105
+
106
+ def rgb_clahe(image, clipLimit=2.0, tileGridSize=(16, 16)):
107
+ lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
108
+ clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
109
+ lab[..., 0] = clahe.apply(lab[..., 0])
110
+ return cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
111
+
112
+
113
+ if __name__ == '__main__':
114
+
115
+ # image_file = '/vol/biomedic3/bh1511/retina/IDRID/segmentation/0_Original_Images/IDRiD_65.jpg'
116
+ # image_file = '/vol/biomedic3/bh1511/retina/CHASE_DB1/images/Image_08R.jpg'
117
+ # image_file = '/vol/vipdata/data/retina/kaggle-diabetic-retinopathy-detection/train/16_right.jpeg'
118
+ # image_file = '/vol/vipdata/data/retina/IDRID/a_segmentation/images/train/IDRiD_01.jpg'
119
+ image_file = 'preprocessing/Image_10L.png'
120
+
121
+ # Load Image
122
+ image = cv2.imread(image_file)
123
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
124
+ plt.imshow(image); plt.show()
125
+
126
+ # Localise and center retina image
127
+ x, y, w, h, _ = _get_retina_bb(image)
128
+ image = image[y:y + h, x:x + w, :]
129
+ image = _pad_to_square(image, border=0)
130
+ image = cv2.resize(image, (1024, 1024))
131
+
132
+ # Apply CLAHE pre-processing
133
+ image = rgb_clahe(image)
134
+
135
+ # Display or save image
136
+ plt.imshow(image); plt.show()
137
+ # cv2.imwrite('processed_image.png', image)