{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "bm4emgbs_ckl" }, "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICES']='0'\n", "import torch\n", "import math,os,requests, random\n", "from torch import nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from tqdm.notebook import tqdm\n", "import inspect\n", "from torch.cuda.amp import autocast, GradScaler\n", "from time import time" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "vGoI0xQRU-kY" }, "outputs": [], "source": [ "device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'\n", "batch_size = 512\n", "block_size = 256" ] }, { "cell_type": "markdown", "metadata": { "id": "_GdXc1kJneyH" }, "source": [ "## Text input" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "8HSIgCNFE2BB" }, "outputs": [], "source": [ "path = os.path.join('/home/datta0/','cano.txt')\n", "if not os.path.isfile(path):\n", " response = requests.get(\"https://sherlock-holm.es/stories/plain-text/cano.txt\")\n", " if response.status_code == 200:\n", " # Save the content to a local file\n", " with open(path, \"w\", encoding=\"utf-8\") as file:\n", " file.write(response.text)\n", " print(\"File downloaded successfully.\")\n", " else:\n", " print(f\"Failed to download file. Status code: {response.status_code}\")\n", "\n", "with open(path, 'r') as f:\n", " total_text = f.read()\n" ] }, { "cell_type": "markdown", "metadata": { "id": "1evQKNOvnnfd" }, "source": [ "## Process data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "brkGH7_lHQdF" }, "outputs": [], "source": [ "all_characters = sorted(list(set([x for x in total_text])))\n", "char_to_idx = {char:idx for idx, char in enumerate(all_characters)}\n", "idx_to_char = {idx:char for idx, char in enumerate(all_characters)}" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'\\n': 0,\n", " ' ': 1,\n", " '!': 2,\n", " '\"': 3,\n", " '&': 4,\n", " \"'\": 5,\n", " '(': 6,\n", " ')': 7,\n", " '*': 8,\n", " ',': 9,\n", " '-': 10,\n", " '.': 11,\n", " '0': 12,\n", " '1': 13,\n", " '2': 14,\n", " '3': 15,\n", " '4': 16,\n", " '5': 17,\n", " '6': 18,\n", " '7': 19,\n", " '8': 20,\n", " '9': 21,\n", " ':': 22,\n", " ';': 23,\n", " '?': 24,\n", " 'A': 25,\n", " 'B': 26,\n", " 'C': 27,\n", " 'D': 28,\n", " 'E': 29,\n", " 'F': 30,\n", " 'G': 31,\n", " 'H': 32,\n", " 'I': 33,\n", " 'J': 34,\n", " 'K': 35,\n", " 'L': 36,\n", " 'M': 37,\n", " 'N': 38,\n", " 'O': 39,\n", " 'P': 40,\n", " 'Q': 41,\n", " 'R': 42,\n", " 'S': 43,\n", " 'T': 44,\n", " 'U': 45,\n", " 'V': 46,\n", " 'W': 47,\n", " 'X': 48,\n", " 'Y': 49,\n", " 'Z': 50,\n", " '[': 51,\n", " ']': 52,\n", " '`': 53,\n", " 'a': 54,\n", " 'b': 55,\n", " 'c': 56,\n", " 'd': 57,\n", " 'e': 58,\n", " 'f': 59,\n", " 'g': 60,\n", " 'h': 61,\n", " 'i': 62,\n", " 'j': 63,\n", " 'k': 64,\n", " 'l': 65,\n", " 'm': 66,\n", " 'n': 67,\n", " 'o': 68,\n", " 'p': 69,\n", " 'q': 70,\n", " 'r': 71,\n", " 's': 72,\n", " 't': 73,\n", " 'u': 74,\n", " 'v': 75,\n", " 'w': 76,\n", " 'x': 77,\n", " 'y': 78,\n", " 'z': 79,\n", " '£': 80,\n", " '°': 81,\n", " 'ß': 82,\n", " 'à': 83,\n", " 'â': 84,\n", " 'è': 85,\n", " 'é': 86,\n", " 'ê': 87,\n", " 'î': 88,\n", " 'ñ': 89,\n", " 'ô': 90,\n", " 'ö': 91,\n", " 'û': 92,\n", " 'ü': 93,\n", " '’': 94}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "char_to_idx" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "BDDriRiOOJof" }, "outputs": [], "source": [ "def encode(text):\n", " return [char_to_idx.get(x) for x in text]\n", "def decode(indices):\n", " return [idx_to_char.get(idx) for idx in indices]\n", "\n", "def batch_encode(batch):\n", " batch_encoded = []\n", " for text in batch:\n", " batch_encoded.append(encode(text))\n", " return batch_encoded\n", "\n", "def batch_decode(batch):\n", " batch_decoded = []\n", " for indices in batch:\n", " batch_decoded.append(decode(indices))\n", " return batch_decoded" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "total_text_len = len(total_text)\n", "\n", "train_len = int(0.9*total_text_len)\n", "encoded_train_text = np.array(encode(total_text[:train_len]))\n", "val_len = total_text_len - train_len\n", "encoded_val_text = np.array(encode(total_text[train_len:]))\n", "\n", "def get_batch(encoded_text, batch_size,total_len):\n", " ix = torch.randint(total_len - block_size, (batch_size,)) # get random starting indices \n", " x = torch.stack([torch.from_numpy((encoded_text[i:i+block_size]).astype(np.int64)) for i in ix])\n", " y = torch.stack([torch.from_numpy((encoded_text[i+1:i+1+block_size]).astype(np.int64)) for i in ix])\n", " x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)\n", " return x,y\n", "\n", "def get_data(split,batch_size=512):\n", " if split=='train':\n", " return get_batch(encoded_train_text, batch_size, train_len)\n", " else:\n", " return get_batch(encoded_val_text, batch_size, val_len)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((tensor([[65, 65, 1, 66, 58, 1, 73, 61, 54, 73, 1, 72, 61, 58, 1, 62, 72, 0,\n", " 62, 67, 1, 36, 68, 67, 57, 68, 67, 9, 1, 55, 74, 73, 1, 54, 72, 1,\n", " 76, 58, 1, 61, 54, 75, 58, 1, 54, 73, 1, 69, 71, 58, 72, 58, 67, 73,\n", " 1, 67, 68, 1, 69, 68, 72, 72, 62, 55, 65, 58, 1, 66, 58, 54, 67, 72,\n", " 1, 68, 59, 1, 73, 58, 65, 65, 62, 67, 60, 0, 76, 61, 58, 71, 58, 9,\n", " 1, 76, 58, 1, 56, 54, 67, 1, 68, 67, 65, 78, 1, 73, 54, 64, 58, 1,\n", " 73, 61, 58, 1, 68, 55, 75, 62, 68, 74, 72, 1, 72, 73, 58, 69, 72, 9,\n", " 1, 58, 54, 73, 1, 68, 74, 71, 1, 57, 62, 67, 67, 58, 71, 9, 1, 54,\n", " 67, 57, 0, 69, 68, 72, 72, 58, 72, 72, 1, 68, 74, 71, 1, 72, 68, 74,\n", " 65, 72, 1, 62, 67, 1, 69, 54, 73, 62, 58, 67, 56, 58, 11, 1, 36, 54,\n", " 73, 58, 71, 1, 62, 67, 1, 73, 61, 58, 1, 58, 75, 58, 67, 62, 67, 60,\n", " 1, 33, 1, 76, 62, 65, 65, 1, 72, 73, 71, 68, 65, 65, 0, 57, 68, 76,\n", " 67, 1, 54, 67, 57, 1, 61, 54, 75, 58, 1, 54, 1, 76, 68, 71, 57, 1,\n", " 76, 62, 73, 61, 1, 59, 71, 62, 58, 67, 57, 1, 36, 58, 72, 73, 71, 54,\n", " 57, 58, 1, 54]], device='cuda:0'),\n", " tensor([[65, 1, 66, 58, 1, 73, 61, 54, 73, 1, 72, 61, 58, 1, 62, 72, 0, 62,\n", " 67, 1, 36, 68, 67, 57, 68, 67, 9, 1, 55, 74, 73, 1, 54, 72, 1, 76,\n", " 58, 1, 61, 54, 75, 58, 1, 54, 73, 1, 69, 71, 58, 72, 58, 67, 73, 1,\n", " 67, 68, 1, 69, 68, 72, 72, 62, 55, 65, 58, 1, 66, 58, 54, 67, 72, 1,\n", " 68, 59, 1, 73, 58, 65, 65, 62, 67, 60, 0, 76, 61, 58, 71, 58, 9, 1,\n", " 76, 58, 1, 56, 54, 67, 1, 68, 67, 65, 78, 1, 73, 54, 64, 58, 1, 73,\n", " 61, 58, 1, 68, 55, 75, 62, 68, 74, 72, 1, 72, 73, 58, 69, 72, 9, 1,\n", " 58, 54, 73, 1, 68, 74, 71, 1, 57, 62, 67, 67, 58, 71, 9, 1, 54, 67,\n", " 57, 0, 69, 68, 72, 72, 58, 72, 72, 1, 68, 74, 71, 1, 72, 68, 74, 65,\n", " 72, 1, 62, 67, 1, 69, 54, 73, 62, 58, 67, 56, 58, 11, 1, 36, 54, 73,\n", " 58, 71, 1, 62, 67, 1, 73, 61, 58, 1, 58, 75, 58, 67, 62, 67, 60, 1,\n", " 33, 1, 76, 62, 65, 65, 1, 72, 73, 71, 68, 65, 65, 0, 57, 68, 76, 67,\n", " 1, 54, 67, 57, 1, 61, 54, 75, 58, 1, 54, 1, 76, 68, 71, 57, 1, 76,\n", " 62, 73, 61, 1, 59, 71, 62, 58, 67, 57, 1, 36, 58, 72, 73, 71, 54, 57,\n", " 58, 1, 54, 73]], device='cuda:0')),\n", " [['l',\n", " 'l',\n", " ' ',\n", " 'm',\n", " 'e',\n", " ' ',\n", " 't',\n", " 'h',\n", " 'a',\n", " 't',\n", " ' ',\n", " 's',\n", " 'h',\n", " 'e',\n", " ' ',\n", " 'i',\n", " 's',\n", " '\\n',\n", " 'i',\n", " 'n',\n", " ' ',\n", " 'L',\n", " 'o',\n", " 'n',\n", " 'd',\n", " 'o',\n", " 'n',\n", " ',',\n", " ' ',\n", " 'b',\n", " 'u',\n", " 't',\n", " ' ',\n", " 'a',\n", " 's',\n", " ' ',\n", " 'w',\n", " 'e',\n", " ' ',\n", " 'h',\n", " 'a',\n", " 'v',\n", " 'e',\n", " ' ',\n", " 'a',\n", " 't',\n", " ' ',\n", " 'p',\n", " 'r',\n", " 'e',\n", " 's',\n", " 'e',\n", " 'n',\n", " 't',\n", " ' ',\n", " 'n',\n", " 'o',\n", " ' ',\n", " 'p',\n", " 'o',\n", " 's',\n", " 's',\n", " 'i',\n", " 'b',\n", " 'l',\n", " 'e',\n", " ' ',\n", " 'm',\n", " 'e',\n", " 'a',\n", " 'n',\n", " 's',\n", " ' ',\n", " 'o',\n", " 'f',\n", " ' ',\n", " 't',\n", " 'e',\n", " 'l',\n", " 'l',\n", " 'i',\n", " 'n',\n", " 'g',\n", " '\\n',\n", " 'w',\n", " 'h',\n", " 'e',\n", " 'r',\n", " 'e',\n", " ',',\n", " ' ',\n", " 'w',\n", " 'e',\n", " ' ',\n", " 'c',\n", " 'a',\n", " 'n',\n", " ' ',\n", " 'o',\n", " 'n',\n", " 'l',\n", " 'y',\n", " ' ',\n", " 't',\n", " 'a',\n", " 'k',\n", " 'e',\n", " ' ',\n", " 't',\n", " 'h',\n", " 'e',\n", " ' ',\n", " 'o',\n", " 'b',\n", " 'v',\n", " 'i',\n", " 'o',\n", " 'u',\n", " 's',\n", " ' ',\n", " 's',\n", " 't',\n", " 'e',\n", " 'p',\n", " 's',\n", " ',',\n", " ' ',\n", " 'e',\n", " 'a',\n", " 't',\n", " ' ',\n", " 'o',\n", " 'u',\n", " 'r',\n", " ' ',\n", " 'd',\n", " 'i',\n", " 'n',\n", " 'n',\n", " 'e',\n", " 'r',\n", " ',',\n", " ' ',\n", " 'a',\n", " 'n',\n", " 'd',\n", " '\\n',\n", " 'p',\n", " 'o',\n", " 's',\n", " 's',\n", " 'e',\n", " 's',\n", " 's',\n", " ' ',\n", " 'o',\n", " 'u',\n", " 'r',\n", " ' ',\n", " 's',\n", " 'o',\n", " 'u',\n", " 'l',\n", " 's',\n", " ' ',\n", " 'i',\n", " 'n',\n", " ' ',\n", " 'p',\n", " 'a',\n", " 't',\n", " 'i',\n", " 'e',\n", " 'n',\n", " 'c',\n", " 'e',\n", " '.',\n", " ' ',\n", " 'L',\n", " 'a',\n", " 't',\n", " 'e',\n", " 'r',\n", " ' ',\n", " 'i',\n", " 'n',\n", " ' ',\n", " 't',\n", " 'h',\n", " 'e',\n", " ' ',\n", " 'e',\n", " 'v',\n", " 'e',\n", " 'n',\n", " 'i',\n", " 'n',\n", " 'g',\n", " ' ',\n", " 'I',\n", " ' ',\n", " 'w',\n", " 'i',\n", " 'l',\n", " 'l',\n", " ' ',\n", " 's',\n", " 't',\n", " 'r',\n", " 'o',\n", " 'l',\n", " 'l',\n", " '\\n',\n", " 'd',\n", " 'o',\n", " 'w',\n", " 'n',\n", " ' ',\n", " 'a',\n", " 'n',\n", " 'd',\n", " ' ',\n", " 'h',\n", " 'a',\n", " 'v',\n", " 'e',\n", " ' ',\n", " 'a',\n", " ' ',\n", " 'w',\n", " 'o',\n", " 'r',\n", " 'd',\n", " ' ',\n", " 'w',\n", " 'i',\n", " 't',\n", " 'h',\n", " ' ',\n", " 'f',\n", " 'r',\n", " 'i',\n", " 'e',\n", " 'n',\n", " 'd',\n", " ' ',\n", " 'L',\n", " 'e',\n", " 's',\n", " 't',\n", " 'r',\n", " 'a',\n", " 'd',\n", " 'e',\n", " ' ',\n", " 'a']])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = get_data('train',1)\n", "b = batch_decode(a[0].tolist())\n", "a,b" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "cklZZJgOb-qa" }, "outputs": [], "source": [ "def repeat_kv(hidden_states, repeat_times):\n", " if repeat_times == 1:\n", " return hidden_states\n", " batch, n_kv_heads, seq_len, head_dim = hidden_states.shape #Shape of q aka Wq@x\n", " hidden_states = hidden_states[:,:,None,:,:].expand(batch, n_kv_heads, repeat_times, seq_len, head_dim) #\n", " return hidden_states.reshape(batch, n_kv_heads*repeat_times, seq_len, head_dim)" ] }, { "cell_type": "markdown", "metadata": { "id": "rqzKT_7lns4W" }, "source": [ "## Model Architecture" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "XTjOd3VztBQz" }, "outputs": [], "source": [ "class RMSNorm(nn.Module):\n", " def __init__(self, hidden_size, eps=1e-6):\n", " super().__init__()\n", " self.weight = nn.Parameter(torch.ones(hidden_size))\n", " self.variance_epsilon = eps\n", "\n", " def forward(self, hidden_states):\n", " input_dtype = hidden_states.dtype\n", " hidden_states = hidden_states.to(torch.float32)\n", " variance = hidden_states.pow(2).mean(-1, keepdim=True)\n", " hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n", " return self.weight * hidden_states.to(input_dtype)\n", "\n", "class PositionalEncoding(nn.Module):\n", " def __init__(self, hidden_size, max_seq_len):\n", " super().__init__()\n", " self.encoding = torch.zeros(max_seq_len, hidden_size)\n", " position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)\n", " div_term = torch.exp(torch.arange(0, hidden_size, 2).float() * (-math.log(10000.0) / hidden_size))\n", " self.encoding[:, 0::2] = torch.sin(position * div_term)\n", " self.encoding[:, 1::2] = torch.cos(position * div_term)\n", " self.encoding = self.encoding.unsqueeze(0)\n", "\n", " def forward(self, x):\n", " if self.encoding.device != x.device:\n", " self.encoding = self.encoding.to(x.device)\n", " return x + self.encoding[:, :x.size(1)].detach()" ] }, { "attachments": { "image.png": { "image/png": "" } }, "cell_type": "markdown", "metadata": {}, "source": [ "![image.png](attachment:image.png)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "0BTGrXZWQIkt" }, "outputs": [], "source": [ "class Attention(nn.Module):\n", " def __init__(self,n_attn_heads,n_kv_heads,hidden_size,max_len=256):\n", " super().__init__()\n", "\n", " assert hidden_size%n_attn_heads==0\n", " assert n_attn_heads%n_kv_heads==0\n", "\n", " self.head_dim = hidden_size // n_attn_heads\n", " kv_size = n_kv_heads * self.head_dim\n", " self.hidden_size = hidden_size\n", " self.n_attn_heads = n_attn_heads\n", " self.n_kv_heads = n_kv_heads\n", "\n", " self.q = nn.Linear(hidden_size, hidden_size, bias=False) #WQ\n", " self.k = nn.Linear(hidden_size, kv_size, bias=False) #WK\n", " self.v = nn.Linear(hidden_size, kv_size, bias=False) #WV\n", "\n", " self.register_buffer('tril',torch.tril(torch.ones(max_len,max_len)).view(1,1,max_len,max_len))\n", "\n", " def forward(self, x, echo = False):\n", "\n", " batch_size, seq_len, hidden_dim = x.shape\n", "\n", " #Pass the inputs through QKV matrices\n", " q = self.q(x) #Q\n", " k = self.k(x) #K\n", " v = self.v(x) #V\n", "\n", " q = q.view(batch_size, seq_len, self.n_attn_heads, self.head_dim).transpose(1, 2)\n", " k = k.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)\n", " v = v.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)\n", "\n", " # If n_attn_heads!=n_kv_heads, we need to repeat the same computation on same n_kv_heads for n_attn_heads/n_kv_heads times\n", " # So we just repeat n_kv_heads to match n_attn_heads size\n", " k = repeat_kv(k, self.n_attn_heads//self.n_kv_heads)\n", " v = repeat_kv(v, self.n_attn_heads//self.n_kv_heads)\n", "\n", " attention = (q @ k.transpose(-2,-1)) * (1.0/math.sqrt(self.hidden_size))\n", " attention = attention.masked_fill(self.tril[:,:,:seq_len,:seq_len]==0, float('-inf'))\n", " probs = nn.functional.softmax(attention,dim=-1)\n", " y = probs@v\n", " y = y.transpose(1,2).contiguous().reshape(batch_size, seq_len, -1)\n", "\n", "\n", " return y\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "tqlQ_dDoIbwQ" }, "outputs": [], "source": [ "class MLP(nn.Module):\n", " def __init__(self, hidden_size, intermediate_size,):\n", " super().__init__()\n", "\n", " self.hidden_size = hidden_size\n", " self.intermediate_size = intermediate_size\n", "\n", " self.up = nn.Linear(hidden_size, intermediate_size, bias=False)\n", " self.gate = nn.Linear(hidden_size, intermediate_size, bias=False)\n", " self.down = nn.Linear(intermediate_size, hidden_size, bias=False)\n", " self.act_fn = nn.GELU()\n", "\n", " def forward(self,x):\n", "\n", " up = self.up(x)\n", " gate = self.gate(x)\n", "\n", " # note that * in torch is element wise multiplication. The two operands need to be of same size.\n", " return self.down(self.act_fn(up * gate))\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "jHuI7TvQIjsQ" }, "outputs": [], "source": [ "class TransformerBlock(nn.Module):\n", " def __init__(self, hidden_size, n_attn_heads, n_kv_heads, intermediate_size,max_len, residual=True):\n", " super().__init__()\n", "\n", " self.attn = Attention(n_attn_heads,n_kv_heads,hidden_size,max_len)\n", " self.mlp = MLP(hidden_size, intermediate_size)\n", " self.residual = residual\n", " self.norm = nn.LayerNorm(hidden_size)\n", "\n", " def forward(self, x, normalise):\n", "\n", " if normalise:\n", " normalised_x = self.norm(x)\n", " attn_out = self.attn(normalised_x)\n", " else:\n", " attn_out = self.attn(x)\n", "\n", " if self.residual:\n", " attn_out = x + attn_out\n", "\n", " if normalise:\n", " normalised_x = self.norm(attn_out)\n", " mlp_out = self.mlp(normalised_x)\n", " else:\n", " mlp_out = self.mlp(attn_out)\n", "\n", " if self.residual:\n", " mlp_out = attn_out + mlp_out\n", "\n", " return mlp_out" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "vachyBjHIrb2" }, "outputs": [], "source": [ "class NanoLlama(nn.Module):\n", "\n", " def __init__(self,n_layers, vocab_size, hidden_size, n_attn_heads, n_kv_heads, intermediate_size,max_len, residual, normalise=True):\n", " super().__init__()\n", " self.embedding = nn.Embedding(vocab_size, hidden_size)\n", " self.n_layers = n_layers\n", " self.layers = nn.ModuleList(\n", " [TransformerBlock(hidden_size, n_attn_heads, n_kv_heads, intermediate_size, max_len, residual) for _ in range(n_layers)]\n", " )\n", " self.normalise = normalise\n", " self.norm = nn.LayerNorm(hidden_size)\n", "\n", " def forward(self,x):\n", "\n", " x = self.embedding(x)\n", " for layer in self.layers:\n", " x = layer(x, self.normalise)\n", "\n", " if self.normalise:\n", " x = self.norm(x)\n", "\n", " return x\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "bBehoFFvIpJM" }, "outputs": [], "source": [ "class NanoLlamaForCausalLM(nn.Module):\n", "\n", " def __init__(self,n_layers, vocab_size, hidden_size, n_attn_heads, n_kv_heads, intermediate_size,max_len=256, residual=True, normalise=True):\n", " super().__init__()\n", "\n", " self.model = NanoLlama(n_layers, vocab_size, hidden_size, n_attn_heads, n_kv_heads, intermediate_size,max_len, residual,normalise)\n", " self.lm_head = nn.Linear(hidden_size,vocab_size, bias=False)\n", " self.max_len = max_len\n", " self.n_layers = n_layers\n", " self.n_attn_heads = n_attn_heads\n", " self.n_kv_heads = n_kv_heads\n", " self.hidden_dim = hidden_size\n", "\n", "\n", " # Apply Kaiming uniform initialization to the weights of the linear layers\n", " for m in self.modules():\n", " if isinstance(m, nn.Linear):\n", " nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')\n", "\n", " def forward(self,input_ids,targets=None):\n", " x = self.model(input_ids)\n", "\n", " if targets is not None:\n", " logits = self.lm_head(x)\n", " loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)\n", " else:\n", " # we only need to pass the last token's outputs through lm_head. Rest we ignore.\n", " logits = self.lm_head(x[:, [-1], :])\n", " loss = None\n", "\n", " return logits,loss\n", "\n", " @torch.no_grad()\n", " def generate(self, input_ids, max_new_tokens=20, temperature=1.0, sample = False):\n", "\n", " if input_ids.device!=self.model.embedding.weight.device:\n", " input_ids = input_ids.to(self.model.embedding.weight.device)\n", "\n", " assert max_new_tokens>0\n", " assert temperature>0\n", "\n", " tokens_generated = 0\n", " while True:\n", " logits,loss = self.forward(input_ids)\n", " final_token_logits = logits[:,-1,:]\n", " if not sample:\n", " next_token = torch.argmax(final_token_logits) # Return the token with max prob\n", "\n", " #Sample from multinomial distribution with probabilities calculated from logits\n", " final_token_logits = final_token_logits/temperature # scale by temperature\n", " probs = nn.functional.softmax(final_token_logits, dim=-1)\n", " next_token = torch.multinomial(probs, num_samples=1)\n", "\n", " tokens_generated += 1\n", " input_ids = torch.cat((input_ids, next_token), dim=1) # Add next token ID to input_ids for generating further tokens\n", " # print(f'input ids shape {input_ids.shape[-1]}, {max_new_tokens}')\n", " if input_ids.shape[-1]>=min(self.max_len,max_new_tokens)-2:\n", " break\n", "\n", " del logits # delete the logits to save memory\n", "\n", " return input_ids.cpu().numpy()\n", " \n", " def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):\n", " # start with all of the candidate parameters\n", " param_dict = {pn: p for pn, p in self.named_parameters()}\n", " # filter out those that do not require grad\n", " param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}\n", " # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.\n", " # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.\n", " decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]\n", " nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]\n", " optim_groups = [\n", " {'params': decay_params, 'weight_decay': weight_decay},\n", " {'params': nodecay_params, 'weight_decay': 0.0}\n", " ]\n", " num_decay_params = sum(p.numel() for p in decay_params)\n", " num_nodecay_params = sum(p.numel() for p in nodecay_params)\n", " print(f\"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters\")\n", " print(f\"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters\")\n", " # Create AdamW optimizer and use the fused version if it is available\n", " fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters\n", " use_fused = fused_available and device_type == 'cuda'\n", " extra_args = dict(fused=True) if use_fused else dict()\n", " optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)\n", " print(f\"using fused AdamW: {use_fused}\")\n", "\n", " return optimizer\n", " \n", " def get_num_params(self):\n", " n_params = sum(p.numel() for p in self.parameters())\n", " return n_params\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Uf4a_Fmnnxti" }, "source": [ "## Training Setup" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "nrs3UvfBU47e" }, "outputs": [], "source": [ "def get_lr(it, warmup_iters, lr_decay_iters, min_lr, learning_rate):\n", " if it < warmup_iters:\n", " return learning_rate * it / warmup_iters\n", " # 2) if it > lr_decay_iters, return min learning rate\n", " if it > lr_decay_iters:\n", " return min_lr\n", " # 3) in between, use cosine decay down to min learning rate\n", " decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)\n", " assert 0 <= decay_ratio <= 1\n", " coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1\n", " return min_lr + coeff * (learning_rate - min_lr)\n", "\n", "@torch.no_grad()\n", "def run_model_and_get_loss(model, steps=100):\n", "\n", " model.eval()\n", " phase_wise_loss = {}\n", " for phase in ['train', 'val']:\n", " losses = []\n", " for _ in range(steps):\n", " input_ids, labels = get_data(phase)\n", " _, loss = model(input_ids, labels)\n", " losses.append(loss.item())\n", " phase_wise_loss[phase] = np.mean(losses)\n", " model.train()\n", "\n", " return phase_wise_loss['train'], phase_wise_loss['val']\n", "\n", "def train_model(model, optimizer, num_iters, device, accumulation_steps=1, eval_steps=100, lr_decay=False, batch_size=512, max_grad_norm=-1,train_dtype=torch.float32):\n", " \n", " train_losses = []\n", " val_losses = []\n", " model = model.to(device)\n", " \n", " warmup_iters = 100\n", " lr_decay_iters = num_iters\n", " min_lr = 1e-4\n", " start_lr = 1e-3\n", "\n", " scaler = GradScaler(enabled=(train_dtype!=torch.float32)) # to make sure grads are in FP32 even for BF/FP16 trainig\n", "\n", " inputs, labels = get_data('train', batch_size)\n", "\n", " for iter_num in tqdm(range(num_iters),'training'):\n", "\n", " start_time = time()\n", "\n", " lr = get_lr(iter_num, warmup_iters, lr_decay_iters, min_lr, start_lr) if lr_decay else start_lr\n", " for param_group in optimizer.param_groups:\n", " param_group['lr'] = lr\n", " \n", " for _ in range(accumulation_steps):\n", " with torch.autocast(dtype=train_dtype, device_type='cuda'):\n", " _, loss = model(inputs, labels)\n", " loss = loss / accumulation_steps\n", "\n", " scaler.scale(loss).backward()\n", " \n", " if max_grad_norm != -1:\n", " scaler.unscale_(optimizer)\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)\n", "\n", " inputs, labels = get_data('train',batch_size)\n", "\n", " scaler.step(optimizer)\n", " scaler.update()\n", " optimizer.zero_grad(set_to_none=True)\n", "\n", " end_time = time()\n", " \n", " if (iter_num + 1) % eval_steps == 0:\n", " with torch.autocast(dtype=train_dtype, device_type='cuda'):\n", " train_loss, val_loss = run_model_and_get_loss(model)\n", " train_losses.append(train_loss)\n", " val_losses.append(val_loss)\n", "\n", " print(f'Iteration: {iter_num + 1}, Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}')\n", " \n", " plt.plot(train_losses, label='train_loss')\n", " plt.plot(val_losses, label='val_loss')\n", " plt.xlabel('Iterations')\n", " plt.ylabel('Loss')\n", " plt.legend()\n", " plt.title('Training and Validation Loss')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Smol models" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "rJAiS4QlW401", "outputId": "a0e5dbef-7365-4cc5-eadd-ced2a4c6e918" }, "outputs": [ { "data": { "text/plain": [ "184832" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nano_llama = NanoLlamaForCausalLM(n_layers=3, vocab_size=len(all_characters)+1, hidden_size=64, n_attn_heads=4, n_kv_heads=2, intermediate_size=256,max_len=block_size, residual=False, normalise=False)\n", "nano_llama.get_num_params()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "CPrW_hHcNGU9" }, "outputs": [], "source": [ "nano_llama = nano_llama.to(device)\n", "nano_llama = torch.compile(nano_llama)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "vgmf-3nL83db" }, "outputs": [], "source": [ "optimizer = torch.optim.AdamW(nano_llama.parameters(), lr=0.0001)\n", "criterion = nn.CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 522, "referenced_widgets": [ "0de38f17934040f3a760965e3ab869de", "e178916466d64e568a8abfcb797ec0cd", "cafff4a16c9b429eb4d024ef73149d56", "30a9f865abc94f89ae5c2e9fd5352214", "9b59ce324223413fb6b65cda2e075966", "6270f7fbf2974a22b2eadd6d162e4864", "96a704c174184da38c51dfc011ee0479", "5103b8685eda4259ac8cebba368b846c", "e4d1332c6299421d8086aeda0d8cc952", "dadcbf8da464400f90d19585974e4aee", "b79782127a264d2f92d0110773d614e7" ] }, "id": "-mOH_3UeUT1U", "outputId": "77f49585-4efa-4091-ebae-9379c759ebcf" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c90612b17bff46d091c3084c4686f96f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "training: 0%| | 0/1000 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "train_model(nano_llama, optimizer, num_iters=1000,device=device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Residual" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![alt text](image.png)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "184832" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nano_llama = NanoLlamaForCausalLM(n_layers=3, vocab_size=len(all_characters)+1, hidden_size=64, n_attn_heads=4, n_kv_heads=2, intermediate_size=256,max_len=block_size, residual=True, normalise=False)\n", "nano_llama.to(device)\n", "nano_llama = torch.compile(nano_llama)\n", "nl_param_count = nano_llama.get_num_params()\n", "nl_param_count#, nano_llama" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f6578bace72e4c8983d2f9a21e361edf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "training: 0%| | 0/500 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "optimizer = torch.optim.AdamW(nano_llama.parameters(), lr=0.0001)\n", "train_model(nano_llama, optimizer, num_iters=500,device=device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Nomralise" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "eE9QXDMU6Nu8", "outputId": "cb0c5b8f-a332-40c9-868f-d5c443fbe926" }, "outputs": [ { "data": { "text/plain": [ "184832" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nano_llama = NanoLlamaForCausalLM(n_layers=3, vocab_size=len(all_characters)+1, hidden_size=64, n_attn_heads=4, n_kv_heads=2, intermediate_size=256,max_len=block_size, residual=False, normalise=True)\n", "nano_llama.to(device)\n", "nano_llama = torch.compile(nano_llama)\n", "nl_param_count = nano_llama.get_num_params()\n", "nl_param_count#, nano_llama" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b86f54f97b7c4c30a0c200c2c8a898de", "version_major": 2, "version_minor": 0 }, "text/plain": [ "training: 0%| | 0/500 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "optimizer = torch.optim.AdamW(nano_llama.parameters(), lr=0.0001)\n", "train_model(nano_llama, optimizer, num_iters=500,device=device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Both" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "184832" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nano_llama = NanoLlamaForCausalLM(n_layers=3, vocab_size=len(all_characters)+1, hidden_size=64, n_attn_heads=4, n_kv_heads=2, intermediate_size=256,max_len=block_size, residual=True, normalise=True)\n", "nano_llama.to(device)\n", "nano_llama = torch.compile(nano_llama)\n", "nl_param_count = nano_llama.get_num_params()\n", "nl_param_count#, nano_llama" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5b6826fb85254310a1c4f6c2797513b7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "training: 0%| | 0/500 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "optimizer = torch.optim.AdamW(nano_llama.parameters(), lr=0.0001)\n", "train_model(nano_llama, optimizer, num_iters=500,device=device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### max grad norm" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "184832" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nano_llama = NanoLlamaForCausalLM(n_layers=3, vocab_size=len(all_characters)+1, hidden_size=64, n_attn_heads=4, n_kv_heads=2, intermediate_size=256,max_len=block_size, residual=True, normalise=True)\n", "nano_llama.to(device)\n", "nano_llama = torch.compile(nano_llama)\n", "nl_param_count = nano_llama.get_num_params()\n", "nl_param_count#, nano_llama" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1091239b764c43e19171313a1430d375", "version_major": 2, "version_minor": 0 }, "text/plain": [ "training: 0%| | 0/1000 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "optimizer = torch.optim.AdamW(nano_llama.parameters(), lr=0.0001)\n", "train_model(nano_llama, optimizer, num_iters=1000,device=device,max_grad_norm=1.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Does Higher intermediate size matter" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "35db7a566cbe43ed9fc2374582e19dab", "version_major": 2, "version_minor": 0 }, "text/plain": [ "training: 0%| | 0/1000 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "094e260f7a1d454394eee03bfe67e170", "version_major": 2, "version_minor": 0 }, "text/plain": [ "training: 0%| | 0/1000 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a5382b461cc04dae90733648bda2697f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "training: 0%| | 0/1000 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7523d47cd5a5432b91e9d110bf9d7b8c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "training: 0%| | 0/1000 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9a95d281bf064f009c7321473686291e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "training: 0%| | 0/1000 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for intermediate_size in [32,64,128,256,512]:\n", " nano_llama = NanoLlamaForCausalLM(n_layers=3, vocab_size=len(all_characters)+1, hidden_size=64, n_attn_heads=4, n_kv_heads=2, intermediate_size=intermediate_size,max_len=block_size, residual=True, normalise=True)\n", " nano_llama.to(device)\n", " nano_llama = torch.compile(nano_llama)\n", " nl_param_count = nano_llama.get_num_params()\n", " nl_param_count#, nano_llama\n", " optimizer = torch.optim.AdamW(nano_llama.parameters(), lr=0.0001)\n", " train_model(nano_llama, optimizer, num_iters=1000,device=device,max_grad_norm=1.0)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b2a0a38bd1c24eaaa90d5d4aee91b15b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "training: 0%| | 0/1000 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a509fb51d6cd4572b5a1b724eeb71c7c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "training: 0%| | 0/1000 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f31e96c5b1404df5ae6a04232ead0984", "version_major": 2, "version_minor": 0 }, "text/plain": [ "training: 0%| | 0/1000 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for intermediate_size in [1024,2048,4096]:\n", " nano_llama = NanoLlamaForCausalLM(n_layers=3, vocab_size=len(all_characters)+1, hidden_size=64, n_attn_heads=4, n_kv_heads=2, intermediate_size=intermediate_size,max_len=block_size, residual=True, normalise=True)\n", " nano_llama.to(device)\n", " nano_llama = torch.compile(nano_llama)\n", " nl_param_count = nano_llama.get_num_params()\n", " nl_param_count#, nano_llama\n", " optimizer = torch.optim.AdamW(nano_llama.parameters(), lr=0.0001)\n", " train_model(nano_llama, optimizer, num_iters=1000,device=device,max_grad_norm=1.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Onto Big Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Dgv8Nza3BARd", "outputId": "f9c43ea5-439f-4dd7-8565-e09c257d64ce" }, "outputs": [], "source": [ "nano_llama_big = NanoLlamaForCausalLM(n_layers=12, vocab_size=128, hidden_size=512, n_attn_heads=16, n_kv_heads=8, intermediate_size=2048, residual=True, normalise=True)\n", "nl_param_count = nano_llama_big.get_num_params()\n", "nano_llama_big = torch.compile(nano_llama_big)\n", "nano_llama_big.to(device)\n", "print(f'Param count {nl_param_count} aka {nl_param_count/(10**6)} million params')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 522, "referenced_widgets": [ "803e8c5bb65c4c77887c5cc16e51ab9c", "60193095d4964169a1575d697fed6ad6", "4435c469978c4b0d9d726c4686363c62", "5db8580eca95451e9fadfd6e0a10ec4f", "296f3088d0844528a45bf36e21961348", "a89c511b42964247a67e8c1348abb792", "e9ff09f027564edbb43801c6dca01eb6", "1197b685eacc4c2daf56e75e8dcf088b", "7388c6af930a40b887f13b2aef17b6b5", "0ce9c707c0ac4ebd8fac23c8f318bfd8", "36ce14d7573a4d8cb8379808e0d5552b" ] }, "id": "jQC0rXWXxllE", "outputId": "e7afe0d4-89b3-4a52-9886-ed9683dafb63" }, "outputs": [], "source": [ "with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):\n", " torch.cuda.empty_cache()\n", " optimizer = nano_llama_big.configure_optimizers(0.1,0.001,(0.9,0.99),'cuda')\n", " train_model(nano_llama_big, optimizer, num_iters=2000,device=device, eval_steps = 200,lr_decay=True, batch_size = 64, max_grad_norm=1.0, train_dtype=torch.bfloat16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "id": "D87ILOfwxoZ7", "outputId": "00d8feab-a473-4d60-dda0-f6cfa48fdac4" }, "outputs": [], "source": [ "text = '\\n looking for clues. '\n", "encoded_text = encode(text)\n", "tensor_input = torch.tensor([encoded_text]).to(device).reshape(1,-1) # to adjust for the lack of batch\n", "out_tokens = nano_llama_big.generate(tensor_input, max_new_tokens=200,temperature=1)\n", "decoded_text = decode(out_tokens[0])\n", "print(''.join(decoded_text[len(text):]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "random_nano_llama_big = NanoLlamaForCausalLM(n_layers=12, vocab_size=128, hidden_size=512, n_attn_heads=16, n_kv_heads=8, intermediate_size=2048, residual=True, normalise=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "decoded_text2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "model = AutoModelForCausalLM.from_pretrained('/home/datta0/spt', trust_remote_code = True)\n", "tokenizer = AutoTokenizer.from_pretrained('/home/datta0/spt', trust_remote_code = True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 4 }