In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
import math,os,requests, random
from torch import nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import inspect
from torch.cuda.amp import autocast, GradScaler
from time import time

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
batch_size = 512
block_size = 256

## Text input

In [3]:
path = os.path.join('/home/datta0/','cano.txt')
if not os.path.isfile(path):
 response = requests.get("https://sherlock-holm.es/stories/plain-text/cano.txt")
 if response.status_code == 200:
 # Save the content to a local file
 with open(path, "w", encoding="utf-8") as file:
 file.write(response.text)
 print("File downloaded successfully.")
 else:
 print(f"Failed to download file. Status code: {response.status_code}")

with open(path, 'r') as f:
 total_text = f.read()


## Process data

In [4]:
all_characters = sorted(list(set([x for x in total_text])))
char_to_idx = {char:idx for idx, char in enumerate(all_characters)}
idx_to_char = {idx:char for idx, char in enumerate(all_characters)}

In [5]:
char_to_idx

{'\n': 0,
 ' ': 1,
 '!': 2,
 '"': 3,
 '&': 4,
 "'": 5,
 '(': 6,
 ')': 7,
 '*': 8,
 ',': 9,
 '-': 10,
 '.': 11,
 '0': 12,
 '1': 13,
 '2': 14,
 '3': 15,
 '4': 16,
 '5': 17,
 '6': 18,
 '7': 19,
 '8': 20,
 '9': 21,
 ':': 22,
 ';': 23,
 '?': 24,
 'A': 25,
 'B': 26,
 'C': 27,
 'D': 28,
 'E': 29,
 'F': 30,
 'G': 31,
 'H': 32,
 'I': 33,
 'J': 34,
 'K': 35,
 'L': 36,
 'M': 37,
 'N': 38,
 'O': 39,
 'P': 40,
 'Q': 41,
 'R': 42,
 'S': 43,
 'T': 44,
 'U': 45,
 'V': 46,
 'W': 47,
 'X': 48,
 'Y': 49,
 'Z': 50,
 '[': 51,
 ']': 52,
 '`': 53,
 'a': 54,
 'b': 55,
 'c': 56,
 'd': 57,
 'e': 58,
 'f': 59,
 'g': 60,
 'h': 61,
 'i': 62,
 'j': 63,
 'k': 64,
 'l': 65,
 'm': 66,
 'n': 67,
 'o': 68,
 'p': 69,
 'q': 70,
 'r': 71,
 's': 72,
 't': 73,
 'u': 74,
 'v': 75,
 'w': 76,
 'x': 77,
 'y': 78,
 'z': 79,
 '£': 80,
 '°': 81,
 'ß': 82,
 'à': 83,
 'â': 84,
 'è': 85,
 'é': 86,
 'ê': 87,
 'î': 88,
 'ñ': 89,
 'ô': 90,
 'ö': 91,
 'û': 92,
 'ü': 93,
 '’': 94}

In [6]:
def encode(text):
 return [char_to_idx.get(x) for x in text]
def decode(indices):
 return [idx_to_char.get(idx) for idx in indices]

def batch_encode(batch):
 batch_encoded = []
 for text in batch:
 batch_encoded.append(encode(text))
 return batch_encoded

def batch_decode(batch):
 batch_decoded = []
 for indices in batch:
 batch_decoded.append(decode(indices))
 return batch_decoded

In [7]:
total_text_len = len(total_text)

train_len = int(0.9*total_text_len)
encoded_train_text = np.array(encode(total_text[:train_len]))
val_len = total_text_len - train_len
encoded_val_text = np.array(encode(total_text[train_len:]))

