import numpy as np import torch as T from tqdm import tqdm from torch import nn from torch.nn import functional as F from components.k_lstm import K_LSTM from components.attention import Attention from data_utils import DatasetUtils from diac_utils import flat2_3head, flat_2_3head class DiacritizerD2(nn.Module): def __init__(self, config): super(DiacritizerD2, self).__init__() self.max_word_len = config["train"]["max-word-len"] self.max_sent_len = config["train"]["max-sent-len"] self.char_embed_dim = config["train"]["char-embed-dim"] self.final_dropout_p = config["train"]["final-dropout"] self.sent_dropout_p = config["train"]["sent-dropout"] self.diac_dropout_p = config["train"]["diac-dropout"] self.vertical_dropout = config['train']['vertical-dropout'] self.recurrent_dropout = config['train']['recurrent-dropout'] self.recurrent_dropout_mode = config['train'].get('recurrent-dropout-mode', 'gal_tied') self.recurrent_activation = config['train'].get('recurrent-activation', 'sigmoid') self.sent_lstm_units = config["train"]["sent-lstm-units"] self.word_lstm_units = config["train"]["word-lstm-units"] self.decoder_units = config["train"]["decoder-units"] self.sent_lstm_layers = config["train"]["sent-lstm-layers"] self.word_lstm_layers = config["train"]["word-lstm-layers"] self.cell = config['train'].get('rnn-cell', 'lstm') self.num_layers = config["train"].get("num-layers", 2) self.RNN_Layer = K_LSTM self.batch_first = config['train'].get('batch-first', True) self.device = 'cuda' if T.cuda.is_available() else 'cpu' self.num_classes = 15 def build(self, wembs: T.Tensor, abjad_size: int): self.closs = F.cross_entropy self.bloss = F.binary_cross_entropy_with_logits rnn_kargs = dict( recurrent_dropout_mode=self.recurrent_dropout_mode, recurrent_activation=self.recurrent_activation, ) self.sent_lstm = self.RNN_Layer( input_size=300, hidden_size=self.sent_lstm_units, num_layers=self.sent_lstm_layers, bidirectional=True, vertical_dropout=self.vertical_dropout, recurrent_dropout=self.recurrent_dropout, batch_first=self.batch_first, **rnn_kargs, ) self.word_lstm = self.RNN_Layer( input_size=self.sent_lstm_units * 2 + self.char_embed_dim, hidden_size=self.word_lstm_units, num_layers=self.word_lstm_layers, bidirectional=True, vertical_dropout=self.vertical_dropout, recurrent_dropout=self.recurrent_dropout, batch_first=self.batch_first, return_states=True, **rnn_kargs, ) self.char_embs = nn.Embedding( abjad_size, self.char_embed_dim, padding_idx=0, ) self.attention = Attention( kind="dot", query_dim=self.word_lstm_units * 2, input_dim=self.sent_lstm_units * 2, ) self.word_embs = T.tensor(wembs).clone().to(dtype=T.float32) self.word_embs = self.word_embs.to(self.device) self.classifier = nn.Linear(self.attention.Dout + self.word_lstm_units * 2, self.num_classes) self.dropout = nn.Dropout(self.final_dropout_p) def forward(self, sents, words, labels=None, subword_lengths=None): #^ sents : [b ts] #^ words : [b ts tw] #^ labels: [b ts tw] max_words = min(self.max_sent_len, sents.shape[1]) word_mask = words.ne(0.).float() #^ word_mask: [b ts tw] if self.training: q = 1.0 - self.sent_dropout_p sdo = T.bernoulli(T.full(sents.shape, q)) sents_do = sents * sdo.long() #^ sents_do : [b ts] ; DO(ts) wembs = self.word_embs[sents_do] #^ wembs : [b ts dw] ; DO(ts) else: wembs = self.word_embs[sents] #^ wembs : [b ts dw] sent_enc = self.sent_lstm(wembs.to(self.device)) #^ sent_enc : [b ts dwe] sentword_do = sent_enc.unsqueeze(2) #^ sentword_do : [b ts _ dwe] sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1)) #^ sentword_do : [b ts tw dwe] word_index = words.view(-1, self.max_word_len) #^ word_index: [b*ts tw]? cembs = self.char_embs(word_index) #^ cembs : [b*ts tw dc] sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2) #^ sentword_do : [b*ts tw dwe] char_embs = T.cat([cembs, sentword_do], dim=-1) #^ char_embs : [b*ts tw dcw] ; dcw = dc + dwe char_enc, _ = self.word_lstm(char_embs) #^ char_enc: [b*ts tw dce] char_enc_reshaped = char_enc.view(-1, max_words, self.max_word_len, self.word_lstm_units * 2) # #^ char_enc: [b ts tw dce] omit_self_mask = (1.0 - T.eye(max_words)).unsqueeze(0).to(self.device) attn_enc, attn_map = self.attention(char_enc_reshaped, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask) # # #^ attn_enc: [b ts tw dae] attn_enc = attn_enc.reshape(-1, self.max_word_len, self.attention.Dout) # #^ attn_enc: [b*ts tw dae] final_vec = T.cat([attn_enc, char_enc], dim=-1) diac_out = self.classifier(self.dropout(final_vec)) #^ diac_out: [b*ts tw 7] diac_out = diac_out.view(-1, max_words, self.max_word_len, self.num_classes) #^ diac_out: [b ts tw 7] if not self.batch_first: diac_out = diac_out.swapaxes(1, 0) return diac_out def step(self, xt, yt, mask=None): xt[1] = xt[1].to(self.device) xt[2] = xt[2].to(self.device) yt = yt.to(self.device) #^ yt: [b ts tw] diac, _ = self(*xt) loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1)) return loss def predict(self, dataloader): training = self.training self.eval() preds = {'haraka': [], 'shadda': [], 'tanween': []} print("> Predicting...") for inputs, _ in tqdm(dataloader, total=len(dataloader)): inputs[0] = inputs[0].to(self.device) inputs[1] = inputs[1].to(self.device) diac, _ = self(*inputs) output = np.argmax(T.softmax(diac.detach(), dim=-1).cpu().numpy(), axis=-1) #^ [b ts tw] haraka, tanween, shadda = flat_2_3head(output) preds['haraka'].extend(haraka) preds['tanween'].extend(tanween) preds['shadda'].extend(shadda) self.train(training) return ( np.array(preds['haraka']), np.array(preds["tanween"]), np.array(preds["shadda"]), ) class DiacritizerD3(nn.Module): def __init__(self, config, device='cuda'): super(DiacritizerD3, self).__init__() self.max_word_len = config["train"]["max-word-len"] self.max_sent_len = config["train"]["max-sent-len"] self.char_embed_dim = config["train"]["char-embed-dim"] self.sent_dropout_p = config["train"]["sent-dropout"] self.diac_dropout_p = config["train"]["diac-dropout"] self.vertical_dropout = config['train']['vertical-dropout'] self.recurrent_dropout = config['train']['recurrent-dropout'] self.recurrent_dropout_mode = config['train'].get('recurrent-dropout-mode', 'gal_tied') self.recurrent_activation = config['train'].get('recurrent-activation', 'sigmoid') self.sent_lstm_units = config["train"]["sent-lstm-units"] self.word_lstm_units = config["train"]["word-lstm-units"] self.decoder_units = config["train"]["decoder-units"] self.sent_lstm_layers = config["train"]["sent-lstm-layers"] self.word_lstm_layers = config["train"]["word-lstm-layers"] self.cell = config['train'].get('rnn-cell', 'lstm') self.num_layers = config["train"].get("num-layers", 2) self.RNN_Layer = K_LSTM self.batch_first = config['train'].get('batch-first', True) self.baseline = config["train"].get("baseline", False) self.device = device def build(self, wembs: T.Tensor, abjad_size: int): self.closs = F.cross_entropy self.bloss = F.binary_cross_entropy_with_logits rnn_kargs = dict( recurrent_dropout_mode=self.recurrent_dropout_mode, recurrent_activation=self.recurrent_activation, ) self.sent_lstm = self.RNN_Layer( input_size=300, hidden_size=self.sent_lstm_units, num_layers=self.sent_lstm_layers, bidirectional=True, vertical_dropout=self.vertical_dropout, recurrent_dropout=self.recurrent_dropout, batch_first=self.batch_first, **rnn_kargs, ) self.word_lstm = self.RNN_Layer( input_size=self.sent_lstm_units * 2 + self.char_embed_dim, hidden_size=self.word_lstm_units, num_layers=self.word_lstm_layers, bidirectional=True, vertical_dropout=self.vertical_dropout, recurrent_dropout=self.recurrent_dropout, batch_first=self.batch_first, return_states=True, **rnn_kargs, ) self.char_embs = nn.Embedding( abjad_size, self.char_embed_dim, padding_idx=0, ) self.attention = Attention( kind="dot", query_dim=self.word_lstm_units * 2, input_dim=self.sent_lstm_units * 2, ) self.lstm_decoder = self.RNN_Layer( input_size=self.word_lstm_units * 2 + self.attention.Dout + 8, hidden_size=self.word_lstm_units * 2, num_layers=1, bidirectional=False, vertical_dropout=self.vertical_dropout, recurrent_dropout=self.recurrent_dropout, batch_first=self.batch_first, return_states=True, **rnn_kargs, ) self.word_embs = T.tensor(wembs, dtype=T.float32) self.classifier = nn.Linear(self.lstm_decoder.hidden_size, 15) self.dropout = nn.Dropout(0.2) def forward(self, sents, words, labels): #^ sents : [b ts] #^ words : [b ts tw] #^ labels: [b ts tw] word_mask = words.ne(0.).float() #^ word_mask: [b ts tw] if self.training: q = 1.0 - self.sent_dropout_p sdo = T.bernoulli(T.full(sents.shape, q)) sents_do = sents * sdo.long() #^ sents_do : [b ts] ; DO(ts) wembs = self.word_embs[sents_do] #^ wembs : [b ts dw] ; DO(ts) else: wembs = self.word_embs[sents] #^ wembs : [b ts dw] sent_enc = self.sent_lstm(wembs.to(self.device)) #^ sent_enc : [b ts dwe] sentword_do = sent_enc.unsqueeze(2) #^ sentword_do : [b ts _ dwe] sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1)) #^ sentword_do : [b ts tw dwe] word_index = words.view(-1, self.max_word_len) #^ word_index: [b*ts tw]? cembs = self.char_embs(word_index) #^ cembs : [b*ts tw dc] sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2) #^ sentword_do : [b*ts tw dwe] char_embs = T.cat([cembs, sentword_do], dim=-1) #^ char_embs : [b*ts tw dcw] ; dcw = dc + dwe char_enc, _ = self.word_lstm(char_embs) #^ char_enc: [b*ts tw dce] char_enc_reshaped = char_enc.view(-1, self.max_sent_len, self.max_word_len, self.word_lstm_units * 2) #^ char_enc: [b ts tw dce] omit_self_mask = (1.0 - T.eye(self.max_sent_len)).unsqueeze(0).to(self.device) attn_enc, attn_map = self.attention(char_enc_reshaped, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask) #^ attn_enc: [b ts tw dae] attn_enc = attn_enc.view(-1, self.max_sent_len*self.max_word_len, self.attention.Dout) #^ attn_enc: [b*ts tw dae] if self.training and self.diac_dropout_p > 0: q = 1.0 - self.diac_dropout_p ddo = T.bernoulli(T.full(labels.shape[:-1], q)) labels = labels * ddo.unsqueeze(-1).long().to(self.device) #^ labels : [b ts tw] ; DO(ts) labels = labels.view(-1, self.max_sent_len*self.max_word_len, 8).float() #^ labels: [b*ts tw 8] char_enc = char_enc.view(-1, self.max_sent_len*self.max_word_len, self.word_lstm_units * 2) final_vec = T.cat([attn_enc, char_enc, labels], dim=-1) #^ final_vec: [b ts*tw dae+8] dec_out, _ = self.lstm_decoder(final_vec) #^ dec_out: [b*ts tw du] dec_out = dec_out.reshape(-1, self.max_word_len, self.lstm_decoder.hidden_size) diac_out = self.classifier(self.dropout(dec_out)) #^ diac_out: [b*ts tw 7] diac_out = diac_out.view(-1, self.max_sent_len, self.max_word_len, 15) #^ diac_out: [b ts tw 7] if not self.batch_first: diac_out = diac_out.swapaxes(1, 0) return diac_out, attn_map def predict_sample(self, sents, words, labels): word_mask = words.ne(0.).float() #^ mask: [b ts tw 1] if self.training: q = 1.0 - self.sent_dropout_p sdo = T.bernoulli(T.full(sents.shape, q)) sents_do = sents * sdo.long() #^ sents_do : [b ts] ; DO(ts) wembs = self.word_embs[sents_do] #^ wembs : [b ts dw] ; DO(ts) else: wembs = self.word_embs[sents] #^ wembs : [b ts dw] sent_enc = self.sent_lstm(wembs.to(self.device)) #^ sent_enc : [b ts dwe] sentword_do = sent_enc.unsqueeze(2) #^ sentword_do : [b ts _ dwe] sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1)) #^ sentword_do : [b ts tw dwe] word_index = words.view(-1, self.max_word_len) #^ word_index: [b*ts tw]? cembs = self.char_embs(word_index) #^ cembs : [b*ts tw dc] sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2) #^ sentword_do : [b*ts tw dwe] char_embs = T.cat([cembs, sentword_do], dim=-1) #^ char_embs : [b*ts tw dcw] ; dcw = dc + dwe char_enc, _ = self.word_lstm(char_embs) #^ char_enc: [b*ts tw dce] #^ word_states: ([b*ts dce], [b*ts dce]) char_enc = char_enc.view(-1, self.max_sent_len, self.max_word_len, self.word_lstm_units*2) #^ char_enc: [b ts tw dce] omit_self_mask = (1.0 - T.eye(self.max_sent_len)).unsqueeze(0).to(self.device) attn_enc, _ = self.attention(char_enc, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask) #^ attn_enc: [b ts tw dae] all_out = T.zeros(*char_enc.size()[:-1], 15).to(self.device) #^ all_out: [b ts tw 7] batch_sz = char_enc.size()[0] #^ batch_sz: b zeros = T.zeros(1, batch_sz, self.lstm_decoder.hidden_size).to(self.device) #^ zeros: [1 b du] bos_tag = T.tensor([0,0,0,0,0,1,0,0]).unsqueeze(0) #^ bos_tag: [1 8] prev_label = T.cat([bos_tag]*batch_sz).to(self.device).float() # bos_vec = T.cat([bos_tag]*batch_sz).to(self.device).float() #^ prev_label: [b 8] for ts in range(self.max_sent_len): dec_hx = (zeros, zeros) #^ dec_hx: [1 b du] for tw in range(self.max_word_len): final_vec = T.cat([attn_enc[:,ts,tw,:], char_enc[:,ts,tw,:], prev_label], dim=-1).unsqueeze(1) #^ final_vec: [b 1 dce+8] dec_out, dec_hx = self.lstm_decoder(final_vec, dec_hx) #^ dec_out: [b 1 du] dec_out = dec_out.squeeze(0) dec_out = dec_out.transpose(0,1) logits_raw = self.classifier(self.dropout(dec_out)) #^ logits_raw: [b 1 15] out_idx = T.max(T.softmax(logits_raw.squeeze(), dim=-1), dim=-1)[1] haraka, tanween, shadda = flat2_3head(out_idx.detach().cpu().numpy()) haraka_onehot = T.eye(6)[haraka].float().to(self.device) #^ haraka_onehot+bos_tag: [b 6] tanween = T.tensor(tanween).float().unsqueeze(-1).to(self.device) shadda = T.tensor(shadda).float().unsqueeze(-1).to(self.device) prev_label = T.cat([haraka_onehot, tanween, shadda], dim=-1) all_out[:,ts,tw,:] = logits_raw.squeeze() if not self.batch_first: all_out = all_out.swapaxes(1, 0) return all_out def step(self, xt, yt, mask=None): xt[1] = xt[1].to(self.device) xt[2] = xt[2].to(self.device) #^ yt: [b ts tw] yt = yt.to(self.device) if self.training: diac, _ = self(*xt) else: diac = self.predict_sample(*xt) #^ diac[0] : [b ts tw 5] loss = self.closs(diac.view(-1,15), yt.view(-1)) return loss def predict(self, dataloader): training = self.training self.eval() preds = {'haraka': [], 'shadda': [], 'tanween': []} print("> Predicting...") for inputs, _ in tqdm(dataloader, total=len(dataloader)): inputs[1] = inputs[1].to(self.device) inputs[2] = inputs[2].to(self.device) diac = self.predict_sample(*inputs) output = np.argmax(T.softmax(diac.detach(), dim=-1).cpu().numpy(), axis=-1) #^ [b ts tw] haraka, tanween, shadda = flat_2_3head(output) preds['haraka'].extend(haraka) preds['tanween'].extend(tanween) preds['shadda'].extend(shadda) self.train(training) return ( np.array(preds['haraka']), np.array(preds["tanween"]), np.array(preds["shadda"]), ) if __name__ == "__main__": import yaml config_path = "configs/dd/config_d2.yaml" model_path = "models/tashkeela-d2.pt" with open(config_path, 'r', encoding="utf-8") as file: config = yaml.load(file, Loader=yaml.FullLoader) data_utils = DatasetUtils(config) vocab_size = len(data_utils.letter_list) word_embeddings = data_utils.embeddings model = DiacritizerD2(config, device='cpu') model.build(word_embeddings, vocab_size) model.load_state_dict(T.load(model_path, map_location=T.device('cpu'))["state_dict"])