nullHawk's picture
init
9a41f63 verified
import torch
import torch.nn as nn
from models.encoder import Encoder
from models.decoder import Decoder
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, device):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
def forward(self, src, trg, teacher_forcing_ratio=0.5):
# src: [batch_size, src_len]
# trg: [batch_size, trg_len]
batch_size = trg.shape[0]
trg_len = trg.shape[1]
trg_vocab_size = self.decoder.output_dim
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
encoder_outputs, hidden = self.encoder(src)
input = trg[:, 0] # First token is <start>
for t in range(1, trg_len):
output, hidden = self.decoder(input, hidden, encoder_outputs)
outputs[t] = output
teacher_force = torch.rand(1) < teacher_forcing_ratio
top1 = output.argmax(1)
input = trg[:, t] if teacher_force else top1
return outputs.permute(1, 0, 2)