|
from typing import Iterable, Union, Tuple |
|
from collections import Counter |
|
|
|
import argparse |
|
import os |
|
|
|
import yaml |
|
from pyarabic.araby import tokenize, strip_tatweel, strip_tashkeel |
|
from tqdm import tqdm |
|
|
|
import numpy as np |
|
import torch as T |
|
from torch.utils.data import DataLoader |
|
|
|
from diac_utils import HARAKAT_MAP, shakkel_char, flat2_3head |
|
from model_partial import PartialDD |
|
from model_dd import DiacritizerD2 |
|
from data_utils import DatasetUtils |
|
from dataloader import DataRetriever |
|
from segment import segment |
|
|
|
from partial_dd_metrics import ( |
|
parse_data, |
|
load_data, |
|
make_mask_hard, |
|
make_mask_logits, |
|
) |
|
|
|
def apply_tashkeel( |
|
line: str, |
|
diacs: Union[np.ndarray, T.Tensor] |
|
): |
|
line_w_diacs = "" |
|
ts, tw = diacs.shape |
|
diacs = diacs.flatten() |
|
diacs_h3 = flat2_3head(diacs) |
|
diacs_h3 = tuple(x.reshape(ts, tw) for x in diacs_h3) |
|
diac_char_idx = 0 |
|
diac_word_idx = 0 |
|
for ch in line: |
|
line_w_diacs += ch |
|
if ch == " ": |
|
diac_char_idx = 0 |
|
diac_word_idx += 1 |
|
else: |
|
tashkeel = (diacs_h3[0][diac_word_idx][diac_char_idx], diacs_h3[1][diac_word_idx][diac_char_idx], diacs_h3[2][diac_word_idx][diac_char_idx]) |
|
diac_char_idx += 1 |
|
line_w_diacs += shakkel_char(*tashkeel) |
|
return line_w_diacs |
|
|
|
def diac_text(data, model_output_base, model_output_ctxt, selection_mode='contrastive-hard', threshold=0.1): |
|
|
|
mode = selection_mode |
|
if mode == 'contrastive-hard': |
|
|
|
|
|
|
|
diacritics = np.where( |
|
make_mask_hard(model_output_ctxt, model_output_base), |
|
model_output_ctxt.argmax(-1), |
|
0, |
|
).astype(int) |
|
else: |
|
|
|
|
|
diacritics = np.where( |
|
make_mask_logits( |
|
model_output_ctxt, model_output_base, |
|
version=mode, threshold=threshold, |
|
), |
|
model_output_ctxt.argmax(-1), |
|
0, |
|
).astype(int) |
|
|
|
diacs_pred = model_output_base |
|
|
|
assert len(diacs_pred) == len(data) |
|
data = [ |
|
' '.join(tokenize( |
|
line.strip(), |
|
morphs=[strip_tashkeel, strip_tatweel] |
|
)) |
|
for line in data |
|
] |
|
|
|
output = [] |
|
for line, line_diacs in zip( |
|
tqdm(data), |
|
diacritics |
|
): |
|
line = apply_tashkeel(line, line_diacs) |
|
output.append(line) |
|
|
|
return output |
|
|
|
class Predictor: |
|
def __init__(self, config): |
|
|
|
self.data_utils = DatasetUtils(config) |
|
vocab_size = len(self.data_utils.letter_list) |
|
word_embeddings = self.data_utils.embeddings |
|
self.config = config |
|
|
|
self.device = T.device( |
|
config['predictor'].get('device', 'cuda:0') |
|
if T.cuda.is_available() else 'cpu' |
|
) |
|
|
|
self.model = PartialDD(config) |
|
self.model.sentence_diac.build(word_embeddings, vocab_size) |
|
state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict'] |
|
self.model.load_state_dict(state_dict) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
def create_dataloader(self, text, do_partial, do_hard_mask, threshold): |
|
self.threshold = threshold |
|
self.do_hard_mask = do_hard_mask |
|
|
|
stride = self.config["segment"]["stride"] |
|
window = self.config["segment"]["window"] |
|
min_window = self.config["segment"]["min-window"] |
|
if self.do_hard_mask or not do_partial: |
|
segments, mapping = segment([text], stride, window, min_window) |
|
|
|
mapping_lines = [] |
|
for sent_idx, seg_idx, word_idx, char_idx in mapping: |
|
mapping_lines += [f"{sent_idx}, {seg_idx}, {word_idx}, {char_idx}"] |
|
|
|
self.mapping = self.data_utils.load_mapping_v3_from_list(mapping_lines) |
|
self.original_lines = [text] |
|
self.segments = segments |
|
else: |
|
segments = text.split('\n') |
|
|
|
self.segments = segments |
|
self.original_lines = text.split('\n') |
|
|
|
self.data_loader = DataLoader( |
|
DataRetriever(self.data_utils, segments), |
|
batch_size=self.config["predictor"].get("batch-size", 32), |
|
shuffle=False, |
|
num_workers=self.config['loader'].get('num-workers', 0), |
|
) |
|
|
|
class PredictTri(Predictor): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.diacritics = { |
|
"FATHA": 1, |
|
"KASRA": 2, |
|
"DAMMA": 3, |
|
"SUKUN": 4 |
|
} |
|
self.votes: Union[Counter[int], Counter[bool]] = Counter() |
|
|
|
def count_votes( |
|
self, |
|
things: Union[Iterable[int], Iterable[bool]] |
|
): |
|
self.votes.clear() |
|
self.votes.update(things) |
|
return self.votes.most_common(1)[0][0] |
|
|
|
def predict_majority_vote(self): |
|
y_gen_diac, y_gen_tanween, y_gen_shadda = self.model.predict(self.data_loader) |
|
diacritized_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda) |
|
return diacritized_lines |
|
|
|
def predict_partial(self, do_partial, lines): |
|
outputs = self.model.predict_partial(self.data_loader, return_extra=True, eval_only='both', do_partial=do_partial) |
|
|
|
if self.do_hard_mask or not do_partial: |
|
y_gen_diac, y_gen_tanween, y_gen_shadda = outputs['diacritics'] |
|
diac_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda) |
|
else: |
|
diac_lines = diac_text(lines, outputs["other"][1], outputs["other"][0], selection_mode='1', threshold=self.threshold) |
|
|
|
return '\n'.join(diac_lines) |
|
|
|
def predict_majority_vote_context_contrastive(self, overwrite_cache=False): |
|
assert isinstance(self.model, PartialDD) |
|
if not os.path.exists("dataset/cache/y_gen_diac.npy") or overwrite_cache: |
|
if not os.path.exists("dataset/cache"): |
|
os.mkdir("dataset/cache") |
|
|
|
segment_outputs = self.model.predict_partial(self.data_loader, return_extra=False, eval_only='both') |
|
T.save(segment_outputs, "dataset/cache/cache.pt") |
|
else: |
|
segment_outputs = T.load("dataset/cache/cache.pt") |
|
|
|
y_gen_diac, y_gen_tanween, y_gen_shadda = segment_outputs['diacritics'] |
|
diacritized_lines, extra_for_lines = self.coalesce_votes_by_majority( |
|
y_gen_diac, y_gen_tanween, y_gen_shadda, |
|
) |
|
extra_out = { |
|
'line_data': { |
|
**extra_for_lines, |
|
}, |
|
'segment_data': { |
|
**segment_outputs, |
|
|
|
} |
|
} |
|
|
|
return diacritized_lines, extra_out |
|
|
|
def coalesce_votes_by_majority( |
|
self, |
|
y_gen_diac: np.ndarray, |
|
y_gen_tanween: np.ndarray, |
|
y_gen_shadda: np.ndarray, |
|
): |
|
prepped_lines_og = [' '.join(tokenize(strip_tatweel(line))) for line in self.original_lines] |
|
max_line_chars = max(len(line) for line in prepped_lines_og) |
|
diacritics_pred = np.full((len(self.original_lines), max_line_chars), fill_value=-1, dtype=int) |
|
|
|
count_processed_sents = 0 |
|
do_break = False |
|
diacritized_lines = [] |
|
for sent_idx, line in enumerate(tqdm(prepped_lines_og)): |
|
count_processed_sents = sent_idx + 1 |
|
line = line.strip() |
|
diacritized_line = "" |
|
for char_idx, char in enumerate(line): |
|
diacritized_line += char |
|
char_vote_diacritic = [] |
|
|
|
if sent_idx not in self.mapping: |
|
continue |
|
|
|
mapping_s_i = self.mapping[sent_idx] |
|
for seg_idx in mapping_s_i: |
|
if self.data_utils.debug and seg_idx >= 256: |
|
do_break = True |
|
break |
|
|
|
mapping_g_i = mapping_s_i[seg_idx] |
|
for t_idx in mapping_g_i: |
|
|
|
mapping_t_i = mapping_g_i[t_idx] |
|
if char_idx in mapping_t_i: |
|
c_idx = mapping_t_i.index(char_idx) |
|
output_idx = np.s_[seg_idx, t_idx, c_idx] |
|
diac_h3 = (y_gen_diac[output_idx], y_gen_tanween[output_idx], y_gen_shadda[output_idx]) |
|
diac_char_i = HARAKAT_MAP.index(diac_h3) |
|
if c_idx < 13 and diac_char_i != 0: |
|
char_vote_diacritic.append(diac_char_i) |
|
|
|
if do_break: |
|
break |
|
if len(char_vote_diacritic) > 0: |
|
char_mv_diac = self.count_votes(char_vote_diacritic) |
|
diacritized_line += shakkel_char(*HARAKAT_MAP[char_mv_diac]) |
|
diacritics_pred[sent_idx, char_idx] = char_mv_diac |
|
else: |
|
diacritics_pred[sent_idx, char_idx] = 0 |
|
if do_break: |
|
break |
|
|
|
diacritized_lines += [diacritized_line.strip()] |
|
|
|
print(f'[INFO] Cutting stats from {len(diacritics_pred)} to {count_processed_sents}') |
|
extra = { |
|
'diac_pred': diacritics_pred[:count_processed_sents], |
|
} |
|
return diacritized_lines, extra |