Spaces:
Runtime error
Runtime error
File size: 3,461 Bytes
d7e4f1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import torch
from model.layers import *
# from layers import *
import pathlib
from pathlib import Path
CONTEXT_SIZE = 5
n_hidden = 100
n_embed = 10
EN_VOCAB_SIZE = 27
AR_VOCAB_SIZE = 37
ACTIVATION = 'relu'
ar_itos = {0: '.', 1: 'ء', 2: 'آ', 3: 'أ', 4: 'ؤ', 5: 'إ', 6: 'ئ', 7: 'ا', 8: 'ب', 9: 'ة', 10: 'ت', 11: 'ث', 12: 'ج', 13: 'ح', 14: 'خ', 15: 'د', 16: 'ذ', 17: 'ر', 18: 'ز', 19: 'س', 20: 'ش', 21: 'ص', 22: 'ض', 23: 'ط', 24: 'ظ', 25: 'ع', 26: 'غ', 27: 'ف', 28: 'ق', 29: 'ك', 30: 'ل', 31: 'م', 32: 'ن', 33: 'ه', 34: 'و', 35: 'ى', 36: 'ي'}
en_itos= {0: '.', 1: '-', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: 'i', 11: 'j', 12: 'k', 13: 'l', 14: 'm', 15: 'n', 16: 'o', 17: 'p', 18: 'q', 19: 'r', 20: 's', 21: 't', 22: 'u', 23: 'v', 24: 'w', 25: 'y', 26: 'z'}
arabic_layers = [
Linear(CONTEXT_SIZE*n_embed , n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION),
Linear(n_hidden, n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION),
Linear(n_hidden, n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION),
Linear(n_hidden , AR_VOCAB_SIZE)
]
english_layers = [
Linear(CONTEXT_SIZE*n_embed , n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION),
Linear(n_hidden, n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION),
Linear(n_hidden, n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION),
Linear(n_hidden , EN_VOCAB_SIZE)
]
parent_path = Path(__file__).parent
arabic_dict = torch.load(Path.joinpath(parent_path,'weights/ar_dataset_weights.pt'))
english_dict= torch.load(Path.joinpath(parent_path,'weights/en_dataset_weights.pt'))
## Weights
arabic_params = arabic_dict['params']
english_params = english_dict['params']
## Batch norm means ans stds
arabic_bn_conf = arabic_dict['bn_conf']
english_bn_conf = english_dict['bn_conf']
# Load embeddings
arabic_embedding = arabic_params[0]
english_embedding = english_params[0]
## Load weights
j = 0
for i,l in enumerate(arabic_layers):
l.set_parameters( arabic_params[i+1] )
if l.__class__.__name__ == "BatchNorm":
l.set_mean_std(arabic_bn_conf[j])
j+=1
j = 0
for i,l in enumerate(english_layers):
l.set_parameters( english_params[i+1] )
if l.__class__.__name__ == "BatchNorm":
l.set_mean_std(english_bn_conf[j])
j+=1
def forward(x_batch, is_training,lang):
if lang =='ar':
embedding = arabic_embedding
layers = arabic_layers
elif lang =='en':
embedding = english_embedding
layers = english_layers
x_batch = embedding[x_batch]
x = x_batch.view(x_batch.shape[0], -1)
for layer in layers:
x = layer(x, is_training)
return x
def generate_name(lang):
w = ''
last_ch = [0]* CONTEXT_SIZE
while True:
last_ch = torch.tensor(last_ch).unsqueeze(0)
x = forward(last_ch, False, lang)
p = torch.softmax(x, dim=1)
next_ch = torch.multinomial(p, num_samples=1, replacement=True).item()
if lang =='ar':
w += ar_itos[next_ch]
elif lang == 'en':
w += en_itos[next_ch]
last_ch = last_ch.clone().detach().squeeze(0)
last_ch = last_ch.tolist()
last_ch = last_ch[1:] + [next_ch]
if next_ch == 0:
break
return w[:-1]
def generate_names(n,lang):
ret = []
for i in range(n):
ret.append(generate_name(lang))
return ret
if __name__ == '__main__':
pass |