File size: 918 Bytes
5b765fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import paddle


class ClsPostProcess(object):
    """Convert between text-label and text-index"""

    def __init__(self, label_list=None, key=None, **kwargs):
        super(ClsPostProcess, self).__init__()
        self.label_list = label_list
        self.key = key

    def __call__(self, preds, label=None, *args, **kwargs):
        if self.key is not None:
            preds = preds[self.key]

        label_list = self.label_list
        if label_list is None:
            label_list = {idx: idx for idx in range(preds.shape[-1])}

        if isinstance(preds, paddle.Tensor):
            preds = preds.numpy()

        pred_idxs = preds.argmax(axis=1)
        decode_out = [
            (label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs)
        ]
        if label is None:
            return decode_out
        label = [(label_list[idx], 1.0) for idx in label]
        return decode_out, label