Hanf Chase
v1
71e7eab
raw
history blame
4.08 kB
import gradio as gr
import cv2
import numpy as np
import os
import tempfile
from ultralytics import YOLO
# 加载YOLOv8模型
model_path = "docgenome_object_detection_yolov8.pt"
model = YOLO(model_path)
def detect_and_visualize(image):
"""
对上传的图像进行目标检测并可视化结果
Args:
image: 上传的图像
Returns:
annotated_image: 带有检测框的图像
yolo_annotations: YOLO格式的标注内容
"""
# 运行检测
results = model(image)
# 获取第一帧的结果
result = results[0]
# 创建图像副本用于可视化
annotated_image = image.copy()
# 准备YOLO格式的标注内容
yolo_annotations = []
# 获取图像尺寸
img_height, img_width = image.shape[:2]
# 在原图上绘制检测结果
for box in result.boxes:
# 获取边界框坐标
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
# 获取置信度
conf = float(box.conf[0])
# 获取类别ID和名称
cls_id = int(box.cls[0])
cls_name = result.names[cls_id]
# 为每个类别生成不同的颜色
color = tuple(np.random.randint(0, 255, 3).tolist())
# 绘制边界框
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
# 准备标签文本
label = f'{cls_name} {conf:.2f}'
# 计算标签大小
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
# 绘制标签背景
cv2.rectangle(annotated_image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1)
# 绘制标签文本
cv2.putText(annotated_image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# 转换为YOLO格式 (x_center, y_center, width, height) 归一化到0-1
x_center = (x1 + x2) / (2 * img_width)
y_center = (y1 + y2) / (2 * img_height)
width = (x2 - x1) / img_width
height = (y2 - y1) / img_height
# 添加到YOLO标注列表
yolo_annotations.append(f"{cls_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
# 将YOLO标注转换为字符串
yolo_annotations_str = "\n".join(yolo_annotations)
return annotated_image, yolo_annotations_str
def save_yolo_annotations(yolo_annotations_str):
"""
保存YOLO标注到临时文件并返回文件路径
Args:
yolo_annotations_str: YOLO格式的标注字符串
Returns:
file_path: 保存的标注文件路径
"""
# 创建临时文件
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt")
temp_file_path = temp_file.name
# 写入标注内容
with open(temp_file_path, "w") as f:
f.write(yolo_annotations_str)
return temp_file_path
# 创建Gradio界面
with gr.Blocks(title="YOLOv8目标检测可视化") as demo:
gr.Markdown("# YOLOv8目标检测可视化")
gr.Markdown("上传图像,使用YOLOv8模型进行目标检测,并下载YOLO格式的标注。")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="上传图像", type="numpy")
detect_btn = gr.Button("开始检测")
with gr.Column():
output_image = gr.Image(label="检测结果")
yolo_annotations = gr.Textbox(label="YOLO标注", lines=10)
download_btn = gr.Button("下载YOLO标注")
download_file = gr.File(label="下载文件")
# 设置点击事件
detect_btn.click(
fn=detect_and_visualize,
inputs=[input_image],
outputs=[output_image, yolo_annotations]
)
download_btn.click(
fn=save_yolo_annotations,
inputs=[yolo_annotations],
outputs=[download_file]
)
# 启动应用
if __name__ == "__main__":
demo.launch()