RIP-AV-su-lab / app.py
weidai00's picture
Update app.py
2d8e012 verified
import torch
import gradio as gr
from PIL import Image
import cv2
from AV.models.network import PGNet
from AV.Tools.AVclassifiation import AVclassifiation
from AV.Tools.utils_test import paint_border_overlap, extract_ordered_overlap_big, Normalize, sigmoid, recompone_overlap, \
kill_border
from AV.config import config_test_general as cfg
import torch.autograd as autograd
import numpy as np
import os
from datetime import datetime
from huggingface_hub import hf_hub_download
hf_token = os.environ.get("HF_token")
def creatMask(Image, threshold=5):
##This program try to creat the mask for the filed-of-view
##Input original image (RGB or green channel), threshold (user set parameter, default 10)
##Output: the filed-of-view mask
if len(Image.shape) == 3: ##RGB image
gray = cv2.cvtColor(Image, cv2.COLOR_BGR2GRAY)
Mask0 = gray >= threshold
else: # for green channel image
Mask0 = Image >= threshold
# ######get the largest blob, this takes 0.18s
cvVersion = int(cv2.__version__.split('.')[0])
Mask0 = np.uint8(Mask0)
contours, hierarchy = cv2.findContours(Mask0, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
areas = [cv2.contourArea(c) for c in contours]
max_index = np.argmax(areas)
Mask = np.zeros(Image.shape[:2], dtype=np.uint8)
cv2.drawContours(Mask, contours, max_index, 1, -1)
ResultImg = Image.copy()
if len(Image.shape) == 3:
ResultImg[Mask == 0] = (255, 255, 255)
else:
ResultImg[Mask == 0] = 255
Mask[Mask > 0] = 255
kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
Mask = cv2.morphologyEx(Mask, cv2.MORPH_OPEN, kernel, iterations=3)
return ResultImg, Mask
def shift_rgb(img, *args):
result_img = np.empty_like(img)
shifts = args
max_value = 255
# print(shifts)
for i, shift in enumerate(shifts):
lut = np.arange(0, max_value + 1).astype("float32")
lut += shift
lut = np.clip(lut, 0, max_value).astype(img.dtype)
if len(img.shape) == 2:
print(f'=========grey image=======')
result_img = cv2.LUT(img, lut)
else:
result_img[..., i] = cv2.LUT(img[..., i], lut)
return result_img
def CAM(x, img_path, rate=0.8, ind=0):
"""
:param dataset_path: 计算整个训练数据集的平均RGB通道值
:param image: array, 单张图片的array 形式
:return: array形式的cam后的结果
"""
# 每次使用新数据集时都需要重新计算前面的RBG平均值
# RGB-->Rshift-->CLAHE
x = np.uint8(x)
_, Mask0 = creatMask(x, threshold=10)
Mask = np.zeros((x.shape[0], x.shape[1]), np.float32)
Mask[Mask0 > 0] = 1
resize = False
R_mea_num, G_mea_num, B_mea_num = [], [], []
dataset_path = img_path
image = np.array(Image.open(dataset_path))
R_mea_num.append(np.mean(image[:, :, 0]))
G_mea_num.append(np.mean(image[:, :, 1]))
B_mea_num.append(np.mean(image[:, :, 2]))
mea2stand = int((np.mean(R_mea_num) - np.mean(x[:, :, 0])) * rate)
mea2standg = int((np.mean(G_mea_num) - np.mean(x[:, :, 1])) * rate)
mea2standb = int((np.mean(B_mea_num) - np.mean(x[:, :, 2])) * rate)
y = shift_rgb(x, mea2stand, mea2standg, mea2standb)
y[Mask == 0, :] = 0
return y
def modelEvalution_out_big(net, use_cuda=False, dataset='', is_kill_border=True, input_ch=3,
config=None, output_dir='', evaluate_metrics=False):
# path for images to save
n_classes = 3
Net = PGNet(use_global_semantic=config.use_global_semantic, input_ch=input_ch,
num_classes=n_classes, use_cuda=use_cuda, pretrained=False, centerness=config.use_centerness,
centerness_map_size=config.centerness_map_size)
msg = Net.load_state_dict(net, strict=False)
if use_cuda:
Net.cuda()
Net.eval()
image_basename = dataset
# if not os.path.exists(output_dir):
# os.makedirs(output_dir)
step = 1
# every step of between star and end for loop until len(image_basename)
# for start_end in start_end_list:
image0 = cv2.imread(image_basename)
test_image_height = image0.shape[0]
test_image_width = image0.shape[1]
if config.use_resize:
if min(test_image_height, test_image_width) <= 256:
scaling = 512 / min(test_image_height, test_image_width)
new_width = int(test_image_width * scaling)
new_height = int(test_image_height * scaling)
test_image_width, test_image_height = new_width, new_height
# 大尺寸处理:确保最长边≤1536
elif max(test_image_height, test_image_width) >= 2048:
scaling = 2048 / max(test_image_height, test_image_width)
new_width = int(test_image_width * scaling)
new_height = int(test_image_height * scaling)
test_image_width, test_image_height = new_width, new_height
ArteryPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32)
VeinPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32)
VesselPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32)
ProMap = np.zeros((1, 3, test_image_height, test_image_width), np.float32)
MaskAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32)
ArteryPred, VeinPred, VesselPred, Mask, LabelArtery, LabelVein, LabelVessel = GetResult_out_big(Net, 0,
use_cuda=use_cuda,
dataset=image_basename,
is_kill_border=is_kill_border,
config=config,
resize_w_h=(
test_image_width,
test_image_height)
)
ArteryPredAll[0 % step, :, :, :] = ArteryPred
VeinPredAll[0 % step, :, :, :] = VeinPred
VesselPredAll[0 % step, :, :, :] = VesselPred
MaskAll[0 % step, :, :, :] = Mask
image_color = AVclassifiation(output_dir, ArteryPredAll, VeinPredAll, VesselPredAll, 1, image_basename)
return image_color
def GetResult_out_big(Net, k, use_cuda=False, dataset='', is_kill_border=False, config=None,
resize_w_h=None):
ImgName = dataset
Img0 = cv2.imread(ImgName)
_, Mask0 = creatMask(Img0, threshold=-1)
Mask = np.zeros((Img0.shape[0], Img0.shape[1]), np.float32)
Mask[Mask0 > 0] = 1
if config.use_resize:
Img0 = cv2.resize(Img0, resize_w_h)
Mask = cv2.resize(Mask, resize_w_h, interpolation=cv2.INTER_NEAREST)
Img = Img0
height, width = Img.shape[:2]
n_classes = 3
patch_height = config.patch_size
patch_width = config.patch_size
stride_height = config.stride_height
stride_width = config.stride_width
Img = cv2.cvtColor(Img, cv2.COLOR_BGR2RGB)
if cfg.dataset == 'all':
# # # 将图像转换为 LAB 颜色空间
lab = cv2.cvtColor(Img, cv2.COLOR_RGB2LAB)
# 拆分 LAB 通道
l, a, b = cv2.split(lab)
# 创建 CLAHE 对象并应用到 L 通道
clahe = cv2.createCLAHE(clipLimit=2, tileGridSize=(8, 8))
l_clahe = clahe.apply(l)
# 将 CLAHE 处理后的 L 通道与原始的 A 和 B 通道合并
lab_clahe = cv2.merge((l_clahe, a, b))
# 将图像转换回 BGR 颜色空间
Img = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
if cfg.use_CAM:
Img = CAM(Img, dataset)
Img = np.float32(Img / 255.)
Img_enlarged = paint_border_overlap(Img, patch_height, patch_width, stride_height, stride_width)
patch_size = config.patch_size
batch_size = 2
patches_imgs, global_images = extract_ordered_overlap_big(Img_enlarged, patch_height, patch_width,
stride_height,
stride_width)
patches_imgs = np.transpose(patches_imgs, (0, 3, 1, 2))
patches_imgs = Normalize(patches_imgs)
global_images = np.transpose(global_images, (0, 3, 1, 2))
global_images = Normalize(global_images)
patchNum = patches_imgs.shape[0]
max_iter = int(np.ceil(patchNum / float(batch_size)))
pred_patches = np.zeros((patchNum, n_classes, patch_size, patch_size), np.float32)
for i in range(max_iter):
begin_index = i * batch_size
end_index = (i + 1) * batch_size
patches_temp1 = patches_imgs[begin_index:end_index, :, :, :]
patches_input_temp1 = torch.FloatTensor(patches_temp1)
global_input_temp1 = patches_input_temp1
if config.use_global_semantic:
global_temp1 = global_images[begin_index:end_index, :, :, :]
global_input_temp1 = torch.FloatTensor(global_temp1)
if use_cuda:
patches_input_temp1 = autograd.Variable(patches_input_temp1.cuda())
if config.use_global_semantic:
global_input_temp1 = autograd.Variable(global_input_temp1.cuda())
else:
patches_input_temp1 = autograd.Variable(patches_input_temp1)
if config.use_global_semantic:
global_input_temp1 = autograd.Variable(global_input_temp1)
output_temp, _1, = Net(patches_input_temp1, global_input_temp1)
pred_patches_temp = np.float32(output_temp.data.cpu().numpy())
pred_patches_temp_sigmoid = sigmoid(pred_patches_temp)
pred_patches[begin_index:end_index, :, :, :] = pred_patches_temp_sigmoid[:, :, :patch_size, :patch_size]
del patches_input_temp1
del pred_patches_temp
del patches_temp1
del output_temp
del pred_patches_temp_sigmoid
new_height, new_width = Img_enlarged.shape[0], Img_enlarged.shape[1]
pred_img = recompone_overlap(pred_patches, new_height, new_width, stride_height, stride_width) # predictions
pred_img = pred_img[:, 0:height, 0:width]
if is_kill_border:
pred_img = kill_border(pred_img, Mask)
ArteryPred = np.float32(pred_img[0, :, :])
VeinPred = np.float32(pred_img[2, :, :])
VesselPred = np.float32(pred_img[1, :, :])
ArteryPred = ArteryPred[np.newaxis, :, :]
VeinPred = VeinPred[np.newaxis, :, :]
VesselPred = VesselPred[np.newaxis, :, :]
Mask = Mask[np.newaxis, :, :]
return ArteryPred, VeinPred, VesselPred, Mask, ArteryPred, VeinPred, VesselPred,
def out_test(cfg,model_path='', output_dir='', evaluate_metrics=False, img_name='out_test'):
device = torch.device("cuda" if cfg.use_cuda else "cpu")
model_path = model_path
net = torch.load(model_path, map_location=device)
image_color = modelEvalution_out_big(net,
use_cuda=cfg.use_cuda,
dataset=img_name,
input_ch=cfg.input_nc,
config=cfg,
output_dir=output_dir, evaluate_metrics=evaluate_metrics)
return image_color
def segment_by_out_test(image,model_name):
print("✅ 传到后端的模型名:", model_name)
model_path = hf_hub_download(
repo_id="weidai00/RIP-AV-sulab", # 模型库的名字
filename=f"G_{model_name}.pkl", # 文件名
repo_type="model", # 模型库必须写 repo_type
token=hf_token
)
cfg.set_dataset(model_name)
if image is None:
raise gr.Error("请上传一张图像(upload a fundus image)。")
os.makedirs("./examples", exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
temp_path = f"./examples/tmp_upload_{timestamp}.png"
image.save(temp_path)
image_color = out_test(cfg,model_path=model_path, output_dir='', evaluate_metrics=False, img_name=temp_path)
return Image.fromarray(image_color)
def gradio_interface():
model_info_md = """
### 📘 模型说明
| 模型(model name) | 数据集(dataset) | patch size |running time |
|------|--------|------------|--------|
| DRIVE | 小分辨率血管图像 | 256 |30s以内|
| HRF | 高分辨率图像(健康、青光眼等)| 256 | 2min以内|
| LES | 视盘中心图像适配 | 256 |2min以内|
| UKBB | UKBB图像 | 256 |2min以内 |
| 通用模型(512) | 超清图像,适配性强 | 512 |2min以内|
"""
model_choices = [
("1: DRIVE专用模型", "DRIVE"),
("2: HRF专用模型", "hrf"),
("3: LES专用模型","LES"),
("4: UKBB专用模型", "ukbb"),
("5: 通用模型(general)", "all"),
]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 👁️ 眼底图像动静脉血管分割(Retinal image artery and vein segmentation)")
gr.Markdown("上传眼底图像,选择一个模型开始处理,结果将自动生成。(Upload the retinal image, select a model to start processing, and the results will be generated automatically.)")
with gr.Row():
image_input = gr.Image(type="pil", label="📤 上传图像(upload)",height=300)
with gr.Row():
with gr.Column():
model_select = gr.Radio(
choices=model_choices,
label="🎯 选择模型",
value="DRIVE",
interactive = True
)
submit_btn = gr.Button("🚀 开始分割(RUN)")
with gr.Column():
output_image = gr.Image(label="🖼️ 分割结果(Result)")
gr.Markdown("### 📁 示例图像examples(点击自动加载)")
gr.Examples(
examples=[
["examples/DRIVE.tif", "DRIVE"],
["examples/LES.png", "LES"],
["examples/hrf.png", "hrf"],
["examples/ukbb.png", "ukbb"],
["examples/all.jpg", "all"]
],
inputs=[image_input, model_select],
label="示例图像",
examples_per_page=5
)
with gr.Accordion("📖 模型说明-Description(点击展开)", open=False):
gr.Markdown(model_info_md)
# 功能连接
submit_btn.click(
fn=segment_by_out_test,
inputs=[image_input, model_select],
outputs=[output_image]
)
gr.Markdown("📚 **专用模型引用cite**: RIP-AV: Joint Representative Instance Pre-training with Context Aware Network for Retinal Artery/Vein Segmentation")
gr.Markdown("📚 **通用模型引用cite**: An Efficient and Interpretable Foundation Model for Retinal Image Analysis in Disease Diagnosis.")
demo.queue()
demo.launch()
if __name__ == '__main__':
# cfg.set_dataset('all')
# image_color = out_test(cfg = cfg, evaluate_metrics=False, img_name=r'.\AV\data\AV-DRIVE\test\images\01_test.tif')
# Image.fromarray(image_color).save('image_color.png')
#print(cfg.patch_size)
gradio_interface()