abhishek4607 commited on
Commit
e97f4e2
·
verified ·
1 Parent(s): 30208da

Upload 16 files

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ loss_eval.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ .DS_Store
3
+
4
+ data/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Saqib Azim
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPT-2 Implementation in PyTorch
2
+
3
+ This project reproduces the GPT-2 model in pytorch and trains it from scratch on the FineWeb-Edu dataset - a high-quality subset of FineWeb dataset tailored for educational content. The goal is to offer a simplified, easy-to-understand PyTorch implementation. Note that this code is intended primarily for educational purposes and is not optimized for speed or production deployment.
4
+
5
+ ### Key Features
6
+ - **Simplified PyTorch Implementation:** Designed to be accessible and well-commented for ease of understanding.
7
+ - **Customizable Training:** Hyperparameters are configurable via the command line and can be easily modified.
8
+ - **Multi-GPU Training Support:** Training can be performed using multiple GPUs using PyTorch Distributed Data Parallel (DDP).
9
+
10
+
11
+ ## Repository Structure
12
+ - `src/train.py`: Script to train the GPT-2 model with customizable configurations.
13
+ - `src/model.py`: Contains the GPT-2 model implementation, including embedding layers, transformer blocks, and output layers.
14
+ - `src/dataloader.py`: Handles data loading and batching for the model during training.
15
+ - `src/prepare_dataset.py`: Downloads and preprocesses the FineWebEdu dataset. Run this script before starting the training process.
16
+ - `requirements.txt`: Python dependencies required to run the project.
17
+
18
+
19
+ ## Getting Started
20
+
21
+ ### Prerequisites
22
+ Ensure you have the following dependencies installed:
23
+
24
+ - numpy
25
+ - pytorch
26
+ - tiktoken
27
+ - transformers (from huggingface)
28
+
29
+ You can install all dependencies with:
30
+ ```bash
31
+ pip install -r requirements.txt
32
+ ```
33
+
34
+ ## Dataset
35
+
36
+ The GPT-2 model was originally trained on the WebText dataset (not publicly released). For this project, we use the FineWebEdu-10B dataset—a specialized educational subset of the FineWeb dataset. It contains approximately 10 billion tokens focused on high-quality educational content.
37
+
38
+ To download and prepare the dataset:
39
+ ```bash
40
+ python prepare_dataset.py
41
+ ```
42
+
43
+ ### Running the Training Script
44
+ You can start training the GPT-2 model using the following commands:
45
+
46
+ You can experiment with different training and model config hyperparameters by setting them through the command line.
47
+
48
+ - Single-GPU Training:
49
+ ```bash
50
+ python train.py --num_epochs=5
51
+ ```
52
+
53
+ - Multi-GPU Training (uses Pytorch DDP):
54
+ ```bash
55
+ torchrun --standalone --nproc_per_node=4 train.py # adjust number of GPUs as per availability
56
+ ```
57
+
58
+ For more details on the training process and customizing hyperparameters, refer to the `src/train.py` script.
59
+
60
+ Training was performed from scratch using multiple GPUs with PyTorch's DDP framework.
61
+
62
+
63
+ After training the model, you can generate text based on custom prompts. Use the `src/inference.py` script to interact with the trained model and generate creative continuations.
64
+
65
+ Run the inference script from the command line with the following syntax:
66
+ ```bash
67
+ python3 inference.py --prompt="I am a AI and robotics enthusiast, I want to" --max_tokens=50 --num_seq=5
68
+ ```
69
+
70
+ This command will output 5 unique text sequences, each starting with the provided prompt and continuing for up to 50 tokens.
71
+
72
+
73
+ ### Model Architecture
74
+ The GPT-2 model consists of the following components:
75
+
76
+ - **Token Embedding Layer:** Encodes input tokens to dense vectors.
77
+ - **Positional Embedding Layer:** Adds positional information to the token embeddings.
78
+ - **Transformer Blocks:** Each block includes layer normalization, multi-headed self-attention, and an MLP with residual connections.
79
+ - **Output Head:** Predicts the next token in the sequence based on the preceding context.
80
+
81
+ The model is trained to predict the next token in a sequence, enabling coherent text generation. For token generation, I have used huggingface `tiktoken` library that generates 50,257 tokens (same as GPT-2).
82
+
83
+
84
+ ### Results
85
+
86
+ The GPT-2 model was trained for roughly 95,365 steps (5 epochs) using two NVIDIA A100 GPUs. Training took approximately 46 hours.
87
+
88
+ ![Training loss and Helloswag evaluation](./assets/loss_eval.png)
89
+
90
+ To generate from the trained model, we provide an input prompt sequence, and ask the model to generate the next N tokens. Here are some samples of text generated by the trained model:
91
+
92
+ - **prompt text:** "Hello, I am a language model"
93
+ - **Model output:**
94
+ ```
95
+ - Hello, I am a language modeler. I use the API, in whatever language I require it to write out. On first, I define a model for
96
+
97
+ - Hello, I am a language model expert and need help with building these model. The project is designed in C++ and the Python library is used. The project
98
+
99
+ - Hello, I am a language model developer at Google Cloud. It has great features on most platforms which makes it one of most popular. It also integrates with third
100
+ ```
101
+
102
+ - **prompt text:** "I am a machine learning and robotics enthusiast, and I want to"
103
+ - **Model output:**
104
+ ```
105
+ - I am a machine learning and robotics enthusiast, and I want to share my excitement about this work as soon as possible.
106
+ The purpose of this project was to help the engineers and programmers understand how the HURD and AVR circuits work and how
107
+
108
+ - I am a machine learning and robotics enthusiast, and I want to try and train a new machine learning-based system such as a deep learning algorithm that is completely new to me.
109
+
110
+ - I am a machine learning and robotics enthusiast, and I want to help you by helping you improve your Python programming skills.To understand the concept of machine learning, you must understand the concept of a machine learning model. Machine learning models
111
+
112
+ - I am a machine learning and robotics enthusiast, and I want to be a part of the team.<|endoftext|>In your next project, you need to gather some interesting information from your team team. This data will help form a map that you can use to
113
+
114
+ - I am a machine learning and robotics enthusiast, and I want to create a new, more sophisticated machine learning-based library for programming languages. To start, I am interested in the machine learning (ML) capabilities of new AI methods and techniques.
115
+ ```
116
+
117
+
118
+ ## Potential Future Work
119
+
120
+ 1. **Dataset Shuffling:** The current training code does not shuffle the dataset after each epoch. Implementing dataset shuffling between epochs could improve the model's ability to generalize and prevent overfitting to the order of the training data.
121
+
122
+ 2. **Extended Training:** Experiment with training the model for more epochs to potentially improve performance. Monitor validation loss to determine the optimal number of epochs and implement early stopping if necessary.
123
+
124
+
125
+ ## References:
126
+ - [Language Models are Unsupervised Multitask Learners (GPT-2 Paper)](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
127
+ - [GPT-3 Paper: Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165)
128
+ - [FineWebEdu-10B Dataset](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
129
+ - [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135)
130
+ - [Attention is all you need](https://arxiv.org/abs/1706.03762)
131
+ - [HellaSwag: Can a Machine Really Finish Your Sentence?](https://arxiv.org/abs/1905.07830)
132
+ - Andrej Karpathy's Video Tutorial on GPT
133
+
134
+
135
+ ## Acknowledgments
136
+ This implementation is inspired by Andrej Karpathy’s tutorial and his approach to making complex AI concepts more accessible.
__init__.py ADDED
File without changes
dataloader.cpython-311.pyc ADDED
Binary file (4.39 kB). View file
 
dataloader.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+
5
+ script_dir = os.path.dirname(__file__)
6
+
7
+
8
+ class DataLoaderLite:
9
+ """ A simple dataloader for FineWebEdu-10B dataset """
10
+
11
+ def __init__(self, B, T, process_rank, num_processes, split='train'):
12
+ super().__init__()
13
+ self.B, self.T = B, T
14
+ self.process_rank = process_rank
15
+ self.num_processes = num_processes
16
+ assert split in {'train', 'val'}
17
+
18
+ # get the shard filenames
19
+ data_root = os.path.join(script_dir, "../data/edu_fineweb10B")
20
+ shard_filenames = os.listdir(data_root)
21
+ shard_filenames = sorted([filename for filename in shard_filenames if split in filename])
22
+ self.shard_filepaths = [os.path.join(data_root, filename) for filename in shard_filenames]
23
+ assert len(self.shard_filepaths) > 0, f'no shards found for split {split}'
24
+ master_process = process_rank == 0
25
+ if master_process:
26
+ print(f'found {len(self.shard_filepaths)} shards for split {split}')
27
+ self.reset()
28
+
29
+ def load_tokens(self, filepath):
30
+ tokens = torch.tensor(np.load(filepath).astype(np.int32), dtype=torch.long)
31
+ return tokens
32
+
33
+ def reset(self):
34
+ # state, init at shard 0
35
+ self.curr_shard = 0
36
+ self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard])
37
+ self.curr_pos = self.B * self.T * self.process_rank
38
+
39
+ def next_batch(self):
40
+ B, T = self.B, self.T
41
+ batch = self.tokens[self.curr_pos : self.curr_pos + B*T + 1]
42
+ x_batch = batch[:-1].view(B, T)
43
+ y_batch = batch[1:].view(B, T)
44
+ self.curr_pos += B * T * self.num_processes
45
+ if self.curr_pos + (B * T + 1) > len(self.tokens):
46
+ self.curr_shard = (self.curr_shard + 1) % len(self.shard_filepaths)
47
+ self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard])
48
+ self.curr_pos = self.B * self.T * self.process_rank
49
+ return x_batch, y_batch
hellaswag_eval.cpython-311.pyc ADDED
Binary file (12.8 kB). View file
 
hellaswag_eval.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Downloads and evaluates HellaSwag in Python.
3
+ https://github.com/rowanz/hellaswag
4
+
5
+ Example HellaSwag json item:
6
+
7
+ {"ind": 24, "activity_label": "Roof shingle removal", "ctx_a": "A man is sitting on a roof.", "ctx_b": "he", "ctx": "A man is sitting on a roof. he", "split": "val", "split_type": "indomain", "label": 3, "endings": ["is using wrap to wrap a pair of skis.", "is ripping level tiles off.", "is holding a rubik's cube.", "starts pulling up roofing on a roof."], "source_id": "activitynet~v_-JhWjGDPHMY"}
8
+
9
+ ind: dataset ID
10
+ activity_label: The ActivityNet or WikiHow label for this example
11
+ context: There are two formats. The full context is in ctx. When the context ends in an (incomplete) noun phrase, like for ActivityNet, this incomplete noun phrase is in ctx_b, and the context up until then is in ctx_a. This can be useful for models such as BERT that need the last sentence to be complete. However, it's never required. If ctx_b is nonempty, then ctx is the same thing as ctx_a, followed by a space, then ctx_b.
12
+ endings: a list of 4 endings. The correct index is given by label (0,1,2, or 3)
13
+ split: train, val, or test.
14
+ split_type: indomain if the activity label is seen during training, else zeroshot
15
+ source_id: Which video or WikiHow article this example came from
16
+
17
+ gpt2 (124M)
18
+ - eleuther harness reports acc 28.92%, acc_norm 31.14% (multiple choice style)
19
+ - this script: 10042 acc: 0.2859 acc_norm: 0.2955 (completion style)
20
+
21
+ gpt2-xl (1558M)
22
+ - eleuther harness reports acc 40.04%, acc_norm 50.89% (multiple choice style)
23
+ - this script: 10042 acc: 0.3842 acc_norm: 0.4893 (completion style)
24
+
25
+ The validation set of HellaSwag has a total of 10,042 examples.
26
+ """
27
+
28
+ import os
29
+ import json
30
+ import requests
31
+ import tiktoken
32
+ from tqdm import tqdm
33
+ import torch
34
+ import torch.nn as nn
35
+ from torch.nn import functional as F
36
+ from transformers import GPT2LMHeadModel
37
+
38
+ # -----------------------------------------------------------------------------
39
+ DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "hellaswag")
40
+
41
+ def download_file(url: str, fname: str, chunk_size=1024):
42
+ """Helper function to download a file from a given url"""
43
+ resp = requests.get(url, stream=True)
44
+ total = int(resp.headers.get("content-length", 0))
45
+ with open(fname, "wb") as file, tqdm(
46
+ desc=fname,
47
+ total=total,
48
+ unit="iB",
49
+ unit_scale=True,
50
+ unit_divisor=1024,
51
+ ) as bar:
52
+ for data in resp.iter_content(chunk_size=chunk_size):
53
+ size = file.write(data)
54
+ bar.update(size)
55
+
56
+ hellaswags = {
57
+ "train": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_train.jsonl",
58
+ "val": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl",
59
+ "test": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_test.jsonl",
60
+ }
61
+
62
+ enc = tiktoken.get_encoding("gpt2")
63
+
64
+ def download(split):
65
+ """Downloads HellaSwag DATA_CACHE_DIR"""
66
+ os.makedirs(DATA_CACHE_DIR, exist_ok=True)
67
+ data_url = hellaswags[split]
68
+ data_filename = os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl")
69
+ if not os.path.exists(data_filename):
70
+ print(f"Downloading {data_url} to {data_filename}...")
71
+ download_file(data_url, data_filename)
72
+
73
+ def render_example(example):
74
+ """
75
+ Given the example as a dictionary, render it as three torch tensors:
76
+ - tokens (the tokens of context + completion, of size 4xN, as there are always 4 candidates)
77
+ - mask (is 1 in the region of the candidate completion, where we evaluate likelihoods)
78
+ - label (the index of the correct completion, which we hope has the highest likelihood)
79
+ """
80
+ ctx = example["ctx"]
81
+ label = example["label"]
82
+ endings = example["endings"]
83
+ # data needed to reproduce this eval on the C size
84
+ data = {
85
+ "label": label,
86
+ "ctx_tokens": None,
87
+ "ending_tokens": [],
88
+ }
89
+ # gather up all the tokens
90
+ ctx_tokens = enc.encode(ctx)
91
+ data["ctx_tokens"] = ctx_tokens
92
+ tok_rows = []
93
+ mask_rows = []
94
+ for end in endings:
95
+ end_tokens = enc.encode(" " + end) # note: prepending " " because GPT-2 tokenizer
96
+ tok_rows.append(ctx_tokens + end_tokens)
97
+ mask_rows.append([0]*len(ctx_tokens) + [1]*len(end_tokens))
98
+ data["ending_tokens"].append(end_tokens)
99
+
100
+ # have to be careful during the collation because the number of tokens in each row can differ
101
+ max_len = max(len(row) for row in tok_rows)
102
+ tokens = torch.zeros((4, max_len), dtype=torch.long)
103
+ mask = torch.zeros((4, max_len), dtype=torch.long)
104
+ for i, (tok_row, mask_row) in enumerate(zip(tok_rows, mask_rows)):
105
+ tokens[i, :len(tok_row)] = torch.tensor(tok_row)
106
+ mask[i, :len(mask_row)] = torch.tensor(mask_row)
107
+ return data, tokens, mask, label
108
+
109
+ def iterate_examples(split):
110
+ # there are 10,042 examples in total in val
111
+ download(split)
112
+ with open(os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl"), "r") as f:
113
+ for line in f:
114
+ example = json.loads(line)
115
+ yield example
116
+
117
+ @torch.no_grad()
118
+ def evaluate(model_type, device):
119
+ torch.set_float32_matmul_precision('high') # use tf32
120
+ model = GPT2LMHeadModel.from_pretrained(model_type)
121
+ model.to(device)
122
+ # model = torch.compile(model) # optionally torch compile the model
123
+ num_correct_norm = 0
124
+ num_correct = 0
125
+ num_total = 0
126
+ for example in iterate_examples("val"):
127
+ data, tokens, mask, label = render_example(example)
128
+ tokens = tokens.to(device)
129
+ mask = mask.to(device)
130
+
131
+ # get the logits
132
+ logits = model(tokens).logits
133
+ # evaluate the autoregressive loss at all positions
134
+ shift_logits = (logits[..., :-1, :]).contiguous()
135
+ shift_tokens = (tokens[..., 1:]).contiguous()
136
+ flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
137
+ flat_shift_tokens = shift_tokens.view(-1)
138
+ shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
139
+ shift_losses = shift_losses.view(tokens.size(0), -1)
140
+ # now get the average loss just for the completion region (where mask == 1), in each row
141
+ shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
142
+ masked_shift_losses = shift_losses * shift_mask
143
+ # sum and divide by the number of 1s in the mask
144
+ sum_loss = masked_shift_losses.sum(dim=1)
145
+ avg_loss = sum_loss / shift_mask.sum(dim=1)
146
+ # now we have a loss for each of the 4 completions
147
+ # the one with the lowest loss should be the most likely
148
+ pred = sum_loss.argmin().item()
149
+ pred_norm = avg_loss.argmin().item()
150
+
151
+ # accumulate stats
152
+ num_total += 1
153
+ num_correct += int(pred == label)
154
+ num_correct_norm += int(pred_norm == label)
155
+ print(f"{num_total} acc_norm: {num_correct_norm}/{num_total}={num_correct_norm/num_total:.4f}")
156
+
157
+ # debug: pretty print a few examples, and the losses in each case
158
+ if num_total < 10:
159
+ print("---")
160
+ print(f"Context:\n {example['ctx']}")
161
+ print(f"Endings:")
162
+ for i, end in enumerate(example["endings"]):
163
+ print(f"{i} (loss: {avg_loss[i].item():.4f}) {end}")
164
+ print(f"predicted: {pred_norm}, actual: {label}")
165
+
166
+
167
+ def get_most_likely_row(tokens, mask, logits):
168
+ """
169
+ helper function for HellaSwag eval. Takes tokens, mask, and logits,
170
+ returns the index of the completion with the lowest loss
171
+ """
172
+ # evaluate the autoregressive loss at all positions
173
+ shift_logits = (logits[..., :-1, :]).contiguous()
174
+ shift_tokens = (tokens[..., 1:]).contiguous()
175
+ flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
176
+ flat_shift_tokens = shift_tokens.view(-1)
177
+ shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
178
+ shift_losses = shift_losses.view(tokens.size(0), -1)
179
+ # now get the average loss just for the completion region (where mask == 1), in each row
180
+ shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
181
+ masked_shift_losses = shift_losses * shift_mask
182
+ # sum and divide by the number of 1s in the mask
183
+ sum_loss = masked_shift_losses.sum(dim=1)
184
+ avg_loss = sum_loss / shift_mask.sum(dim=1)
185
+ # now we have a loss for each of the 4 completions
186
+ # the one with the lowest loss should be the most likely
187
+ pred_norm = avg_loss.argmin().item()
188
+ return pred_norm
189
+
190
+
191
+ if __name__ == "__main__":
192
+ import argparse
193
+ parser = argparse.ArgumentParser()
194
+ parser.add_argument("-m", "--model_type", type=str, default="gpt2", help="the model type to use")
195
+ parser.add_argument("-d", "--device", type=str, default="cuda", help="the device to use")
196
+ args = parser.parse_args()
197
+ evaluate(args.model_type, args.device)
inference.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import tiktoken
5
+ from dataclasses import dataclass
6
+
7
+ from model import GPT
8
+
9
+
10
+ class GPT2Inference:
11
+ """ To generate text sequences using a trained GPT2 model """
12
+
13
+ def __init__(self, model, token_encoder, device):
14
+ self.model = model
15
+ self.token_encoder = token_encoder
16
+ self.device = device
17
+ self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
18
+
19
+ def generate_sequences(self, prompt, num_seq=5, max_tokens=50):
20
+ self.model.eval()
21
+ tokens = self.token_encoder.encode(prompt)
22
+ tokens = torch.tensor(tokens, dtype=torch.long) # (n,) n : current sequence length
23
+ tokens = tokens.unsqueeze(0).repeat(num_seq, 1) # (1,n) --> (num_seq, n)
24
+ gen_tokens = tokens.to(self.device)
25
+ # create a different rng generator so as not to impact the global rng state used for training
26
+ sample_rng = torch.Generator(device=self.device).manual_seed(42)
27
+
28
+ # generate new tokens one token at a time until the sequence length becomes 'max_tokens'
29
+ while gen_tokens.shape[-1] <= max_tokens:
30
+ with torch.no_grad():
31
+ with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
32
+ logits, loss = self.model(gen_tokens) # (num_seq, n, vocab_size)
33
+ logits = logits[:, -1, :] # (num_seq, vocab_size)
34
+ probs = F.softmax(logits, dim=-1) # (num_seq, vocab_size)
35
+ # take top-k 50 probs
36
+ topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # (num_seq, 50), (num_seq, 50)
37
+ # sample a token from top-50 probabilities
38
+ ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) # (num_seq, 1)
39
+ next_tok = torch.gather(topk_indices, -1, ix) # (num_seq, 1)
40
+ gen_tokens = torch.cat([gen_tokens, next_tok], dim=1)
41
+ # decode generated tokens and print generated text
42
+ for i in range(num_seq):
43
+ tokens = gen_tokens[i, :max_tokens].tolist()
44
+ gen_text = self.token_encoder.decode(tokens)
45
+ print(f"> sample {i}: {gen_text}")
46
+
47
+
48
+ def parse_args():
49
+ import argparse
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument('--prompt', type=str, default="Hello, I am a language model,")
52
+ parser.add_argument('--num_seq', type=int, default=5)
53
+ parser.add_argument('--max_tokens', type=int, default=50)
54
+ args = parser.parse_args()
55
+ return args
56
+
57
+
58
+ @dataclass
59
+ class GPTConfig:
60
+ context_length: int = 1024 # max context / sequence length
61
+ vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 <endoftext> token
62
+ num_layers: int = 12
63
+ embd_size: int = 768 # embedding dim
64
+ num_heads: int = 12
65
+
66
+
67
+ def inference(args=None):
68
+ if args is None:
69
+ args = parse_args()
70
+
71
+ device = 'cpu'
72
+ if torch.cuda.is_available():
73
+ device = 'cuda'
74
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
75
+ device = 'mps' # for apple macbook GPUs
76
+ print(f'using device: {device}')
77
+
78
+ model_path = './logs/model_95364.pt'
79
+ checkpoint = torch.load(model_path, weights_only=False)
80
+ print(f"loaded model from: {model_path}")
81
+ # print(checkpoint['model'].keys())
82
+
83
+ model = GPT(config=checkpoint['config'])
84
+ model.load_state_dict(checkpoint['model'])
85
+ model = model.to(device)
86
+ token_encoder = tiktoken.get_encoding('gpt2')
87
+ generator = GPT2Inference(model, token_encoder, device)
88
+
89
+ generator.generate_sequences(args.prompt, args.num_seq, args.max_tokens)
90
+
91
+
92
+ if __name__ == '__main__':
93
+ inference()
log.txt ADDED
File without changes
loss_eval.png ADDED

Git LFS Details

  • SHA256: b45c30768b24e7ed672c36d40755eaa7606c2c21b95b68e75bf7ac485552bfda
  • Pointer size: 131 Bytes
  • Size of remote file: 847 kB
model.cpython-311.pyc ADDED
Binary file (16.6 kB). View file
 
model.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ import inspect
6
+
7
+
8
+ @dataclass
9
+ class GPTConfig:
10
+ context_length: int = 1024 # max context / sequence length
11
+ vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 <endoftext> token
12
+ num_layers: int = 12
13
+ embd_size: int = 768 # embedding dim
14
+ num_heads: int = 12
15
+
16
+
17
+ class CausalSelfAttention(nn.Module):
18
+ def __init__(self, config):
19
+ super().__init__()
20
+ # 'embd_size' sized vector divided into 'num_heads' heads
21
+ assert config.embd_size % config.num_heads == 0, f"embedding dim should be divisible by number of heads"
22
+ self.num_heads = config.num_heads
23
+ self.embd_size = config.embd_size
24
+ # batched key, query, and value projections for all heads
25
+ self.c_attn = nn.Linear(config.embd_size, 3 * config.embd_size)
26
+ self.c_proj = nn.Linear(config.embd_size, config.embd_size)
27
+ self.c_proj.SCALE_INIT = 1.0
28
+ # not really a bias, more of a mask, but following OpenAI/HF naming convention
29
+ # self.register_buffer("bias", torch.tril(torch.ones(config.context_length, config.context_length)).view(1, 1, config.context_length, config.context_length))
30
+
31
+ def forward(self, x):
32
+ B, T, C = x.shape
33
+ # calculate query, key, values for all heads in a batch and move head forward to be the batch dim
34
+ # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
35
+ # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels
36
+ qkv = self.c_attn(x) # (B, T, 3C)
37
+ q, k, v = qkv.split(self.embd_size, dim=-1) # (B,T,C), (B,T,C), (B,T,C)
38
+ q = q.view(B, T, self.num_heads, self.embd_size // self.num_heads).transpose(1, 2) # (B,nh,T,hs)
39
+ k = k.view(B, T, self.num_heads, self.embd_size // self.num_heads).transpose(1, 2) # (B,nh,T,hs)
40
+ v = v.view(B, T, self.num_heads, self.embd_size // self.num_heads).transpose(1, 2) # (B,nh,T,hs)
41
+ # attn = q @ k.transpose(-2, -1) / np.sqrt(k.shape[-1]) # (B,nh,T,hs) @ (B,nh,hs,T) --> (B,nh,T,T)
42
+ # attn = attn.masked_fill(self.bias[:,:,:T,:T] == 0, float("-inf"))
43
+ # attn = F.softmax(attn, dim=-1)
44
+ # out = attn @ v # (B,nh,T,T) @ (B,nh,T,hs) --> (B,nh,T,hs)
45
+ # flash-attention paper (significantly faster, but logically the same as above 4 lines)
46
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # (B,nh,T,hs)
47
+ out = out.transpose(1, 2).contiguous().view(B, T, C) # (B,nh,T,hs) --> (B,T,nh,hs) --> (B,T,C=nh*hs)
48
+ out = self.c_proj(out) # (B,T,C) --> (B,T,C)
49
+ return out
50
+
51
+
52
+ class MLP(nn.Module):
53
+ def __init__(self, config):
54
+ super().__init__()
55
+ self.c_fc = nn.Linear(config.embd_size, 4 * config.embd_size)
56
+ self.gelu = nn.GELU(approximate='tanh') # approximate='tanh' used to try to reproduce gpt2 paper
57
+ self.c_proj = nn.Linear(4 * config.embd_size, config.embd_size)
58
+ self.c_proj.SCALE_INIT = 1.0
59
+
60
+ def forward(self, x):
61
+ x = self.c_fc(x)
62
+ x = self.gelu(x)
63
+ x = self.c_proj(x)
64
+ return x
65
+
66
+
67
+ class Block(nn.Module):
68
+ """ Transformer Encoder block """
69
+
70
+ def __init__(self, config):
71
+ super().__init__()
72
+ self.ln_1 = nn.LayerNorm(config.embd_size)
73
+ self.attn = CausalSelfAttention(config)
74
+ self.ln_2 = nn.LayerNorm(config.embd_size)
75
+ self.mlp = MLP(config)
76
+
77
+ def forward(self, x):
78
+ x = x + self.attn(self.ln_1(x))
79
+ x = x + self.mlp(self.ln_2(x))
80
+ return x
81
+
82
+
83
+ class GPT(nn.Module):
84
+ def __init__(self, config):
85
+ super().__init__()
86
+ self.config = config
87
+ self.transformer = nn.ModuleDict(dict(
88
+ wte = nn.Embedding(self.config.vocab_size, self.config.embd_size),
89
+ wpe = nn.Embedding(self.config.context_length, self.config.embd_size),
90
+ h = nn.ModuleList([Block(self.config) for _ in range(self.config.num_layers)]),
91
+ ln_f = nn.LayerNorm(self.config.embd_size)
92
+ ))
93
+ # language modeling head
94
+ self.lm_head = nn.Linear(self.config.embd_size, self.config.vocab_size, bias=False)
95
+ # weight sharing scheme (reduces 768*50267=~40M params, fewer params, more efficient)
96
+ self.transformer.wte.weight = self.lm_head.weight
97
+ # init params (iterates over all submodules and applies _init_weights)
98
+ self.apply(self._init_weights)
99
+
100
+ def _init_weights(self, module):
101
+ if isinstance(module, nn.Linear):
102
+ std = 0.02
103
+ if hasattr(module, 'SCALE_INIT'):
104
+ std /= (2 * self.config.num_layers)**0.5
105
+ torch.nn.init.normal_(module.weight, mean=0, std=std) # as per openai gpt-2 source code
106
+ if module.bias is not None:
107
+ torch.nn.init.zeros_(module.bias)
108
+ elif isinstance(module, nn.Embedding):
109
+ torch.nn.init.normal_(module.weight, mean=0, std=0.02)
110
+
111
+ def forward(self, idx, targets=None):
112
+ B, T = idx.shape
113
+ assert T <= self.config.context_length, f'sequence length {T} should be <= {self.config.context_length}'
114
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # (T,)
115
+ pos_embd = self.transformer.wpe(pos) # (T, embd_size)
116
+ tok_embd = self.transformer.wte(idx) # (B, T, embd_size)
117
+ x = pos_embd + tok_embd # (B, T, embd_size)
118
+ for block in self.transformer.h:
119
+ x = block(x)
120
+ x = self.transformer.ln_f(x) # (B, T, embd_size)
121
+ logits = self.lm_head(x) # (B, T, vocab_size)
122
+ loss = None
123
+ if targets is not None:
124
+ loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
125
+ return logits, loss
126
+
127
+ @classmethod
128
+ def from_pretrained(cls, model_type):
129
+ """ Loads pretrained GPT2 model weights from huggingface """
130
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
131
+ from transformers import GPT2LMHeadModel
132
+ print(f"loading weights from pretrained gpt: {model_type}")
133
+
134
+ config_args = {
135
+ 'gpt2': dict(num_layers=12, num_heads=12, embd_size=768), # 124M params
136
+ 'gpt2-medium': dict(num_layers=24, num_heads=16, embd_size=1024), # 350M params
137
+ 'gpt2-large': dict(num_layers=36, num_heads=20, embd_size=1280), # 774M params
138
+ 'gpt2-xl': dict(num_layers=48, num_heads=25, embd_size=1600), # 1558M params
139
+ }[model_type]
140
+ config_args['vocab_size'] = 50257
141
+ config_args['context_length'] = 1024
142
+
143
+ # create a from-scratch minGPT model
144
+ config = GPTConfig(**config_args)
145
+ model = GPT(config)
146
+ sd = model.state_dict()
147
+ sd_keys = sd.keys()
148
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')]
149
+
150
+ # init a huggingface transformers model
151
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
152
+ sd_hf = model_hf.state_dict()
153
+ sd_keys_hf = sd_hf.keys()
154
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')]
155
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')]
156
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
157
+
158
+ assert len(sd_keys) == len(sd_keys_hf), f"mismatched keys {len(sd_keys)} != {len(sd_keys_hf)}"
159
+
160
+ # copy while ensuring all parameters are aligned in names and shape
161
+ for k in sd_keys_hf:
162
+ if any(k.endswith(w) for w in transposed):
163
+ # need to transpose Conv1D weights
164
+ assert sd_hf[k].shape[::-1] == sd[k].shape
165
+ with torch.no_grad():
166
+ sd[k].copy_(sd_hf[k].T)
167
+ else:
168
+ assert sd_hf[k].shape == sd[k].shape
169
+ with torch.no_grad():
170
+ sd[k].copy_(sd_hf[k])
171
+ return model
172
+
173
+ def configure_optimizers(self, weight_decay, lr, device_type, master_process):
174
+ """
175
+ Essentially implements weight decay (regularization tool, by decaying the weights, we
176
+ forcing the optimizer to use more of the weights, and not allowing any single weight to dominate)
177
+ """
178
+ # start with all of the candidate params (that require gradient)
179
+ param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
180
+
181
+ # create optim groups: any parameters that are 2D will be weight decayed, otherwise no.
182
+ # i.e., all weight tensors in matmuls + embeddings will decay, whereas biases and layernorms won't be decayed
183
+ decay_params = [p for pn, p in param_dict.items() if p.dim() >= 2]
184
+ nodecay_params = [p for pn, p in param_dict.items() if p.dim() < 2]
185
+ optim_groups = [
186
+ {'params': decay_params, 'weight_decay': weight_decay},
187
+ {'params': nodecay_params, 'weight_decay': 0.0}
188
+ ]
189
+ num_decay_params = sum(p.numel() for p in decay_params)
190
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
191
+ if master_process:
192
+ print(f'num decay parameter tensors: {len(decay_params)} with {num_decay_params:,} parameters')
193
+ print(f'num nodecay parameter tensors: {len(nodecay_params)} with {num_nodecay_params:,} parameters')
194
+
195
+ # use fused version of AdamW optimizer (faster than non-fused version)
196
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
197
+ use_fused = fused_available and device_type == 'cuda'
198
+ if master_process:
199
+ print(f'using fused AdamW optimizer: {use_fused}')
200
+ optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
201
+ return optimizer
prepare_dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing as mp
2
+ from datasets import load_dataset, DownloadConfig
3
+ import backoff
4
+ import os
5
+ from pathlib import Path
6
+ import numpy as np
7
+ import tiktoken
8
+
9
+ # Function to process individual dataset items
10
+ def process_data(item):
11
+ """
12
+ Process a single dataset item.
13
+ Replace this with your actual processing logic (e.g., tokenization).
14
+ """
15
+ # Example: Tokenize text using tiktoken (adjust based on your needs)
16
+ encoder = tiktoken.get_encoding('gpt2')
17
+ text = item.get('text', '') # Assuming dataset has a 'text' field
18
+ tokens = encoder.encode(text)
19
+ return tokens
20
+
21
+ @backoff.on_exception(backoff.expo, Exception, max_tries=5)
22
+ def fetch_data(item):
23
+ """
24
+ Wrapper for process_data with exponential backoff for retries.
25
+ """
26
+ return process_data(item)
27
+
28
+ def main():
29
+ """
30
+ Main function to load and process the FineWeb-Edu dataset.
31
+ """
32
+ # Configuration
33
+ remote_name = "sample-10BT" # Dataset configuration name
34
+ output_dir = "./data" # Directory to save processed data
35
+ os.makedirs(output_dir, exist_ok=True)
36
+
37
+ # Set up download config to handle rate limits and caching
38
+ download_config = DownloadConfig(
39
+ max_retries=5,
40
+ num_proc=4, # Limit to 4 processes to avoid HTTP 429
41
+ cache_dir=Path.home() / ".cache" / "huggingface" / "datasets"
42
+ )
43
+
44
+ try:
45
+ # Load dataset with caching
46
+ print("Loading dataset...")
47
+ dataset = load_dataset(
48
+ 'HuggingFaceFW/fineweb-edu',
49
+ name=remote_name,
50
+ split='train',
51
+ download_mode="reuse_dataset_if_exists",
52
+ download_config=download_config
53
+ )
54
+ print(f"Dataset loaded with {len(dataset)} items.")
55
+
56
+ # Limit number of processes to avoid overwhelming Hugging Face Hub
57
+ nprocs = min(mp.cpu_count(), 4)
58
+ print(f"Using {nprocs} processes for multiprocessing.")
59
+
60
+ # Process dataset using multiprocessing
61
+ with mp.Pool(nprocs) as pool:
62
+ results = pool.map(fetch_data, dataset)
63
+
64
+ # Save processed results (example: save as numpy arrays)
65
+ output_path = os.path.join(output_dir, "processed_fineweb_edu.npy")
66
+ np.save(output_path, results)
67
+ print(f"Processed dataset saved to {output_path}")
68
+
69
+ except Exception as e:
70
+ print(f"Error loading or processing dataset: {e}")
71
+ raise
72
+
73
+ if __name__ == '__main__':
74
+ mp.freeze_support() # Required for Windows compatibility with executables
75
+ main()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ tiktoken
4
+ transformers
train.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ import time
5
+ from dataclasses import dataclass
6
+ import tiktoken
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.distributed as dist
11
+ from torch.nn.parallel import DistributedDataParallel as DDP
12
+ # import code; code.interact(local=locals())
13
+
14
+ from model import GPT
15
+ from dataloader import DataLoaderLite
16
+ from hellaswag_eval import render_example, iterate_examples, get_most_likely_row
17
+
18
+ torch.set_float32_matmul_precision('high') # enable TF32 precision
19
+
20
+ # set torch compile to True (if it doesn't throws any error) to speed up training
21
+ use_torch_compile = False
22
+
23
+
24
+ class Trainer:
25
+ def __init__(
26
+ self,
27
+ model,
28
+ optimizer,
29
+ train_loader,
30
+ val_loader,
31
+ token_encoder,
32
+ eval_freq,
33
+ grad_accum_steps,
34
+ ddp,
35
+ ddp_rank,
36
+ ddp_world_size,
37
+ device,
38
+ logpath
39
+ ):
40
+ self.ddp = ddp
41
+ self.ddp_rank = ddp_rank
42
+ self.master_process = ddp_rank == 0
43
+ self.ddp_world_size = ddp_world_size
44
+
45
+ self.model = model
46
+ self.optimizer = optimizer
47
+ self.train_loader = train_loader
48
+ self.val_loader = val_loader
49
+ self.token_encoder = token_encoder
50
+
51
+ self.eval_freq = eval_freq
52
+ self.grad_accum_steps = grad_accum_steps
53
+ self.device = device
54
+ self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
55
+ self.logpath = logpath
56
+
57
+
58
+ def train(
59
+ self,
60
+ max_steps,
61
+ warmup_steps,
62
+ max_lr,
63
+ min_lr
64
+ ):
65
+ for step in range(max_steps):
66
+ t0 = time.time()
67
+ self.is_last_step = (step == max_steps - 1)
68
+
69
+ # evaluate validation loss
70
+ if step % self.eval_freq == 0 or self.is_last_step:
71
+ self.evaluate_validation(step)
72
+
73
+ # evaluate model performance on HellaSwag every once in a while
74
+ if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile):
75
+ self.evaluate_helloswag(step)
76
+
77
+ # generate sequences from the model every once in a while
78
+ if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile):
79
+ self.generate_sequences(num_seq=5, max_tokens=32)
80
+
81
+ # training loop starts here
82
+ self.model.train() # sets model to train mode
83
+ self.optimizer.zero_grad() # resets all gradients
84
+ batch_loss = 0.0
85
+
86
+ for mini_step in range(self.grad_accum_steps):
87
+ inp, tar = self.train_loader.next_batch()
88
+ inp, tar = inp.to(self.device), tar.to(self.device)
89
+
90
+ # FORWARD PASS !!!
91
+ # autocast to bfloat16 for faster compute and memory efficiency
92
+ with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
93
+ logits, loss = self.model(inp, tar)
94
+
95
+ # loss is scaled to account for gradient accumulation, because the gradients just add
96
+ # on each successive backward() call. Addition of gradients corresponds to SUM in the objective,
97
+ # but we want MEAN instead of a SUM
98
+ loss /= self.grad_accum_steps
99
+ batch_loss += loss.detach()
100
+
101
+ if self.ddp:
102
+ # in the final mini_step, sync and avg all gradients across all processes. used by both forward and backward processes
103
+ # can use 'no_sync()' context manager alternatively.
104
+ self.model.require_backward_grad_sync = (mini_step == self.grad_accum_steps - 1)
105
+
106
+ # each process accumulates gradients separately when 'require_backward_grad_sync'=False
107
+ # in the final 'mini_step', 'require_backward_grad_sync' becomes True, therefore
108
+ # gradients are averaged across all processes and shared among them by loss.backward()
109
+ loss.backward()
110
+
111
+ if self.ddp:
112
+ # 'batch_loss' is outside of DDP container, so need to perform 'all_reduce' to
113
+ # average out 'batch_loss' across all processes of all ranks. 'batch_loss' tensor exists on all GPUs.
114
+ # 'all_reduce' averages and deposits the result on all the processes
115
+ dist.all_reduce(batch_loss, op=dist.ReduceOp.AVG)
116
+
117
+ # once gradients are computed, clip the global l2-norm of the gradient at 1.0
118
+ norm = nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) # monitor/print 'norm'
119
+
120
+ # determine learning rate with decay
121
+ lr = self.estimate_lr(step, warmup_steps, max_steps, max_lr, min_lr)
122
+ # set learning rate for this iteration
123
+ for param_group in self.optimizer.param_groups:
124
+ param_group['lr'] = lr
125
+
126
+ self.optimizer.step()
127
+ if self.device_type == 'cuda':
128
+ torch.cuda.synchronize() # wait for the GPU to finish work
129
+
130
+ dt = (time.time() - t0) * 1000.0 # in ms
131
+ tokens_processed = self.train_loader.B * self.train_loader.T * self.grad_accum_steps * self.ddp_world_size
132
+ tokens_per_sec = tokens_processed / dt
133
+
134
+ if self.master_process:
135
+ print(f'step {step:4d} | loss: {batch_loss.item():.6f} | lr: {lr:.2e} | norm: {norm:.4f} | dt: {dt:.4f}ms | tok/sec: {tokens_per_sec:.4f}')
136
+ with open(self.logpath, 'a') as f:
137
+ f.write(f'{step} train {batch_loss.item():.6f}\n')
138
+
139
+
140
+ def evaluate_validation(self, step):
141
+ self.model.eval() # sets model to eval mode
142
+ self.val_loader.reset()
143
+ # evaluate the model on validation set
144
+ with torch.no_grad():
145
+ val_loss_accum = 0.0
146
+ val_steps = 20
147
+ for _ in range(val_steps):
148
+ inp, tar = self.val_loader.next_batch()
149
+ inp, tar = inp.to(self.device), tar.to(self.device)
150
+ with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
151
+ logits, loss = self.model(inp, tar)
152
+ loss /= val_steps
153
+ val_loss_accum += loss.detach()
154
+
155
+ if self.ddp:
156
+ dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
157
+ if self.master_process:
158
+ print(f'Val loss: {val_loss_accum.item():.4f}')
159
+ with open(self.logpath, 'a') as f:
160
+ f.write(f'{step} val {val_loss_accum.item():.4f}\n')
161
+
162
+ if step > 0 and (step % 10000 == 0 or self.is_last_step):
163
+ raw_model = self.model.module if self.ddp else self.model
164
+ logdir = os.path.dirname(self.logpath)
165
+ ckpt_path = os.path.join(logdir, f'model_{step:05d}.pt')
166
+ checkpoint = {
167
+ 'model': raw_model.state_dict(),
168
+ 'config': raw_model.config,
169
+ 'step': step,
170
+ 'val_loss': val_loss_accum.item()
171
+ } # add optimizer.state_dict(), rng_seeds, etc. if resuming training
172
+ torch.save(checkpoint, ckpt_path)
173
+
174
+
175
+ def evaluate_helloswag(self, step):
176
+ """
177
+ Construct a batch of 4 sequences and perform token completion using
178
+ our model.
179
+ """
180
+ n_total = 0
181
+ n_correct_norm = 0
182
+ for i, example in enumerate(iterate_examples('val')):
183
+ # only process examples where i % ddp_world_size == ddp_rank
184
+ if i % self.ddp_world_size != self.ddp_rank:
185
+ continue
186
+ # render the example into tokens and labels
187
+ _, tokens, mask, label = render_example(example) # (4,N), (4,N), (4,N)
188
+ tokens, mask = tokens.to(self.device), mask.to(self.device)
189
+ with torch.no_grad():
190
+ with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
191
+ logits, loss = self.model(tokens)
192
+ pred_norm = get_most_likely_row(tokens, mask, logits)
193
+ n_total += 1
194
+ n_correct_norm += int(pred_norm == label)
195
+ # reduce the stats across all processes
196
+ if self.ddp:
197
+ n_total = torch.tensor(n_total, device=self.device, dtype=torch.long)
198
+ n_correct_norm = torch.tensor(n_correct_norm, device=self.device, dtype=torch.long)
199
+ dist.all_reduce(n_total, op=dist.ReduceOp.SUM)
200
+ dist.all_reduce(n_correct_norm, op=dist.ReduceOp.SUM)
201
+ n_total = n_total.item()
202
+ n_correct_norm = n_correct_norm.item()
203
+ acc_norm = n_correct_norm / n_total
204
+ if self.master_process:
205
+ print(f'HelloSwag accuracy: {n_correct_norm}/{n_total}={acc_norm:.4f}')
206
+ with open(self.logpath, 'a') as f:
207
+ f.write(f'{step} hellaswag {acc_norm:.4f}\n')
208
+
209
+
210
+ def generate_sequences(self, num_seq=4, max_tokens=32):
211
+ self.model.eval()
212
+ tokens = self.token_encoder.encode("Hello, I am a language model")
213
+ tokens = torch.tensor(tokens, dtype=torch.long) # (n,) n : current sequence length
214
+ tokens = tokens.unsqueeze(0).repeat(num_seq, 1) # (1,n) --> (num_seq, n)
215
+ gen_tokens = tokens.to(self.device)
216
+ # create a different rng generator so as not to impact the global rng state used for training
217
+ sample_rng = torch.Generator(device=self.device)
218
+ # adding 'ddp_rank' in seeding to generate different tokens for different rank processes
219
+ sample_rng.manual_seed(42 + self.ddp_rank)
220
+ # generate new tokens one token at a time until the sequence length becomes 'max_tokens'
221
+ while gen_tokens.shape[-1] <= max_tokens:
222
+ with torch.no_grad():
223
+ with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
224
+ logits, loss = self.model(gen_tokens) # (num_seq, n, vocab_size)
225
+ logits = logits[:, -1, :] # (num_seq, vocab_size)
226
+ probs = F.softmax(logits, dim=-1) # (num_seq, vocab_size)
227
+ # take top-k 50 probs
228
+ topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # (num_seq, 50), (num_seq, 50)
229
+ # sample a token from top-50 probabilities
230
+ ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) # (num_seq, 1)
231
+ next_tok = torch.gather(topk_indices, -1, ix) # (num_seq, 1)
232
+ gen_tokens = torch.cat([gen_tokens, next_tok], dim=1)
233
+ # decode generated tokens and print generated text
234
+ for i in range(num_seq):
235
+ tokens = gen_tokens[i, :max_tokens].tolist()
236
+ gen_text = self.token_encoder.decode(tokens)
237
+ print(f"> rank {self.ddp_rank} sample {i}: {gen_text}")
238
+
239
+
240
+ def estimate_lr(self, step, warmup_steps, max_steps, max_lr, min_lr):
241
+ """
242
+ Learning rate scheduler: Cosine-decay learning schedule with warmup
243
+ """
244
+ # 1) linear warmup for 'warmup_iters' steps
245
+ if step < warmup_steps:
246
+ return max_lr * (step+1) / warmup_steps
247
+ # 2) if step > lr_decay_iters, return min lr
248
+ if step > max_steps:
249
+ return min_lr
250
+ # 3) in between, use cosine decay down to min lr
251
+ decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
252
+ assert 0 <= decay_ratio <= 1
253
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
254
+ return min_lr + coeff * (max_lr - min_lr)
255
+
256
+
257
+ @dataclass
258
+ class GPTConfig:
259
+ context_length: int = 1024 # max context / sequence length
260
+ vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 <endoftext> token
261
+ num_layers: int = 12
262
+ embd_size: int = 768 # embedding dim
263
+ num_heads: int = 12
264
+
265
+
266
+ def get_args():
267
+ import argparse
268
+ parser = argparse.ArgumentParser(description="Hyperparameter Configuration")
269
+ parser.add_argument("--total_batch_size", type=int, default=524288, help="number of tokens processed for each weight update") # =2^19 tokens/step update, (~0.5M tokens used in openai gpt3 paper)
270
+ parser.add_argument("--mini_batch_size", type=int, default=32, help="setting of mini_batch_size is just a performance optimization. bigger gpu, bigger mini_batch_size")
271
+ parser.add_argument("--context_length", type=int, default=1024) # max sequence length (can also try 2048)
272
+ parser.add_argument("--num_layers", type=int, default=12)
273
+ parser.add_argument("--embd_size", type=int, default=768)
274
+ parser.add_argument("--num_heads", type=int, default=12)
275
+ parser.add_argument("--max_lr", type=float, default=1e-3)
276
+ parser.add_argument("--min_lr", type=float, default=1e-3 * 0.1)
277
+ parser.add_argument("--warmup_steps", type=int, default=715)
278
+ parser.add_argument("--weight_decay", type=float, default=0.1)
279
+ parser.add_argument("--num_epochs", type=int, default=5)
280
+ parser.add_argument("--steps_per_epoch", type=int, default=19073) # 10^10 / 2^19 ~ 19073 for 1 epoch on FineWebEdu-sample10BT
281
+ parser.add_argument("--eval_freq", type=int, default=250)
282
+ # parser.add_argument("--use_torch_compile", action='store_true') # default False
283
+ parser.add_argument("--seed", type=int, default=1337, help="Random seed for reproducibility")
284
+ parser.add_argument("--logdir", type=str, default="./logs/")
285
+ return parser.parse_args()
286
+
287
+
288
+ def main():
289
+ args = get_args()
290
+
291
+ # Print the hyperparameters
292
+ print("Hyperparameter Configuration:")
293
+ for key, value in vars(args).items():
294
+ print(f"{key}: {value}")
295
+
296
+ # create the logs directory if it doesn't exist
297
+ os.makedirs(args.logdir, exist_ok=True)
298
+ logpath = os.path.join(args.logdir, 'log.txt')
299
+ with open(logpath, 'w') as f:
300
+ pass
301
+
302
+ # set up DDP (distributed data parallel)
303
+ # 'torchrun' command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
304
+ # RANK and LOCAL_RANK same for (single node, multi-GPU) settings, may differ for (multinode,
305
+ # multi GPU) settings.
306
+ ddp = int(os.environ.get('RANK', -1)) != -1 # if this is a ddp run or not
307
+ if ddp:
308
+ # use of ddp requires CUDA
309
+ assert torch.cuda.is_available(), f'use of DDP requires CUDA'
310
+ dist.init_process_group(backend='nccl')
311
+ ddp_rank = int(os.environ['RANK'])
312
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
313
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
314
+ device = f'cuda:{ddp_local_rank}'
315
+ torch.cuda.set_device(device)
316
+ # master process (arbitrarily set to 0) will do printing, logging, checkpointing, etc.
317
+ master_process = ddp_rank == 0
318
+ else:
319
+ # not using ddp
320
+ ddp_rank = 0
321
+ ddp_local_rank = 0
322
+ ddp_world_size = 1
323
+ master_process = True # ddp_rank == 0
324
+ device = 'cpu'
325
+ if torch.cuda.is_available():
326
+ device = 'cuda'
327
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
328
+ device = 'mps' # for apple macbook GPUs
329
+ print(f'using device: {device}')
330
+
331
+ device_type = 'cuda' if device.startswith('cuda') else 'cpu'
332
+
333
+ # setting seed for reproducibility
334
+ np.random.seed(args.seed)
335
+ torch.manual_seed(args.seed) # sets seed for random number generation on CPU
336
+ if torch.cuda.is_available():
337
+ torch.cuda.manual_seed(args.seed) # sets seed for random number generation on GPU
338
+ torch.cuda.manual_seed_all(args.seed) # sets seed for all GPUs
339
+
340
+ assert args.total_batch_size % (args.mini_batch_size * args.context_length * ddp_world_size) == 0, f'ensure total_batch_size divisible by B*T*ddp_world_size'
341
+ grad_accum_steps = args.total_batch_size // (args.mini_batch_size * args.context_length * ddp_world_size)
342
+ if master_process:
343
+ print(f'desired batch size (number of tokens): {args.total_batch_size}')
344
+ print(f'gradient accumulation steps: {grad_accum_steps}')
345
+ print(f'GPU: {ddp_rank}, {ddp_local_rank}')
346
+
347
+ train_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='train')
348
+ val_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='val')
349
+
350
+ # create GPT model. each ddp process will create its own instance of the model but since the seed is fixed,
351
+ # they will create same identical model
352
+ gpt_config = GPTConfig(vocab_size=50304, # 50304 (nice number, lots of power of 2s) used instead of 50257 (bad, odd number)
353
+ context_length=args.context_length,
354
+ num_layers=args.num_layers,
355
+ num_heads=args.num_heads,
356
+ embd_size=args.embd_size
357
+ )
358
+ model = GPT(config=gpt_config)
359
+ # model = GPT.from_pretrained('gpt2') # init from OpenAI GPT-2
360
+ model.to(device) # move model to device
361
+ if use_torch_compile:
362
+ # use torch compile almost always unless debugging (requires compilation time, but makes training faster)
363
+ # speedup comes from reducing python overhead and GPU read/write
364
+ model = torch.compile(model)
365
+
366
+ if ddp:
367
+ # wraps the model in DDP container (forward pass is unchanged, but after backward pass,
368
+ # gradients computed across each processes averaged by DDP using 'AllReduce' and shared across
369
+ # all processes so that each process has same gradients)
370
+ model = DDP(model, device_ids=[ddp_local_rank])
371
+
372
+ raw_model = model.module if ddp else model
373
+ optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, lr=args.max_lr, device_type=device_type, master_process=master_process)
374
+ token_encoder = tiktoken.get_encoding('gpt2')
375
+
376
+ start_time = time.time()
377
+ # init the trainer object
378
+ trainer = Trainer(model, optimizer, train_loader, val_loader, token_encoder, args.eval_freq, grad_accum_steps,
379
+ ddp, ddp_rank, ddp_world_size, device, logpath)
380
+
381
+ max_steps = args.steps_per_epoch * args.num_epochs
382
+ trainer.train(max_steps, args.warmup_steps, args.max_lr, args.min_lr)
383
+
384
+ dt = (time.time() - start_time) / (60*60)
385
+ print(f"Total training time: {dt:.4f}hr")
386
+
387
+ if ddp:
388
+ dist.destroy_process_group()
389
+
390
+
391
+ if __name__ == "__main__":
392
+ main()