import cv2 import numpy as np cimport cython cimport libcpp cimport libcpp.pair cimport libcpp.queue cimport numpy as np from libcpp.pair cimport * from libcpp.queue cimport * @cython.boundscheck(False) @cython.wraparound(False) cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels, np.ndarray[np.int32_t, ndim=2] label, int kernel_num, int label_num, float min_area=0): cdef np.ndarray[np.int32_t, ndim=2] pred pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) for label_idx in range(1, label_num): if np.sum(label == label_idx) < min_area: label[label == label_idx] = 0 cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \ queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \ queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() cdef np.int16_t* dx = [-1, 1, 0, 0] cdef np.int16_t* dy = [0, 0, -1, 1] cdef np.int16_t tmpx, tmpy points = np.array(np.where(label > 0)).transpose((1, 0)) for point_idx in range(points.shape[0]): tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) pred[tmpx, tmpy] = label[tmpx, tmpy] cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur cdef int cur_label for kernel_idx in range(kernel_num - 1, -1, -1): while not que.empty(): cur = que.front() que.pop() cur_label = pred[cur.first, cur.second] is_edge = True for j in range(4): tmpx = cur.first + dx[j] tmpy = cur.second + dy[j] if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: continue if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: continue que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) pred[tmpx, tmpy] = cur_label is_edge = False if is_edge: nxt_que.push(cur) que, nxt_que = nxt_que, que return pred def pse(kernels, min_area): kernel_num = kernels.shape[0] label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4) return _pse(kernels[:-1], label, kernel_num, label_num, min_area)