|
from typing import List, Iterator, cast |
|
|
|
import copy |
|
import numpy as np |
|
|
|
import torch as T |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from transformers import BertConfig, BertModel |
|
from transformers import AutoTokenizer, AutoModel, AutoConfig |
|
from transformers import PreTrainedModel |
|
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions |
|
|
|
class Diacritizer(nn.Module): |
|
def __init__( |
|
self, |
|
config, |
|
device=None, |
|
load_pretrained=True |
|
) -> None: |
|
super().__init__() |
|
self._dummy = nn.Parameter(T.ones(1)) |
|
|
|
if 'modeling' in config: |
|
config = config['modeling'] |
|
self.config = config |
|
|
|
model_name = config.get('base_model', "CAMeL-Lab/bert-base-arabic-camelbert-mix-ner") |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
if load_pretrained: |
|
self.token_model: BertModel = AutoModel.from_pretrained(model_name) |
|
else: |
|
marbert_config = AutoConfig.from_pretrained(model_name) |
|
self.token_model = AutoModel.from_config(marbert_config) |
|
|
|
self.num_classes = 15 |
|
self.diac_model_config = BertConfig(**config['diac_model_config']) |
|
self.token_model_config: BertConfig = self.token_model.config |
|
|
|
self.char_embs = nn.Embedding(config["num-chars"], embedding_dim=config["char-embed-dim"]) |
|
self.diac_emb_model = self.build_diac_model(self.token_model) |
|
|
|
self.down_project_token_embeds_deep = None |
|
self.down_project_token_embeds = None |
|
if 'token_hidden_size' in config: |
|
if config['token_hidden_size'] == 'auto': |
|
down_proj_size = self.diac_emb_model.config.hidden_size |
|
else: |
|
down_proj_size = config['token_hidden_size'] |
|
if config.get('deep-down-proj', False): |
|
self.down_project_token_embeds_deep = nn.Sequential( |
|
nn.Linear( |
|
self.token_model_config.hidden_size + config["char-embed-dim"], |
|
down_proj_size * 4, |
|
bias=False, |
|
), |
|
nn.Tanh(), |
|
nn.Linear( |
|
down_proj_size * 4, |
|
down_proj_size, |
|
bias=False, |
|
) |
|
) |
|
|
|
self.down_project_token_embeds = nn.Linear( |
|
self.token_model_config.hidden_size + config["char-embed-dim"], |
|
down_proj_size, |
|
bias=False, |
|
) |
|
|
|
|
|
classifier_feature_size = self.diac_model_config.hidden_size |
|
if config.get('deep-cls', False): |
|
|
|
self.final_feature_transform = nn.Linear( |
|
self.diac_model_config.hidden_size |
|
+ self.token_model_config.hidden_size, |
|
|
|
out_features=classifier_feature_size, |
|
bias=False |
|
) |
|
else: |
|
self.final_feature_transform = None |
|
|
|
self.feature_layer_norm = nn.LayerNorm(classifier_feature_size) |
|
self.classifier = nn.Linear(classifier_feature_size, self.num_classes, bias=True) |
|
|
|
self.trim_model_(config) |
|
|
|
self.dropout = nn.Dropout(config['dropout']) |
|
self.sent_dropout_p = config['sentence_dropout'] |
|
self.closs = F.cross_entropy |
|
|
|
def build_diac_model(self, token_model=None): |
|
if self.config.get('pre-init-diac-model', False): |
|
model = copy.deepcopy(self.token_model) |
|
model.pooler = None |
|
model.embeddings.word_embeddings = None |
|
|
|
num_layers = self.config.get('keep-token-model-layers', None) |
|
model.encoder.layer = nn.ModuleList( |
|
list(model.encoder.layer[num_layers:num_layers*2]) |
|
) |
|
|
|
model.encoder.config.num_hidden_layers = num_layers |
|
else: |
|
model = BertModel(self.diac_model_config) |
|
return model |
|
|
|
def trim_model_(self, config): |
|
self.token_model.pooler = None |
|
self.diac_emb_model.pooler = None |
|
|
|
self.diac_emb_model.embeddings.word_embeddings = None |
|
|
|
num_token_model_kept_layers = config.get('keep-token-model-layers', None) |
|
if num_token_model_kept_layers is not None: |
|
self.token_model.encoder.layer = nn.ModuleList( |
|
list(self.token_model.encoder.layer[:num_token_model_kept_layers]) |
|
) |
|
self.token_model.encoder.config.num_hidden_layers = num_token_model_kept_layers |
|
|
|
if not config.get('full-finetune', False): |
|
for param in self.token_model.parameters(): |
|
param.requires_grad = False |
|
finetune_last_layers = config.get('num-finetune-last-layers', 4) |
|
if finetune_last_layers > 0: |
|
unfrozen_layers = self.token_model.encoder.layer[-finetune_last_layers:] |
|
for layer in unfrozen_layers: |
|
for param in layer.parameters(): |
|
param.requires_grad = True |
|
|
|
def get_grouped_params(self): |
|
downstream_params: Iterator[nn.Parameter] = cast( |
|
Iterator, |
|
(param |
|
for module in (self.diac_emb_model, self.classifier, self.char_embs) |
|
for param in module.parameters()) |
|
) |
|
pg = { |
|
'pretrained': self.token_model.parameters(), |
|
'downstream': downstream_params, |
|
} |
|
return pg |
|
|
|
@property |
|
def device(self): |
|
return self._dummy.device |
|
|
|
def step(self, xt, yt, mask=None, subword_lengths: T.Tensor=None): |
|
|
|
|
|
|
|
|
|
|
|
xt[0], xt[1], yt, subword_lengths = self._slim_batch_size(xt[0], xt[1], yt, subword_lengths) |
|
xt[0] = xt[0].to(self.device) |
|
xt[1] = xt[1].to(self.device) |
|
|
|
|
|
yt = yt.to(self.device) |
|
|
|
|
|
Nb, Tword, Tchar = xt[1].shape |
|
if Tword * Tchar < 500: |
|
diac = self(*xt, subword_lengths) |
|
loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum') |
|
else: |
|
num_chunks = Tword * Tchar / 300 |
|
loss = 0 |
|
for i in range(round(num_chunks+0.5)): |
|
_slice = slice(i*300, (i+1)*300) |
|
chunk = self._slice_batch(xt, _slice) |
|
diac = self(*chunk, subword_lengths[_slice]) |
|
chunk_loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum') |
|
loss = loss + chunk_loss |
|
|
|
return loss |
|
|
|
def _slice_batch(self, xt: List[T.Tensor], _slice): |
|
return [xt[0][_slice], xt[1][_slice], xt[2][_slice]] |
|
|
|
def _slim_batch_size( |
|
self, |
|
tx: T.Tensor, |
|
cx: T.Tensor, |
|
yt: T.Tensor, |
|
subword_lengths: T.Tensor |
|
): |
|
|
|
|
|
|
|
token_nonpad_mask = tx.ne(self.tokenizer.pad_token_id) |
|
Ttoken = token_nonpad_mask.sum(1).max() |
|
tx = tx[:, :Ttoken] |
|
|
|
char_nonpad_mask = cx.ne(0) |
|
Tword = char_nonpad_mask.any(2).sum(1).max() |
|
Tchar = char_nonpad_mask.sum(2).max() |
|
cx = cx[:, :Tword, :Tchar] |
|
yt = yt[:, :Tword, :Tchar] |
|
subword_lengths = subword_lengths[:, :Tword] |
|
|
|
return tx, cx, yt, subword_lengths |
|
|
|
def token_dropout(self, toke_x): |
|
|
|
if self.training: |
|
q = 1.0 - self.sent_dropout_p |
|
sdo = T.bernoulli(T.full(toke_x.shape, q)) |
|
toke_x[sdo == 0] = self.tokenizer.pad_token_id |
|
return toke_x |
|
|
|
def sentence_dropout(self, word_embs: T.Tensor): |
|
|
|
if self.training: |
|
q = 1.0 - self.sent_dropout_p |
|
sdo = T.bernoulli(T.full(word_embs.shape[:2], q)) |
|
sdo = sdo.detach().unsqueeze(-1).to(word_embs) |
|
word_embs = word_embs * sdo |
|
|
|
return word_embs |
|
|
|
def embed_tokens(self, input_ids: T.Tensor, attention_mask: T.Tensor): |
|
y: BaseModelOutputWithPoolingAndCrossAttentions |
|
y = self.token_model(input_ids, attention_mask=attention_mask) |
|
z = y.last_hidden_state |
|
return z |
|
|
|
def forward( |
|
self, |
|
toke_x : T.Tensor, |
|
char_x : T.Tensor, |
|
diac_x : T.Tensor, |
|
subword_lengths : T.Tensor, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token_nonpad_mask = toke_x.ne(self.tokenizer.pad_token_id) |
|
char_nonpad_mask = char_x.ne(0) |
|
|
|
Nb, Tw, Tc = char_x.shape |
|
|
|
|
|
|
|
token_embs = self.embed_tokens(toke_x, attention_mask=token_nonpad_mask) |
|
|
|
|
|
token_embs = token_embs[:, 1:-1, ...] |
|
|
|
sent_word_strides = subword_lengths.cumsum(1) |
|
sent_enc: T.Tensor = T.zeros(Nb, Tw, token_embs.shape[-1]).to(token_embs) |
|
for i_b in range(Nb): |
|
token_embs_ib = token_embs[i_b] |
|
start_iw = 0 |
|
for i_word, end_iw in enumerate(sent_word_strides[i_b]): |
|
if end_iw == start_iw: break |
|
word_emb = token_embs_ib[start_iw : end_iw].sum(0) / (end_iw - start_iw) |
|
sent_enc[i_b, i_word] = word_emb |
|
start_iw = end_iw |
|
|
|
|
|
char_x_flat = char_x.reshape(Nb*Tw, Tc) |
|
char_nonpad_mask = char_x_flat.gt(0) |
|
|
|
|
|
char_x_flat = char_x_flat * char_nonpad_mask |
|
|
|
cembs = self.char_embs(char_x_flat) |
|
|
|
|
|
wembs = sent_enc.unsqueeze(-2).expand(Nb, Tw, Tc, -1).view(Nb*Tw, Tc, -1) |
|
|
|
cw_embs = T.cat([cembs, wembs], dim=-1) |
|
|
|
cw_embs = self.dropout(cw_embs) |
|
|
|
cw_embs_ = cw_embs |
|
if self.down_project_token_embeds is not None: |
|
cw_embs_ = self.down_project_token_embeds(cw_embs) |
|
if self.down_project_token_embeds_deep is not None: |
|
cw_embs_ = cw_embs_ + self.down_project_token_embeds_deep(cw_embs) |
|
cw_embs = cw_embs_ |
|
|
|
diac_enc: BaseModelOutputWithPoolingAndCrossAttentions |
|
diac_enc = self.diac_emb_model(inputs_embeds=cw_embs, attention_mask=char_nonpad_mask) |
|
diac_emb = diac_enc.last_hidden_state |
|
diac_emb = self.dropout(diac_emb) |
|
|
|
diac_emb = diac_emb.view(Nb, Tw, Tc, -1) |
|
|
|
sent_residual = sent_enc.unsqueeze(2).expand(-1, -1, Tc, -1) |
|
final_feature = T.cat([sent_residual, diac_emb], dim=-1) |
|
if self.final_feature_transform is not None: |
|
final_feature = self.final_feature_transform(final_feature) |
|
final_feature = F.tanh(final_feature) |
|
final_feature = self.dropout(final_feature) |
|
else: |
|
final_feature = diac_emb |
|
|
|
|
|
diac_out = self.classifier(final_feature) |
|
|
|
|
|
return diac_out |
|
|
|
def predict(self, dataloader): |
|
from tqdm import tqdm |
|
import diac_utils as du |
|
training = self.training |
|
self.eval() |
|
|
|
preds = {'haraka': [], 'shadda': [], 'tanween': []} |
|
print("> Predicting...") |
|
for inputs, _, subword_lengths in tqdm(dataloader, total=len(dataloader)): |
|
inputs[0] = inputs[0].to(self.device) |
|
inputs[1] = inputs[1].to(self.device) |
|
output = self(*inputs, subword_lengths).detach() |
|
|
|
marks = np.argmax(T.softmax(output, dim=-1).cpu().numpy(), axis=-1) |
|
|
|
|
|
haraka, tanween, shadda = du.flat_2_3head(marks) |
|
|
|
preds['haraka'].extend(haraka) |
|
preds['tanween'].extend(tanween) |
|
preds['shadda'].extend(shadda) |
|
|
|
self.train(training) |
|
return ( |
|
np.array(preds['haraka']), |
|
np.array(preds["tanween"]), |
|
np.array(preds["shadda"]), |
|
) |
|
|
|
if __name__ == "__main__": |
|
model = Diacritizer({ |
|
"num-chars": 36, |
|
"hidden_size": 768, |
|
"char-embed-dim": 32, |
|
"dropout": 0.25, |
|
"sentence_dropout": 0.2, |
|
"diac_model_config": { |
|
"num_layers": 4, |
|
"hidden_size": 768 + 32, |
|
"intermediate_size": (768 + 32) * 4, |
|
}, |
|
}, load_pretrained=False) |
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
print(model) |
|
print(f"{trainable_params:,}/{total_params:,} Trainable Parameters") |