Spaces:
Sleeping
Sleeping
File size: 2,430 Bytes
9a4dd2c b1c38c2 8e41ab0 9a4dd2c 8e41ab0 9a4dd2c b1c38c2 9a4dd2c b1c38c2 9a4dd2c b1c38c2 9a4dd2c b1c38c2 9a4dd2c b1c38c2 8e41ab0 b1c38c2 9a4dd2c b1c38c2 8e41ab0 b1c38c2 9a4dd2c b1c38c2 9a4dd2c b1c38c2 9a4dd2c 8e41ab0 b1c38c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import os
import sys
import torch
sys.path.append(os.path.abspath("src/train/"))
sys.path.append(os.path.abspath("src/data_processing/"))
from transformer import Transformer
from data_processing import SRC, TRG, arTokenizer, engTokenizer
device = "cpu"
# Define model hyperparameters
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
max_len = 230
dropout = 0.4
embedding_size = 256
# Define vocabulary sizes and padding index
src_pad_idx = SRC.vocab.stoi["<pad>"]
src_vocab_size = len(SRC.vocab)
trg_vocab_size = len(TRG.vocab)
# Initialize model with specified hyperparameters
model = Transformer(
embedding_size,
src_vocab_size,
trg_vocab_size,
src_pad_idx,
num_heads,
num_encoder_layers,
num_decoder_layers,
dropout,
max_len,
device=device,
).to(device)
# Load the saved model parameters
model.load_state_dict(torch.load("models/arabic2english.pt", map_location=device))
def translate(sentence, srcField, targetField):
"""
Translates an English sentence to Arabic using the Transformer model.
Args:
sentence (str): Input Arabic sentence to be translated.
srcField: Source language field.
targetField: Target language field.
Returns:
str: Translated English sentence.
"""
model.eval() # Set model to evaluation mode
srcTokenizer = engTokenizer # Initialize source tokenizer
srcField = SRC # Set source language field to English
targetField = TRG # Set target language field to Arabic
processed_sentence = srcField.process([srcTokenizer(sentence)]).to(
device
) # Process input sentence
trg = ["بداية"] # Initialize target sentence with start token
# Generate translation
for _ in range(max_len):
trg_tensor = (
torch.tensor([targetField.vocab.stoi[word] for word in trg])
.unsqueeze(1)
.to(device)
)
outputs = model(processed_sentence, trg_tensor) # Generate output predictions
# Determine predicted token
pred_token = targetField.vocab.itos[outputs.argmax(2)[-1:].item()]
if pred_token != "<unk>": # Exclude unknown tokens
trg.append(pred_token)
if pred_token == "نهاية": # Stop translation at end token
break
return " ".join(
[word for word in trg if word != "<unk>"][1:-1]
) # Return translated sentence
|