from Models.seq2seq_model import Seq2seq_no_attention from Models.seq2seqAttention_model import Seq2seq_with_attention from Models.Transformer_model import NMT_Transformer from Models.ModelArgs import ModelArgs def get_model(params:ModelArgs, vocab_size): if params.model_type.lower() == 's2s': model = Seq2seq_no_attention(vocab_size=vocab_size, dim_embed=params.dim_embed, dim_model=params.dim_model, dim_feedforward=params.dim_feedforward, num_layers=params.num_layers, dropout_probability=params.dropout) elif params.model_type.lower() == 's2sattention': model = Seq2seq_with_attention(vocab_size=vocab_size, dim_embed=params.dim_embed, dim_model=params.dim_model, dim_feedforward=params.dim_feedforward, num_layers=params.num_layers, dropout_probability=params.dropout) else: model = NMT_Transformer(vocab_size=vocab_size, dim_embed=params.dim_embed, dim_model=params.dim_model, dim_feedforward=params.dim_feedforward, num_layers=params.num_layers, dropout_probability=params.dropout, maxlen=params.maxlen) return model