Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import random | |
from typing import Dict, Optional, Sequence, Union | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from mmocr.models.common.dictionary import Dictionary | |
from mmocr.models.textrecog.decoders.base import BaseDecoder | |
from mmocr.registry import MODELS | |
from mmocr.structures import TextRecogDataSample | |
class ABCNetRecDecoder(BaseDecoder): | |
"""Decoder for ABCNet. | |
Args: | |
in_channels (int): Number of input channels. | |
dropout_prob (float): Probability of dropout. Default to 0.5. | |
teach_prob (float): Probability of teacher forcing. Defaults to 0.5. | |
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or | |
the instance of `Dictionary`. | |
module_loss (dict, optional): Config to build module_loss. Defaults | |
to None. | |
postprocessor (dict, optional): Config to build postprocessor. | |
Defaults to None. | |
max_seq_len (int, optional): Max sequence length. Defaults to 30. | |
init_cfg (dict or list[dict], optional): Initialization configs. | |
Defaults to None. | |
""" | |
def __init__(self, | |
in_channels: int = 256, | |
dropout_prob: float = 0.5, | |
teach_prob: float = 0.5, | |
dictionary: Union[Dictionary, Dict] = None, | |
module_loss: Dict = None, | |
postprocessor: Dict = None, | |
max_seq_len: int = 30, | |
init_cfg=dict(type='Xavier', layer='Conv2d'), | |
**kwargs): | |
super().__init__( | |
init_cfg=init_cfg, | |
dictionary=dictionary, | |
module_loss=module_loss, | |
postprocessor=postprocessor, | |
max_seq_len=max_seq_len) | |
self.in_channels = in_channels | |
self.teach_prob = teach_prob | |
self.embedding = nn.Embedding(self.dictionary.num_classes, in_channels) | |
self.attn_combine = nn.Linear(in_channels * 2, in_channels) | |
self.dropout = nn.Dropout(dropout_prob) | |
self.gru = nn.GRU(in_channels, in_channels) | |
self.out = nn.Linear(in_channels, self.dictionary.num_classes) | |
self.vat = nn.Linear(in_channels, 1) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward_train( | |
self, | |
feat: torch.Tensor, | |
out_enc: Optional[torch.Tensor] = None, | |
data_samples: Optional[Sequence[TextRecogDataSample]] = None | |
) -> torch.Tensor: | |
""" | |
Args: | |
feat (Tensor): A Tensor of shape :math:`(N, C, 1, W)`. | |
out_enc (torch.Tensor, optional): Encoder output. Defaults to None. | |
data_samples (list[TextRecogDataSample], optional): Batch of | |
TextRecogDataSample, containing gt_text information. Defaults | |
to None. | |
Returns: | |
Tensor: The raw logit tensor. Shape :math:`(N, W, C)` where | |
:math:`C` is ``num_classes``. | |
""" | |
bs = out_enc.size()[1] | |
trg_seq = [] | |
for target in data_samples: | |
trg_seq.append(target.gt_text.padded_indexes.to(feat.device)) | |
decoder_input = torch.zeros(bs).long().to(out_enc.device) | |
trg_seq = torch.stack(trg_seq, dim=0) | |
decoder_hidden = torch.zeros(1, bs, | |
self.in_channels).to(out_enc.device) | |
decoder_outputs = [] | |
for index in range(trg_seq.shape[1]): | |
# decoder_output (nbatch, ncls) | |
decoder_output, decoder_hidden = self._attention( | |
decoder_input, decoder_hidden, out_enc) | |
teach_forcing = True if random.random( | |
) > self.teach_prob else False | |
if teach_forcing: | |
decoder_input = trg_seq[:, index] # Teacher forcing | |
else: | |
_, topi = decoder_output.data.topk(1) | |
decoder_input = topi.squeeze() | |
decoder_outputs.append(decoder_output) | |
return torch.stack(decoder_outputs, dim=1) | |
def forward_test( | |
self, | |
feat: Optional[torch.Tensor] = None, | |
out_enc: Optional[torch.Tensor] = None, | |
data_samples: Optional[Sequence[TextRecogDataSample]] = None | |
) -> torch.Tensor: | |
""" | |
Args: | |
feat (Tensor): A Tensor of shape :math:`(N, C, 1, W)`. | |
out_enc (torch.Tensor, optional): Encoder output. Defaults to None. | |
data_samples (list[TextRecogDataSample]): Batch of | |
TextRecogDataSample, containing ``gt_text`` information. | |
Defaults to None. | |
Returns: | |
Tensor: Character probabilities. of shape | |
:math:`(N, self.max_seq_len, C)` where :math:`C` is | |
``num_classes``. | |
""" | |
bs = out_enc.size()[1] | |
outputs = [] | |
decoder_input = torch.zeros(bs).long().to(out_enc.device) | |
decoder_hidden = torch.zeros(1, bs, | |
self.in_channels).to(out_enc.device) | |
for _ in range(self.max_seq_len): | |
# decoder_output (nbatch, ncls) | |
decoder_output, decoder_hidden = self._attention( | |
decoder_input, decoder_hidden, out_enc) | |
_, topi = decoder_output.data.topk(1) | |
decoder_input = topi.squeeze() | |
outputs.append(decoder_output) | |
outputs = torch.stack(outputs, dim=1) | |
return self.softmax(outputs) | |
def _attention(self, input, hidden, encoder_outputs): | |
embedded = self.embedding(input) | |
embedded = self.dropout(embedded) | |
# test | |
batch_size = encoder_outputs.shape[1] | |
alpha = hidden + encoder_outputs | |
alpha = alpha.view(-1, alpha.shape[-1]) # (T * n, hidden_size) | |
attn_weights = self.vat(torch.tanh(alpha)) # (T * n, 1) | |
attn_weights = attn_weights.view(-1, 1, batch_size).permute( | |
(2, 1, 0)) # (T, 1, n) -> (n, 1, T) | |
attn_weights = F.softmax(attn_weights, dim=2) | |
attn_applied = torch.matmul(attn_weights, | |
encoder_outputs.permute((1, 0, 2))) | |
if embedded.dim() == 1: | |
embedded = embedded.unsqueeze(0) | |
output = torch.cat((embedded, attn_applied.squeeze(1)), 1) | |
output = self.attn_combine(output).unsqueeze(0) # (1, n, hidden_size) | |
output = F.relu(output) | |
output, hidden = self.gru(output, hidden) # (1, n, hidden_size) | |
output = self.out(output[0]) | |
return output, hidden | |