|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import numpy as np |
|
|
|
import os |
|
import sys |
|
|
|
__dir__ = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append(__dir__) |
|
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) |
|
|
|
os.environ["FLAGS_allocator_strategy"] = 'auto_growth' |
|
import cv2 |
|
import json |
|
import paddle |
|
import paddle.distributed as dist |
|
|
|
from ppocr.data import create_operators, transform |
|
from ppocr.modeling.architectures import build_model |
|
from ppocr.postprocess import build_post_process |
|
from ppocr.utils.save_load import load_model |
|
from ppocr.utils.visual import draw_re_results |
|
from ppocr.utils.logging import get_logger |
|
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict |
|
from tools.program import ArgsParser, load_config, merge_config |
|
from tools.infer_kie_token_ser import SerPredictor |
|
|
|
|
|
class ReArgsParser(ArgsParser): |
|
def __init__(self): |
|
super(ReArgsParser, self).__init__() |
|
self.add_argument( |
|
"-c_ser", "--config_ser", help="ser configuration file to use") |
|
self.add_argument( |
|
"-o_ser", |
|
"--opt_ser", |
|
nargs='+', |
|
help="set ser configuration options ") |
|
|
|
def parse_args(self, argv=None): |
|
args = super(ReArgsParser, self).parse_args(argv) |
|
assert args.config_ser is not None, \ |
|
"Please specify --config_ser=ser_configure_file_path." |
|
args.opt_ser = self._parse_opt(args.opt_ser) |
|
return args |
|
|
|
|
|
def make_input(ser_inputs, ser_results): |
|
entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2} |
|
batch_size, max_seq_len = ser_inputs[0].shape[:2] |
|
entities = ser_inputs[8][0] |
|
ser_results = ser_results[0] |
|
assert len(entities) == len(ser_results) |
|
|
|
|
|
start = [] |
|
end = [] |
|
label = [] |
|
entity_idx_dict = {} |
|
for i, (res, entity) in enumerate(zip(ser_results, entities)): |
|
if res['pred'] == 'O': |
|
continue |
|
entity_idx_dict[len(start)] = i |
|
start.append(entity['start']) |
|
end.append(entity['end']) |
|
label.append(entities_labels[res['pred']]) |
|
|
|
entities = np.full([max_seq_len + 1, 3], fill_value=-1, dtype=np.int64) |
|
entities[0, 0] = len(start) |
|
entities[1:len(start) + 1, 0] = start |
|
entities[0, 1] = len(end) |
|
entities[1:len(end) + 1, 1] = end |
|
entities[0, 2] = len(label) |
|
entities[1:len(label) + 1, 2] = label |
|
|
|
|
|
head = [] |
|
tail = [] |
|
for i in range(len(label)): |
|
for j in range(len(label)): |
|
if label[i] == 1 and label[j] == 2: |
|
head.append(i) |
|
tail.append(j) |
|
|
|
relations = np.full([len(head) + 1, 2], fill_value=-1, dtype=np.int64) |
|
relations[0, 0] = len(head) |
|
relations[1:len(head) + 1, 0] = head |
|
relations[0, 1] = len(tail) |
|
relations[1:len(tail) + 1, 1] = tail |
|
|
|
entities = np.expand_dims(entities, axis=0) |
|
entities = np.repeat(entities, batch_size, axis=0) |
|
relations = np.expand_dims(relations, axis=0) |
|
relations = np.repeat(relations, batch_size, axis=0) |
|
|
|
|
|
if isinstance(ser_inputs[0], paddle.Tensor): |
|
entities = paddle.to_tensor(entities) |
|
relations = paddle.to_tensor(relations) |
|
ser_inputs = ser_inputs[:5] + [entities, relations] |
|
|
|
entity_idx_dict_batch = [] |
|
for b in range(batch_size): |
|
entity_idx_dict_batch.append(entity_idx_dict) |
|
return ser_inputs, entity_idx_dict_batch |
|
|
|
|
|
class SerRePredictor(object): |
|
def __init__(self, config, ser_config): |
|
global_config = config['Global'] |
|
if "infer_mode" in global_config: |
|
ser_config["Global"]["infer_mode"] = global_config["infer_mode"] |
|
|
|
self.ser_engine = SerPredictor(ser_config) |
|
|
|
|
|
|
|
|
|
self.post_process_class = build_post_process(config['PostProcess'], |
|
global_config) |
|
|
|
|
|
self.model = build_model(config['Architecture']) |
|
|
|
load_model( |
|
config, self.model, model_type=config['Architecture']["model_type"]) |
|
|
|
self.model.eval() |
|
|
|
def __call__(self, data): |
|
ser_results, ser_inputs = self.ser_engine(data) |
|
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) |
|
if self.model.backbone.use_visual_backbone is False: |
|
re_input.pop(4) |
|
preds = self.model(re_input) |
|
post_result = self.post_process_class( |
|
preds, |
|
ser_results=ser_results, |
|
entity_idx_dict_batch=entity_idx_dict_batch) |
|
return post_result |
|
|
|
|
|
def preprocess(): |
|
FLAGS = ReArgsParser().parse_args() |
|
config = load_config(FLAGS.config) |
|
config = merge_config(config, FLAGS.opt) |
|
|
|
ser_config = load_config(FLAGS.config_ser) |
|
ser_config = merge_config(ser_config, FLAGS.opt_ser) |
|
|
|
logger = get_logger() |
|
|
|
|
|
use_gpu = config['Global']['use_gpu'] |
|
|
|
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' |
|
device = paddle.set_device(device) |
|
|
|
logger.info('{} re config {}'.format('*' * 10, '*' * 10)) |
|
print_dict(config, logger) |
|
logger.info('\n') |
|
logger.info('{} ser config {}'.format('*' * 10, '*' * 10)) |
|
print_dict(ser_config, logger) |
|
logger.info('train with paddle {} and device {}'.format(paddle.__version__, |
|
device)) |
|
return config, ser_config, device, logger |
|
|
|
|
|
if __name__ == '__main__': |
|
config, ser_config, device, logger = preprocess() |
|
os.makedirs(config['Global']['save_res_path'], exist_ok=True) |
|
|
|
ser_re_engine = SerRePredictor(config, ser_config) |
|
|
|
if config["Global"].get("infer_mode", None) is False: |
|
data_dir = config['Eval']['dataset']['data_dir'] |
|
with open(config['Global']['infer_img'], "rb") as f: |
|
infer_imgs = f.readlines() |
|
else: |
|
infer_imgs = get_image_file_list(config['Global']['infer_img']) |
|
|
|
with open( |
|
os.path.join(config['Global']['save_res_path'], |
|
"infer_results.txt"), |
|
"w", |
|
encoding='utf-8') as fout: |
|
for idx, info in enumerate(infer_imgs): |
|
if config["Global"].get("infer_mode", None) is False: |
|
data_line = info.decode('utf-8') |
|
substr = data_line.strip("\n").split("\t") |
|
img_path = os.path.join(data_dir, substr[0]) |
|
data = {'img_path': img_path, 'label': substr[1]} |
|
else: |
|
img_path = info |
|
data = {'img_path': img_path} |
|
|
|
save_img_path = os.path.join( |
|
config['Global']['save_res_path'], |
|
os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg") |
|
|
|
result = ser_re_engine(data) |
|
result = result[0] |
|
fout.write(img_path + "\t" + json.dumps( |
|
result, ensure_ascii=False) + "\n") |
|
img_res = draw_re_results(img_path, result) |
|
cv2.imwrite(save_img_path, img_res) |
|
|
|
logger.info("process: [{}/{}], save result to {}".format( |
|
idx, len(infer_imgs), save_img_path)) |
|
|