import os import sys import torch sys.path.append(os.path.abspath("src/train/")) sys.path.append(os.path.abspath("src/data_processing/")) from transformer import Transformer from data_processing import SRC, TRG, arTokenizer, engTokenizer device = "cpu" # Define model hyperparameters num_heads = 8 num_encoder_layers = 3 num_decoder_layers = 3 max_len = 230 dropout = 0.4 embedding_size = 256 # Define vocabulary sizes and padding index src_pad_idx = SRC.vocab.stoi[""] src_vocab_size = len(SRC.vocab) trg_vocab_size = len(TRG.vocab) # Initialize model with specified hyperparameters model = Transformer( embedding_size, src_vocab_size, trg_vocab_size, src_pad_idx, num_heads, num_encoder_layers, num_decoder_layers, dropout, max_len, device=device, ).to(device) # Load the saved model parameters model.load_state_dict(torch.load("models/arabic2english.pt", map_location=device)) def translate(sentence, srcField, targetField): """ Translates an English sentence to Arabic using the Transformer model. Args: sentence (str): Input Arabic sentence to be translated. srcField: Source language field. targetField: Target language field. Returns: str: Translated English sentence. """ model.eval() # Set model to evaluation mode srcTokenizer = engTokenizer # Initialize source tokenizer srcField = SRC # Set source language field to English targetField = TRG # Set target language field to Arabic processed_sentence = srcField.process([srcTokenizer(sentence)]).to( device ) # Process input sentence trg = ["بداية"] # Initialize target sentence with start token # Generate translation for _ in range(max_len): trg_tensor = ( torch.tensor([targetField.vocab.stoi[word] for word in trg]) .unsqueeze(1) .to(device) ) outputs = model(processed_sentence, trg_tensor) # Generate output predictions # Determine predicted token pred_token = targetField.vocab.itos[outputs.argmax(2)[-1:].item()] if pred_token != "": # Exclude unknown tokens trg.append(pred_token) if pred_token == "نهاية": # Stop translation at end token break return " ".join( [word for word in trg if word != ""][1:-1] ) # Return translated sentence