File size: 2,577 Bytes
5b765fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

import cv2
import numpy as np

cimport cython
cimport libcpp
cimport libcpp.pair
cimport libcpp.queue
cimport numpy as np
from libcpp.pair cimport *
from libcpp.queue cimport *


@cython.boundscheck(False)
@cython.wraparound(False)
cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels,
                                         np.ndarray[np.int32_t, ndim=2] label,
                                         int kernel_num,
                                         int label_num,
                                         float min_area=0):
    cdef np.ndarray[np.int32_t, ndim=2] pred
    pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32)

    for label_idx in range(1, label_num):
        if np.sum(label == label_idx) < min_area:
            label[label == label_idx] = 0

    cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \
        queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
    cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \
        queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
    cdef np.int16_t* dx = [-1, 1, 0, 0]
    cdef np.int16_t* dy = [0, 0, -1, 1]
    cdef np.int16_t tmpx, tmpy

    points = np.array(np.where(label > 0)).transpose((1, 0))
    for point_idx in range(points.shape[0]):
        tmpx, tmpy = points[point_idx, 0], points[point_idx, 1]
        que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
        pred[tmpx, tmpy] = label[tmpx, tmpy]

    cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur
    cdef int cur_label
    for kernel_idx in range(kernel_num - 1, -1, -1):
        while not que.empty():
            cur = que.front()
            que.pop()
            cur_label = pred[cur.first, cur.second]

            is_edge = True
            for j in range(4):
                tmpx = cur.first + dx[j]
                tmpy = cur.second + dy[j]
                if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]:
                    continue
                if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
                    continue

                que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
                pred[tmpx, tmpy] = cur_label
                is_edge = False
            if is_edge:
                nxt_que.push(cur)

        que, nxt_que = nxt_que, que

    return pred

def pse(kernels, min_area):
    kernel_num = kernels.shape[0]
    label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4)
    return _pse(kernels[:-1], label, kernel_num, label_num, min_area)