File size: 4,339 Bytes
5b765fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import absolute_import, division, print_function

import cv2
import numpy as np
import paddle

from .locality_aware_nms import nms_locality


class EASTPostProcess(object):
    """
    The post process for EAST.
    """

    def __init__(self, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2, **kwargs):

        self.score_thresh = score_thresh
        self.cover_thresh = cover_thresh
        self.nms_thresh = nms_thresh

    def restore_rectangle_quad(self, origin, geometry):
        """
        Restore rectangle from quadrangle.
        """
        # quad
        origin_concat = np.concatenate(
            (origin, origin, origin, origin), axis=1
        )  # (n, 8)
        pred_quads = origin_concat - geometry
        pred_quads = pred_quads.reshape((-1, 4, 2))  # (n, 4, 2)
        return pred_quads

    def detect(
        self, score_map, geo_map, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2
    ):
        """
        restore text boxes from score map and geo map
        """

        score_map = score_map[0]
        geo_map = np.swapaxes(geo_map, 1, 0)
        geo_map = np.swapaxes(geo_map, 1, 2)
        # filter the score map
        xy_text = np.argwhere(score_map > score_thresh)
        if len(xy_text) == 0:
            return []
        # sort the text boxes via the y axis
        xy_text = xy_text[np.argsort(xy_text[:, 0])]
        # restore quad proposals
        text_box_restored = self.restore_rectangle_quad(
            xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :]
        )
        boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
        boxes[:, :8] = text_box_restored.reshape((-1, 8))
        boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]

        try:
            import lanms

            boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
        except:
            print(
                "you should install lanms by pip3 install lanms-nova to speed up nms_locality"
            )
            boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
        if boxes.shape[0] == 0:
            return []
        # Here we filter some low score boxes by the average score map,
        #   this is different from the orginal paper.
        for i, box in enumerate(boxes):
            mask = np.zeros_like(score_map, dtype=np.uint8)
            cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1)
            boxes[i, 8] = cv2.mean(score_map, mask)[0]
        boxes = boxes[boxes[:, 8] > cover_thresh]
        return boxes

    def sort_poly(self, p):
        """
        Sort polygons.
        """
        min_axis = np.argmin(np.sum(p, axis=1))
        p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]]
        if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
            return p
        else:
            return p[[0, 3, 2, 1]]

    def __call__(self, outs_dict, shape_list):
        score_list = outs_dict["f_score"]
        geo_list = outs_dict["f_geo"]
        if isinstance(score_list, paddle.Tensor):
            score_list = score_list.numpy()
            geo_list = geo_list.numpy()
        img_num = len(shape_list)
        dt_boxes_list = []
        for ino in range(img_num):
            score = score_list[ino]
            geo = geo_list[ino]
            boxes = self.detect(
                score_map=score,
                geo_map=geo,
                score_thresh=self.score_thresh,
                cover_thresh=self.cover_thresh,
                nms_thresh=self.nms_thresh,
            )
            boxes_norm = []
            if len(boxes) > 0:
                h, w = score.shape[1:]
                src_h, src_w, ratio_h, ratio_w = shape_list[ino]
                boxes = boxes[:, :8].reshape((-1, 4, 2))
                boxes[:, :, 0] /= ratio_w
                boxes[:, :, 1] /= ratio_h
                for i_box, box in enumerate(boxes):
                    box = self.sort_poly(box.astype(np.int32))
                    if (
                        np.linalg.norm(box[0] - box[1]) < 5
                        or np.linalg.norm(box[3] - box[0]) < 5
                    ):
                        continue
                    boxes_norm.append(box)
            dt_boxes_list.append({"points": np.array(boxes_norm)})
        return dt_boxes_list