def get_batch(encoded_text, batch_size,total_len):
 ix = torch.randint(total_len - block_size, (batch_size,)) # get random starting indices 
 x = torch.stack([torch.from_numpy((encoded_text[i:i+block_size]).astype(np.int64)) for i in ix])
 y = torch.stack([torch.from_numpy((encoded_text[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
 x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
 return x,y

def get_data(split,batch_size=512):
 if split=='train':
 return get_batch(encoded_train_text, batch_size, train_len)
 else:
 return get_batch(encoded_val_text, batch_size, val_len)

In [8]:
a = get_data('train',1)
b = batch_decode(a[0].tolist())
a,b

((tensor([[65, 65, 1, 66, 58, 1, 73, 61, 54, 73, 1, 72, 61, 58, 1, 62, 72, 0,
 62, 67, 1, 36, 68, 67, 57, 68, 67, 9, 1, 55, 74, 73, 1, 54, 72, 1,
 76, 58, 1, 61, 54, 75, 58, 1, 54, 73, 1, 69, 71, 58, 72, 58, 67, 73,
 1, 67, 68, 1, 69, 68, 72, 72, 62, 55, 65, 58, 1, 66, 58, 54, 67, 72,
 1, 68, 59, 1, 73, 58, 65, 65, 62, 67, 60, 0, 76, 61, 58, 71, 58, 9,
 1, 76, 58, 1, 56, 54, 67, 1, 68, 67, 65, 78, 1, 73, 54, 64, 58, 1,
 73, 61, 58, 1, 68, 55, 75, 62, 68, 74, 72, 1, 72, 73, 58, 69, 72, 9,
 1, 58, 54, 73, 1, 68, 74, 71, 1, 57, 62, 67, 67, 58, 71, 9, 1, 54,
 67, 57, 0, 69, 68, 72, 72, 58, 72, 72, 1, 68, 74, 71, 1, 72, 68, 74,
 65, 72, 1, 62, 67, 1, 69, 54, 73, 62, 58, 67, 56, 58, 11, 1, 36, 54,
 73, 58, 71, 1, 62, 67, 1, 73, 61, 58, 1, 58, 75, 58, 67, 62, 67, 60,
 1, 33, 1, 76, 62, 65, 65, 1, 72, 73, 71, 68, 65, 65, 0, 57, 68, 76,
 67, 1, 54, 67, 57, 1, 61, 54, 75, 58, 1, 54, 1, 76, 68, 71, 57, 1,
 76, 62, 73, 61, 1, 59, 71, 62, 58, 67, 57, 1, 36, 58, 72, 73, 71, 54,
 57, 58, 1, 54]], dev

In [9]:
def repeat_kv(hidden_states, repeat_times):
 if repeat_times == 1:
 return hidden_states
 batch, n_kv_heads, seq_len, head_dim = hidden_states.shape #Shape of q aka Wq@x
 hidden_states = hidden_states[:,:,None,:,:].expand(batch, n_kv_heads, repeat_times, seq_len, head_dim) #
 return hidden_states.reshape(batch, n_kv_heads*repeat_times, seq_len, head_dim)

## Model Architecture

In [10]:
class RMSNorm(nn.Module):
 def __init__(self, hidden_size, eps=1e-6):
 super().__init__()
 self.weight = nn.Parameter(torch.ones(hidden_size))
 self.variance_epsilon = eps

 def forward(self, hidden_states):
 input_dtype = hidden_states.dtype
 hidden_states = hidden_states.to(torch.float32)
 variance = hidden_states.pow(2).mean(-1, keepdim=True)
 hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 return self.weight * hidden_states.to(input_dtype)

class PositionalEncoding(nn.Module):
 def __init__(self, hidden_size, max_seq_len):
 super().__init__()
 self.encoding = torch.zeros(max_seq_len, hidden_size)
 position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
 div_term = torch.exp(torch.arange(0, hidden_size, 2).float() * (-math.log(10000.0) / hidden_size))
 self.encoding[:, 0::2] = torch.sin(position * div_term)
 self.encoding[:, 1::2] = torch.cos(position * div_term)
 self.encoding = self.encoding.unsqueeze(0)

 def forward(self, x):
 if self.encoding.device != x.device:
 self.encoding = self.encoding.to(x.device)
 return x + self.encoding[:, :x.size(1)].detach()

![image.png](attachment:image.png)

In [11]:
class Attention(nn.Module):
 def __init__(self,n_attn_heads,n_kv_heads,hidden_size,max_len=256):
 super().__init__()

 assert hidden_size%n_attn_heads==0
 assert n_attn_heads%n_kv_heads==0

 self.head_dim = hidden_size // n_attn_heads
 kv_size = n_kv_heads * self.head_dim
 self.hidden_size = hidden_size
 self.n_attn_heads = n_attn_heads
 self.n_kv_heads = n_kv_heads

 self.q = nn.Linear(hidden_size, hidden_size, bias=False) #WQ
 self.k = nn.Linear(hidden_size, kv_size, bias=False) #WK
 self.v = nn.Linear(hidden_size, kv_size, bias=False) #WV

 self.register_buffer('tril',torch.tril(torch.ones(max_len,max_len)).view(1,1,max_len,max_len))

 def forward(self, x, echo = False):

 batch_size, seq_len, hidden_dim = x.shape

 #Pass the inputs through QKV matrices
 q = self.q(x) #Q
 k = self.k(x) #K
 v = self.v(x) #V

 q = q.view(batch_size, seq_len, self.n_attn_heads, self.head_dim).transpose(1, 2)
 k = k.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
 v = v.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)

 # If n_attn_heads!=n_kv_heads, we need to repeat the same computation on same n_kv_heads for n_attn_heads/n_kv_heads times
 # So we just repeat n_kv_heads to match n_attn_heads size
 k = repeat_kv(k, self.n_attn_heads//self.n_kv_heads)
 v = repeat_kv(v, self.n_attn_heads//self.n_kv_heads)

 attention = (q @ k.transpose(-2,-1)) * (1.0/math.sqrt(self.hidden_size))
 attention = attention.masked_fill(self.tril[:,:,:seq_len,:seq_len]==0, float('-inf'))
 probs = nn.functional.softmax(attention,dim=-1)
 y = probs@v
 y = y.transpose(1,2).contiguous().reshape(batch_size, seq_len, -1)


 return y


In [12]:
class MLP(nn.Module):
 def __init__(self, hidden_size, intermediate_size,):
 super().__init__()

 self.hidden_size = hidden_size
 self.intermediate_size = intermediate_size

 self.up = nn.Linear(hidden_size, intermediate_size, bias=False)
 self.gate = nn.Linear(hidden_size, intermediate_size, bias=False)
 self.down = nn.Linear(intermediate_size, hidden_size, bias=False)
 self.act_fn = nn.GELU()

 def forward(self,x):

 up = self.up(x)
 gate = self.gate(x)

 # note that * in torch is element wise multiplication. The two operands need to be of same size.
 return self.down(self.act_fn(up * gate))


In [13]:
class TransformerBlock(nn.Module):
 def __init__(self, hidden_size, n_attn_heads, n_kv_heads, intermediate_size,max_len, residual=True):
 super().__init__()

 self.attn = Attention(n_attn_heads,n_kv_heads,hidden_size,max_len)
 self.mlp = MLP(hidden_size, intermediate_size)
 self.residual = residual
 self.norm = nn.LayerNorm(hidden_size)

 def forward(self, x, normalise):

 if normalise:
 normalised_x = self.norm(x)
 attn_out = self.attn(normalised_x)
 else:
 attn_out = self.attn(x)

 if self.residual:
 attn_out = x + attn_out

 if normalise:
 normalised_x = self.norm(attn_out)
 mlp_out = self.mlp(normalised_x)
 else:
 mlp_out = self.mlp(attn_out)

 if self.residual:
 mlp_out = attn_out + mlp_out

 return mlp_out

In [14]:
class NanoLlama(nn.Module):

 def __init__(self,n_layers, vocab_size, hidden_size, n_attn_heads, n_kv_heads, intermediate_size,max_len, residual, normalise=True):
 super().__init__()
 self.embedding = nn.Embedding(vocab_size, hidden_size)
 self.n_layers = n_layers
 self.layers = nn.ModuleList(
 [TransformerBlock(hidden_size, n_attn_heads, n_kv_heads, intermediate_size, max_len, residual) for _ in range(n_layers)]
 )
 self.normalise = normalise
 self.norm = nn.LayerNorm(hidden_size)

 def forward(self,x):

 x = self.embedding(x)
 for layer in self.layers:
 x = layer(x, self.normalise)

 if self.normalise:
 x = self.norm(x)

 return x


In [15]:
class NanoLlamaForCausalLM(nn.Module):

 def __init__(self,n_layers, vocab_size, hidden_size, n_attn_heads, n_kv_heads, intermediate_size,max_len=256, residual=True, normalise=True):
 super().__init__()

 self.model = NanoLlama(n_layers, vocab_size, hidden_size, n_attn_heads, n_kv_heads, intermediate_size,max_len, residual,normalise)
 self.lm_head = nn.Linear(hidden_size,vocab_size, bias=False)
 self.max_len = max_len
 self.n_layers = n_layers
 self.n_attn_heads = n_attn_heads
 self.n_kv_heads = n_kv_heads
 self.hidden_dim = hidden_size


 # Apply Kaiming uniform initialization to the weights of the linear layers
 for m in self.modules():
 if isinstance(m, nn.Linear):
 nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')

 def forward(self,input_ids,targets=None):
 x = self.model(input_ids)

 if targets is not None:
 logits = self.lm_head(x)
 loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
 else:
 # we only need to pass the last token's outputs through lm_head. Rest we ignore.
 logits = self.lm_head(x[:, [-1], :])
 loss = None

 return logits,loss

 @torch.no_grad()
 def generate(self, input_ids, max_new_tokens=20, temperature=1.0, sample = False):

 if input_ids.device!=self.model.embedding.weight.device:
 input_ids = input_ids.to(self.model.embedding.weight.device)

 assert max_new_tokens>0
 assert temperature>0

 tokens_generated = 0
 while True:
 logits,loss = self.forward(input_ids)
 final_token_logits = logits[:,-1,:]
 if not sample:
 next_token = torch.argmax(final_token_logits) # Return the token with max prob

 #Sample from multinomial distribution with probabilities calculated from logits
 final_token_logits = final_token_logits/temperature # scale by temperature
 probs = nn.functional.softmax(final_token_logits, dim=-1)
 next_token = torch.multinomial(probs, num_samples=1)

 tokens_generated += 1
 input_ids = torch.cat((input_ids, next_token), dim=1) # Add next token ID to input_ids for generating further tokens
 # print(f'input ids shape {input_ids.shape[-1]}, {max_new_tokens}')
 if input_ids.shape[-1]>=min(self.max_len,max_new_tokens)-2:
 break

 del logits # delete the logits to save memory

 return input_ids.cpu().numpy()
 
 def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
 # start with all of the candidate parameters
 param_dict = {pn: p for pn, p in self.named_parameters()}
 # filter out those that do not require grad
 param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
 # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
 # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
 decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
 nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
 optim_groups = [
 {'params': decay_params, 'weight_decay': weight_decay},
 {'params': nodecay_params, 'weight_decay': 0.0}
 ]
 num_decay_params = sum(p.numel() for p in decay_params)
 num_nodecay_params = sum(p.numel() for p in nodecay_params)
 print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
 print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
 # Create AdamW optimizer and use the fused version if it is available
 fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
 use_fused = fused_available and device_type == 'cuda'
 extra_args = dict(fused=True) if use_fused else dict()
 optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
 print(f"using fused AdamW: {use_fused}")

 return optimizer
 
 def get_num_params(self):
 n_params = sum(p.numel() for p in self.parameters())
 return n_params


## Training Setup

In [16]:
def get_lr(it, warmup_iters, lr_decay_iters, min_lr, learning_rate):
 if it < warmup_iters:
 return learning_rate * it / warmup_iters
 # 2) if it > lr_decay_iters, return min learning rate
 if it > lr_decay_iters:
 return min_lr
 # 3) in between, use cosine decay down to min learning rate
 decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
 assert 0 <= decay_ratio <= 1
 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
 return min_lr + coeff * (learning_rate - min_lr)

@torch.no_grad()
def run_model_and_get_loss(model, steps=100):

 model.eval()
 phase_wise_loss = {}
 for phase in ['train', 'val']:
 losses = []
 for _ in range(steps):
 input_ids, labels = get_data(phase)
 _, loss = model(input_ids, labels)
 losses.append(loss.item())
 phase_wise_loss[phase] = np.mean(losses)
 model.train()

 return phase_wise_loss['train'], phase_wise_loss['val']

def train_model(model, optimizer, num_iters, device, accumulation_steps=1, eval_steps=100, lr_decay=False, batch_size=512, max_grad_norm=-1,train_dtype=torch.float32):
 
 train_losses = []
 val_losses = []
 model = model.to(device)
 
 warmup_iters = 100
 lr_decay_iters = num_iters
 min_lr = 1e-4
 start_lr = 1e-3

 scaler = GradScaler(enabled=(train_dtype!=torch.float32)) # to make sure grads are in FP32 even for BF/FP16 trainig

 inputs, labels = get_data('train', batch_size)

 for iter_num in tqdm(range(num_iters),'training'):

 start_time = time()

 lr = get_lr(iter_num, warmup_iters, lr_decay_iters, min_lr, start_lr) if lr_decay else start_lr
 for param_group in optimizer.param_groups:
 param_group['lr'] = lr
 
 for _ in range(accumulation_steps):
 with torch.autocast(dtype=train_dtype, device_type='cuda'):
 _, loss = model(inputs, labels)
 loss = loss / accumulation_steps

 scaler.scale(loss).backward()
 
 if max_grad_norm != -1:
 scaler.unscale_(optimizer)
 torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

 inputs, labels = get_data('train',batch_size)

 scaler.step(optimizer)
 scaler.update()
 optimizer.zero_grad(set_to_none=True)

 end_time = time()
 
 if (iter_num + 1) % eval_steps == 0:
 with torch.autocast(dtype=train_dtype, device_type='cuda'):
 train_loss, val_loss = run_model_and_get_loss(model)
 train_losses.append(train_loss)
 val_losses.append(val_loss)

 print(f'Iteration: {iter_num + 1}, Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}')
 
 plt.plot(train_losses, label='train_loss')
 plt.plot(val_losses, label='val_loss')
 plt.xlabel('Iterations')
 plt.ylabel('Loss')
 plt.legend()
 plt.title('Training and Validation Loss')
 plt.show()

## Smol models

In [17]:
nano_llama = NanoLlamaForCausalLM(n_layers=3, vocab_size=len(all_characters)+1, hidden_size=64, n_attn_heads=4, n_kv_heads=2, intermediate_size=256,max_len=block_size, residual=False, normalise=False)
nano_llama.get_num_params()

184832

In [18]:
nano_llama = nano_llama.to(device)
nano_llama = torch.compile(nano_llama)

In [19]:
optimizer = torch.optim.AdamW(nano_llama.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

In [20]:
train_model(nano_llama, optimizer, num_iters=1000,device=device)

training: 0%| | 0/1000 [00:00

### Residual

![alt text](image.png)

In [21]:
nano_llama = NanoLlamaForCausalLM(n_layers=3, vocab_size=len(all_characters)+1, hidden_size=64, n_attn_heads=4, n_kv_heads=2, intermediate_size=256,max_len=block_size, residual=True, normalise=False)
nano_llama.to(device)
nano_llama = torch.compile(nano_llama)
nl_param_count = nano_llama.get_num_params()
nl_param_count#, nano_llama

184832

In [22]:
optimizer = torch.optim.AdamW(nano_llama.parameters(), lr=0.0001)
train_model(nano_llama, optimizer, num_iters=500,device=device)

training: 0%| | 0/500 [00:00

### Nomralise

In [23]:
nano_llama = NanoLlamaForCausalLM(n_layers=3, vocab_size=len(all_characters)+1, hidden_size=64, n_attn_heads=4, n_kv_heads=2, intermediate_size=256,max_len=block_size, residual=False, normalise=True)
nano_llama.to(device)
nano_llama = torch.compile(nano_llama)
nl_param_count = nano_llama.get_num_params()
nl_param_count#, nano_llama

184832

In [24]:
optimizer = torch.optim.AdamW(nano_llama.parameters(), lr=0.0001)
train_model(nano_llama, optimizer, num_iters=500,device=device)

training: 0%| | 0/500 [00:00

### Both

In [25]:
nano_llama = NanoLlamaForCausalLM(n_layers=3, vocab_size=len(all_characters)+1, hidden_size=64, n_attn_heads=4, n_kv_heads=2, intermediate_size=256,max_len=block_size, residual=True, normalise=True)
nano_llama.to(device)
nano_llama = torch.compile(nano_llama)
nl_param_count = nano_llama.get_num_params()
nl_param_count#, nano_llama

184832

In [26]:
optimizer = torch.optim.AdamW(nano_llama.parameters(), lr=0.0001)
train_model(nano_llama, optimizer, num_iters=500,device=device)

training: 0%| | 0/500 [00:00

### max grad norm

In [27]:
nano_llama = NanoLlamaForCausalLM(n_layers=3, vocab_size=len(all_characters)+1, hidden_size=64, n_attn_heads=4, n_kv_heads=2, intermediate_size=256,max_len=block_size, residual=True, normalise=True)
nano_llama.to(device)
nano_llama = torch.compile(nano_llama)
nl_param_count = nano_llama.get_num_params()
nl_param_count#, nano_llama

184832

In [28]:
optimizer = torch.optim.AdamW(nano_llama.parameters(), lr=0.0001)
train_model(nano_llama, optimizer, num_iters=1000,device=device,max_grad_norm=1.0)

training: 0%| | 0/1000 [00:00

### Does Higher intermediate size matter

In [29]:
for intermediate_size in [32,64,128,256,512]:
 nano_llama = NanoLlamaForCausalLM(n_layers=3, vocab_size=len(all_characters)+1, hidden_size=64, n_attn_heads=4, n_kv_heads=2, intermediate_size=intermediate_size,max_len=block_size, residual=True, normalise=True)
 nano_llama.to(device)
 nano_llama = torch.compile(nano_llama)
 nl_param_count = nano_llama.get_num_params()
 nl_param_count#, nano_llama
 optimizer = torch.optim.AdamW(nano_llama.parameters(), lr=0.0001)
 train_model(nano_llama, optimizer, num_iters=1000,device=device,max_grad_norm=1.0)

training: 0%| | 0/1000 [00:00

training: 0%| | 0/1000 [00:00

training: 0%| | 0/1000 [00:00

training: 0%| | 0/1000 [00:00

training: 0%| | 0/1000 [00:00

In [30]:
for intermediate_size in [1024,2048,4096]:
 nano_llama = NanoLlamaForCausalLM(n_layers=3, vocab_size=len(all_characters)+1, hidden_size=64, n_attn_heads=4, n_kv_heads=2, intermediate_size=intermediate_size,max_len=block_size, residual=True, normalise=True)
 nano_llama.to(device)
 nano_llama = torch.compile(nano_llama)
 nl_param_count = nano_llama.get_num_params()
 nl_param_count#, nano_llama
 optimizer = torch.optim.AdamW(nano_llama.parameters(), lr=0.0001)
 train_model(nano_llama, optimizer, num_iters=1000,device=device,max_grad_norm=1.0)

training: 0%| | 0/1000 [00:00

training: 0%| | 0/1000 [00:00

training: 0%| | 0/1000 [00:00

## Onto Big Model

In [None]:
nano_llama_big = NanoLlamaForCausalLM(n_layers=12, vocab_size=128, hidden_size=512, n_attn_heads=16, n_kv_heads=8, intermediate_size=2048, residual=True, normalise=True)
nl_param_count = nano_llama_big.get_num_params()
nano_llama_big = torch.compile(nano_llama_big)
nano_llama_big.to(device)
print(f'Param count {nl_param_count} aka {nl_param_count/(10**6)} million params')

In [None]:
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
 torch.cuda.empty_cache()
 optimizer = nano_llama_big.configure_optimizers(0.1,0.001,(0.9,0.99),'cuda')
 train_model(nano_llama_big, optimizer, num_iters=2000,device=device, eval_steps = 200,lr_decay=True, batch_size = 64, max_grad_norm=1.0, train_dtype=torch.bfloat16)

In [None]:
text = '\n looking for clues. '
encoded_text = encode(text)
tensor_input = torch.tensor([encoded_text]).to(device).reshape(1,-1) # to adjust for the lack of batch
out_tokens = nano_llama_big.generate(tensor_input, max_new_tokens=200,temperature=1)
decoded_text = decode(out_tokens[0])
print(''.join(decoded_text[len(text):]))

In [None]:
random_nano_llama_big = NanoLlamaForCausalLM(n_layers=12, vocab_size=128, hidden_size=512, n_attn_heads=16, n_kv_heads=8, intermediate_size=2048, residual=True, normalise=True)

In [None]:
decoded_text2

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained('/home/datta0/spt', trust_remote_code = True)
tokenizer = AutoTokenizer.from_pretrained('/home/datta0/spt', trust_remote_code = True)