File size: 5,779 Bytes
c412427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from torch import nn
import random


class Encoder(nn.Module):
    def __init__(self, vocab_size, dim_embed, dim_hidden, dim_feedforward, num_layers, dropout_probability=0.1):
        super().__init__()

        self.embd_layer = nn.Embedding(vocab_size, dim_embed)
        self.dropout = nn.Dropout(dropout_probability)
        self.rnn = nn.GRU(dim_embed, dim_hidden, num_layers, batch_first=True, dropout=dropout_probability, bidirectional=True)

        self.hidden_map = nn.Sequential(nn.Linear(dim_hidden*2, dim_feedforward),
                                        nn.ReLU(),
                                        nn.Linear(dim_feedforward, dim_hidden),
                                        nn.Dropout(dropout_probability))
        
        self.output_map = nn.Sequential(nn.Linear(dim_hidden*2, dim_feedforward),
                                        nn.ReLU(),
                                        nn.Linear(dim_feedforward, dim_hidden),
                                        nn.Dropout(dropout_probability))

    def forward(self, x):
        embds = self.dropout(self.embd_layer(x))
        context, hidden = self.rnn(embds)
        last_hidden = torch.cat([hidden[-2,:,:], hidden[-1,:,:]], dim=-1)
        to_decoder_hidden = self.hidden_map(last_hidden)
        to_decoder_output = self.output_map(context)
        return to_decoder_output, to_decoder_hidden


class Attention(nn.Module):
    def __init__(self, input_dims):
        super().__init__()

        self.fc_energy = nn.Linear(input_dims*2, input_dims)
        self.alpha = nn.Linear(input_dims, 1, bias=False)

    def forward(self,

                encoder_output, # (B,T,encoder_hidden)

                decoder_hidden): # (B,decoder_hidden)
        ## encoder_hidden = encoder_hidden = input_dims

        seq_len = encoder_output.size(1)
        decoder_hidden = decoder_hidden.unsqueeze(1).repeat(1, seq_len, 1) ## (B,T,input_dims)
        energy = self.fc_energy(torch.cat((decoder_hidden, encoder_output), dim=-1))
        alphas = self.alpha(energy).squeeze(-1)

        return torch.softmax(alphas, dim=-1)



class Decoder(nn.Module):
    def __init__(self, vocab_size, dim_embed, dim_hidden, attention, num_layers, dropout_probability):
        super().__init__()
        self.attention = attention
        self.embd_layer = nn.Embedding(vocab_size, dim_embed)
        self.rnn = nn.GRU(dim_hidden + dim_embed, dim_hidden, batch_first=True, num_layers=num_layers, dropout=dropout_probability)

    def forward(self, x, encoder_output, hidden_t_1):
        ## hidden_t_1 shape: (num_layers,B,dim_hidden)
        ## encoder_output shape : (B,T,dim_hidden)
        ## x shape: (B,1) one token

        embds = self.embd_layer(x) ## (B,1,dim_embed)
        alphas = self.attention(encoder_output, hidden_t_1[-1]).unsqueeze(1) ## (B,1,T)
        attention = torch.bmm(alphas, encoder_output) ## (B,T,dim_embed)
        rnn_input = torch.cat((embds, attention), dim=-1) ## (B,1,dim_hidden + dim_embed)

        output, hidden_t = self.rnn(rnn_input, hidden_t_1)
        
        return output, hidden_t, alphas.squeeze(1) ## "a" is returned for visualization

class Seq2seq_with_attention(nn.Module):
    def __init__(self, vocab_size:int, dim_embed:int, dim_model:int, dim_feedforward:int, num_layers:int, dropout_probability:float):
        super().__init__()

        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.encoder = Encoder(vocab_size, dim_embed, dim_model, dim_feedforward, num_layers, dropout_probability)
        self.attention = Attention(dim_model)
        self.decoder = Decoder(vocab_size, dim_embed, dim_model, self.attention, num_layers, dropout_probability)
        self.classifier = nn.Linear(dim_model, vocab_size)

        ## weight sharing between classifier and embed_shared_src_trg_cls
        self.encoder.embd_layer.weight = self.classifier.weight
        self.decoder.embd_layer.weight = self.classifier.weight

    def forward(self, source, target, pad_tokenId):
        # target = <s> text </s>
        # teacher_force_ratio = 0.5
        B, T = target.size()
        total_logits = torch.zeros(B, T, self.vocab_size, device=source.device)
        context, hidden = self.encoder(source)
        hidden = hidden.unsqueeze(0).repeat(self.num_layers,1,1) # (numlayer, B, dim_model)
        for step in range(T):
            step_token = target[:, [step]]
            out, hidden, alphas = self.decoder(step_token, context, hidden)
            logits = self.classifier(out).squeeze(1)
            total_logits[:, step] = logits
        loss = None
        if T > 1:
            flat_logits = total_logits[:,:-1,:].reshape(-1, total_logits.size(-1))
            flat_targets = target[:,1:].reshape(-1)
            loss = nn.functional.cross_entropy(flat_logits, flat_targets, ignore_index=pad_tokenId)
        return total_logits, loss
    
    @torch.no_grad
    def greedy_decode_fast(self, source:torch.Tensor, sos_tokenId: int, eos_tokenId:int, pad_tokenId, max_tries=50):
        self.eval()
        targets_hat = [sos_tokenId]
        context, hidden = self.encoder(source.unsqueeze(0))
        hidden = hidden.unsqueeze(0).repeat(self.num_layers,1,1) # (numlayer, B, dim_model)
        for step in range(max_tries):
            x = torch.tensor([targets_hat[step]]).unsqueeze(0).to(source.device)
            out, hidden, alphas = self.decoder(x, context, hidden)
            logits = self.classifier(out)
            top1 = logits.argmax(-1)
            targets_hat.append(top1.item())
            if top1 == eos_tokenId:
                return targets_hat
        return targets_hat