Spaces:
Runtime error
Runtime error
File size: 918 Bytes
fc8c192 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import paddle
class ClsPostProcess(object):
"""Convert between text-label and text-index"""
def __init__(self, label_list=None, key=None, **kwargs):
super(ClsPostProcess, self).__init__()
self.label_list = label_list
self.key = key
def __call__(self, preds, label=None, *args, **kwargs):
if self.key is not None:
preds = preds[self.key]
label_list = self.label_list
if label_list is None:
label_list = {idx: idx for idx in range(preds.shape[-1])}
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
pred_idxs = preds.argmax(axis=1)
decode_out = [
(label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs)
]
if label is None:
return decode_out
label = [(label_list[idx], 1.0) for idx in label]
return decode_out, label
|