Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from utils.config import config | |
class Encoder(nn.Module): | |
def __init__(self, input_dim, embedding_dim, hidden_dim, n_layers, dropout): | |
super().__init__() | |
self.embedding = nn.Embedding(input_dim, embedding_dim) | |
self.rnn = nn.GRU( | |
embedding_dim, | |
hidden_dim, | |
num_layers=n_layers, | |
dropout=dropout if n_layers > 1 else 0, | |
bidirectional=False | |
) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, src): | |
# src: [batch_size, src_len] | |
embedded = self.dropout(self.embedding(src)) | |
# embedded: [batch_size, src_len, embedding_dim] | |
outputs, hidden = self.rnn(embedded.permute(1, 0, 2)) | |
# outputs: [src_len, batch_size, hidden_dim] | |
# hidden: [n_layers * num_directions, batch_size, hidden_dim] | |
return outputs, hidden |