Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from models.encoder import Encoder | |
from models.decoder import Decoder | |
class Seq2Seq(nn.Module): | |
def __init__(self, encoder, decoder, device): | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.device = device | |
def forward(self, src, trg, teacher_forcing_ratio=0.5): | |
# src: [batch_size, src_len] | |
# trg: [batch_size, trg_len] | |
batch_size = trg.shape[0] | |
trg_len = trg.shape[1] | |
trg_vocab_size = self.decoder.output_dim | |
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device) | |
encoder_outputs, hidden = self.encoder(src) | |
input = trg[:, 0] # First token is <start> | |
for t in range(1, trg_len): | |
output, hidden = self.decoder(input, hidden, encoder_outputs) | |
outputs[t] = output | |
teacher_force = torch.rand(1) < teacher_forcing_ratio | |
top1 = output.argmax(1) | |
input = trg[:, t] if teacher_force else top1 | |
return outputs.permute(1, 0, 2) |