bn_multi_tribe_mt / src /seq2seqtrainer.py
MasumBhuiyan's picture
Updated Data Processing module. Seq2Seq model added (unverified)
e13f31a
raw
history blame
812 Bytes
from pipes import models, utils, const
from pipes.data import Dataset
if __name__ == "__main__":
input_lang = 'gr'
output_lang = 'bn'
dataset_object = Dataset([input_lang, output_lang])
dataset_object.pack()
dataset_object.process()
dataset_dict = dataset_object.get_dict()
seq2seq = models.Seq2Seq(
input_vocab_size=dataset_dict[input_lang]["vocab_size"],
output_vocab_size=dataset_dict[output_lang]["vocab_size"],
embedding_dim=256,
hidden_units=64)
seq2seq.build()
seq2seq.run(
encoder_input_data=dataset_dict[input_lang]["train"],
decoder_input_data=dataset_dict[output_lang]["train"],
val_encoder_input_data=dataset_dict[input_lang]["val"],
val_decoder_input_data=dataset_dict[output_lang]["val"],
)