Spaces:
Running
Running
import torch | |
import gradio as gr | |
from PIL import Image | |
import cv2 | |
from AV.models.network import PGNet | |
from AV.Tools.AVclassifiation import AVclassifiation | |
from AV.Tools.utils_test import paint_border_overlap, extract_ordered_overlap_big, Normalize, sigmoid, recompone_overlap, \ | |
kill_border | |
from AV.config import config_test_general as cfg | |
import torch.autograd as autograd | |
import numpy as np | |
import os | |
from datetime import datetime | |
from huggingface_hub import hf_hub_download | |
hf_token = os.environ.get("HF_token") | |
def creatMask(Image, threshold=5): | |
##This program try to creat the mask for the filed-of-view | |
##Input original image (RGB or green channel), threshold (user set parameter, default 10) | |
##Output: the filed-of-view mask | |
if len(Image.shape) == 3: ##RGB image | |
gray = cv2.cvtColor(Image, cv2.COLOR_BGR2GRAY) | |
Mask0 = gray >= threshold | |
else: # for green channel image | |
Mask0 = Image >= threshold | |
# ######get the largest blob, this takes 0.18s | |
cvVersion = int(cv2.__version__.split('.')[0]) | |
Mask0 = np.uint8(Mask0) | |
contours, hierarchy = cv2.findContours(Mask0, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
areas = [cv2.contourArea(c) for c in contours] | |
max_index = np.argmax(areas) | |
Mask = np.zeros(Image.shape[:2], dtype=np.uint8) | |
cv2.drawContours(Mask, contours, max_index, 1, -1) | |
ResultImg = Image.copy() | |
if len(Image.shape) == 3: | |
ResultImg[Mask == 0] = (255, 255, 255) | |
else: | |
ResultImg[Mask == 0] = 255 | |
Mask[Mask > 0] = 255 | |
kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) | |
Mask = cv2.morphologyEx(Mask, cv2.MORPH_OPEN, kernel, iterations=3) | |
return ResultImg, Mask | |
def shift_rgb(img, *args): | |
result_img = np.empty_like(img) | |
shifts = args | |
max_value = 255 | |
# print(shifts) | |
for i, shift in enumerate(shifts): | |
lut = np.arange(0, max_value + 1).astype("float32") | |
lut += shift | |
lut = np.clip(lut, 0, max_value).astype(img.dtype) | |
if len(img.shape) == 2: | |
print(f'=========grey image=======') | |
result_img = cv2.LUT(img, lut) | |
else: | |
result_img[..., i] = cv2.LUT(img[..., i], lut) | |
return result_img | |
def CAM(x, img_path, rate=0.8, ind=0): | |
""" | |
:param dataset_path: 计算整个训练数据集的平均RGB通道值 | |
:param image: array, 单张图片的array 形式 | |
:return: array形式的cam后的结果 | |
""" | |
# 每次使用新数据集时都需要重新计算前面的RBG平均值 | |
# RGB-->Rshift-->CLAHE | |
x = np.uint8(x) | |
_, Mask0 = creatMask(x, threshold=10) | |
Mask = np.zeros((x.shape[0], x.shape[1]), np.float32) | |
Mask[Mask0 > 0] = 1 | |
resize = False | |
R_mea_num, G_mea_num, B_mea_num = [], [], [] | |
dataset_path = img_path | |
image = np.array(Image.open(dataset_path)) | |
R_mea_num.append(np.mean(image[:, :, 0])) | |
G_mea_num.append(np.mean(image[:, :, 1])) | |
B_mea_num.append(np.mean(image[:, :, 2])) | |
mea2stand = int((np.mean(R_mea_num) - np.mean(x[:, :, 0])) * rate) | |
mea2standg = int((np.mean(G_mea_num) - np.mean(x[:, :, 1])) * rate) | |
mea2standb = int((np.mean(B_mea_num) - np.mean(x[:, :, 2])) * rate) | |
y = shift_rgb(x, mea2stand, mea2standg, mea2standb) | |
y[Mask == 0, :] = 0 | |
return y | |
def modelEvalution_out_big(net, use_cuda=False, dataset='', is_kill_border=True, input_ch=3, | |
config=None, output_dir='', evaluate_metrics=False): | |
# path for images to save | |
n_classes = 3 | |
Net = PGNet(use_global_semantic=config.use_global_semantic, input_ch=input_ch, | |
num_classes=n_classes, use_cuda=use_cuda, pretrained=False, centerness=config.use_centerness, | |
centerness_map_size=config.centerness_map_size) | |
msg = Net.load_state_dict(net, strict=False) | |
if use_cuda: | |
Net.cuda() | |
Net.eval() | |
image_basename = dataset | |
# if not os.path.exists(output_dir): | |
# os.makedirs(output_dir) | |
step = 1 | |
# every step of between star and end for loop until len(image_basename) | |
# for start_end in start_end_list: | |
image0 = cv2.imread(image_basename) | |
test_image_height = image0.shape[0] | |
test_image_width = image0.shape[1] | |
if config.use_resize: | |
if min(test_image_height, test_image_width) <= 256: | |
scaling = 512 / min(test_image_height, test_image_width) | |
new_width = int(test_image_width * scaling) | |
new_height = int(test_image_height * scaling) | |
test_image_width, test_image_height = new_width, new_height | |
# 大尺寸处理:确保最长边≤1536 | |
elif max(test_image_height, test_image_width) >= 2048: | |
scaling = 2048 / max(test_image_height, test_image_width) | |
new_width = int(test_image_width * scaling) | |
new_height = int(test_image_height * scaling) | |
test_image_width, test_image_height = new_width, new_height | |
ArteryPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) | |
VeinPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) | |
VesselPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) | |
ProMap = np.zeros((1, 3, test_image_height, test_image_width), np.float32) | |
MaskAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) | |
ArteryPred, VeinPred, VesselPred, Mask, LabelArtery, LabelVein, LabelVessel = GetResult_out_big(Net, 0, | |
use_cuda=use_cuda, | |
dataset=image_basename, | |
is_kill_border=is_kill_border, | |
config=config, | |
resize_w_h=( | |
test_image_width, | |
test_image_height) | |
) | |
ArteryPredAll[0 % step, :, :, :] = ArteryPred | |
VeinPredAll[0 % step, :, :, :] = VeinPred | |
VesselPredAll[0 % step, :, :, :] = VesselPred | |
MaskAll[0 % step, :, :, :] = Mask | |
image_color = AVclassifiation(output_dir, ArteryPredAll, VeinPredAll, VesselPredAll, 1, image_basename) | |
return image_color | |
def GetResult_out_big(Net, k, use_cuda=False, dataset='', is_kill_border=False, config=None, | |
resize_w_h=None): | |
ImgName = dataset | |
Img0 = cv2.imread(ImgName) | |
_, Mask0 = creatMask(Img0, threshold=-1) | |
Mask = np.zeros((Img0.shape[0], Img0.shape[1]), np.float32) | |
Mask[Mask0 > 0] = 1 | |
if config.use_resize: | |
Img0 = cv2.resize(Img0, resize_w_h) | |
Mask = cv2.resize(Mask, resize_w_h, interpolation=cv2.INTER_NEAREST) | |
Img = Img0 | |
height, width = Img.shape[:2] | |
n_classes = 3 | |
patch_height = config.patch_size | |
patch_width = config.patch_size | |
stride_height = config.stride_height | |
stride_width = config.stride_width | |
Img = cv2.cvtColor(Img, cv2.COLOR_BGR2RGB) | |
if cfg.dataset == 'all': | |
# # # 将图像转换为 LAB 颜色空间 | |
lab = cv2.cvtColor(Img, cv2.COLOR_RGB2LAB) | |
# 拆分 LAB 通道 | |
l, a, b = cv2.split(lab) | |
# 创建 CLAHE 对象并应用到 L 通道 | |
clahe = cv2.createCLAHE(clipLimit=2, tileGridSize=(8, 8)) | |
l_clahe = clahe.apply(l) | |
# 将 CLAHE 处理后的 L 通道与原始的 A 和 B 通道合并 | |
lab_clahe = cv2.merge((l_clahe, a, b)) | |
# 将图像转换回 BGR 颜色空间 | |
Img = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB) | |
if cfg.use_CAM: | |
Img = CAM(Img, dataset) | |
Img = np.float32(Img / 255.) | |
Img_enlarged = paint_border_overlap(Img, patch_height, patch_width, stride_height, stride_width) | |
patch_size = config.patch_size | |
batch_size = 2 | |
patches_imgs, global_images = extract_ordered_overlap_big(Img_enlarged, patch_height, patch_width, | |
stride_height, | |
stride_width) | |
patches_imgs = np.transpose(patches_imgs, (0, 3, 1, 2)) | |
patches_imgs = Normalize(patches_imgs) | |
global_images = np.transpose(global_images, (0, 3, 1, 2)) | |
global_images = Normalize(global_images) | |
patchNum = patches_imgs.shape[0] | |
max_iter = int(np.ceil(patchNum / float(batch_size))) | |
pred_patches = np.zeros((patchNum, n_classes, patch_size, patch_size), np.float32) | |
for i in range(max_iter): | |
begin_index = i * batch_size | |
end_index = (i + 1) * batch_size | |
patches_temp1 = patches_imgs[begin_index:end_index, :, :, :] | |
patches_input_temp1 = torch.FloatTensor(patches_temp1) | |
global_input_temp1 = patches_input_temp1 | |
if config.use_global_semantic: | |
global_temp1 = global_images[begin_index:end_index, :, :, :] | |
global_input_temp1 = torch.FloatTensor(global_temp1) | |
if use_cuda: | |
patches_input_temp1 = autograd.Variable(patches_input_temp1.cuda()) | |
if config.use_global_semantic: | |
global_input_temp1 = autograd.Variable(global_input_temp1.cuda()) | |
else: | |
patches_input_temp1 = autograd.Variable(patches_input_temp1) | |
if config.use_global_semantic: | |
global_input_temp1 = autograd.Variable(global_input_temp1) | |
output_temp, _1, = Net(patches_input_temp1, global_input_temp1) | |
pred_patches_temp = np.float32(output_temp.data.cpu().numpy()) | |
pred_patches_temp_sigmoid = sigmoid(pred_patches_temp) | |
pred_patches[begin_index:end_index, :, :, :] = pred_patches_temp_sigmoid[:, :, :patch_size, :patch_size] | |
del patches_input_temp1 | |
del pred_patches_temp | |
del patches_temp1 | |
del output_temp | |
del pred_patches_temp_sigmoid | |
new_height, new_width = Img_enlarged.shape[0], Img_enlarged.shape[1] | |
pred_img = recompone_overlap(pred_patches, new_height, new_width, stride_height, stride_width) # predictions | |
pred_img = pred_img[:, 0:height, 0:width] | |
if is_kill_border: | |
pred_img = kill_border(pred_img, Mask) | |
ArteryPred = np.float32(pred_img[0, :, :]) | |
VeinPred = np.float32(pred_img[2, :, :]) | |
VesselPred = np.float32(pred_img[1, :, :]) | |
ArteryPred = ArteryPred[np.newaxis, :, :] | |
VeinPred = VeinPred[np.newaxis, :, :] | |
VesselPred = VesselPred[np.newaxis, :, :] | |
Mask = Mask[np.newaxis, :, :] | |
return ArteryPred, VeinPred, VesselPred, Mask, ArteryPred, VeinPred, VesselPred, | |
def out_test(cfg,model_path='', output_dir='', evaluate_metrics=False, img_name='out_test'): | |
device = torch.device("cuda" if cfg.use_cuda else "cpu") | |
model_path = model_path | |
net = torch.load(model_path, map_location=device) | |
image_color = modelEvalution_out_big(net, | |
use_cuda=cfg.use_cuda, | |
dataset=img_name, | |
input_ch=cfg.input_nc, | |
config=cfg, | |
output_dir=output_dir, evaluate_metrics=evaluate_metrics) | |
return image_color | |
def segment_by_out_test(image,model_name): | |
print("✅ 传到后端的模型名:", model_name) | |
model_path = hf_hub_download( | |
repo_id="weidai00/RIP-AV-sulab", # 模型库的名字 | |
filename=f"G_{model_name}.pkl", # 文件名 | |
repo_type="model", # 模型库必须写 repo_type | |
token=hf_token | |
) | |
cfg.set_dataset(model_name) | |
if image is None: | |
raise gr.Error("请上传一张图像(upload a fundus image)。") | |
os.makedirs("./examples", exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
temp_path = f"./examples/tmp_upload_{timestamp}.png" | |
image.save(temp_path) | |
image_color = out_test(cfg,model_path=model_path, output_dir='', evaluate_metrics=False, img_name=temp_path) | |
return Image.fromarray(image_color) | |
def gradio_interface(): | |
model_info_md = """ | |
### 📘 模型说明 | |
| 模型(model name) | 数据集(dataset) | patch size |running time | | |
|------|--------|------------|--------| | |
| DRIVE | 小分辨率血管图像 | 256 |30s以内| | |
| HRF | 高分辨率图像(健康、青光眼等)| 256 | 2min以内| | |
| LES | 视盘中心图像适配 | 256 |2min以内| | |
| UKBB | UKBB图像 | 256 |2min以内 | | |
| 通用模型(512) | 超清图像,适配性强 | 512 |2min以内| | |
""" | |
model_choices = [ | |
("1: DRIVE专用模型", "DRIVE"), | |
("2: HRF专用模型", "hrf"), | |
("3: LES专用模型","LES"), | |
("4: UKBB专用模型", "ukbb"), | |
("5: 通用模型(general)", "all"), | |
] | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 👁️ 眼底图像动静脉血管分割(Retinal image artery and vein segmentation)") | |
gr.Markdown("上传眼底图像,选择一个模型开始处理,结果将自动生成。(Upload the retinal image, select a model to start processing, and the results will be generated automatically.)") | |
with gr.Row(): | |
image_input = gr.Image(type="pil", label="📤 上传图像(upload)",height=300) | |
with gr.Row(): | |
with gr.Column(): | |
model_select = gr.Radio( | |
choices=model_choices, | |
label="🎯 选择模型", | |
value="DRIVE", | |
interactive = True | |
) | |
submit_btn = gr.Button("🚀 开始分割(RUN)") | |
with gr.Column(): | |
output_image = gr.Image(label="🖼️ 分割结果(Result)") | |
gr.Markdown("### 📁 示例图像examples(点击自动加载)") | |
gr.Examples( | |
examples=[ | |
["examples/DRIVE.tif", "DRIVE"], | |
["examples/LES.png", "LES"], | |
["examples/hrf.png", "hrf"], | |
["examples/ukbb.png", "ukbb"], | |
["examples/all.jpg", "all"] | |
], | |
inputs=[image_input, model_select], | |
label="示例图像", | |
examples_per_page=5 | |
) | |
with gr.Accordion("📖 模型说明-Description(点击展开)", open=False): | |
gr.Markdown(model_info_md) | |
# 功能连接 | |
submit_btn.click( | |
fn=segment_by_out_test, | |
inputs=[image_input, model_select], | |
outputs=[output_image] | |
) | |
gr.Markdown("📚 **专用模型引用cite**: RIP-AV: Joint Representative Instance Pre-training with Context Aware Network for Retinal Artery/Vein Segmentation") | |
gr.Markdown("📚 **通用模型引用cite**: An Efficient and Interpretable Foundation Model for Retinal Image Analysis in Disease Diagnosis.") | |
demo.queue() | |
demo.launch() | |
if __name__ == '__main__': | |
# cfg.set_dataset('all') | |
# image_color = out_test(cfg = cfg, evaluate_metrics=False, img_name=r'.\AV\data\AV-DRIVE\test\images\01_test.tif') | |
# Image.fromarray(image_color).save('image_color.png') | |
#print(cfg.patch_size) | |
gradio_interface() | |