bkhmsi's picture
support for TD2
d7c4b94
from typing import List, Iterator, cast
import copy
import numpy as np
import torch as T
from torch import nn
from torch.nn import functional as F
from transformers import BertConfig, BertModel
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
class Diacritizer(nn.Module):
def __init__(
self,
config,
device=None,
load_pretrained=True
) -> None:
super().__init__()
self._dummy = nn.Parameter(T.ones(1))
if 'modeling' in config:
config = config['modeling']
self.config = config
model_name = config.get('base_model', "CAMeL-Lab/bert-base-arabic-camelbert-mix-ner")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if load_pretrained:
self.token_model: BertModel = AutoModel.from_pretrained(model_name)
else:
marbert_config = AutoConfig.from_pretrained(model_name)
self.token_model = AutoModel.from_config(marbert_config)
self.num_classes = 15
self.diac_model_config = BertConfig(**config['diac_model_config'])
self.token_model_config: BertConfig = self.token_model.config
self.char_embs = nn.Embedding(config["num-chars"], embedding_dim=config["char-embed-dim"])
self.diac_emb_model = self.build_diac_model(self.token_model)
self.down_project_token_embeds_deep = None
self.down_project_token_embeds = None
if 'token_hidden_size' in config:
if config['token_hidden_size'] == 'auto':
down_proj_size = self.diac_emb_model.config.hidden_size
else:
down_proj_size = config['token_hidden_size']
if config.get('deep-down-proj', False):
self.down_project_token_embeds_deep = nn.Sequential(
nn.Linear(
self.token_model_config.hidden_size + config["char-embed-dim"],
down_proj_size * 4,
bias=False,
),
nn.Tanh(),
nn.Linear(
down_proj_size * 4,
down_proj_size,
bias=False,
)
)
# else:
self.down_project_token_embeds = nn.Linear(
self.token_model_config.hidden_size + config["char-embed-dim"],
down_proj_size,
bias=False,
)
# assert self.down_project_token_embeds_deep is None or self.down_project_token_embeds is None
classifier_feature_size = self.diac_model_config.hidden_size
if config.get('deep-cls', False):
# classifier_feature_size = 512
self.final_feature_transform = nn.Linear(
self.diac_model_config.hidden_size
+ self.token_model_config.hidden_size,
#^ diac_features + [residual from token_model]
out_features=classifier_feature_size,
bias=False
)
else:
self.final_feature_transform = None
self.feature_layer_norm = nn.LayerNorm(classifier_feature_size)
self.classifier = nn.Linear(classifier_feature_size, self.num_classes, bias=True)
self.trim_model_(config)
self.dropout = nn.Dropout(config['dropout'])
self.sent_dropout_p = config['sentence_dropout']
self.closs = F.cross_entropy
def build_diac_model(self, token_model=None):
if self.config.get('pre-init-diac-model', False):
model = copy.deepcopy(self.token_model)
model.pooler = None
model.embeddings.word_embeddings = None
num_layers = self.config.get('keep-token-model-layers', None)
model.encoder.layer = nn.ModuleList(
list(model.encoder.layer[num_layers:num_layers*2])
)
model.encoder.config.num_hidden_layers = num_layers
else:
model = BertModel(self.diac_model_config)
return model
def trim_model_(self, config):
self.token_model.pooler = None
self.diac_emb_model.pooler = None
# self.diac_emb_model.embeddings = None
self.diac_emb_model.embeddings.word_embeddings = None
num_token_model_kept_layers = config.get('keep-token-model-layers', None)
if num_token_model_kept_layers is not None:
self.token_model.encoder.layer = nn.ModuleList(
list(self.token_model.encoder.layer[:num_token_model_kept_layers])
)
self.token_model.encoder.config.num_hidden_layers = num_token_model_kept_layers
if not config.get('full-finetune', False):
for param in self.token_model.parameters():
param.requires_grad = False
finetune_last_layers = config.get('num-finetune-last-layers', 4)
if finetune_last_layers > 0:
unfrozen_layers = self.token_model.encoder.layer[-finetune_last_layers:]
for layer in unfrozen_layers:
for param in layer.parameters():
param.requires_grad = True
def get_grouped_params(self):
downstream_params: Iterator[nn.Parameter] = cast(
Iterator,
(param
for module in (self.diac_emb_model, self.classifier, self.char_embs)
for param in module.parameters())
)
pg = {
'pretrained': self.token_model.parameters(),
'downstream': downstream_params,
}
return pg
@property
def device(self):
return self._dummy.device
def step(self, xt, yt, mask=None, subword_lengths: T.Tensor=None):
# ^ word_x, char_x, diac_x are Indices
# ^ xt : self.preprocess((word_x, char_x, diac_x)),
# ^ yt : T.tensor(diac_y, dtype=T.long),
# ^ subword_lengths: T.tensor(subword_lengths, dtype=T.long)
#< Move char_x, diac_x to device because they're small and trainable
xt[0], xt[1], yt, subword_lengths = self._slim_batch_size(xt[0], xt[1], yt, subword_lengths)
xt[0] = xt[0].to(self.device)
xt[1] = xt[1].to(self.device)
# xt[2] = xt[2].to(self.device)
yt = yt.to(self.device)
#^ yt: [b tw tc]
Nb, Tword, Tchar = xt[1].shape
if Tword * Tchar < 500:
diac = self(*xt, subword_lengths)
loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum')
else:
num_chunks = Tword * Tchar / 300
loss = 0
for i in range(round(num_chunks+0.5)):
_slice = slice(i*300, (i+1)*300)
chunk = self._slice_batch(xt, _slice)
diac = self(*chunk, subword_lengths[_slice])
chunk_loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum')
loss = loss + chunk_loss
return loss
def _slice_batch(self, xt: List[T.Tensor], _slice):
return [xt[0][_slice], xt[1][_slice], xt[2][_slice]]
def _slim_batch_size(
self,
tx: T.Tensor,
cx: T.Tensor,
yt: T.Tensor,
subword_lengths: T.Tensor
):
#^ tx : [b tt]
#^ cx : [b tw tc]
#^ yt : [b tw tc]
token_nonpad_mask = tx.ne(self.tokenizer.pad_token_id)
Ttoken = token_nonpad_mask.sum(1).max()
tx = tx[:, :Ttoken]
char_nonpad_mask = cx.ne(0)
Tword = char_nonpad_mask.any(2).sum(1).max()
Tchar = char_nonpad_mask.sum(2).max()
cx = cx[:, :Tword, :Tchar]
yt = yt[:, :Tword, :Tchar]
subword_lengths = subword_lengths[:, :Tword]
return tx, cx, yt, subword_lengths
def token_dropout(self, toke_x):
#^ toke_x : [b tw]
if self.training:
q = 1.0 - self.sent_dropout_p
sdo = T.bernoulli(T.full(toke_x.shape, q))
toke_x[sdo == 0] = self.tokenizer.pad_token_id
return toke_x
def sentence_dropout(self, word_embs: T.Tensor):
#^ word_embs : [b tw dwe]
if self.training:
q = 1.0 - self.sent_dropout_p
sdo = T.bernoulli(T.full(word_embs.shape[:2], q))
sdo = sdo.detach().unsqueeze(-1).to(word_embs)
word_embs = word_embs * sdo
# toke_x[sdo == 0] = self.tokenizer.pad_token_id
return word_embs
def embed_tokens(self, input_ids: T.Tensor, attention_mask: T.Tensor):
y: BaseModelOutputWithPoolingAndCrossAttentions
y = self.token_model(input_ids, attention_mask=attention_mask)
z = y.last_hidden_state
return z
def forward(
self,
toke_x : T.Tensor,
char_x : T.Tensor,
diac_x : T.Tensor,
subword_lengths : T.Tensor,
):
#^ toke_x : [b tt]
#^ char_x : [b tw tc]
#^ diac_x/labels : [b tw tc]
#^ subword_lengths : [b, tw]
# !TODO Use `subword_lengths` to aggregate subword embeddings first before ...
# ... passing concatenated contextual embedding to chars in diac_model
token_nonpad_mask = toke_x.ne(self.tokenizer.pad_token_id)
char_nonpad_mask = char_x.ne(0)
Nb, Tw, Tc = char_x.shape
# assert Tw == Tw_0 and Tc == Tc_0, f"{Tw=} {Tw_0=}, {Tc=} {Tc_0=}"
# toke_x = self.token_dropout(toke_x)
token_embs = self.embed_tokens(toke_x, attention_mask=token_nonpad_mask)
# token_embs = self.sentence_dropout(token_embs)
#? Strip BOS,EOS
token_embs = token_embs[:, 1:-1, ...]
sent_word_strides = subword_lengths.cumsum(1)
sent_enc: T.Tensor = T.zeros(Nb, Tw, token_embs.shape[-1]).to(token_embs)
for i_b in range(Nb):
token_embs_ib = token_embs[i_b]
start_iw = 0
for i_word, end_iw in enumerate(sent_word_strides[i_b]):
if end_iw == start_iw: break
word_emb = token_embs_ib[start_iw : end_iw].sum(0) / (end_iw - start_iw)
sent_enc[i_b, i_word] = word_emb
start_iw = end_iw
#^ sent_enc: [b tw dwe]
char_x_flat = char_x.reshape(Nb*Tw, Tc)
char_nonpad_mask = char_x_flat.gt(0)
# ^ char_nonpad_mask [b*tw tc]
char_x_flat = char_x_flat * char_nonpad_mask
cembs = self.char_embs(char_x_flat)
#^ cembs: [b*tw tc dce]
wembs = sent_enc.unsqueeze(-2).expand(Nb, Tw, Tc, -1).view(Nb*Tw, Tc, -1)
#^ wembs: [b tw dwe] => [b tw _ dwe] => [b*tw tc dwe]
cw_embs = T.cat([cembs, wembs], dim=-1)
#^ char_embs : [b*tw tc dcw] ; dcw = dc + dwe
cw_embs = self.dropout(cw_embs)
cw_embs_ = cw_embs
if self.down_project_token_embeds is not None:
cw_embs_ = self.down_project_token_embeds(cw_embs)
if self.down_project_token_embeds_deep is not None:
cw_embs_ = cw_embs_ + self.down_project_token_embeds_deep(cw_embs)
cw_embs = cw_embs_
diac_enc: BaseModelOutputWithPoolingAndCrossAttentions
diac_enc = self.diac_emb_model(inputs_embeds=cw_embs, attention_mask=char_nonpad_mask)
diac_emb = diac_enc.last_hidden_state
diac_emb = self.dropout(diac_emb)
#^ diac_emb: [b*tw tc dce]
diac_emb = diac_emb.view(Nb, Tw, Tc, -1)
sent_residual = sent_enc.unsqueeze(2).expand(-1, -1, Tc, -1)
final_feature = T.cat([sent_residual, diac_emb], dim=-1)
if self.final_feature_transform is not None:
final_feature = self.final_feature_transform(final_feature)
final_feature = F.tanh(final_feature)
final_feature = self.dropout(final_feature)
else:
final_feature = diac_emb
# final_feature = self.feature_layer_norm(final_feature)
diac_out = self.classifier(final_feature)
# if T.isnan(diac_out).any():
# breakpoint()
return diac_out
def predict(self, dataloader):
from tqdm import tqdm
import diac_utils as du
training = self.training
self.eval()
preds = {'haraka': [], 'shadda': [], 'tanween': []}
print("> Predicting...")
for inputs, _, subword_lengths in tqdm(dataloader, total=len(dataloader)):
inputs[0] = inputs[0].to(self.device)
inputs[1] = inputs[1].to(self.device)
output = self(*inputs, subword_lengths).detach()
marks = np.argmax(T.softmax(output, dim=-1).cpu().numpy(), axis=-1)
#^ [b ts tw]
haraka, tanween, shadda = du.flat_2_3head(marks)
preds['haraka'].extend(haraka)
preds['tanween'].extend(tanween)
preds['shadda'].extend(shadda)
self.train(training)
return (
np.array(preds['haraka']),
np.array(preds["tanween"]),
np.array(preds["shadda"]),
)
if __name__ == "__main__":
model = Diacritizer({
"num-chars": 36,
"hidden_size": 768,
"char-embed-dim": 32,
"dropout": 0.25,
"sentence_dropout": 0.2,
"diac_model_config": {
"num_layers": 4,
"hidden_size": 768 + 32,
"intermediate_size": (768 + 32) * 4,
},
}, load_pretrained=False)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(model)
print(f"{trainable_params:,}/{total_params:,} Trainable Parameters")