Spaces:
Runtime error
Runtime error
File size: 1,420 Bytes
fc8c192 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
class VQAReTokenLayoutLMPostProcess(object):
"""Convert between text-label and text-index"""
def __init__(self, **kwargs):
super(VQAReTokenLayoutLMPostProcess, self).__init__()
def __call__(self, preds, label=None, *args, **kwargs):
if label is not None:
return self._metric(preds, label)
else:
return self._infer(preds, *args, **kwargs)
def _metric(self, preds, label):
return preds["pred_relations"], label[6], label[5]
def _infer(self, preds, *args, **kwargs):
ser_results = kwargs["ser_results"]
entity_idx_dict_batch = kwargs["entity_idx_dict_batch"]
pred_relations = preds["pred_relations"]
# merge relations and ocr info
results = []
for pred_relation, ser_result, entity_idx_dict in zip(
pred_relations, ser_results, entity_idx_dict_batch
):
result = []
used_tail_id = []
for relation in pred_relation:
if relation["tail_id"] in used_tail_id:
continue
used_tail_id.append(relation["tail_id"])
ocr_info_head = ser_result[entity_idx_dict[relation["head_id"]]]
ocr_info_tail = ser_result[entity_idx_dict[relation["tail_id"]]]
result.append((ocr_info_head, ocr_info_tail))
results.append(result)
return results
|