sunnychenxiwang's picture
update all
24c4def
raw
history blame
7.51 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from mmengine.structures import InstanceData
from mmocr.models import Dictionary
from mmocr.models.textrecog.postprocessors import BaseTextRecogPostprocessor
from mmocr.registry import MODELS
from mmocr.structures import TextSpottingDataSample
from mmocr.utils import rescale_polygons
@MODELS.register_module()
class SPTSPostprocessor(BaseTextRecogPostprocessor):
"""PostProcessor for SPTS.
Args:
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
num_bins (int): Number of bins dividing the image. Defaults to 1000.
rescale_fields (list[str], optional): The bbox/polygon field names to
be rescaled. If None, no rescaling will be performed.
max_seq_len (int): Maximum sequence length. In SPTS, a sequence
encodes all the text instances in a sample. Defaults to 40, which
will be overridden by SPTSDecoder.
ignore_chars (list[str]): A list of characters to be ignored from the
final results. Postprocessor will skip over these characters when
converting raw indexes to characters. Apart from single characters,
each item can be one of the following reversed keywords: 'padding',
'end' and 'unknown', which refer to their corresponding special
tokens in the dictionary.
"""
def __init__(self,
dictionary: Union[Dictionary, Dict],
num_bins: int,
rescale_fields: Optional[Sequence[str]] = ['points'],
max_seq_len: int = 40,
ignore_chars: Sequence[str] = ['padding'],
**kwargs) -> None:
assert rescale_fields is None or isinstance(rescale_fields, list)
self.num_bins = num_bins
self.rescale_fields = rescale_fields
super().__init__(
dictionary=dictionary,
num_bins=num_bins,
max_seq_len=max_seq_len,
ignore_chars=ignore_chars)
def get_single_prediction(
self,
max_probs: torch.Tensor,
seq: torch.Tensor,
data_sample: Optional[TextSpottingDataSample] = None,
) -> Tuple[List[List[int]], List[List[float]], List[Tuple[float]],
List[Tuple[float]]]:
"""Convert the output probabilities of a single image to character
indexes, character scores, points and point scores.
Args:
max_probs (torch.Tensor): Character probabilities with shape
:math:`(T)`.
seq (torch.Tensor): Sequence indexes with shape
:math:`(T)`.
data_sample (TextSpottingDataSample, optional): Datasample of an
image. Defaults to None.
Returns:
tuple(list[list[int]], list[list[float]], list[(float, float)],
list(float, float)): character indexes, character scores, points
and point scores. Each has len of max_seq_len.
"""
h, w = data_sample.img_shape
# the if is not a must since the softmaxed are masked out in decoder
# if len(max_probs) % 27 != 0:
# max_probs = max_probs[:-len(max_probs) % 27]
# seq = seq[:-len(seq) % 27]
# max_value, max_idx = torch.max(max_probs, -1)
max_probs = max_probs.reshape(-1, 27)
seq = seq.reshape(-1, 27)
indexes, text_scores, points, pt_scores = [], [], [], []
output_indexes = seq.cpu().detach().numpy().tolist()
output_scores = max_probs.cpu().detach().numpy().tolist()
for output_index, output_score in zip(output_indexes, output_scores):
# work for multi-batch
# if output_index[0] == self.dictionary.seq_end_idx +self.num_bins:
# break
point_x = output_index[0] / self.num_bins * w
point_y = output_index[1] / self.num_bins * h
points.append((point_x, point_y))
pt_scores.append(
np.mean([
output_score[0],
output_score[1],
]).item())
indexes.append([])
char_scores = []
for char_index, char_score in zip(output_index[2:],
output_score[2:]):
# the first num_bins indexes are for points
if char_index in self.ignore_indexes:
continue
if char_index == self.dictionary.end_idx:
break
indexes[-1].append(char_index)
char_scores.append(char_score)
text_scores.append(np.mean(char_scores).item())
return indexes, text_scores, points, pt_scores
def __call__(
self, output: Tuple[torch.Tensor, torch.Tensor],
data_samples: Sequence[TextSpottingDataSample]
) -> Sequence[TextSpottingDataSample]:
"""Convert outputs to strings and scores.
Args:
output (tuple(Tensor, Tensor)): A tuple of (probs, seq), each has
the shape of :math:`(T,)`.
data_samples (list[TextSpottingDataSample]): The list of
TextSpottingDataSample.
Returns:
list(TextSpottingDataSample): The list of TextSpottingDataSample.
"""
max_probs, seq = output
batch_size = max_probs.size(0)
for idx in range(batch_size):
(char_idxs, text_scores, points,
pt_scores) = self.get_single_prediction(max_probs[idx, :],
seq[idx, :],
data_samples[idx])
texts = []
scores = []
for index, pt_score in zip(char_idxs, pt_scores):
text = self.dictionary.idx2str(index)
texts.append(text)
# the "scores" field only accepts a float number
scores.append(np.mean(pt_score).item())
pred_instances = InstanceData()
pred_instances.texts = texts
pred_instances.scores = scores
pred_instances.text_scores = text_scores
pred_instances.points = points
data_samples[idx].pred_instances = pred_instances
pred_instances = self.rescale(data_samples[idx],
data_samples[idx].scale_factor)
return data_samples
def rescale(self, results: TextSpottingDataSample,
scale_factor: Sequence[int]) -> TextSpottingDataSample:
"""Rescale results in ``results.pred_instances`` according to
``scale_factor``, whose keys are defined in ``self.rescale_fields``.
Usually used to rescale bboxes and/or polygons.
Args:
results (TextSpottingDataSample): The post-processed prediction
results.
scale_factor (tuple(int)): (w_scale, h_scale)
Returns:
TextDetDataSample: Prediction results with rescaled results.
"""
scale_factor = np.asarray(scale_factor)
for key in self.rescale_fields:
# TODO: this util may need an alias or to be renamed
results.pred_instances[key] = rescale_polygons(
results.pred_instances[key], scale_factor, mode='div')
return results