bkhmsi's picture
initialized repo
d36d50b
raw
history blame
7.13 kB
from typing import List
import torch as T
import numpy as np
from pyarabic.araby import (
tokenize,
strip_tashkeel,
strip_tatweel,
DIACRITICS
)
SEPARATE_DIACRITICS = {
"FATHA": 1,
"KASRA": 2,
"DAMMA": 3,
"SUKUN": 4
}
HARAKAT_MAP = [
#^ (haraka, tanween, shadda)
(0,0,0), #< No diacs on char
(1,0,0),
(1,1,0), #< Tanween on 2nd slot
(2,0,0),
(2,1,0),
(3,0,0),
(3,1,0),
(4,0,0),
(0,0,1), #< shadda on 3rd slot
(1,0,1),
(1,1,1),
(2,0,1),
(2,1,1),
(3,0,1),
(3,1,1),
(0,0,0), #< Padding == -1 (also for spaces)
]
SPECIAL_TOKENS = ['<pad>', '<unk>', '<num>', '<punc>']
LETTER_LIST = SPECIAL_TOKENS + list("ุกุขุฃุคุฅุฆุงุจุฉุชุซุฌุญุฎุฏุฐุฑุฒุณุดุตุถุทุธุนุบูู‚ูƒู„ู…ู†ู‡ูˆู‰ูŠ")
CLASSES_LIST = [' ', 'ูŽ', 'ู‹', 'ู', 'ูŒ', 'ู', 'ู', 'ู’', 'ู‘', 'ู‘ูŽ', 'ู‘ู‹', 'ู‘ู', 'ู‘ูŒ', 'ู‘ู', 'ู‘ู']
DIACRITICS_SHORT = [' ', 'ูŽ', 'ู‹', 'ู', 'ู', 'ู', 'ูŒ', 'ู’', 'ู‘']
NUMBERS = list("0123456789")
DELIMITERS = ["ุŒ","ุ›",",",";","ยซ","ยป","{","}","(",")","[","]",".","*","-",":","?","!","ุŸ"]
UNKNOWN_DIACRITICS = list(set(DIACRITICS).difference(set(DIACRITICS_SHORT)))
def shakkel_char(diac: int, tanween: bool, shadda: bool) -> str:
returned_text = ""
if shadda and diac != SEPARATE_DIACRITICS["SUKUN"]:
returned_text += "\u0651"
if diac == SEPARATE_DIACRITICS["FATHA"]:
returned_text += "\u064E" if not tanween else "\u064B"
elif diac == SEPARATE_DIACRITICS["KASRA"]:
returned_text += "\u0650" if not tanween else "\u064D"
elif diac == SEPARATE_DIACRITICS["DAMMA"]:
returned_text += "\u064F" if not tanween else "\u064C"
elif diac == SEPARATE_DIACRITICS["SUKUN"]:
returned_text += "\u0652"
return returned_text
def diac_ids_of_line(line: str):
words = tokenize(line)
diacs = []
for word in words:
word_chars = split_word_on_characters_with_diacritics(word)
cx, cy, cy_3head = create_label_for_word(word_chars)
diacs.extend(cy)
diacs.append(-1)
return np.array(diacs[:-1])
def strip_unknown_tashkeel(word: str):
#! FIXME! warnings.warn("Stripping unknown tashkeel is disabled.")
return word
return ''.join(c for c in word if c not in UNKNOWN_DIACRITICS)
def split_word_on_characters_with_diacritics(word: str):
'''
TODO! Make faster without deque and looping
Returns: List[List[char: "letter or diacritic"]]
'''
chars_w_diac = []
i_start = 0
for i_c, c in enumerate(word):
#! FIXME! DIACRITICS_SHORT is missing a lot of less common diacritics ...
#! which are then treated as letters during splitting.
# if c not in DIACRITICS:
if c not in DIACRITICS_SHORT:
sub = list(word[i_start:i_c])
chars_w_diac.append(sub)
i_start = i_c
sub = list(word[i_start:])
if sub:
chars_w_diac.append(sub)
if not chars_w_diac[0]:
chars_w_diac = chars_w_diac[1:]
return chars_w_diac
def char_type(char: str):
if char in LETTER_LIST:
return LETTER_LIST.index(char)
elif char in NUMBERS:
return LETTER_LIST.index('<num>')
elif char in DELIMITERS:
return LETTER_LIST.index('<punc>')
else:
return LETTER_LIST.index('<unk>')
def create_labels(char_w_diac: str):
remap_dict = {0: 0, 1: 1, 3: 2, 5: 3, 7: 4}
char_w_diac = [char_w_diac[0]] + list(set(char_w_diac[1:]))
if len(char_w_diac) > 3:
char_w_diac = char_w_diac[:2] if DIACRITICS_SHORT[8] not in char_w_diac else char_w_diac[:3]
char_idx = None
diacritic_index = None
head_3 = None
char_idx = char_type(char_w_diac[0])
diacs = set(char_w_diac[1:])
diac_h3 = [0, 0, 0]
for diac in diacs:
if diac in DIACRITICS_SHORT:
diac_idx = DIACRITICS_SHORT.index(diac)
if diac_idx in [2, 4, 6]: #< Tanween
diac_h3[0] = remap_dict[diac_idx - 1]
diac_h3[1] = 1
elif diac_idx == 8: #< shadda
diac_h3[2] = 1
else: #< Haraka or sukoon
diac_h3[0] = remap_dict[diac_idx]
assert not (diac_h3[0] == 4 and (diac_h3[1] or diac_h3[2]))
diacritic_index = HARAKAT_MAP.index(tuple(diac_h3))
return char_idx, diacritic_index, diac_h3
if len(char_w_diac) == 1:
return char_idx, 0, [remap_dict[0], 0, 0]
elif len(char_w_diac) == 2: # If shadda OR diac
diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1])
if diacritic_index in [2, 4, 6]: # list of tanween
head_3 = [remap_dict[diacritic_index - 1], 1, 0]
elif diacritic_index == 8:
head_3 = [0, 0, 1]
else:
head_3 = [remap_dict[diacritic_index], 0, 0]
elif len(char_w_diac) == 3: # If shadda AND diac
if DIACRITICS_SHORT[8] == char_w_diac[1]:
diacritic_index = DIACRITICS_SHORT.index(char_w_diac[2])
else:
diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1])
if diacritic_index in [2, 4, 6]: # list of tanween
head_3 = [remap_dict[diacritic_index - 1], 1, 1]
else:
head_3 = [remap_dict[diacritic_index], 0, 1]
diacritic_index = diacritic_index+8
return char_idx, diacritic_index, head_3
def create_label_for_word(split_word: List[List[str]]):
word_char_indices = []
word_diac_indices = []
word_diac_indices_h3 = []
for char_w_diac in split_word:
char_idx, diac_idx, diac_h3 = create_labels(char_w_diac)
if char_idx == None:
print(split_word)
raise ValueError(char_idx)
word_char_indices.append(char_idx)
word_diac_indices.append(diac_idx)
word_diac_indices_h3.append(diac_h3)
return word_char_indices, word_diac_indices, word_diac_indices_h3
def flat_2_3head(output: T.Tensor):
'''
output: [b tw tc]
'''
haraka, tanween, shadda = [], [], []
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
# 0, F, FF, K, KK, D, DD, S, Sh, ShF, ShFF, ShK, ShKK, ShD, ShDD
b, ts, tw = output.shape
for b_idx in range(b):
h_s, t_s, s_s = [], [], []
for w_idx in range(ts):
h_w, t_w, s_w = [], [], []
for c_idx in range(tw):
c = HARAKAT_MAP[int(output[b_idx, w_idx, c_idx])]
h_w += [c[0]]
t_w += [c[1]]
s_w += [c[2]]
h_s += [h_w]
t_s += [t_w]
s_s += [s_w]
haraka += [h_s]
tanween += [t_s]
shadda += [s_s]
return haraka, tanween, shadda
def flat2_3head(diac_idx):
'''
diac_idx: [tw]
'''
haraka, tanween, shadda = [], [], []
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
# 0, F, FF, K, KK, D, DD, S, Sh, ShF, ShFF, ShK, ShKK, ShD, ShDD
for diac in diac_idx:
c_out = HARAKAT_MAP[diac]
haraka += [c_out[0]]
tanween += [c_out[1]]
shadda += [c_out[2]]
return np.array(haraka), np.array(tanween), np.array(shadda)