File size: 1,963 Bytes
c412427 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from Models.seq2seq_model import Seq2seq_no_attention
from Models.seq2seqAttention_model import Seq2seq_with_attention
from Models.Transformer_model import NMT_Transformer
from Models.ModelArgs import ModelArgs
def get_model(params:ModelArgs, vocab_size):
if params.model_type.lower() == 's2s': model = Seq2seq_no_attention(vocab_size=vocab_size,
dim_embed=params.dim_embed,
dim_model=params.dim_model,
dim_feedforward=params.dim_feedforward,
num_layers=params.num_layers,
dropout_probability=params.dropout)
elif params.model_type.lower() == 's2sattention': model = Seq2seq_with_attention(vocab_size=vocab_size,
dim_embed=params.dim_embed,
dim_model=params.dim_model,
dim_feedforward=params.dim_feedforward,
num_layers=params.num_layers,
dropout_probability=params.dropout)
else: model = NMT_Transformer(vocab_size=vocab_size,
dim_embed=params.dim_embed,
dim_model=params.dim_model,
dim_feedforward=params.dim_feedforward,
num_layers=params.num_layers,
dropout_probability=params.dropout,
maxlen=params.maxlen)
return model
|