object-detection / models /yolo_crack.py
mingyang91's picture
add crack split.py
684e6f5 verified
from models.tools import split
from PIL import Image
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
import torch
import cv2
import numpy as np
import math
from models.tools.draw import add_bboxes2
class YoloModel:
def __init__(self, seg_repo_name: str, seg_file_name: str, det_repo_name: str, det_file_name: str):
seg_weight_file = YoloModel.download_weight_file(seg_repo_name, seg_file_name)
det_weight_file = YoloModel.download_weight_file(det_repo_name, det_file_name)
self.seg_model = YOLO(seg_weight_file)
self.det_model = YOLO(det_weight_file)
@staticmethod
def download_weight_file(repo_name: str, file_name: str):
return hf_hub_download(repo_name, file_name)
def preview_detect(self, im, confidence):
results = self.detect(im)
res_img = Image.open(im)
res = {
'boxes': [
{
'xyxy': [x1, y1, x2, y2],
'cls': cls,
'conf': conf
} for x1, y1, x2, y2, conf, cls in results
]
}
res_img = add_bboxes2(res_img, res, confidence)
return res_img
def detect(self, source):
pred_bbox_list = [] # 初始化该图像bbox列表
threshold = 50 # 暂定bbox merge 阈值为50, 后期可根据用户需求做自适应调整
strategy = "distance" # 暂定bbox merge 策略为distance
seg_img_list = self._seg_ori_img(source) # 对该图像进行路面分割
assert len(seg_img_list) == 1, "seg_img_list out of range"
road_img = Image.fromarray(cv2.cvtColor(seg_img_list[0], cv2.COLOR_BGR2RGB))
small_imgs = split.split_image(road_img, (640, 640), (1080, 1080), 0.1) # 对路面图像进行小图分割
num = 0
for small_img in small_imgs:
num += 1
results = self.det_model(source=small_img["image"])
for result in results:
temp_bbox_list = result.boxes.xyxy # 获取检测结果中的bbox坐标(此处使用xyxy格式)
w_bias = small_img["area"][0]
h_bias = small_img["area"][1]
temp_bbox_list = self._bbox_map(temp_bbox_list, w_bias, h_bias) # 将bbox坐标映射到原始大图坐标系中
temp_bbox_cls = result.boxes.cls # 获取检测结果中的class
temp_bbox_conf = result.boxes.conf # 获取检测结果中的confidence
assert len(temp_bbox_list) == len(temp_bbox_cls) == len(
temp_bbox_conf), 'different number of matrix size'
for i in range(len(temp_bbox_list)): # 整合bbox、conf和class到一个数组中
temp_bbox_list[i].append(temp_bbox_conf[i])
temp_bbox_list[i].append(temp_bbox_cls[i])
pred_bbox_list += temp_bbox_list # 将单张大图分割后的全体小图得到的检测结果(bbox、conf、class)整合到一个list
pred_bbox_list = self._merge_box(pred_bbox_list, threshold, strategy=strategy) # 调用指定算法,对bbox进行分析合并
return pred_bbox_list
def _seg_ori_img(self, source):
"""
分割原始图像中的沥青路面区域
:param source: 图像路径
:return: 分割得到的沥青路面图像(尺寸与原始图像一致,非路面区域用白色填充)
"""
ori_img = cv2.imread(source)
ori_size = ori_img.shape
results = self.seg_model(source=source)
seg_img_list = []
for result in results:
if result.masks is not None and len(result.masks) > 0: # 检测到路面时
masks_data = result.masks.data
obj_masks = masks_data[:]
road_mask = torch.any(obj_masks, dim=0).int() * 255
mask = road_mask.cpu().numpy()
Mask = mask.astype(np.uint8)
mask_res = cv2.resize(Mask, (ori_size[1], ori_size[0]), interpolation=cv2.INTER_CUBIC)
else: # 检测不到路面时保存纯黑色图像
mask_res = np.zeros((ori_size[0], ori_size[1], 3), dtype=np.uint8)
mask_region = mask_res == 0
ori_img[mask_region] = 255 # 判断条件置0掩码为黑,置255背景为白
seg_img_list.append(ori_img)
return seg_img_list
def _bbox_map(self, bbox_list, w, h):
"""
将小图中的bbox坐标映射到原始图像中
:param bbox_list: 小图中的bbox数组
:param w: 小图在原始图像中的偏置w
:param h: 小图在原始图像中的偏置h
:return: 该bbox数组在原始图像中的坐标
"""
if isinstance(bbox_list, torch.Tensor):
bbox_list = bbox_list.tolist()
for bbox in bbox_list:
bbox[0] += w
bbox[1] += h
bbox[2] += w
bbox[3] += h
return bbox_list
def _xywh2xyxy(self, box_list):
"""
YOLO标签,xywh转xyxy
:param box_list: bbox数组(xywh)
:return: bbox数组(xyxy)
"""
new_box_list = []
for box in box_list:
x1 = box[0] - box[2] / 2
y1 = box[1] - box[3] / 2
x2 = box[0] + box[2] / 2
y2 = box[1] + box[3] / 2
new_box_list.append([x1, y1, x2, y2])
return new_box_list
def _xyxy2xywh(self, box_list):
"""
YOLO标签,xyxy转xywh
:param box_list: bbox数组(xyxy)
:return: bbox数组(xywh)
"""
new_box_list = []
for box in box_list:
x1 = (box[0] + box[2]) / 2
y1 = (box[1] + box[3]) / 2
w = (box[2] - box[0])
h = (box[3] - box[1])
new_box_list.append([x1, y1, w, h])
return new_box_list
def _nor2std(self, box_list, img_w, img_h):
"""
YOLO标签,标准化坐标映射到原始图像
:param box_list: bbox数组(nor)
:param img_w: 原始图像宽度
:param img_h: 原始图像高度
:return: bbox数组(在原始图像中的坐标)
"""
for box in box_list:
box[0] *= img_w
box[1] *= img_h
box[2] *= img_w
box[3] *= img_h
def _std2nor(self, box_list, img_w, img_h):
"""
YOLO标签,原始图像坐标转标准化坐标
:param box_list: bbox数组(std)
:param img_w: 原始图像宽度
:param img_h: 原始图像高度
:return: bbox数组(标准化坐标)
"""
for box in box_list:
box[0] /= img_w
box[1] /= img_h
box[2] /= img_w
box[3] /= img_h
def _judge_merge_by_center_distance(self, center_box1, center_box2, distance_threshold):
"""
根据bbox中心坐标间距,判断是否进行bbox合并
:param center_box1: box1的中心坐标
:param center_box2: box2的中心坐标
:param distance_threshold: 间距阈值
:return: 若间距小于阈值,进行合并(Ture);反之则忽略(False)
"""
x1 = center_box1[0]
x2 = center_box2[0]
y1 = center_box1[1]
y2 = center_box2[1]
distance = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
if distance < distance_threshold:
return True
else:
return False
def _judge_merge_by_overlap_area(self, std_box1, std_box2, overlap_threshold):
"""
根据bbox交叉面积,判断是否进行bbox合并
:param std_box1: box1的标准坐标
:param std_box2: box2的标准坐标
:param overlap_threshold: 交叉面积阈值
:return: 若交叉面积大于阈值,进行合并(True);反之则忽略(False)
"""
x1 = max(std_box1[0], std_box2[0])
y1 = max(std_box1[1], std_box2[1])
x2 = min(std_box1[2], std_box2[2])
y2 = min(std_box1[3], std_box2[3])
width = max(0, x2 - x1)
height = max(0, y2 - y1)
area = width * height
if area < overlap_threshold:
return False
else:
return True
def _basic_merge(self, box1, box2):
"""
合并两个box,生成新的box坐标
:param box1: box1坐标(std)
:param box2: box2坐标(std)
:return: 新box坐标(std)
"""
x11 = box1[0]
y11 = box1[1]
x12 = box1[2]
y12 = box1[3]
x21 = box2[0]
y21 = box2[1]
x22 = box2[2]
y22 = box2[3]
new_x1 = min(x11, x12, x21, x22)
new_y1 = min(y11, y12, y21, y22)
new_x2 = max(x11, x12, x21, x22)
new_y2 = max(y11, y12, y21, y22)
assert len(box1) == len(box2), 'box1 and box2 has different size'
if len(box1) == 6: # 此时,box中带有conf和class,其结构为[x1, y1, x2, y2, conf, class]
avg_conf = (box1[4] + box2[4]) / 2
clas = box1[5]
new_box = [new_x1, new_y1, new_x2, new_y2, avg_conf, clas]
else:
new_box = [new_x1, new_y1, new_x2, new_y2]
return new_box
def _update_list(self, bbox_list, del_index):
"""
更新bbox数组,删除特定的bbox元素(已经被合并到其他box中的bbox)
:param bbox_list: bbox数组
:param del_index: 待删除bbox元素的rank
:return: 更新后的bbox数组
"""
assert len(bbox_list) > del_index >= 0, 'del_index out of boundary'
bbox_list[del_index] = bbox_list[-1:][0]
bbox_list.pop()
return bbox_list
def _merge_box(self, std_bbox_list, threshold, strategy='overlap'):
"""
bbox合并算法,根据选定的合并策略及阈值,进行bbox合并
:param std_bbox_list: std_bbox_list可有两种格式:(Array[N, 4] -> [x1, y1, x2, y2]; Array[N, 6] -> [x1, y1, x2, y2, conf, class])
:param threshold: 阈值
:param strategy: 合并策略(distance/overlap)
"""
if isinstance(std_bbox_list, torch.Tensor):
std_bbox_list = std_bbox_list.tolist()
center_bbox_list = self._xyxy2xywh(std_bbox_list)
i = 0
while i < len(std_bbox_list):
j = i + 1
while j < len(std_bbox_list):
if strategy == 'overlap':
assert i < len(std_bbox_list) and j < len(std_bbox_list), f'len={len(std_bbox_list)}, j={j}, i={i}'
if self._judge_merge_by_overlap_area(std_bbox_list[i], std_bbox_list[j], threshold):
std_bbox_list[i] = self._basic_merge(std_bbox_list[i], std_bbox_list[j])
self._update_list(std_bbox_list, j)
self._update_list(center_bbox_list, j)
continue
else:
if self._judge_merge_by_center_distance(center_bbox_list[i], center_bbox_list[j], threshold):
std_bbox_list[i] = self._basic_merge(std_bbox_list[i], std_bbox_list[j])
self._update_list(std_bbox_list, j)
self._update_list(center_bbox_list, j)
continue
j += 1
i += 1
return std_bbox_list
def main():
model = YoloModel("SHOU-ISD/yolo-cracks", "last4.pt", "SHOU-ISD/yolo-cracks", "best.pt")
model.preview_detect('./datasets/Das1100209.jpg', 0.4).show()
if __name__ == '__main__':
main()