File size: 8,168 Bytes
d36d50b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb42b73
 
d36d50b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb42b73
d36d50b
 
bb42b73
d36d50b
bb42b73
d36d50b
 
 
 
 
 
 
bb42b73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d36d50b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb42b73
 
 
 
 
 
 
 
 
 
 
 
d36d50b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb42b73
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
from typing import List

import torch as T
import numpy as np

from pyarabic.araby import (
    tokenize,
    strip_tashkeel,
    strip_tatweel,
    DIACRITICS
)

SEPARATE_DIACRITICS = {
    "FATHA": 1,
    "KASRA": 2,
    "DAMMA": 3,
    "SUKUN": 4
}

HARAKAT_MAP = [
    #^ (haraka, tanween, shadda)
    (0,0,0), #< No diacs on char
    (1,0,0),
    (1,1,0), #< Tanween on 2nd slot
    (2,0,0),
    (2,1,0),
    (3,0,0),
    (3,1,0),
    (4,0,0),
    (0,0,1), #< shadda on 3rd slot
    (1,0,1),
    (1,1,1),
    (2,0,1),
    (2,1,1),
    (3,0,1),
    (3,1,1),
    (0,0,0), #< Padding == -1 (also for spaces)
]

DIAC_PAD_IDX = -1

SPECIAL_TOKENS = ['<pad>', '<unk>', '<num>', '<punc>']
LETTER_LIST = SPECIAL_TOKENS + list("ุกุขุฃุคุฅุฆุงุจุฉุชุซุฌุญุฎุฏุฐุฑุฒุณุดุตุถุทุธุนุบูู‚ูƒู„ู…ู†ู‡ูˆู‰ูŠ")
CLASSES_LIST = [' ', 'ูŽ', 'ู‹', 'ู', 'ูŒ', 'ู', 'ู', 'ู’', 'ู‘', 'ู‘ูŽ', 'ู‘ู‹', 'ู‘ู', 'ู‘ูŒ', 'ู‘ู', 'ู‘ู']
DIACRITICS_SHORT = [' ', 'ูŽ', 'ู‹', 'ู', 'ู', 'ู', 'ูŒ', 'ู’', 'ู‘']
NUMBERS = list("0123456789")
DELIMITERS = ["ุŒ","ุ›",",",";","ยซ","ยป","{","}","(",")","[","]",".","*","-",":","?","!","ุŸ"]

UNKNOWN_DIACRITICS = list(set(DIACRITICS).difference(set(DIACRITICS_SHORT)))

def shakkel_char(diac: int, tanween: bool, shadda: bool) -> str:
    returned_text = ""
    if shadda and diac != SEPARATE_DIACRITICS["SUKUN"]:
        returned_text += "\u0651"

    if diac == SEPARATE_DIACRITICS["FATHA"]:
        returned_text += "\u064E" if not tanween else "\u064B"
    elif diac == SEPARATE_DIACRITICS["KASRA"]:
        returned_text += "\u0650" if not tanween else "\u064D"
    elif diac == SEPARATE_DIACRITICS["DAMMA"]:
        returned_text += "\u064F" if not tanween else "\u064C"
    elif diac == SEPARATE_DIACRITICS["SUKUN"]:
        returned_text += "\u0652"

    return returned_text

def diac_ids_of_line(line: str):
    diacs = []
    words = tokenize(line)
    for word in words:
        word_chars = split_word_on_characters_with_diacritics(word)
        _cx, cy, _cy_3head = create_label_for_word(word_chars)
        diacs.extend(cy)
        diacs.append(DIAC_PAD_IDX)
    return np.array(diacs[:-1])

def strip_unknown_tashkeel(word: str):
    #! FIXME! warnings.warn("Stripping unknown tashkeel is disabled.")
    return word
    return ''.join(c for c in word if c not in UNKNOWN_DIACRITICS)

def create_gt_labels(lines):
    gt_labels = []
    for line in lines:
        # gt_labels_line = []
        # tokens = tokenize(line.strip())
        # for w_idx, word in enumerate(tokens):
        #     split_word = self.split_word_on_characters_with_diacritics(word)
        #     _, cy_flat, _ = du.create_label_for_word(split_word)

        #     gt_labels_line.extend(cy_flat)
        #     if w_idx+1 < len(tokens):
        #         gt_labels_line += [0]
        
        gt_labels_line = diac_ids_of_line(line)
        gt_labels.append(gt_labels_line)
    return gt_labels

def split_word_on_characters_with_diacritics(word: str):
    '''
    TODO! Make faster without deque and looping
    Returns: List[List[char: "letter or diacritic"]]
    '''
    chars_w_diac = []
    i_start = 0
    for i_c, c in enumerate(word):
        #! FIXME! DIACRITICS_SHORT is missing a lot of less common diacritics ...
        #! which are then treated as letters during splitting.
        # if c not in DIACRITICS:
        if c not in DIACRITICS_SHORT:
            sub = list(word[i_start:i_c])
            chars_w_diac.append(sub)
            i_start = i_c
    sub = list(word[i_start:])
    if sub:
        chars_w_diac.append(sub)
    if not chars_w_diac[0]:
        chars_w_diac = chars_w_diac[1:]
    return chars_w_diac


