Shilpaj commited on
Commit
5bdbe00
·
1 Parent(s): 4f130bd

Feat: Project code

Browse files
Files changed (6) hide show
  1. .gitignore +2 -1
  2. README.md +118 -0
  3. S12Trained.ipynb +1091 -0
  4. assets/LLMfromScratch2.png +0 -0
  5. input.txt +0 -0
  6. requirements.txt +5 -0
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  __pycache__/
2
- references/
 
 
1
  __pycache__/
2
+ references/
3
+ nano_gpt_model.pt
README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pre-Training
2
+
3
+ This section focuses on Embeddings and Pre-training.
4
+
5
+ ![LLM Training Steps](./assets/LLMfromScratch2.png)
6
+
7
+ In this project, a GPT (decoder-only) model is trained on Shakespeare data. The model architecture follows the original GPT design with multi-head self-attention and feed-forward layers. Key specifications include:
8
+
9
+ - 8 transformer layers
10
+ - 8 attention heads
11
+ - 384 embedding dimensions
12
+ - 512 context window size
13
+ - ~50k vocabulary size
14
+
15
+ The model is trained using cross-entropy loss and AdamW optimizer with weight decay. Training is done on Shakespeare's works to learn the language patterns and writing style. The trained model can generate Shakespeare-style text given a prompt.
16
+
17
+
18
+
19
+ ### Project Structure
20
+
21
+ ```
22
+ .
23
+ ├── assets # Images for README
24
+ ├── nano_gpt_model.pt # Trained model
25
+ ├── S12Trained.ipynb # Notebook for training
26
+ ├── input.txt # Shakespeare data
27
+ ├── README.md # This file
28
+ └── requirements.txt # Dependencies
29
+ ```
30
+
31
+ ### Install Dependencies
32
+
33
+ ```
34
+ pip install -r requirements.txt
35
+ ```
36
+
37
+ ### Run the Notebook
38
+
39
+ ```
40
+ jupyter notebook S12Trained.ipynb
41
+ ```
42
+
43
+ ### Training Logs
44
+
45
+ Training logs for few steps are shown below:
46
+
47
+ ```bash
48
+ GPU Memory: 0.68GB / 1.77GB
49
+ step 10,000 | loss: 0.5863 | lr: 6.00e-05 | dt: 684.74ms | tok/sec: 5981.86 | norm: 3.94
50
+
51
+ GPU Memory: 0.68GB / 1.77GB
52
+ step 10,100 | loss: 0.5372 | lr: 6.00e-05 | dt: 687.72ms | tok/sec: 5955.88 | norm: 3.74
53
+
54
+ GPU Memory: 0.67GB / 1.77GB
55
+ step 10,200 | loss: 0.6054 | lr: 6.00e-05 | dt: 685.72ms | tok/sec: 5973.31 | norm: 5.71
56
+
57
+ GPU Memory: 0.68GB / 1.77GB
58
+ step 10,300 | loss: 0.5850 | lr: 6.00e-05 | dt: 686.01ms | tok/sec: 5970.77 | norm: 4.36
59
+
60
+ GPU Memory: 0.68GB / 1.77GB
61
+ step 10,400 | loss: 0.3319 | lr: 6.00e-05 | dt: 684.77ms | tok/sec: 5981.53 | norm: 4.68
62
+
63
+ GPU Memory: 0.68GB / 1.77GB
64
+ step 10,500 | loss: 0.4140 | lr: 6.00e-05 | dt: 684.41ms | tok/sec: 5984.70 | norm: 3.21
65
+
66
+ GPU Memory: 0.68GB / 1.77GB
67
+ step 10,600 | loss: 0.4008 | lr: 6.00e-05 | dt: 683.34ms | tok/sec: 5994.10 | norm: 3.58
68
+
69
+ GPU Memory: 0.68GB / 1.77GB
70
+ step 10,700 | loss: 0.3951 | lr: 6.00e-05 | dt: 685.49ms | tok/sec: 5975.26 | norm: 3.81
71
+
72
+ GPU Memory: 0.68GB / 1.77GB
73
+ step 10,800 | loss: 0.3022 | lr: 6.00e-05 | dt: 687.40ms | tok/sec: 5958.64 | norm: 3.06
74
+
75
+ GPU Memory: 0.68GB / 1.77GB
76
+ step 10,900 | loss: 0.4287 | lr: 6.00e-05 | dt: 686.75ms | tok/sec: 5964.31 | norm: 3.60
77
+
78
+ GPU Memory: 0.68GB / 1.77GB
79
+ step 11,000 | loss: 0.2447 | lr: 6.00e-05 | dt: 687.35ms | tok/sec: 5959.12 | norm: 3.35
80
+
81
+ GPU Memory: 0.68GB / 1.77GB
82
+ step 11,100 | loss: 0.2773 | lr: 6.00e-05 | dt: 688.83ms | tok/sec: 5946.35 | norm: 2.71
83
+
84
+ GPU Memory: 0.67GB / 1.77GB
85
+ step 11,200 | loss: 0.2839 | lr: 6.00e-05 | dt: 687.56ms | tok/sec: 5957.31 | norm: 3.90
86
+
87
+ GPU Memory: 0.68GB / 1.77GB
88
+ step 11,300 | loss: 0.3481 | lr: 6.00e-05 | dt: 684.68ms | tok/sec: 5982.32 | norm: 3.68
89
+
90
+ GPU Memory: 0.78GB / 1.77GB
91
+ step 11,400 | loss: 0.1913 | lr: 6.00e-05 | dt: 685.73ms | tok/sec: 5973.18 | norm: 2.93
92
+
93
+ GPU Memory: 0.68GB / 1.77GB
94
+ step 11,500 | loss: 0.2605 | lr: 6.00e-05 | dt: 685.74ms | tok/sec: 5973.11 | norm: 2.96
95
+
96
+ GPU Memory: 0.68GB / 1.77GB
97
+ step 11,600 | loss: 0.2029 | lr: 6.00e-05 | dt: 689.04ms | tok/sec: 5944.49 | norm: 2.84
98
+
99
+
100
+ Reached target loss! Final loss: 0.0889 at step 11,663
101
+ Model saved to gpt_model.pt
102
+ ```
103
+
104
+ ### Model Output
105
+
106
+ ```bash
107
+ Once upon a time to not;
108
+ More slaughter'd, sweet Rivers, I receive my children, and title with pardon hither
109
+ That one stuff'd with a conquest; and teeth, of my? Why, in life thee,
110
+ Which now not joy of foe, thought o'n slaughter bed,
111
+ And, is mine own soul me, not so heavy in every day:
112
+ The tyrant from one curst my death lies;
113
+ For the ground is nothing henceforth fell executioner come
114
+ ```
115
+
116
+ ### Try it out
117
+
118
+ App Link
S12Trained.ipynb ADDED
@@ -0,0 +1,1091 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "LzOshnYOpqP_"
7
+ },
8
+ "source": [
9
+ "# NanoGPT\n",
10
+ "\n",
11
+ "Training a decoder-only model (NanoGPT) to genereate text in Shakespear stype"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "metadata": {
17
+ "id": "J93p7rk7qK-P"
18
+ },
19
+ "source": [
20
+ "## Install Modules"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 1,
26
+ "metadata": {
27
+ "colab": {
28
+ "base_uri": "https://localhost:8080/"
29
+ },
30
+ "id": "8MnYOQ4xcKXa",
31
+ "outputId": "0b15787f-f2a8-4ba6-d744-a60d08f9d152"
32
+ },
33
+ "outputs": [
34
+ {
35
+ "output_type": "stream",
36
+ "name": "stdout",
37
+ "text": [
38
+ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/1.2 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.2/1.2 MB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m17.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
39
+ "\u001b[?25h"
40
+ ]
41
+ }
42
+ ],
43
+ "source": [
44
+ "# Tiktoken for tokenization\n",
45
+ "!pip install tiktoken --quiet"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "metadata": {
51
+ "id": "uHuCVeKYqWvo"
52
+ },
53
+ "source": [
54
+ "## Import Modules"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 2,
60
+ "metadata": {
61
+ "id": "igAl5bXSqZWo"
62
+ },
63
+ "outputs": [],
64
+ "source": [
65
+ "# Standard Library Imports\n",
66
+ "import os\n",
67
+ "import math\n",
68
+ "import time\n",
69
+ "import inspect\n",
70
+ "from dataclasses import dataclass\n",
71
+ "\n",
72
+ "# Third-Party Imports\n",
73
+ "import tiktoken\n",
74
+ "import torch\n",
75
+ "import torch.nn as nn\n",
76
+ "from torch.nn import functional as F"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "metadata": {
82
+ "id": "6QNqCz2fqSX0"
83
+ },
84
+ "source": [
85
+ "## Transformer Achitecture"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": 3,
91
+ "metadata": {
92
+ "id": "jgXGx_-YqVNH"
93
+ },
94
+ "outputs": [],
95
+ "source": [
96
+ "class CausalSelfAttention(nn.Module):\n",
97
+ "\n",
98
+ " def __init__(self, config):\n",
99
+ " super().__init__()\n",
100
+ " assert config.n_embd % config.n_head == 0\n",
101
+ " # key, query, value projections for all heads, but in a batch\n",
102
+ " self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)\n",
103
+ " # output projection\n",
104
+ " self.c_proj = nn.Linear(config.n_embd, config.n_embd)\n",
105
+ " self.c_proj.NANGPT_SCALE_INIT = 1\n",
106
+ " # regularization\n",
107
+ " self.n_head = config.n_head\n",
108
+ " self.n_embd = config.n_embd\n",
109
+ " self.register_buffer(\"bias\", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))\n",
110
+ "\n",
111
+ " def forward(self, x):\n",
112
+ " B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)\n",
113
+ " # calculate query, key, values for all heads in batch and move head forward to be the batch dim\n",
114
+ " # nh is \"number of heads\", hs is \"head size\", and C (number of channels) = nh * hs\n",
115
+ " # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer\n",
116
+ " qkv = self.c_attn(x)\n",
117
+ " q, k, v = qkv.split(self.n_embd, dim=2)\n",
118
+ " k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n",
119
+ " q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n",
120
+ " v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n",
121
+ "\n",
122
+ " # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))\n",
123
+ " # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))\n",
124
+ " # att = F.softmax(att, dim=-1)\n",
125
+ " # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n",
126
+ "\n",
127
+ " y = F.scaled_dot_product_attention(q, k, v, is_causal = True) # Flash attention\n",
128
+ "\n",
129
+ " y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side\n",
130
+ " # output projection\n",
131
+ " y = self.c_proj(y)\n",
132
+ " return y\n",
133
+ "\n",
134
+ "\n",
135
+ "class MLP(nn.Module):\n",
136
+ "\n",
137
+ " def __init__(self, config):\n",
138
+ " super().__init__()\n",
139
+ " self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)\n",
140
+ " self.gelu = nn.GELU(approximate='tanh')\n",
141
+ " self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)\n",
142
+ " self.c_proj.NANOGPT_SCALE_INIT = 1\n",
143
+ "\n",
144
+ " def forward(self, x):\n",
145
+ " x = self.c_fc(x)\n",
146
+ " x = self.gelu(x)\n",
147
+ " x = self.c_proj(x)\n",
148
+ " return x\n",
149
+ "\n",
150
+ "class Block(nn.Module):\n",
151
+ "\n",
152
+ " def __init__(self, config):\n",
153
+ " super().__init__()\n",
154
+ " self.ln_1 = nn.LayerNorm(config.n_embd)\n",
155
+ " self.attn = CausalSelfAttention(config)\n",
156
+ " self.ln_2 = nn.LayerNorm(config.n_embd)\n",
157
+ " self.mlp = MLP(config)\n",
158
+ "\n",
159
+ " def forward(self, x):\n",
160
+ " x = x + self.attn(self.ln_1(x))\n",
161
+ " x = x + self.mlp(self.ln_2(x))\n",
162
+ " return x\n",
163
+ "\n",
164
+ "class GPT(nn.Module):\n",
165
+ "\n",
166
+ " def __init__(self, config):\n",
167
+ " super().__init__()\n",
168
+ " self.config = config\n",
169
+ " self.gradient_checkpointing = True\n",
170
+ "\n",
171
+ " self.transformer = nn.ModuleDict(dict(\n",
172
+ " wte = nn.Embedding(config.vocab_size, config.n_embd),\n",
173
+ " wpe = nn.Embedding(config.block_size, config.n_embd),\n",
174
+ " h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),\n",
175
+ " ln_f = nn.LayerNorm(config.n_embd),\n",
176
+ " ))\n",
177
+ " self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
178
+ "\n",
179
+ " # weight sharing\n",
180
+ " self.transformer.wte.weight = self.lm_head.weight\n",
181
+ "\n",
182
+ " # weight initialization\n",
183
+ " self.apply(self._init_weights)\n",
184
+ "\n",
185
+ " def _init_weights(self, module):\n",
186
+ " if isinstance(module, nn.Linear):\n",
187
+ " std = 0.02\n",
188
+ " if hasattr(module, 'NANGPT_SCALE_INIT'):\n",
189
+ " std *= (2 * self.config.n_layer) ** -0.5\n",
190
+ " torch.nn.init.normal_(module.weight, mean = 0.0, std = std)\n",
191
+ " if module.bias is not None:\n",
192
+ " torch.nn.init.zeros_(module.bias)\n",
193
+ " elif isinstance(module, nn.Embedding):\n",
194
+ " torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)\n",
195
+ "\n",
196
+ "\n",
197
+ "\n",
198
+ " def forward(self, idx, targets=None):\n",
199
+ " # idx is of shape (B, T)\n",
200
+ " B, T = idx.size()\n",
201
+ " assert T <= self.config.block_size, f\"Cannot forward sequence of length {T}, block size is only {self.config.block_size}\"\n",
202
+ " # forward the token and posisition embeddings\n",
203
+ " # pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)\n",
204
+ " # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)\n",
205
+ " # tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)\n",
206
+ " pos = torch.arange(0, T, dtype=torch.long, device=idx.device)\n",
207
+ " pos_emb = self.transformer.wpe(pos)\n",
208
+ " tok_emb = self.transformer.wte(idx)\n",
209
+ " x = tok_emb + pos_emb\n",
210
+ " # forward the blocks of the transformer\n",
211
+ " for block in self.transformer.h:\n",
212
+ " x = block(x)\n",
213
+ " # Modify the transformer blocks section to use gradient checkpointing\n",
214
+ " if self.gradient_checkpointing and self.training:\n",
215
+ " for block in self.transformer.h:\n",
216
+ " x = torch.utils.checkpoint.checkpoint(block, x)\n",
217
+ " else:\n",
218
+ " for block in self.transformer.h:\n",
219
+ " x = block(x)\n",
220
+ "\n",
221
+ " x = self.transformer.ln_f(x)\n",
222
+ " logits = self.lm_head(x)\n",
223
+ "\n",
224
+ " loss = None\n",
225
+ " if targets is not None:\n",
226
+ " loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))\n",
227
+ " return logits, loss\n",
228
+ "\n",
229
+ " @classmethod\n",
230
+ " def from_pretrained(cls, model_type):\n",
231
+ " \"\"\"Loads pretrained GPT-2 model weights from huggingface\"\"\"\n",
232
+ " assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}\n",
233
+ " from transformers import GPT2LMHeadModel\n",
234
+ " print(\"loading weights from pretrained gpt: %s\" % model_type)\n",
235
+ "\n",
236
+ " # n_layer, n_head and n_embd are determined from model_type\n",
237
+ " config_args = {\n",
238
+ " 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params\n",
239
+ " 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n",
240
+ " 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n",
241
+ " 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n",
242
+ " }[model_type]\n",
243
+ " config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints\n",
244
+ " config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints\n",
245
+ " # create a from-scratch initialized minGPT model\n",
246
+ " config = GPTConfig(**config_args)\n",
247
+ " model = GPT(config)\n",
248
+ " sd = model.state_dict()\n",
249
+ " sd_keys = sd.keys()\n",
250
+ " sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param\n",
251
+ "\n",
252
+ " # init a huggingface/transformers model\n",
253
+ " model_hf = GPT2LMHeadModel.from_pretrained(model_type)\n",
254
+ " sd_hf = model_hf.state_dict()\n",
255
+ "\n",
256
+ " # copy while ensuring all of the parameters are aligned and match in names and shapes\n",
257
+ " sd_keys_hf = sd_hf.keys()\n",
258
+ " sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer\n",
259
+ " sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)\n",
260
+ " transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']\n",
261
+ " # basically the openai checkpoints use a \"Conv1D\" module, but we only want to use a vanilla Linear\n",
262
+ " # this means that we have to transpose these weights when we import them\n",
263
+ " assert len(sd_keys_hf) == len(sd_keys), f\"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}\"\n",
264
+ " for k in sd_keys_hf:\n",
265
+ " if any(k.endswith(w) for w in transposed):\n",
266
+ " # special treatment for the Conv1D weights we need to transpose\n",
267
+ " assert sd_hf[k].shape[::-1] == sd[k].shape\n",
268
+ " with torch.no_grad():\n",
269
+ " sd[k].copy_(sd_hf[k].t())\n",
270
+ " else:\n",
271
+ " # vanilla copy over the other parameters\n",
272
+ " assert sd_hf[k].shape == sd[k].shape\n",
273
+ " with torch.no_grad():\n",
274
+ " sd[k].copy_(sd_hf[k])\n",
275
+ "\n",
276
+ " return model\n",
277
+ "\n",
278
+ " def configure_optimizers(self, weight_decay, learning_rate, device_type):\n",
279
+ " # start with all of the candidate parameters (that require grad)\n",
280
+ " param_dict = {pn: p for pn, p in self.named_parameters()}\n",
281
+ " param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}\n",
282
+ " # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.\n",
283
+ " # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.\n",
284
+ " decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]\n",
285
+ " nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]\n",
286
+ " optim_groups = [\n",
287
+ " {'params': decay_params, 'weight_decay': weight_decay},\n",
288
+ " {'params': nodecay_params, 'weight_decay': 0.0}\n",
289
+ " ]\n",
290
+ " num_decay_params = sum(p.numel() for p in decay_params)\n",
291
+ " num_nodecay_params = sum(p.numel() for p in nodecay_params)\n",
292
+ "\n",
293
+ " print(f\"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters\")\n",
294
+ " print(f\"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters\")\n",
295
+ " # Create AdamW optimizer and use the fused version if it is available\n",
296
+ " fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters\n",
297
+ " use_fused = fused_available and device_type == \"cuda\"\n",
298
+ "\n",
299
+ " print(f\"using fused AdamW: {use_fused}\")\n",
300
+ " optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)\n",
301
+ " return optimizer"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "markdown",
306
+ "metadata": {
307
+ "id": "C3Xev9Ycq7Qe"
308
+ },
309
+ "source": [
310
+ "## Configuration Parameters"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": 4,
316
+ "metadata": {
317
+ "id": "27h8bNasq-Xe"
318
+ },
319
+ "outputs": [],
320
+ "source": [
321
+ "@dataclass\n",
322
+ "class GPTConfig:\n",
323
+ " block_size: int = 512 # max sequence length\n",
324
+ " vocab_size: int = 50304 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token\n",
325
+ " n_layer: int = 8 # number of layers\n",
326
+ " n_head: int = 8 # number of heads\n",
327
+ " n_embd: int = 384 # embedding dimension 768"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "markdown",
332
+ "metadata": {
333
+ "id": "GHpQZ_avrMyd"
334
+ },
335
+ "source": [
336
+ "## DataLoader"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "execution_count": 5,
342
+ "metadata": {
343
+ "id": "NcXgug8hrPO-"
344
+ },
345
+ "outputs": [],
346
+ "source": [
347
+ "class DataLoaderLite:\n",
348
+ " def __init__(self, B, T):\n",
349
+ " self.B = B\n",
350
+ " self.T = T\n",
351
+ "\n",
352
+ " # Modify path to your input file location in Drive\n",
353
+ " input_path = 'input.txt' # Update this path\n",
354
+ " try:\n",
355
+ " with open(input_path, 'r', encoding='utf-8') as f:\n",
356
+ " text = f.read()\n",
357
+ " print(f\"Successfully loaded text from {input_path}\")\n",
358
+ " except Exception as e:\n",
359
+ " print(f\"Error loading file: {e}\")\n",
360
+ " raise\n",
361
+ "\n",
362
+ " enc = tiktoken.get_encoding('gpt2')\n",
363
+ " tokens = enc.encode(text)\n",
364
+ " self.tokens = torch.tensor(tokens)\n",
365
+ " print(f'Loaded {len(self.tokens):,} tokens')\n",
366
+ " print(f'1 epoch = {len(self.tokens) // (B * T):,} batches')\n",
367
+ " print(f'Input text size: {len(text):,} characters')\n",
368
+ "\n",
369
+ " self.current_position = 0\n",
370
+ "\n",
371
+ " def next_batch(self):\n",
372
+ " B, T = self.B, self.T\n",
373
+ " buf = self.tokens[self.current_position: self.current_position + B * T + 1]\n",
374
+ " x = (buf[:-1]).view(B, T) # inputs\n",
375
+ " y = (buf[1:]).view(B, T) # targets\n",
376
+ " # advance the position in the tensor\n",
377
+ " self.current_position += B*T\n",
378
+ " # if loading the next batch would be out of bounds, reset\n",
379
+ " if self.current_position + (B * T + 1) > len(self.tokens):\n",
380
+ " self.current_position = 0\n",
381
+ " return x, y"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "markdown",
386
+ "metadata": {
387
+ "id": "y0eEHPNrrz6b"
388
+ },
389
+ "source": [
390
+ "## Device Configutration"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "code",
395
+ "execution_count": 6,
396
+ "metadata": {
397
+ "id": "w_OS901rr3SO",
398
+ "colab": {
399
+ "base_uri": "https://localhost:8080/"
400
+ },
401
+ "outputId": "0c386cef-8cc3-4964-916a-178cd5668959"
402
+ },
403
+ "outputs": [
404
+ {
405
+ "output_type": "stream",
406
+ "name": "stdout",
407
+ "text": [
408
+ "CUDA available: True\n",
409
+ "CUDA device: Tesla T4\n",
410
+ "Total GPU memory: 15.84 GB\n",
411
+ "Using device: cuda\n"
412
+ ]
413
+ }
414
+ ],
415
+ "source": [
416
+ "# Add CUDA check and memory info\n",
417
+ "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
418
+ "if torch.cuda.is_available():\n",
419
+ " print(f\"CUDA device: {torch.cuda.get_device_name(0)}\")\n",
420
+ " print(f\"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\")\n",
421
+ "\n",
422
+ "# Modify the device selection\n",
423
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
424
+ "print(f\"Using device: {device}\")"
425
+ ]
426
+ },
427
+ {
428
+ "cell_type": "markdown",
429
+ "metadata": {
430
+ "id": "p74YZyRe_3_Z"
431
+ },
432
+ "source": [
433
+ "## Utilities"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": 7,
439
+ "metadata": {
440
+ "id": "q1Q1T5az_7Pb"
441
+ },
442
+ "outputs": [],
443
+ "source": [
444
+ "def get_lr(it):\n",
445
+ " if it < warmup_steps:\n",
446
+ " return max_lr * (it + 1) / warmup_steps\n",
447
+ " if it > max_steps:\n",
448
+ " return min_lr\n",
449
+ " decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)\n",
450
+ " assert 0 <= decay_ratio <=1\n",
451
+ " coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))\n",
452
+ " return min_lr + coeff * (max_lr - min_lr)"
453
+ ]
454
+ },
455
+ {
456
+ "cell_type": "markdown",
457
+ "metadata": {
458
+ "id": "veEWTXgVsIVz"
459
+ },
460
+ "source": [
461
+ "## Model Training"
462
+ ]
463
+ },
464
+ {
465
+ "cell_type": "code",
466
+ "execution_count": 9,
467
+ "metadata": {
468
+ "colab": {
469
+ "base_uri": "https://localhost:8080/"
470
+ },
471
+ "id": "NjuBy6HGb9Ta",
472
+ "outputId": "78a7feb8-c362-4343-a3d1-8a1cbf744491"
473
+ },
474
+ "outputs": [
475
+ {
476
+ "output_type": "stream",
477
+ "name": "stdout",
478
+ "text": [
479
+ "Successfully loaded text from input.txt\n",
480
+ "Loaded 338,025 tokens\n",
481
+ "1 epoch = 330 batches\n",
482
+ "Input text size: 1,115,394 characters\n",
483
+ "num decayed parameter tensors: 34, with 33,669,120 parameters\n",
484
+ "num non-decayed parameter tensors: 66, with 40,704 parameters\n",
485
+ "using fused AdamW: True\n"
486
+ ]
487
+ },
488
+ {
489
+ "output_type": "stream",
490
+ "name": "stderr",
491
+ "text": [
492
+ "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:1604: UserWarning: Tesla T4 does not support bfloat16 compilation natively, skipping\n",
493
+ " warnings.warn(\n",
494
+ "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:1604: UserWarning: Tesla T4 does not support bfloat16 compilation natively, skipping\n",
495
+ " warnings.warn(\n",
496
+ "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:1604: UserWarning: Tesla T4 does not support bfloat16 compilation natively, skipping\n",
497
+ " warnings.warn(\n",
498
+ "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:1604: UserWarning: Tesla T4 does not support bfloat16 compilation natively, skipping\n",
499
+ " warnings.warn(\n",
500
+ "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:632: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
501
+ " return fn(*args, **kwargs)\n"
502
+ ]
503
+ },
504
+ {
505
+ "output_type": "stream",
506
+ "name": "stdout",
507
+ "text": [
508
+ "GPU Memory: 0.68GB / 1.55GB\n",
509
+ "step 0 | loss: 10.9470 | lr: 6.00e-05 | dt: 23182.17ms | tok/sec: 176.69 | norm: 5.03\n",
510
+ " \n",
511
+ "GPU Memory: 0.78GB / 1.77GB\n",
512
+ "step 100 | loss: 5.9936 | lr: 6.00e-05 | dt: 687.63ms | tok/sec: 5956.71 | norm: 0.90\n",
513
+ " \n",
514
+ "GPU Memory: 0.78GB / 1.77GB\n",
515
+ "step 200 | loss: 6.1430 | lr: 6.00e-05 | dt: 680.82ms | tok/sec: 6016.26 | norm: 0.85\n",
516
+ " \n",
517
+ "GPU Memory: 0.68GB / 1.77GB\n",
518
+ "step 300 | loss: 5.7334 | lr: 6.00e-05 | dt: 684.30ms | tok/sec: 5985.65 | norm: 1.32\n",
519
+ " \n",
520
+ "GPU Memory: 0.68GB / 1.77GB\n",
521
+ "step 400 | loss: 5.5456 | lr: 6.00e-05 | dt: 680.26ms | tok/sec: 6021.24 | norm: 1.12\n",
522
+ " \n",
523
+ "GPU Memory: 0.68GB / 1.77GB\n",
524
+ "step 500 | loss: 4.7860 | lr: 6.00e-05 | dt: 682.16ms | tok/sec: 6004.47 | norm: 1.53\n",
525
+ " \n",
526
+ "GPU Memory: 0.68GB / 1.77GB\n",
527
+ "step 600 | loss: 5.4306 | lr: 6.00e-05 | dt: 683.22ms | tok/sec: 5995.14 | norm: 1.23\n",
528
+ " \n",
529
+ "GPU Memory: 0.78GB / 1.77GB\n",
530
+ "step 700 | loss: 5.3188 | lr: 6.00e-05 | dt: 683.26ms | tok/sec: 5994.79 | norm: 1.21\n",
531
+ " \n",
532
+ "GPU Memory: 0.68GB / 1.77GB\n",
533
+ "step 800 | loss: 5.1400 | lr: 6.00e-05 | dt: 686.78ms | tok/sec: 5964.07 | norm: 1.84\n",
534
+ " \n",
535
+ "GPU Memory: 0.68GB / 1.77GB\n",
536
+ "step 900 | loss: 4.9057 | lr: 6.00e-05 | dt: 680.29ms | tok/sec: 6020.97 | norm: 1.59\n",
537
+ " \n",
538
+ "GPU Memory: 0.78GB / 1.77GB\n",
539
+ "step 1,000 | loss: 4.8926 | lr: 6.00e-05 | dt: 684.95ms | tok/sec: 5980.04 | norm: 1.48\n",
540
+ " \n",
541
+ "GPU Memory: 0.68GB / 1.77GB\n",
542
+ "step 1,100 | loss: 4.7693 | lr: 6.00e-05 | dt: 685.51ms | tok/sec: 5975.12 | norm: 2.60\n",
543
+ " \n",
544
+ "GPU Memory: 0.78GB / 1.77GB\n",
545
+ "step 1,200 | loss: 5.0689 | lr: 6.00e-05 | dt: 682.40ms | tok/sec: 6002.35 | norm: 1.57\n",
546
+ " \n",
547
+ "GPU Memory: 0.68GB / 1.77GB\n",
548
+ "step 1,300 | loss: 4.5868 | lr: 6.00e-05 | dt: 682.93ms | tok/sec: 5997.68 | norm: 1.86\n",
549
+ " \n",
550
+ "GPU Memory: 0.68GB / 1.77GB\n",
551
+ "step 1,400 | loss: 4.8036 | lr: 6.00e-05 | dt: 682.67ms | tok/sec: 5999.99 | norm: 1.92\n",
552
+ " \n",
553
+ "GPU Memory: 0.68GB / 1.77GB\n",
554
+ "step 1,500 | loss: 4.5739 | lr: 6.00e-05 | dt: 682.74ms | tok/sec: 5999.33 | norm: 1.89\n",
555
+ " \n",
556
+ "GPU Memory: 0.68GB / 1.77GB\n",
557
+ "step 1,600 | loss: 4.3767 | lr: 6.00e-05 | dt: 681.53ms | tok/sec: 6009.97 | norm: 2.48\n",
558
+ " \n",
559
+ "GPU Memory: 0.68GB / 1.77GB\n",
560
+ "step 1,700 | loss: 4.7170 | lr: 6.00e-05 | dt: 682.21ms | tok/sec: 6004.05 | norm: 2.88\n",
561
+ " \n",
562
+ "GPU Memory: 0.78GB / 1.77GB\n",
563
+ "step 1,800 | loss: 4.5040 | lr: 6.00e-05 | dt: 681.62ms | tok/sec: 6009.23 | norm: 2.25\n",
564
+ " \n",
565
+ "GPU Memory: 0.68GB / 1.77GB\n",
566
+ "step 1,900 | loss: 4.4489 | lr: 6.00e-05 | dt: 686.54ms | tok/sec: 5966.18 | norm: 2.01\n",
567
+ " \n",
568
+ "GPU Memory: 0.68GB / 1.77GB\n",
569
+ "step 2,000 | loss: 3.9155 | lr: 6.00e-05 | dt: 683.89ms | tok/sec: 5989.27 | norm: 2.26\n",
570
+ " \n",
571
+ "GPU Memory: 0.68GB / 1.77GB\n",
572
+ "step 2,100 | loss: 4.1443 | lr: 6.00e-05 | dt: 681.80ms | tok/sec: 6007.60 | norm: 2.33\n",
573
+ " \n",
574
+ "GPU Memory: 0.78GB / 1.77GB\n",
575
+ "step 2,200 | loss: 4.6614 | lr: 6.00e-05 | dt: 681.72ms | tok/sec: 6008.37 | norm: 2.44\n",
576
+ " \n",
577
+ "GPU Memory: 0.68GB / 1.77GB\n",
578
+ "step 2,300 | loss: 3.6722 | lr: 6.00e-05 | dt: 684.94ms | tok/sec: 5980.11 | norm: 2.75\n",
579
+ " \n",
580
+ "GPU Memory: 0.78GB / 1.77GB\n",
581
+ "step 2,400 | loss: 4.0311 | lr: 6.00e-05 | dt: 684.12ms | tok/sec: 5987.25 | norm: 2.90\n",
582
+ " \n",
583
+ "GPU Memory: 0.68GB / 1.77GB\n",
584
+ "step 2,500 | loss: 4.1794 | lr: 6.00e-05 | dt: 686.70ms | tok/sec: 5964.77 | norm: 2.67\n",
585
+ " \n",
586
+ "GPU Memory: 0.68GB / 1.77GB\n",
587
+ "step 2,600 | loss: 3.7992 | lr: 6.00e-05 | dt: 681.41ms | tok/sec: 6011.09 | norm: 3.10\n",
588
+ " \n",
589
+ "GPU Memory: 0.67GB / 1.77GB\n",
590
+ "step 2,700 | loss: 4.5840 | lr: 6.00e-05 | dt: 682.54ms | tok/sec: 6001.09 | norm: 3.52\n",
591
+ " \n",
592
+ "GPU Memory: 0.68GB / 1.77GB\n",
593
+ "step 2,800 | loss: 3.6126 | lr: 6.00e-05 | dt: 682.40ms | tok/sec: 6002.35 | norm: 2.92\n",
594
+ " \n",
595
+ "GPU Memory: 0.68GB / 1.77GB\n",
596
+ "step 2,900 | loss: 3.6538 | lr: 6.00e-05 | dt: 681.88ms | tok/sec: 6006.93 | norm: 3.04\n",
597
+ " \n",
598
+ "GPU Memory: 0.68GB / 1.77GB\n",
599
+ "step 3,000 | loss: 3.4119 | lr: 6.00e-05 | dt: 684.70ms | tok/sec: 5982.20 | norm: 2.96\n",
600
+ " \n",
601
+ "GPU Memory: 0.68GB / 1.77GB\n",
602
+ "step 3,100 | loss: 3.4817 | lr: 6.00e-05 | dt: 686.39ms | tok/sec: 5967.42 | norm: 3.22\n",
603
+ " \n",
604
+ "GPU Memory: 0.68GB / 1.77GB\n",
605
+ "step 3,200 | loss: 3.7793 | lr: 6.00e-05 | dt: 682.57ms | tok/sec: 6000.86 | norm: 4.25\n",
606
+ " \n",
607
+ "GPU Memory: 0.68GB / 1.77GB\n",
608
+ "step 3,300 | loss: 3.6155 | lr: 6.00e-05 | dt: 685.45ms | tok/sec: 5975.64 | norm: 4.07\n",
609
+ " \n",
610
+ "GPU Memory: 0.78GB / 1.77GB\n",
611
+ "step 3,400 | loss: 3.7451 | lr: 6.00e-05 | dt: 686.79ms | tok/sec: 5963.94 | norm: 3.34\n",
612
+ " \n",
613
+ "GPU Memory: 0.68GB / 1.77GB\n",
614
+ "step 3,500 | loss: 3.5069 | lr: 6.00e-05 | dt: 686.91ms | tok/sec: 5962.97 | norm: 4.46\n",
615
+ " \n",
616
+ "GPU Memory: 0.68GB / 1.77GB\n",
617
+ "step 3,600 | loss: 3.5325 | lr: 6.00e-05 | dt: 683.46ms | tok/sec: 5993.08 | norm: 4.59\n",
618
+ " \n",
619
+ "GPU Memory: 0.68GB / 1.77GB\n",
620
+ "step 3,700 | loss: 3.6000 | lr: 6.00e-05 | dt: 683.37ms | tok/sec: 5993.87 | norm: 4.69\n",
621
+ " \n",
622
+ "GPU Memory: 0.68GB / 1.77GB\n",
623
+ "step 3,800 | loss: 2.7920 | lr: 6.00e-05 | dt: 681.15ms | tok/sec: 6013.37 | norm: 3.43\n",
624
+ " \n",
625
+ "GPU Memory: 0.68GB / 1.77GB\n",
626
+ "step 3,900 | loss: 3.5265 | lr: 6.00e-05 | dt: 684.23ms | tok/sec: 5986.29 | norm: 3.96\n",
627
+ " \n",
628
+ "GPU Memory: 0.68GB / 1.77GB\n",
629
+ "step 4,000 | loss: 3.5139 | lr: 6.00e-05 | dt: 680.86ms | tok/sec: 6015.93 | norm: 4.81\n",
630
+ " \n",
631
+ "GPU Memory: 0.68GB / 1.77GB\n",
632
+ "step 4,100 | loss: 3.4394 | lr: 6.00e-05 | dt: 680.96ms | tok/sec: 6015.00 | norm: 4.55\n",
633
+ " \n",
634
+ "GPU Memory: 0.78GB / 1.77GB\n",
635
+ "step 4,200 | loss: 3.1918 | lr: 6.00e-05 | dt: 682.53ms | tok/sec: 6001.21 | norm: 5.20\n",
636
+ " \n",
637
+ "GPU Memory: 0.68GB / 1.77GB\n",
638
+ "step 4,300 | loss: 3.2290 | lr: 6.00e-05 | dt: 684.99ms | tok/sec: 5979.67 | norm: 4.89\n",
639
+ " \n",
640
+ "GPU Memory: 0.68GB / 1.77GB\n",
641
+ "step 4,400 | loss: 3.0279 | lr: 6.00e-05 | dt: 678.78ms | tok/sec: 6034.36 | norm: 4.00\n",
642
+ " \n",
643
+ "GPU Memory: 0.67GB / 1.77GB\n",
644
+ "step 4,500 | loss: 3.3566 | lr: 6.00e-05 | dt: 680.69ms | tok/sec: 6017.44 | norm: 4.92\n",
645
+ " \n",
646
+ "GPU Memory: 0.68GB / 1.77GB\n",
647
+ "step 4,600 | loss: 2.9366 | lr: 6.00e-05 | dt: 683.88ms | tok/sec: 5989.36 | norm: 4.65\n",
648
+ " \n",
649
+ "GPU Memory: 0.67GB / 1.77GB\n",
650
+ "step 4,700 | loss: 3.1994 | lr: 6.00e-05 | dt: 679.04ms | tok/sec: 6032.08 | norm: 7.23\n",
651
+ " \n",
652
+ "GPU Memory: 0.68GB / 1.77GB\n",
653
+ "step 4,800 | loss: 2.9273 | lr: 6.00e-05 | dt: 681.97ms | tok/sec: 6006.12 | norm: 3.98\n",
654
+ " \n",
655
+ "GPU Memory: 0.68GB / 1.77GB\n",
656
+ "step 4,900 | loss: 2.8992 | lr: 6.00e-05 | dt: 686.39ms | tok/sec: 5967.48 | norm: 5.38\n",
657
+ " \n",
658
+ "GPU Memory: 0.78GB / 1.77GB\n",
659
+ "step 5,000 | loss: 2.9706 | lr: 6.00e-05 | dt: 685.26ms | tok/sec: 5977.31 | norm: 4.36\n",
660
+ " \n",
661
+ "GPU Memory: 0.68GB / 1.77GB\n",
662
+ "step 5,100 | loss: 2.7330 | lr: 6.00e-05 | dt: 682.23ms | tok/sec: 6003.81 | norm: 4.92\n",
663
+ " \n",
664
+ "GPU Memory: 0.68GB / 1.77GB\n",
665
+ "step 5,200 | loss: 2.8838 | lr: 6.00e-05 | dt: 683.10ms | tok/sec: 5996.15 | norm: 6.76\n",
666
+ " \n",
667
+ "GPU Memory: 0.68GB / 1.77GB\n",
668
+ "step 5,300 | loss: 2.3963 | lr: 6.00e-05 | dt: 685.79ms | tok/sec: 5972.66 | norm: 4.91\n",
669
+ " \n",
670
+ "GPU Memory: 0.78GB / 1.77GB\n",
671
+ "step 5,400 | loss: 2.4825 | lr: 6.00e-05 | dt: 684.18ms | tok/sec: 5986.69 | norm: 4.63\n",
672
+ " \n",
673
+ "GPU Memory: 0.68GB / 1.77GB\n",
674
+ "step 5,500 | loss: 3.0586 | lr: 6.00e-05 | dt: 685.38ms | tok/sec: 5976.21 | norm: 5.94\n",
675
+ " \n",
676
+ "GPU Memory: 0.68GB / 1.77GB\n",
677
+ "step 5,600 | loss: 2.2882 | lr: 6.00e-05 | dt: 684.35ms | tok/sec: 5985.27 | norm: 5.28\n",
678
+ " \n",
679
+ "GPU Memory: 0.78GB / 1.77GB\n",
680
+ "step 5,700 | loss: 2.3943 | lr: 6.00e-05 | dt: 681.74ms | tok/sec: 6008.12 | norm: 5.19\n",
681
+ " \n",
682
+ "GPU Memory: 0.68GB / 1.77GB\n",
683
+ "step 5,800 | loss: 2.5011 | lr: 6.00e-05 | dt: 686.10ms | tok/sec: 5969.99 | norm: 5.41\n",
684
+ " \n",
685
+ "GPU Memory: 0.67GB / 1.77GB\n",
686
+ "step 5,900 | loss: 2.3386 | lr: 6.00e-05 | dt: 683.09ms | tok/sec: 5996.32 | norm: 5.78\n",
687
+ " \n",
688
+ "GPU Memory: 0.68GB / 1.77GB\n",
689
+ "step 6,000 | loss: 2.6910 | lr: 6.00e-05 | dt: 685.64ms | tok/sec: 5973.95 | norm: 5.91\n",
690
+ " \n",
691
+ "GPU Memory: 0.68GB / 1.77GB\n",
692
+ "step 6,100 | loss: 1.9940 | lr: 6.00e-05 | dt: 682.09ms | tok/sec: 6005.04 | norm: 4.81\n",
693
+ " \n",
694
+ "GPU Memory: 0.68GB / 1.77GB\n",
695
+ "step 6,200 | loss: 2.1706 | lr: 6.00e-05 | dt: 687.09ms | tok/sec: 5961.39 | norm: 6.18\n",
696
+ " \n",
697
+ "GPU Memory: 0.78GB / 1.77GB\n",
698
+ "step 6,300 | loss: 1.8759 | lr: 6.00e-05 | dt: 686.62ms | tok/sec: 5965.47 | norm: 4.45\n",
699
+ " \n",
700
+ "GPU Memory: 0.68GB / 1.77GB\n",
701
+ "step 6,400 | loss: 1.8825 | lr: 6.00e-05 | dt: 686.07ms | tok/sec: 5970.20 | norm: 5.26\n",
702
+ " \n",
703
+ "GPU Memory: 0.68GB / 1.77GB\n",
704
+ "step 6,500 | loss: 2.1047 | lr: 6.00e-05 | dt: 687.22ms | tok/sec: 5960.26 | norm: 5.17\n",
705
+ " \n",
706
+ "GPU Memory: 0.68GB / 1.77GB\n",
707
+ "step 6,600 | loss: 2.2490 | lr: 6.00e-05 | dt: 683.49ms | tok/sec: 5992.75 | norm: 6.13\n",
708
+ " \n",
709
+ "GPU Memory: 0.68GB / 1.77GB\n",
710
+ "step 6,700 | loss: 2.0222 | lr: 6.00e-05 | dt: 684.05ms | tok/sec: 5987.87 | norm: 4.83\n",
711
+ " \n",
712
+ "GPU Memory: 0.68GB / 1.77GB\n",
713
+ "step 6,800 | loss: 1.7948 | lr: 6.00e-05 | dt: 687.00ms | tok/sec: 5962.19 | norm: 5.72\n",
714
+ " \n",
715
+ "GPU Memory: 0.68GB / 1.77GB\n",
716
+ "step 6,900 | loss: 1.9430 | lr: 6.00e-05 | dt: 685.17ms | tok/sec: 5978.09 | norm: 6.76\n",
717
+ " \n",
718
+ "GPU Memory: 0.68GB / 1.77GB\n",
719
+ "step 7,000 | loss: 1.9375 | lr: 6.00e-05 | dt: 685.81ms | tok/sec: 5972.46 | norm: 5.95\n",
720
+ " \n",
721
+ "GPU Memory: 0.68GB / 1.77GB\n",
722
+ "step 7,100 | loss: 1.4104 | lr: 6.00e-05 | dt: 686.53ms | tok/sec: 5966.27 | norm: 5.39\n",
723
+ " \n",
724
+ "GPU Memory: 0.68GB / 1.77GB\n",
725
+ "step 7,200 | loss: 1.7128 | lr: 6.00e-05 | dt: 682.73ms | tok/sec: 5999.47 | norm: 4.76\n",
726
+ " \n",
727
+ "GPU Memory: 0.68GB / 1.77GB\n",
728
+ "step 7,300 | loss: 1.7015 | lr: 6.00e-05 | dt: 686.85ms | tok/sec: 5963.47 | norm: 4.74\n",
729
+ " \n",
730
+ "GPU Memory: 0.68GB / 1.77GB\n",
731
+ "step 7,400 | loss: 1.6215 | lr: 6.00e-05 | dt: 683.93ms | tok/sec: 5988.95 | norm: 6.07\n",
732
+ " \n",
733
+ "GPU Memory: 0.68GB / 1.77GB\n",
734
+ "step 7,500 | loss: 1.5474 | lr: 6.00e-05 | dt: 684.24ms | tok/sec: 5986.21 | norm: 5.89\n",
735
+ " \n",
736
+ "GPU Memory: 0.78GB / 1.77GB\n",
737
+ "step 7,600 | loss: 1.5799 | lr: 6.00e-05 | dt: 684.59ms | tok/sec: 5983.12 | norm: 5.08\n",
738
+ " \n",
739
+ "GPU Memory: 0.68GB / 1.77GB\n",
740
+ "step 7,700 | loss: 1.4209 | lr: 6.00e-05 | dt: 685.21ms | tok/sec: 5977.75 | norm: 4.84\n",
741
+ " \n",
742
+ "GPU Memory: 0.68GB / 1.77GB\n",
743
+ "step 7,800 | loss: 1.4405 | lr: 6.00e-05 | dt: 686.67ms | tok/sec: 5965.03 | norm: 4.81\n",
744
+ " \n",
745
+ "GPU Memory: 0.68GB / 1.77GB\n",
746
+ "step 7,900 | loss: 1.1260 | lr: 6.00e-05 | dt: 685.51ms | tok/sec: 5975.15 | norm: 6.15\n",
747
+ " \n",
748
+ "GPU Memory: 0.78GB / 1.77GB\n",
749
+ "step 8,000 | loss: 1.6376 | lr: 6.00e-05 | dt: 685.51ms | tok/sec: 5975.14 | norm: 7.62\n",
750
+ " \n",
751
+ "GPU Memory: 0.68GB / 1.77GB\n",
752
+ "step 8,100 | loss: 1.2116 | lr: 6.00e-05 | dt: 684.44ms | tok/sec: 5984.46 | norm: 5.31\n",
753
+ " \n",
754
+ "GPU Memory: 0.68GB / 1.77GB\n",
755
+ "step 8,200 | loss: 1.2855 | lr: 6.00e-05 | dt: 686.78ms | tok/sec: 5964.02 | norm: 6.41\n",
756
+ " \n",
757
+ "GPU Memory: 0.68GB / 1.77GB\n",
758
+ "step 8,300 | loss: 1.2305 | lr: 6.00e-05 | dt: 686.39ms | tok/sec: 5967.49 | norm: 4.97\n",
759
+ " \n",
760
+ "GPU Memory: 0.68GB / 1.77GB\n",
761
+ "step 8,400 | loss: 1.1149 | lr: 6.00e-05 | dt: 685.69ms | tok/sec: 5973.56 | norm: 5.16\n",
762
+ " \n",
763
+ "GPU Memory: 0.78GB / 1.77GB\n",
764
+ "step 8,500 | loss: 1.4075 | lr: 6.00e-05 | dt: 685.20ms | tok/sec: 5977.81 | norm: 6.99\n",
765
+ " \n",
766
+ "GPU Memory: 0.78GB / 1.77GB\n",
767
+ "step 8,600 | loss: 0.8826 | lr: 6.00e-05 | dt: 682.54ms | tok/sec: 6001.15 | norm: 4.75\n",
768
+ " \n",
769
+ "GPU Memory: 0.78GB / 1.77GB\n",
770
+ "step 8,700 | loss: 0.9010 | lr: 6.00e-05 | dt: 684.27ms | tok/sec: 5985.97 | norm: 5.06\n",
771
+ " \n",
772
+ "GPU Memory: 0.68GB / 1.77GB\n",
773
+ "step 8,800 | loss: 1.2441 | lr: 6.00e-05 | dt: 687.10ms | tok/sec: 5961.29 | norm: 5.49\n",
774
+ " \n",
775
+ "GPU Memory: 0.68GB / 1.77GB\n",
776
+ "step 8,900 | loss: 0.8399 | lr: 6.00e-05 | dt: 683.32ms | tok/sec: 5994.27 | norm: 7.54\n",
777
+ " \n",
778
+ "GPU Memory: 0.68GB / 1.77GB\n",
779
+ "step 9,000 | loss: 0.7800 | lr: 6.00e-05 | dt: 686.11ms | tok/sec: 5969.91 | norm: 4.41\n",
780
+ " \n",
781
+ "GPU Memory: 0.68GB / 1.77GB\n",
782
+ "step 9,100 | loss: 0.8157 | lr: 6.00e-05 | dt: 685.21ms | tok/sec: 5977.69 | norm: 4.66\n",
783
+ " \n",
784
+ "GPU Memory: 0.68GB / 1.77GB\n",
785
+ "step 9,200 | loss: 0.7936 | lr: 6.00e-05 | dt: 684.71ms | tok/sec: 5982.08 | norm: 5.31\n",
786
+ " \n",
787
+ "GPU Memory: 0.68GB / 1.77GB\n",
788
+ "step 9,300 | loss: 1.0805 | lr: 6.00e-05 | dt: 684.98ms | tok/sec: 5979.70 | norm: 5.38\n",
789
+ " \n",
790
+ "GPU Memory: 0.68GB / 1.77GB\n",
791
+ "step 9,400 | loss: 0.5698 | lr: 6.00e-05 | dt: 682.32ms | tok/sec: 6003.03 | norm: 4.22\n",
792
+ " \n",
793
+ "GPU Memory: 0.78GB / 1.77GB\n",
794
+ "step 9,500 | loss: 0.6732 | lr: 6.00e-05 | dt: 683.20ms | tok/sec: 5995.35 | norm: 5.23\n",
795
+ " \n",
796
+ "GPU Memory: 0.78GB / 1.77GB\n",
797
+ "step 9,600 | loss: 0.4544 | lr: 6.00e-05 | dt: 685.71ms | tok/sec: 5973.34 | norm: 3.60\n",
798
+ " \n",
799
+ "GPU Memory: 0.68GB / 1.77GB\n",
800
+ "step 9,700 | loss: 0.4766 | lr: 6.00e-05 | dt: 682.04ms | tok/sec: 6005.54 | norm: 4.36\n",
801
+ " \n",
802
+ "GPU Memory: 0.68GB / 1.77GB\n",
803
+ "step 9,800 | loss: 0.6707 | lr: 6.00e-05 | dt: 685.59ms | tok/sec: 5974.41 | norm: 4.26\n",
804
+ " \n",
805
+ "GPU Memory: 0.68GB / 1.77GB\n",
806
+ "step 9,900 | loss: 0.6953 | lr: 6.00e-05 | dt: 683.24ms | tok/sec: 5994.98 | norm: 5.27\n",
807
+ " \n",
808
+ "GPU Memory: 0.68GB / 1.77GB\n",
809
+ "step 10,000 | loss: 0.5863 | lr: 6.00e-05 | dt: 684.74ms | tok/sec: 5981.86 | norm: 3.94\n",
810
+ " \n",
811
+ "GPU Memory: 0.68GB / 1.77GB\n",
812
+ "step 10,100 | loss: 0.5372 | lr: 6.00e-05 | dt: 687.72ms | tok/sec: 5955.88 | norm: 3.74\n",
813
+ " \n",
814
+ "GPU Memory: 0.67GB / 1.77GB\n",
815
+ "step 10,200 | loss: 0.6054 | lr: 6.00e-05 | dt: 685.72ms | tok/sec: 5973.31 | norm: 5.71\n",
816
+ " \n",
817
+ "GPU Memory: 0.68GB / 1.77GB\n",
818
+ "step 10,300 | loss: 0.5850 | lr: 6.00e-05 | dt: 686.01ms | tok/sec: 5970.77 | norm: 4.36\n",
819
+ " \n",
820
+ "GPU Memory: 0.68GB / 1.77GB\n",
821
+ "step 10,400 | loss: 0.3319 | lr: 6.00e-05 | dt: 684.77ms | tok/sec: 5981.53 | norm: 4.68\n",
822
+ " \n",
823
+ "GPU Memory: 0.68GB / 1.77GB\n",
824
+ "step 10,500 | loss: 0.4140 | lr: 6.00e-05 | dt: 684.41ms | tok/sec: 5984.70 | norm: 3.21\n",
825
+ " \n",
826
+ "GPU Memory: 0.68GB / 1.77GB\n",
827
+ "step 10,600 | loss: 0.4008 | lr: 6.00e-05 | dt: 683.34ms | tok/sec: 5994.10 | norm: 3.58\n",
828
+ " \n",
829
+ "GPU Memory: 0.68GB / 1.77GB\n",
830
+ "step 10,700 | loss: 0.3951 | lr: 6.00e-05 | dt: 685.49ms | tok/sec: 5975.26 | norm: 3.81\n",
831
+ " \n",
832
+ "GPU Memory: 0.68GB / 1.77GB\n",
833
+ "step 10,800 | loss: 0.3022 | lr: 6.00e-05 | dt: 687.40ms | tok/sec: 5958.64 | norm: 3.06\n",
834
+ " \n",
835
+ "GPU Memory: 0.68GB / 1.77GB\n",
836
+ "step 10,900 | loss: 0.4287 | lr: 6.00e-05 | dt: 686.75ms | tok/sec: 5964.31 | norm: 3.60\n",
837
+ " \n",
838
+ "GPU Memory: 0.68GB / 1.77GB\n",
839
+ "step 11,000 | loss: 0.2447 | lr: 6.00e-05 | dt: 687.35ms | tok/sec: 5959.12 | norm: 3.35\n",
840
+ " \n",
841
+ "GPU Memory: 0.68GB / 1.77GB\n",
842
+ "step 11,100 | loss: 0.2773 | lr: 6.00e-05 | dt: 688.83ms | tok/sec: 5946.35 | norm: 2.71\n",
843
+ " \n",
844
+ "GPU Memory: 0.67GB / 1.77GB\n",
845
+ "step 11,200 | loss: 0.2839 | lr: 6.00e-05 | dt: 687.56ms | tok/sec: 5957.31 | norm: 3.90\n",
846
+ " \n",
847
+ "GPU Memory: 0.68GB / 1.77GB\n",
848
+ "step 11,300 | loss: 0.3481 | lr: 6.00e-05 | dt: 684.68ms | tok/sec: 5982.32 | norm: 3.68\n",
849
+ " \n",
850
+ "GPU Memory: 0.78GB / 1.77GB\n",
851
+ "step 11,400 | loss: 0.1913 | lr: 6.00e-05 | dt: 685.73ms | tok/sec: 5973.18 | norm: 2.93\n",
852
+ " \n",
853
+ "GPU Memory: 0.68GB / 1.77GB\n",
854
+ "step 11,500 | loss: 0.2605 | lr: 6.00e-05 | dt: 685.74ms | tok/sec: 5973.11 | norm: 2.96\n",
855
+ " \n",
856
+ "GPU Memory: 0.68GB / 1.77GB\n",
857
+ "step 11,600 | loss: 0.2029 | lr: 6.00e-05 | dt: 689.04ms | tok/sec: 5944.49 | norm: 2.84\n",
858
+ " \n",
859
+ "\n",
860
+ "Reached target loss! Final loss: 0.0889 at step 11,663\n",
861
+ "Model saved to gpt_model.pt\n"
862
+ ]
863
+ }
864
+ ],
865
+ "source": [
866
+ "# SEED\n",
867
+ "torch.manual_seed(1337)\n",
868
+ "if torch.cuda.is_available():\n",
869
+ " torch.cuda.manual_seed(1337)\n",
870
+ "\n",
871
+ "# STOP\n",
872
+ "num_return_sequences = 5\n",
873
+ "max_length = 30\n",
874
+ "\n",
875
+ "# Mixed precision and model\n",
876
+ "torch.set_float32_matmul_precision('high')\n",
877
+ "model = GPT(GPTConfig())\n",
878
+ "model.to(device)\n",
879
+ "model = torch.compile(model)\n",
880
+ "\n",
881
+ "# CODE UPDATE HERE\n",
882
+ "max_lr = 6e-4\n",
883
+ "min_lr = max_lr * 0.1\n",
884
+ "warmup_steps = 10\n",
885
+ "max_steps = 50\n",
886
+ "\n",
887
+ "# Dataloader\n",
888
+ "train_loader = DataLoaderLite(B = 4, T = 256)\n",
889
+ "\n",
890
+ "# Optimizer\n",
891
+ "optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device)\n",
892
+ "step = 0\n",
893
+ "\n",
894
+ "# Add memory optimization settings\n",
895
+ "torch.backends.cuda.matmul.allow_tf32 = True\n",
896
+ "torch.backends.cudnn.allow_tf32 = True\n",
897
+ "torch.set_float32_matmul_precision('high')\n",
898
+ "\n",
899
+ "# In the training loop, add gradient accumulation\n",
900
+ "gradient_accumulation_steps = 4 # Accumulate gradients over 4 steps\n",
901
+ "\n",
902
+ "while True:\n",
903
+ " t0 = time.time()\n",
904
+ "\n",
905
+ " # Reset gradients at the start of accumulation\n",
906
+ " optimizer.zero_grad()\n",
907
+ "\n",
908
+ " # Accumulate gradients\n",
909
+ " for _ in range(gradient_accumulation_steps):\n",
910
+ " x, y = train_loader.next_batch()\n",
911
+ " x, y = x.to(device), y.to(device)\n",
912
+ "\n",
913
+ " with torch.autocast(device_type=device, dtype=torch.bfloat16):\n",
914
+ " logits, loss = model(x, y)\n",
915
+ " # Scale loss by accumulation steps\n",
916
+ " loss = loss / gradient_accumulation_steps\n",
917
+ "\n",
918
+ " loss.backward()\n",
919
+ "\n",
920
+ " # Clip gradients and update weights once per accumulation\n",
921
+ " norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
922
+ "\n",
923
+ " lr = get_lr(step)\n",
924
+ " for param_group in optimizer.param_groups:\n",
925
+ " param_group['lr'] = lr\n",
926
+ "\n",
927
+ " optimizer.step()\n",
928
+ "\n",
929
+ " if device == 'cuda' and step % 100 == 0: # Print memory every 100 steps\n",
930
+ " torch.cuda.synchronize()\n",
931
+ " print(f\"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB\")\n",
932
+ "\n",
933
+ " t1 = time.time()\n",
934
+ " dt = (t1 - t0) * 1000\n",
935
+ " tokens_per_sec = (train_loader.B * train_loader.T * gradient_accumulation_steps) / (t1 - t0)\n",
936
+ "\n",
937
+ " if step % 100 == 0: # Print details every 100 steps\n",
938
+ " print(f'step {step:,} | loss: {loss.item()*gradient_accumulation_steps:.4f} | lr: {lr:.2e} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec:.2f} | norm: {norm:.2f}')\n",
939
+ " print(\" \")\n",
940
+ "\n",
941
+ " actual_loss = loss.item() * gradient_accumulation_steps\n",
942
+ " if actual_loss < 0.09:\n",
943
+ " print(f'\\nReached target loss! Final loss: {actual_loss:.4f} at step {step:,}')\n",
944
+ " save_path = 'gpt_model.pt'\n",
945
+ " torch.save(model.state_dict(), save_path)\n",
946
+ " print(f\"Model saved to {save_path}\")\n",
947
+ " break\n",
948
+ "\n",
949
+ " step += 1"
950
+ ]
951
+ },
952
+ {
953
+ "cell_type": "markdown",
954
+ "metadata": {
955
+ "id": "ZmZ4Yk9LQehd"
956
+ },
957
+ "source": [
958
+ "## Save Model"
959
+ ]
960
+ },
961
+ {
962
+ "cell_type": "code",
963
+ "execution_count": 10,
964
+ "metadata": {
965
+ "colab": {
966
+ "base_uri": "https://localhost:8080/"
967
+ },
968
+ "id": "ByO3iW55Qehd",
969
+ "outputId": "81185c36-fd16-416e-e439-64788940776f"
970
+ },
971
+ "outputs": [
972
+ {
973
+ "output_type": "stream",
974
+ "name": "stdout",
975
+ "text": [
976
+ "Total model parameters: 33,709,824\n",
977
+ "Model saved to nano_gpt_model.pt\n"
978
+ ]
979
+ }
980
+ ],
981
+ "source": [
982
+ "# Print total model parameters\n",
983
+ "total_params = sum(p.numel() for p in model.parameters())\n",
984
+ "print(f\"Total model parameters: {total_params:,}\")\n",
985
+ "\n",
986
+ "# Save the model\n",
987
+ "save_path = 'nano_gpt_model.pt'\n",
988
+ "torch.save(model.state_dict(), save_path)\n",
989
+ "print(f\"Model saved to {save_path}\")"
990
+ ]
991
+ },
992
+ {
993
+ "cell_type": "markdown",
994
+ "metadata": {
995
+ "id": "Tj9Rs-dysuOg"
996
+ },
997
+ "source": [
998
+ "## Inference"
999
+ ]
1000
+ },
1001
+ {
1002
+ "cell_type": "code",
1003
+ "execution_count": 11,
1004
+ "metadata": {
1005
+ "id": "CE2_CV1TcttD",
1006
+ "colab": {
1007
+ "base_uri": "https://localhost:8080/"
1008
+ },
1009
+ "outputId": "420c8d5e-7c65-4f6c-bc1e-467350ea6468"
1010
+ },
1011
+ "outputs": [
1012
+ {
1013
+ "output_type": "stream",
1014
+ "name": "stdout",
1015
+ "text": [
1016
+ "\n",
1017
+ "Generating text samples...\n"
1018
+ ]
1019
+ },
1020
+ {
1021
+ "output_type": "stream",
1022
+ "name": "stderr",
1023
+ "text": [
1024
+ "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:87: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
1025
+ " warnings.warn(\n"
1026
+ ]
1027
+ },
1028
+ {
1029
+ "output_type": "stream",
1030
+ "name": "stdout",
1031
+ "text": [
1032
+ "\n",
1033
+ "Generated text:\n",
1034
+ "\n",
1035
+ "Once upon a time to not;\n",
1036
+ "More slaughter'd, sweet Rivers, I receive my children, and title with pardon hither\n",
1037
+ "That one stuff'd with a conquest; and teeth, of my? Why, in life thee,\n",
1038
+ "Which now not joy of foe, thought o'n slaughter bed,\n",
1039
+ "And, is mine own soul me, not so heavy in every day:\n",
1040
+ "The tyrant from one curst my death lies;\n",
1041
+ "For the ground is nothing henceforth fell executioner come\n"
1042
+ ]
1043
+ }
1044
+ ],
1045
+ "source": [
1046
+ "# Text generation\n",
1047
+ "print(\"\\nGenerating text samples...\")\n",
1048
+ "enc = tiktoken.get_encoding('gpt2')\n",
1049
+ "context = \"Once upon a time\"\n",
1050
+ "x = torch.tensor([enc.encode(context)], dtype=torch.long, device=device)\n",
1051
+ "\n",
1052
+ "max_length = 100 # Generate 100 tokens\n",
1053
+ "torch.manual_seed(42)\n",
1054
+ "if torch.cuda.is_available():\n",
1055
+ " torch.cuda.manual_seed(42)\n",
1056
+ "\n",
1057
+ "while x.size(1) < max_length:\n",
1058
+ " with torch.no_grad():\n",
1059
+ " with torch.autocast(device_type=device, dtype=torch.bfloat16):\n",
1060
+ " logits = model(x)[0]\n",
1061
+ " logits = logits[:, -1, :]\n",
1062
+ " probs = F.softmax(logits, dim=-1)\n",
1063
+ " topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)\n",
1064
+ " ix = torch.multinomial(topk_probs, num_samples=1)\n",
1065
+ " xcol = torch.gather(topk_indices, -1, ix)\n",
1066
+ " x = torch.cat([x, xcol], dim=1)\n",
1067
+ "\n",
1068
+ "print(\"\\nGenerated text:\")\n",
1069
+ "tokens = x[0].tolist() # Take first sequence\n",
1070
+ "decoded = enc.decode(tokens)\n",
1071
+ "print(f\"\\n{decoded}\")"
1072
+ ]
1073
+ }
1074
+ ],
1075
+ "metadata": {
1076
+ "accelerator": "GPU",
1077
+ "colab": {
1078
+ "gpuType": "T4",
1079
+ "provenance": []
1080
+ },
1081
+ "kernelspec": {
1082
+ "display_name": "Python 3",
1083
+ "name": "python3"
1084
+ },
1085
+ "language_info": {
1086
+ "name": "python"
1087
+ }
1088
+ },
1089
+ "nbformat": 4,
1090
+ "nbformat_minor": 0
1091
+ }
assets/LLMfromScratch2.png ADDED
input.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ gradio>=3.50.0
3
+ tiktoken>=0.5.1
4
+ numpy>=1.24.0
5
+ jupyter>=1.0.0