|
import os |
|
import pickle |
|
import numpy as np |
|
|
|
from tqdm import tqdm |
|
from prettytable import PrettyTable |
|
from pyarabic.araby import tokenize, strip_tashkeel |
|
import diac_utils as du |
|
|
|
class DatasetUtils: |
|
def __init__(self, config): |
|
self.base_path = config["paths"]["base"] |
|
self.special_tokens = ['<pad>', '<unk>', '<num>', '<punc>'] |
|
self.delimeters = config["sentence-break"]["delimeters"] |
|
self.load_constants(config["paths"]["constants"]) |
|
self.debug = config["debug"] |
|
|
|
self.stride = config["sentence-break"]["stride"] |
|
self.window = config["sentence-break"]["window"] |
|
self.val_stride = config["sentence-break"].get("val-stride", self.stride) |
|
|
|
self.test_stride = config["predictor"]["stride"] |
|
self.test_window = config["predictor"]["window"] |
|
|
|
self.max_word_len = config["train"]["max-word-len"] |
|
self.max_sent_len = config["train"]["max-sent-len"] |
|
self.max_token_count = config["train"]["max-token-count"] |
|
self.pad_target_val = -100 |
|
self.pad_char_id = du.LETTER_LIST.index('<pad>') |
|
|
|
self.markov_signal = config['train'].get('markov-signal', False) |
|
self.batch_first = config['train'].get('batch-first', True) |
|
|
|
self.gt_prob = config["predictor"]["gt-signal-prob"] |
|
if self.gt_prob > 0: |
|
self.s_idx = config["predictor"]["seed-idx"] |
|
subpath = f"test_gt_mask_{self.gt_prob}_{self.s_idx}.txt" |
|
mask_path = os.path.join(self.base_path, "test", subpath) |
|
with open(mask_path, 'r') as fin: |
|
self.gt_mask = fin.readlines() |
|
|
|
if "word-embs" in config["paths"] and config["paths"]["word-embs"].strip() != "": |
|
self.pad_val = self.special_tokens.index("<pad>") |
|
self.embeddings, self.vocab = self.load_embeddings(config["paths"]["word-embs"], config["loader"]["wembs-limit"]) |
|
self.embeddings = self.normalize(self.embeddings, ["unit", "centeremb", "unit"]) |
|
self.w2idx = {word: i for i, word in enumerate(self.vocab)} |
|
|
|
def load_file(self, path): |
|
with open(path, 'rb') as f: |
|
return list(pickle.load(f)) |
|
|
|
def normalize(self, matrix, actions, mean=None): |
|
def length_normalize(matrix): |
|
norms = np.sqrt(np.sum(matrix**2, axis=1)) |
|
norms[norms == 0] = 1 |
|
matrix = matrix / norms[:, np.newaxis] |
|
return matrix |
|
|
|
def mean_center(matrix): |
|
return matrix - mean |
|
|
|
def length_normalize_dimensionwise(matrix): |
|
norms = np.sqrt(np.sum(matrix**2, axis=0)) |
|
norms[norms == 0] = 1 |
|
matrix = matrix / norms |
|
return matrix |
|
|
|
def mean_center_embeddingwise(matrix): |
|
avg = np.mean(matrix, axis=1) |
|
matrix = matrix - avg[:, np.newaxis] |
|
return matrix |
|
|
|
for action in actions: |
|
if action == 'unit': |
|
matrix = length_normalize(matrix) |
|
elif action == 'center': |
|
matrix = mean_center(matrix) |
|
elif action == 'unitdim': |
|
matrix = length_normalize_dimensionwise(matrix) |
|
elif action == 'centeremb': |
|
matrix = mean_center_embeddingwise(matrix) |
|
|
|
return matrix |
|
|
|
def load_constants(self, path): |
|
|
|
|
|
|
|
self.numbers = du.NUMBERS |
|
self.letter_list = du.LETTER_LIST |
|
self.diacritic_list = du.DIACRITICS_SHORT |
|
|
|
def split_word_on_characters_with_diacritics(self, word: str): |
|
return du.split_word_on_characters_with_diacritics(word) |
|
|
|
def load_mapping_v3(self, dtype, file_ext=None): |
|
mapping = {} |
|
if file_ext is None: |
|
file_ext = f"-{self.test_stride}-{self.test_window}.map" |
|
f_name = os.path.join(self.base_path, dtype, dtype + file_ext) |
|
with open(f_name, 'r') as fin: |
|
for line in fin: |
|
sent_idx, seg_idx, t_idx, c_idx = map(int, line.split(',')) |
|
if sent_idx not in mapping: |
|
mapping[sent_idx] = {} |
|
if seg_idx not in mapping[sent_idx]: |
|
mapping[sent_idx][seg_idx] = {} |
|
if t_idx not in mapping[sent_idx][seg_idx]: |
|
mapping[sent_idx][seg_idx][t_idx] = [] |
|
mapping[sent_idx][seg_idx][t_idx] += [c_idx] |
|
return mapping |
|
|
|
def load_mapping_v3_from_list(self, mapping_list): |
|
mapping = {} |
|
for line in mapping_list: |
|
sent_idx, seg_idx, t_idx, c_idx = map(int, line.split(',')) |
|
if sent_idx not in mapping: |
|
mapping[sent_idx] = {} |
|
if seg_idx not in mapping[sent_idx]: |
|
mapping[sent_idx][seg_idx] = {} |
|
if t_idx not in mapping[sent_idx][seg_idx]: |
|
mapping[sent_idx][seg_idx][t_idx] = [] |
|
mapping[sent_idx][seg_idx][t_idx] += [c_idx] |
|
return mapping |
|
|
|
def load_embeddings(self, embs_path, limit=-1): |
|
if self.debug: |
|
return np.zeros((200+len(self.special_tokens),300)), self.special_tokens + ["c"] * 200 |
|
|
|
words = [self.special_tokens[0]] |
|
print(f"[INFO] Reading Embeddings from {embs_path}") |
|
with open(embs_path, encoding='utf-8', mode='r') as fin: |
|
n, d = map(int, fin.readline().split()) |
|
limit = n if limit <= 0 else limit |
|
embeddings = np.zeros((limit+1, d)) |
|
for i, line in tqdm(enumerate(fin), total=limit): |
|
if i >= limit: break |
|
tokens = line.rstrip().split() |
|
words += [tokens[0]] |
|
embeddings[i+1] = list(map(float, tokens[1:])) |
|
return embeddings, words |
|
|
|
def load_file_clean(self, dtype, strip=False): |
|
f_name = os.path.join(self.base_path, dtype, dtype + ".txt") |
|
with open(f_name, 'r', encoding="utf-8", newline='\n') as fin: |
|
if strip: |
|
original_lines = [strip_tashkeel(self.preprocess(line)) for line in fin.readlines()] |
|
else: |
|
original_lines = [self.preprocess(line) for line in fin.readlines()] |
|
return original_lines |
|
|
|
def preprocess(self, line): |
|
return ' '.join(tokenize(line)) |
|
|
|
def pad_and_truncate_sequence(self, tokens, max_len, pad=None): |
|
if pad is None: |
|
pad = self.special_tokens.index("<pad>") |
|
if len(tokens) < max_len: |
|
offset = max_len - len(tokens) |
|
return tokens + [pad] * offset |
|
else: |
|
return tokens[:max_len] |
|
|
|
def stats(self, freq, percentile=90, name="stats"): |
|
table = PrettyTable(["Dataset", "Mean", "Std", "Min", "Max", f"{percentile}th Percentile"]) |
|
freq = np.array(sorted(freq)) |
|
table.add_row([name, freq.mean(), freq.std(), freq.min(), freq.max(), np.percentile(freq, percentile)]) |
|
print(table) |
|
|
|
def create_gt_mask(self, lines, prob, idx, seed=1111): |
|
np.random.seed(seed) |
|
|
|
gt_masks = [] |
|
for line in lines: |
|
tokens = tokenize(line.strip()) |
|
gt_mask_token = "" |
|
for t_idx, token in enumerate(tokens): |
|
gt_mask_token += ''.join(map(str, np.random.binomial(1, prob, len(token)))) |
|
if t_idx+1 < len(tokens): |
|
gt_mask_token += " " |
|
gt_masks += [gt_mask_token] |
|
|
|
subpath = f"test_gt_mask_{prob}_{idx}.txt" |
|
mask_path = os.path.join(self.base_path, "test", subpath) |
|
|
|
with open(mask_path, 'w') as fout: |
|
fout.write('\n'.join(gt_masks)) |
|
|
|
def create_gt_labels(self, lines): |
|
gt_labels = [] |
|
for line in lines: |
|
gt_labels_line = [] |
|
tokens = tokenize(line.strip()) |
|
for w_idx, word in enumerate(tokens): |
|
split_word = self.split_word_on_characters_with_diacritics(word) |
|
_, cy_flat, _ = du.create_label_for_word(split_word) |
|
|
|
gt_labels_line.extend(cy_flat) |
|
if w_idx+1 < len(tokens): |
|
gt_labels_line += [0] |
|
|
|
gt_labels += [gt_labels_line] |
|
return gt_labels |
|
|
|
def get_ce(self, diac_word_y, e_idx=None, return_idx=False): |
|
|
|
if e_idx is None: e_idx = len(diac_word_y) |
|
for c_idx in reversed(range(e_idx)): |
|
if diac_word_y[c_idx] != [0,0,0]: |
|
return diac_word_y[c_idx] if not return_idx else c_idx |
|
return diac_word_y[e_idx-1] if not return_idx else e_idx-1 |
|
|
|
def create_decoder_input(self, diac_code_y, prob=0): |
|
|
|
diac_code_x = np.zeros((*np.array(diac_code_y).shape[:-1], 8)) |
|
if not self.markov_signal: |
|
return list(diac_code_x) |
|
prev_ce = list(np.eye(6)[-1]) + [0,0] |
|
for w_idx, word in enumerate(diac_code_y): |
|
diac_code_x[w_idx, 0, :] = prev_ce |
|
for c_idx, char in enumerate(word[:-1]): |
|
|
|
|
|
if char[0] == self.pad_target_val: |
|
break |
|
haraka = list(np.eye(6)[char[0]]) |
|
diac_code_x[w_idx, c_idx+1, :] = haraka + char[1:] |
|
ce = self.get_ce(diac_code_y[w_idx], c_idx) |
|
prev_ce = list(np.eye(6)[ce[0]]) + ce[1:] |
|
return list(diac_code_x) |