|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import numpy as np |
|
|
|
import os |
|
import sys |
|
import json |
|
|
|
__dir__ = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append(__dir__) |
|
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) |
|
|
|
os.environ["FLAGS_allocator_strategy"] = 'auto_growth' |
|
|
|
import paddle |
|
|
|
from ppocr.data import create_operators, transform |
|
from ppocr.modeling.architectures import build_model |
|
from ppocr.postprocess import build_post_process |
|
from ppocr.utils.save_load import load_model |
|
from ppocr.utils.utility import get_image_file_list |
|
import tools.program as program |
|
|
|
|
|
def main(): |
|
global_config = config['Global'] |
|
|
|
|
|
post_process_class = build_post_process(config['PostProcess'], |
|
global_config) |
|
|
|
|
|
if hasattr(post_process_class, 'character'): |
|
char_num = len(getattr(post_process_class, 'character')) |
|
if config['Architecture']["algorithm"] in ["Distillation", |
|
]: |
|
for key in config['Architecture']["Models"]: |
|
if config['Architecture']['Models'][key]['Head'][ |
|
'name'] == 'MultiHead': |
|
out_channels_list = {} |
|
if config['PostProcess'][ |
|
'name'] == 'DistillationSARLabelDecode': |
|
char_num = char_num - 2 |
|
out_channels_list['CTCLabelDecode'] = char_num |
|
out_channels_list['SARLabelDecode'] = char_num + 2 |
|
config['Architecture']['Models'][key]['Head'][ |
|
'out_channels_list'] = out_channels_list |
|
else: |
|
config['Architecture']["Models"][key]["Head"][ |
|
'out_channels'] = char_num |
|
elif config['Architecture']['Head'][ |
|
'name'] == 'MultiHead': |
|
out_channels_list = {} |
|
if config['PostProcess']['name'] == 'SARLabelDecode': |
|
char_num = char_num - 2 |
|
out_channels_list['CTCLabelDecode'] = char_num |
|
out_channels_list['SARLabelDecode'] = char_num + 2 |
|
config['Architecture']['Head'][ |
|
'out_channels_list'] = out_channels_list |
|
else: |
|
config['Architecture']["Head"]['out_channels'] = char_num |
|
|
|
model = build_model(config['Architecture']) |
|
|
|
load_model(config, model) |
|
|
|
|
|
transforms = [] |
|
for op in config['Eval']['dataset']['transforms']: |
|
op_name = list(op)[0] |
|
if 'Label' in op_name: |
|
continue |
|
elif op_name in ['RecResizeImg']: |
|
op[op_name]['infer_mode'] = True |
|
elif op_name == 'KeepKeys': |
|
if config['Architecture']['algorithm'] == "SRN": |
|
op[op_name]['keep_keys'] = [ |
|
'image', 'encoder_word_pos', 'gsrm_word_pos', |
|
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2' |
|
] |
|
elif config['Architecture']['algorithm'] == "SAR": |
|
op[op_name]['keep_keys'] = ['image', 'valid_ratio'] |
|
elif config['Architecture']['algorithm'] == "RobustScanner": |
|
op[op_name][ |
|
'keep_keys'] = ['image', 'valid_ratio', 'word_positons'] |
|
else: |
|
op[op_name]['keep_keys'] = ['image'] |
|
transforms.append(op) |
|
global_config['infer_mode'] = True |
|
ops = create_operators(transforms, global_config) |
|
|
|
save_res_path = config['Global'].get('save_res_path', |
|
"./output/rec/predicts_rec.txt") |
|
if not os.path.exists(os.path.dirname(save_res_path)): |
|
os.makedirs(os.path.dirname(save_res_path)) |
|
|
|
model.eval() |
|
|
|
with open(save_res_path, "w") as fout: |
|
for file in get_image_file_list(config['Global']['infer_img']): |
|
logger.info("infer_img: {}".format(file)) |
|
with open(file, 'rb') as f: |
|
img = f.read() |
|
data = {'image': img} |
|
batch = transform(data, ops) |
|
if config['Architecture']['algorithm'] == "SRN": |
|
encoder_word_pos_list = np.expand_dims(batch[1], axis=0) |
|
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) |
|
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) |
|
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) |
|
|
|
others = [ |
|
paddle.to_tensor(encoder_word_pos_list), |
|
paddle.to_tensor(gsrm_word_pos_list), |
|
paddle.to_tensor(gsrm_slf_attn_bias1_list), |
|
paddle.to_tensor(gsrm_slf_attn_bias2_list) |
|
] |
|
if config['Architecture']['algorithm'] == "SAR": |
|
valid_ratio = np.expand_dims(batch[-1], axis=0) |
|
img_metas = [paddle.to_tensor(valid_ratio)] |
|
if config['Architecture']['algorithm'] == "RobustScanner": |
|
valid_ratio = np.expand_dims(batch[1], axis=0) |
|
word_positons = np.expand_dims(batch[2], axis=0) |
|
img_metas = [ |
|
paddle.to_tensor(valid_ratio), |
|
paddle.to_tensor(word_positons), |
|
] |
|
if config['Architecture']['algorithm'] == "CAN": |
|
image_mask = paddle.ones( |
|
(np.expand_dims( |
|
batch[0], axis=0).shape), dtype='float32') |
|
label = paddle.ones((1, 36), dtype='int64') |
|
images = np.expand_dims(batch[0], axis=0) |
|
images = paddle.to_tensor(images) |
|
if config['Architecture']['algorithm'] == "SRN": |
|
preds = model(images, others) |
|
elif config['Architecture']['algorithm'] == "SAR": |
|
preds = model(images, img_metas) |
|
elif config['Architecture']['algorithm'] == "RobustScanner": |
|
preds = model(images, img_metas) |
|
elif config['Architecture']['algorithm'] == "CAN": |
|
preds = model([images, image_mask, label]) |
|
else: |
|
preds = model(images) |
|
post_result = post_process_class(preds) |
|
info = None |
|
if isinstance(post_result, dict): |
|
rec_info = dict() |
|
for key in post_result: |
|
if len(post_result[key][0]) >= 2: |
|
rec_info[key] = { |
|
"label": post_result[key][0][0], |
|
"score": float(post_result[key][0][1]), |
|
} |
|
info = json.dumps(rec_info, ensure_ascii=False) |
|
elif isinstance(post_result, list) and isinstance(post_result[0], |
|
int): |
|
|
|
info = str(post_result[0]) |
|
else: |
|
if len(post_result[0]) >= 2: |
|
info = post_result[0][0] + "\t" + str(post_result[0][1]) |
|
|
|
if info is not None: |
|
logger.info("\t result: {}".format(info)) |
|
fout.write(file + "\t" + info + "\n") |
|
logger.info("success!") |
|
|
|
|
|
if __name__ == '__main__': |
|
config, device, logger, vdl_writer = program.preprocess() |
|
main() |
|
|