from typing import Iterable, Union, Tuple from collections import Counter import argparse import os import yaml from pyarabic.araby import tokenize, strip_tatweel, strip_tashkeel from tqdm import tqdm import numpy as np import torch as T from torch.utils.data import DataLoader from diac_utils import HARAKAT_MAP, shakkel_char, flat2_3head from model_partial import PartialDD from model_dd import DiacritizerD2 from data_utils import DatasetUtils from dataloader import DataRetriever from segment import segment from partial_dd_metrics import ( parse_data, load_data, make_mask_hard, make_mask_logits, ) def apply_tashkeel( line: str, diacs: Union[np.ndarray, T.Tensor] ): line_w_diacs = "" ts, tw = diacs.shape diacs = diacs.flatten() diacs_h3 = flat2_3head(diacs) diacs_h3 = tuple(x.reshape(ts, tw) for x in diacs_h3) diac_char_idx = 0 diac_word_idx = 0 for ch in line: line_w_diacs += ch if ch == " ": diac_char_idx = 0 diac_word_idx += 1 else: tashkeel = (diacs_h3[0][diac_word_idx][diac_char_idx], diacs_h3[1][diac_word_idx][diac_char_idx], diacs_h3[2][diac_word_idx][diac_char_idx]) diac_char_idx += 1 line_w_diacs += shakkel_char(*tashkeel) return line_w_diacs def diac_text(data, model_output_base, model_output_ctxt, selection_mode='contrastive-hard', threshold=0.1): mode = selection_mode if mode == 'contrastive-hard': # model_output_base = parse_data(data_base)[0] # model_output_ctxt = parse_data(data_ctxt)[0] # diacs = np.where(diacs_base != diacs_ctxt, diacs_ctxt, 0) diacritics = np.where( make_mask_hard(model_output_ctxt, model_output_base), model_output_ctxt.argmax(-1), 0, ).astype(int) else: # model_output_base = parse_data(data_base, logits=True, side='base')[2] # model_output_ctxt = parse_data(data_ctxt, logits=True, side='ctxt')[2] diacritics = np.where( make_mask_logits( model_output_ctxt, model_output_base, version=mode, threshold=threshold, ), model_output_ctxt.argmax(-1), 0, ).astype(int) #^ shape: [b, tc | ClassId] diacs_pred = model_output_base assert len(diacs_pred) == len(data) data = [ ' '.join(tokenize( line.strip(), morphs=[strip_tashkeel, strip_tatweel] )) for line in data ] output = [] for line, line_diacs in zip( tqdm(data), diacritics ): line = apply_tashkeel(line, line_diacs) output.append(line) return output class Predictor: def __init__(self, config): self.data_utils = DatasetUtils(config) vocab_size = len(self.data_utils.letter_list) word_embeddings = self.data_utils.embeddings self.config = config self.device = T.device( config['predictor'].get('device', 'cuda:0') if T.cuda.is_available() else 'cpu' ) self.model = PartialDD(config) self.model.sentence_diac.build(word_embeddings, vocab_size) state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict'] self.model.load_state_dict(state_dict) self.model.to(self.device) self.model.eval() def create_dataloader(self, text, do_partial, do_hard_mask, threshold): self.threshold = threshold self.do_hard_mask = do_hard_mask stride = self.config["segment"]["stride"] window = self.config["segment"]["window"] min_window = self.config["segment"]["min-window"] if self.do_hard_mask or not do_partial: segments, mapping = segment([text], stride, window, min_window) mapping_lines = [] for sent_idx, seg_idx, word_idx, char_idx in mapping: mapping_lines += [f"{sent_idx}, {seg_idx}, {word_idx}, {char_idx}"] self.mapping = self.data_utils.load_mapping_v3_from_list(mapping_lines) self.original_lines = [text] self.segments = segments else: segments = text.split('\n') self.segments = segments self.original_lines = text.split('\n') self.data_loader = DataLoader( DataRetriever(self.data_utils, segments), batch_size=self.config["predictor"].get("batch-size", 32), shuffle=False, num_workers=self.config['loader'].get('num-workers', 0), ) class PredictTri(Predictor): def __init__(self, config): super().__init__(config) self.diacritics = { "FATHA": 1, "KASRA": 2, "DAMMA": 3, "SUKUN": 4 } self.votes: Union[Counter[int], Counter[bool]] = Counter() def count_votes( self, things: Union[Iterable[int], Iterable[bool]] ): self.votes.clear() self.votes.update(things) return self.votes.most_common(1)[0][0] def predict_majority_vote(self): y_gen_diac, y_gen_tanween, y_gen_shadda = self.model.predict(self.data_loader) diacritized_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda) return diacritized_lines def predict_partial(self, do_partial, lines): outputs = self.model.predict_partial(self.data_loader, return_extra=True, eval_only='both', do_partial=do_partial) if self.do_hard_mask or not do_partial: y_gen_diac, y_gen_tanween, y_gen_shadda = outputs['diacritics'] diac_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda) else: diac_lines = diac_text(lines, outputs["other"][1], outputs["other"][0], selection_mode='1', threshold=self.threshold) return '\n'.join(diac_lines) def predict_majority_vote_context_contrastive(self, overwrite_cache=False): assert isinstance(self.model, PartialDD) if not os.path.exists("dataset/cache/y_gen_diac.npy") or overwrite_cache: if not os.path.exists("dataset/cache"): os.mkdir("dataset/cache") # segment_outputs = self.model.predict_partial(self.data_loader, return_extra=True) segment_outputs = self.model.predict_partial(self.data_loader, return_extra=False, eval_only='both') T.save(segment_outputs, "dataset/cache/cache.pt") else: segment_outputs = T.load("dataset/cache/cache.pt") y_gen_diac, y_gen_tanween, y_gen_shadda = segment_outputs['diacritics'] diacritized_lines, extra_for_lines = self.coalesce_votes_by_majority( y_gen_diac, y_gen_tanween, y_gen_shadda, ) extra_out = { 'line_data': { **extra_for_lines, }, 'segment_data': { **segment_outputs, # 'logits': segment_outputs['logits'], } } return diacritized_lines, extra_out def coalesce_votes_by_majority( self, y_gen_diac: np.ndarray, y_gen_tanween: np.ndarray, y_gen_shadda: np.ndarray, ): prepped_lines_og = [' '.join(tokenize(strip_tatweel(line))) for line in self.original_lines] max_line_chars = max(len(line) for line in prepped_lines_og) diacritics_pred = np.full((len(self.original_lines), max_line_chars), fill_value=-1, dtype=int) count_processed_sents = 0 do_break = False diacritized_lines = [] for sent_idx, line in enumerate(tqdm(prepped_lines_og)): count_processed_sents = sent_idx + 1 line = line.strip() diacritized_line = "" for char_idx, char in enumerate(line): diacritized_line += char char_vote_diacritic = [] # ? This is the voting part if sent_idx not in self.mapping: continue mapping_s_i = self.mapping[sent_idx] for seg_idx in mapping_s_i: if self.data_utils.debug and seg_idx >= 256: do_break = True break mapping_g_i = mapping_s_i[seg_idx] for t_idx in mapping_g_i: mapping_t_i = mapping_g_i[t_idx] if char_idx in mapping_t_i: c_idx = mapping_t_i.index(char_idx) output_idx = np.s_[seg_idx, t_idx, c_idx] diac_h3 = (y_gen_diac[output_idx], y_gen_tanween[output_idx], y_gen_shadda[output_idx]) diac_char_i = HARAKAT_MAP.index(diac_h3) if c_idx < 13 and diac_char_i != 0: char_vote_diacritic.append(diac_char_i) if do_break: break if len(char_vote_diacritic) > 0: char_mv_diac = self.count_votes(char_vote_diacritic) diacritized_line += shakkel_char(*HARAKAT_MAP[char_mv_diac]) diacritics_pred[sent_idx, char_idx] = char_mv_diac else: diacritics_pred[sent_idx, char_idx] = 0 if do_break: break diacritized_lines += [diacritized_line.strip()] print(f'[INFO] Cutting stats from {len(diacritics_pred)} to {count_processed_sents}') extra = { 'diac_pred': diacritics_pred[:count_processed_sents], } return diacritized_lines, extra