{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "LzOshnYOpqP_" }, "source": [ "# NanoGPT\n", "\n", "Training a decoder-only model (NanoGPT) to genereate text in Shakespear stype" ] }, { "cell_type": "markdown", "metadata": { "id": "J93p7rk7qK-P" }, "source": [ "## Install Modules" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8MnYOQ4xcKXa", "outputId": "0b15787f-f2a8-4ba6-d744-a60d08f9d152" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/1.2 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.2/1.2 MB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m17.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "# Tiktoken for tokenization\n", "!pip install tiktoken --quiet" ] }, { "cell_type": "markdown", "metadata": { "id": "uHuCVeKYqWvo" }, "source": [ "## Import Modules" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "igAl5bXSqZWo" }, "outputs": [], "source": [ "# Standard Library Imports\n", "import os\n", "import math\n", "import time\n", "import inspect\n", "from dataclasses import dataclass\n", "\n", "# Third-Party Imports\n", "import tiktoken\n", "import torch\n", "import torch.nn as nn\n", "from torch.nn import functional as F" ] }, { "cell_type": "markdown", "metadata": { "id": "6QNqCz2fqSX0" }, "source": [ "## Transformer Achitecture" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "jgXGx_-YqVNH" }, "outputs": [], "source": [ "class CausalSelfAttention(nn.Module):\n", "\n", " def __init__(self, config):\n", " super().__init__()\n", " assert config.n_embd % config.n_head == 0\n", " # key, query, value projections for all heads, but in a batch\n", " self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)\n", " # output projection\n", " self.c_proj = nn.Linear(config.n_embd, config.n_embd)\n", " self.c_proj.NANGPT_SCALE_INIT = 1\n", " # regularization\n", " self.n_head = config.n_head\n", " self.n_embd = config.n_embd\n", " self.register_buffer(\"bias\", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))\n", "\n", " def forward(self, x):\n", " B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)\n", " # calculate query, key, values for all heads in batch and move head forward to be the batch dim\n", " # nh is \"number of heads\", hs is \"head size\", and C (number of channels) = nh * hs\n", " # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer\n", " qkv = self.c_attn(x)\n", " q, k, v = qkv.split(self.n_embd, dim=2)\n", " k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n", " q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n", " v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n", "\n", " # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))\n", " # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))\n", " # att = F.softmax(att, dim=-1)\n", " # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n", "\n", " y = F.scaled_dot_product_attention(q, k, v, is_causal = True) # Flash attention\n", "\n", " y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side\n", " # output projection\n", " y = self.c_proj(y)\n", " return y\n", "\n", "\n", "class MLP(nn.Module):\n", "\n", " def __init__(self, config):\n", " super().__init__()\n", " self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)\n", " self.gelu = nn.GELU(approximate='tanh')\n", " self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)\n", " self.c_proj.NANOGPT_SCALE_INIT = 1\n", "\n", " def forward(self, x):\n", " x = self.c_fc(x)\n", " x = self.gelu(x)\n", " x = self.c_proj(x)\n", " return x\n", "\n", "class Block(nn.Module):\n", "\n", " def __init__(self, config):\n", " super().__init__()\n", " self.ln_1 = nn.LayerNorm(config.n_embd)\n", " self.attn = CausalSelfAttention(config)\n", " self.ln_2 = nn.LayerNorm(config.n_embd)\n", " self.mlp = MLP(config)\n", "\n", " def forward(self, x):\n", " x = x + self.attn(self.ln_1(x))\n", " x = x + self.mlp(self.ln_2(x))\n", " return x\n", "\n", "class GPT(nn.Module):\n", "\n", " def __init__(self, config):\n", " super().__init__()\n", " self.config = config\n", " self.gradient_checkpointing = True\n", "\n", " self.transformer = nn.ModuleDict(dict(\n", " wte = nn.Embedding(config.vocab_size, config.n_embd),\n", " wpe = nn.Embedding(config.block_size, config.n_embd),\n", " h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),\n", " ln_f = nn.LayerNorm(config.n_embd),\n", " ))\n", " self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n", "\n", " # weight sharing\n", " self.transformer.wte.weight = self.lm_head.weight\n", "\n", " # weight initialization\n", " self.apply(self._init_weights)\n", "\n", " def _init_weights(self, module):\n", " if isinstance(module, nn.Linear):\n", " std = 0.02\n", " if hasattr(module, 'NANGPT_SCALE_INIT'):\n", " std *= (2 * self.config.n_layer) ** -0.5\n", " torch.nn.init.normal_(module.weight, mean = 0.0, std = std)\n", " if module.bias is not None:\n", " torch.nn.init.zeros_(module.bias)\n", " elif isinstance(module, nn.Embedding):\n", " torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)\n", "\n", "\n", "\n", " def forward(self, idx, targets=None):\n", " # idx is of shape (B, T)\n", " B, T = idx.size()\n", " assert T <= self.config.block_size, f\"Cannot forward sequence of length {T}, block size is only {self.config.block_size}\"\n", " # forward the token and posisition embeddings\n", " # pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)\n", " # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)\n", " # tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)\n", " pos = torch.arange(0, T, dtype=torch.long, device=idx.device)\n", " pos_emb = self.transformer.wpe(pos)\n", " tok_emb = self.transformer.wte(idx)\n", " x = tok_emb + pos_emb\n", " # forward the blocks of the transformer\n", " for block in self.transformer.h:\n", " x = block(x)\n", " # Modify the transformer blocks section to use gradient checkpointing\n", " if self.gradient_checkpointing and self.training:\n", " for block in self.transformer.h:\n", " x = torch.utils.checkpoint.checkpoint(block, x)\n", " else:\n", " for block in self.transformer.h:\n", " x = block(x)\n", "\n", " x = self.transformer.ln_f(x)\n", " logits = self.lm_head(x)\n", "\n", " loss = None\n", " if targets is not None:\n", " loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))\n", " return logits, loss\n", "\n", " @classmethod\n", " def from_pretrained(cls, model_type):\n", " \"\"\"Loads pretrained GPT-2 model weights from huggingface\"\"\"\n", " assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}\n", " from transformers import GPT2LMHeadModel\n", " print(\"loading weights from pretrained gpt: %s\" % model_type)\n", "\n", " # n_layer, n_head and n_embd are determined from model_type\n", " config_args = {\n", " 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params\n", " 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n", " 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n", " 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n", " }[model_type]\n", " config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints\n", " config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints\n", " # create a from-scratch initialized minGPT model\n", " config = GPTConfig(**config_args)\n", " model = GPT(config)\n", " sd = model.state_dict()\n", " sd_keys = sd.keys()\n", " sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param\n", "\n", " # init a huggingface/transformers model\n", " model_hf = GPT2LMHeadModel.from_pretrained(model_type)\n", " sd_hf = model_hf.state_dict()\n", "\n", " # copy while ensuring all of the parameters are aligned and match in names and shapes\n", " sd_keys_hf = sd_hf.keys()\n", " sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer\n", " sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)\n", " transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']\n", " # basically the openai checkpoints use a \"Conv1D\" module, but we only want to use a vanilla Linear\n", " # this means that we have to transpose these weights when we import them\n", " assert len(sd_keys_hf) == len(sd_keys), f\"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}\"\n", " for k in sd_keys_hf:\n", " if any(k.endswith(w) for w in transposed):\n", " # special treatment for the Conv1D weights we need to transpose\n", " assert sd_hf[k].shape[::-1] == sd[k].shape\n", " with torch.no_grad():\n", " sd[k].copy_(sd_hf[k].t())\n", " else:\n", " # vanilla copy over the other parameters\n", " assert sd_hf[k].shape == sd[k].shape\n", " with torch.no_grad():\n", " sd[k].copy_(sd_hf[k])\n", "\n", " return model\n", "\n", " def configure_optimizers(self, weight_decay, learning_rate, device_type):\n", " # start with all of the candidate parameters (that require grad)\n", " param_dict = {pn: p for pn, p in self.named_parameters()}\n", " param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}\n", " # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.\n", " # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.\n", " decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]\n", " nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]\n", " optim_groups = [\n", " {'params': decay_params, 'weight_decay': weight_decay},\n", " {'params': nodecay_params, 'weight_decay': 0.0}\n", " ]\n", " num_decay_params = sum(p.numel() for p in decay_params)\n", " num_nodecay_params = sum(p.numel() for p in nodecay_params)\n", "\n", " print(f\"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters\")\n", " print(f\"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters\")\n", " # Create AdamW optimizer and use the fused version if it is available\n", " fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters\n", " use_fused = fused_available and device_type == \"cuda\"\n", "\n", " print(f\"using fused AdamW: {use_fused}\")\n", " optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)\n", " return optimizer" ] }, { "cell_type": "markdown", "metadata": { "id": "C3Xev9Ycq7Qe" }, "source": [ "## Configuration Parameters" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "27h8bNasq-Xe" }, "outputs": [], "source": [ "@dataclass\n", "class GPTConfig:\n", " block_size: int = 512 # max sequence length\n", " vocab_size: int = 50304 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token\n", " n_layer: int = 8 # number of layers\n", " n_head: int = 8 # number of heads\n", " n_embd: int = 384 # embedding dimension 768" ] }, { "cell_type": "markdown", "metadata": { "id": "GHpQZ_avrMyd" }, "source": [ "## DataLoader" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "NcXgug8hrPO-" }, "outputs": [], "source": [ "class DataLoaderLite:\n", " def __init__(self, B, T):\n", " self.B = B\n", " self.T = T\n", "\n", " # Modify path to your input file location in Drive\n", " input_path = 'input.txt' # Update this path\n", " try:\n", " with open(input_path, 'r', encoding='utf-8') as f:\n", " text = f.read()\n", " print(f\"Successfully loaded text from {input_path}\")\n", " except Exception as e:\n", " print(f\"Error loading file: {e}\")\n", " raise\n", "\n", " enc = tiktoken.get_encoding('gpt2')\n", " tokens = enc.encode(text)\n", " self.tokens = torch.tensor(tokens)\n", " print(f'Loaded {len(self.tokens):,} tokens')\n", " print(f'1 epoch = {len(self.tokens) // (B * T):,} batches')\n", " print(f'Input text size: {len(text):,} characters')\n", "\n", " self.current_position = 0\n", "\n", " def next_batch(self):\n", " B, T = self.B, self.T\n", " buf = self.tokens[self.current_position: self.current_position + B * T + 1]\n", " x = (buf[:-1]).view(B, T) # inputs\n", " y = (buf[1:]).view(B, T) # targets\n", " # advance the position in the tensor\n", " self.current_position += B*T\n", " # if loading the next batch would be out of bounds, reset\n", " if self.current_position + (B * T + 1) > len(self.tokens):\n", " self.current_position = 0\n", " return x, y" ] }, { "cell_type": "markdown", "metadata": { "id": "y0eEHPNrrz6b" }, "source": [ "## Device Configutration" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "w_OS901rr3SO", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "0c386cef-8cc3-4964-916a-178cd5668959" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "CUDA available: True\n", "CUDA device: Tesla T4\n", "Total GPU memory: 15.84 GB\n", "Using device: cuda\n" ] } ], "source": [ "# Add CUDA check and memory info\n", "print(f\"CUDA available: {torch.cuda.is_available()}\")\n", "if torch.cuda.is_available():\n", " print(f\"CUDA device: {torch.cuda.get_device_name(0)}\")\n", " print(f\"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\")\n", "\n", "# Modify the device selection\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Using device: {device}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "p74YZyRe_3_Z" }, "source": [ "## Utilities" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "q1Q1T5az_7Pb" }, "outputs": [], "source": [ "def get_lr(it):\n", " if it < warmup_steps:\n", " return max_lr * (it + 1) / warmup_steps\n", " if it > max_steps:\n", " return min_lr\n", " decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)\n", " assert 0 <= decay_ratio <=1\n", " coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))\n", " return min_lr + coeff * (max_lr - min_lr)" ] }, { "cell_type": "markdown", "metadata": { "id": "veEWTXgVsIVz" }, "source": [ "## Model Training" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NjuBy6HGb9Ta", "outputId": "78a7feb8-c362-4343-a3d1-8a1cbf744491" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Successfully loaded text from input.txt\n", "Loaded 338,025 tokens\n", "1 epoch = 330 batches\n", "Input text size: 1,115,394 characters\n", "num decayed parameter tensors: 34, with 33,669,120 parameters\n", "num non-decayed parameter tensors: 66, with 40,704 parameters\n", "using fused AdamW: True\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:1604: UserWarning: Tesla T4 does not support bfloat16 compilation natively, skipping\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:1604: UserWarning: Tesla T4 does not support bfloat16 compilation natively, skipping\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:1604: UserWarning: Tesla T4 does not support bfloat16 compilation natively, skipping\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:1604: UserWarning: Tesla T4 does not support bfloat16 compilation natively, skipping\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:632: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " return fn(*args, **kwargs)\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "GPU Memory: 0.68GB / 1.55GB\n", "step 0 | loss: 10.9470 | lr: 6.00e-05 | dt: 23182.17ms | tok/sec: 176.69 | norm: 5.03\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 100 | loss: 5.9936 | lr: 6.00e-05 | dt: 687.63ms | tok/sec: 5956.71 | norm: 0.90\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 200 | loss: 6.1430 | lr: 6.00e-05 | dt: 680.82ms | tok/sec: 6016.26 | norm: 0.85\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 300 | loss: 5.7334 | lr: 6.00e-05 | dt: 684.30ms | tok/sec: 5985.65 | norm: 1.32\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 400 | loss: 5.5456 | lr: 6.00e-05 | dt: 680.26ms | tok/sec: 6021.24 | norm: 1.12\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 500 | loss: 4.7860 | lr: 6.00e-05 | dt: 682.16ms | tok/sec: 6004.47 | norm: 1.53\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 600 | loss: 5.4306 | lr: 6.00e-05 | dt: 683.22ms | tok/sec: 5995.14 | norm: 1.23\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 700 | loss: 5.3188 | lr: 6.00e-05 | dt: 683.26ms | tok/sec: 5994.79 | norm: 1.21\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 800 | loss: 5.1400 | lr: 6.00e-05 | dt: 686.78ms | tok/sec: 5964.07 | norm: 1.84\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 900 | loss: 4.9057 | lr: 6.00e-05 | dt: 680.29ms | tok/sec: 6020.97 | norm: 1.59\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 1,000 | loss: 4.8926 | lr: 6.00e-05 | dt: 684.95ms | tok/sec: 5980.04 | norm: 1.48\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 1,100 | loss: 4.7693 | lr: 6.00e-05 | dt: 685.51ms | tok/sec: 5975.12 | norm: 2.60\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 1,200 | loss: 5.0689 | lr: 6.00e-05 | dt: 682.40ms | tok/sec: 6002.35 | norm: 1.57\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 1,300 | loss: 4.5868 | lr: 6.00e-05 | dt: 682.93ms | tok/sec: 5997.68 | norm: 1.86\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 1,400 | loss: 4.8036 | lr: 6.00e-05 | dt: 682.67ms | tok/sec: 5999.99 | norm: 1.92\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 1,500 | loss: 4.5739 | lr: 6.00e-05 | dt: 682.74ms | tok/sec: 5999.33 | norm: 1.89\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 1,600 | loss: 4.3767 | lr: 6.00e-05 | dt: 681.53ms | tok/sec: 6009.97 | norm: 2.48\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 1,700 | loss: 4.7170 | lr: 6.00e-05 | dt: 682.21ms | tok/sec: 6004.05 | norm: 2.88\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 1,800 | loss: 4.5040 | lr: 6.00e-05 | dt: 681.62ms | tok/sec: 6009.23 | norm: 2.25\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 1,900 | loss: 4.4489 | lr: 6.00e-05 | dt: 686.54ms | tok/sec: 5966.18 | norm: 2.01\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 2,000 | loss: 3.9155 | lr: 6.00e-05 | dt: 683.89ms | tok/sec: 5989.27 | norm: 2.26\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 2,100 | loss: 4.1443 | lr: 6.00e-05 | dt: 681.80ms | tok/sec: 6007.60 | norm: 2.33\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 2,200 | loss: 4.6614 | lr: 6.00e-05 | dt: 681.72ms | tok/sec: 6008.37 | norm: 2.44\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 2,300 | loss: 3.6722 | lr: 6.00e-05 | dt: 684.94ms | tok/sec: 5980.11 | norm: 2.75\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 2,400 | loss: 4.0311 | lr: 6.00e-05 | dt: 684.12ms | tok/sec: 5987.25 | norm: 2.90\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 2,500 | loss: 4.1794 | lr: 6.00e-05 | dt: 686.70ms | tok/sec: 5964.77 | norm: 2.67\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 2,600 | loss: 3.7992 | lr: 6.00e-05 | dt: 681.41ms | tok/sec: 6011.09 | norm: 3.10\n", " \n", "GPU Memory: 0.67GB / 1.77GB\n", "step 2,700 | loss: 4.5840 | lr: 6.00e-05 | dt: 682.54ms | tok/sec: 6001.09 | norm: 3.52\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 2,800 | loss: 3.6126 | lr: 6.00e-05 | dt: 682.40ms | tok/sec: 6002.35 | norm: 2.92\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 2,900 | loss: 3.6538 | lr: 6.00e-05 | dt: 681.88ms | tok/sec: 6006.93 | norm: 3.04\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 3,000 | loss: 3.4119 | lr: 6.00e-05 | dt: 684.70ms | tok/sec: 5982.20 | norm: 2.96\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 3,100 | loss: 3.4817 | lr: 6.00e-05 | dt: 686.39ms | tok/sec: 5967.42 | norm: 3.22\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 3,200 | loss: 3.7793 | lr: 6.00e-05 | dt: 682.57ms | tok/sec: 6000.86 | norm: 4.25\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 3,300 | loss: 3.6155 | lr: 6.00e-05 | dt: 685.45ms | tok/sec: 5975.64 | norm: 4.07\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 3,400 | loss: 3.7451 | lr: 6.00e-05 | dt: 686.79ms | tok/sec: 5963.94 | norm: 3.34\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 3,500 | loss: 3.5069 | lr: 6.00e-05 | dt: 686.91ms | tok/sec: 5962.97 | norm: 4.46\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 3,600 | loss: 3.5325 | lr: 6.00e-05 | dt: 683.46ms | tok/sec: 5993.08 | norm: 4.59\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 3,700 | loss: 3.6000 | lr: 6.00e-05 | dt: 683.37ms | tok/sec: 5993.87 | norm: 4.69\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 3,800 | loss: 2.7920 | lr: 6.00e-05 | dt: 681.15ms | tok/sec: 6013.37 | norm: 3.43\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 3,900 | loss: 3.5265 | lr: 6.00e-05 | dt: 684.23ms | tok/sec: 5986.29 | norm: 3.96\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 4,000 | loss: 3.5139 | lr: 6.00e-05 | dt: 680.86ms | tok/sec: 6015.93 | norm: 4.81\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 4,100 | loss: 3.4394 | lr: 6.00e-05 | dt: 680.96ms | tok/sec: 6015.00 | norm: 4.55\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 4,200 | loss: 3.1918 | lr: 6.00e-05 | dt: 682.53ms | tok/sec: 6001.21 | norm: 5.20\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 4,300 | loss: 3.2290 | lr: 6.00e-05 | dt: 684.99ms | tok/sec: 5979.67 | norm: 4.89\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 4,400 | loss: 3.0279 | lr: 6.00e-05 | dt: 678.78ms | tok/sec: 6034.36 | norm: 4.00\n", " \n", "GPU Memory: 0.67GB / 1.77GB\n", "step 4,500 | loss: 3.3566 | lr: 6.00e-05 | dt: 680.69ms | tok/sec: 6017.44 | norm: 4.92\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 4,600 | loss: 2.9366 | lr: 6.00e-05 | dt: 683.88ms | tok/sec: 5989.36 | norm: 4.65\n", " \n", "GPU Memory: 0.67GB / 1.77GB\n", "step 4,700 | loss: 3.1994 | lr: 6.00e-05 | dt: 679.04ms | tok/sec: 6032.08 | norm: 7.23\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 4,800 | loss: 2.9273 | lr: 6.00e-05 | dt: 681.97ms | tok/sec: 6006.12 | norm: 3.98\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 4,900 | loss: 2.8992 | lr: 6.00e-05 | dt: 686.39ms | tok/sec: 5967.48 | norm: 5.38\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 5,000 | loss: 2.9706 | lr: 6.00e-05 | dt: 685.26ms | tok/sec: 5977.31 | norm: 4.36\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 5,100 | loss: 2.7330 | lr: 6.00e-05 | dt: 682.23ms | tok/sec: 6003.81 | norm: 4.92\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 5,200 | loss: 2.8838 | lr: 6.00e-05 | dt: 683.10ms | tok/sec: 5996.15 | norm: 6.76\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 5,300 | loss: 2.3963 | lr: 6.00e-05 | dt: 685.79ms | tok/sec: 5972.66 | norm: 4.91\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 5,400 | loss: 2.4825 | lr: 6.00e-05 | dt: 684.18ms | tok/sec: 5986.69 | norm: 4.63\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 5,500 | loss: 3.0586 | lr: 6.00e-05 | dt: 685.38ms | tok/sec: 5976.21 | norm: 5.94\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 5,600 | loss: 2.2882 | lr: 6.00e-05 | dt: 684.35ms | tok/sec: 5985.27 | norm: 5.28\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 5,700 | loss: 2.3943 | lr: 6.00e-05 | dt: 681.74ms | tok/sec: 6008.12 | norm: 5.19\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 5,800 | loss: 2.5011 | lr: 6.00e-05 | dt: 686.10ms | tok/sec: 5969.99 | norm: 5.41\n", " \n", "GPU Memory: 0.67GB / 1.77GB\n", "step 5,900 | loss: 2.3386 | lr: 6.00e-05 | dt: 683.09ms | tok/sec: 5996.32 | norm: 5.78\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 6,000 | loss: 2.6910 | lr: 6.00e-05 | dt: 685.64ms | tok/sec: 5973.95 | norm: 5.91\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 6,100 | loss: 1.9940 | lr: 6.00e-05 | dt: 682.09ms | tok/sec: 6005.04 | norm: 4.81\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 6,200 | loss: 2.1706 | lr: 6.00e-05 | dt: 687.09ms | tok/sec: 5961.39 | norm: 6.18\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 6,300 | loss: 1.8759 | lr: 6.00e-05 | dt: 686.62ms | tok/sec: 5965.47 | norm: 4.45\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 6,400 | loss: 1.8825 | lr: 6.00e-05 | dt: 686.07ms | tok/sec: 5970.20 | norm: 5.26\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 6,500 | loss: 2.1047 | lr: 6.00e-05 | dt: 687.22ms | tok/sec: 5960.26 | norm: 5.17\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 6,600 | loss: 2.2490 | lr: 6.00e-05 | dt: 683.49ms | tok/sec: 5992.75 | norm: 6.13\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 6,700 | loss: 2.0222 | lr: 6.00e-05 | dt: 684.05ms | tok/sec: 5987.87 | norm: 4.83\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 6,800 | loss: 1.7948 | lr: 6.00e-05 | dt: 687.00ms | tok/sec: 5962.19 | norm: 5.72\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 6,900 | loss: 1.9430 | lr: 6.00e-05 | dt: 685.17ms | tok/sec: 5978.09 | norm: 6.76\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 7,000 | loss: 1.9375 | lr: 6.00e-05 | dt: 685.81ms | tok/sec: 5972.46 | norm: 5.95\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 7,100 | loss: 1.4104 | lr: 6.00e-05 | dt: 686.53ms | tok/sec: 5966.27 | norm: 5.39\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 7,200 | loss: 1.7128 | lr: 6.00e-05 | dt: 682.73ms | tok/sec: 5999.47 | norm: 4.76\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 7,300 | loss: 1.7015 | lr: 6.00e-05 | dt: 686.85ms | tok/sec: 5963.47 | norm: 4.74\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 7,400 | loss: 1.6215 | lr: 6.00e-05 | dt: 683.93ms | tok/sec: 5988.95 | norm: 6.07\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 7,500 | loss: 1.5474 | lr: 6.00e-05 | dt: 684.24ms | tok/sec: 5986.21 | norm: 5.89\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 7,600 | loss: 1.5799 | lr: 6.00e-05 | dt: 684.59ms | tok/sec: 5983.12 | norm: 5.08\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 7,700 | loss: 1.4209 | lr: 6.00e-05 | dt: 685.21ms | tok/sec: 5977.75 | norm: 4.84\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 7,800 | loss: 1.4405 | lr: 6.00e-05 | dt: 686.67ms | tok/sec: 5965.03 | norm: 4.81\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 7,900 | loss: 1.1260 | lr: 6.00e-05 | dt: 685.51ms | tok/sec: 5975.15 | norm: 6.15\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 8,000 | loss: 1.6376 | lr: 6.00e-05 | dt: 685.51ms | tok/sec: 5975.14 | norm: 7.62\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 8,100 | loss: 1.2116 | lr: 6.00e-05 | dt: 684.44ms | tok/sec: 5984.46 | norm: 5.31\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 8,200 | loss: 1.2855 | lr: 6.00e-05 | dt: 686.78ms | tok/sec: 5964.02 | norm: 6.41\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 8,300 | loss: 1.2305 | lr: 6.00e-05 | dt: 686.39ms | tok/sec: 5967.49 | norm: 4.97\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 8,400 | loss: 1.1149 | lr: 6.00e-05 | dt: 685.69ms | tok/sec: 5973.56 | norm: 5.16\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 8,500 | loss: 1.4075 | lr: 6.00e-05 | dt: 685.20ms | tok/sec: 5977.81 | norm: 6.99\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 8,600 | loss: 0.8826 | lr: 6.00e-05 | dt: 682.54ms | tok/sec: 6001.15 | norm: 4.75\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 8,700 | loss: 0.9010 | lr: 6.00e-05 | dt: 684.27ms | tok/sec: 5985.97 | norm: 5.06\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 8,800 | loss: 1.2441 | lr: 6.00e-05 | dt: 687.10ms | tok/sec: 5961.29 | norm: 5.49\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 8,900 | loss: 0.8399 | lr: 6.00e-05 | dt: 683.32ms | tok/sec: 5994.27 | norm: 7.54\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 9,000 | loss: 0.7800 | lr: 6.00e-05 | dt: 686.11ms | tok/sec: 5969.91 | norm: 4.41\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 9,100 | loss: 0.8157 | lr: 6.00e-05 | dt: 685.21ms | tok/sec: 5977.69 | norm: 4.66\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 9,200 | loss: 0.7936 | lr: 6.00e-05 | dt: 684.71ms | tok/sec: 5982.08 | norm: 5.31\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 9,300 | loss: 1.0805 | lr: 6.00e-05 | dt: 684.98ms | tok/sec: 5979.70 | norm: 5.38\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 9,400 | loss: 0.5698 | lr: 6.00e-05 | dt: 682.32ms | tok/sec: 6003.03 | norm: 4.22\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 9,500 | loss: 0.6732 | lr: 6.00e-05 | dt: 683.20ms | tok/sec: 5995.35 | norm: 5.23\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 9,600 | loss: 0.4544 | lr: 6.00e-05 | dt: 685.71ms | tok/sec: 5973.34 | norm: 3.60\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 9,700 | loss: 0.4766 | lr: 6.00e-05 | dt: 682.04ms | tok/sec: 6005.54 | norm: 4.36\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 9,800 | loss: 0.6707 | lr: 6.00e-05 | dt: 685.59ms | tok/sec: 5974.41 | norm: 4.26\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 9,900 | loss: 0.6953 | lr: 6.00e-05 | dt: 683.24ms | tok/sec: 5994.98 | norm: 5.27\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 10,000 | loss: 0.5863 | lr: 6.00e-05 | dt: 684.74ms | tok/sec: 5981.86 | norm: 3.94\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 10,100 | loss: 0.5372 | lr: 6.00e-05 | dt: 687.72ms | tok/sec: 5955.88 | norm: 3.74\n", " \n", "GPU Memory: 0.67GB / 1.77GB\n", "step 10,200 | loss: 0.6054 | lr: 6.00e-05 | dt: 685.72ms | tok/sec: 5973.31 | norm: 5.71\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 10,300 | loss: 0.5850 | lr: 6.00e-05 | dt: 686.01ms | tok/sec: 5970.77 | norm: 4.36\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 10,400 | loss: 0.3319 | lr: 6.00e-05 | dt: 684.77ms | tok/sec: 5981.53 | norm: 4.68\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 10,500 | loss: 0.4140 | lr: 6.00e-05 | dt: 684.41ms | tok/sec: 5984.70 | norm: 3.21\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 10,600 | loss: 0.4008 | lr: 6.00e-05 | dt: 683.34ms | tok/sec: 5994.10 | norm: 3.58\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 10,700 | loss: 0.3951 | lr: 6.00e-05 | dt: 685.49ms | tok/sec: 5975.26 | norm: 3.81\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 10,800 | loss: 0.3022 | lr: 6.00e-05 | dt: 687.40ms | tok/sec: 5958.64 | norm: 3.06\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 10,900 | loss: 0.4287 | lr: 6.00e-05 | dt: 686.75ms | tok/sec: 5964.31 | norm: 3.60\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 11,000 | loss: 0.2447 | lr: 6.00e-05 | dt: 687.35ms | tok/sec: 5959.12 | norm: 3.35\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 11,100 | loss: 0.2773 | lr: 6.00e-05 | dt: 688.83ms | tok/sec: 5946.35 | norm: 2.71\n", " \n", "GPU Memory: 0.67GB / 1.77GB\n", "step 11,200 | loss: 0.2839 | lr: 6.00e-05 | dt: 687.56ms | tok/sec: 5957.31 | norm: 3.90\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 11,300 | loss: 0.3481 | lr: 6.00e-05 | dt: 684.68ms | tok/sec: 5982.32 | norm: 3.68\n", " \n", "GPU Memory: 0.78GB / 1.77GB\n", "step 11,400 | loss: 0.1913 | lr: 6.00e-05 | dt: 685.73ms | tok/sec: 5973.18 | norm: 2.93\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 11,500 | loss: 0.2605 | lr: 6.00e-05 | dt: 685.74ms | tok/sec: 5973.11 | norm: 2.96\n", " \n", "GPU Memory: 0.68GB / 1.77GB\n", "step 11,600 | loss: 0.2029 | lr: 6.00e-05 | dt: 689.04ms | tok/sec: 5944.49 | norm: 2.84\n", " \n", "\n", "Reached target loss! Final loss: 0.0889 at step 11,663\n", "Model saved to gpt_model.pt\n" ] } ], "source": [ "# SEED\n", "torch.manual_seed(1337)\n", "if torch.cuda.is_available():\n", " torch.cuda.manual_seed(1337)\n", "\n", "# STOP\n", "num_return_sequences = 5\n", "max_length = 30\n", "\n", "# Mixed precision and model\n", "torch.set_float32_matmul_precision('high')\n", "model = GPT(GPTConfig())\n", "model.to(device)\n", "model = torch.compile(model)\n", "\n", "# CODE UPDATE HERE\n", "max_lr = 6e-4\n", "min_lr = max_lr * 0.1\n", "warmup_steps = 10\n", "max_steps = 50\n", "\n", "# Dataloader\n", "train_loader = DataLoaderLite(B = 4, T = 256)\n", "\n", "# Optimizer\n", "optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device)\n", "step = 0\n", "\n", "# Add memory optimization settings\n", "torch.backends.cuda.matmul.allow_tf32 = True\n", "torch.backends.cudnn.allow_tf32 = True\n", "torch.set_float32_matmul_precision('high')\n", "\n", "# In the training loop, add gradient accumulation\n", "gradient_accumulation_steps = 4 # Accumulate gradients over 4 steps\n", "\n", "while True:\n", " t0 = time.time()\n", "\n", " # Reset gradients at the start of accumulation\n", " optimizer.zero_grad()\n", "\n", " # Accumulate gradients\n", " for _ in range(gradient_accumulation_steps):\n", " x, y = train_loader.next_batch()\n", " x, y = x.to(device), y.to(device)\n", "\n", " with torch.autocast(device_type=device, dtype=torch.bfloat16):\n", " logits, loss = model(x, y)\n", " # Scale loss by accumulation steps\n", " loss = loss / gradient_accumulation_steps\n", "\n", " loss.backward()\n", "\n", " # Clip gradients and update weights once per accumulation\n", " norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", "\n", " lr = get_lr(step)\n", " for param_group in optimizer.param_groups:\n", " param_group['lr'] = lr\n", "\n", " optimizer.step()\n", "\n", " if device == 'cuda' and step % 100 == 0: # Print memory every 100 steps\n", " torch.cuda.synchronize()\n", " print(f\"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.memory_reserved() / 1e9:.2f}GB\")\n", "\n", " t1 = time.time()\n", " dt = (t1 - t0) * 1000\n", " tokens_per_sec = (train_loader.B * train_loader.T * gradient_accumulation_steps) / (t1 - t0)\n", "\n", " if step % 100 == 0: # Print details every 100 steps\n", " print(f'step {step:,} | loss: {loss.item()*gradient_accumulation_steps:.4f} | lr: {lr:.2e} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec:.2f} | norm: {norm:.2f}')\n", " print(\" \")\n", "\n", " actual_loss = loss.item() * gradient_accumulation_steps\n", " if actual_loss < 0.09:\n", " print(f'\\nReached target loss! Final loss: {actual_loss:.4f} at step {step:,}')\n", " save_path = 'gpt_model.pt'\n", " torch.save(model.state_dict(), save_path)\n", " print(f\"Model saved to {save_path}\")\n", " break\n", "\n", " step += 1" ] }, { "cell_type": "markdown", "metadata": { "id": "ZmZ4Yk9LQehd" }, "source": [ "## Save Model" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ByO3iW55Qehd", "outputId": "81185c36-fd16-416e-e439-64788940776f" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Total model parameters: 33,709,824\n", "Model saved to nano_gpt_model.pt\n" ] } ], "source": [ "# Print total model parameters\n", "total_params = sum(p.numel() for p in model.parameters())\n", "print(f\"Total model parameters: {total_params:,}\")\n", "\n", "# Save the model\n", "save_path = 'nano_gpt_model.pt'\n", "torch.save(model.state_dict(), save_path)\n", "print(f\"Model saved to {save_path}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Tj9Rs-dysuOg" }, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "CE2_CV1TcttD", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "420c8d5e-7c65-4f6c-bc1e-467350ea6468" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Generating text samples...\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:87: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n", " warnings.warn(\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", "Generated text:\n", "\n", "Once upon a time to not;\n", "More slaughter'd, sweet Rivers, I receive my children, and title with pardon hither\n", "That one stuff'd with a conquest; and teeth, of my? Why, in life thee,\n", "Which now not joy of foe, thought o'n slaughter bed,\n", "And, is mine own soul me, not so heavy in every day:\n", "The tyrant from one curst my death lies;\n", "For the ground is nothing henceforth fell executioner come\n" ] } ], "source": [ "# Text generation\n", "print(\"\\nGenerating text samples...\")\n", "enc = tiktoken.get_encoding('gpt2')\n", "context = \"Once upon a time\"\n", "x = torch.tensor([enc.encode(context)], dtype=torch.long, device=device)\n", "\n", "max_length = 100 # Generate 100 tokens\n", "torch.manual_seed(42)\n", "if torch.cuda.is_available():\n", " torch.cuda.manual_seed(42)\n", "\n", "while x.size(1) < max_length:\n", " with torch.no_grad():\n", " with torch.autocast(device_type=device, dtype=torch.bfloat16):\n", " logits = model(x)[0]\n", " logits = logits[:, -1, :]\n", " probs = F.softmax(logits, dim=-1)\n", " topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)\n", " ix = torch.multinomial(topk_probs, num_samples=1)\n", " xcol = torch.gather(topk_indices, -1, ix)\n", " x = torch.cat([x, xcol], dim=1)\n", "\n", "print(\"\\nGenerated text:\")\n", "tokens = x[0].tolist() # Take first sequence\n", "decoded = enc.decode(tokens)\n", "print(f\"\\n{decoded}\")" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }