Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import torch | |
sys.path.append(os.path.abspath("src/train/")) | |
sys.path.append(os.path.abspath("src/data_processing/")) | |
from transformer import Transformer | |
from data_processing import SRC, TRG, arTokenizer, engTokenizer | |
device = "cpu" | |
# Define model hyperparameters | |
num_heads = 8 | |
num_encoder_layers = 3 | |
num_decoder_layers = 3 | |
max_len = 230 | |
dropout = 0.4 | |
embedding_size = 256 | |
# Define vocabulary sizes and padding index | |
src_pad_idx = SRC.vocab.stoi["<pad>"] | |
src_vocab_size = len(SRC.vocab) | |
trg_vocab_size = len(TRG.vocab) | |
# Initialize model with specified hyperparameters | |
model = Transformer( | |
embedding_size, | |
src_vocab_size, | |
trg_vocab_size, | |
src_pad_idx, | |
num_heads, | |
num_encoder_layers, | |
num_decoder_layers, | |
dropout, | |
max_len, | |
device=device, | |
).to(device) | |
# Load the saved model parameters | |
model.load_state_dict(torch.load("models/arabic2english.pt", map_location=device)) | |
def translate(sentence, srcField, targetField): | |
""" | |
Translates an English sentence to Arabic using the Transformer model. | |
Args: | |
sentence (str): Input Arabic sentence to be translated. | |
srcField: Source language field. | |
targetField: Target language field. | |
Returns: | |
str: Translated English sentence. | |
""" | |
model.eval() # Set model to evaluation mode | |
srcTokenizer = engTokenizer # Initialize source tokenizer | |
srcField = SRC # Set source language field to English | |
targetField = TRG # Set target language field to Arabic | |
processed_sentence = srcField.process([srcTokenizer(sentence)]).to( | |
device | |
) # Process input sentence | |
trg = ["بداية"] # Initialize target sentence with start token | |
# Generate translation | |
for _ in range(max_len): | |
trg_tensor = ( | |
torch.tensor([targetField.vocab.stoi[word] for word in trg]) | |
.unsqueeze(1) | |
.to(device) | |
) | |
outputs = model(processed_sentence, trg_tensor) # Generate output predictions | |
# Determine predicted token | |
pred_token = targetField.vocab.itos[outputs.argmax(2)[-1:].item()] | |
if pred_token != "<unk>": # Exclude unknown tokens | |
trg.append(pred_token) | |
if pred_token == "نهاية": # Stop translation at end token | |
break | |
return " ".join( | |
[word for word in trg if word != "<unk>"][1:-1] | |
) # Return translated sentence | |