Spaces:
Runtime error
Runtime error
from __future__ import absolute_import, division, print_function | |
import os | |
import sys | |
import paddle | |
from .extract_textpoint_fast import * | |
from .extract_textpoint_slow import * | |
__dir__ = os.path.dirname(__file__) | |
sys.path.append(__dir__) | |
sys.path.append(os.path.join(__dir__, "..")) | |
class PGNet_PostProcess(object): | |
# two different post-process | |
def __init__( | |
self, character_dict_path, valid_set, score_thresh, outs_dict, shape_list | |
): | |
self.Lexicon_Table = get_dict(character_dict_path) | |
self.valid_set = valid_set | |
self.score_thresh = score_thresh | |
self.outs_dict = outs_dict | |
self.shape_list = shape_list | |
def pg_postprocess_fast(self): | |
p_score = self.outs_dict["f_score"] | |
p_border = self.outs_dict["f_border"] | |
p_char = self.outs_dict["f_char"] | |
p_direction = self.outs_dict["f_direction"] | |
if isinstance(p_score, paddle.Tensor): | |
p_score = p_score[0].numpy() | |
p_border = p_border[0].numpy() | |
p_direction = p_direction[0].numpy() | |
p_char = p_char[0].numpy() | |
else: | |
p_score = p_score[0] | |
p_border = p_border[0] | |
p_direction = p_direction[0] | |
p_char = p_char[0] | |
src_h, src_w, ratio_h, ratio_w = self.shape_list[0] | |
instance_yxs_list, seq_strs = generate_pivot_list_fast( | |
p_score, | |
p_char, | |
p_direction, | |
self.Lexicon_Table, | |
score_thresh=self.score_thresh, | |
) | |
poly_list, keep_str_list = restore_poly( | |
instance_yxs_list, | |
seq_strs, | |
p_border, | |
ratio_w, | |
ratio_h, | |
src_w, | |
src_h, | |
self.valid_set, | |
) | |
data = { | |
"points": poly_list, | |
"texts": keep_str_list, | |
} | |
return data | |
def pg_postprocess_slow(self): | |
p_score = self.outs_dict["f_score"] | |
p_border = self.outs_dict["f_border"] | |
p_char = self.outs_dict["f_char"] | |
p_direction = self.outs_dict["f_direction"] | |
if isinstance(p_score, paddle.Tensor): | |
p_score = p_score[0].numpy() | |
p_border = p_border[0].numpy() | |
p_direction = p_direction[0].numpy() | |
p_char = p_char[0].numpy() | |
else: | |
p_score = p_score[0] | |
p_border = p_border[0] | |
p_direction = p_direction[0] | |
p_char = p_char[0] | |
src_h, src_w, ratio_h, ratio_w = self.shape_list[0] | |
is_curved = self.valid_set == "totaltext" | |
char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow( | |
p_score, | |
p_char, | |
p_direction, | |
score_thresh=self.score_thresh, | |
is_backbone=True, | |
is_curved=is_curved, | |
) | |
seq_strs = [] | |
for char_idx_set in char_seq_idx_set: | |
pr_str = "".join([self.Lexicon_Table[pos] for pos in char_idx_set]) | |
seq_strs.append(pr_str) | |
poly_list = [] | |
keep_str_list = [] | |
all_point_list = [] | |
all_point_pair_list = [] | |
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs): | |
if len(yx_center_line) == 1: | |
yx_center_line.append(yx_center_line[-1]) | |
offset_expand = 1.0 | |
if self.valid_set == "totaltext": | |
offset_expand = 1.2 | |
point_pair_list = [] | |
for batch_id, y, x in yx_center_line: | |
offset = p_border[:, y, x].reshape(2, 2) | |
if offset_expand != 1.0: | |
offset_length = np.linalg.norm(offset, axis=1, keepdims=True) | |
expand_length = np.clip( | |
offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0 | |
) | |
offset_detal = offset / offset_length * expand_length | |
offset = offset + offset_detal | |
ori_yx = np.array([y, x], dtype=np.float32) | |
point_pair = ( | |
(ori_yx + offset)[:, ::-1] | |
* 4.0 | |
/ np.array([ratio_w, ratio_h]).reshape(-1, 2) | |
) | |
point_pair_list.append(point_pair) | |
all_point_list.append( | |
[int(round(x * 4.0 / ratio_w)), int(round(y * 4.0 / ratio_h))] | |
) | |
all_point_pair_list.append(point_pair.round().astype(np.int32).tolist()) | |
detected_poly, pair_length_info = point_pair2poly(point_pair_list) | |
detected_poly = expand_poly_along_width( | |
detected_poly, shrink_ratio_of_width=0.2 | |
) | |
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w) | |
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h) | |
if len(keep_str) < 2: | |
continue | |
keep_str_list.append(keep_str) | |
detected_poly = np.round(detected_poly).astype("int32") | |
if self.valid_set == "partvgg": | |
middle_point = len(detected_poly) // 2 | |
detected_poly = detected_poly[ | |
[0, middle_point - 1, middle_point, -1], : | |
] | |
poly_list.append(detected_poly) | |
elif self.valid_set == "totaltext": | |
poly_list.append(detected_poly) | |
else: | |
print("--> Not supported format.") | |
exit(-1) | |
data = { | |
"points": poly_list, | |
"texts": keep_str_list, | |
} | |
return data | |
class PGPostProcess(object): | |
""" | |
The post process for PGNet. | |
""" | |
def __init__(self, character_dict_path, valid_set, score_thresh, mode, **kwargs): | |
self.character_dict_path = character_dict_path | |
self.valid_set = valid_set | |
self.score_thresh = score_thresh | |
self.mode = mode | |
# c++ la-nms is faster, but only support python 3.5 | |
self.is_python35 = False | |
if sys.version_info.major == 3 and sys.version_info.minor == 5: | |
self.is_python35 = True | |
def __call__(self, outs_dict, shape_list): | |
post = PGNet_PostProcess( | |
self.character_dict_path, | |
self.valid_set, | |
self.score_thresh, | |
outs_dict, | |
shape_list, | |
) | |
if self.mode == "fast": | |
data = post.pg_postprocess_fast() | |
else: | |
data = post.pg_postprocess_slow() | |
return data | |