bn_multi_tribe_mt / src /seq2seq_trainer.py
MasumBhuiyan's picture
Seq2Seq model implemented
c5dc1d4
raw
history blame
1.1 kB
from pipes import utils
from pipes import const
from pipes import models
from pipes.data import Dataset
import tensorflow as tf
if __name__ == "__main__":
input_lang = 'gr'
output_lang = 'bn'
dataset_object = Dataset([input_lang, output_lang])
dataset_object.pack()
dataset_object.process()
train_ds, val_ds = dataset_object.pull()
dataset_dict = dataset_object.get_dict()
model_object = models.Seq2Seq(
input_vocab_size=dataset_dict[input_lang]["vocab_size"],
output_vocab_size=dataset_dict[output_lang]["vocab_size"],
embedding_dim=256,
hidden_units=512
)
model_object.build()
model = model_object.get()
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy', 'val_accuracy'],
)
history = model.fit(
train_ds.repeat(),
epochs=10,
steps_per_epoch=100,
validation_steps=20,
validation_data=val_ds,
callbacks=[tf.keras.callbacks.EarlyStopping(patience=3)]
)