deprem-ocr-2 / ocr /postprocess /cls_postprocess.py
Goodsea's picture
paddleocr
5b765fe
raw
history blame
918 Bytes
import paddle
class ClsPostProcess(object):
"""Convert between text-label and text-index"""
def __init__(self, label_list=None, key=None, **kwargs):
super(ClsPostProcess, self).__init__()
self.label_list = label_list
self.key = key
def __call__(self, preds, label=None, *args, **kwargs):
if self.key is not None:
preds = preds[self.key]
label_list = self.label_list
if label_list is None:
label_list = {idx: idx for idx in range(preds.shape[-1])}
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
pred_idxs = preds.argmax(axis=1)
decode_out = [
(label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs)
]
if label is None:
return decode_out
label = [(label_list[idx], 1.0) for idx in label]
return decode_out, label