sunnychenxiwang's picture
update all
24c4def
raw
history blame
20.9 kB
import copy
from typing import Dict, List, Optional, Sequence, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from mmocr.models.common import Dictionary
from mmocr.models.textrecog.decoders import BaseDecoder
from mmocr.registry import MODELS
from mmocr.utils.typing_utils import TextSpottingDataSample
from .position_embedding import PositionEmbeddingSine
@MODELS.register_module()
class SPTSDecoder(BaseDecoder):
"""SPTS Decoder.
Args:
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
num_bins (int): Number of bins dividing the image. Defaults to 1000.
n_head (int): Number of parallel attention heads. Defaults to 8.
d_model (int): Dimension :math:`D_m` of the input from previous model.
Defaults to 256.
d_feedforward (int): Dimension of the feedforward layer.
Defaults to 1024.
normalize_before (bool): Whether to normalize the input before
encoding/decoding. Defaults to True.
max_num_text (int): Maximum number of text instances in a sample.
Defaults to 60.
module_loss (dict, optional): Config to build loss. Defaults to None.
postprocessor (dict, optional): Config to build postprocessor.
Defaults to None.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
"""
def __init__(self,
dictionary: Union[Dict, Dictionary],
num_bins: int = 1000,
n_head: int = 8,
d_model: int = 256,
d_feedforward: int = 1024,
normalize_before: bool = True,
dropout: float = 0.1,
max_num_text: int = 60,
module_loss: Optional[Dict] = None,
postprocessor: Optional[Dict] = None,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
# TODO: fix hardcode
self.max_seq_len = (2 + 25) * max_num_text + 1
super().__init__(
dictionary=dictionary,
module_loss=module_loss,
postprocessor=postprocessor,
max_seq_len=self.max_seq_len,
init_cfg=init_cfg)
self.num_bins = num_bins
self.embedding = DecoderEmbeddings(self.dictionary.num_classes,
self.dictionary.padding_idx,
d_model, self.max_seq_len, dropout)
self.pos_embedding = PositionEmbeddingSine(d_model // 2)
self.vocab_embed = self._gen_vocab_embed(d_model, d_model,
self.dictionary.num_classes,
3)
encoder_layer = TransformerEncoderLayer(d_model, n_head, d_feedforward,
dropout, 'relu',
normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
num_encoder_layers = 6
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers,
encoder_norm)
decoder_layer = TransformerDecoderLayer(d_model, n_head, d_feedforward,
dropout, 'relu',
normalize_before)
decoder_norm = nn.LayerNorm(d_model)
num_decoder_layers = 6
self.decoder = TransformerDecoder(
decoder_layer,
num_decoder_layers,
decoder_norm,
return_intermediate=False)
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def _gen_vocab_embed(self, input_dim: int, hidden_dim: int,
output_dim: int, num_layers: int) -> nn.Module:
"""Generate vocab embedding layer."""
net = nn.Sequential()
h = [hidden_dim] * (num_layers - 1)
for i, (n, k) in enumerate(zip([input_dim] + h, h + [output_dim])):
net.add_module(f'layer-{i}', nn.Linear(n, k))
if i < num_layers - 1:
net.add_module(f'relu-{i}', nn.ReLU())
return net
def forward_train(
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextSpottingDataSample]] = None
) -> torch.Tensor:
"""Forward for training.
Args:
feat (torch.Tensor, optional): The feature map from backbone of
shape :math:`(N, E, H, W)`. Defaults to None.
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
data_samples (Sequence[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
"""
mask, pos_embed, memory, query_embed = self._embed(
out_enc, data_samples)
padded_targets = [
data_sample.gt_instances.padded_indexes
for data_sample in data_samples
]
padded_targets = torch.stack(padded_targets, dim=0).to(out_enc.device)
# we don't need eos here
tgt = self.embedding(padded_targets[:, :-1]).permute(1, 0, 2)
hs = self.decoder(
tgt,
memory,
memory_key_padding_mask=mask,
pos=pos_embed,
query_pos=query_embed[:len(tgt)],
tgt_mask=self._generate_square_subsequent_mask(len(tgt)).to(
tgt.device))
return self.vocab_embed(hs[-1].transpose(0, 1))
def forward_test(
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextSpottingDataSample]] = None
) -> torch.Tensor:
"""Forward for testing.
Args:
feat (torch.Tensor, optional): The feature map from backbone of
shape :math:`(N, E, H, W)`. Defaults to None.
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
data_samples (Sequence[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
"""
batch_size = out_enc.shape[0]
mask, pos_embed, memory, query_embed = self._embed(
out_enc, data_samples)
max_probs = []
seq = torch.zeros(
batch_size, 1, dtype=torch.long).to(
out_enc.device) + self.dictionary.start_idx
for i in range(self.max_seq_len):
tgt = self.embedding(seq).permute(1, 0, 2)
hs = self.decoder(
tgt,
memory,
memory_key_padding_mask=mask,
pos=pos_embed,
query_pos=query_embed[:len(tgt)],
tgt_mask=self._generate_square_subsequent_mask(len(tgt)).to(
tgt.device)) # bs, 1, E ?
out = self.vocab_embed(hs.transpose(1, 2)[-1, :, -1, :])
out = out.softmax(-1)
# bins chars unk eos seq_eos sos padding
if i % 27 == 0: # coordinate or eos
out[:, self.num_bins:self.dictionary.seq_end_idx] = 0
out[:, self.dictionary.seq_end_idx + 1:] = 0
elif i % 27 == 1: # coordinate
out[:, self.num_bins:] = 0
else: # chars
out[:, :self.num_bins] = 0
out[:, self.dictionary.seq_end_idx:] = 0
max_prob, extra_seq = torch.max(out, dim=-1, keepdim=True)
# prob, extra_seq = out.topk(dim=-1, k=1)
# work for single batch only (original implementation)
# TODO: optimize for multi-batch
seq = torch.cat([seq, extra_seq], dim=-1)
max_probs.append(max_prob)
if extra_seq[0] == self.dictionary.seq_end_idx:
break
max_probs = torch.cat(max_probs, dim=-1)
max_probs = max_probs[:, :-1] # remove seq_eos
seq = seq[:, 1:-1] # remove start index and seq_eos
return max_probs, seq
def _embed(self, out_enc, data_samples):
bs, c, h, w = out_enc.shape
mask, pos_embed = self._gen_mask(out_enc, data_samples)
out_enc = out_enc.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
mask = mask.flatten(1)
# TODO move encoder to mmcv
memory = self.encoder(
out_enc, src_key_padding_mask=mask, pos=pos_embed.half())
query_embed = self.embedding.position_embeddings.weight.unsqueeze(1)
query_embed = query_embed.repeat(1, bs, 1)
return mask, pos_embed, memory, query_embed
def _generate_square_subsequent_mask(self, size):
r"""Generate a square mask for the sequence. The masked positions are
filled with float('-inf'). Unmasked positions are filled with
float(0.0).
"""
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
mask == 1, float(0.0))
return mask
def _gen_mask(self, out_enc, data_samples):
bs, _, h, w = out_enc.shape
masks = torch.ones((bs, h, w), dtype=bool, device=out_enc.device)
for i, data_sample in enumerate(data_samples):
img_h, img_w = data_sample.img_shape
masks[i, :img_h, :img_w] = False
masks = F.interpolate(
masks[None].float(), size=(h, w)).to(torch.bool)[0]
return masks, self.pos_embedding(masks)
class DecoderEmbeddings(nn.Module):
def __init__(self, num_classes: int, padding_idx: int, hidden_dim,
max_position_embeddings, dropout):
super(DecoderEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(
num_classes, hidden_dim, padding_idx=padding_idx)
self.position_embeddings = nn.Embedding(max_position_embeddings,
hidden_dim)
self.LayerNorm = torch.nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
input_shape = x.size()
seq_length = input_shape[1]
device = x.device
position_ids = torch.arange(
seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
input_embeds = self.word_embeddings(x)
position_embeds = self.position_embeddings(position_ids)
embeddings = input_embeds + position_embeds
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self,
src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
output = src
for layer in self.layers:
output = layer(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
pos=pos)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerDecoder(nn.Module):
def __init__(self,
decoder_layer,
num_layers,
norm=None,
return_intermediate=False):
super(TransformerDecoder, self).__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
output = tgt
for layer in self.layers:
output = layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos)
if self.norm is not None:
# nn.LayerNorm(d_model)
output = self.norm(output)
return output.unsqueeze(0)
class TransformerEncoderLayer(nn.Module):
def __init__(self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation='relu',
normalize_before=False):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(
q,
k,
value=src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(
q,
k,
value=src2,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module):
def __init__(self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation='relu',
normalize_before=False):
super(TransformerDecoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(
q,
k,
value=tgt,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(
q,
k,
value=tgt2,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask,
pos, query_pos)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string."""
if activation == 'relu':
return F.relu
if activation == 'gelu':
return F.gelu
if activation == 'glu':
return F.glu
raise RuntimeError(F'activation should be relu/gelu, not {activation}.')