File size: 6,500 Bytes
c412427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
from torch import nn


class NMT_Transformer(nn.Module):
    def __init__(self, vocab_size:int, dim_embed:int,

                 dim_model:int, dim_feedforward:int, num_layers:int,

                 dropout_probability:float, maxlen:int):
        super().__init__()

        self.embed_shared_src_trg_cls = nn.Embedding(num_embeddings=vocab_size, embedding_dim=dim_embed)
        self.positonal_shared_src_trg = nn.Embedding(num_embeddings=maxlen, embedding_dim=dim_embed)

        # self.trg_embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=dim_embed)
        # self.trg_pos = nn.Embedding(num_embeddings=maxlen, embedding_dim=dim_embed)

        self.dropout = nn.Dropout(dropout_probability)

        encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model, nhead=8,
                                                   dim_feedforward=dim_feedforward,
                                                   dropout=dropout_probability,
                                                   batch_first=True, norm_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, enable_nested_tensor=False)

        decoder_layer = nn.TransformerDecoderLayer(d_model=dim_model, nhead=8,
                                                   dim_feedforward=dim_feedforward,
                                                   dropout=dropout_probability,
                                                   batch_first=True, norm_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        
        self.classifier = nn.Linear(dim_model, vocab_size)
        ## weight sharing between classifier and embed_shared_src_trg_cls
        self.classifier.weight = self.embed_shared_src_trg_cls.weight

        self.maxlen = maxlen
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.ones_(module.weight)
            torch.nn.init.zeros_(module.bias)
    
    def forward(self, source, target, pad_tokenId):
        # target = <sos> + text + <eos>
        # source = text
        B, Ts = source.shape
        B, Tt = target.shape
        device = source.device
        ## Encoder Path
        src_poses = self.positonal_shared_src_trg(torch.arange(0, Ts).to(device).unsqueeze(0).repeat(B, 1))
        src_embedings = self.dropout(self.embed_shared_src_trg_cls(source) + src_poses)

        src_pad_mask = source == pad_tokenId
        memory = self.transformer_encoder(src=src_embedings, mask=None, src_key_padding_mask=src_pad_mask, is_causal=False)
        ## Decoder Path
        trg_poses = self.positonal_shared_src_trg(torch.arange(0, Tt).to(device).unsqueeze(0).repeat(B, 1))
        trg_embedings = self.dropout(self.embed_shared_src_trg_cls(target) + trg_poses)
        
        trg_pad_mask = target == pad_tokenId
        tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(Tt, dtype=bool).to(device)
        decoder_out = self.transformer_decoder.forward(tgt=trg_embedings,
                                                memory=memory,
                                                tgt_mask=tgt_mask,
                                                memory_mask=None,
                                                tgt_key_padding_mask=trg_pad_mask,
                                                memory_key_padding_mask=None)
        ## Classifier Path
        logits = self.classifier(decoder_out)
        loss = None
        if Tt > 1:
            # for model logits we will need all tokens except the last one
            flat_logits = logits[:,:-1,:].reshape(-1, logits.size(-1))
            # for targets we will need all tokens excapt the first one
            flat_targets = target[:,1:].reshape(-1)
            loss = nn.functional.cross_entropy(flat_logits, flat_targets, ignore_index=pad_tokenId)
        return logits, loss
    

    @torch.no_grad
    def greedy_decode_fast(self, source_tensor:torch.Tensor, sos_tokenId: int, eos_tokenId:int, pad_tokenId, max_tries=50):
        self.eval()
        source_tensor = source_tensor.unsqueeze(0)
        B, Ts = source_tensor.shape
        device = source_tensor.device
        target_tensor = torch.tensor([sos_tokenId]).unsqueeze(0).to(device)

        ## Encoder Path
        src_poses = self.positonal_shared_src_trg(torch.arange(0, Ts).to(device).unsqueeze(0).repeat(B, 1))
        src_embedings = self.embed_shared_src_trg_cls(source_tensor) + src_poses
        src_pad_mask = source_tensor == pad_tokenId
        context = self.transformer_encoder(src=src_embedings, mask=None, src_key_padding_mask=src_pad_mask, is_causal=False)

        for i in range(max_tries):
            ## Decoder Path
            trg_poses = self.positonal_shared_src_trg(torch.arange(0, i+1).to(device).unsqueeze(0).repeat(B, 1))
            trg_embedings = self.embed_shared_src_trg_cls(target_tensor) + trg_poses
            
            trg_pad_mask = target_tensor == pad_tokenId
            tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(i+1, dtype=bool).to(device)
            decoder_out = self.transformer_decoder.forward(tgt=trg_embedings,
                                                    memory=context,
                                                    tgt_mask=tgt_mask,
                                                    memory_mask=None,
                                                    tgt_key_padding_mask=trg_pad_mask,
                                                    memory_key_padding_mask=None)
            ## Classifier Path
            logits = self.classifier(decoder_out)
            # Greedy decoding
            top1 = logits[:,-1,:].argmax(dim=-1, keepdim=True)
            # Append predicted token
            target_tensor = torch.cat([target_tensor, top1], dim=1)
            
            # Stop if predict <EOS>
            if top1.item() == eos_tokenId:
                break
        return target_tensor.squeeze(0).tolist()