File size: 10,189 Bytes
24c4def
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import Dict, Sequence, Union

import numpy as np
import torch
from torch import nn

from mmocr.models.common.dictionary import Dictionary
from mmocr.models.textrecog.module_losses import CEModuleLoss
from mmocr.registry import MODELS
from mmocr.structures import TextSpottingDataSample


@MODELS.register_module()
class SPTSModuleLoss(CEModuleLoss):
    """Implementation of loss module for SPTS with CrossEntropy loss.

    Args:
        dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
            the instance of `Dictionary`.
        num_bins (int): Number of bins dividing the image. Defaults to 1000.
        seq_eos_coef (float): The loss weight coefficient of seq_eos token.
            Defaults to 0.01.
        max_seq_len (int): Maximum sequence length. In SPTS, a sequence
            encodes all the text instances in a sample. Defaults to 40, which
            will be overridden by SPTSDecoder.
        max_text_len (int): Maximum length for each text instance in a
            sequence. Defaults to 25.
        letter_case (str): There are three options to alter the letter cases
            of gt texts:
            - unchanged: Do not change gt texts.
            - upper: Convert gt texts into uppercase characters.
            - lower: Convert gt texts into lowercase characters.
            Usually, it only works for English characters. Defaults to
            'unchanged'.
        pad_with (str): The padding strategy for ``gt_text.padded_indexes``.
            Defaults to 'auto'. Options are:
            - 'auto': Use dictionary.padding_idx to pad gt texts, or
              dictionary.end_idx if dictionary.padding_idx
              is None.
            - 'padding': Always use dictionary.padding_idx to pad gt texts.
            - 'end': Always use dictionary.end_idx to pad gt texts.
            - 'none': Do not pad gt texts.
        ignore_char (int or str): Specifies a target value that is
            ignored and does not contribute to the input gradient.
            ignore_char can be int or str. If int, it is the index of
            the ignored char. If str, it is the character to ignore.
            Apart from single characters, each item can be one of the
            following reversed keywords: 'padding', 'start', 'end',
            and 'unknown', which refer to their corresponding special
            tokens in the dictionary. It will not ignore any special
            tokens when ignore_char == -1 or 'none'. Defaults to 'padding'.
        flatten (bool): Whether to flatten the output and target before
            computing CE loss. Defaults to False.
        reduction (str): Specifies the reduction to apply to the output,
            should be one of the following: ('none', 'mean', 'sum'). Defaults
            to 'none'.
        ignore_first_char (bool): Whether to ignore the first token in target (
            usually the start token). Defaults to ``True``.
        flatten (bool): Whether to flatten the vectors for loss computation.
            Defaults to False.
    """

    def __init__(self,
                 dictionary: Union[Dict, Dictionary],
                 num_bins: int,
                 seq_eos_coef: float = 0.01,
                 max_seq_len: int = 40,
                 max_text_len: int = 25,
                 letter_case: str = 'unchanged',
                 pad_with: str = 'auto',
                 ignore_char: Union[int, str] = 'padding',
                 flatten: bool = False,
                 reduction: str = 'none',
                 ignore_first_char: bool = True):
        super().__init__(dictionary, max_seq_len, letter_case, pad_with,
                         ignore_char, flatten, reduction, ignore_first_char)
        # TODO: fix hardcode
        self.max_text_len = max_text_len
        self.max_num_text = (self.max_seq_len - 1) // (2 + max_text_len)
        self.num_bins = num_bins

        weights = torch.ones(self.dictionary.num_classes, dtype=torch.float32)
        weights[self.dictionary.seq_end_idx] = seq_eos_coef
        weights.requires_grad_ = False
        self.loss_ce = nn.CrossEntropyLoss(
            ignore_index=self.ignore_index,
            reduction=reduction,
            weight=weights)

    def get_targets(
        self, data_samples: Sequence[TextSpottingDataSample]
    ) -> Sequence[TextSpottingDataSample]:
        """Target generator.

        Args:
            data_samples (list[TextSpottingDataSample]): It usually includes
                ``gt_instances`` information.

        Returns:
            list[TextSpottingDataSample]: Updated data_samples. Two keys will
            be added to data_sample:

            - indexes (torch.LongTensor): Character indexes representing gt
              texts. All special tokens are excluded, except for UKN.
            - padded_indexes (torch.LongTensor): Character indexes
              representing gt texts with BOS and EOS if applicable, following
              several padding indexes until the length reaches ``max_seq_len``.
              In particular, if ``pad_with='none'``, no padding will be
              applied.
        """

        batch_max_len = 0

        for data_sample in data_samples:
            if data_sample.get('have_target', False):
                continue

            if len(data_sample.gt_instances) > self.max_num_text:
                keep = random.sample(
                    range(len(data_sample.gt_instances)), self.max_num_text)
                data_sample.gt_instances = data_sample.gt_instances[keep]

            gt_instances = data_sample.gt_instances

            if len(gt_instances) > 0:
                center_pts = []
                # Slightly different from the original implementation
                # which gets the center points from bezier curves
                # for bezier_pt in gt_instances.beziers:
                #     bezier_pt = bezier_pt.reshape(8, 2)
                #     mid_pt1 = sample_bezier_curve(
                #         bezier_pt[:4], mid_point=True)
                #     mid_pt2 = sample_bezier_curve(
                #         bezier_pt[4:], mid_point=True)
                #     center_pt = (mid_pt1 + mid_pt2) / 2
                for polygon in gt_instances.polygons:
                    center_pt = polygon.reshape(-1, 2).mean(0)
                    center_pts.append(center_pt)
                center_pts = np.vstack(center_pts)
                center_pts /= data_sample.img_shape[::-1]
                center_pts = torch.from_numpy(center_pts).type(torch.float32)
            else:
                center_pts = torch.ones(0).reshape(-1, 2).type(torch.float32)

            center_pts = (center_pts * self.num_bins).floor().type(torch.long)
            center_pts = torch.clamp(center_pts, min=0, max=self.num_bins - 1)

            gt_indexes = []
            for text in gt_instances.texts:
                if self.letter_case in ['upper', 'lower']:
                    text = getattr(text, self.letter_case)()

                indexes = self.dictionary.str2idx(text)
                indexes_tensor = torch.zeros(
                    self.max_text_len,
                    dtype=torch.long) + self.dictionary.end_idx
                max_len = min(self.max_text_len - 1, len(indexes))
                indexes_tensor[:max_len] = torch.LongTensor(indexes)[:max_len]
                indexes_tensor = indexes_tensor
                gt_indexes.append(indexes_tensor)

            if len(gt_indexes) == 0:
                gt_indexes = torch.ones(0).reshape(-1, self.max_text_len)
            else:
                gt_indexes = torch.vstack(gt_indexes)
            gt_indexes = torch.cat([center_pts, gt_indexes], dim=-1)
            gt_indexes = gt_indexes.flatten()

            if self.dictionary.start_idx is not None:
                gt_indexes = torch.cat([
                    torch.LongTensor([self.dictionary.start_idx]), gt_indexes
                ])
            if self.dictionary.seq_end_idx is not None:
                gt_indexes = torch.cat([
                    gt_indexes,
                    torch.LongTensor([self.dictionary.seq_end_idx])
                ])

            batch_max_len = max(batch_max_len, len(gt_indexes))

            gt_instances.set_metainfo(dict(indexes=gt_indexes))

        # Here we have to have the second pass as we need to know the max
        # length of the batch to pad the indexes in order to save memory
        for data_sample in data_samples:

            if data_sample.get('have_target', False):
                continue

            indexes = data_sample.gt_instances.indexes

            padded_indexes = (
                torch.zeros(batch_max_len, dtype=torch.long) +
                self.dictionary.padding_idx)
            padded_indexes[:len(indexes)] = indexes
            data_sample.gt_instances.set_metainfo(
                dict(padded_indexes=padded_indexes))
            data_sample.set_metainfo(dict(have_target=True))

        return data_samples

    def forward(self, outputs: torch.Tensor,
                data_samples: Sequence[TextSpottingDataSample]) -> Dict:
        """
        Args:
            outputs (Tensor): A raw logit tensor of shape :math:`(N, T, C)`.
            data_samples (list[TextSpottingDataSample]): List of
                ``TextSpottingDataSample`` which are processed by
                ``get_targets``.

        Returns:
            dict: A loss dict with the key ``loss_ce``.
        """
        targets = list()
        for data_sample in data_samples:
            targets.append(data_sample.gt_instances.padded_indexes)
        targets = torch.stack(targets, dim=0).long()
        if self.ignore_first_char:
            targets = targets[:, 1:].contiguous()
            # outputs = outputs[:, :-1, :].contiguous()
        if self.flatten:
            outputs = outputs.view(-1, outputs.size(-1))
            targets = targets.view(-1)
        else:
            outputs = outputs.permute(0, 2, 1).contiguous()

        loss_ce = self.loss_ce(outputs, targets.to(outputs.device))
        losses = dict(loss_ce=loss_ce)

        return losses