File size: 7,513 Bytes
24c4def
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from mmengine.structures import InstanceData

from mmocr.models import Dictionary
from mmocr.models.textrecog.postprocessors import BaseTextRecogPostprocessor
from mmocr.registry import MODELS
from mmocr.structures import TextSpottingDataSample
from mmocr.utils import rescale_polygons


@MODELS.register_module()
class SPTSPostprocessor(BaseTextRecogPostprocessor):
    """PostProcessor for SPTS.

    Args:
        dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
            the instance of `Dictionary`.
        num_bins (int): Number of bins dividing the image. Defaults to 1000.
        rescale_fields (list[str], optional): The bbox/polygon field names to
            be rescaled. If None, no rescaling will be performed.
        max_seq_len (int): Maximum sequence length. In SPTS, a sequence
            encodes all the text instances in a sample. Defaults to 40, which
            will be overridden by SPTSDecoder.
        ignore_chars (list[str]): A list of characters to be ignored from the
            final results. Postprocessor will skip over these characters when
            converting raw indexes to characters. Apart from single characters,
            each item can be one of the following reversed keywords: 'padding',
            'end' and 'unknown', which refer to their corresponding special
            tokens in the dictionary.
    """

    def __init__(self,
                 dictionary: Union[Dictionary, Dict],
                 num_bins: int,
                 rescale_fields: Optional[Sequence[str]] = ['points'],
                 max_seq_len: int = 40,
                 ignore_chars: Sequence[str] = ['padding'],
                 **kwargs) -> None:
        assert rescale_fields is None or isinstance(rescale_fields, list)
        self.num_bins = num_bins
        self.rescale_fields = rescale_fields
        super().__init__(
            dictionary=dictionary,
            num_bins=num_bins,
            max_seq_len=max_seq_len,
            ignore_chars=ignore_chars)

    def get_single_prediction(
        self,
        max_probs: torch.Tensor,
        seq: torch.Tensor,
        data_sample: Optional[TextSpottingDataSample] = None,
    ) -> Tuple[List[List[int]], List[List[float]], List[Tuple[float]],
               List[Tuple[float]]]:
        """Convert the output probabilities of a single image to character
        indexes, character scores, points and point scores.

        Args:
            max_probs (torch.Tensor): Character probabilities with shape
                :math:`(T)`.
            seq (torch.Tensor): Sequence indexes with shape
                :math:`(T)`.

            data_sample (TextSpottingDataSample, optional): Datasample of an
                image. Defaults to None.

        Returns:
            tuple(list[list[int]], list[list[float]], list[(float, float)],
            list(float, float)): character indexes, character scores, points
            and point scores. Each has len of max_seq_len.
        """
        h, w = data_sample.img_shape
        # the if is not a must since the softmaxed are masked out in decoder
        # if len(max_probs) % 27 != 0:
        #     max_probs = max_probs[:-len(max_probs) % 27]
        #     seq = seq[:-len(seq) % 27]
        # max_value, max_idx = torch.max(max_probs, -1)
        max_probs = max_probs.reshape(-1, 27)
        seq = seq.reshape(-1, 27)

        indexes, text_scores, points, pt_scores = [], [], [], []
        output_indexes = seq.cpu().detach().numpy().tolist()
        output_scores = max_probs.cpu().detach().numpy().tolist()
        for output_index, output_score in zip(output_indexes, output_scores):
            # work for multi-batch
            # if output_index[0] == self.dictionary.seq_end_idx +self.num_bins:
            #     break
            point_x = output_index[0] / self.num_bins * w
            point_y = output_index[1] / self.num_bins * h
            points.append((point_x, point_y))
            pt_scores.append(
                np.mean([
                    output_score[0],
                    output_score[1],
                ]).item())
            indexes.append([])
            char_scores = []
            for char_index, char_score in zip(output_index[2:],
                                              output_score[2:]):
                # the first num_bins indexes are for points
                if char_index in self.ignore_indexes:
                    continue
                if char_index == self.dictionary.end_idx:
                    break
                indexes[-1].append(char_index)
                char_scores.append(char_score)
            text_scores.append(np.mean(char_scores).item())
        return indexes, text_scores, points, pt_scores

    def __call__(
        self, output: Tuple[torch.Tensor, torch.Tensor],
        data_samples: Sequence[TextSpottingDataSample]
    ) -> Sequence[TextSpottingDataSample]:
        """Convert outputs to strings and scores.

        Args:
            output (tuple(Tensor, Tensor)): A tuple of (probs, seq), each has
                the shape of :math:`(T,)`.
            data_samples (list[TextSpottingDataSample]): The list of
                TextSpottingDataSample.

        Returns:
            list(TextSpottingDataSample): The list of TextSpottingDataSample.
        """
        max_probs, seq = output
        batch_size = max_probs.size(0)

        for idx in range(batch_size):
            (char_idxs, text_scores, points,
             pt_scores) = self.get_single_prediction(max_probs[idx, :],
                                                     seq[idx, :],
                                                     data_samples[idx])
            texts = []
            scores = []
            for index, pt_score in zip(char_idxs, pt_scores):
                text = self.dictionary.idx2str(index)
                texts.append(text)
                # the "scores" field only accepts a float number
                scores.append(np.mean(pt_score).item())
            pred_instances = InstanceData()
            pred_instances.texts = texts
            pred_instances.scores = scores
            pred_instances.text_scores = text_scores
            pred_instances.points = points
            data_samples[idx].pred_instances = pred_instances
            pred_instances = self.rescale(data_samples[idx],
                                          data_samples[idx].scale_factor)
        return data_samples

    def rescale(self, results: TextSpottingDataSample,
                scale_factor: Sequence[int]) -> TextSpottingDataSample:
        """Rescale results in ``results.pred_instances`` according to
        ``scale_factor``, whose keys are defined in ``self.rescale_fields``.
        Usually used to rescale bboxes and/or polygons.

        Args:
            results (TextSpottingDataSample): The post-processed prediction
                results.
            scale_factor (tuple(int)): (w_scale, h_scale)

        Returns:
            TextDetDataSample: Prediction results with rescaled results.
        """
        scale_factor = np.asarray(scale_factor)
        for key in self.rescale_fields:
            # TODO: this util may need an alias or to be renamed
            results.pred_instances[key] = rescale_polygons(
                results.pred_instances[key], scale_factor, mode='div')
        return results