import torch import torch.nn as nn from models.encoder import Encoder from models.decoder import Decoder class Seq2Seq(nn.Module): def __init__(self, encoder, decoder, device): super().__init__() self.encoder = encoder self.decoder = decoder self.device = device def forward(self, src, trg, teacher_forcing_ratio=0.5): # src: [batch_size, src_len] # trg: [batch_size, trg_len] batch_size = trg.shape[0] trg_len = trg.shape[1] trg_vocab_size = self.decoder.output_dim outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device) encoder_outputs, hidden = self.encoder(src) input = trg[:, 0] # First token is for t in range(1, trg_len): output, hidden = self.decoder(input, hidden, encoder_outputs) outputs[t] = output teacher_force = torch.rand(1) < teacher_forcing_ratio top1 = output.argmax(1) input = trg[:, t] if teacher_force else top1 return outputs.permute(1, 0, 2)