TheDemond commited on
Commit
c412427
·
verified ·
1 Parent(s): e4cbbb9

Upload 9 files

Browse files
Configurations/s2s_model_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dim_embed": 256,
3
+ "dim_model": 256,
4
+ "dim_feedforward": 1024,
5
+ "num_layers": 4,
6
+ "dropout": 0.1,
7
+ "maxlen": 512,
8
+ "flash_attention": false
9
+ }
Configurations/s2sattention_model_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dim_embed": 256,
3
+ "dim_model": 256,
4
+ "dim_feedforward": 1024,
5
+ "num_layers": 4,
6
+ "dropout": 0.1,
7
+ "maxlen": 512,
8
+ "flash_attention": false
9
+ }
Configurations/transformer_model_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dim_embed": 256,
3
+ "dim_model": 256,
4
+ "dim_feedforward": 1024,
5
+ "num_layers": 4,
6
+ "dropout": 0.1,
7
+ "maxlen": 512,
8
+ "flash_attention": false
9
+ }
Models/AutoModel.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Models.seq2seq_model import Seq2seq_no_attention
2
+ from Models.seq2seqAttention_model import Seq2seq_with_attention
3
+ from Models.Transformer_model import NMT_Transformer
4
+ from Models.ModelArgs import ModelArgs
5
+
6
+
7
+ def get_model(params:ModelArgs, vocab_size):
8
+
9
+ if params.model_type.lower() == 's2s': model = Seq2seq_no_attention(vocab_size=vocab_size,
10
+ dim_embed=params.dim_embed,
11
+ dim_model=params.dim_model,
12
+ dim_feedforward=params.dim_feedforward,
13
+ num_layers=params.num_layers,
14
+ dropout_probability=params.dropout)
15
+
16
+ elif params.model_type.lower() == 's2sattention': model = Seq2seq_with_attention(vocab_size=vocab_size,
17
+ dim_embed=params.dim_embed,
18
+ dim_model=params.dim_model,
19
+ dim_feedforward=params.dim_feedforward,
20
+ num_layers=params.num_layers,
21
+ dropout_probability=params.dropout)
22
+
23
+ else: model = NMT_Transformer(vocab_size=vocab_size,
24
+ dim_embed=params.dim_embed,
25
+ dim_model=params.dim_model,
26
+ dim_feedforward=params.dim_feedforward,
27
+ num_layers=params.num_layers,
28
+ dropout_probability=params.dropout,
29
+ maxlen=params.maxlen)
30
+ return model
31
+
Models/ModelArgs.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ class ModelArgs:
4
+ """
5
+ A class to parse and store model configuration from a JSON file.
6
+ """
7
+ def __init__(self, model_type:str, config_path:str):
8
+ """
9
+ Initialize ModelArgs with configuration from a JSON file.
10
+
11
+ Args:
12
+ config_path (str): Path to the JSON configuration file.
13
+
14
+ Raises:
15
+ AssertionError: If the JSON content is invalid or has missing keys.
16
+ """
17
+ # Load JSON file
18
+ with open(config_path, 'r') as file:
19
+ config = json.load(file)
20
+
21
+ # Validate and assign attributes
22
+ self.model_type = model_type.lower()
23
+ assert self.model_type in ['s2s', 's2sattention', 'transformer'], \
24
+ "Supported model_type values are ['s2s', 's2sAttention', 'transformer']."
25
+
26
+ self.dim_embed = config.get("dim_embed")
27
+ assert isinstance(self.dim_embed, int), "dim_embed must be an integer."
28
+
29
+ self.dim_model = config.get("dim_model")
30
+ assert isinstance(self.dim_model, int), "dim_model must be an integer."
31
+
32
+ self.dim_feedforward = config.get("dim_feedforward")
33
+ assert isinstance(self.dim_feedforward, int), "dim_feedforward must be an integer."
34
+
35
+ self.num_layers = config.get("num_layers")
36
+ assert isinstance(self.num_layers, int), "num_layers must be an integer."
37
+
38
+ self.dropout = config.get("dropout")
39
+ assert isinstance(self.dropout, float), "dropout must be a float."
40
+
41
+ self.maxlen = config.get("maxlen")
42
+ assert isinstance(self.maxlen, int), "maxlen must be an integer."
43
+
44
+ self.flash_attention = config.get("flash_attention")
45
+ assert isinstance(self.flash_attention, bool), "flash_attention must be a boolean."
46
+
47
+ def __repr__(self):
48
+ return (f"ModelArgs(\n" +
49
+ f"model_type={self.model_type},\n" +
50
+ f"dim_embed={self.dim_embed},\n" +
51
+ f"dim_model={self.dim_model},\n" +
52
+ f"dim_feedforward={self.dim_feedforward},\n" +
53
+ f"num_layers={self.num_layers},\n" +
54
+ f"dropout={self.dropout},\n" +
55
+ f"maxlen={self.maxlen},\n" +
56
+ f"flash_attention={self.flash_attention}\n" +
57
+ ")")
Models/Transformer_model.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class NMT_Transformer(nn.Module):
6
+ def __init__(self, vocab_size:int, dim_embed:int,
7
+ dim_model:int, dim_feedforward:int, num_layers:int,
8
+ dropout_probability:float, maxlen:int):
9
+ super().__init__()
10
+
11
+ self.embed_shared_src_trg_cls = nn.Embedding(num_embeddings=vocab_size, embedding_dim=dim_embed)
12
+ self.positonal_shared_src_trg = nn.Embedding(num_embeddings=maxlen, embedding_dim=dim_embed)
13
+
14
+ # self.trg_embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=dim_embed)
15
+ # self.trg_pos = nn.Embedding(num_embeddings=maxlen, embedding_dim=dim_embed)
16
+
17
+ self.dropout = nn.Dropout(dropout_probability)
18
+
19
+ encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model, nhead=8,
20
+ dim_feedforward=dim_feedforward,
21
+ dropout=dropout_probability,
22
+ batch_first=True, norm_first=True)
23
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, enable_nested_tensor=False)
24
+
25
+ decoder_layer = nn.TransformerDecoderLayer(d_model=dim_model, nhead=8,
26
+ dim_feedforward=dim_feedforward,
27
+ dropout=dropout_probability,
28
+ batch_first=True, norm_first=True)
29
+ self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
30
+
31
+ self.classifier = nn.Linear(dim_model, vocab_size)
32
+ ## weight sharing between classifier and embed_shared_src_trg_cls
33
+ self.classifier.weight = self.embed_shared_src_trg_cls.weight
34
+
35
+ self.maxlen = maxlen
36
+ self.apply(self._init_weights)
37
+
38
+ def _init_weights(self, module):
39
+ if isinstance(module, nn.Linear):
40
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
41
+ if module.bias is not None:
42
+ torch.nn.init.zeros_(module.bias)
43
+ elif isinstance(module, nn.Embedding):
44
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
45
+ elif isinstance(module, nn.LayerNorm):
46
+ torch.nn.init.ones_(module.weight)
47
+ torch.nn.init.zeros_(module.bias)
48
+
49
+ def forward(self, source, target, pad_tokenId):
50
+ # target = <sos> + text + <eos>
51
+ # source = text
52
+ B, Ts = source.shape
53
+ B, Tt = target.shape
54
+ device = source.device
55
+ ## Encoder Path
56
+ src_poses = self.positonal_shared_src_trg(torch.arange(0, Ts).to(device).unsqueeze(0).repeat(B, 1))
57
+ src_embedings = self.dropout(self.embed_shared_src_trg_cls(source) + src_poses)
58
+
59
+ src_pad_mask = source == pad_tokenId
60
+ memory = self.transformer_encoder(src=src_embedings, mask=None, src_key_padding_mask=src_pad_mask, is_causal=False)
61
+ ## Decoder Path
62
+ trg_poses = self.positonal_shared_src_trg(torch.arange(0, Tt).to(device).unsqueeze(0).repeat(B, 1))
63
+ trg_embedings = self.dropout(self.embed_shared_src_trg_cls(target) + trg_poses)
64
+
65
+ trg_pad_mask = target == pad_tokenId
66
+ tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(Tt, dtype=bool).to(device)
67
+ decoder_out = self.transformer_decoder.forward(tgt=trg_embedings,
68
+ memory=memory,
69
+ tgt_mask=tgt_mask,
70
+ memory_mask=None,
71
+ tgt_key_padding_mask=trg_pad_mask,
72
+ memory_key_padding_mask=None)
73
+ ## Classifier Path
74
+ logits = self.classifier(decoder_out)
75
+ loss = None
76
+ if Tt > 1:
77
+ # for model logits we will need all tokens except the last one
78
+ flat_logits = logits[:,:-1,:].reshape(-1, logits.size(-1))
79
+ # for targets we will need all tokens excapt the first one
80
+ flat_targets = target[:,1:].reshape(-1)
81
+ loss = nn.functional.cross_entropy(flat_logits, flat_targets, ignore_index=pad_tokenId)
82
+ return logits, loss
83
+
84
+
85
+ @torch.no_grad
86
+ def greedy_decode_fast(self, source_tensor:torch.Tensor, sos_tokenId: int, eos_tokenId:int, pad_tokenId, max_tries=50):
87
+ self.eval()
88
+ source_tensor = source_tensor.unsqueeze(0)
89
+ B, Ts = source_tensor.shape
90
+ device = source_tensor.device
91
+ target_tensor = torch.tensor([sos_tokenId]).unsqueeze(0).to(device)
92
+
93
+ ## Encoder Path
94
+ src_poses = self.positonal_shared_src_trg(torch.arange(0, Ts).to(device).unsqueeze(0).repeat(B, 1))
95
+ src_embedings = self.embed_shared_src_trg_cls(source_tensor) + src_poses
96
+ src_pad_mask = source_tensor == pad_tokenId
97
+ context = self.transformer_encoder(src=src_embedings, mask=None, src_key_padding_mask=src_pad_mask, is_causal=False)
98
+
99
+ for i in range(max_tries):
100
+ ## Decoder Path
101
+ trg_poses = self.positonal_shared_src_trg(torch.arange(0, i+1).to(device).unsqueeze(0).repeat(B, 1))
102
+ trg_embedings = self.embed_shared_src_trg_cls(target_tensor) + trg_poses
103
+
104
+ trg_pad_mask = target_tensor == pad_tokenId
105
+ tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(i+1, dtype=bool).to(device)
106
+ decoder_out = self.transformer_decoder.forward(tgt=trg_embedings,
107
+ memory=context,
108
+ tgt_mask=tgt_mask,
109
+ memory_mask=None,
110
+ tgt_key_padding_mask=trg_pad_mask,
111
+ memory_key_padding_mask=None)
112
+ ## Classifier Path
113
+ logits = self.classifier(decoder_out)
114
+ # Greedy decoding
115
+ top1 = logits[:,-1,:].argmax(dim=-1, keepdim=True)
116
+ # Append predicted token
117
+ target_tensor = torch.cat([target_tensor, top1], dim=1)
118
+
119
+ # Stop if predict <EOS>
120
+ if top1.item() == eos_tokenId:
121
+ break
122
+ return target_tensor.squeeze(0).tolist()
Models/__init__.py ADDED
File without changes
Models/seq2seqAttention_model.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import random
4
+
5
+
6
+ class Encoder(nn.Module):
7
+ def __init__(self, vocab_size, dim_embed, dim_hidden, dim_feedforward, num_layers, dropout_probability=0.1):
8
+ super().__init__()
9
+
10
+ self.embd_layer = nn.Embedding(vocab_size, dim_embed)
11
+ self.dropout = nn.Dropout(dropout_probability)
12
+ self.rnn = nn.GRU(dim_embed, dim_hidden, num_layers, batch_first=True, dropout=dropout_probability, bidirectional=True)
13
+
14
+ self.hidden_map = nn.Sequential(nn.Linear(dim_hidden*2, dim_feedforward),
15
+ nn.ReLU(),
16
+ nn.Linear(dim_feedforward, dim_hidden),
17
+ nn.Dropout(dropout_probability))
18
+
19
+ self.output_map = nn.Sequential(nn.Linear(dim_hidden*2, dim_feedforward),
20
+ nn.ReLU(),
21
+ nn.Linear(dim_feedforward, dim_hidden),
22
+ nn.Dropout(dropout_probability))
23
+
24
+ def forward(self, x):
25
+ embds = self.dropout(self.embd_layer(x))
26
+ context, hidden = self.rnn(embds)
27
+ last_hidden = torch.cat([hidden[-2,:,:], hidden[-1,:,:]], dim=-1)
28
+ to_decoder_hidden = self.hidden_map(last_hidden)
29
+ to_decoder_output = self.output_map(context)
30
+ return to_decoder_output, to_decoder_hidden
31
+
32
+
33
+ class Attention(nn.Module):
34
+ def __init__(self, input_dims):
35
+ super().__init__()
36
+
37
+ self.fc_energy = nn.Linear(input_dims*2, input_dims)
38
+ self.alpha = nn.Linear(input_dims, 1, bias=False)
39
+
40
+ def forward(self,
41
+ encoder_output, # (B,T,encoder_hidden)
42
+ decoder_hidden): # (B,decoder_hidden)
43
+ ## encoder_hidden = encoder_hidden = input_dims
44
+
45
+ seq_len = encoder_output.size(1)
46
+ decoder_hidden = decoder_hidden.unsqueeze(1).repeat(1, seq_len, 1) ## (B,T,input_dims)
47
+ energy = self.fc_energy(torch.cat((decoder_hidden, encoder_output), dim=-1))
48
+ alphas = self.alpha(energy).squeeze(-1)
49
+
50
+ return torch.softmax(alphas, dim=-1)
51
+
52
+
53
+
54
+ class Decoder(nn.Module):
55
+ def __init__(self, vocab_size, dim_embed, dim_hidden, attention, num_layers, dropout_probability):
56
+ super().__init__()
57
+ self.attention = attention
58
+ self.embd_layer = nn.Embedding(vocab_size, dim_embed)
59
+ self.rnn = nn.GRU(dim_hidden + dim_embed, dim_hidden, batch_first=True, num_layers=num_layers, dropout=dropout_probability)
60
+
61
+ def forward(self, x, encoder_output, hidden_t_1):
62
+ ## hidden_t_1 shape: (num_layers,B,dim_hidden)
63
+ ## encoder_output shape : (B,T,dim_hidden)
64
+ ## x shape: (B,1) one token
65
+
66
+ embds = self.embd_layer(x) ## (B,1,dim_embed)
67
+ alphas = self.attention(encoder_output, hidden_t_1[-1]).unsqueeze(1) ## (B,1,T)
68
+ attention = torch.bmm(alphas, encoder_output) ## (B,T,dim_embed)
69
+ rnn_input = torch.cat((embds, attention), dim=-1) ## (B,1,dim_hidden + dim_embed)
70
+
71
+ output, hidden_t = self.rnn(rnn_input, hidden_t_1)
72
+
73
+ return output, hidden_t, alphas.squeeze(1) ## "a" is returned for visualization
74
+
75
+ class Seq2seq_with_attention(nn.Module):
76
+ def __init__(self, vocab_size:int, dim_embed:int, dim_model:int, dim_feedforward:int, num_layers:int, dropout_probability:float):
77
+ super().__init__()
78
+
79
+ self.vocab_size = vocab_size
80
+ self.num_layers = num_layers
81
+ self.encoder = Encoder(vocab_size, dim_embed, dim_model, dim_feedforward, num_layers, dropout_probability)
82
+ self.attention = Attention(dim_model)
83
+ self.decoder = Decoder(vocab_size, dim_embed, dim_model, self.attention, num_layers, dropout_probability)
84
+ self.classifier = nn.Linear(dim_model, vocab_size)
85
+
86
+ ## weight sharing between classifier and embed_shared_src_trg_cls
87
+ self.encoder.embd_layer.weight = self.classifier.weight
88
+ self.decoder.embd_layer.weight = self.classifier.weight
89
+
90
+ def forward(self, source, target, pad_tokenId):
91
+ # target = <s> text </s>
92
+ # teacher_force_ratio = 0.5
93
+ B, T = target.size()
94
+ total_logits = torch.zeros(B, T, self.vocab_size, device=source.device)
95
+ context, hidden = self.encoder(source)
96
+ hidden = hidden.unsqueeze(0).repeat(self.num_layers,1,1) # (numlayer, B, dim_model)
97
+ for step in range(T):
98
+ step_token = target[:, [step]]
99
+ out, hidden, alphas = self.decoder(step_token, context, hidden)
100
+ logits = self.classifier(out).squeeze(1)
101
+ total_logits[:, step] = logits
102
+ loss = None
103
+ if T > 1:
104
+ flat_logits = total_logits[:,:-1,:].reshape(-1, total_logits.size(-1))
105
+ flat_targets = target[:,1:].reshape(-1)
106
+ loss = nn.functional.cross_entropy(flat_logits, flat_targets, ignore_index=pad_tokenId)
107
+ return total_logits, loss
108
+
109
+ @torch.no_grad
110
+ def greedy_decode_fast(self, source:torch.Tensor, sos_tokenId: int, eos_tokenId:int, pad_tokenId, max_tries=50):
111
+ self.eval()
112
+ targets_hat = [sos_tokenId]
113
+ context, hidden = self.encoder(source.unsqueeze(0))
114
+ hidden = hidden.unsqueeze(0).repeat(self.num_layers,1,1) # (numlayer, B, dim_model)
115
+ for step in range(max_tries):
116
+ x = torch.tensor([targets_hat[step]]).unsqueeze(0).to(source.device)
117
+ out, hidden, alphas = self.decoder(x, context, hidden)
118
+ logits = self.classifier(out)
119
+ top1 = logits.argmax(-1)
120
+ targets_hat.append(top1.item())
121
+ if top1 == eos_tokenId:
122
+ return targets_hat
123
+ return targets_hat
Models/seq2seq_model.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import random
4
+
5
+
6
+ class Encoder(nn.Module):
7
+ def __init__(self, vocab_size, dim_embed, dim_hidden, dim_feedforward, num_layers, dropout_probability=0.1):
8
+ super().__init__()
9
+
10
+ self.embd_layer = nn.Embedding(vocab_size, dim_embed)
11
+ self.dropout = nn.Dropout(dropout_probability)
12
+ self.rnn = nn.GRU(dim_embed, dim_hidden, num_layers=num_layers,
13
+ dropout=dropout_probability,batch_first=True,
14
+ bidirectional=True)
15
+ self.ff = nn.Sequential(nn.Linear(dim_hidden*2, dim_feedforward),
16
+ nn.ReLU(),
17
+ nn.Linear(dim_feedforward, dim_hidden),
18
+ nn.Dropout(dropout_probability))
19
+
20
+ def forward(self, x):
21
+ embds = self.dropout(self.embd_layer(x))
22
+ output, hidden = self.rnn(embds)
23
+ ## hidden[-2,:,:]: hidden state for the forward direction of the last layer.
24
+ ## hidden[-1,:,:]: hidden state for the backward direction of the last layer.
25
+ last_hidden = torch.cat([hidden[-2,:,:], hidden[-1,:,:]], dim=-1)
26
+ projected_hidden = self.ff(last_hidden)
27
+ return projected_hidden
28
+
29
+
30
+
31
+ class Decoder(nn.Module):
32
+ def __init__(self, vocab_size, dim_embed, dim_hidden, num_layers, dropout_probability=0.1):
33
+ super().__init__()
34
+
35
+ self.embd_layer = nn.Embedding(vocab_size, dim_embed)
36
+ self.dropout = nn.Dropout(dropout_probability)
37
+ self.rnn = nn.GRU(dim_embed, dim_hidden, num_layers=num_layers,
38
+ dropout=dropout_probability, batch_first=True)
39
+ self.ffw = nn.Linear(dim_hidden, dim_hidden)
40
+
41
+ def forward(self, x, hidden_t_1):
42
+ embds = self.dropout(self.embd_layer(x))
43
+ output, hidden_t = self.rnn(embds, hidden_t_1)
44
+ out = self.ffw(hidden_t[-1])
45
+ return out, hidden_t
46
+
47
+
48
+ class Seq2seq_no_attention(nn.Module):
49
+ def __init__(self, vocab_size:int, dim_embed:int, dim_model:int, dim_feedforward:int, num_layers:int, dropout_probability:float):
50
+ super(Seq2seq_no_attention, self).__init__()
51
+ self.vocab_size = vocab_size
52
+ self.num_layers = num_layers
53
+ self.encoder = Encoder(vocab_size, dim_embed, dim_model, dim_feedforward, num_layers, dropout_probability)
54
+ self.decoder = Decoder(vocab_size, dim_embed, dim_model, num_layers, dropout_probability)
55
+ self.classifier = nn.Linear(dim_model, vocab_size)
56
+ ## weight sharing between classifier and embed_shared_src_trg_cls
57
+ self.encoder.embd_layer.weight = self.classifier.weight
58
+ self.decoder.embd_layer.weight = self.classifier.weight
59
+
60
+ def forward(self, source, target, pad_tokenId):
61
+ # target = <s> text </s>
62
+ # teacher_force_ratio = 0.5
63
+ B, T = target.size()
64
+ total_logits = torch.zeros(B, T, self.vocab_size, device=source.device) # (B,T,vocab_size)
65
+
66
+ context = self.encoder(source) # (B, dim_model)
67
+ ## We will pass the hiddens for each layer of the decoder (inspired by Attention is all you need paper)
68
+ context = context.unsqueeze(0).repeat(self.num_layers,1,1) # (numlayer, B, dim_model)
69
+ for step in range(T):
70
+ step_token = target[:, [step]]
71
+ out, context = self.decoder(step_token, context)
72
+ logits = self.classifier(out).squeeze(1)
73
+ total_logits[:, step] = logits
74
+ loss = None
75
+ if T > 1:
76
+ flat_logits = total_logits[:,:-1,:].reshape(-1, total_logits.size(-1))
77
+ flat_targets = target[:,1:].reshape(-1)
78
+ loss = nn.functional.cross_entropy(flat_logits, flat_targets, ignore_index=pad_tokenId)
79
+ return total_logits, loss
80
+
81
+
82
+ @torch.no_grad
83
+ def greedy_decode_fast(self, source:torch.Tensor, sos_tokenId: int, eos_tokenId:int, pad_tokenId, max_tries=50):
84
+ self.eval()
85
+ targets_hat = [sos_tokenId]
86
+ context = self.encoder(source.unsqueeze(0))
87
+ context = context.unsqueeze(0).repeat(self.num_layers,1,1)
88
+ for step in range(max_tries):
89
+ x = torch.tensor([targets_hat[step]]).unsqueeze(0).to(source.device)
90
+ out, context = self.decoder(x, context)
91
+ logits = self.classifier(out)
92
+ top1 = logits.argmax(-1)
93
+ targets_hat.append(top1.item())
94
+ if top1 == eos_tokenId:
95
+ return targets_hat
96
+ return targets_hat