Commit
·
6543d58
1
Parent(s):
ccfa333
Updated trainer
Browse files- src/pipes/const.py +2 -0
- src/pipes/data.py +32 -0
- src/pipes/models.py +0 -32
src/pipes/const.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
data_dir: str = "E:/bn_multi_tribe_mt/data/"
|
2 |
langs: list[str] = ['bn', 'en', 'gr']
|
3 |
MAX_SEQ_LEN = 30
|
|
|
|
|
|
1 |
data_dir: str = "E:/bn_multi_tribe_mt/data/"
|
2 |
langs: list[str] = ['bn', 'en', 'gr']
|
3 |
MAX_SEQ_LEN = 30
|
4 |
+
BATCH_SIZE = 64
|
5 |
+
BUFFER_SIZE = 10000
|
src/pipes/data.py
CHANGED
@@ -3,6 +3,7 @@ import const
|
|
3 |
import utils
|
4 |
import string
|
5 |
|
|
|
6 |
class SequenceLoader:
|
7 |
def __init__(self):
|
8 |
self.sequence_dict = None
|
@@ -38,6 +39,12 @@ class SequenceLoader:
|
|
38 |
self.lang = lang
|
39 |
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
def remove_punctuation_from_seq(seq):
|
42 |
english_punctuations = string.punctuation
|
43 |
bangla_punctuations = "৷-–—’‘৳…।"
|
@@ -157,6 +164,29 @@ class Dataset:
|
|
157 |
seq_processor.pad()
|
158 |
self.dataset_dict = seq_processor.get_dict()
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
def get_dict(self):
|
161 |
return self.dataset_dict
|
162 |
|
@@ -167,4 +197,6 @@ if __name__ == "__main__":
|
|
167 |
dataset_dict = dataset_object.get_dict()
|
168 |
utils.save_dict("{}/dataset.txt".format(const.data_dir), dataset_dict)
|
169 |
dataset_object.process()
|
|
|
|
|
170 |
print(utils.load_dict("{}/dataset.txt".format(const.data_dir)))
|
|
|
3 |
import utils
|
4 |
import string
|
5 |
|
6 |
+
|
7 |
class SequenceLoader:
|
8 |
def __init__(self):
|
9 |
self.sequence_dict = None
|
|
|
39 |
self.lang = lang
|
40 |
|
41 |
|
42 |
+
def serialize(src_seq, tar_seq):
|
43 |
+
tar_seq_in = tar_seq[:, :-1].to_tensor()
|
44 |
+
tar_seq_out = tar_seq[:, 1:].to_tensor()
|
45 |
+
return (src_seq, tar_seq_in), tar_seq_out
|
46 |
+
|
47 |
+
|
48 |
def remove_punctuation_from_seq(seq):
|
49 |
english_punctuations = string.punctuation
|
50 |
bangla_punctuations = "৷-–—’‘৳…।"
|
|
|
164 |
seq_processor.pad()
|
165 |
self.dataset_dict = seq_processor.get_dict()
|
166 |
|
167 |
+
def pull(self):
|
168 |
+
src_lang_train_seqs = self.dataset_dict[self.langs[0]]["train"]
|
169 |
+
tar_lang_train_seqs = self.dataset_dict[self.langs[1]]["train"]
|
170 |
+
|
171 |
+
src_lang_val_seqs = self.dataset_dict[self.langs[0]]["val"]
|
172 |
+
tar_lang_val_seqs = self.dataset_dict[self.langs[1]]["val"]
|
173 |
+
|
174 |
+
train_ds = ((tf.data.Dataset
|
175 |
+
.from_tensor_slices((src_lang_train_seqs, tar_lang_train_seqs)))
|
176 |
+
.shuffle(const.BUFFER_SIZE)
|
177 |
+
.batch(const.BATCH_SIZE))
|
178 |
+
|
179 |
+
val_ds = (tf.data.Dataset
|
180 |
+
.from_tensor_slices(src_lang_val_seqs, tar_lang_val_seqs)
|
181 |
+
.shuffle(const.BUFFER_SIZE)
|
182 |
+
.batch(const.BATCH_SIZE))
|
183 |
+
|
184 |
+
train_ds = train_ds.map(serialize, tf.data.AUTOTUNE)
|
185 |
+
val_ds = val_ds.map(serialize, tf.data.AUTOTUNE)
|
186 |
+
|
187 |
+
return trainset, valset
|
188 |
+
|
189 |
+
@staticmethod
|
190 |
def get_dict(self):
|
191 |
return self.dataset_dict
|
192 |
|
|
|
197 |
dataset_dict = dataset_object.get_dict()
|
198 |
utils.save_dict("{}/dataset.txt".format(const.data_dir), dataset_dict)
|
199 |
dataset_object.process()
|
200 |
+
trainset, valset = dataset_object.pull()
|
201 |
+
|
202 |
print(utils.load_dict("{}/dataset.txt".format(const.data_dir)))
|
src/pipes/models.py
CHANGED
@@ -34,38 +34,6 @@ class Seq2Seq:
|
|
34 |
outputs = self.output_layer(decoder_outputs)
|
35 |
self.model = tf.keras.Model([encoder_inputs, decoder_inputs], outputs)
|
36 |
|
37 |
-
def run(self, encoder_input_data, decoder_input_data, val_encoder_input_data, val_decoder_input_data):
|
38 |
-
self.model.compile(
|
39 |
-
optimizer=self.optimizer,
|
40 |
-
loss=self.loss,
|
41 |
-
metrics=self.metrics
|
42 |
-
)
|
43 |
-
|
44 |
-
decoder_target_data = [[sentence[1:] + [0]] for sentence in decoder_input_data]
|
45 |
-
val_decoder_target_data = [[sentence[1:] + [0]] for sentence in val_decoder_input_data]
|
46 |
-
|
47 |
-
self.model.fit(
|
48 |
-
([encoder_input_data, decoder_input_data]),
|
49 |
-
decoder_target_data,
|
50 |
-
batch_size=self.batch_size,
|
51 |
-
epochs=self.epochs,
|
52 |
-
validation_data=([val_encoder_input_data, val_decoder_input_data], val_decoder_target_data)
|
53 |
-
)
|
54 |
-
|
55 |
def get(self):
|
56 |
return self.model
|
57 |
|
58 |
-
def set_epochs(self, epochs):
|
59 |
-
self.epochs = epochs
|
60 |
-
|
61 |
-
def set_batch_size(self, batch_size):
|
62 |
-
self.batch_size = batch_size
|
63 |
-
|
64 |
-
def set_loss(self, loss):
|
65 |
-
self.loss = loss
|
66 |
-
|
67 |
-
def set_optimizer(self, optimizer):
|
68 |
-
self.optimizer = optimizer
|
69 |
-
|
70 |
-
def set_metric(self, metrics):
|
71 |
-
self.metrics = metrics
|
|
|
34 |
outputs = self.output_layer(decoder_outputs)
|
35 |
self.model = tf.keras.Model([encoder_inputs, decoder_inputs], outputs)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
def get(self):
|
38 |
return self.model
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|