import torch.nn as nn import torch from models.attention import Attention from utils.config import config class Decoder(nn.Module): def __init__(self, output_dim, embedding_dim, hidden_dim, n_layers, dropout): super().__init__() self.output_dim = output_dim self.attention = Attention(hidden_dim) self.embedding = nn.Embedding(output_dim, embedding_dim) self.rnn = nn.GRU( embedding_dim + hidden_dim, hidden_dim, num_layers=n_layers, dropout=dropout if n_layers > 1 else 0 ) self.fc_out = nn.Linear(hidden_dim * 2, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, input, hidden, encoder_outputs): # input: [batch_size] # hidden: [n_layers, batch_size, hidden_dim] # encoder_outputs: [src_len, batch_size, hidden_dim] input = input.unsqueeze(0) # input: [1, batch_size] embedded = self.dropout(self.embedding(input)) # embedded: [1, batch_size, embedding_dim] a = self.attention(hidden[-1], encoder_outputs) # a: [src_len, batch_size] a = a.permute(1, 0).unsqueeze(1) # a: [batch_size, 1, src_len] encoder_outputs = encoder_outputs.permute(1, 0, 2) # encoder_outputs: [batch_size, src_len, hidden_dim] weighted = torch.bmm(a, encoder_outputs) weighted = weighted.permute(1, 0, 2) # weighted: [1, batch_size, hidden_dim] rnn_input = torch.cat((embedded, weighted), dim=2) # rnn_input: [1, batch_size, embedding_dim + hidden_dim] output, hidden = self.rnn(rnn_input, hidden) # output: [1, batch_size, hidden_dim] # hidden: [n_layers, batch_size, hidden_dim] embedded = embedded.squeeze(0) output = output.squeeze(0) weighted = weighted.squeeze(0) prediction = self.fc_out(torch.cat((output, weighted), dim=1)) # prediction: [batch_size, output_dim] return prediction, hidden