YOLOv11-seg / app.py
robot2no1's picture
Update app.py
ecfd8bd verified
raw
history blame
5.97 kB
from ultralytics import YOLO
import gradio as gr
import cv2
import numpy as np
from collections import defaultdict
import os
import torch
# 初始化模型
model = None
def load_model():
global model
if model is None:
model = YOLO('./yolo11x-seg.pt')
return model
def segment_image(image, conf_threshold, iou_threshold, mask_threshold, line_thickness, use_retina_masks):
# 加载模型
model = load_model()
# 确保图像是BGR格式
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
elif image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
# 使用模型进行预测
results = model(
image,
conf=conf_threshold,
iou=iou_threshold,
device='cpu',
retina_masks=use_retina_masks
)
result = results[0]
# 按类别分组存储结果
class_images = defaultdict(lambda: image.copy())
detected_classes = set()
if result.masks is not None:
names = model.names
# 处理每个检测结果
for seg, box, cls in zip(result.masks, result.boxes, result.boxes.cls):
class_id = int(cls)
class_name = names[class_id]
detected_classes.add(class_name)
output_image = class_images[class_name]
# 处理分割掩码
segment = seg.data[0].cpu().numpy()
segment = cv2.resize(segment, (output_image.shape[1], output_image.shape[0]))
# 生成颜色
color_mask = np.array([hash(class_name) % 256,
hash(class_name * 2) % 256,
hash(class_name * 3) % 256], dtype=np.uint8)
# 应用掩码
mask_area = segment > mask_threshold
overlay = output_image.copy()
overlay[mask_area] = color_mask
cv2.addWeighted(overlay, 0.4, output_image, 0.6, 0, output_image)
# 添加边界框和标签
conf = float(box.conf)
x1, y1, x2, y2 = map(int, box.xyxy[0])
cv2.rectangle(output_image, (x1, y1), (x2, y2),
color_mask.tolist(), line_thickness)
# 添加标签
label = f"{class_name} {conf:.2f}"
font_scale = 0.6 * line_thickness / 2
(label_width, label_height), _ = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, line_thickness)
cv2.rectangle(output_image,
(x1, y1 - label_height - 10),
(x1 + label_width, y1),
color_mask.tolist(), -1)
cv2.putText(output_image, label, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, font_scale,
(255, 255, 255), line_thickness, cv2.LINE_AA)
class_images[class_name] = output_image
# 准备Gallery输出
gallery_output = []
# 添加完整结果
if detected_classes:
full_result = image.copy()
for class_name in detected_classes:
cv2.addWeighted(class_images[class_name], 0.5, full_result, 0.5, 0, full_result)
gallery_output.append((full_result, "完整结果"))
# 添加各个类别的结果
for class_name in detected_classes:
gallery_output.append((class_images[class_name], class_name))
return gallery_output if gallery_output else None
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("# YOLO 图像分割")
gr.Markdown("上传一张图片,模型将对图片进行实例分割。每个检测到的类别将单独显示。")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image()
with gr.Row():
conf_threshold = gr.Slider(
minimum=0.1, maximum=1.0, value=0.25, step=0.05,
label="置信度阈值", info="检测置信度的最小值"
)
iou_threshold = gr.Slider(
minimum=0.1, maximum=1.0, value=0.7, step=0.05,
label="IOU阈值", info="非极大值抑制的IOU阈值"
)
with gr.Row():
mask_threshold = gr.Slider(
minimum=0.1, maximum=1.0, value=0.5, step=0.05,
label="掩码阈值", info="分割掩码的阈值"
)
line_thickness = gr.Slider(
minimum=1, maximum=5, value=2, step=1,
label="线条粗细", info="边界框和文本的粗细"
)
with gr.Row():
retina_masks = gr.Checkbox(
label="高分辨率掩码",
value=True,
info="启用高分辨率分割掩码(可能会降低速度)"
)
with gr.Column(scale=1):
output_gallery = gr.Gallery(
label="分割结果",
show_label=True,
columns=2,
rows=2,
height=600,
object_fit="contain"
)
submit_btn = gr.Button("开始分割")
submit_btn.click(
fn=segment_image,
inputs=[
input_image,
conf_threshold,
iou_threshold,
mask_threshold,
line_thickness,
retina_masks
],
outputs=output_gallery
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(share=True)