File size: 3,461 Bytes
d7e4f1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from model.layers import *
# from layers import *

import pathlib
from pathlib import Path


CONTEXT_SIZE = 5
n_hidden = 100
n_embed = 10
EN_VOCAB_SIZE = 27
AR_VOCAB_SIZE = 37
ACTIVATION = 'relu'

ar_itos = {0: '.', 1: 'ء', 2: 'آ', 3: 'أ', 4: 'ؤ', 5: 'إ', 6: 'ئ', 7: 'ا', 8: 'ب', 9: 'ة', 10: 'ت', 11: 'ث', 12: 'ج', 13: 'ح', 14: 'خ', 15: 'د', 16: 'ذ', 17: 'ر', 18: 'ز', 19: 'س', 20: 'ش', 21: 'ص', 22: 'ض', 23: 'ط', 24: 'ظ', 25: 'ع', 26: 'غ', 27: 'ف', 28: 'ق', 29: 'ك', 30: 'ل', 31: 'م', 32: 'ن', 33: 'ه', 34: 'و', 35: 'ى', 36: 'ي'}
en_itos= {0: '.', 1: '-', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: 'i', 11: 'j', 12: 'k', 13: 'l', 14: 'm', 15: 'n', 16: 'o', 17: 'p', 18: 'q', 19: 'r', 20: 's', 21: 't', 22: 'u', 23: 'v', 24: 'w', 25: 'y', 26: 'z'}
arabic_layers = [
    Linear(CONTEXT_SIZE*n_embed , n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION),
    Linear(n_hidden, n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION),
    Linear(n_hidden, n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION),
    Linear(n_hidden , AR_VOCAB_SIZE)
]

english_layers = [
    Linear(CONTEXT_SIZE*n_embed , n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION),
    Linear(n_hidden, n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION),
    Linear(n_hidden, n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION),
    Linear(n_hidden , EN_VOCAB_SIZE)
]



parent_path = Path(__file__).parent
arabic_dict = torch.load(Path.joinpath(parent_path,'weights/ar_dataset_weights.pt'))
english_dict= torch.load(Path.joinpath(parent_path,'weights/en_dataset_weights.pt'))

## Weights
arabic_params = arabic_dict['params']
english_params = english_dict['params']

## Batch norm means ans stds
arabic_bn_conf = arabic_dict['bn_conf']
english_bn_conf = english_dict['bn_conf']


# Load embeddings
arabic_embedding = arabic_params[0]
english_embedding = english_params[0]

## Load weights
j = 0
for i,l in enumerate(arabic_layers):
    l.set_parameters( arabic_params[i+1] )
    if l.__class__.__name__ == "BatchNorm":
        l.set_mean_std(arabic_bn_conf[j])
        j+=1

j = 0
for i,l in enumerate(english_layers):
    l.set_parameters( english_params[i+1] )
    if l.__class__.__name__ == "BatchNorm":
        l.set_mean_std(english_bn_conf[j])
        j+=1

def forward(x_batch, is_training,lang):

    if lang =='ar':
        embedding = arabic_embedding
        layers = arabic_layers

    elif lang =='en':
        embedding = english_embedding
        layers = english_layers


    x_batch = embedding[x_batch]
    x = x_batch.view(x_batch.shape[0], -1)
    for layer in layers:
        x = layer(x, is_training)
    return x


def generate_name(lang):

    w = ''
    last_ch = [0]* CONTEXT_SIZE
    while True:
        last_ch = torch.tensor(last_ch).unsqueeze(0)


        x = forward(last_ch, False, lang)
        p = torch.softmax(x, dim=1)
        next_ch = torch.multinomial(p, num_samples=1, replacement=True).item()
        if lang =='ar':
            w += ar_itos[next_ch]
        elif lang == 'en':
            w += en_itos[next_ch]

        last_ch = last_ch.clone().detach().squeeze(0)
        last_ch = last_ch.tolist()
        last_ch = last_ch[1:] + [next_ch]
        if next_ch == 0:
            break

    return w[:-1]

def generate_names(n,lang):
    ret = []
    for i in range(n):
        ret.append(generate_name(lang))

    return ret


if __name__ == '__main__':

    pass