vmal commited on
Commit
ca9f11d
·
1 Parent(s): 45f018f

decoder only transformer learning

Browse files
.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Initially taken from Github's Python gitignore file
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # tests and logs
12
+ tests/fixtures/cached_*_text.txt
13
+ logs/
14
+ lightning_logs/
15
+ lang_code_data/
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # celery beat schedule file
92
+ celerybeat-schedule
93
+
94
+ # SageMath parsed files
95
+ *.sage.py
96
+
97
+ # Environments
98
+ .env
99
+ .venv
100
+ env/
101
+ venv/
102
+ ENV/
103
+ env.bak/
104
+ venv.bak/
105
+
106
+ # Spyder project settings
107
+ .spyderproject
108
+ .spyproject
109
+
110
+ # Rope project settings
111
+ .ropeproject
112
+
113
+ # mkdocs documentation
114
+ /site
115
+
116
+ # mypy
117
+ .mypy_cache/
118
+ .dmypy.json
119
+ dmypy.json
120
+
121
+ # Pyre type checker
122
+ .pyre/
123
+
124
+ # vscode
125
+ .vs
126
+ .vscode
127
+
128
+ # Pycharm
129
+ .idea
130
+
131
+ # TF code
132
+ tensorflow_code
133
+
134
+ # Models
135
+ proc_data
136
+
137
+ # examples
138
+ runs
139
+ /runs_old
140
+ /wandb
141
+ /examples/runs
142
+ /examples/**/*.args
143
+ /examples/rag/sweep
144
+
145
+ # data
146
+ /data
147
+ serialization_dir
148
+
149
+ # emacs
150
+ *.*~
151
+ debug.env
152
+
153
+ # vim
154
+ .*.swp
155
+
156
+ #ctags
157
+ tags
158
+
159
+ # pre-commit
160
+ .pre-commit*
161
+
162
+ # .lock
163
+ *.lock
164
+
165
+ # DS_Store (MacOS)
166
+ .DS_Store
167
+
168
+ # ruff
169
+ .ruff_cache
README.md CHANGED
@@ -1,3 +1 @@
1
- ---
2
- license: unlicense
3
- ---
 
