|
from typing import List |
|
|
|
import torch as T |
|
import numpy as np |
|
|
|
from pyarabic.araby import ( |
|
tokenize, |
|
strip_tashkeel, |
|
strip_tatweel, |
|
DIACRITICS |
|
) |
|
|
|
SEPARATE_DIACRITICS = { |
|
"FATHA": 1, |
|
"KASRA": 2, |
|
"DAMMA": 3, |
|
"SUKUN": 4 |
|
} |
|
|
|
HARAKAT_MAP = [ |
|
|
|
(0,0,0), |
|
(1,0,0), |
|
(1,1,0), |
|
(2,0,0), |
|
(2,1,0), |
|
(3,0,0), |
|
(3,1,0), |
|
(4,0,0), |
|
(0,0,1), |
|
(1,0,1), |
|
(1,1,1), |
|
(2,0,1), |
|
(2,1,1), |
|
(3,0,1), |
|
(3,1,1), |
|
(0,0,0), |
|
] |
|
|
|
DIAC_PAD_IDX = -1 |
|
|
|
SPECIAL_TOKENS = ['<pad>', '<unk>', '<num>', '<punc>'] |
|
LETTER_LIST = SPECIAL_TOKENS + list("ุกุขุฃุคุฅุฆุงุจุฉุชุซุฌุญุฎุฏุฐุฑุฒุณุดุตุถุทุธุนุบููููู
ููููู") |
|
CLASSES_LIST = [' ', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ูู', 'ูู', 'ูู', 'ูู', 'ูู', 'ูู'] |
|
DIACRITICS_SHORT = [' ', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู'] |
|
NUMBERS = list("0123456789") |
|
DELIMITERS = ["ุ","ุ",",",";","ยซ","ยป","{","}","(",")","[","]",".","*","-",":","?","!","ุ"] |
|
|
|
UNKNOWN_DIACRITICS = list(set(DIACRITICS).difference(set(DIACRITICS_SHORT))) |
|
|
|
def shakkel_char(diac: int, tanween: bool, shadda: bool) -> str: |
|
returned_text = "" |
|
if shadda and diac != SEPARATE_DIACRITICS["SUKUN"]: |
|
returned_text += "\u0651" |
|
|
|
if diac == SEPARATE_DIACRITICS["FATHA"]: |
|
returned_text += "\u064E" if not tanween else "\u064B" |
|
elif diac == SEPARATE_DIACRITICS["KASRA"]: |
|
returned_text += "\u0650" if not tanween else "\u064D" |
|
elif diac == SEPARATE_DIACRITICS["DAMMA"]: |
|
returned_text += "\u064F" if not tanween else "\u064C" |
|
elif diac == SEPARATE_DIACRITICS["SUKUN"]: |
|
returned_text += "\u0652" |
|
|
|
return returned_text |
|
|
|
def diac_ids_of_line(line: str): |
|
diacs = [] |
|
words = tokenize(line) |
|
for word in words: |
|
word_chars = split_word_on_characters_with_diacritics(word) |
|
_cx, cy, _cy_3head = create_label_for_word(word_chars) |
|
diacs.extend(cy) |
|
diacs.append(DIAC_PAD_IDX) |
|
return np.array(diacs[:-1]) |
|
|
|
def strip_unknown_tashkeel(word: str): |
|
|
|
return word |
|
return ''.join(c for c in word if c not in UNKNOWN_DIACRITICS) |
|
|
|
def create_gt_labels(lines): |
|
gt_labels = [] |
|
for line in lines: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gt_labels_line = diac_ids_of_line(line) |
|
gt_labels.append(gt_labels_line) |
|
return gt_labels |
|
|
|
def split_word_on_characters_with_diacritics(word: str): |
|
''' |
|
TODO! Make faster without deque and looping |
|
Returns: List[List[char: "letter or diacritic"]] |
|
''' |
|
chars_w_diac = [] |
|
i_start = 0 |
|
for i_c, c in enumerate(word): |
|
|
|
|
|
|
|
if c not in DIACRITICS_SHORT: |
|
sub = list(word[i_start:i_c]) |
|
chars_w_diac.append(sub) |
|
i_start = i_c |
|
sub = list(word[i_start:]) |
|
if sub: |
|
chars_w_diac.append(sub) |
|
if not chars_w_diac[0]: |
|
chars_w_diac = chars_w_diac[1:] |
|
return chars_w_diac |
|
|
|
|
|
def load_lines(path: str, *, strip: bool): |
|
with open(path, 'r', encoding="utf-8", newline='\n') as fin: |
|
if strip: |
|
original_lines = [strip_tashkeel(normalize_spaces(line)) for line in fin.readlines()] |
|
else: |
|
original_lines = [normalize_spaces(line) for line in fin.readlines()] |
|
return original_lines |
|
|
|
def normalize_spaces(line: str): |
|
return ' '.join(tokenize(line.strip())) |
|
|
|
|
|
def char_type(char: str): |
|
if char in LETTER_LIST: |
|
return LETTER_LIST.index(char) |
|
elif char in NUMBERS: |
|
return LETTER_LIST.index('<num>') |
|
elif char in DELIMITERS: |
|
return LETTER_LIST.index('<punc>') |
|
else: |
|
return LETTER_LIST.index('<unk>') |
|
|
|
def create_labels(char_w_diac: str): |
|
remap_dict = {0: 0, 1: 1, 3: 2, 5: 3, 7: 4} |
|
char_w_diac = [char_w_diac[0]] + list(set(char_w_diac[1:])) |
|
if len(char_w_diac) > 3: |
|
char_w_diac = char_w_diac[:2] if DIACRITICS_SHORT[8] not in char_w_diac else char_w_diac[:3] |
|
|
|
char_idx = None |
|
diacritic_index = None |
|
head_3 = None |
|
|
|
char_idx = char_type(char_w_diac[0]) |
|
diacs = set(char_w_diac[1:]) |
|
diac_h3 = [0, 0, 0] |
|
for diac in diacs: |
|
if diac in DIACRITICS_SHORT: |
|
diac_idx = DIACRITICS_SHORT.index(diac) |
|
if diac_idx in [2, 4, 6]: |
|
diac_h3[0] = remap_dict[diac_idx - 1] |
|
diac_h3[1] = 1 |
|
elif diac_idx == 8: |
|
diac_h3[2] = 1 |
|
else: |
|
diac_h3[0] = remap_dict[diac_idx] |
|
assert not (diac_h3[0] == 4 and (diac_h3[1] or diac_h3[2])) |
|
diacritic_index = HARAKAT_MAP.index(tuple(diac_h3)) |
|
return char_idx, diacritic_index, diac_h3 |
|
if len(char_w_diac) == 1: |
|
return char_idx, 0, [remap_dict[0], 0, 0] |
|
elif len(char_w_diac) == 2: |
|
diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1]) |
|
if diacritic_index in [2, 4, 6]: |
|
head_3 = [remap_dict[diacritic_index - 1], 1, 0] |
|
elif diacritic_index == 8: |
|
head_3 = [0, 0, 1] |
|
else: |
|
head_3 = [remap_dict[diacritic_index], 0, 0] |
|
elif len(char_w_diac) == 3: |
|
if DIACRITICS_SHORT[8] == char_w_diac[1]: |
|
diacritic_index = DIACRITICS_SHORT.index(char_w_diac[2]) |
|
else: |
|
diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1]) |
|
|
|
if diacritic_index in [2, 4, 6]: |
|
head_3 = [remap_dict[diacritic_index - 1], 1, 1] |
|
else: |
|
head_3 = [remap_dict[diacritic_index], 0, 1] |
|
diacritic_index = diacritic_index+8 |
|
|
|
return char_idx, diacritic_index, head_3 |
|
|
|
def create_label_for_word(split_word: List[List[str]]): |
|
word_char_indices = [] |
|
word_diac_indices = [] |
|
word_diac_indices_h3 = [] |
|
for char_w_diac in split_word: |
|
char_idx, diac_idx, diac_h3 = create_labels(char_w_diac) |
|
if char_idx == None: |
|
print(split_word) |
|
raise ValueError(char_idx) |
|
word_char_indices.append(char_idx) |
|
word_diac_indices.append(diac_idx) |
|
word_diac_indices_h3.append(diac_h3) |
|
return word_char_indices, word_diac_indices, word_diac_indices_h3 |
|
|
|
|
|
def flat_2_3head(output: T.Tensor): |
|
''' |
|
output: [b tw tc] |
|
''' |
|
haraka, tanween, shadda = [], [], [] |
|
|
|
|
|
|
|
|
|
b, ts, tw = output.shape |
|
|
|
for b_idx in range(b): |
|
h_s, t_s, s_s = [], [], [] |
|
for w_idx in range(ts): |
|
h_w, t_w, s_w = [], [], [] |
|
for c_idx in range(tw): |
|
c = HARAKAT_MAP[int(output[b_idx, w_idx, c_idx])] |
|
h_w += [c[0]] |
|
t_w += [c[1]] |
|
s_w += [c[2]] |
|
h_s += [h_w] |
|
t_s += [t_w] |
|
s_s += [s_w] |
|
|
|
haraka += [h_s] |
|
tanween += [t_s] |
|
shadda += [s_s] |
|
|
|
|
|
return haraka, tanween, shadda |
|
|
|
def flat2_3head(diac_idx): |
|
''' |
|
diac_idx: [tw] |
|
''' |
|
haraka, tanween, shadda = [], [], [] |
|
|
|
|
|
|
|
for diac in diac_idx: |
|
c_out = HARAKAT_MAP[diac] |
|
haraka += [c_out[0]] |
|
tanween += [c_out[1]] |
|
shadda += [c_out[2]] |
|
|
|
return np.array(haraka), np.array(tanween), np.array(shadda) |