weidai00 commited on
Commit
04cd202
·
verified ·
1 Parent(s): b91ca17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +384 -384
app.py CHANGED
@@ -1,384 +1,384 @@
1
- import torch
2
- import gradio as gr
3
- from PIL import Image
4
- import cv2
5
- from AV.models.network import PGNet
6
- from AV.Tools.AVclassifiation import AVclassifiation
7
- from AV.Tools.utils_test import paint_border_overlap, extract_ordered_overlap_big, Normalize, sigmoid, recompone_overlap, \
8
- kill_border
9
- from AV.config import config_test_general as cfg
10
- import torch.autograd as autograd
11
- import numpy as np
12
- import os
13
- from datetime import datetime
14
-
15
- def creatMask(Image, threshold=5):
16
- ##This program try to creat the mask for the filed-of-view
17
- ##Input original image (RGB or green channel), threshold (user set parameter, default 10)
18
- ##Output: the filed-of-view mask
19
-
20
- if len(Image.shape) == 3: ##RGB image
21
- gray = cv2.cvtColor(Image, cv2.COLOR_BGR2GRAY)
22
- Mask0 = gray >= threshold
23
-
24
- else: # for green channel image
25
- Mask0 = Image >= threshold
26
-
27
- # ######get the largest blob, this takes 0.18s
28
- cvVersion = int(cv2.__version__.split('.')[0])
29
-
30
- Mask0 = np.uint8(Mask0)
31
-
32
- contours, hierarchy = cv2.findContours(Mask0, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
33
-
34
- areas = [cv2.contourArea(c) for c in contours]
35
- max_index = np.argmax(areas)
36
- Mask = np.zeros(Image.shape[:2], dtype=np.uint8)
37
- cv2.drawContours(Mask, contours, max_index, 1, -1)
38
-
39
- ResultImg = Image.copy()
40
- if len(Image.shape) == 3:
41
- ResultImg[Mask == 0] = (255, 255, 255)
42
- else:
43
- ResultImg[Mask == 0] = 255
44
- Mask[Mask > 0] = 255
45
- kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
46
- Mask = cv2.morphologyEx(Mask, cv2.MORPH_OPEN, kernel, iterations=3)
47
- return ResultImg, Mask
48
-
49
-
50
- def shift_rgb(img, *args):
51
- result_img = np.empty_like(img)
52
- shifts = args
53
- max_value = 255
54
- # print(shifts)
55
- for i, shift in enumerate(shifts):
56
- lut = np.arange(0, max_value + 1).astype("float32")
57
- lut += shift
58
-
59
- lut = np.clip(lut, 0, max_value).astype(img.dtype)
60
- if len(img.shape) == 2:
61
- print(f'=========grey image=======')
62
- result_img = cv2.LUT(img, lut)
63
- else:
64
- result_img[..., i] = cv2.LUT(img[..., i], lut)
65
-
66
- return result_img
67
-
68
-
69
- def CAM(x, img_path, rate=0.8, ind=0):
70
- """
71
- :param dataset_path: 计算整个训练数据集的平均RGB通道值
72
- :param image: array, 单张图片的array 形式
73
- :return: array形式的cam后的结果
74
- """
75
- # 每次使用新数据集时都需要重新计算前面的RBG平均值
76
- # RGB-->Rshift-->CLAHE
77
-
78
- x = np.uint8(x)
79
- _, Mask0 = creatMask(x, threshold=10)
80
- Mask = np.zeros((x.shape[0], x.shape[1]), np.float32)
81
- Mask[Mask0 > 0] = 1
82
-
83
- resize = False
84
- R_mea_num, G_mea_num, B_mea_num = [], [], []
85
-
86
- dataset_path = img_path
87
- image = np.array(Image.open(dataset_path))
88
- R_mea_num.append(np.mean(image[:, :, 0]))
89
- G_mea_num.append(np.mean(image[:, :, 1]))
90
- B_mea_num.append(np.mean(image[:, :, 2]))
91
-
92
- mea2stand = int((np.mean(R_mea_num) - np.mean(x[:, :, 0])) * rate)
93
- mea2standg = int((np.mean(G_mea_num) - np.mean(x[:, :, 1])) * rate)
94
- mea2standb = int((np.mean(B_mea_num) - np.mean(x[:, :, 2])) * rate)
95
-
96
- y = shift_rgb(x, mea2stand, mea2standg, mea2standb)
97
-
98
- y[Mask == 0, :] = 0
99
-
100
- return y
101
-
102
-
103
- def modelEvalution_out_big(net, use_cuda=False, dataset='', is_kill_border=True, input_ch=3,
104
- config=None, output_dir='', evaluate_metrics=False):
105
- # path for images to save
106
- n_classes = 3
107
- Net = PGNet(use_global_semantic=config.use_global_semantic, input_ch=input_ch,
108
- num_classes=n_classes, use_cuda=use_cuda, pretrained=False, centerness=config.use_centerness,
109
- centerness_map_size=config.centerness_map_size)
110
- msg = Net.load_state_dict(net, strict=False)
111
-
112
- if use_cuda:
113
- Net.cuda()
114
- Net.eval()
115
-
116
- image_basename = dataset
117
-
118
- # if not os.path.exists(output_dir):
119
- # os.makedirs(output_dir)
120
-
121
- step = 1
122
- # every step of between star and end for loop until len(image_basename)
123
-
124
- # for start_end in start_end_list:
125
- image0 = cv2.imread(image_basename)
126
- test_image_height = image0.shape[0]
127
- test_image_width = image0.shape[1]
128
-
129
- if config.use_resize:
130
-
131
- if min(test_image_height, test_image_width) <= 256:
132
- scaling = 512 / min(test_image_height, test_image_width)
133
- new_width = int(test_image_width * scaling)
134
- new_height = int(test_image_height * scaling)
135
- test_image_width, test_image_height = new_width, new_height
136
-
137
- # 大尺寸处理:确保最长边≤1536
138
- elif max(test_image_height, test_image_width) >= 2048:
139
- scaling = 2048 / max(test_image_height, test_image_width)
140
- new_width = int(test_image_width * scaling)
141
- new_height = int(test_image_height * scaling)
142
- test_image_width, test_image_height = new_width, new_height
143
-
144
- ArteryPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32)
145
- VeinPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32)
146
- VesselPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32)
147
- ProMap = np.zeros((1, 3, test_image_height, test_image_width), np.float32)
148
- MaskAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32)
149
- ArteryPred, VeinPred, VesselPred, Mask, LabelArtery, LabelVein, LabelVessel = GetResult_out_big(Net, 0,
150
- use_cuda=use_cuda,
151
- dataset=image_basename,
152
- is_kill_border=is_kill_border,
153
- config=config,
154
- resize_w_h=(
155
- test_image_width,
156
- test_image_height)
157
- )
158
- ArteryPredAll[0 % step, :, :, :] = ArteryPred
159
- VeinPredAll[0 % step, :, :, :] = VeinPred
160
- VesselPredAll[0 % step, :, :, :] = VesselPred
161
-
162
- MaskAll[0 % step, :, :, :] = Mask
163
-
164
- image_color = AVclassifiation(output_dir, ArteryPredAll, VeinPredAll, VesselPredAll, 1, image_basename)
165
-
166
- return image_color
167
-
168
-
169
- def GetResult_out_big(Net, k, use_cuda=False, dataset='', is_kill_border=False, config=None,
170
- resize_w_h=None):
171
- ImgName = dataset
172
- Img0 = cv2.imread(ImgName)
173
-
174
- _, Mask0 = creatMask(Img0, threshold=-1)
175
- Mask = np.zeros((Img0.shape[0], Img0.shape[1]), np.float32)
176
- Mask[Mask0 > 0] = 1
177
-
178
- if config.use_resize:
179
- Img0 = cv2.resize(Img0, resize_w_h)
180
- Mask = cv2.resize(Mask, resize_w_h, interpolation=cv2.INTER_NEAREST)
181
-
182
- Img = Img0
183
- height, width = Img.shape[:2]
184
- n_classes = 3
185
- patch_height = config.patch_size
186
- patch_width = config.patch_size
187
- stride_height = config.stride_height
188
- stride_width = config.stride_width
189
-
190
- Img = cv2.cvtColor(Img, cv2.COLOR_BGR2RGB)
191
- if cfg.dataset == 'all':
192
- # # # 将图像转换为 LAB 颜色空间
193
- lab = cv2.cvtColor(Img, cv2.COLOR_RGB2LAB)
194
-
195
- # 拆分 LAB 通道
196
- l, a, b = cv2.split(lab)
197
-
198
- # 创建 CLAHE 对象并应用到 L 通道
199
- clahe = cv2.createCLAHE(clipLimit=2, tileGridSize=(8, 8))
200
- l_clahe = clahe.apply(l)
201
-
202
- # 将 CLAHE 处理后的 L 通道与原始的 A 和 B 通道合并
203
- lab_clahe = cv2.merge((l_clahe, a, b))
204
-
205
- # 将图像转换回 BGR 颜色空间
206
- Img = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
207
-
208
- if cfg.use_CAM:
209
- Img = CAM(Img, dataset)
210
-
211
- Img = np.float32(Img / 255.)
212
- Img_enlarged = paint_border_overlap(Img, patch_height, patch_width, stride_height, stride_width)
213
- patch_size = config.patch_size
214
- batch_size = 2
215
- patches_imgs, global_images = extract_ordered_overlap_big(Img_enlarged, patch_height, patch_width,
216
- stride_height,
217
- stride_width)
218
-
219
- patches_imgs = np.transpose(patches_imgs, (0, 3, 1, 2))
220
- patches_imgs = Normalize(patches_imgs)
221
- global_images = np.transpose(global_images, (0, 3, 1, 2))
222
- global_images = Normalize(global_images)
223
- patchNum = patches_imgs.shape[0]
224
- max_iter = int(np.ceil(patchNum / float(batch_size)))
225
-
226
- pred_patches = np.zeros((patchNum, n_classes, patch_size, patch_size), np.float32)
227
-
228
- for i in range(max_iter):
229
- begin_index = i * batch_size
230
- end_index = (i + 1) * batch_size
231
-
232
- patches_temp1 = patches_imgs[begin_index:end_index, :, :, :]
233
-
234
- patches_input_temp1 = torch.FloatTensor(patches_temp1)
235
- global_input_temp1 = patches_input_temp1
236
- if config.use_global_semantic:
237
- global_temp1 = global_images[begin_index:end_index, :, :, :]
238
- global_input_temp1 = torch.FloatTensor(global_temp1)
239
- if use_cuda:
240
- patches_input_temp1 = autograd.Variable(patches_input_temp1.cuda())
241
- if config.use_global_semantic:
242
- global_input_temp1 = autograd.Variable(global_input_temp1.cuda())
243
- else:
244
- patches_input_temp1 = autograd.Variable(patches_input_temp1)
245
- if config.use_global_semantic:
246
- global_input_temp1 = autograd.Variable(global_input_temp1)
247
-
248
- output_temp, _1, = Net(patches_input_temp1, global_input_temp1)
249
-
250
- pred_patches_temp = np.float32(output_temp.data.cpu().numpy())
251
-
252
- pred_patches_temp_sigmoid = sigmoid(pred_patches_temp)
253
-
254
- pred_patches[begin_index:end_index, :, :, :] = pred_patches_temp_sigmoid[:, :, :patch_size, :patch_size]
255
-
256
- del patches_input_temp1
257
- del pred_patches_temp
258
- del patches_temp1
259
- del output_temp
260
- del pred_patches_temp_sigmoid
261
-
262
- new_height, new_width = Img_enlarged.shape[0], Img_enlarged.shape[1]
263
-
264
- pred_img = recompone_overlap(pred_patches, new_height, new_width, stride_height, stride_width) # predictions
265
- pred_img = pred_img[:, 0:height, 0:width]
266
-
267
- if is_kill_border:
268
- pred_img = kill_border(pred_img, Mask)
269
-
270
- ArteryPred = np.float32(pred_img[0, :, :])
271
- VeinPred = np.float32(pred_img[2, :, :])
272
- VesselPred = np.float32(pred_img[1, :, :])
273
-
274
- ArteryPred = ArteryPred[np.newaxis, :, :]
275
- VeinPred = VeinPred[np.newaxis, :, :]
276
- VesselPred = VesselPred[np.newaxis, :, :]
277
- Mask = Mask[np.newaxis, :, :]
278
-
279
- return ArteryPred, VeinPred, VesselPred, Mask, ArteryPred, VeinPred, VesselPred,
280
-
281
-
282
- def out_test(cfg, output_dir='', evaluate_metrics=False, img_name='out_test'):
283
- device = torch.device("cuda" if cfg.use_cuda else "cpu")
284
- model_root = cfg.model_path_pretrained_G
285
- model_path = os.path.join(model_root, 'G_' + str(cfg.model_step_pretrained_G) + '.pkl')
286
- net = torch.load(model_path, map_location=device)
287
-
288
- image_color = modelEvalution_out_big(net,
289
- use_cuda=cfg.use_cuda,
290
- dataset=img_name,
291
- input_ch=cfg.input_nc,
292
- config=cfg,
293
- output_dir=output_dir, evaluate_metrics=evaluate_metrics)
294
-
295
- return image_color
296
-
297
-
298
- def segment_by_out_test(image,model_name):
299
- print("✅ 传到后端的模型名:", model_name)
300
-
301
- cfg.set_dataset(model_name)
302
- if image is None:
303
- raise gr.Error("请上传一张图像。")
304
- os.makedirs("./examples", exist_ok=True)
305
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
306
- temp_path = f"./examples/tmp_upload_{timestamp}.png"
307
- image.save(temp_path)
308
-
309
- image_color = out_test(cfg, output_dir='', evaluate_metrics=False, img_name=temp_path)
310
- return Image.fromarray(image_color)
311
-
312
- def gradio_interface():
313
- model_info_md = """
314
- ### 📘 模型说明
315
-
316
- | 模型 | 数据集 | patch size |running time |
317
- |------|--------|------------|--------|
318
- | DRIVE | 小分辨率血管图像 | 256 |30s以内|
319
- | HRF | 高分辨率图像(健康、青光眼等) | 2min以内|
320
- | LES | 视盘中心图像适配 | 256 |2min以内|
321
- | UKBB | UKBB图像 | 256 |2min以内 |
322
- | 通用模型(512) | 超清图像,适配性强 | 512 |2min以内|
323
- """
324
- model_choices = [
325
- ("1: DRIVE专用模型", "DRIVE"),
326
- ("2: HRF专用模型", "hrf"),
327
- ("3: LES专用模型","LES"),
328
- ("4: UKBB专用模型", "ukbb"),
329
- ("5: 通用模型", "all"),
330
- ]
331
-
332
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
333
- gr.Markdown("# 👁️ 眼底图像动静脉血管分割")
334
- gr.Markdown("上传眼底图像,选择一个模型开始处理,结果将自动生成。")
335
-
336
- with gr.Row():
337
- image_input = gr.Image(type="pil", label="📤 上传图像",height=300)
338
-
339
- with gr.Row():
340
- with gr.Column():
341
- model_select = gr.Radio(
342
- choices=model_choices,
343
- label="🎯 选择模型",
344
- value="DRIVE",
345
- interactive = True
346
- )
347
- submit_btn = gr.Button("🚀 开始分割")
348
- with gr.Column():
349
- output_image = gr.Image(label="🖼️ 分割结果")
350
-
351
- gr.Markdown("### 📁 示例图像(点击自动加载)")
352
- gr.Examples(
353
- examples=[
354
- ["examples/DRIVE.tif", "DRIVE"],
355
- ["examples/LES.png", "LES"],
356
- ["examples/hrf.png", "hrf"],
357
- ["examples/ukbb.png", "ukbb"],
358
- ["examples/all.jpg", "all"]
359
- ],
360
- inputs=[image_input, model_select],
361
- label="示例图像",
362
- examples_per_page=5
363
- )
364
- with gr.Accordion("📖 模型说明(点击展开)", open=False):
365
- gr.Markdown(model_info_md)
366
-
367
- # 功能连接
368
- submit_btn.click(
369
- fn=segment_by_out_test,
370
- inputs=[image_input, model_select],
371
- outputs=[output_image]
372
- )
373
- gr.Markdown("📚 **专用模型**: RIP-AV: Joint Representative Instance Pre-training with Context Aware Network for Retinal Artery/Vein Segmentation")
374
- gr.Markdown("📚 **通用模型**: An Efficient and Interpretable Foundation Model for Retinal Image Analysis in Disease Diagnosis.")
375
- demo.queue()
376
- demo.launch()
377
-
378
-
379
- if __name__ == '__main__':
380
- # cfg.set_dataset('all')
381
- # image_color = out_test(cfg = cfg, evaluate_metrics=False, img_name=r'.\AV\data\AV-DRIVE\test\images\01_test.tif')
382
- # Image.fromarray(image_color).save('image_color.png')
383
- #print(cfg.patch_size)
384
- gradio_interface()
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import cv2
5
+ from AV.models.network import PGNet
6
+ from AV.Tools.AVclassifiation import AVclassifiation
7
+ from AV.Tools.utils_test import paint_border_overlap, extract_ordered_overlap_big, Normalize, sigmoid, recompone_overlap, \
8
+ kill_border
9
+ from AV.config import config_test_general as cfg
10
+ import torch.autograd as autograd
11
+ import numpy as np
12
+ import os
13
+ from datetime import datetime
14
+
15
+ def creatMask(Image, threshold=5):
16
+ ##This program try to creat the mask for the filed-of-view
17
+ ##Input original image (RGB or green channel), threshold (user set parameter, default 10)
18
+ ##Output: the filed-of-view mask
19
+
20
+ if len(Image.shape) == 3: ##RGB image
21
+ gray = cv2.cvtColor(Image, cv2.COLOR_BGR2GRAY)
22
+ Mask0 = gray >= threshold
23
+
24
+ else: # for green channel image
25
+ Mask0 = Image >= threshold
26
+
27
+ # ######get the largest blob, this takes 0.18s
28
+ cvVersion = int(cv2.__version__.split('.')[0])
29
+
30
+ Mask0 = np.uint8(Mask0)
31
+
32
+ contours, hierarchy = cv2.findContours(Mask0, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
33
+
34
+ areas = [cv2.contourArea(c) for c in contours]
35
+ max_index = np.argmax(areas)
36
+ Mask = np.zeros(Image.shape[:2], dtype=np.uint8)
37
+ cv2.drawContours(Mask, contours, max_index, 1, -1)
38
+
39
+ ResultImg = Image.copy()
40
+ if len(Image.shape) == 3:
41
+ ResultImg[Mask == 0] = (255, 255, 255)
42
+ else:
43
+ ResultImg[Mask == 0] = 255
44
+ Mask[Mask > 0] = 255
45
+ kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
46
+ Mask = cv2.morphologyEx(Mask, cv2.MORPH_OPEN, kernel, iterations=3)
47
+ return ResultImg, Mask
48
+
49
+
50
+ def shift_rgb(img, *args):
51
+ result_img = np.empty_like(img)
52
+ shifts = args
53
+ max_value = 255
54
+ # print(shifts)
55
+ for i, shift in enumerate(shifts):
56
+ lut = np.arange(0, max_value + 1).astype("float32")
57
+ lut += shift
58
+
59
+ lut = np.clip(lut, 0, max_value).astype(img.dtype)
60
+ if len(img.shape) == 2:
61
+ print(f'=========grey image=======')
62
+ result_img = cv2.LUT(img, lut)
63
+ else:
64
+ result_img[..., i] = cv2.LUT(img[..., i], lut)
65
+
66
+ return result_img
67
+
68
+
69
+ def CAM(x, img_path, rate=0.8, ind=0):
70
+ """
71
+ :param dataset_path: 计算整个训练数据集的平均RGB通道值
72
+ :param image: array, 单张图片的array 形式
73
+ :return: array形式的cam后的结果
74
+ """
75
+ # 每次使用新数据集时都需要重新计算前面的RBG平均值
76
+ # RGB-->Rshift-->CLAHE
77
+
78
+ x = np.uint8(x)
79
+ _, Mask0 = creatMask(x, threshold=10)
80
+ Mask = np.zeros((x.shape[0], x.shape[1]), np.float32)
81
+ Mask[Mask0 > 0] = 1
82
+
83
+ resize = False
84
+ R_mea_num, G_mea_num, B_mea_num = [], [], []
85
+
86
+ dataset_path = img_path
87
+ image = np.array(Image.open(dataset_path))
88
+ R_mea_num.append(np.mean(image[:, :, 0]))
89
+ G_mea_num.append(np.mean(image[:, :, 1]))
90
+ B_mea_num.append(np.mean(image[:, :, 2]))
91
+
92
+ mea2stand = int((np.mean(R_mea_num) - np.mean(x[:, :, 0])) * rate)
93
+ mea2standg = int((np.mean(G_mea_num) - np.mean(x[:, :, 1])) * rate)
94
+ mea2standb = int((np.mean(B_mea_num) - np.mean(x[:, :, 2])) * rate)
95
+
96
+ y = shift_rgb(x, mea2stand, mea2standg, mea2standb)
97
+
98
+ y[Mask == 0, :] = 0
99
+
100
+ return y
101
+
102
+
103
+ def modelEvalution_out_big(net, use_cuda=False, dataset='', is_kill_border=True, input_ch=3,
104
+ config=None, output_dir='', evaluate_metrics=False):
105
+ # path for images to save
106
+ n_classes = 3
107
+ Net = PGNet(use_global_semantic=config.use_global_semantic, input_ch=input_ch,
108
+ num_classes=n_classes, use_cuda=use_cuda, pretrained=False, centerness=config.use_centerness,
109
+ centerness_map_size=config.centerness_map_size)
110
+ msg = Net.load_state_dict(net, strict=False)
111
+
112
+ if use_cuda:
113
+ Net.cuda()
114
+ Net.eval()
115
+
116
+ image_basename = dataset
117
+
118
+ # if not os.path.exists(output_dir):
119
+ # os.makedirs(output_dir)
120
+
121
+ step = 1
122
+ # every step of between star and end for loop until len(image_basename)
123
+
124
+ # for start_end in start_end_list:
125
+ image0 = cv2.imread(image_basename)
126
+ test_image_height = image0.shape[0]
127
+ test_image_width = image0.shape[1]
128
+
129
+ if config.use_resize:
130
+
131
+ if min(test_image_height, test_image_width) <= 256:
132
+ scaling = 512 / min(test_image_height, test_image_width)
133
+ new_width = int(test_image_width * scaling)
134
+ new_height = int(test_image_height * scaling)
135
+ test_image_width, test_image_height = new_width, new_height
136
+
137
+ # 大尺寸处理:确保最长边≤1536
138
+ elif max(test_image_height, test_image_width) >= 2048:
139
+ scaling = 2048 / max(test_image_height, test_image_width)
140
+ new_width = int(test_image_width * scaling)
141
+ new_height = int(test_image_height * scaling)
142
+ test_image_width, test_image_height = new_width, new_height
143
+
144
+ ArteryPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32)
145
+ VeinPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32)
146
+ VesselPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32)
147
+ ProMap = np.zeros((1, 3, test_image_height, test_image_width), np.float32)
148
+ MaskAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32)
149
+ ArteryPred, VeinPred, VesselPred, Mask, LabelArtery, LabelVein, LabelVessel = GetResult_out_big(Net, 0,
150
+ use_cuda=use_cuda,
151
+ dataset=image_basename,
152
+ is_kill_border=is_kill_border,
153
+ config=config,
154
+ resize_w_h=(
155
+ test_image_width,
156
+ test_image_height)
157
+ )
158
+ ArteryPredAll[0 % step, :, :, :] = ArteryPred
159
+ VeinPredAll[0 % step, :, :, :] = VeinPred
160
+ VesselPredAll[0 % step, :, :, :] = VesselPred
161
+
162
+ MaskAll[0 % step, :, :, :] = Mask
163
+
164
+ image_color = AVclassifiation(output_dir, ArteryPredAll, VeinPredAll, VesselPredAll, 1, image_basename)
165
+
166
+ return image_color
167
+
168
+
169
+ def GetResult_out_big(Net, k, use_cuda=False, dataset='', is_kill_border=False, config=None,
170
+ resize_w_h=None):
171
+ ImgName = dataset
172
+ Img0 = cv2.imread(ImgName)
173
+
174
+ _, Mask0 = creatMask(Img0, threshold=-1)
175
+ Mask = np.zeros((Img0.shape[0], Img0.shape[1]), np.float32)
176
+ Mask[Mask0 > 0] = 1
177
+
178
+ if config.use_resize:
179
+ Img0 = cv2.resize(Img0, resize_w_h)
180
+ Mask = cv2.resize(Mask, resize_w_h, interpolation=cv2.INTER_NEAREST)
181
+
182
+ Img = Img0
183
+ height, width = Img.shape[:2]
184
+ n_classes = 3
185
+ patch_height = config.patch_size
186
+ patch_width = config.patch_size
187
+ stride_height = config.stride_height
188
+ stride_width = config.stride_width
189
+
190
+ Img = cv2.cvtColor(Img, cv2.COLOR_BGR2RGB)
191
+ if cfg.dataset == 'all':
192
+ # # # 将图像转换为 LAB 颜色空间
193
+ lab = cv2.cvtColor(Img, cv2.COLOR_RGB2LAB)
194
+
195
+ # 拆分 LAB 通道
196
+ l, a, b = cv2.split(lab)
197
+
198
+ # 创建 CLAHE 对象并应用到 L 通道
199
+ clahe = cv2.createCLAHE(clipLimit=2, tileGridSize=(8, 8))
200
+ l_clahe = clahe.apply(l)
201
+
202
+ # 将 CLAHE 处理后的 L 通道与原始的 A 和 B 通道合并
203
+ lab_clahe = cv2.merge((l_clahe, a, b))
204
+
205
+ # 将图像转换回 BGR 颜色空间
206
+ Img = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
207
+
208
+ if cfg.use_CAM:
209
+ Img = CAM(Img, dataset)
210
+
211
+ Img = np.float32(Img / 255.)
212
+ Img_enlarged = paint_border_overlap(Img, patch_height, patch_width, stride_height, stride_width)
213
+ patch_size = config.patch_size
214
+ batch_size = 2
215
+ patches_imgs, global_images = extract_ordered_overlap_big(Img_enlarged, patch_height, patch_width,
216
+ stride_height,
217
+ stride_width)
218
+
219
+ patches_imgs = np.transpose(patches_imgs, (0, 3, 1, 2))
220
+ patches_imgs = Normalize(patches_imgs)
221
+ global_images = np.transpose(global_images, (0, 3, 1, 2))
222
+ global_images = Normalize(global_images)
223
+ patchNum = patches_imgs.shape[0]
224
+ max_iter = int(np.ceil(patchNum / float(batch_size)))
225
+
226
+ pred_patches = np.zeros((patchNum, n_classes, patch_size, patch_size), np.float32)
227
+
228
+ for i in range(max_iter):
229
+ begin_index = i * batch_size
230
+ end_index = (i + 1) * batch_size
231
+
232
+ patches_temp1 = patches_imgs[begin_index:end_index, :, :, :]
233
+
234
+ patches_input_temp1 = torch.FloatTensor(patches_temp1)
235
+ global_input_temp1 = patches_input_temp1
236
+ if config.use_global_semantic:
237
+ global_temp1 = global_images[begin_index:end_index, :, :, :]
238
+ global_input_temp1 = torch.FloatTensor(global_temp1)
239
+ if use_cuda:
240
+ patches_input_temp1 = autograd.Variable(patches_input_temp1.cuda())
241
+ if config.use_global_semantic:
242
+ global_input_temp1 = autograd.Variable(global_input_temp1.cuda())
243
+ else:
244
+ patches_input_temp1 = autograd.Variable(patches_input_temp1)
245
+ if config.use_global_semantic:
246
+ global_input_temp1 = autograd.Variable(global_input_temp1)
247
+
248
+ output_temp, _1, = Net(patches_input_temp1, global_input_temp1)
249
+
250
+ pred_patches_temp = np.float32(output_temp.data.cpu().numpy())
251
+
252
+ pred_patches_temp_sigmoid = sigmoid(pred_patches_temp)
253
+
254
+ pred_patches[begin_index:end_index, :, :, :] = pred_patches_temp_sigmoid[:, :, :patch_size, :patch_size]
255
+
256
+ del patches_input_temp1
257
+ del pred_patches_temp
258
+ del patches_temp1
259
+ del output_temp
260
+ del pred_patches_temp_sigmoid
261
+
262
+ new_height, new_width = Img_enlarged.shape[0], Img_enlarged.shape[1]
263
+
264
+ pred_img = recompone_overlap(pred_patches, new_height, new_width, stride_height, stride_width) # predictions
265
+ pred_img = pred_img[:, 0:height, 0:width]
266
+
267
+ if is_kill_border:
268
+ pred_img = kill_border(pred_img, Mask)
269
+
270
+ ArteryPred = np.float32(pred_img[0, :, :])
271
+ VeinPred = np.float32(pred_img[2, :, :])
272
+ VesselPred = np.float32(pred_img[1, :, :])
273
+
274
+ ArteryPred = ArteryPred[np.newaxis, :, :]
275
+ VeinPred = VeinPred[np.newaxis, :, :]
276
+ VesselPred = VesselPred[np.newaxis, :, :]
277
+ Mask = Mask[np.newaxis, :, :]
278
+
279
+ return ArteryPred, VeinPred, VesselPred, Mask, ArteryPred, VeinPred, VesselPred,
280
+
281
+
282
+ def out_test(cfg, output_dir='', evaluate_metrics=False, img_name='out_test'):
283
+ device = torch.device("cuda" if cfg.use_cuda else "cpu")
284
+ model_root = cfg.model_path_pretrained_G
285
+ model_path = os.path.join(model_root, 'G_' + str(cfg.model_step_pretrained_G) + '.pkl')
286
+ net = torch.load(model_path, map_location=device)
287
+
288
+ image_color = modelEvalution_out_big(net,
289
+ use_cuda=cfg.use_cuda,
290
+ dataset=img_name,
291
+ input_ch=cfg.input_nc,
292
+ config=cfg,
293
+ output_dir=output_dir, evaluate_metrics=evaluate_metrics)
294
+
295
+ return image_color
296
+
297
+
298
+ def segment_by_out_test(image,model_name):
299
+ print("✅ 传到后端的模型名:", model_name)
300
+
301
+ cfg.set_dataset(model_name)
302
+ if image is None:
303
+ raise gr.Error("请上传一张图像(upload a fundus image)。")
304
+ os.makedirs("./examples", exist_ok=True)
305
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
306
+ temp_path = f"./examples/tmp_upload_{timestamp}.png"
307
+ image.save(temp_path)
308
+
309
+ image_color = out_test(cfg, output_dir='', evaluate_metrics=False, img_name=temp_path)
310
+ return Image.fromarray(image_color)
311
+
312
+ def gradio_interface():
313
+ model_info_md = """
314
+ ### 📘 模型说明
315
+
316
+ | 模型(model name) | 数据集(dataset) | patch size |running time |
317
+ |------|--------|------------|--------|
318
+ | DRIVE | 小分辨率血管图像 | 256 |30s以内|
319
+ | HRF | 高分辨率图像(健康、青光眼等)| 256 | 2min以内|
320
+ | LES | 视盘中心图像适配 | 256 |2min以内|
321
+ | UKBB | UKBB图像 | 256 |2min以内 |
322
+ | 通用模型(512) | 超清图像,适配性强 | 512 |2min以内|
323
+ """
324
+ model_choices = [
325
+ ("1: DRIVE专用模型", "DRIVE"),
326
+ ("2: HRF专用模型", "hrf"),
327
+ ("3: LES专用模型","LES"),
328
+ ("4: UKBB专用模型", "ukbb"),
329
+ ("5: 通用模型(general)", "all"),
330
+ ]
331
+
332
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
333
+ gr.Markdown("# 👁️ 眼底图像动静脉血管分割(Retinal image artery and vein segmentation)")
334
+ gr.Markdown("上传眼底图像,选择一个模型开始处理,结果将自动生成。(Upload the retinal image, select a model to start processing, and the results will be generated automatically.)")
335
+
336
+ with gr.Row():
337
+ image_input = gr.Image(type="pil", label="📤 上传图像(upload)",height=300)
338
+
339
+ with gr.Row():
340
+ with gr.Column():
341
+ model_select = gr.Radio(
342
+ choices=model_choices,
343
+ label="🎯 选择模型",
344
+ value="DRIVE",
345
+ interactive = True
346
+ )
347
+ submit_btn = gr.Button("🚀 开始分割(RUN)")
348
+ with gr.Column():
349
+ output_image = gr.Image(label="🖼️ 分割结果(Result)")
350
+
351
+ gr.Markdown("### 📁 示例图像examples(点击自动加载)")
352
+ gr.Examples(
353
+ examples=[
354
+ ["examples/DRIVE.tif", "DRIVE"],
355
+ ["examples/LES.png", "LES"],
356
+ ["examples/hrf.png", "hrf"],
357
+ ["examples/ukbb.png", "ukbb"],
358
+ ["examples/all.jpg", "all"]
359
+ ],
360
+ inputs=[image_input, model_select],
361
+ label="示例图像",
362
+ examples_per_page=5
363
+ )
364
+ with gr.Accordion("📖 模型���明desciption(点击展开)", open=False):
365
+ gr.Markdown(model_info_md)
366
+
367
+ # 功能连接
368
+ submit_btn.click(
369
+ fn=segment_by_out_test,
370
+ inputs=[image_input, model_select],
371
+ outputs=[output_image]
372
+ )
373
+ gr.Markdown("📚 **专用模型引用cite**: RIP-AV: Joint Representative Instance Pre-training with Context Aware Network for Retinal Artery/Vein Segmentation")
374
+ gr.Markdown("📚 **通用模型引用cite**: An Efficient and Interpretable Foundation Model for Retinal Image Analysis in Disease Diagnosis.")
375
+ demo.queue()
376
+ demo.launch()
377
+
378
+
379
+ if __name__ == '__main__':
380
+ # cfg.set_dataset('all')
381
+ # image_color = out_test(cfg = cfg, evaluate_metrics=False, img_name=r'.\AV\data\AV-DRIVE\test\images\01_test.tif')
382
+ # Image.fromarray(image_color).save('image_color.png')
383
+ #print(cfg.patch_size)
384
+ gradio_interface()