|
from typing import NamedTuple |
|
import yaml |
|
from tqdm import tqdm |
|
import numpy as np |
|
|
|
import torch as T |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from diac_utils import flat_2_3head |
|
|
|
from model_dd import DiacritizerD2 |
|
from model_plm import Diacritizer |
|
|
|
class Readout(nn.Module): |
|
def __init__( |
|
self, |
|
in_size: int, |
|
out_size: int, |
|
): |
|
super().__init__() |
|
self.W1 = nn.Linear(in_size, in_size) |
|
self.W2 = nn.Linear(in_size, out_size) |
|
|
|
def forward(self, x: T.Tensor): |
|
z = self.W1(x) |
|
z = T.tanh(z) |
|
z = self.W2(x) |
|
return z |
|
|
|
class WordDD_LSTM(nn.Module): |
|
def __init__( |
|
self, |
|
feature_size: int, |
|
num_classes: int = 13, |
|
return_logits: bool = True, |
|
): |
|
super().__init__() |
|
self.feature_size = feature_size |
|
self.num_classes = num_classes |
|
self.return_logits = return_logits |
|
self.cell = nn.LSTM(feature_size) |
|
self.head = Readout(feature_size, num_classes) |
|
|
|
def forward(self, x: T.Tensor): |
|
|
|
z = self.cell(x) |
|
|
|
y = self.head(z) |
|
|
|
yhat = y |
|
if not self.return_logits: |
|
yhat = F.softmax(yhat, dim=1) |
|
|
|
return yhat |
|
|
|
class PartialDiacOutput(NamedTuple): |
|
preds_hard: T.Tensor |
|
preds_ctxt_logit: T.Tensor |
|
preds_base_logit: T.Tensor |
|
|
|
class PartialDD(nn.Module): |
|
def __init__( |
|
self, |
|
config: dict, |
|
**kwargs |
|
): |
|
super().__init__() |
|
self._built = False |
|
self.no_diac_id = 0 |
|
self._dummy = nn.Parameter(T.ones(1, 1)) |
|
|
|
|
|
|
|
self.config = config |
|
self._use_d2 = config["model-name"] == "D2" |
|
if self._use_d2: |
|
self.sentence_diac = DiacritizerD2(self.config) |
|
else: |
|
self.sentence_diac = Diacritizer(self.config, load_pretrained=False) |
|
|
|
|
|
|
|
|
|
self.eval() |
|
|
|
@property |
|
def device(self): |
|
return self._dummy.device |
|
|
|
@property |
|
def tokenizer(self): |
|
return self.sentence_diac.tokenizer |
|
|
|
def load_state_dict( |
|
self, |
|
state_dict: dict, |
|
strict: bool = True, |
|
): |
|
self.sentence_diac.load_state_dict(state_dict, strict=strict) |
|
|
|
def _slim_batch( |
|
self, |
|
toke_ids: T.Tensor, |
|
char_ids: T.Tensor, |
|
diac_ids: T.Tensor, |
|
subword_lengths: T.Tensor, |
|
): |
|
|
|
|
|
|
|
|
|
token_nonpad_mask = toke_ids.ne(self.tokenizer.pad_token_id) |
|
Ttoken = token_nonpad_mask.sum(1).max() |
|
toke_ids = toke_ids[:, :Ttoken] |
|
|
|
char_nonpad_mask = char_ids.ne(0) |
|
Tword = char_nonpad_mask.any(2).sum(1).max() |
|
Tchar = char_nonpad_mask.sum(2).max() |
|
char_ids = char_ids[:, :Tword, :Tchar] |
|
diac_ids = diac_ids[:, :Tword, :Tchar] |
|
subword_lengths = subword_lengths[:, :Tword] |
|
|
|
return toke_ids, char_ids, diac_ids, subword_lengths |
|
|
|
T.jit.export |
|
def word_diac( |
|
self, |
|
toke_ids: T.Tensor, |
|
char_ids: T.Tensor, |
|
diac_ids: T.Tensor, |
|
subword_lengths: T.Tensor, |
|
*, |
|
shape: tuple = None, |
|
): |
|
if shape is None: |
|
toke_ids, char_ids, diac_ids, subword_lengths = self._slim_batch( |
|
toke_ids, char_ids, diac_ids, subword_lengths |
|
) |
|
else: |
|
Nb, Tw, Tc = shape |
|
toke_ids = toke_ids[:, :] |
|
char_ids = char_ids[:, :Tw, :Tc] |
|
diac_ids = diac_ids[:, :Tw, :Tc, :] |
|
subword_lengths = subword_lengths[:, :Tw] |
|
Nb, Tw, Tc = char_ids.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sent_word_strides = subword_lengths.cumsum(1) |
|
assert tuple(subword_lengths.shape) == (Nb, Tw), f"{subword_lengths.shape} != {(Nb, Tw)=}" |
|
max_tokens_per_word: int = subword_lengths.max().int().item() |
|
word_x = T.zeros(Nb, Tw, max_tokens_per_word).to(toke_ids) |
|
for i_b in range(toke_ids.shape[0]): |
|
sent_i = toke_ids[i_b] |
|
start_iw = 0 |
|
for i_word, end_iw in enumerate(sent_word_strides[i_b]): |
|
if end_iw == start_iw: break |
|
word = sent_i[start_iw:end_iw] |
|
word_x[i_b, i_word, 0 : end_iw - start_iw] = word |
|
start_iw = end_iw |
|
|
|
word_x = word_x.reshape(Nb * Tw, max_tokens_per_word) |
|
cids_flat = char_ids.reshape(Nb * Tw, 1, Tc) |
|
word_lengths = subword_lengths.reshape(Nb * Tw, 1) |
|
|
|
z = self.sentence_diac( |
|
word_x, |
|
cids_flat, |
|
diac_ids.reshape(Nb*Tw, Tc, -1), |
|
subword_lengths=word_lengths, |
|
) |
|
|
|
|
|
z = z.reshape(Nb, Tw, Tc, -1) |
|
return z |
|
|
|
T.jit.ignore |
|
def forward( |
|
self, |
|
word_ids: T.Tensor, |
|
char_ids: T.Tensor, |
|
_labels: T.Tensor, |
|
|
|
|
|
*, |
|
eval_only: str = None, |
|
subword_lengths: T.Tensor, |
|
return_extra: bool = False, |
|
do_partial: bool = False, |
|
): |
|
|
|
assert not self.training |
|
|
|
|
|
|
|
|
|
padding_mask = char_ids.eq(0) |
|
|
|
|
|
if True or eval_only != 'base': |
|
y_ctxt = self.sentence_diac( |
|
word_ids, |
|
char_ids, |
|
_labels, |
|
subword_lengths=subword_lengths, |
|
) |
|
out_shape = y_ctxt.shape[:-1] |
|
else: |
|
out_shape = self.sentence_diac._slim_batch_size( |
|
word_ids, |
|
char_ids, |
|
_labels, |
|
subword_lengths, |
|
)[1].shape |
|
|
|
if eval_only == 'ctxt': |
|
return y_ctxt.argmax(-1) |
|
|
|
y_base = self.word_diac( |
|
word_ids, |
|
char_ids, |
|
_labels, |
|
subword_lengths, |
|
shape=out_shape |
|
) |
|
|
|
if eval_only == 'base': |
|
return y_base.argmax(-1) |
|
|
|
|
|
ypred_ctxt = y_ctxt.argmax(-1) |
|
ypred_base = y_base.argmax(-1) |
|
|
|
|
|
|
|
|
|
|
|
if do_partial: |
|
ypred_ctxt[(padding_mask) | (ypred_base == ypred_ctxt)] = self.no_diac_id |
|
|
|
if not return_extra: |
|
return ypred_ctxt |
|
else: |
|
return PartialDiacOutput(ypred_ctxt, y_ctxt, y_base) |
|
|
|
def step(self, xt, yt, mask=None): |
|
raise NotImplementedError |
|
xt[1] = xt[1].to(self.device) |
|
xt[2] = xt[2].to(self.device) |
|
|
|
yt = yt.to(self.device) |
|
|
|
|
|
diac, _ = self(*xt) |
|
loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1)) |
|
|
|
return loss |
|
|
|
def predict_partial( |
|
self, |
|
dataloader, |
|
return_extra=False, |
|
eval_only: str = None, |
|
do_partial=True, |
|
): |
|
training = self.training |
|
self.eval() |
|
|
|
preds = { |
|
'haraka': [], |
|
'shadda': [], |
|
'tanween': [], |
|
'diacs': [], |
|
'y_ctxt': [], |
|
'y_base': [], |
|
'subword_lengths': [], |
|
} |
|
print("> Predicting...") |
|
|
|
for i_batch, (inputs, _, subword_lengths) in enumerate(tqdm(dataloader)): |
|
|
|
|
|
|
|
inputs[0] = inputs[0].to(self.device) |
|
inputs[1] = inputs[1].to(self.device) |
|
|
|
|
|
if self._use_d2: |
|
subword_lengths = T.ones_like(inputs[0]) |
|
subword_lengths[inputs[0] == 0] = 0 |
|
|
|
with T.no_grad(): |
|
output = self( |
|
*inputs, |
|
subword_lengths=subword_lengths, |
|
return_extra=return_extra, |
|
eval_only=eval_only, |
|
do_partial=do_partial, |
|
) |
|
|
|
|
|
if return_extra: |
|
assert isinstance(output, PartialDiacOutput) |
|
marks = output.preds_hard |
|
if eval_only == 'recalibrated': |
|
marks = (output.preds_ctxt_logit + output.preds_base_logit).argmax(-1) |
|
preds['diacs'].extend(list(marks.detach().cpu().numpy())) |
|
preds['y_ctxt'].extend(list(output.preds_ctxt_logit.detach().cpu().numpy())) |
|
preds['y_base'].extend(list(output.preds_base_logit.detach().cpu().numpy())) |
|
preds['subword_lengths'].extend(list(subword_lengths.detach().cpu().numpy())) |
|
else: |
|
assert isinstance(output, T.Tensor) |
|
marks = output |
|
preds['diacs'].extend(list(marks.detach().cpu().numpy())) |
|
|
|
|
|
haraka, tanween, shadda = flat_2_3head(marks) |
|
|
|
preds['haraka'].extend(haraka) |
|
preds['tanween'].extend(tanween) |
|
preds['shadda'].extend(shadda) |
|
|
|
self.train(training) |
|
return { |
|
'diacritics': ( |
|
|
|
np.array(preds['haraka']), |
|
np.array(preds["tanween"]), |
|
np.array(preds["shadda"]), |
|
), |
|
'other': ( |
|
np.array(preds['y_ctxt']), |
|
np.array(preds['y_base']), |
|
np.array(preds['diacs']), |
|
np.array(preds['subword_lengths']), |
|
) |
|
} |
|
|
|
def predict(self, dataloader): |
|
training = self.training |
|
self.eval() |
|
|
|
preds = {'haraka': [], 'shadda': [], 'tanween': []} |
|
print("> Predicting...") |
|
for inputs, _ in tqdm(dataloader, total=len(dataloader)): |
|
inputs[0] = inputs[0].to(self.device) |
|
inputs[1] = inputs[1].to(self.device) |
|
output = self(*inputs) |
|
|
|
|
|
marks = output |
|
|
|
|
|
haraka, tanween, shadda = flat_2_3head(marks) |
|
|
|
preds['haraka'].extend(haraka) |
|
preds['tanween'].extend(tanween) |
|
preds['shadda'].extend(shadda) |
|
|
|
self.train(training) |
|
return ( |
|
np.array(preds['haraka']), |
|
np.array(preds["tanween"]), |
|
np.array(preds["shadda"]), |
|
) |