1
+ # decoder only transformers
 
 
learning.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch # create tensors and provides helper functions
2
+ import torch.nn as nn # for nn.Module(), nn.Embedding() and nn.Linear()
3
+ import torch.nn.functional as F # gives us the softmax() and argmax()
4
+ from torch.optim import Adam # Adam optimizer, stochastic gradient descent
5
+ from torch.utils.data import TensorDataset, DataLoader # for storing data loader
6
+
7
+ # first, create a dict that maps vocabulary tokens to id numbers
8
+ token_to_id = ({
9
+ 'what': 0,
10
+ 'is': 1,
11
+ 'your': 2,
12
+ 'name': 3,
13
+ 'gpt': 4,
14
+ 'my': 5,
15
+ '<EOS>': 10, # END OF SEQUENCE
16
+ '<PAD>': 11, # PADDING
17
+ })
18
+
19
+ ## create the dict that maps the ids to tokens, for interpretintg the model output.
20
+ id_to_token = dict(map(reversed, token_to_id.items()))
21
+ VOCAB_SIZE = len(token_to_id)
22
+ SEQ_LEN = 6
23
+ D_MODEL = 2
24
+ # we use decoder only transformer, the inputs contain
25
+ # the questions followed by <EOS> token followed by the response 'gpt'
26
+ # this is because all of the tokens will be used as inputs to the decoder only
27
+ # transformer during training.
28
+ # it's called teacher forcing
29
+ # teacher forcing helps us train the neural network faster
30
+
31
+ inputs = torch.tensor([
32
+ [
33
+ token_to_id['what'],
34
+ token_to_id['is'],
35
+ token_to_id['your'],
36
+ token_to_id['name'],
37
+ ],
38
+ [
39
+ token_to_id['gpt'],
40
+ token_to_id['is'],
41
+ token_to_id['my'],
42
+ ]
43
+ ])
44
+
45
+ # we are using decoder only transformer the outputs, or
46
+ # the predictions, are the input questions (minus the first word) followed by
47
+ # <EOS> gpt <EOS>. the first <EOS> means we are dong processing the input question
48
+ # and the second means we are done generating the output.
49
+ labels = torch.tensor([
50
+ [
51
+ token_to_id['is'],
52
+ token_to_id['your'],
53
+ token_to_id['name'],
54
+ token_to_id['<EOS>'],
55
+ token_to_id['gpt'],
56
+ token_to_id['<EOS>'],
57
+ ],
58
+ [
59
+ token_to_id['is'],
60
+ token_to_id['my'],
61
+ token_to_id['<EOS>'],
62
+ token_to_id['name'],
63
+ token_to_id['<EOS>'],
64
+ token_to_id['<PAD>'],
65
+ ]
66
+ ])
67
+
68
+ dataset = TensorDataset(inputs, labels)
69
+ dataloader = DataLoader(dataset=dataset)
70
+
71
+ print(f'Shape of the input: {inputs.shape}')
72
+ print(f'Shape of the labels: {labels.shape}')
73
+
74
+ x = inputs.unsqueeze(0)
75
+ y = labels.unsqueeze(0)
76
+
77
+ print(f'Batch input: {x.shape}')
78
+ print(f'Batch labels: {y.shape}')
model.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ A Decoder-Only Transformer components
3
+ -> Word Embedding
4
+ -> Position Encoding
5
+ -> Masked Self-Attention
6
+ -> Residual Connections
7
+ -> A fully connected layer
8
+ -> Classification Head
9
+ '''
10
+ import torch
11
+ import torch.nn as nn
12
+ import math
13
+
14
+ class WordEmbeddings(nn.Module):
15
+
16
+ def __init__(self, d_model: int, vocab_size: int):
17
+ ## d_model: The dimension of the transformer, which is also the number of embedding values per token.
18
+ ## vocab_size: Get the size of the underlying vocabulary
19
+
20
+ super().__init__()
21
+ self.d_model = d_model
22
+ self.vocab_size = vocab_size
23
+ self.embedding = nn.Embedding(num_embeddings=vocab_size,
24
+ embedding_dim=d_model)
25
+
26
+ def forward(self, x):
27
+ # (batch, seq_len) --> (batch, seq_len, d_model)
28
+ # multiply by sqrt(d_model) to scale the embeddings
29
+ return self.embedding(x) * math.sqrt(self.d_model)
30
+
31
+ class PositionEncoding(nn.Module):
32
+ '''Ref: https://github.com/StatQuest/decoder_transformer_from_scratch/blob/main/decoder_transformers_with_pytorch_and_lightning_v2.ipynb
33
+ '''
34
+
35
+ def __init__(self, d_model: int, seq_len: int, dropout: float):
36
+ ## d_model = The dimension of the transformer, which is also the number of embedding values per token.
37
+ ## In the transformer I used in the StatQuest: Transformer Neural Networks Clearly Explained!!!
38
+ ## d_model=2, so that's what we'll use as a default for now.
39
+ ## However, in "Attention Is All You Need" d_model=512
40
+ ## seq_len = maximum number of tokens we allow as input.
41
+ ## Since we are precomputing the position encoding values and storing them in a lookup table
42
+ ## we can use d_model and seq_len to determine the number of rows and columns in that
43
+ ## lookup table.
44
+ ##
45
+ ## In this simple example, we are only using short phrases, so we are using
46
+ ## seq_len=6 as the default setting.
47
+ ## However, in The Annotated Transformer, they set the default value for seq_len to 5000
48
+
49
+ ## We call the super's init because by creating our own __init__() method, we overwrite the one
50
+ ## we inherited from nn.Module. So we have to explicity call nn.Module's __init__(), otherwise it
51
+ ## won't get initialized. NOTE: If we didn't write our own __init__(), then we would not have
52
+ ## to call super().__init__(). Alternatively, if we didn't want to access any of nn.Module's methods,
53
+ ## we wouldn't have to call it then either.
54
+ super().__init__()
55
+
56
+ self.d_model = d_model
57
+ self.seq_len = seq_len
58
+ self.dropout = nn.Dropout(dropout)
59
+
60
+ ## Now we create a lookup table, pe, of position encoding values and initialize all of them to 0.
61
+ ## To do this, we will make a matrix of 0s that has seq_len rows and d_model columns.
62
+ ## for example...
63
+ ## torch.zeros(3, 2)
64
+ ## ...returns a matrix of 0s with 3 rows and 2 columns...
65
+ ## tensor([[0., 0.],
66
+ ## [0., 0.],
67
+ ## [0., 0.]])
68
+ pe = torch.zeros(seq_len, d_model)
69
+
70
+ ## Now we create a sequence of numbers for each position that a token can have in the input (or output).
71
+ ## For example, if the input tokens where "I'm happy today!", then "I'm" would get the first
72
+ ## position, 0, "happy" would get the second position, 1, and "today!" would get the third position, 2.
73
+ ## NOTE: Since we are going to be doing math with these position indices to create the
74
+ ## positional encoding for each one, we need them to be floats rather than ints.
75
+ ##
76
+ ## NOTE: Two ways to create floats are...
77
+ ##
78
+ ## torch.arange(start=0, end=3, step=1, dtype=torch.float)
79
+ ##
80
+ ## ...and...
81
+ ##
82
+ ## torch.arange(start=0, end=3, step=1).float()
83
+ ##
84
+ ## ...but the latter is just as clear and requires less typing.
85
+ ##
86
+ ## Lastly, .unsqueeze(1) converts the single list of numbers that torch.arange creates into a matrix with
87
+ ## one row for each index, and all of the indices in a single column. So if "seq_len" = 3, then we
88
+ ## would create a matrix with 3 rows and 1 column like this...
89
+ ##
90
+ ## torch.arange(start=0, end=3, step=1, dtype=torch.float).unsqueeze(1)
91
+ ##
92
+ ## ...returns...
93
+ ##
94
+ ## tensor([[0.],
95
+ ## [1.],
96
+ ## [2.]])
97
+ position = torch.arange(start=0, end=seq_len, step=1).float().unsqueeze(1)
98
+
99
+
100
+ ## Here is where we start doing the math to determine the y-axis coordinates on the
101
+ ## sine and cosine curves.
102
+ ##
103
+ ## The positional encoding equations used in "Attention is all you need" are...
104
+ ##
105
+ ## PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
106
+ ## PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
107
+ ##
108
+ ## ...and we see, within the sin() and cos() functions, we divide "pos" by some number that depends
109
+ ## on the index (i) and total number of PE values we want per token (d_model).
110
+ ##
111
+ ## NOTE: When the index, i, is 0 then we are calculating the y-axis coordinates on the **first pair**
112
+ ## of sine and cosine curves. When i=1, then we are calculating the y-axis coordiantes on the
113
+ ## **second pair** of sine and cosine curves. etc. etc.
114
+ ##
115
+ ## Now, pretty much everyone calculates the term we use to divide "pos" by first, and they do it with
116
+ ## code that looks like this...
117
+ ##
118
+ ## div_term = torch.exp(torch.arange(start=0, end=d_model, step=2).float() * -(math.log(10000.0) / d_model))
119
+ ##
120
+ ## Now, at least to me, it's not obvious that div_term = 1/(10000^(2i/d_model)) for a few reasons:
121
+ ##
122
+ ## 1) div_term wraps everything in a call to torch.exp()
123
+ ## 2) It uses log()
124
+ ## 2) The order of the terms is different
125
+ ##
126
+ ## The reason for these differences is, presumably, trying to prevent underflow (getting too close to 0).
127
+ ## So, to show that div_term = 1/(10000^(2i/d_model))...
128
+ ##
129
+ ## 1) Swap out math.log() for torch.log() (doing this requires converting 10000.0 to a tensor, which is my
130
+ ## guess for why they used math.log() instead of torch.log())...
131
+ ##
132
+ ## torch.exp(torch.arange(start=0, end=d_model, step=2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
133
+ ##
134
+ ## 2) Rearrange the terms...
135
+ ##
136
+ ## torch.exp(-1 * (torch.log(torch.tensor(10000.0)) * torch.arange(start=0, end=d_model, step=2).float() / d_model))
137
+ ##
138
+ ## 3) Pull out the -1 with exp(-1 * x) = 1/exp(x)
139
+ ##
140
+ ## 1/torch.exp(torch.log(torch.tensor(10000.0)) * torch.arange(start=0, end=d_model, step=2).float() / d_model)
141
+ ##
142
+ ## 4) Use exp(a * b) = exp(a)^b to pull out the 2i/d_model term...
143
+ ##
144
+ ## 1/torch.exp(torch.log(torch.tensor(10000.0)))^(torch.arange(start=0, end=d_model, step=2).float() / d_model)
145
+ ##
146
+ ## 5) Use exp(log(x)) = x to get the original form of the denominator...
147
+ ##
148
+ ## 1/(torch.tensor(10000.0)^(torch.arange(start=0, end=d_model, step=2).float() / d_model))
149
+ ##
150
+ ## 6) Bam.
151
+ ##
152
+ ## So, that being said, I don't think underflow is actually that big an issue. In fact, some coder at Hugging Face
153
+ ## also doesn't think so, and their code for positional encoding in DistilBERT (a streamlined version of BERT, which
154
+ ## is a transformer model)
155
+ ## calculates the values directly - using the form of the equation found in original Attention is all you need
156
+ ## manuscript. See...
157
+ ## https://github.com/huggingface/transformers/blob/455c6390938a5c737fa63e78396cedae41e4e87e/src/transformers/modeling_distilbert.py#L53
158
+ ## So I think we can simplify the code, but I'm also writing all these comments to show that it is equivalent to what
159
+ ## you'll see in the wild...
160
+ ##
161
+ ## Now let's create an index for the embedding positions to simplify the code a little more...
162
+ embedding_index = torch.arange(start=0, end=d_model, step=2).float()
163
+ ## NOTE: Setting step=2 results in the same sequence numbers that we would get if we multiplied i by 2.
164
+ ## So we can save ourselves a little math by just setting step=2.
165
+
166
+ ## And now, finally, let's create div_term...
167
+ div_term = 1/torch.tensor(10000.0)**(embedding_index / d_model)
168
+
169
+ ## Now we calculate the actual positional encoding values. Remember 'pe' was initialized as a matrix of 0s
170
+ ## with seq_len (max number of input tokens) rows and d_model (number of embedding values per token) columns.
171
+ pe[:, 0::2] = torch.sin(position * div_term) ## every other column, starting with the 1st, has sin() values
172
+ pe[:, 1::2] = torch.cos(position * div_term) ## every other column, starting with the 2nd, has cos() values
173
+ ## NOTE: If the notation for indexing 'pe[]' looks cryptic to you, read on...
174
+ ##
175
+ ## First, let's look at the general indexing notation:
176
+ ##
177
+ ## For each row or column in matrix we can select elements in that
178
+ ## row or column with the following indexs...
179
+ ##
180
+ ## i:j:k = select elements between i and j with stepsize = k.
181
+ ##
182
+ ## ...where...
183
+ ##
184
+ ## i defaults to 0
185
+ ## j defaults to the number of elements in the row, column or whatever.
186
+ ## k defaults to 1
187
+ ##
188
+ ## Now that we have looked at the general notation, let's look at specific
189
+ ## examples so that we can understand it.
190
+ ##
191
+ ## We'll start with: pe[:, 0::2]
192
+ ##
193
+ ## The stuff that comes before the comma (in this case ':') refers to the rows we want to select.
194
+ ## The ':' before the comma means "select all rows" because we are not providing specific
195
+ ## values for i, j and k and, instead, just using the default values.
196
+ ##
197
+ ## The stuff after the comma refers to the columns we want to select.
198
+ ## In this case, we have '0::2', and that means we start with
199
+ ## the first column (column = 0) and go to the end (using the default value for j)
200
+ ## and we set the stepsize to 2, which means we skip every other column.
201
+ ##
202
+ ## Now to understand pe[:, 1::2]
203
+ ##
204
+ ## Again, the stuff before the comma refers to the rows, and, just like before
205
+ ## we use default values for i,j and k, so we select all rows.
206
+ ##
207
+ ## The stuff that comes after the comma refers to the columns.
208
+ ## In this case, we start with the 2nd column (column = 1), and go to the end
209
+ ## (using the default value for 'j') and we set the stepsize to 2, which
210
+ ## means we skip every other column.
211
+ ##
212
+ ## NOTE: using this ':' based notation is called "indexing" and also called "slicing"
213
+ ## Add a batch dimension to the positional encoding
214
+ pe = pe.unsqueeze(0) # (1, seq_len, d_model)
215
+ ## Now we "register 'pe'.
216
+ self.register_buffer('pe', pe) ## "register_buffer()" ensures that
217
+ ## 'pe' will be moved to wherever the model gets
218
+ ## moved to. So if the model is moved to a GPU, then,
219
+ ## even though we don't need to optimize 'pe', it will
220
+ ## also be moved to that GPU. This, in turn, means
221
+ ## that accessing 'pe' will be relatively fast copared
222
+ ## to having a GPU have to get the data from a CPU.
223
+
224
+
225
+ def forward(self, word_embeddings):
226
+ ## Because this class, PositionEncoding, inherits from nn.Module, the forward() method
227
+ ## is called by default when we use a PositionEncoding() object.
228
+ ## In other words, after we create a PositionEncoding() object, pe = PositionEncoding(),
229
+ ## then pe(word_embeddings) will call forward() and so this is where
230
+ ## we will add the position encoding values to the word embedding values
231
+ ## (batch, seq_len, d_model)
232
+ x = word_embeddings + (self.pe[:,:word_embeddings.shape[1], :]).requires_grad_(False)
233
+
234
+ return self.dropout(x)
235
+
236
+ class LayerNormalization(nn.Module):
237
+
238
+ def __init__(self, features: int, eps:float=10**-6) -> None:
239
+ super().__init__()
240
+ self.eps = eps
241
+ self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
242
+ self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter
243
+
244
+ def forward(self, x):
245
+ # x: (batch, seq_len, hidden_size)
246
+ # Keep the dimension for broadcasting
247
+ mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
248
+ # Keep the dimension for broadcasting
249
+ std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
250
+ # eps is to prevent dividing by zero or when std is very small
251
+ return self.alpha * (x - mean) / (std + self.eps) + self.bias
252
+
253
+ class MultiHeadAttentionBlock(nn.Module):
254
+
255
+ def __init__(self, d_model: int, h: int, dropout: float) -> None:
256
+ super().__init__()
257
+ # Make sure d_model is divisible by h
258
+ assert d_model % h == 0, "d_model is not divisible by h"
259
+
260
+ self.d_model = d_model # Embedding vector size
261
+ self.h = h # Number of heads
262
+
263
+ self.d_k = d_model // h # Dimension of vector seen by each head
264
+ self.w_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False) # Wq
265
+ self.w_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False) # Wk
266
+ self.w_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False) # Wv
267
+ self.w_o = nn.Linear(in_features=d_model, out_features=d_model, bias=False) # Wo
268
+ self.dropout = nn.Dropout(dropout)
269
+
270
+ @staticmethod
271
+ def attention(query, key, value, mask, dropout: nn.Dropout):
272
+ d_k = query.shape[-1]
273
+ ## (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
274
+ ## Compute attention scores, the equation is (q * k^T)/sqrt(d_model)
275
+ attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
276
+ if mask is not None:
277
+ ## Here we are masking out things we don't want to pay attention to,
278
+ ## like tokens that come after the current token.
279
+ ## We can also use masking to block out the <PAD> token,
280
+ ## which is used when we have a batch of inputs sequences
281
+ ## and they are not all the exact same length. Because the batch is passed
282
+ ## in as a matrix, each input sequence has to have the same length, so we
283
+ ## add <PAD> to the shorter sequences so that they are all as long ast the
284
+ ## longest sequence.
285
+ ##
286
+ ## We replace <PAD>, or tokens that come after the current token
287
+ ## with a very large negative number so that the SoftMax() function
288
+ ## will give all masked elements an output value (or "probability") of 0.
289
+ ## Write a very low value (indicating -inf) to the positions where mask == 0
290
+ attention_scores.masked_fill_(mask == 0, -1e9)
291
+
292
+ ## Apply softmax to determine what percent of each token's value to
293
+ ## use in the final attention values.
294
+ ## (batch, h, seq_len, seq_len)
295
+ attention_scores = attention_scores.softmax(dim=-1)
296
+
297
+ if dropout is not None:
298
+ attention_scores = dropout(attention_scores)
299
+
300
+ ## (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
301
+ ## return attention scores which can be used for visualization
302
+ return (attention_scores @ value), attention_scores
303
+
304
+ def forward(self, q, k, v, mask):
305
+ query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
306
+ key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
307
+ value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
308
+
309
+ # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
310
+ query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
311
+ key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
312
+ value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
313
+
314
+ # Calculate attention
315
+ x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
316
+
317
+ # Combine all the heads together
318
+ # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
319
+ x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
320
+
321
+ # Multiply by Wo
322
+ # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
323
+ return self.w_o(x)
324
+
325
+ class ResidualConnection(nn.Module):
326
+
327
+ def __init__(self, features: int, dropout: float) -> None:
328
+ super().__init__()
329
+ self.dropout = nn.Dropout(dropout)
330
+ self.norm = LayerNormalization(features)
331
+
332
+ def forward(self, x, sublayer):
333
+ return x + self.dropout(sublayer(self.norm(x)))
334
+
335
+ class FeedForwardBlock(nn.Module):
336
+
337
+ def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
338
+ super().__init__()
339
+ self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
340
+ self.dropout = nn.Dropout(dropout)
341
+ self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2
342
+
343
+ def forward(self, x):
344
+ # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
345
+ return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
346
+
347
+ class DecoderBlock(nn.Module):
348
+
349
+ def __init__(self,
350
+ features: int,
351
+ self_attention_block: MultiHeadAttentionBlock,
352
+ feed_forward_block: FeedForwardBlock,
353
+ dropout: float) -> None:
354
+ super().__init__()
355
+ self.self_attention_block = self_attention_block
356
+ self.feed_forward_block = feed_forward_block
357
+ self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])
358
+
359
+ def forward(self, x, mask):
360
+ x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, mask))
361
+ x = self.residual_connections[1](x, self.feed_forward_block)
362
+ return x
363
+
364
+ class Decoder(nn.Module):
365
+
366
+ def __init__(self, features: int, layers: nn.ModuleList) -> None:
367
+ super().__init__()
368
+ self.layers = layers
369
+ self.norm = LayerNormalization(features)
370
+
371
+ def forward(self, x, mask):
372
+ for layer in self.layers:
373
+ x = layer(x, mask)
374
+ return self.norm(x)
375
+
376
+ class ProjectionLayer(nn.Module):
377
+
378
+ def __init__(self, d_model, vocab_size):
379
+ super().__init__()
380
+ self.proj = nn.Linear(d_model, vocab_size)
381
+
382
+ def forward(self, x) -> None:
383
+ # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
384
+ return self.proj(x)
385
+
386
+ class DecoderOnlyTransformer(nn.Module):
387
+ def __init__(self,
388
+ word_embedding: WordEmbeddings,
389
+ position_embedding: PositionEncoding,
390
+ decoder: Decoder,
391
+ projection_layer: ProjectionLayer):
392
+ super().__init__()
393
+ self.word_embedding = word_embedding
394
+ self.position_embedding = position_embedding
395
+ self.decoder = decoder
396
+ self.projection_layer = projection_layer
397
+
398
+ def decode(self, x: torch.Tensor, mask: torch.Tensor):
399
+ # x shape (batch, seq_len)
400
+ x = self.word_embedding(x)
401
+ x = self.position_embedding(x)
402
+ # x shape (batch, seq_len, d_model)
403
+ return self.decoder(x, mask)
404
+
405
+ def project(self, x):
406
+ # (batch, seq_len, vocab_size)
407
+ return self.projection_layer(x)
408
+
409
+ def build_transformer(vocab_size: int,
410
+ seq_len: int,
411
+ d_model: int=512,
412
+ N: int=6,
413
+ h: int=8,
414
+ dropout: float=0.1,
415
+ d_ff: int=2048) -> DecoderOnlyTransformer:
416
+ # Create the embedding layers
417
+ word_embedding = WordEmbeddings(d_model, vocab_size)
418
+
419
+ # Create the positional encoding layers
420
+ position_encoding = PositionEncoding(d_model, seq_len, dropout)
421
+
422
+ # Create the decoder blocks
423
+ decoder_blocks = []
424
+ for _ in range(N):
425
+ multi_head_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
426
+ feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
427
+ decoder_block = DecoderBlock(d_model, multi_head_self_attention_block, feed_forward_block, dropout)
428
+ decoder_blocks.append(decoder_block)
429
+
430
+ # Create the encoder and decoder
431
+ decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
432
+
433
+ # Create the projection layer
434
+ projection_layer = ProjectionLayer(d_model, vocab_size)
435
+
436
+ # Create the transformer
437
+ transformer = DecoderOnlyTransformer(word_embedding,
438
+ position_encoding,
439
+ decoder,
440
+ projection_layer)
441
+
442
+ # Initialize the parameters
443
+ for p in transformer.parameters():
444
+ if p.dim() > 1:
445
+ nn.init.xavier_uniform_(p)
446
+
447
+ return transformer
448
+
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.34
2
+ numpy
3
+ tiktoken
4
+ torch
5
+ torchdata
6
+ accelerate
7
+ evaluate
8
+ rouge_score
9
+ loralib
10
+ peft
11
+ datasets
12
+ torchmetrics
shakespeare/data/input.txt ADDED
The diff for this file is too large to render. See raw diff
 
shakespeare/data/test.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff705d970c50caa6f8b25a9f35b2c168a8dae8d81ef0963347e63327402c9d60
3
+ size 69612
shakespeare/data/train.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c3897b6613e06aa21e514d86d67214e3570fcfe20ccb8b983f4cb676e1d6c56
3
+ size 584160
shakespeare/weights/tmodel_26000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9c59034983f6e1048d7d2d3699da8a6b4c25ead392c7c43c7eac6bdfe5a9b74
3
+ size 846393125
shakespeare_config.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from pathlib import Path
5
+ from transformers import GPT2Tokenizer
6
+
7
+
8
+ def get_config():
9
+ return {
10
+ "batch_size": 8,
11
+ "num_epochs": 600000,
12
+ "lr": 10**-4,
13
+ "seq_len": 350,
14
+ "d_model": 512,
15
+ "vocab_size": 50304,
16
+ "datasource": 'shakespeare',
17
+ "model_folder": "weights",
18
+ "model_basename": "tmodel_",
19
+ "preload": "latest",
20
+ "tokenizer_file": "tokenizer.json",
21
+ "experiment_name": "runs/tmodel"
22
+ }
23
+
24
+ current_directory = os.path.dirname(os.path.abspath(__file__))
25
+
26
+ def get_weights_file_path(config, epoch: str):
27
+ model_folder = f"{current_directory}/{config['datasource']}/{config['model_folder']}"
28
+ # Create the folder and subfolders if they don't exist
29
+ Path(model_folder).mkdir(parents=True, exist_ok=True)
30
+ model_filename = f"{config['model_basename']}{epoch}.pt"
31
+ return model_folder + '/' + model_filename
32
+
33
+ def get_data_folder_path(config):
34
+ model_folder = f"{current_directory}/{config['datasource']}/data"
35
+ Path(model_folder).mkdir(parents=True, exist_ok=True)
36
+ return model_folder
37
+
38
+ # Find the latest weights file in the weights folder
39
+ def latest_weights_file_path(config):
40
+ model_folder = f"{current_directory}/{config['datasource']}/{config['model_folder']}"
41
+ model_filename = f"{config['model_basename']}*"
42
+ weights_files = list(Path(model_folder).glob(model_filename))
43
+ if len(weights_files) == 0:
44
+ return None
45
+ weights_files.sort()
46
+ return str(weights_files[-1])
47
+
48
+ def get_gpt2_tokenizer(config):
49
+ tokenizer:GPT2Tokenizer = GPT2Tokenizer.from_pretrained(
50
+ pretrained_model_name_or_path="openai-community/gpt2",
51
+ model_max_length=config['seq_len'],
52
+ pad_token='[PAD]')
53
+ return tokenizer
54
+
55
+ def causal_mask(size):
56
+ mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
57
+ return mask==0
shakespeare_data.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Ref: https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare/prepare.py
2
+ '''
3
+ import os
4
+ import requests
5
+ import numpy as np
6
+ from transformers import GPT2Tokenizer
7
+ from shakespeare_config import get_data_folder_path, get_config, get_gpt2_tokenizer
8
+ from pathlib import Path
9
+
10
+ if __name__=='__main__':
11
+ config=get_config()
12
+ data_folder_path = get_data_folder_path(config=config)
13
+ # download the tiny shakespeare dataset
14
+ input_file_path = os.path.join(data_folder_path, 'input.txt')
15
+ tokenizer:GPT2Tokenizer = get_gpt2_tokenizer(config=config)
16
+
17
+ print(tokenizer.model_max_length)
18
+
19
+ if not Path(input_file_path).exists():
20
+ data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
21
+ with open(input_file_path, 'w', encoding='utf-8') as f:
22
+ f.write(requests.get(data_url).text)
23
+
24
+ data=''
25
+ with open(input_file_path, 'r', encoding='utf-8') as f:
26
+ for line in f.readlines():
27
+ if len(line.rstrip())>0:
28
+ data += ' ' + line
29
+
30
+ print(data)
31
+ n = len(data)
32
+ train_split = int(n*0.9)
33
+ train_data = data[:train_split]
34
+ test_data = data[train_split:]
35
+
36
+ train_ids = tokenizer.encode(train_data)
37
+ test_ids = tokenizer.encode(test_data)
38
+ print(f"train has {len(train_ids):,} tokens")
39
+ print(f"test has {len(test_ids):,} tokens")
40
+
41
+ # export to bin files
42
+ train_ids = np.array(train_ids, dtype=np.uint16)
43
+ test_ids = np.array(test_ids, dtype=np.uint16)
44
+ train_ids.tofile(os.path.join(data_folder_path, 'train.bin'))
45
+ test_ids.tofile(os.path.join(data_folder_path, 'test.bin'))
46
+ # train has 292,080 tokens
47
+ # test has 34,806 tokens
48
+ print(tokenizer.convert_ids_to_tokens(tokenizer.eos_token_id))
49
+ print(tokenizer.convert_ids_to_tokens(tokenizer.pad_token_id))
shakespeare_inference.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from shakespeare_config import (get_config,
2
+ latest_weights_file_path,
3
+ get_gpt2_tokenizer,
4
+ causal_mask,
5
+ current_directory)
6
+ import torch
7
+ import warnings
8
+ import heapq
9
+ from train import build_transformer
10
+
11
+ def predict_with_greedy_search(start_str:str)-> None:
12
+ config:dict=get_config()
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print(f"Using device: {device}")
15
+
16
+ tokenizer = get_gpt2_tokenizer(config=config)
17
+ model = build_transformer(vocab_size=config['vocab_size'],
18
+ seq_len=config['seq_len'],
19
+ d_model=config['d_model']).to(device)
20
+ # load the pretrained weights
21
+ model_filename = latest_weights_file_path(config)
22
+ state = torch.load(model_filename)
23
+ model.load_state_dict(state['model_state_dict'])
24
+ model.eval()
25
+
26
+ output = start_str
27
+ with torch.no_grad():
28
+ start_tokens = tokenizer.encode(start_str)
29
+ print(start_tokens)
30
+ input = torch.tensor(data=start_tokens, dtype=torch.int64).unsqueeze(dim=0).to(device)
31
+ # print(input)
32
+ while input.size(1) <= config['seq_len']:
33
+ # use mask otheriwse model may generate repetitive words in prediction
34
+ mask = causal_mask(input.size(1)).to(device)
35
+ out = model.decode(input,mask)
36
+ prob = model.project(out[:, -1])
37
+ _, next_word = torch.max(prob, dim=1)
38
+ input = torch.cat(
39
+ [
40
+ input,
41
+ torch.empty(1,1).type_as(input).fill_(next_word.item()).to(device)
42
+ ],
43
+ dim=1
44
+ )
45
+ output += tokenizer.decode(next_word.item())
46
+
47
+ print(f'Model output: {output}')
48
+
49
+
50
+ def predict_with_beam_search(start_str: str,
51
+ beam_width: int = 3) -> None:
52
+ config: dict = get_config()
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ print(f"Using device: {device}")
55
+
56
+ tokenizer = get_gpt2_tokenizer(config=config)
57
+ model = build_transformer(vocab_size=config['vocab_size'],
58
+ seq_len=config['seq_len'],
59
+ d_model=config['d_model']).to(device)
60
+
61
+ # Load the pretrained weights
62
+ model_filename = latest_weights_file_path(config)
63
+ state = torch.load(model_filename)
64
+ model.load_state_dict(state['model_state_dict'])
65
+ model.eval()
66
+
67
+ # Initial input
68
+ start_tokens = tokenizer.encode(start_str)
69
+ input = torch.tensor(data=start_tokens, dtype=torch.int64).unsqueeze(dim=0).to(device) # (1, seq_len)
70
+
71
+ # Beam search variables
72
+ beams = [(0, input, [])] # Each beam is a tuple of (score, sequence, tokens_generated)
73
+
74
+ for _ in range(config['seq_len']):
75
+ all_candidates = []
76
+
77
+ # Process each beam
78
+ for score, seq, tokens in beams:
79
+ # use mask otheriwse model may generate repetitive words in prediction
80
+ mask = causal_mask(seq.size(1)).to(device)
81
+ out = model.decode(seq, mask)
82
+ prob = model.project(out[:, -1])
83
+
84
+ # Get the top k predictions
85
+ top_k_probabilities, top_k_indices = torch.topk(prob, beam_width, dim=1)
86
+
87
+ # Generate new beams for each of the top k tokens
88
+ for i in range(beam_width):
89
+ new_token = top_k_indices[0, i].item()
90
+ new_score = score - torch.log(top_k_probabilities[0, i]).item() # We negate because we want to maximize
91
+ new_seq = torch.cat([seq, torch.tensor([[new_token]], device=device)], dim=1)
92
+ new_tokens = tokens + [new_token]
93
+ all_candidates.append((new_score, new_seq, new_tokens))
94
+
95
+ # Sort all candidates based on their score and keep the top `beam_width` beams
96
+ beams = heapq.nsmallest(beam_width, all_candidates, key=lambda x: x[0])
97
+
98
+ # Optionally, stop early if all beams end with an EOS token
99
+ if all(beam[1].shape[1] >= config['seq_len'] for beam in beams):
100
+ break
101
+
102
+ # Retrieve the best beam (with the highest score)
103
+ best_beam = beams[0]
104
+ best_tokens = best_beam[2]
105
+
106
+ # Decode the final sequence
107
+ output = tokenizer.decode(best_tokens, skip_special_tokens=True)
108
+ print(f'Model output: {output}')
109
+
110
+ if __name__ == '__main__':
111
+ warnings.filterwarnings("ignore")
112
+ start_str = 'Now sadder, that you come so'
113
+ predict_with_greedy_search(start_str=start_str)
114
+ print('--'*100)
115
+ predict_with_beam_search(start_str=start_str)
train.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchmetrics.classification
2
+ from torchmetrics.text import ROUGEScore
3
+ from model import build_transformer
4
+ from shakespeare_config import (get_config,
5
+ get_data_folder_path,
6
+ get_weights_file_path,
7
+ latest_weights_file_path,
8
+ current_directory,
9
+ causal_mask,
10
+ get_gpt2_tokenizer)
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.utils.tensorboard import SummaryWriter
15
+ import torchmetrics
16
+ import numpy as np
17
+ import warnings
18
+ import os
19
+ from pathlib import Path
20
+
21
+ def get_model(config):
22
+ model = build_transformer(vocab_size=config['vocab_size'],
23
+ seq_len=config['seq_len'],
24
+ d_model=config['d_model'])
25
+ return model
26
+
27
+ def get_batch(split, data_dir, block_size, batch_size, device='gpu', device_type='cuda'):
28
+ # We recreate np.memmap every batch to avoid a memory leak, as per
29
+ # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
30
+ if split == 'train':
31
+ data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
32
+ else:
33
+ data = np.memmap(os.path.join(data_dir, 'test.bin'), dtype=np.uint16, mode='r')
34
+ ix = torch.randint(len(data) - block_size, (batch_size,))
35
+ x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
36
+ y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
37
+ # if device_type == 'cuda':
38
+ # # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
39
+ # x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
40
+ # else:
41
+ # x, y = x.to(device), y.to(device)
42
+ return x, y
43
+
44
+ def greedy_decode(model,
45
+ input,
46
+ mask,
47
+ tokenizer,
48
+ max_len,
49
+ device):
50
+ while True:
51
+ if input.size(1) == max_len:
52
+ break
53
+
54
+ out = model.decode(input, mask)
55
+ prob = model.project(out[:, -1])
56
+ _, next_word = torch.max(prob, dim=1)
57
+ input = torch.cat(
58
+ [input, torch.empty(1,1).type_as(input).fill_(next_word.item()).to(device)],
59
+ dim=1
60
+ )
61
+ if next_word == tokenizer.eos_token_id:
62
+ break
63
+ return input.squeeze(0)
64
+
65
+ def run_validation(model,
66
+ x,
67
+ y,
68
+ tokenizer,
69
+ max_len,
70
+ device,
71
+ print_msg,
72
+ global_step,
73
+ writer,
74
+ rouge:ROUGEScore):
75
+
76
+ model.eval()
77
+ source_texts = []
78
+ expected = []
79
+ predicted = []
80
+
81
+ with torch.no_grad():
82
+ decoder_input = x.to(device) # (b, seq)
83
+ mask = causal_mask(x.size(1)).to(device) #(b,1,1,seq)
84
+
85
+ # check that batch size is 1
86
+ assert decoder_input.size(0)==1, "batch size must be 1 for validation"
87
+
88
+ model_out = greedy_decode(model,
89
+ decoder_input,
90
+ mask,
91
+ tokenizer,
92
+ max_len,
93
+ device)
94
+
95
+ source_text = tokenizer.decode(x[0])
96
+ target_text = tokenizer.decode(y[0])
97
+ model_out_text = tokenizer.decode(model_out.detach().cpu().numpy())
98
+
99
+ source_texts.append(source_text)
100
+ expected.append(target_text)
101
+ predicted.append(model_out_text)
102
+
103
+ # Print the source, target and model output
104
+ print_msg('-'*100)
105
+ print_msg(f"{f'SOURCE: ':>12}{source_text}")
106
+ print_msg(f"{f'TARGET: ':>12}{target_text}")
107
+ print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")
108
+
109
+ rouge_score = rouge(predicted, expected)
110
+ print_msg(f"{f'ROUGE-1 Score: ':>12}{rouge_score['rouge1_fmeasure'].item()}")
111
+ print_msg(f"{f'ROUGE-2 Score: ':>12}{rouge_score['rouge2_fmeasure'].item()}")
112
+ print_msg(f"{f'ROUGE-L Score: ':>12}{rouge_score['rougeL_fmeasure'].item()}")
113
+ print_msg('-'*100)
114
+
115
+ if writer:
116
+ writer.add_scalar('validation ROUGE/ROUGE-1', rouge_score["rouge1_fmeasure"].item(), global_step)
117
+ writer.add_scalar('validation ROUGE/ROUGE-2', rouge_score["rouge2_fmeasure"].item(), global_step)
118
+ writer.add_scalar('validation ROUGE/ROUGE-L', rouge_score["rougeL_fmeasure"].item(), global_step)
119
+ writer.add_scalar('validation ROUGE/ROUGE-L', rouge_score["rougeLsum_fmeasure"].item(), global_step)
120
+ writer.flush()
121
+
122
+ def train_model(config):
123
+ # define the device
124
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_ms or torch.backends.mps.is_available else "cpu"
125
+ print("Using device:", device)
126
+
127
+ if (device == 'cuda'):
128
+ print(f"Device name: {torch.cuda.get_device_name(device=device.index)}")
129
+ print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
130
+ elif (device == 'cpu'):
131
+ print(f"device name: <mps>")
132
+ else:
133
+ print("It's cpu")
134
+
135
+ device = torch.device(device)
136
+
137
+ # make sure the weights folder exists
138
+ Path(f"{current_directory}/{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)
139
+
140
+ tokenizer = get_gpt2_tokenizer(config=config)
141
+ model = get_model(config).to(device)
142
+ # tensorboard
143
+ writer = SummaryWriter(f"{current_directory}/{config['experiment_name']}")
144
+
145
+ optimizer = torch.optim.Adam(model.parameters(),
146
+ lr=config['lr'],
147
+ eps=1e-9)
148
+ rouge:ROUGEScore = ROUGEScore()
149
+
150
+ # if the user specified a model to preload before training, load it
151
+ initial_epoch = 0
152
+ global_step = 0
153
+ preload = config['preload']
154
+ model_filename = (latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None)
155
+ if model_filename:
156
+ print(f'Preloading model {model_filename}')
157
+ state = torch.load(model_filename)
158
+ model.load_state_dict(state['model_state_dict'])
159
+ initial_epoch = state['epoch'] + 1
160
+ optimizer.load_state_dict(state['optimizer_state_dict'])
161
+ global_step = state['global_step']
162
+ else:
163
+ print('No model to preload, starting from scratch')
164
+
165
+ loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.convert_tokens_to_ids('[PAD]'), label_smoothing=0.1).to(device)
166
+ for epoch in range(initial_epoch, config['num_epochs']):
167
+ torch.cuda.empty_cache()
168
+ model.train()
169
+ #batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
170
+ X, y = get_batch(split='train',
171
+ data_dir=get_data_folder_path(config=config),
172
+ block_size=config['seq_len'],
173
+ batch_size=config['batch_size'])
174
+ print(f'length of the batch: {len(X)}, type:{X.shape}')
175
+
176
+ decoder_input = X.to(device) # (b, seq_len)
177
+ decoder_mask = causal_mask(config['seq_len']).to(device) # (1, seq_len, seq_len)
178
+
179
+ # run the tensors through the encoder, decoder and the projection layer
180
+ decoder_output = model.decode(decoder_input, decoder_mask) # (b, seq, d_model)
181
+ proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)
182
+
183
+ # compare the output with the label
184
+ label = y.to(device) #(b, seq_len)
185
+
186
+ # compute the loss using a simple cross entropy
187
+ loss = loss_fn(proj_output.view(-1, config['vocab_size']),
188
+ label.view(-1))
189
+ #batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
190
+ print(f"loss: {loss.item():6.3f}")
191
+
192
+ # log the loss
193
+ writer.add_scalar('train loss', loss.item(), global_step)
194
+ writer.flush()
195
+
196
+ # backpropagate the loss
197
+ loss.backward()
198
+
199
+ # update the weights
200
+ optimizer.step()
201
+ optimizer.zero_grad(set_to_none=True)
202
+
203
+ global_step += 1
204
+
205
+ # run validation at the end of every epoch
206
+ X_val, y_val = get_batch(split='val',
207
+ data_dir=get_data_folder_path(config=config),
208
+ block_size=config['seq_len'],
209
+ batch_size=1)
210
+ run_validation(model,
211
+ X_val,
212
+ y_val,
213
+ tokenizer,
214
+ config['seq_len'],
215
+ device,
216
+ lambda msg: print(msg),
217
+ global_step,
218
+ writer,
219
+ rouge)
220
+
221
+ if epoch%1000==0 or epoch >= (config['num_epochs']-1):
222
+ # save the model at the end of every epoch
223
+ model_filename = get_weights_file_path(config, f"{epoch:02d}")
224
+ torch.save({
225
+ 'epoch': epoch,
226
+ 'model_state_dict': model.state_dict(),
227
+ 'optimizer_state_dict': optimizer.state_dict(),
228
+ 'global_step': global_step
229
+ }, model_filename)
230
+
231
+ if __name__ == '__main__':
232
+ warnings.filterwarnings("ignore")
233
+ config = get_config()
234
+ train_model(config)