Spaces:
Runtime error
Runtime error
from models.tools import split | |
from PIL import Image | |
from huggingface_hub import hf_hub_download | |
from ultralytics import YOLO | |
import torch | |
import cv2 | |
import numpy as np | |
import math | |
from models.tools.draw import add_bboxes2 | |
class YoloModel: | |
def __init__(self, seg_repo_name: str, seg_file_name: str, det_repo_name: str, det_file_name: str): | |
seg_weight_file = YoloModel.download_weight_file(seg_repo_name, seg_file_name) | |
det_weight_file = YoloModel.download_weight_file(det_repo_name, det_file_name) | |
self.seg_model = YOLO(seg_weight_file) | |
self.det_model = YOLO(det_weight_file) | |
def download_weight_file(repo_name: str, file_name: str): | |
return hf_hub_download(repo_name, file_name) | |
def preview_detect(self, im, confidence): | |
results = self.detect(im) | |
res_img = Image.open(im) | |
res = { | |
'boxes': [ | |
{ | |
'xyxy': [x1, y1, x2, y2], | |
'cls': cls, | |
'conf': conf | |
} for x1, y1, x2, y2, conf, cls in results | |
] | |
} | |
res_img = add_bboxes2(res_img, res, confidence) | |
return res_img | |
def detect(self, source): | |
pred_bbox_list = [] # 初始化该图像bbox列表 | |
threshold = 50 # 暂定bbox merge 阈值为50, 后期可根据用户需求做自适应调整 | |
strategy = "distance" # 暂定bbox merge 策略为distance | |
seg_img_list = self._seg_ori_img(source) # 对该图像进行路面分割 | |
assert len(seg_img_list) == 1, "seg_img_list out of range" | |
road_img = Image.fromarray(cv2.cvtColor(seg_img_list[0], cv2.COLOR_BGR2RGB)) | |
small_imgs = split.split_image(road_img, (640, 640), (1080, 1080), 0.1) # 对路面图像进行小图分割 | |
num = 0 | |
for small_img in small_imgs: | |
num += 1 | |
results = self.det_model(source=small_img["image"]) | |
for result in results: | |
temp_bbox_list = result.boxes.xyxy # 获取检测结果中的bbox坐标(此处使用xyxy格式) | |
w_bias = small_img["area"][0] | |
h_bias = small_img["area"][1] | |
temp_bbox_list = self._bbox_map(temp_bbox_list, w_bias, h_bias) # 将bbox坐标映射到原始大图坐标系中 | |
temp_bbox_cls = result.boxes.cls # 获取检测结果中的class | |
temp_bbox_conf = result.boxes.conf # 获取检测结果中的confidence | |
assert len(temp_bbox_list) == len(temp_bbox_cls) == len( | |
temp_bbox_conf), 'different number of matrix size' | |
for i in range(len(temp_bbox_list)): # 整合bbox、conf和class到一个数组中 | |
temp_bbox_list[i].append(temp_bbox_conf[i]) | |
temp_bbox_list[i].append(temp_bbox_cls[i]) | |
pred_bbox_list += temp_bbox_list # 将单张大图分割后的全体小图得到的检测结果(bbox、conf、class)整合到一个list | |
pred_bbox_list = self._merge_box(pred_bbox_list, threshold, strategy=strategy) # 调用指定算法,对bbox进行分析合并 | |
return pred_bbox_list | |
def _seg_ori_img(self, source): | |
""" | |
分割原始图像中的沥青路面区域 | |
:param source: 图像路径 | |
:return: 分割得到的沥青路面图像(尺寸与原始图像一致,非路面区域用白色填充) | |
""" | |
ori_img = cv2.imread(source) | |
ori_size = ori_img.shape | |
results = self.seg_model(source=source) | |
seg_img_list = [] | |
for result in results: | |
if result.masks is not None and len(result.masks) > 0: # 检测到路面时 | |
masks_data = result.masks.data | |
obj_masks = masks_data[:] | |
road_mask = torch.any(obj_masks, dim=0).int() * 255 | |
mask = road_mask.cpu().numpy() | |
Mask = mask.astype(np.uint8) | |
mask_res = cv2.resize(Mask, (ori_size[1], ori_size[0]), interpolation=cv2.INTER_CUBIC) | |
else: # 检测不到路面时保存纯黑色图像 | |
mask_res = np.zeros((ori_size[0], ori_size[1], 3), dtype=np.uint8) | |
mask_region = mask_res == 0 | |
ori_img[mask_region] = 255 # 判断条件置0掩码为黑,置255背景为白 | |
seg_img_list.append(ori_img) | |
return seg_img_list | |
def _bbox_map(self, bbox_list, w, h): | |
""" | |
将小图中的bbox坐标映射到原始图像中 | |
:param bbox_list: 小图中的bbox数组 | |
:param w: 小图在原始图像中的偏置w | |
:param h: 小图在原始图像中的偏置h | |
:return: 该bbox数组在原始图像中的坐标 | |
""" | |
if isinstance(bbox_list, torch.Tensor): | |
bbox_list = bbox_list.tolist() | |
for bbox in bbox_list: | |
bbox[0] += w | |
bbox[1] += h | |
bbox[2] += w | |
bbox[3] += h | |
return bbox_list | |
def _xywh2xyxy(self, box_list): | |
""" | |
YOLO标签,xywh转xyxy | |
:param box_list: bbox数组(xywh) | |
:return: bbox数组(xyxy) | |
""" | |
new_box_list = [] | |
for box in box_list: | |
x1 = box[0] - box[2] / 2 | |
y1 = box[1] - box[3] / 2 | |
x2 = box[0] + box[2] / 2 | |
y2 = box[1] + box[3] / 2 | |
new_box_list.append([x1, y1, x2, y2]) | |
return new_box_list | |
def _xyxy2xywh(self, box_list): | |
""" | |
YOLO标签,xyxy转xywh | |
:param box_list: bbox数组(xyxy) | |
:return: bbox数组(xywh) | |
""" | |
new_box_list = [] | |
for box in box_list: | |
x1 = (box[0] + box[2]) / 2 | |
y1 = (box[1] + box[3]) / 2 | |
w = (box[2] - box[0]) | |
h = (box[3] - box[1]) | |
new_box_list.append([x1, y1, w, h]) | |
return new_box_list | |
def _nor2std(self, box_list, img_w, img_h): | |
""" | |
YOLO标签,标准化坐标映射到原始图像 | |
:param box_list: bbox数组(nor) | |
:param img_w: 原始图像宽度 | |
:param img_h: 原始图像高度 | |
:return: bbox数组(在原始图像中的坐标) | |
""" | |
for box in box_list: | |
box[0] *= img_w | |
box[1] *= img_h | |
box[2] *= img_w | |
box[3] *= img_h | |
def _std2nor(self, box_list, img_w, img_h): | |
""" | |
YOLO标签,原始图像坐标转标准化坐标 | |
:param box_list: bbox数组(std) | |
:param img_w: 原始图像宽度 | |
:param img_h: 原始图像高度 | |
:return: bbox数组(标准化坐标) | |
""" | |
for box in box_list: | |
box[0] /= img_w | |
box[1] /= img_h | |
box[2] /= img_w | |
box[3] /= img_h | |
def _judge_merge_by_center_distance(self, center_box1, center_box2, distance_threshold): | |
""" | |
根据bbox中心坐标间距,判断是否进行bbox合并 | |
:param center_box1: box1的中心坐标 | |
:param center_box2: box2的中心坐标 | |
:param distance_threshold: 间距阈值 | |
:return: 若间距小于阈值,进行合并(Ture);反之则忽略(False) | |
""" | |
x1 = center_box1[0] | |
x2 = center_box2[0] | |
y1 = center_box1[1] | |
y2 = center_box2[1] | |
distance = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) | |
if distance < distance_threshold: | |
return True | |
else: | |
return False | |
def _judge_merge_by_overlap_area(self, std_box1, std_box2, overlap_threshold): | |
""" | |
根据bbox交叉面积,判断是否进行bbox合并 | |
:param std_box1: box1的标准坐标 | |
:param std_box2: box2的标准坐标 | |
:param overlap_threshold: 交叉面积阈值 | |
:return: 若交叉面积大于阈值,进行合并(True);反之则忽略(False) | |
""" | |
x1 = max(std_box1[0], std_box2[0]) | |
y1 = max(std_box1[1], std_box2[1]) | |
x2 = min(std_box1[2], std_box2[2]) | |
y2 = min(std_box1[3], std_box2[3]) | |
width = max(0, x2 - x1) | |
height = max(0, y2 - y1) | |
area = width * height | |
if area < overlap_threshold: | |
return False | |
else: | |
return True | |
def _basic_merge(self, box1, box2): | |
""" | |
合并两个box,生成新的box坐标 | |
:param box1: box1坐标(std) | |
:param box2: box2坐标(std) | |
:return: 新box坐标(std) | |
""" | |
x11 = box1[0] | |
y11 = box1[1] | |
x12 = box1[2] | |
y12 = box1[3] | |
x21 = box2[0] | |
y21 = box2[1] | |
x22 = box2[2] | |
y22 = box2[3] | |
new_x1 = min(x11, x12, x21, x22) | |
new_y1 = min(y11, y12, y21, y22) | |
new_x2 = max(x11, x12, x21, x22) | |
new_y2 = max(y11, y12, y21, y22) | |
assert len(box1) == len(box2), 'box1 and box2 has different size' | |
if len(box1) == 6: # 此时,box中带有conf和class,其结构为[x1, y1, x2, y2, conf, class] | |
avg_conf = (box1[4] + box2[4]) / 2 | |
clas = box1[5] | |
new_box = [new_x1, new_y1, new_x2, new_y2, avg_conf, clas] | |
else: | |
new_box = [new_x1, new_y1, new_x2, new_y2] | |
return new_box | |
def _update_list(self, bbox_list, del_index): | |
""" | |
更新bbox数组,删除特定的bbox元素(已经被合并到其他box中的bbox) | |
:param bbox_list: bbox数组 | |
:param del_index: 待删除bbox元素的rank | |
:return: 更新后的bbox数组 | |
""" | |
assert len(bbox_list) > del_index >= 0, 'del_index out of boundary' | |
bbox_list[del_index] = bbox_list[-1:][0] | |
bbox_list.pop() | |
return bbox_list | |
def _merge_box(self, std_bbox_list, threshold, strategy='overlap'): | |
""" | |
bbox合并算法,根据选定的合并策略及阈值,进行bbox合并 | |
:param std_bbox_list: std_bbox_list可有两种格式:(Array[N, 4] -> [x1, y1, x2, y2]; Array[N, 6] -> [x1, y1, x2, y2, conf, class]) | |
:param threshold: 阈值 | |
:param strategy: 合并策略(distance/overlap) | |
""" | |
if isinstance(std_bbox_list, torch.Tensor): | |
std_bbox_list = std_bbox_list.tolist() | |
center_bbox_list = self._xyxy2xywh(std_bbox_list) | |
i = 0 | |
while i < len(std_bbox_list): | |
j = i + 1 | |
while j < len(std_bbox_list): | |
if strategy == 'overlap': | |
assert i < len(std_bbox_list) and j < len(std_bbox_list), f'len={len(std_bbox_list)}, j={j}, i={i}' | |
if self._judge_merge_by_overlap_area(std_bbox_list[i], std_bbox_list[j], threshold): | |
std_bbox_list[i] = self._basic_merge(std_bbox_list[i], std_bbox_list[j]) | |
self._update_list(std_bbox_list, j) | |
self._update_list(center_bbox_list, j) | |
continue | |
else: | |
if self._judge_merge_by_center_distance(center_bbox_list[i], center_bbox_list[j], threshold): | |
std_bbox_list[i] = self._basic_merge(std_bbox_list[i], std_bbox_list[j]) | |
self._update_list(std_bbox_list, j) | |
self._update_list(center_bbox_list, j) | |
continue | |
j += 1 | |
i += 1 | |
return std_bbox_list | |
def main(): | |
model = YoloModel("SHOU-ISD/yolo-cracks", "last4.pt", "SHOU-ISD/yolo-cracks", "best.pt") | |
model.preview_detect('./datasets/Das1100209.jpg', 0.4).show() | |
if __name__ == '__main__': | |
main() | |