Spaces:
Runtime error
Runtime error
from __future__ import absolute_import, division, print_function | |
import cv2 | |
import numpy as np | |
import paddle | |
from .locality_aware_nms import nms_locality | |
class EASTPostProcess(object): | |
""" | |
The post process for EAST. | |
""" | |
def __init__(self, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2, **kwargs): | |
self.score_thresh = score_thresh | |
self.cover_thresh = cover_thresh | |
self.nms_thresh = nms_thresh | |
def restore_rectangle_quad(self, origin, geometry): | |
""" | |
Restore rectangle from quadrangle. | |
""" | |
# quad | |
origin_concat = np.concatenate( | |
(origin, origin, origin, origin), axis=1 | |
) # (n, 8) | |
pred_quads = origin_concat - geometry | |
pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2) | |
return pred_quads | |
def detect( | |
self, score_map, geo_map, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2 | |
): | |
""" | |
restore text boxes from score map and geo map | |
""" | |
score_map = score_map[0] | |
geo_map = np.swapaxes(geo_map, 1, 0) | |
geo_map = np.swapaxes(geo_map, 1, 2) | |
# filter the score map | |
xy_text = np.argwhere(score_map > score_thresh) | |
if len(xy_text) == 0: | |
return [] | |
# sort the text boxes via the y axis | |
xy_text = xy_text[np.argsort(xy_text[:, 0])] | |
# restore quad proposals | |
text_box_restored = self.restore_rectangle_quad( | |
xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :] | |
) | |
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) | |
boxes[:, :8] = text_box_restored.reshape((-1, 8)) | |
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] | |
try: | |
import lanms | |
boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh) | |
except: | |
print( | |
"you should install lanms by pip3 install lanms-nova to speed up nms_locality" | |
) | |
boxes = nms_locality(boxes.astype(np.float64), nms_thresh) | |
if boxes.shape[0] == 0: | |
return [] | |
# Here we filter some low score boxes by the average score map, | |
# this is different from the orginal paper. | |
for i, box in enumerate(boxes): | |
mask = np.zeros_like(score_map, dtype=np.uint8) | |
cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1) | |
boxes[i, 8] = cv2.mean(score_map, mask)[0] | |
boxes = boxes[boxes[:, 8] > cover_thresh] | |
return boxes | |
def sort_poly(self, p): | |
""" | |
Sort polygons. | |
""" | |
min_axis = np.argmin(np.sum(p, axis=1)) | |
p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]] | |
if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]): | |
return p | |
else: | |
return p[[0, 3, 2, 1]] | |
def __call__(self, outs_dict, shape_list): | |
score_list = outs_dict["f_score"] | |
geo_list = outs_dict["f_geo"] | |
if isinstance(score_list, paddle.Tensor): | |
score_list = score_list.numpy() | |
geo_list = geo_list.numpy() | |
img_num = len(shape_list) | |
dt_boxes_list = [] | |
for ino in range(img_num): | |
score = score_list[ino] | |
geo = geo_list[ino] | |
boxes = self.detect( | |
score_map=score, | |
geo_map=geo, | |
score_thresh=self.score_thresh, | |
cover_thresh=self.cover_thresh, | |
nms_thresh=self.nms_thresh, | |
) | |
boxes_norm = [] | |
if len(boxes) > 0: | |
h, w = score.shape[1:] | |
src_h, src_w, ratio_h, ratio_w = shape_list[ino] | |
boxes = boxes[:, :8].reshape((-1, 4, 2)) | |
boxes[:, :, 0] /= ratio_w | |
boxes[:, :, 1] /= ratio_h | |
for i_box, box in enumerate(boxes): | |
box = self.sort_poly(box.astype(np.int32)) | |
if ( | |
np.linalg.norm(box[0] - box[1]) < 5 | |
or np.linalg.norm(box[3] - box[0]) < 5 | |
): | |
continue | |
boxes_norm.append(box) | |
dt_boxes_list.append({"points": np.array(boxes_norm)}) | |
return dt_boxes_list | |