File size: 2,464 Bytes
d36d50b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7c4b94
d36d50b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os

from pyarabic.araby import tokenize, strip_tashkeel

import numpy as np
import torch as T
from torch.utils.data import Dataset

from data_utils import DatasetUtils
import diac_utils as du

class DataRetriever(Dataset):
    def __init__(self, data_utils : DatasetUtils, lines: list):
        super(DataRetriever).__init__()

        self.data_utils = data_utils
        self.lines = lines

    def preprocess(self, data, dtype=T.long):
        return [T.tensor(np.array(x), dtype=dtype) for x in data]

    def __len__(self):
        return len(self.lines) 

    def __getitem__(self, idx):
        word_x, char_x, diac_x, diac_y = self.create_sentence(idx)
        return self.preprocess((word_x, char_x, diac_x)), T.tensor(diac_y, dtype=T.long), [0]

    def create_sentence(self, idx):
        line = self.lines[idx]
        tokens = tokenize(line.strip())

        word_x = []
        char_x = []
        diac_x = []
        diac_y = []
        diac_y_tmp = []
        
        for word in tokens:
            word = du.strip_unknown_tashkeel(word)
            word_chars = du.split_word_on_characters_with_diacritics(word)
            cx, cy, cy_3head = du.create_label_for_word(word_chars)

            word_strip = strip_tashkeel(word)
            word_x += [self.data_utils.w2idx[word_strip] if word_strip in self.data_utils.w2idx else self.data_utils.w2idx["<pad>"]]

            char_x += [self.data_utils.pad_and_truncate_sequence(cx, self.data_utils.max_word_len)]
            
            diac_y += [self.data_utils.pad_and_truncate_sequence(cy, self.data_utils.max_word_len, pad=self.data_utils.pad_target_val)]
            diac_y_tmp += [self.data_utils.pad_and_truncate_sequence(cy_3head, self.data_utils.max_word_len, pad=[self.data_utils.pad_target_val]*3)]

        diac_x = self.data_utils.create_decoder_input(diac_y_tmp)

        max_slen = self.data_utils.max_sent_len
        max_wlen = self.data_utils.max_word_len
        p_val = self.data_utils.pad_val
        pt_val = self.data_utils.pad_target_val

        word_x = self.data_utils.pad_and_truncate_sequence(word_x, max_slen)
        char_x = self.data_utils.pad_and_truncate_sequence(char_x, max_slen, pad=[p_val]*max_wlen)
        diac_x = self.data_utils.pad_and_truncate_sequence(diac_x, max_slen, pad=[[p_val]*8]*max_wlen)
        diac_y = self.data_utils.pad_and_truncate_sequence(diac_y, max_slen, pad=[pt_val]*max_wlen)

        return word_x, char_x, diac_x, diac_y