|
import gradio as gr |
|
import cv2 |
|
import numpy as np |
|
import os |
|
import tempfile |
|
from ultralytics import YOLO |
|
|
|
|
|
model_path = "docgenome_object_detection_yolov8.pt" |
|
model = YOLO(model_path) |
|
|
|
def detect_and_visualize(image): |
|
""" |
|
对上传的图像进行目标检测并可视化结果 |
|
|
|
Args: |
|
image: 上传的图像 |
|
|
|
Returns: |
|
annotated_image: 带有检测框的图像 |
|
yolo_annotations: YOLO格式的标注内容 |
|
""" |
|
|
|
results = model(image) |
|
|
|
|
|
result = results[0] |
|
|
|
|
|
annotated_image = image.copy() |
|
|
|
|
|
yolo_annotations = [] |
|
|
|
|
|
img_height, img_width = image.shape[:2] |
|
|
|
|
|
for box in result.boxes: |
|
|
|
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() |
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) |
|
|
|
|
|
conf = float(box.conf[0]) |
|
|
|
|
|
cls_id = int(box.cls[0]) |
|
cls_name = result.names[cls_id] |
|
|
|
|
|
color = tuple(np.random.randint(0, 255, 3).tolist()) |
|
|
|
|
|
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2) |
|
|
|
|
|
label = f'{cls_name} {conf:.2f}' |
|
|
|
|
|
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) |
|
|
|
|
|
cv2.rectangle(annotated_image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1) |
|
|
|
|
|
cv2.putText(annotated_image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) |
|
|
|
|
|
x_center = (x1 + x2) / (2 * img_width) |
|
y_center = (y1 + y2) / (2 * img_height) |
|
width = (x2 - x1) / img_width |
|
height = (y2 - y1) / img_height |
|
|
|
|
|
yolo_annotations.append(f"{cls_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}") |
|
|
|
|
|
yolo_annotations_str = "\n".join(yolo_annotations) |
|
|
|
return annotated_image, yolo_annotations_str |
|
|
|
def save_yolo_annotations(yolo_annotations_str): |
|
""" |
|
保存YOLO标注到临时文件并返回文件路径 |
|
|
|
Args: |
|
yolo_annotations_str: YOLO格式的标注字符串 |
|
|
|
Returns: |
|
file_path: 保存的标注文件路径 |
|
""" |
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt") |
|
temp_file_path = temp_file.name |
|
|
|
|
|
with open(temp_file_path, "w") as f: |
|
f.write(yolo_annotations_str) |
|
|
|
return temp_file_path |
|
|
|
|
|
with gr.Blocks(title="YOLOv8目标检测可视化") as demo: |
|
gr.Markdown("# YOLOv8目标检测可视化") |
|
gr.Markdown("上传图像,使用YOLOv8模型进行目标检测,并下载YOLO格式的标注。") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(label="上传图像", type="numpy") |
|
detect_btn = gr.Button("开始检测") |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(label="检测结果") |
|
yolo_annotations = gr.Textbox(label="YOLO标注", lines=10) |
|
download_btn = gr.Button("下载YOLO标注") |
|
download_file = gr.File(label="下载文件") |
|
|
|
|
|
detect_btn.click( |
|
fn=detect_and_visualize, |
|
inputs=[input_image], |
|
outputs=[output_image, yolo_annotations] |
|
) |
|
|
|
download_btn.click( |
|
fn=save_yolo_annotations, |
|
inputs=[yolo_annotations], |
|
outputs=[download_file] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |