File size: 5,779 Bytes
c412427 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import torch
from torch import nn
import random
class Encoder(nn.Module):
def __init__(self, vocab_size, dim_embed, dim_hidden, dim_feedforward, num_layers, dropout_probability=0.1):
super().__init__()
self.embd_layer = nn.Embedding(vocab_size, dim_embed)
self.dropout = nn.Dropout(dropout_probability)
self.rnn = nn.GRU(dim_embed, dim_hidden, num_layers, batch_first=True, dropout=dropout_probability, bidirectional=True)
self.hidden_map = nn.Sequential(nn.Linear(dim_hidden*2, dim_feedforward),
nn.ReLU(),
nn.Linear(dim_feedforward, dim_hidden),
nn.Dropout(dropout_probability))
self.output_map = nn.Sequential(nn.Linear(dim_hidden*2, dim_feedforward),
nn.ReLU(),
nn.Linear(dim_feedforward, dim_hidden),
nn.Dropout(dropout_probability))
def forward(self, x):
embds = self.dropout(self.embd_layer(x))
context, hidden = self.rnn(embds)
last_hidden = torch.cat([hidden[-2,:,:], hidden[-1,:,:]], dim=-1)
to_decoder_hidden = self.hidden_map(last_hidden)
to_decoder_output = self.output_map(context)
return to_decoder_output, to_decoder_hidden
class Attention(nn.Module):
def __init__(self, input_dims):
super().__init__()
self.fc_energy = nn.Linear(input_dims*2, input_dims)
self.alpha = nn.Linear(input_dims, 1, bias=False)
def forward(self,
encoder_output, # (B,T,encoder_hidden)
decoder_hidden): # (B,decoder_hidden)
## encoder_hidden = encoder_hidden = input_dims
seq_len = encoder_output.size(1)
decoder_hidden = decoder_hidden.unsqueeze(1).repeat(1, seq_len, 1) ## (B,T,input_dims)
energy = self.fc_energy(torch.cat((decoder_hidden, encoder_output), dim=-1))
alphas = self.alpha(energy).squeeze(-1)
return torch.softmax(alphas, dim=-1)
class Decoder(nn.Module):
def __init__(self, vocab_size, dim_embed, dim_hidden, attention, num_layers, dropout_probability):
super().__init__()
self.attention = attention
self.embd_layer = nn.Embedding(vocab_size, dim_embed)
self.rnn = nn.GRU(dim_hidden + dim_embed, dim_hidden, batch_first=True, num_layers=num_layers, dropout=dropout_probability)
def forward(self, x, encoder_output, hidden_t_1):
## hidden_t_1 shape: (num_layers,B,dim_hidden)
## encoder_output shape : (B,T,dim_hidden)
## x shape: (B,1) one token
embds = self.embd_layer(x) ## (B,1,dim_embed)
alphas = self.attention(encoder_output, hidden_t_1[-1]).unsqueeze(1) ## (B,1,T)
attention = torch.bmm(alphas, encoder_output) ## (B,T,dim_embed)
rnn_input = torch.cat((embds, attention), dim=-1) ## (B,1,dim_hidden + dim_embed)
output, hidden_t = self.rnn(rnn_input, hidden_t_1)
return output, hidden_t, alphas.squeeze(1) ## "a" is returned for visualization
class Seq2seq_with_attention(nn.Module):
def __init__(self, vocab_size:int, dim_embed:int, dim_model:int, dim_feedforward:int, num_layers:int, dropout_probability:float):
super().__init__()
self.vocab_size = vocab_size
self.num_layers = num_layers
self.encoder = Encoder(vocab_size, dim_embed, dim_model, dim_feedforward, num_layers, dropout_probability)
self.attention = Attention(dim_model)
self.decoder = Decoder(vocab_size, dim_embed, dim_model, self.attention, num_layers, dropout_probability)
self.classifier = nn.Linear(dim_model, vocab_size)
## weight sharing between classifier and embed_shared_src_trg_cls
self.encoder.embd_layer.weight = self.classifier.weight
self.decoder.embd_layer.weight = self.classifier.weight
def forward(self, source, target, pad_tokenId):
# target = <s> text </s>
# teacher_force_ratio = 0.5
B, T = target.size()
total_logits = torch.zeros(B, T, self.vocab_size, device=source.device)
context, hidden = self.encoder(source)
hidden = hidden.unsqueeze(0).repeat(self.num_layers,1,1) # (numlayer, B, dim_model)
for step in range(T):
step_token = target[:, [step]]
out, hidden, alphas = self.decoder(step_token, context, hidden)
logits = self.classifier(out).squeeze(1)
total_logits[:, step] = logits
loss = None
if T > 1:
flat_logits = total_logits[:,:-1,:].reshape(-1, total_logits.size(-1))
flat_targets = target[:,1:].reshape(-1)
loss = nn.functional.cross_entropy(flat_logits, flat_targets, ignore_index=pad_tokenId)
return total_logits, loss
@torch.no_grad
def greedy_decode_fast(self, source:torch.Tensor, sos_tokenId: int, eos_tokenId:int, pad_tokenId, max_tries=50):
self.eval()
targets_hat = [sos_tokenId]
context, hidden = self.encoder(source.unsqueeze(0))
hidden = hidden.unsqueeze(0).repeat(self.num_layers,1,1) # (numlayer, B, dim_model)
for step in range(max_tries):
x = torch.tensor([targets_hat[step]]).unsqueeze(0).to(source.device)
out, hidden, alphas = self.decoder(x, context, hidden)
logits = self.classifier(out)
top1 = logits.argmax(-1)
targets_hat.append(top1.item())
if top1 == eos_tokenId:
return targets_hat
return targets_hat
|