bn_multi_tribe_mt / src /seq2seq_trainer.py
MasumBhuiyan's picture
Seq2Seq model implemented
c5dc1d4
from pipes import utils
from pipes import const
from pipes import models
from pipes.data import Dataset
import tensorflow as tf
if __name__ == "__main__":
input_lang = 'gr'
output_lang = 'bn'
dataset_object = Dataset([input_lang, output_lang])
dataset_object.pack()
dataset_object.process()
train_ds, val_ds = dataset_object.pull()
dataset_dict = dataset_object.get_dict()
model_object = models.Seq2Seq(
input_vocab_size=dataset_dict[input_lang]["vocab_size"],
output_vocab_size=dataset_dict[output_lang]["vocab_size"],
embedding_dim=256,
hidden_units=512
)
model_object.build()
model = model_object.get()
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy', 'val_accuracy'],
)
history = model.fit(
train_ds.repeat(),
epochs=10,
steps_per_epoch=100,
validation_steps=20,
validation_data=val_ds,
callbacks=[tf.keras.callbacks.EarlyStopping(patience=3)]
)