def load_lines(path: str, *, strip: bool):
    with open(path, 'r', encoding="utf-8", newline='\n') as fin:
        if strip:
            original_lines = [strip_tashkeel(normalize_spaces(line)) for line in fin.readlines()]
        else:
            original_lines = [normalize_spaces(line) for line in fin.readlines()]
    return original_lines

def normalize_spaces(line: str):
    return ' '.join(tokenize(line.strip()))


def char_type(char: str):
    if char in LETTER_LIST:
        return LETTER_LIST.index(char)
    elif char in NUMBERS:
        return LETTER_LIST.index('<num>')
    elif char in DELIMITERS:
        return LETTER_LIST.index('<punc>')
    else:
        return LETTER_LIST.index('<unk>')

def create_labels(char_w_diac: str):
    remap_dict = {0: 0, 1: 1, 3: 2, 5: 3, 7: 4}
    char_w_diac = [char_w_diac[0]] + list(set(char_w_diac[1:]))
    if len(char_w_diac) > 3:
        char_w_diac = char_w_diac[:2] if DIACRITICS_SHORT[8] not in char_w_diac else char_w_diac[:3]

    char_idx = None
    diacritic_index = None
    head_3 = None

    char_idx = char_type(char_w_diac[0])
    diacs = set(char_w_diac[1:])
    diac_h3 = [0, 0, 0]
    for diac in diacs:
        if diac in DIACRITICS_SHORT:
            diac_idx = DIACRITICS_SHORT.index(diac)
            if diac_idx in [2, 4, 6]: #< Tanween
                diac_h3[0] = remap_dict[diac_idx - 1]
                diac_h3[1] = 1
            elif diac_idx == 8: #< shadda
                diac_h3[2] = 1
            else: #< Haraka or sukoon
                diac_h3[0] = remap_dict[diac_idx]
    assert not (diac_h3[0] == 4 and (diac_h3[1] or diac_h3[2]))
    diacritic_index = HARAKAT_MAP.index(tuple(diac_h3))
    return char_idx, diacritic_index, diac_h3
    if len(char_w_diac) == 1:
        return char_idx, 0, [remap_dict[0], 0, 0]
    elif len(char_w_diac) == 2:  # If shadda OR diac
        diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1])
        if diacritic_index in [2, 4, 6]:  # list of tanween
            head_3 = [remap_dict[diacritic_index - 1], 1, 0]
        elif diacritic_index == 8:
            head_3 = [0, 0, 1]
        else:
            head_3 = [remap_dict[diacritic_index], 0, 0]
    elif len(char_w_diac) == 3:  # If shadda AND diac
        if DIACRITICS_SHORT[8] == char_w_diac[1]:
            diacritic_index = DIACRITICS_SHORT.index(char_w_diac[2])
        else:
            diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1])

        if diacritic_index in [2, 4, 6]:  # list of tanween
            head_3 = [remap_dict[diacritic_index - 1], 1, 1]
        else:
            head_3 = [remap_dict[diacritic_index], 0, 1]
        diacritic_index = diacritic_index+8

    return char_idx, diacritic_index, head_3

def create_label_for_word(split_word: List[List[str]]):
    word_char_indices = []
    word_diac_indices = []
    word_diac_indices_h3 = []
    for char_w_diac in split_word:
        char_idx, diac_idx, diac_h3 = create_labels(char_w_diac)
        if char_idx == None:
            print(split_word)
            raise ValueError(char_idx)
        word_char_indices.append(char_idx)
        word_diac_indices.append(diac_idx)
        word_diac_indices_h3.append(diac_h3)
    return word_char_indices, word_diac_indices, word_diac_indices_h3


def flat_2_3head(output: T.Tensor):
    '''
    output: [b tw tc]
    '''
    haraka, tanween, shadda = [], [], []

    # 0, 1,  2, 3,  4, 5,  6, 7, 8,  9,     10,  11,   12,  13,   14
    # 0, F, FF, K, KK, D, DD, S, Sh, ShF, ShFF, ShK, ShKK, ShD, ShDD

    b, ts, tw = output.shape

    for b_idx in range(b):
        h_s, t_s, s_s = [], [], []
        for w_idx in range(ts):
            h_w, t_w, s_w = [], [], []
            for c_idx in range(tw):
                c = HARAKAT_MAP[int(output[b_idx, w_idx, c_idx])]
                h_w  += [c[0]]
                t_w += [c[1]]
                s_w  += [c[2]]
            h_s += [h_w]
            t_s += [t_w]
            s_s += [s_w]

        haraka  += [h_s]
        tanween += [t_s]
        shadda  += [s_s]


    return haraka, tanween, shadda

def flat2_3head(diac_idx):
    '''
    diac_idx: [tw]
    '''
    haraka, tanween, shadda = [], [], []
    # 0, 1,  2, 3,  4, 5,  6, 7, 8,  9,     10,  11,   12,  13,   14
    # 0, F, FF, K, KK, D, DD, S, Sh, ShF, ShFF, ShK, ShKK, ShD, ShDD

    for diac in diac_idx:
        c_out = HARAKAT_MAP[diac]
        haraka += [c_out[0]]
        tanween += [c_out[1]]
        shadda += [c_out[2]]

    return np.array(haraka), np.array(tanween), np.array(shadda)