Upload 9 files
Browse files- Configurations/s2s_model_config.json +9 -0
- Configurations/s2sattention_model_config.json +9 -0
- Configurations/transformer_model_config.json +9 -0
- Models/AutoModel.py +31 -0
- Models/ModelArgs.py +57 -0
- Models/Transformer_model.py +122 -0
- Models/__init__.py +0 -0
- Models/seq2seqAttention_model.py +123 -0
- Models/seq2seq_model.py +96 -0
Configurations/s2s_model_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dim_embed": 256,
|
3 |
+
"dim_model": 256,
|
4 |
+
"dim_feedforward": 1024,
|
5 |
+
"num_layers": 4,
|
6 |
+
"dropout": 0.1,
|
7 |
+
"maxlen": 512,
|
8 |
+
"flash_attention": false
|
9 |
+
}
|
Configurations/s2sattention_model_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dim_embed": 256,
|
3 |
+
"dim_model": 256,
|
4 |
+
"dim_feedforward": 1024,
|
5 |
+
"num_layers": 4,
|
6 |
+
"dropout": 0.1,
|
7 |
+
"maxlen": 512,
|
8 |
+
"flash_attention": false
|
9 |
+
}
|
Configurations/transformer_model_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dim_embed": 256,
|
3 |
+
"dim_model": 256,
|
4 |
+
"dim_feedforward": 1024,
|
5 |
+
"num_layers": 4,
|
6 |
+
"dropout": 0.1,
|
7 |
+
"maxlen": 512,
|
8 |
+
"flash_attention": false
|
9 |
+
}
|
Models/AutoModel.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Models.seq2seq_model import Seq2seq_no_attention
|
2 |
+
from Models.seq2seqAttention_model import Seq2seq_with_attention
|
3 |
+
from Models.Transformer_model import NMT_Transformer
|
4 |
+
from Models.ModelArgs import ModelArgs
|
5 |
+
|
6 |
+
|
7 |
+
def get_model(params:ModelArgs, vocab_size):
|
8 |
+
|
9 |
+
if params.model_type.lower() == 's2s': model = Seq2seq_no_attention(vocab_size=vocab_size,
|
10 |
+
dim_embed=params.dim_embed,
|
11 |
+
dim_model=params.dim_model,
|
12 |
+
dim_feedforward=params.dim_feedforward,
|
13 |
+
num_layers=params.num_layers,
|
14 |
+
dropout_probability=params.dropout)
|
15 |
+
|
16 |
+
elif params.model_type.lower() == 's2sattention': model = Seq2seq_with_attention(vocab_size=vocab_size,
|
17 |
+
dim_embed=params.dim_embed,
|
18 |
+
dim_model=params.dim_model,
|
19 |
+
dim_feedforward=params.dim_feedforward,
|
20 |
+
num_layers=params.num_layers,
|
21 |
+
dropout_probability=params.dropout)
|
22 |
+
|
23 |
+
else: model = NMT_Transformer(vocab_size=vocab_size,
|
24 |
+
dim_embed=params.dim_embed,
|
25 |
+
dim_model=params.dim_model,
|
26 |
+
dim_feedforward=params.dim_feedforward,
|
27 |
+
num_layers=params.num_layers,
|
28 |
+
dropout_probability=params.dropout,
|
29 |
+
maxlen=params.maxlen)
|
30 |
+
return model
|
31 |
+
|
Models/ModelArgs.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
class ModelArgs:
|
4 |
+
"""
|
5 |
+
A class to parse and store model configuration from a JSON file.
|
6 |
+
"""
|
7 |
+
def __init__(self, model_type:str, config_path:str):
|
8 |
+
"""
|
9 |
+
Initialize ModelArgs with configuration from a JSON file.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
config_path (str): Path to the JSON configuration file.
|
13 |
+
|
14 |
+
Raises:
|
15 |
+
AssertionError: If the JSON content is invalid or has missing keys.
|
16 |
+
"""
|
17 |
+
# Load JSON file
|
18 |
+
with open(config_path, 'r') as file:
|
19 |
+
config = json.load(file)
|
20 |
+
|
21 |
+
# Validate and assign attributes
|
22 |
+
self.model_type = model_type.lower()
|
23 |
+
assert self.model_type in ['s2s', 's2sattention', 'transformer'], \
|
24 |
+
"Supported model_type values are ['s2s', 's2sAttention', 'transformer']."
|
25 |
+
|
26 |
+
self.dim_embed = config.get("dim_embed")
|
27 |
+
assert isinstance(self.dim_embed, int), "dim_embed must be an integer."
|
28 |
+
|
29 |
+
self.dim_model = config.get("dim_model")
|
30 |
+
assert isinstance(self.dim_model, int), "dim_model must be an integer."
|
31 |
+
|
32 |
+
self.dim_feedforward = config.get("dim_feedforward")
|
33 |
+
assert isinstance(self.dim_feedforward, int), "dim_feedforward must be an integer."
|
34 |
+
|
35 |
+
self.num_layers = config.get("num_layers")
|
36 |
+
assert isinstance(self.num_layers, int), "num_layers must be an integer."
|
37 |
+
|
38 |
+
self.dropout = config.get("dropout")
|
39 |
+
assert isinstance(self.dropout, float), "dropout must be a float."
|
40 |
+
|
41 |
+
self.maxlen = config.get("maxlen")
|
42 |
+
assert isinstance(self.maxlen, int), "maxlen must be an integer."
|
43 |
+
|
44 |
+
self.flash_attention = config.get("flash_attention")
|
45 |
+
assert isinstance(self.flash_attention, bool), "flash_attention must be a boolean."
|
46 |
+
|
47 |
+
def __repr__(self):
|
48 |
+
return (f"ModelArgs(\n" +
|
49 |
+
f"model_type={self.model_type},\n" +
|
50 |
+
f"dim_embed={self.dim_embed},\n" +
|
51 |
+
f"dim_model={self.dim_model},\n" +
|
52 |
+
f"dim_feedforward={self.dim_feedforward},\n" +
|
53 |
+
f"num_layers={self.num_layers},\n" +
|
54 |
+
f"dropout={self.dropout},\n" +
|
55 |
+
f"maxlen={self.maxlen},\n" +
|
56 |
+
f"flash_attention={self.flash_attention}\n" +
|
57 |
+
")")
|
Models/Transformer_model.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class NMT_Transformer(nn.Module):
|
6 |
+
def __init__(self, vocab_size:int, dim_embed:int,
|
7 |
+
dim_model:int, dim_feedforward:int, num_layers:int,
|
8 |
+
dropout_probability:float, maxlen:int):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.embed_shared_src_trg_cls = nn.Embedding(num_embeddings=vocab_size, embedding_dim=dim_embed)
|
12 |
+
self.positonal_shared_src_trg = nn.Embedding(num_embeddings=maxlen, embedding_dim=dim_embed)
|
13 |
+
|
14 |
+
# self.trg_embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=dim_embed)
|
15 |
+
# self.trg_pos = nn.Embedding(num_embeddings=maxlen, embedding_dim=dim_embed)
|
16 |
+
|
17 |
+
self.dropout = nn.Dropout(dropout_probability)
|
18 |
+
|
19 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model, nhead=8,
|
20 |
+
dim_feedforward=dim_feedforward,
|
21 |
+
dropout=dropout_probability,
|
22 |
+
batch_first=True, norm_first=True)
|
23 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, enable_nested_tensor=False)
|
24 |
+
|
25 |
+
decoder_layer = nn.TransformerDecoderLayer(d_model=dim_model, nhead=8,
|
26 |
+
dim_feedforward=dim_feedforward,
|
27 |
+
dropout=dropout_probability,
|
28 |
+
batch_first=True, norm_first=True)
|
29 |
+
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
|
30 |
+
|
31 |
+
self.classifier = nn.Linear(dim_model, vocab_size)
|
32 |
+
## weight sharing between classifier and embed_shared_src_trg_cls
|
33 |
+
self.classifier.weight = self.embed_shared_src_trg_cls.weight
|
34 |
+
|
35 |
+
self.maxlen = maxlen
|
36 |
+
self.apply(self._init_weights)
|
37 |
+
|
38 |
+
def _init_weights(self, module):
|
39 |
+
if isinstance(module, nn.Linear):
|
40 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
41 |
+
if module.bias is not None:
|
42 |
+
torch.nn.init.zeros_(module.bias)
|
43 |
+
elif isinstance(module, nn.Embedding):
|
44 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
45 |
+
elif isinstance(module, nn.LayerNorm):
|
46 |
+
torch.nn.init.ones_(module.weight)
|
47 |
+
torch.nn.init.zeros_(module.bias)
|
48 |
+
|
49 |
+
def forward(self, source, target, pad_tokenId):
|
50 |
+
# target = <sos> + text + <eos>
|
51 |
+
# source = text
|
52 |
+
B, Ts = source.shape
|
53 |
+
B, Tt = target.shape
|
54 |
+
device = source.device
|
55 |
+
## Encoder Path
|
56 |
+
src_poses = self.positonal_shared_src_trg(torch.arange(0, Ts).to(device).unsqueeze(0).repeat(B, 1))
|
57 |
+
src_embedings = self.dropout(self.embed_shared_src_trg_cls(source) + src_poses)
|
58 |
+
|
59 |
+
src_pad_mask = source == pad_tokenId
|
60 |
+
memory = self.transformer_encoder(src=src_embedings, mask=None, src_key_padding_mask=src_pad_mask, is_causal=False)
|
61 |
+
## Decoder Path
|
62 |
+
trg_poses = self.positonal_shared_src_trg(torch.arange(0, Tt).to(device).unsqueeze(0).repeat(B, 1))
|
63 |
+
trg_embedings = self.dropout(self.embed_shared_src_trg_cls(target) + trg_poses)
|
64 |
+
|
65 |
+
trg_pad_mask = target == pad_tokenId
|
66 |
+
tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(Tt, dtype=bool).to(device)
|
67 |
+
decoder_out = self.transformer_decoder.forward(tgt=trg_embedings,
|
68 |
+
memory=memory,
|
69 |
+
tgt_mask=tgt_mask,
|
70 |
+
memory_mask=None,
|
71 |
+
tgt_key_padding_mask=trg_pad_mask,
|
72 |
+
memory_key_padding_mask=None)
|
73 |
+
## Classifier Path
|
74 |
+
logits = self.classifier(decoder_out)
|
75 |
+
loss = None
|
76 |
+
if Tt > 1:
|
77 |
+
# for model logits we will need all tokens except the last one
|
78 |
+
flat_logits = logits[:,:-1,:].reshape(-1, logits.size(-1))
|
79 |
+
# for targets we will need all tokens excapt the first one
|
80 |
+
flat_targets = target[:,1:].reshape(-1)
|
81 |
+
loss = nn.functional.cross_entropy(flat_logits, flat_targets, ignore_index=pad_tokenId)
|
82 |
+
return logits, loss
|
83 |
+
|
84 |
+
|
85 |
+
@torch.no_grad
|
86 |
+
def greedy_decode_fast(self, source_tensor:torch.Tensor, sos_tokenId: int, eos_tokenId:int, pad_tokenId, max_tries=50):
|
87 |
+
self.eval()
|
88 |
+
source_tensor = source_tensor.unsqueeze(0)
|
89 |
+
B, Ts = source_tensor.shape
|
90 |
+
device = source_tensor.device
|
91 |
+
target_tensor = torch.tensor([sos_tokenId]).unsqueeze(0).to(device)
|
92 |
+
|
93 |
+
## Encoder Path
|
94 |
+
src_poses = self.positonal_shared_src_trg(torch.arange(0, Ts).to(device).unsqueeze(0).repeat(B, 1))
|
95 |
+
src_embedings = self.embed_shared_src_trg_cls(source_tensor) + src_poses
|
96 |
+
src_pad_mask = source_tensor == pad_tokenId
|
97 |
+
context = self.transformer_encoder(src=src_embedings, mask=None, src_key_padding_mask=src_pad_mask, is_causal=False)
|
98 |
+
|
99 |
+
for i in range(max_tries):
|
100 |
+
## Decoder Path
|
101 |
+
trg_poses = self.positonal_shared_src_trg(torch.arange(0, i+1).to(device).unsqueeze(0).repeat(B, 1))
|
102 |
+
trg_embedings = self.embed_shared_src_trg_cls(target_tensor) + trg_poses
|
103 |
+
|
104 |
+
trg_pad_mask = target_tensor == pad_tokenId
|
105 |
+
tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(i+1, dtype=bool).to(device)
|
106 |
+
decoder_out = self.transformer_decoder.forward(tgt=trg_embedings,
|
107 |
+
memory=context,
|
108 |
+
tgt_mask=tgt_mask,
|
109 |
+
memory_mask=None,
|
110 |
+
tgt_key_padding_mask=trg_pad_mask,
|
111 |
+
memory_key_padding_mask=None)
|
112 |
+
## Classifier Path
|
113 |
+
logits = self.classifier(decoder_out)
|
114 |
+
# Greedy decoding
|
115 |
+
top1 = logits[:,-1,:].argmax(dim=-1, keepdim=True)
|
116 |
+
# Append predicted token
|
117 |
+
target_tensor = torch.cat([target_tensor, top1], dim=1)
|
118 |
+
|
119 |
+
# Stop if predict <EOS>
|
120 |
+
if top1.item() == eos_tokenId:
|
121 |
+
break
|
122 |
+
return target_tensor.squeeze(0).tolist()
|
Models/__init__.py
ADDED
File without changes
|
Models/seq2seqAttention_model.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import random
|
4 |
+
|
5 |
+
|
6 |
+
class Encoder(nn.Module):
|
7 |
+
def __init__(self, vocab_size, dim_embed, dim_hidden, dim_feedforward, num_layers, dropout_probability=0.1):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.embd_layer = nn.Embedding(vocab_size, dim_embed)
|
11 |
+
self.dropout = nn.Dropout(dropout_probability)
|
12 |
+
self.rnn = nn.GRU(dim_embed, dim_hidden, num_layers, batch_first=True, dropout=dropout_probability, bidirectional=True)
|
13 |
+
|
14 |
+
self.hidden_map = nn.Sequential(nn.Linear(dim_hidden*2, dim_feedforward),
|
15 |
+
nn.ReLU(),
|
16 |
+
nn.Linear(dim_feedforward, dim_hidden),
|
17 |
+
nn.Dropout(dropout_probability))
|
18 |
+
|
19 |
+
self.output_map = nn.Sequential(nn.Linear(dim_hidden*2, dim_feedforward),
|
20 |
+
nn.ReLU(),
|
21 |
+
nn.Linear(dim_feedforward, dim_hidden),
|
22 |
+
nn.Dropout(dropout_probability))
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
embds = self.dropout(self.embd_layer(x))
|
26 |
+
context, hidden = self.rnn(embds)
|
27 |
+
last_hidden = torch.cat([hidden[-2,:,:], hidden[-1,:,:]], dim=-1)
|
28 |
+
to_decoder_hidden = self.hidden_map(last_hidden)
|
29 |
+
to_decoder_output = self.output_map(context)
|
30 |
+
return to_decoder_output, to_decoder_hidden
|
31 |
+
|
32 |
+
|
33 |
+
class Attention(nn.Module):
|
34 |
+
def __init__(self, input_dims):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.fc_energy = nn.Linear(input_dims*2, input_dims)
|
38 |
+
self.alpha = nn.Linear(input_dims, 1, bias=False)
|
39 |
+
|
40 |
+
def forward(self,
|
41 |
+
encoder_output, # (B,T,encoder_hidden)
|
42 |
+
decoder_hidden): # (B,decoder_hidden)
|
43 |
+
## encoder_hidden = encoder_hidden = input_dims
|
44 |
+
|
45 |
+
seq_len = encoder_output.size(1)
|
46 |
+
decoder_hidden = decoder_hidden.unsqueeze(1).repeat(1, seq_len, 1) ## (B,T,input_dims)
|
47 |
+
energy = self.fc_energy(torch.cat((decoder_hidden, encoder_output), dim=-1))
|
48 |
+
alphas = self.alpha(energy).squeeze(-1)
|
49 |
+
|
50 |
+
return torch.softmax(alphas, dim=-1)
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
class Decoder(nn.Module):
|
55 |
+
def __init__(self, vocab_size, dim_embed, dim_hidden, attention, num_layers, dropout_probability):
|
56 |
+
super().__init__()
|
57 |
+
self.attention = attention
|
58 |
+
self.embd_layer = nn.Embedding(vocab_size, dim_embed)
|
59 |
+
self.rnn = nn.GRU(dim_hidden + dim_embed, dim_hidden, batch_first=True, num_layers=num_layers, dropout=dropout_probability)
|
60 |
+
|
61 |
+
def forward(self, x, encoder_output, hidden_t_1):
|
62 |
+
## hidden_t_1 shape: (num_layers,B,dim_hidden)
|
63 |
+
## encoder_output shape : (B,T,dim_hidden)
|
64 |
+
## x shape: (B,1) one token
|
65 |
+
|
66 |
+
embds = self.embd_layer(x) ## (B,1,dim_embed)
|
67 |
+
alphas = self.attention(encoder_output, hidden_t_1[-1]).unsqueeze(1) ## (B,1,T)
|
68 |
+
attention = torch.bmm(alphas, encoder_output) ## (B,T,dim_embed)
|
69 |
+
rnn_input = torch.cat((embds, attention), dim=-1) ## (B,1,dim_hidden + dim_embed)
|
70 |
+
|
71 |
+
output, hidden_t = self.rnn(rnn_input, hidden_t_1)
|
72 |
+
|
73 |
+
return output, hidden_t, alphas.squeeze(1) ## "a" is returned for visualization
|
74 |
+
|
75 |
+
class Seq2seq_with_attention(nn.Module):
|
76 |
+
def __init__(self, vocab_size:int, dim_embed:int, dim_model:int, dim_feedforward:int, num_layers:int, dropout_probability:float):
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
self.vocab_size = vocab_size
|
80 |
+
self.num_layers = num_layers
|
81 |
+
self.encoder = Encoder(vocab_size, dim_embed, dim_model, dim_feedforward, num_layers, dropout_probability)
|
82 |
+
self.attention = Attention(dim_model)
|
83 |
+
self.decoder = Decoder(vocab_size, dim_embed, dim_model, self.attention, num_layers, dropout_probability)
|
84 |
+
self.classifier = nn.Linear(dim_model, vocab_size)
|
85 |
+
|
86 |
+
## weight sharing between classifier and embed_shared_src_trg_cls
|
87 |
+
self.encoder.embd_layer.weight = self.classifier.weight
|
88 |
+
self.decoder.embd_layer.weight = self.classifier.weight
|
89 |
+
|
90 |
+
def forward(self, source, target, pad_tokenId):
|
91 |
+
# target = <s> text </s>
|
92 |
+
# teacher_force_ratio = 0.5
|
93 |
+
B, T = target.size()
|
94 |
+
total_logits = torch.zeros(B, T, self.vocab_size, device=source.device)
|
95 |
+
context, hidden = self.encoder(source)
|
96 |
+
hidden = hidden.unsqueeze(0).repeat(self.num_layers,1,1) # (numlayer, B, dim_model)
|
97 |
+
for step in range(T):
|
98 |
+
step_token = target[:, [step]]
|
99 |
+
out, hidden, alphas = self.decoder(step_token, context, hidden)
|
100 |
+
logits = self.classifier(out).squeeze(1)
|
101 |
+
total_logits[:, step] = logits
|
102 |
+
loss = None
|
103 |
+
if T > 1:
|
104 |
+
flat_logits = total_logits[:,:-1,:].reshape(-1, total_logits.size(-1))
|
105 |
+
flat_targets = target[:,1:].reshape(-1)
|
106 |
+
loss = nn.functional.cross_entropy(flat_logits, flat_targets, ignore_index=pad_tokenId)
|
107 |
+
return total_logits, loss
|
108 |
+
|
109 |
+
@torch.no_grad
|
110 |
+
def greedy_decode_fast(self, source:torch.Tensor, sos_tokenId: int, eos_tokenId:int, pad_tokenId, max_tries=50):
|
111 |
+
self.eval()
|
112 |
+
targets_hat = [sos_tokenId]
|
113 |
+
context, hidden = self.encoder(source.unsqueeze(0))
|
114 |
+
hidden = hidden.unsqueeze(0).repeat(self.num_layers,1,1) # (numlayer, B, dim_model)
|
115 |
+
for step in range(max_tries):
|
116 |
+
x = torch.tensor([targets_hat[step]]).unsqueeze(0).to(source.device)
|
117 |
+
out, hidden, alphas = self.decoder(x, context, hidden)
|
118 |
+
logits = self.classifier(out)
|
119 |
+
top1 = logits.argmax(-1)
|
120 |
+
targets_hat.append(top1.item())
|
121 |
+
if top1 == eos_tokenId:
|
122 |
+
return targets_hat
|
123 |
+
return targets_hat
|
Models/seq2seq_model.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import random
|
4 |
+
|
5 |
+
|
6 |
+
class Encoder(nn.Module):
|
7 |
+
def __init__(self, vocab_size, dim_embed, dim_hidden, dim_feedforward, num_layers, dropout_probability=0.1):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.embd_layer = nn.Embedding(vocab_size, dim_embed)
|
11 |
+
self.dropout = nn.Dropout(dropout_probability)
|
12 |
+
self.rnn = nn.GRU(dim_embed, dim_hidden, num_layers=num_layers,
|
13 |
+
dropout=dropout_probability,batch_first=True,
|
14 |
+
bidirectional=True)
|
15 |
+
self.ff = nn.Sequential(nn.Linear(dim_hidden*2, dim_feedforward),
|
16 |
+
nn.ReLU(),
|
17 |
+
nn.Linear(dim_feedforward, dim_hidden),
|
18 |
+
nn.Dropout(dropout_probability))
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
embds = self.dropout(self.embd_layer(x))
|
22 |
+
output, hidden = self.rnn(embds)
|
23 |
+
## hidden[-2,:,:]: hidden state for the forward direction of the last layer.
|
24 |
+
## hidden[-1,:,:]: hidden state for the backward direction of the last layer.
|
25 |
+
last_hidden = torch.cat([hidden[-2,:,:], hidden[-1,:,:]], dim=-1)
|
26 |
+
projected_hidden = self.ff(last_hidden)
|
27 |
+
return projected_hidden
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
class Decoder(nn.Module):
|
32 |
+
def __init__(self, vocab_size, dim_embed, dim_hidden, num_layers, dropout_probability=0.1):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.embd_layer = nn.Embedding(vocab_size, dim_embed)
|
36 |
+
self.dropout = nn.Dropout(dropout_probability)
|
37 |
+
self.rnn = nn.GRU(dim_embed, dim_hidden, num_layers=num_layers,
|
38 |
+
dropout=dropout_probability, batch_first=True)
|
39 |
+
self.ffw = nn.Linear(dim_hidden, dim_hidden)
|
40 |
+
|
41 |
+
def forward(self, x, hidden_t_1):
|
42 |
+
embds = self.dropout(self.embd_layer(x))
|
43 |
+
output, hidden_t = self.rnn(embds, hidden_t_1)
|
44 |
+
out = self.ffw(hidden_t[-1])
|
45 |
+
return out, hidden_t
|
46 |
+
|
47 |
+
|
48 |
+
class Seq2seq_no_attention(nn.Module):
|
49 |
+
def __init__(self, vocab_size:int, dim_embed:int, dim_model:int, dim_feedforward:int, num_layers:int, dropout_probability:float):
|
50 |
+
super(Seq2seq_no_attention, self).__init__()
|
51 |
+
self.vocab_size = vocab_size
|
52 |
+
self.num_layers = num_layers
|
53 |
+
self.encoder = Encoder(vocab_size, dim_embed, dim_model, dim_feedforward, num_layers, dropout_probability)
|
54 |
+
self.decoder = Decoder(vocab_size, dim_embed, dim_model, num_layers, dropout_probability)
|
55 |
+
self.classifier = nn.Linear(dim_model, vocab_size)
|
56 |
+
## weight sharing between classifier and embed_shared_src_trg_cls
|
57 |
+
self.encoder.embd_layer.weight = self.classifier.weight
|
58 |
+
self.decoder.embd_layer.weight = self.classifier.weight
|
59 |
+
|
60 |
+
def forward(self, source, target, pad_tokenId):
|
61 |
+
# target = <s> text </s>
|
62 |
+
# teacher_force_ratio = 0.5
|
63 |
+
B, T = target.size()
|
64 |
+
total_logits = torch.zeros(B, T, self.vocab_size, device=source.device) # (B,T,vocab_size)
|
65 |
+
|
66 |
+
context = self.encoder(source) # (B, dim_model)
|
67 |
+
## We will pass the hiddens for each layer of the decoder (inspired by Attention is all you need paper)
|
68 |
+
context = context.unsqueeze(0).repeat(self.num_layers,1,1) # (numlayer, B, dim_model)
|
69 |
+
for step in range(T):
|
70 |
+
step_token = target[:, [step]]
|
71 |
+
out, context = self.decoder(step_token, context)
|
72 |
+
logits = self.classifier(out).squeeze(1)
|
73 |
+
total_logits[:, step] = logits
|
74 |
+
loss = None
|
75 |
+
if T > 1:
|
76 |
+
flat_logits = total_logits[:,:-1,:].reshape(-1, total_logits.size(-1))
|
77 |
+
flat_targets = target[:,1:].reshape(-1)
|
78 |
+
loss = nn.functional.cross_entropy(flat_logits, flat_targets, ignore_index=pad_tokenId)
|
79 |
+
return total_logits, loss
|
80 |
+
|
81 |
+
|
82 |
+
@torch.no_grad
|
83 |
+
def greedy_decode_fast(self, source:torch.Tensor, sos_tokenId: int, eos_tokenId:int, pad_tokenId, max_tries=50):
|
84 |
+
self.eval()
|
85 |
+
targets_hat = [sos_tokenId]
|
86 |
+
context = self.encoder(source.unsqueeze(0))
|
87 |
+
context = context.unsqueeze(0).repeat(self.num_layers,1,1)
|
88 |
+
for step in range(max_tries):
|
89 |
+
x = torch.tensor([targets_hat[step]]).unsqueeze(0).to(source.device)
|
90 |
+
out, context = self.decoder(x, context)
|
91 |
+
logits = self.classifier(out)
|
92 |
+
top1 = logits.argmax(-1)
|
93 |
+
targets_hat.append(top1.item())
|
94 |
+
if top1 == eos_tokenId:
|
95 |
+
return targets_hat
|
96 |
+
return targets_hat
|