Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,12 +2,22 @@ from ultralytics import YOLO
|
|
2 |
import gradio as gr
|
3 |
import cv2
|
4 |
import numpy as np
|
5 |
-
import torch
|
6 |
from collections import defaultdict
|
|
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
def segment_image(image, conf_threshold, iou_threshold, mask_threshold, line_thickness, use_retina_masks):
|
|
|
|
|
|
|
11 |
# 确保图像是BGR格式
|
12 |
if len(image.shape) == 2:
|
13 |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
@@ -92,69 +102,66 @@ def segment_image(image, conf_threshold, iou_threshold, mask_threshold, line_thi
|
|
92 |
|
93 |
return gallery_output if gallery_output else None
|
94 |
|
95 |
-
|
96 |
-
with gr.Blocks() as
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
)
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
inputs=input_image,
|
143 |
-
)
|
144 |
|
145 |
-
|
146 |
-
fn=segment_image,
|
147 |
-
inputs=[
|
148 |
-
input_image,
|
149 |
-
conf_threshold,
|
150 |
-
iou_threshold,
|
151 |
-
mask_threshold,
|
152 |
-
line_thickness,
|
153 |
-
retina_masks
|
154 |
-
],
|
155 |
-
outputs=output_gallery
|
156 |
-
)
|
157 |
|
158 |
-
# 启动应用
|
159 |
if __name__ == "__main__":
|
160 |
-
|
|
|
|
2 |
import gradio as gr
|
3 |
import cv2
|
4 |
import numpy as np
|
|
|
5 |
from collections import defaultdict
|
6 |
+
import os
|
7 |
|
8 |
+
# 初始化模型
|
9 |
+
model = None
|
10 |
+
|
11 |
+
def load_model():
|
12 |
+
global model
|
13 |
+
if model is None:
|
14 |
+
model = YOLO('yolo11x-seg.pt')
|
15 |
+
return model
|
16 |
|
17 |
def segment_image(image, conf_threshold, iou_threshold, mask_threshold, line_thickness, use_retina_masks):
|
18 |
+
# 加载模型
|
19 |
+
model = load_model()
|
20 |
+
|
21 |
# 确保图像是BGR格式
|
22 |
if len(image.shape) == 2:
|
23 |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
|
|
102 |
|
103 |
return gallery_output if gallery_output else None
|
104 |
|
105 |
+
def create_demo():
|
106 |
+
with gr.Blocks() as demo:
|
107 |
+
gr.Markdown("# YOLO 图像分割")
|
108 |
+
gr.Markdown("上传一张图片,模型将对图片进行实例分割。每个检测到的类别将单独显示。")
|
109 |
+
|
110 |
+
with gr.Row():
|
111 |
+
with gr.Column(scale=1):
|
112 |
+
input_image = gr.Image()
|
113 |
+
with gr.Row():
|
114 |
+
conf_threshold = gr.Slider(
|
115 |
+
minimum=0.1, maximum=1.0, value=0.25, step=0.05,
|
116 |
+
label="置信度阈值", info="检测置信度的最小值"
|
117 |
+
)
|
118 |
+
iou_threshold = gr.Slider(
|
119 |
+
minimum=0.1, maximum=1.0, value=0.7, step=0.05,
|
120 |
+
label="IOU阈值", info="非极大值抑制的IOU阈值"
|
121 |
+
)
|
122 |
+
with gr.Row():
|
123 |
+
mask_threshold = gr.Slider(
|
124 |
+
minimum=0.1, maximum=1.0, value=0.5, step=0.05,
|
125 |
+
label="掩码阈值", info="分割掩码的阈值"
|
126 |
+
)
|
127 |
+
line_thickness = gr.Slider(
|
128 |
+
minimum=1, maximum=5, value=2, step=1,
|
129 |
+
label="线条粗细", info="边界框和文本的粗细"
|
130 |
+
)
|
131 |
+
with gr.Row():
|
132 |
+
retina_masks = gr.Checkbox(
|
133 |
+
label="高分辨率掩码",
|
134 |
+
value=True,
|
135 |
+
info="启用高分辨率分割掩码(可能会降低速度)"
|
136 |
+
)
|
137 |
+
|
138 |
+
with gr.Column(scale=1):
|
139 |
+
output_gallery = gr.Gallery(
|
140 |
+
label="分割结果",
|
141 |
+
show_label=True,
|
142 |
+
columns=2,
|
143 |
+
rows=2,
|
144 |
+
height=600,
|
145 |
+
object_fit="contain"
|
146 |
)
|
147 |
|
148 |
+
submit_btn = gr.Button("开始分割")
|
149 |
+
|
150 |
+
submit_btn.click(
|
151 |
+
fn=segment_image,
|
152 |
+
inputs=[
|
153 |
+
input_image,
|
154 |
+
conf_threshold,
|
155 |
+
iou_threshold,
|
156 |
+
mask_threshold,
|
157 |
+
line_thickness,
|
158 |
+
retina_masks
|
159 |
+
],
|
160 |
+
outputs=output_gallery
|
161 |
+
)
|
|
|
|
|
162 |
|
163 |
+
return demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
|
|
165 |
if __name__ == "__main__":
|
166 |
+
demo = create_demo()
|
167 |
+
demo.launch()
|