Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import random | |
from typing import Dict, Sequence, Union | |
import numpy as np | |
import torch | |
from torch import nn | |
from mmocr.models.common.dictionary import Dictionary | |
from mmocr.models.textrecog.module_losses import CEModuleLoss | |
from mmocr.registry import MODELS | |
from mmocr.structures import TextSpottingDataSample | |
class SPTSModuleLoss(CEModuleLoss): | |
"""Implementation of loss module for SPTS with CrossEntropy loss. | |
Args: | |
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or | |
the instance of `Dictionary`. | |
num_bins (int): Number of bins dividing the image. Defaults to 1000. | |
seq_eos_coef (float): The loss weight coefficient of seq_eos token. | |
Defaults to 0.01. | |
max_seq_len (int): Maximum sequence length. In SPTS, a sequence | |
encodes all the text instances in a sample. Defaults to 40, which | |
will be overridden by SPTSDecoder. | |
max_text_len (int): Maximum length for each text instance in a | |
sequence. Defaults to 25. | |
letter_case (str): There are three options to alter the letter cases | |
of gt texts: | |
- unchanged: Do not change gt texts. | |
- upper: Convert gt texts into uppercase characters. | |
- lower: Convert gt texts into lowercase characters. | |
Usually, it only works for English characters. Defaults to | |
'unchanged'. | |
pad_with (str): The padding strategy for ``gt_text.padded_indexes``. | |
Defaults to 'auto'. Options are: | |
- 'auto': Use dictionary.padding_idx to pad gt texts, or | |
dictionary.end_idx if dictionary.padding_idx | |
is None. | |
- 'padding': Always use dictionary.padding_idx to pad gt texts. | |
- 'end': Always use dictionary.end_idx to pad gt texts. | |
- 'none': Do not pad gt texts. | |
ignore_char (int or str): Specifies a target value that is | |
ignored and does not contribute to the input gradient. | |
ignore_char can be int or str. If int, it is the index of | |
the ignored char. If str, it is the character to ignore. | |
Apart from single characters, each item can be one of the | |
following reversed keywords: 'padding', 'start', 'end', | |
and 'unknown', which refer to their corresponding special | |
tokens in the dictionary. It will not ignore any special | |
tokens when ignore_char == -1 or 'none'. Defaults to 'padding'. | |
flatten (bool): Whether to flatten the output and target before | |
computing CE loss. Defaults to False. | |
reduction (str): Specifies the reduction to apply to the output, | |
should be one of the following: ('none', 'mean', 'sum'). Defaults | |
to 'none'. | |
ignore_first_char (bool): Whether to ignore the first token in target ( | |
usually the start token). Defaults to ``True``. | |
flatten (bool): Whether to flatten the vectors for loss computation. | |
Defaults to False. | |
""" | |
def __init__(self, | |
dictionary: Union[Dict, Dictionary], | |
num_bins: int, | |
seq_eos_coef: float = 0.01, | |
max_seq_len: int = 40, | |
max_text_len: int = 25, | |
letter_case: str = 'unchanged', | |
pad_with: str = 'auto', | |
ignore_char: Union[int, str] = 'padding', | |
flatten: bool = False, | |
reduction: str = 'none', | |
ignore_first_char: bool = True): | |
super().__init__(dictionary, max_seq_len, letter_case, pad_with, | |
ignore_char, flatten, reduction, ignore_first_char) | |
# TODO: fix hardcode | |
self.max_text_len = max_text_len | |
self.max_num_text = (self.max_seq_len - 1) // (2 + max_text_len) | |
self.num_bins = num_bins | |
weights = torch.ones(self.dictionary.num_classes, dtype=torch.float32) | |
weights[self.dictionary.seq_end_idx] = seq_eos_coef | |
weights.requires_grad_ = False | |
self.loss_ce = nn.CrossEntropyLoss( | |
ignore_index=self.ignore_index, | |
reduction=reduction, | |
weight=weights) | |
def get_targets( | |
self, data_samples: Sequence[TextSpottingDataSample] | |
) -> Sequence[TextSpottingDataSample]: | |
"""Target generator. | |
Args: | |
data_samples (list[TextSpottingDataSample]): It usually includes | |
``gt_instances`` information. | |
Returns: | |
list[TextSpottingDataSample]: Updated data_samples. Two keys will | |
be added to data_sample: | |
- indexes (torch.LongTensor): Character indexes representing gt | |
texts. All special tokens are excluded, except for UKN. | |
- padded_indexes (torch.LongTensor): Character indexes | |
representing gt texts with BOS and EOS if applicable, following | |
several padding indexes until the length reaches ``max_seq_len``. | |
In particular, if ``pad_with='none'``, no padding will be | |
applied. | |
""" | |
batch_max_len = 0 | |
for data_sample in data_samples: | |
if data_sample.get('have_target', False): | |
continue | |
if len(data_sample.gt_instances) > self.max_num_text: | |
keep = random.sample( | |
range(len(data_sample.gt_instances)), self.max_num_text) | |
data_sample.gt_instances = data_sample.gt_instances[keep] | |
gt_instances = data_sample.gt_instances | |
if len(gt_instances) > 0: | |
center_pts = [] | |
# Slightly different from the original implementation | |
# which gets the center points from bezier curves | |
# for bezier_pt in gt_instances.beziers: | |
# bezier_pt = bezier_pt.reshape(8, 2) | |
# mid_pt1 = sample_bezier_curve( | |
# bezier_pt[:4], mid_point=True) | |
# mid_pt2 = sample_bezier_curve( | |
# bezier_pt[4:], mid_point=True) | |
# center_pt = (mid_pt1 + mid_pt2) / 2 | |
for polygon in gt_instances.polygons: | |
center_pt = polygon.reshape(-1, 2).mean(0) | |
center_pts.append(center_pt) | |
center_pts = np.vstack(center_pts) | |
center_pts /= data_sample.img_shape[::-1] | |
center_pts = torch.from_numpy(center_pts).type(torch.float32) | |
else: | |
center_pts = torch.ones(0).reshape(-1, 2).type(torch.float32) | |
center_pts = (center_pts * self.num_bins).floor().type(torch.long) | |
center_pts = torch.clamp(center_pts, min=0, max=self.num_bins - 1) | |
gt_indexes = [] | |
for text in gt_instances.texts: | |
if self.letter_case in ['upper', 'lower']: | |
text = getattr(text, self.letter_case)() | |
indexes = self.dictionary.str2idx(text) | |
indexes_tensor = torch.zeros( | |
self.max_text_len, | |
dtype=torch.long) + self.dictionary.end_idx | |
max_len = min(self.max_text_len - 1, len(indexes)) | |
indexes_tensor[:max_len] = torch.LongTensor(indexes)[:max_len] | |
indexes_tensor = indexes_tensor | |
gt_indexes.append(indexes_tensor) | |
if len(gt_indexes) == 0: | |
gt_indexes = torch.ones(0).reshape(-1, self.max_text_len) | |
else: | |
gt_indexes = torch.vstack(gt_indexes) | |
gt_indexes = torch.cat([center_pts, gt_indexes], dim=-1) | |
gt_indexes = gt_indexes.flatten() | |
if self.dictionary.start_idx is not None: | |
gt_indexes = torch.cat([ | |
torch.LongTensor([self.dictionary.start_idx]), gt_indexes | |
]) | |
if self.dictionary.seq_end_idx is not None: | |
gt_indexes = torch.cat([ | |
gt_indexes, | |
torch.LongTensor([self.dictionary.seq_end_idx]) | |
]) | |
batch_max_len = max(batch_max_len, len(gt_indexes)) | |
gt_instances.set_metainfo(dict(indexes=gt_indexes)) | |
# Here we have to have the second pass as we need to know the max | |
# length of the batch to pad the indexes in order to save memory | |
for data_sample in data_samples: | |
if data_sample.get('have_target', False): | |
continue | |
indexes = data_sample.gt_instances.indexes | |
padded_indexes = ( | |
torch.zeros(batch_max_len, dtype=torch.long) + | |
self.dictionary.padding_idx) | |
padded_indexes[:len(indexes)] = indexes | |
data_sample.gt_instances.set_metainfo( | |
dict(padded_indexes=padded_indexes)) | |
data_sample.set_metainfo(dict(have_target=True)) | |
return data_samples | |
def forward(self, outputs: torch.Tensor, | |
data_samples: Sequence[TextSpottingDataSample]) -> Dict: | |
""" | |
Args: | |
outputs (Tensor): A raw logit tensor of shape :math:`(N, T, C)`. | |
data_samples (list[TextSpottingDataSample]): List of | |
``TextSpottingDataSample`` which are processed by | |
``get_targets``. | |
Returns: | |
dict: A loss dict with the key ``loss_ce``. | |
""" | |
targets = list() | |
for data_sample in data_samples: | |
targets.append(data_sample.gt_instances.padded_indexes) | |
targets = torch.stack(targets, dim=0).long() | |
if self.ignore_first_char: | |
targets = targets[:, 1:].contiguous() | |
# outputs = outputs[:, :-1, :].contiguous() | |
if self.flatten: | |
outputs = outputs.view(-1, outputs.size(-1)) | |
targets = targets.view(-1) | |
else: | |
outputs = outputs.permute(0, 2, 1).contiguous() | |
loss_ce = self.loss_ce(outputs, targets.to(outputs.device)) | |
losses = dict(loss_ce=loss_ce) | |
return losses | |