from typing import NamedTuple import yaml from tqdm import tqdm import numpy as np import torch as T from torch import nn from torch.nn import functional as F from diac_utils import flat_2_3head from model_dd import DiacritizerD2 from model_plm import Diacritizer class Readout(nn.Module): def __init__( self, in_size: int, out_size: int, ): super().__init__() self.W1 = nn.Linear(in_size, in_size) self.W2 = nn.Linear(in_size, out_size) def forward(self, x: T.Tensor): z = self.W1(x) z = T.tanh(z) z = self.W2(x) return z class WordDD_LSTM(nn.Module): def __init__( self, feature_size: int, num_classes: int = 13, return_logits: bool = True, ): super().__init__() self.feature_size = feature_size self.num_classes = num_classes self.return_logits = return_logits self.cell = nn.LSTM(feature_size) self.head = Readout(feature_size, num_classes) def forward(self, x: T.Tensor): #^ x: [b tc dc] z = self.cell(x) #^ z: [b tc @dc] y = self.head(z) #^ y: [b tc Classes] yhat = y if not self.return_logits: yhat = F.softmax(yhat, dim=1) #^ yhat: [b tc @Classes] return yhat class PartialDiacOutput(NamedTuple): preds_hard: T.Tensor preds_ctxt_logit: T.Tensor preds_base_logit: T.Tensor class PartialDD(nn.Module): def __init__( self, config: dict, **kwargs ): super().__init__() self._built = False self.no_diac_id = 0 self._dummy = nn.Parameter(T.ones(1, 1)) # with open('./configs/dd/config_d2.yaml', 'r', encoding='utf-8') as fin: # self.config_d2 = yaml.safe_load(fin) # self.device = T.device('cuda' if T.cuda.is_available() else 'cpu') self.config = config self._use_d2 = config["model-name"] == "D2" if self._use_d2: self.sentence_diac = DiacritizerD2(self.config) else: self.sentence_diac = Diacritizer(self.config, load_pretrained=False) # self.sentence_diac.to(self.device) # self.build() # self.word_diac = WordDD_LSTM(feature_size, num_classes=13, return_logits=False) self.eval() @property def device(self): return self._dummy.device @property def tokenizer(self): return self.sentence_diac.tokenizer def load_state_dict( self, state_dict: dict, strict: bool = True, ): self.sentence_diac.load_state_dict(state_dict, strict=strict) def _slim_batch( self, toke_ids: T.Tensor, char_ids: T.Tensor, diac_ids: T.Tensor, subword_lengths: T.Tensor, ): #^ toke_ids: [b tt] #^ char_ids: [b tw tc] #^ diac_ids: [b tw tc "13"] #^ subword_lengths: [b tw] token_nonpad_mask = toke_ids.ne(self.tokenizer.pad_token_id) Ttoken = token_nonpad_mask.sum(1).max() toke_ids = toke_ids[:, :Ttoken] char_nonpad_mask = char_ids.ne(0) Tword = char_nonpad_mask.any(2).sum(1).max() Tchar = char_nonpad_mask.sum(2).max() char_ids = char_ids[:, :Tword, :Tchar] diac_ids = diac_ids[:, :Tword, :Tchar] subword_lengths = subword_lengths[:, :Tword] return toke_ids, char_ids, diac_ids, subword_lengths T.jit.export def word_diac( self, toke_ids: T.Tensor, char_ids: T.Tensor, diac_ids: T.Tensor, subword_lengths: T.Tensor, *, shape: tuple = None, ): if shape is None: toke_ids, char_ids, diac_ids, subword_lengths = self._slim_batch( toke_ids, char_ids, diac_ids, subword_lengths ) else: Nb, Tw, Tc = shape toke_ids = toke_ids[:, :] char_ids = char_ids[:, :Tw, :Tc] diac_ids = diac_ids[:, :Tw, :Tc, :] subword_lengths = subword_lengths[:, :Tw] Nb, Tw, Tc = char_ids.shape # Tw = min(Tw, word_ids.shape[1]) #^ word_ids: [b tt] #^ char_ids: [b tw tc] # wids_flat = word_ids[:, Tw].reshape(Nb * Tw, 1) # cids_flat = char_ids[:, Tw].reshape(Nb * Tw, 1, Tc) # z = self.sentence_diac(wids_flat, cids_flat) sent_word_strides = subword_lengths.cumsum(1) assert tuple(subword_lengths.shape) == (Nb, Tw), f"{subword_lengths.shape} != {(Nb, Tw)=}" max_tokens_per_word: int = subword_lengths.max().int().item() word_x = T.zeros(Nb, Tw, max_tokens_per_word).to(toke_ids) for i_b in range(toke_ids.shape[0]): sent_i = toke_ids[i_b] start_iw = 0 for i_word, end_iw in enumerate(sent_word_strides[i_b]): if end_iw == start_iw: break word = sent_i[start_iw:end_iw] word_x[i_b, i_word, 0 : end_iw - start_iw] = word start_iw = end_iw #^ word_x: [b tw tt] word_x = word_x.reshape(Nb * Tw, max_tokens_per_word) cids_flat = char_ids.reshape(Nb * Tw, 1, Tc) word_lengths = subword_lengths.reshape(Nb * Tw, 1) z = self.sentence_diac( word_x, cids_flat, diac_ids.reshape(Nb*Tw, Tc, -1), subword_lengths=word_lengths, ) # Nc = z.shape[-1] #^ z: [b*tw, 1, tc, "13"] z = z.reshape(Nb, Tw, Tc, -1) return z T.jit.ignore def forward( self, word_ids: T.Tensor, char_ids: T.Tensor, _labels: T.Tensor, # ground_truth: T.Tensor, # padding_mask: T.BoolTensor, *, eval_only: str = None, subword_lengths: T.Tensor, return_extra: bool = False, do_partial: bool = False, ): # assert self._built and not self.training assert not self.training #^ word_ids: [b tw] #^ char_ids: [b tw tc] #^ ground_truth: [b tw tc] padding_mask = char_ids.eq(0) #^ padding_mask: [b tw tc] if True or eval_only != 'base': y_ctxt = self.sentence_diac( word_ids, char_ids, _labels, subword_lengths=subword_lengths, ) out_shape = y_ctxt.shape[:-1] else: out_shape = self.sentence_diac._slim_batch_size( word_ids, char_ids, _labels, subword_lengths, )[1].shape #^ y_ctxt: [b tw tc "13"] if eval_only == 'ctxt': return y_ctxt.argmax(-1) y_base = self.word_diac( word_ids, char_ids, _labels, subword_lengths, shape=out_shape ) #^ y_base: [b tw tc "13"] if eval_only == 'base': return y_base.argmax(-1) #! TODO: Return the logits. ypred_ctxt = y_ctxt.argmax(-1) ypred_base = y_base.argmax(-1) #^ ypred: [b tw tc _] # Maybe for eval # ypred_ctxt[~((ypred_base == ground_truth) & (~padding_mask))] = self.no_diac_id # return ypred_ctxt if do_partial: ypred_ctxt[(padding_mask) | (ypred_base == ypred_ctxt)] = self.no_diac_id if not return_extra: return ypred_ctxt else: return PartialDiacOutput(ypred_ctxt, y_ctxt, y_base) def step(self, xt, yt, mask=None): raise NotImplementedError xt[1] = xt[1].to(self.device) xt[2] = xt[2].to(self.device) yt = yt.to(self.device) #^ yt: [b ts tw] diac, _ = self(*xt) # xt: (word_ids, char_ids, _labels) loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1)) return loss def predict_partial( self, dataloader, return_extra=False, eval_only: str = None, do_partial=True, ): training = self.training self.eval() preds = { 'haraka': [], 'shadda': [], 'tanween': [], 'diacs': [], 'y_ctxt': [], 'y_base': [], 'subword_lengths': [], } print("> Predicting...") # breakpoint() for i_batch, (inputs, _, subword_lengths) in enumerate(tqdm(dataloader)): # if i_batch > 10: # break #^ inputs: [toke_ids, char_ids, diac_ids] inputs[0] = inputs[0].to(self.device) #< toke_ids inputs[1] = inputs[1].to(self.device) #< char_ids # inputs[2] = inputs[2].to(self.device) #< diac_ids if self._use_d2: subword_lengths = T.ones_like(inputs[0]) subword_lengths[inputs[0] == 0] = 0 with T.no_grad(): output = self( *inputs, subword_lengths=subword_lengths, return_extra=return_extra, eval_only=eval_only, do_partial=do_partial, ) # output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1) if return_extra: assert isinstance(output, PartialDiacOutput) marks = output.preds_hard if eval_only == 'recalibrated': marks = (output.preds_ctxt_logit + output.preds_base_logit).argmax(-1) preds['diacs'].extend(list(marks.detach().cpu().numpy())) preds['y_ctxt'].extend(list(output.preds_ctxt_logit.detach().cpu().numpy())) preds['y_base'].extend(list(output.preds_base_logit.detach().cpu().numpy())) preds['subword_lengths'].extend(list(subword_lengths.detach().cpu().numpy())) else: assert isinstance(output, T.Tensor) marks = output preds['diacs'].extend(list(marks.detach().cpu().numpy())) #^ [b ts tw] haraka, tanween, shadda = flat_2_3head(marks) preds['haraka'].extend(haraka) preds['tanween'].extend(tanween) preds['shadda'].extend(shadda) self.train(training) return { 'diacritics': ( #! FIXME! Due to batch slimming, output diacritics may need padding. np.array(preds['haraka']), np.array(preds["tanween"]), np.array(preds["shadda"]), ), 'other': ( # Would be empty when !return_extra np.array(preds['y_ctxt']), np.array(preds['y_base']), np.array(preds['diacs']), np.array(preds['subword_lengths']), ) } def predict(self, dataloader): training = self.training self.eval() preds = {'haraka': [], 'shadda': [], 'tanween': []} print("> Predicting...") for inputs, _ in tqdm(dataloader, total=len(dataloader)): inputs[0] = inputs[0].to(self.device) inputs[1] = inputs[1].to(self.device) output = self(*inputs) # output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1) marks = output #^ [b ts tw] haraka, tanween, shadda = flat_2_3head(marks) preds['haraka'].extend(haraka) preds['tanween'].extend(tanween) preds['shadda'].extend(shadda) self.train(training) return ( np.array(preds['haraka']), np.array(preds["tanween"]), np.array(preds["shadda"]), )