|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from functools import partial |
|
from itertools import permutations |
|
from typing import Sequence, Any, Optional |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
|
|
from pytorch_lightning.utilities.types import STEP_OUTPUT |
|
from timm.models.helpers import named_apply |
|
|
|
from strhub.models.base import CrossEntropySystem |
|
from strhub.models.utils import init_weights |
|
from .modules import DecoderLayer, Decoder, Encoder, TokenEmbedding |
|
|
|
|
|
class PARSeq(CrossEntropySystem): |
|
|
|
def __init__(self, charset_train: str, charset_test: str, max_label_length: int, |
|
batch_size: int, lr: float, warmup_pct: float, weight_decay: float, |
|
img_size: Sequence[int], patch_size: Sequence[int], embed_dim: int, |
|
enc_num_heads: int, enc_mlp_ratio: int, enc_depth: int, |
|
dec_num_heads: int, dec_mlp_ratio: int, dec_depth: int, |
|
perm_num: int, perm_forward: bool, perm_mirrored: bool, |
|
decode_ar: bool, refine_iters: int, dropout: float, **kwargs: Any) -> None: |
|
super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) |
|
self.save_hyperparameters() |
|
|
|
self.max_label_length = max_label_length |
|
self.decode_ar = decode_ar |
|
self.refine_iters = refine_iters |
|
|
|
self.encoder = Encoder(img_size, patch_size, embed_dim=embed_dim, depth=enc_depth, num_heads=enc_num_heads, |
|
mlp_ratio=enc_mlp_ratio) |
|
decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) |
|
self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=nn.LayerNorm(embed_dim)) |
|
|
|
|
|
self.rng = np.random.default_rng() |
|
self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num |
|
self.perm_forward = perm_forward |
|
self.perm_mirrored = perm_mirrored |
|
|
|
|
|
self.head = nn.Linear(embed_dim, len(self.tokenizer) - 2) |
|
self.text_embed = TokenEmbedding(len(self.tokenizer), embed_dim) |
|
|
|
|
|
self.pos_queries = nn.Parameter(torch.Tensor(1, max_label_length + 1, embed_dim)) |
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
named_apply(partial(init_weights, exclude=['encoder']), self) |
|
nn.init.trunc_normal_(self.pos_queries, std=.02) |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
param_names = {'text_embed.embedding.weight', 'pos_queries'} |
|
enc_param_names = {'encoder.' + n for n in self.encoder.no_weight_decay()} |
|
return param_names.union(enc_param_names) |
|
|
|
def encode(self, img: torch.Tensor): |
|
return self.encoder(img) |
|
|
|
def decode(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[Tensor] = None, |
|
tgt_padding_mask: Optional[Tensor] = None, tgt_query: Optional[Tensor] = None, |
|
tgt_query_mask: Optional[Tensor] = None): |
|
N, L = tgt.shape |
|
|
|
null_ctx = self.text_embed(tgt[:, :1]) |
|
tgt_emb = self.pos_queries[:, :L - 1] + self.text_embed(tgt[:, 1:]) |
|
tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1)) |
|
if tgt_query is None: |
|
tgt_query = self.pos_queries[:, :L].expand(N, -1, -1) |
|
tgt_query = self.dropout(tgt_query) |
|
return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask) |
|
|
|
def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: |
|
testing = max_length is None |
|
max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) |
|
bs = images.shape[0] |
|
|
|
num_steps = max_length + 1 |
|
memory = self.encode(images) |
|
|
|
|
|
pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1) |
|
|
|
|
|
tgt_mask = query_mask = torch.triu(torch.full((num_steps, num_steps), float('-inf'), device=self._device), 1) |
|
|
|
if self.decode_ar: |
|
tgt_in = torch.full((bs, num_steps), self.pad_id, dtype=torch.long, device=self._device) |
|
tgt_in[:, 0] = self.bos_id |
|
|
|
logits = [] |
|
for i in range(num_steps): |
|
j = i + 1 |
|
|
|
|
|
|
|
|
|
tgt_out = self.decode(tgt_in[:, :j], memory, tgt_mask[:j, :j], tgt_query=pos_queries[:, i:j], |
|
tgt_query_mask=query_mask[i:j, :j]) |
|
|
|
p_i = self.head(tgt_out) |
|
logits.append(p_i) |
|
if j < num_steps: |
|
|
|
tgt_in[:, j] = p_i.squeeze().argmax(-1) |
|
|
|
if testing and (tgt_in == self.eos_id).any(dim=-1).all(): |
|
break |
|
|
|
logits = torch.cat(logits, dim=1) |
|
else: |
|
|
|
tgt_in = torch.full((bs, 1), self.bos_id, dtype=torch.long, device=self._device) |
|
tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries) |
|
logits = self.head(tgt_out) |
|
|
|
if self.refine_iters: |
|
|
|
|
|
query_mask[torch.triu(torch.ones(num_steps, num_steps, dtype=torch.bool, device=self._device), 2)] = 0 |
|
bos = torch.full((bs, 1), self.bos_id, dtype=torch.long, device=self._device) |
|
for i in range(self.refine_iters): |
|
|
|
tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1) |
|
tgt_padding_mask = ((tgt_in == self.eos_id).int().cumsum(-1) > 0) |
|
tgt_out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, |
|
tgt_query=pos_queries, tgt_query_mask=query_mask[:, :tgt_in.shape[1]]) |
|
logits = self.head(tgt_out) |
|
|
|
return logits |
|
|
|
def gen_tgt_perms(self, tgt): |
|
"""Generate shared permutations for the whole batch. |
|
This works because the same attention mask can be used for the shorter sequences |
|
because of the padding mask. |
|
""" |
|
|
|
max_num_chars = tgt.shape[1] - 2 |
|
|
|
if max_num_chars == 1: |
|
return torch.arange(3, device=self._device).unsqueeze(0) |
|
perms = [torch.arange(max_num_chars, device=self._device)] if self.perm_forward else [] |
|
|
|
max_perms = math.factorial(max_num_chars) |
|
if self.perm_mirrored: |
|
max_perms //= 2 |
|
num_gen_perms = min(self.max_gen_perms, max_perms) |
|
|
|
|
|
if max_num_chars < 5: |
|
|
|
|
|
if max_num_chars == 4 and self.perm_mirrored: |
|
selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21] |
|
else: |
|
selector = list(range(max_perms)) |
|
perm_pool = torch.as_tensor(list(permutations(range(max_num_chars), max_num_chars)), device=self._device)[selector] |
|
|
|
if self.perm_forward: |
|
perm_pool = perm_pool[1:] |
|
perms = torch.stack(perms) |
|
if len(perm_pool): |
|
i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(perms), replace=False) |
|
perms = torch.cat([perms, perm_pool[i]]) |
|
else: |
|
perms.extend([torch.randperm(max_num_chars, device=self._device) for _ in range(num_gen_perms - len(perms))]) |
|
perms = torch.stack(perms) |
|
if self.perm_mirrored: |
|
|
|
comp = perms.flip(-1) |
|
|
|
perms = torch.stack([perms, comp]).transpose(0, 1).reshape(-1, max_num_chars) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bos_idx = perms.new_zeros((len(perms), 1)) |
|
eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1) |
|
perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1) |
|
|
|
|
|
|
|
if len(perms) > 1: |
|
perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=self._device) |
|
return perms |
|
|
|
def generate_attn_masks(self, perm): |
|
"""Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens) |
|
:param perm: the permutation sequence. i = 0 is always the BOS |
|
:return: lookahead attention masks |
|
""" |
|
sz = perm.shape[0] |
|
mask = torch.zeros((sz, sz), device=self._device) |
|
for i in range(sz): |
|
query_idx = perm[i] |
|
masked_keys = perm[i + 1:] |
|
mask[query_idx, masked_keys] = float('-inf') |
|
content_mask = mask[:-1, :-1].clone() |
|
mask[torch.eye(sz, dtype=torch.bool, device=self._device)] = float('-inf') |
|
query_mask = mask[1:, :-1] |
|
return content_mask, query_mask |
|
|
|
def training_step(self, batch, batch_idx) -> STEP_OUTPUT: |
|
images, labels = batch |
|
tgt = self.tokenizer.encode(labels, self._device) |
|
|
|
|
|
memory = self.encode(images) |
|
|
|
|
|
tgt_perms = self.gen_tgt_perms(tgt) |
|
tgt_in = tgt[:, :-1] |
|
tgt_out = tgt[:, 1:] |
|
|
|
tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id) |
|
|
|
loss = 0 |
|
loss_numel = 0 |
|
n = (tgt_out != self.pad_id).sum().item() |
|
for i, perm in enumerate(tgt_perms): |
|
tgt_mask, query_mask = self.generate_attn_masks(perm) |
|
out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask) |
|
logits = self.head(out).flatten(end_dim=1) |
|
loss += n * F.cross_entropy(logits, tgt_out.flatten(), ignore_index=self.pad_id) |
|
loss_numel += n |
|
|
|
|
|
if i == 1: |
|
tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id, tgt_out) |
|
n = (tgt_out != self.pad_id).sum().item() |
|
loss /= loss_numel |
|
|
|
self.log('loss', loss) |
|
return loss |
|
|