File size: 1,963 Bytes
c412427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from Models.seq2seq_model import Seq2seq_no_attention
from Models.seq2seqAttention_model import Seq2seq_with_attention
from Models.Transformer_model import NMT_Transformer
from Models.ModelArgs import ModelArgs


def get_model(params:ModelArgs, vocab_size):

    if params.model_type.lower() == 's2s': model = Seq2seq_no_attention(vocab_size=vocab_size,
                                                        dim_embed=params.dim_embed,
                                                        dim_model=params.dim_model,
                                                        dim_feedforward=params.dim_feedforward,
                                                        num_layers=params.num_layers,
                                                        dropout_probability=params.dropout)
      
    elif params.model_type.lower() == 's2sattention': model = Seq2seq_with_attention(vocab_size=vocab_size,
                                                                                 dim_embed=params.dim_embed,
                                                                                 dim_model=params.dim_model,
                                                                                 dim_feedforward=params.dim_feedforward,
                                                                                 num_layers=params.num_layers,
                                                                                 dropout_probability=params.dropout)

    else: model = NMT_Transformer(vocab_size=vocab_size,
                                dim_embed=params.dim_embed,
                                dim_model=params.dim_model,
                                dim_feedforward=params.dim_feedforward,
                                num_layers=params.num_layers,
                                dropout_probability=params.dropout,
                                maxlen=params.maxlen)
    return model