sunnychenxiwang's picture
Upload 1600 files
14c9181 verified
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import Dict, Optional, Sequence, Union
import torch
import torch.nn as nn
from torch.nn import functional as F
from mmocr.models.common.dictionary import Dictionary
from mmocr.models.textrecog.decoders.base import BaseDecoder
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
@MODELS.register_module()
class ABCNetRecDecoder(BaseDecoder):
"""Decoder for ABCNet.
Args:
in_channels (int): Number of input channels.
dropout_prob (float): Probability of dropout. Default to 0.5.
teach_prob (float): Probability of teacher forcing. Defaults to 0.5.
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
module_loss (dict, optional): Config to build module_loss. Defaults
to None.
postprocessor (dict, optional): Config to build postprocessor.
Defaults to None.
max_seq_len (int, optional): Max sequence length. Defaults to 30.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
"""
def __init__(self,
in_channels: int = 256,
dropout_prob: float = 0.5,
teach_prob: float = 0.5,
dictionary: Union[Dictionary, Dict] = None,
module_loss: Dict = None,
postprocessor: Dict = None,
max_seq_len: int = 30,
init_cfg=dict(type='Xavier', layer='Conv2d'),
**kwargs):
super().__init__(
init_cfg=init_cfg,
dictionary=dictionary,
module_loss=module_loss,
postprocessor=postprocessor,
max_seq_len=max_seq_len)
self.in_channels = in_channels
self.teach_prob = teach_prob
self.embedding = nn.Embedding(self.dictionary.num_classes, in_channels)
self.attn_combine = nn.Linear(in_channels * 2, in_channels)
self.dropout = nn.Dropout(dropout_prob)
self.gru = nn.GRU(in_channels, in_channels)
self.out = nn.Linear(in_channels, self.dictionary.num_classes)
self.vat = nn.Linear(in_channels, 1)
self.softmax = nn.Softmax(dim=-1)
def forward_train(
self,
feat: torch.Tensor,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> torch.Tensor:
"""
Args:
feat (Tensor): A Tensor of shape :math:`(N, C, 1, W)`.
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
data_samples (list[TextRecogDataSample], optional): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
Returns:
Tensor: The raw logit tensor. Shape :math:`(N, W, C)` where
:math:`C` is ``num_classes``.
"""
bs = out_enc.size()[1]
trg_seq = []
for target in data_samples:
trg_seq.append(target.gt_text.padded_indexes.to(feat.device))
decoder_input = torch.zeros(bs).long().to(out_enc.device)
trg_seq = torch.stack(trg_seq, dim=0)
decoder_hidden = torch.zeros(1, bs,
self.in_channels).to(out_enc.device)
decoder_outputs = []
for index in range(trg_seq.shape[1]):
# decoder_output (nbatch, ncls)
decoder_output, decoder_hidden = self._attention(
decoder_input, decoder_hidden, out_enc)
teach_forcing = True if random.random(
) > self.teach_prob else False
if teach_forcing:
decoder_input = trg_seq[:, index] # Teacher forcing
else:
_, topi = decoder_output.data.topk(1)
decoder_input = topi.squeeze()
decoder_outputs.append(decoder_output)
return torch.stack(decoder_outputs, dim=1)
def forward_test(
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> torch.Tensor:
"""
Args:
feat (Tensor): A Tensor of shape :math:`(N, C, 1, W)`.
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
data_samples (list[TextRecogDataSample]): Batch of
TextRecogDataSample, containing ``gt_text`` information.
Defaults to None.
Returns:
Tensor: Character probabilities. of shape
:math:`(N, self.max_seq_len, C)` where :math:`C` is
``num_classes``.
"""
bs = out_enc.size()[1]
outputs = []
decoder_input = torch.zeros(bs).long().to(out_enc.device)
decoder_hidden = torch.zeros(1, bs,
self.in_channels).to(out_enc.device)
for _ in range(self.max_seq_len):
# decoder_output (nbatch, ncls)
decoder_output, decoder_hidden = self._attention(
decoder_input, decoder_hidden, out_enc)
_, topi = decoder_output.data.topk(1)
decoder_input = topi.squeeze()
outputs.append(decoder_output)
outputs = torch.stack(outputs, dim=1)
return self.softmax(outputs)
def _attention(self, input, hidden, encoder_outputs):
embedded = self.embedding(input)
embedded = self.dropout(embedded)
# test
batch_size = encoder_outputs.shape[1]
alpha = hidden + encoder_outputs
alpha = alpha.view(-1, alpha.shape[-1]) # (T * n, hidden_size)
attn_weights = self.vat(torch.tanh(alpha)) # (T * n, 1)
attn_weights = attn_weights.view(-1, 1, batch_size).permute(
(2, 1, 0)) # (T, 1, n) -> (n, 1, T)
attn_weights = F.softmax(attn_weights, dim=2)
attn_applied = torch.matmul(attn_weights,
encoder_outputs.permute((1, 0, 2)))
if embedded.dim() == 1:
embedded = embedded.unsqueeze(0)
output = torch.cat((embedded, attn_applied.squeeze(1)), 1)
output = self.attn_combine(output).unsqueeze(0) # (1, n, hidden_size)
output = F.relu(output)
output, hidden = self.gru(output, hidden) # (1, n, hidden_size)
output = self.out(output[0])
return output, hidden