Spaces:
Runtime error
Runtime error
import torch | |
from model.layers import * | |
# from layers import * | |
import pathlib | |
from pathlib import Path | |
CONTEXT_SIZE = 5 | |
n_hidden = 100 | |
n_embed = 10 | |
EN_VOCAB_SIZE = 27 | |
AR_VOCAB_SIZE = 37 | |
ACTIVATION = 'relu' | |
ar_itos = {0: '.', 1: 'ء', 2: 'آ', 3: 'أ', 4: 'ؤ', 5: 'إ', 6: 'ئ', 7: 'ا', 8: 'ب', 9: 'ة', 10: 'ت', 11: 'ث', 12: 'ج', 13: 'ح', 14: 'خ', 15: 'د', 16: 'ذ', 17: 'ر', 18: 'ز', 19: 'س', 20: 'ش', 21: 'ص', 22: 'ض', 23: 'ط', 24: 'ظ', 25: 'ع', 26: 'غ', 27: 'ف', 28: 'ق', 29: 'ك', 30: 'ل', 31: 'م', 32: 'ن', 33: 'ه', 34: 'و', 35: 'ى', 36: 'ي'} | |
en_itos= {0: '.', 1: '-', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: 'i', 11: 'j', 12: 'k', 13: 'l', 14: 'm', 15: 'n', 16: 'o', 17: 'p', 18: 'q', 19: 'r', 20: 's', 21: 't', 22: 'u', 23: 'v', 24: 'w', 25: 'y', 26: 'z'} | |
arabic_layers = [ | |
Linear(CONTEXT_SIZE*n_embed , n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION), | |
Linear(n_hidden, n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION), | |
Linear(n_hidden, n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION), | |
Linear(n_hidden , AR_VOCAB_SIZE) | |
] | |
english_layers = [ | |
Linear(CONTEXT_SIZE*n_embed , n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION), | |
Linear(n_hidden, n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION), | |
Linear(n_hidden, n_hidden),BatchNorm(n_hidden), Activation(ACTIVATION), | |
Linear(n_hidden , EN_VOCAB_SIZE) | |
] | |
parent_path = Path(__file__).parent | |
arabic_dict = torch.load(Path.joinpath(parent_path,'weights/ar_dataset_weights.pt')) | |
english_dict= torch.load(Path.joinpath(parent_path,'weights/en_dataset_weights.pt')) | |
## Weights | |
arabic_params = arabic_dict['params'] | |
english_params = english_dict['params'] | |
## Batch norm means ans stds | |
arabic_bn_conf = arabic_dict['bn_conf'] | |
english_bn_conf = english_dict['bn_conf'] | |
# Load embeddings | |
arabic_embedding = arabic_params[0] | |
english_embedding = english_params[0] | |
## Load weights | |
j = 0 | |
for i,l in enumerate(arabic_layers): | |
l.set_parameters( arabic_params[i+1] ) | |
if l.__class__.__name__ == "BatchNorm": | |
l.set_mean_std(arabic_bn_conf[j]) | |
j+=1 | |
j = 0 | |
for i,l in enumerate(english_layers): | |
l.set_parameters( english_params[i+1] ) | |
if l.__class__.__name__ == "BatchNorm": | |
l.set_mean_std(english_bn_conf[j]) | |
j+=1 | |
def forward(x_batch, is_training,lang): | |
if lang =='ar': | |
embedding = arabic_embedding | |
layers = arabic_layers | |
elif lang =='en': | |
embedding = english_embedding | |
layers = english_layers | |
x_batch = embedding[x_batch] | |
x = x_batch.view(x_batch.shape[0], -1) | |
for layer in layers: | |
x = layer(x, is_training) | |
return x | |
def generate_name(lang): | |
w = '' | |
last_ch = [0]* CONTEXT_SIZE | |
while True: | |
last_ch = torch.tensor(last_ch).unsqueeze(0) | |
x = forward(last_ch, False, lang) | |
p = torch.softmax(x, dim=1) | |
next_ch = torch.multinomial(p, num_samples=1, replacement=True).item() | |
if lang =='ar': | |
w += ar_itos[next_ch] | |
elif lang == 'en': | |
w += en_itos[next_ch] | |
last_ch = last_ch.clone().detach().squeeze(0) | |
last_ch = last_ch.tolist() | |
last_ch = last_ch[1:] + [next_ch] | |
if next_ch == 0: | |
break | |
return w[:-1] | |
def generate_names(n,lang): | |
ret = [] | |
for i in range(n): | |
ret.append(generate_name(lang)) | |
return ret | |
if __name__ == '__main__': | |
pass |