import os |
import pickle |
import numpy as np |
from tqdm import tqdm |
from prettytable import PrettyTable |
from pyarabic.araby import tokenize, strip_tashkeel |
import diac_utils as du |
class DatasetUtils: |
def __init__(self, config): |
self.base_path = config["paths"]["base"] |
self.special_tokens = ['<pad>', '<unk>', '<num>', '<punc>'] |
self.delimeters = config["sentence-break"]["delimeters"] |
self.load_constants(config["paths"]["constants"]) |
self.debug = config["debug"] |
self.stride = config["sentence-break"]["stride"] |
self.window = config["sentence-break"]["window"] |
self.val_stride = config["sentence-break"].get("val-stride", self.stride) |
self.test_stride = config["predictor"]["stride"] |
self.test_window = config["predictor"]["window"] |
self.max_word_len = config["train"]["max-word-len"] |
self.max_sent_len = config["train"]["max-sent-len"] |
self.max_token_count = config["train"]["max-token-count"] |
self.pad_target_val = -100 |
self.pad_char_id = du.DIAC_PAD_IDX |
self.markov_signal = config['train'].get('markov-signal', False) |
self.batch_first = config['train'].get('batch-first', True) |
self.gt_prob = config["predictor"]["gt-signal-prob"] |
if self.gt_prob > 0: |
self.s_idx = config["predictor"]["seed-idx"] |
subpath = f"test_gt_mask_{self.gt_prob}_{self.s_idx}.txt" |
mask_path = os.path.join(self.base_path, "test", subpath) |
with open(mask_path, 'r') as fin: |
self.gt_mask = fin.readlines() |
if "word-embs" in config["paths"] and config["paths"]["word-embs"].strip() != "": |
self.pad_val = self.special_tokens.index("<pad>") |
self.embeddings, self.vocab = self.load_embeddings(config["paths"]["word-embs"], config["loader"]["wembs-limit"]) |
self.embeddings = self.normalize(self.embeddings, ["unit", "centeremb", "unit"]) |
self.w2idx = {word: i for i, word in enumerate(self.vocab)} |
def load_file(self, path): |
with open(path, 'rb') as f: |
return list(pickle.load(f)) |
def normalize(self, matrix, actions, mean=None): |
def length_normalize(matrix): |
norms = np.sqrt(np.sum(matrix**2, axis=1)) |
norms[norms == 0] = 1 |
matrix = matrix / norms[:, np.newaxis] |
return matrix |
def mean_center(matrix): |
return matrix - mean |
def length_normalize_dimensionwise(matrix): |
norms = np.sqrt(np.sum(matrix**2, axis=0)) |
norms[norms == 0] = 1 |
matrix = matrix / norms |
return matrix |
def mean_center_embeddingwise(matrix): |
avg = np.mean(matrix, axis=1) |
matrix = matrix - avg[:, np.newaxis] |
return matrix |
for action in actions: |
if action == 'unit': |
matrix = length_normalize(matrix) |
elif action == 'center': |
matrix = mean_center(matrix) |
elif action == 'unitdim': |
matrix = length_normalize_dimensionwise(matrix) |
elif action == 'centeremb': |
matrix = mean_center_embeddingwise(matrix) |
return matrix |
def load_constants(self, path): |
self.numbers = du.NUMBERS |
self.letter_list = du.LETTER_LIST |
self.diacritic_list = du.DIACRITICS_SHORT |
def split_word_on_characters_with_diacritics(self, word: str): |
return du.split_word_on_characters_with_diacritics(word) |
def load_mapping_v3(self, dtype, file_ext=None): |
mapping = {} |
if file_ext is None: |
file_ext = f"-{self.test_stride}-{self.test_window}.map" |
f_name = os.path.join(self.base_path, dtype, dtype + file_ext) |
with open(f_name, 'r') as fin: |
for line in fin: |
sent_idx, seg_idx, t_idx, c_idx = map(int, line.split(',')) |
if sent_idx not in mapping: |
mapping[sent_idx] = {} |
if seg_idx not in mapping[sent_idx]: |
mapping[sent_idx][seg_idx] = {} |
if t_idx not in mapping[sent_idx][seg_idx]: |
mapping[sent_idx][seg_idx][t_idx] = [] |
mapping[sent_idx][seg_idx][t_idx] += [c_idx] |
return mapping |
def load_mapping_v3_from_list(self, mapping_list): |
mapping = {} |
for line in mapping_list: |
sent_idx, seg_idx, t_idx, c_idx = map(int, line.split(',')) |
if sent_idx not in mapping: |
mapping[sent_idx] = {} |
if seg_idx not in mapping[sent_idx]: |
mapping[sent_idx][seg_idx] = {} |
if t_idx not in mapping[sent_idx][seg_idx]: |
mapping[sent_idx][seg_idx][t_idx] = [] |
mapping[sent_idx][seg_idx][t_idx] += [c_idx] |
return mapping |
def load_embeddings(self, embs_path, limit=-1): |
if self.debug: |
return np.zeros((200+len(self.special_tokens),300)), self.special_tokens + ["c"] * 200 |
words = [self.special_tokens[0]] |
print(f"[INFO] Reading Embeddings from {embs_path}") |
with open(embs_path, encoding='utf-8', mode='r') as fin: |
n, d = map(int, fin.readline().split()) |
limit = n if limit <= 0 else limit |
embeddings = np.zeros((limit+1, d)) |
for i, line in tqdm(enumerate(fin), total=limit): |
if i >= limit: break |
tokens = line.rstrip().split() |
words += [tokens[0]] |
embeddings[i+1] = list(map(float, tokens[1:])) |
return embeddings, words |
def load_file_clean(self, dtype, strip=False): |
f_name = os.path.join(self.base_path, dtype, dtype + ".txt") |
with open(f_name, 'r', encoding="utf-8", newline='\n') as fin: |
if strip: |
original_lines = [strip_tashkeel(self.preprocess(line)) for line in fin.readlines()] |
else: |
original_lines = [self.preprocess(line) for line in fin.readlines()] |
return original_lines |
def preprocess(self, line): |
return ' '.join(tokenize(line)) |
def pad_and_truncate_sequence(self, tokens, max_len, pad=None): |
if pad is None: |
pad = self.special_tokens.index("<pad>") |
if len(tokens) < max_len: |
offset = max_len - len(tokens) |
return tokens + [pad] * offset |
else: |
return tokens[:max_len] |
def stats(self, freq, percentile=90, name="stats"): |
table = PrettyTable(["Dataset", "Mean", "Std", "Min", "Max", f"{percentile}th Percentile"]) |
freq = np.array(sorted(freq)) |
table.add_row([name, freq.mean(), freq.std(), freq.min(), freq.max(), np.percentile(freq, percentile)]) |
print(table) |
def create_gt_mask(self, lines, prob, idx, seed=1111): |
np.random.seed(seed) |
gt_masks = [] |
for line in lines: |
tokens = tokenize(line.strip()) |
gt_mask_token = "" |
for t_idx, token in enumerate(tokens): |
gt_mask_token += ''.join(map(str, np.random.binomial(1, prob, len(token)))) |
if t_idx+1 < len(tokens): |
gt_mask_token += " " |
gt_masks += [gt_mask_token] |
subpath = f"test_gt_mask_{prob}_{idx}.txt" |
mask_path = os.path.join(self.base_path, "test", subpath) |
with open(mask_path, 'w') as fout: |
fout.write('\n'.join(gt_masks)) |
def create_gt_labels(self, lines): |
gt_labels = [] |
for line in lines: |
gt_labels_line = [] |
tokens = tokenize(line.strip()) |
for w_idx, word in enumerate(tokens): |
split_word = self.split_word_on_characters_with_diacritics(word) |
_, cy_flat, _ = du.create_label_for_word(split_word) |
gt_labels_line.extend(cy_flat) |
if w_idx+1 < len(tokens): |
gt_labels_line += [0] |
gt_labels += [gt_labels_line] |
return gt_labels |
def get_ce(self, diac_word_y, e_idx=None, return_idx=False): |
if e_idx is None: e_idx = len(diac_word_y) |
for c_idx in reversed(range(e_idx)): |
if diac_word_y[c_idx] != [0,0,0]: |
return diac_word_y[c_idx] if not return_idx else c_idx |
return diac_word_y[e_idx-1] if not return_idx else e_idx-1 |
def create_decoder_input(self, diac_code_y, prob=0): |
diac_code_x = np.zeros((*np.array(diac_code_y).shape[:-1], 8)) |
if not self.markov_signal: |
return list(diac_code_x) |
prev_ce = list(np.eye(6)[-1]) + [0,0] |
for w_idx, word in enumerate(diac_code_y): |
diac_code_x[w_idx, 0, :] = prev_ce |
for c_idx, char in enumerate(word[:-1]): |
if char[0] == self.pad_target_val: |
break |
haraka = list(np.eye(6)[char[0]]) |
diac_code_x[w_idx, c_idx+1, :] = haraka + char[1:] |
ce = self.get_ce(diac_code_y[w_idx], c_idx) |
prev_ce = list(np.eye(6)[ce[0]]) + ce[1:] |
return list(diac_code_x) |