Upload 16 files
Browse files- .gitattributes +1 -0
- .gitignore +4 -0
- LICENSE +21 -0
- README.md +136 -0
- __init__.py +0 -0
- dataloader.cpython-311.pyc +0 -0
- dataloader.py +49 -0
- hellaswag_eval.cpython-311.pyc +0 -0
- hellaswag_eval.py +197 -0
- inference.py +93 -0
- log.txt +0 -0
- loss_eval.png +3 -0
- model.cpython-311.pyc +0 -0
- model.py +201 -0
- prepare_dataset.py +75 -0
- requirements.txt +4 -0
- train.py +392 -0
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
loss_eval.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.DS_Store
|
3 |
+
|
4 |
+
data/
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Saqib Azim
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GPT-2 Implementation in PyTorch
|
2 |
+
|
3 |
+
This project reproduces the GPT-2 model in pytorch and trains it from scratch on the FineWeb-Edu dataset - a high-quality subset of FineWeb dataset tailored for educational content. The goal is to offer a simplified, easy-to-understand PyTorch implementation. Note that this code is intended primarily for educational purposes and is not optimized for speed or production deployment.
|
4 |
+
|
5 |
+
### Key Features
|
6 |
+
- **Simplified PyTorch Implementation:** Designed to be accessible and well-commented for ease of understanding.
|
7 |
+
- **Customizable Training:** Hyperparameters are configurable via the command line and can be easily modified.
|
8 |
+
- **Multi-GPU Training Support:** Training can be performed using multiple GPUs using PyTorch Distributed Data Parallel (DDP).
|
9 |
+
|
10 |
+
|
11 |
+
## Repository Structure
|
12 |
+
- `src/train.py`: Script to train the GPT-2 model with customizable configurations.
|
13 |
+
- `src/model.py`: Contains the GPT-2 model implementation, including embedding layers, transformer blocks, and output layers.
|
14 |
+
- `src/dataloader.py`: Handles data loading and batching for the model during training.
|
15 |
+
- `src/prepare_dataset.py`: Downloads and preprocesses the FineWebEdu dataset. Run this script before starting the training process.
|
16 |
+
- `requirements.txt`: Python dependencies required to run the project.
|
17 |
+
|
18 |
+
|
19 |
+
## Getting Started
|
20 |
+
|
21 |
+
### Prerequisites
|
22 |
+
Ensure you have the following dependencies installed:
|
23 |
+
|
24 |
+
- numpy
|
25 |
+
- pytorch
|
26 |
+
- tiktoken
|
27 |
+
- transformers (from huggingface)
|
28 |
+
|
29 |
+
You can install all dependencies with:
|
30 |
+
```bash
|
31 |
+
pip install -r requirements.txt
|
32 |
+
```
|
33 |
+
|
34 |
+
## Dataset
|
35 |
+
|
36 |
+
The GPT-2 model was originally trained on the WebText dataset (not publicly released). For this project, we use the FineWebEdu-10B dataset—a specialized educational subset of the FineWeb dataset. It contains approximately 10 billion tokens focused on high-quality educational content.
|
37 |
+
|
38 |
+
To download and prepare the dataset:
|
39 |
+
```bash
|
40 |
+
python prepare_dataset.py
|
41 |
+
```
|
42 |
+
|
43 |
+
### Running the Training Script
|
44 |
+
You can start training the GPT-2 model using the following commands:
|
45 |
+
|
46 |
+
You can experiment with different training and model config hyperparameters by setting them through the command line.
|
47 |
+
|
48 |
+
- Single-GPU Training:
|
49 |
+
```bash
|
50 |
+
python train.py --num_epochs=5
|
51 |
+
```
|
52 |
+
|
53 |
+
- Multi-GPU Training (uses Pytorch DDP):
|
54 |
+
```bash
|
55 |
+
torchrun --standalone --nproc_per_node=4 train.py # adjust number of GPUs as per availability
|
56 |
+
```
|
57 |
+
|
58 |
+
For more details on the training process and customizing hyperparameters, refer to the `src/train.py` script.
|
59 |
+
|
60 |
+
Training was performed from scratch using multiple GPUs with PyTorch's DDP framework.
|
61 |
+
|
62 |
+
|
63 |
+
After training the model, you can generate text based on custom prompts. Use the `src/inference.py` script to interact with the trained model and generate creative continuations.
|
64 |
+
|
65 |
+
Run the inference script from the command line with the following syntax:
|
66 |
+
```bash
|
67 |
+
python3 inference.py --prompt="I am a AI and robotics enthusiast, I want to" --max_tokens=50 --num_seq=5
|
68 |
+
```
|
69 |
+
|
70 |
+
This command will output 5 unique text sequences, each starting with the provided prompt and continuing for up to 50 tokens.
|
71 |
+
|
72 |
+
|
73 |
+
### Model Architecture
|
74 |
+
The GPT-2 model consists of the following components:
|
75 |
+
|
76 |
+
- **Token Embedding Layer:** Encodes input tokens to dense vectors.
|
77 |
+
- **Positional Embedding Layer:** Adds positional information to the token embeddings.
|
78 |
+
- **Transformer Blocks:** Each block includes layer normalization, multi-headed self-attention, and an MLP with residual connections.
|
79 |
+
- **Output Head:** Predicts the next token in the sequence based on the preceding context.
|
80 |
+
|
81 |
+
The model is trained to predict the next token in a sequence, enabling coherent text generation. For token generation, I have used huggingface `tiktoken` library that generates 50,257 tokens (same as GPT-2).
|
82 |
+
|
83 |
+
|
84 |
+
### Results
|
85 |
+
|
86 |
+
The GPT-2 model was trained for roughly 95,365 steps (5 epochs) using two NVIDIA A100 GPUs. Training took approximately 46 hours.
|
87 |
+
|
88 |
+

|
89 |
+
|
90 |
+
To generate from the trained model, we provide an input prompt sequence, and ask the model to generate the next N tokens. Here are some samples of text generated by the trained model:
|
91 |
+
|
92 |
+
- **prompt text:** "Hello, I am a language model"
|
93 |
+
- **Model output:**
|
94 |
+
```
|
95 |
+
- Hello, I am a language modeler. I use the API, in whatever language I require it to write out. On first, I define a model for
|
96 |
+
|
97 |
+
- Hello, I am a language model expert and need help with building these model. The project is designed in C++ and the Python library is used. The project
|
98 |
+
|
99 |
+
- Hello, I am a language model developer at Google Cloud. It has great features on most platforms which makes it one of most popular. It also integrates with third
|
100 |
+
```
|
101 |
+
|
102 |
+
- **prompt text:** "I am a machine learning and robotics enthusiast, and I want to"
|
103 |
+
- **Model output:**
|
104 |
+
```
|
105 |
+
- I am a machine learning and robotics enthusiast, and I want to share my excitement about this work as soon as possible.
|
106 |
+
The purpose of this project was to help the engineers and programmers understand how the HURD and AVR circuits work and how
|
107 |
+
|
108 |
+
- I am a machine learning and robotics enthusiast, and I want to try and train a new machine learning-based system such as a deep learning algorithm that is completely new to me.
|
109 |
+
|
110 |
+
- I am a machine learning and robotics enthusiast, and I want to help you by helping you improve your Python programming skills.To understand the concept of machine learning, you must understand the concept of a machine learning model. Machine learning models
|
111 |
+
|
112 |
+
- I am a machine learning and robotics enthusiast, and I want to be a part of the team.<|endoftext|>In your next project, you need to gather some interesting information from your team team. This data will help form a map that you can use to
|
113 |
+
|
114 |
+
- I am a machine learning and robotics enthusiast, and I want to create a new, more sophisticated machine learning-based library for programming languages. To start, I am interested in the machine learning (ML) capabilities of new AI methods and techniques.
|
115 |
+
```
|
116 |
+
|
117 |
+
|
118 |
+
## Potential Future Work
|
119 |
+
|
120 |
+
1. **Dataset Shuffling:** The current training code does not shuffle the dataset after each epoch. Implementing dataset shuffling between epochs could improve the model's ability to generalize and prevent overfitting to the order of the training data.
|
121 |
+
|
122 |
+
2. **Extended Training:** Experiment with training the model for more epochs to potentially improve performance. Monitor validation loss to determine the optimal number of epochs and implement early stopping if necessary.
|
123 |
+
|
124 |
+
|
125 |
+
## References:
|
126 |
+
- [Language Models are Unsupervised Multitask Learners (GPT-2 Paper)](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
|
127 |
+
- [GPT-3 Paper: Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165)
|
128 |
+
- [FineWebEdu-10B Dataset](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
|
129 |
+
- [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135)
|
130 |
+
- [Attention is all you need](https://arxiv.org/abs/1706.03762)
|
131 |
+
- [HellaSwag: Can a Machine Really Finish Your Sentence?](https://arxiv.org/abs/1905.07830)
|
132 |
+
- Andrej Karpathy's Video Tutorial on GPT
|
133 |
+
|
134 |
+
|
135 |
+
## Acknowledgments
|
136 |
+
This implementation is inspired by Andrej Karpathy’s tutorial and his approach to making complex AI concepts more accessible.
|
__init__.py
ADDED
File without changes
|
dataloader.cpython-311.pyc
ADDED
Binary file (4.39 kB). View file
|
|
dataloader.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
script_dir = os.path.dirname(__file__)
|
6 |
+
|
7 |
+
|
8 |
+
class DataLoaderLite:
|
9 |
+
""" A simple dataloader for FineWebEdu-10B dataset """
|
10 |
+
|
11 |
+
def __init__(self, B, T, process_rank, num_processes, split='train'):
|
12 |
+
super().__init__()
|
13 |
+
self.B, self.T = B, T
|
14 |
+
self.process_rank = process_rank
|
15 |
+
self.num_processes = num_processes
|
16 |
+
assert split in {'train', 'val'}
|
17 |
+
|
18 |
+
# get the shard filenames
|
19 |
+
data_root = os.path.join(script_dir, "../data/edu_fineweb10B")
|
20 |
+
shard_filenames = os.listdir(data_root)
|
21 |
+
shard_filenames = sorted([filename for filename in shard_filenames if split in filename])
|
22 |
+
self.shard_filepaths = [os.path.join(data_root, filename) for filename in shard_filenames]
|
23 |
+
assert len(self.shard_filepaths) > 0, f'no shards found for split {split}'
|
24 |
+
master_process = process_rank == 0
|
25 |
+
if master_process:
|
26 |
+
print(f'found {len(self.shard_filepaths)} shards for split {split}')
|
27 |
+
self.reset()
|
28 |
+
|
29 |
+
def load_tokens(self, filepath):
|
30 |
+
tokens = torch.tensor(np.load(filepath).astype(np.int32), dtype=torch.long)
|
31 |
+
return tokens
|
32 |
+
|
33 |
+
def reset(self):
|
34 |
+
# state, init at shard 0
|
35 |
+
self.curr_shard = 0
|
36 |
+
self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard])
|
37 |
+
self.curr_pos = self.B * self.T * self.process_rank
|
38 |
+
|
39 |
+
def next_batch(self):
|
40 |
+
B, T = self.B, self.T
|
41 |
+
batch = self.tokens[self.curr_pos : self.curr_pos + B*T + 1]
|
42 |
+
x_batch = batch[:-1].view(B, T)
|
43 |
+
y_batch = batch[1:].view(B, T)
|
44 |
+
self.curr_pos += B * T * self.num_processes
|
45 |
+
if self.curr_pos + (B * T + 1) > len(self.tokens):
|
46 |
+
self.curr_shard = (self.curr_shard + 1) % len(self.shard_filepaths)
|
47 |
+
self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard])
|
48 |
+
self.curr_pos = self.B * self.T * self.process_rank
|
49 |
+
return x_batch, y_batch
|
hellaswag_eval.cpython-311.pyc
ADDED
Binary file (12.8 kB). View file
|
|
hellaswag_eval.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Downloads and evaluates HellaSwag in Python.
|
3 |
+
https://github.com/rowanz/hellaswag
|
4 |
+
|
5 |
+
Example HellaSwag json item:
|
6 |
+
|
7 |
+
{"ind": 24, "activity_label": "Roof shingle removal", "ctx_a": "A man is sitting on a roof.", "ctx_b": "he", "ctx": "A man is sitting on a roof. he", "split": "val", "split_type": "indomain", "label": 3, "endings": ["is using wrap to wrap a pair of skis.", "is ripping level tiles off.", "is holding a rubik's cube.", "starts pulling up roofing on a roof."], "source_id": "activitynet~v_-JhWjGDPHMY"}
|
8 |
+
|
9 |
+
ind: dataset ID
|
10 |
+
activity_label: The ActivityNet or WikiHow label for this example
|
11 |
+
context: There are two formats. The full context is in ctx. When the context ends in an (incomplete) noun phrase, like for ActivityNet, this incomplete noun phrase is in ctx_b, and the context up until then is in ctx_a. This can be useful for models such as BERT that need the last sentence to be complete. However, it's never required. If ctx_b is nonempty, then ctx is the same thing as ctx_a, followed by a space, then ctx_b.
|
12 |
+
endings: a list of 4 endings. The correct index is given by label (0,1,2, or 3)
|
13 |
+
split: train, val, or test.
|
14 |
+
split_type: indomain if the activity label is seen during training, else zeroshot
|
15 |
+
source_id: Which video or WikiHow article this example came from
|
16 |
+
|
17 |
+
gpt2 (124M)
|
18 |
+
- eleuther harness reports acc 28.92%, acc_norm 31.14% (multiple choice style)
|
19 |
+
- this script: 10042 acc: 0.2859 acc_norm: 0.2955 (completion style)
|
20 |
+
|
21 |
+
gpt2-xl (1558M)
|
22 |
+
- eleuther harness reports acc 40.04%, acc_norm 50.89% (multiple choice style)
|
23 |
+
- this script: 10042 acc: 0.3842 acc_norm: 0.4893 (completion style)
|
24 |
+
|
25 |
+
The validation set of HellaSwag has a total of 10,042 examples.
|
26 |
+
"""
|
27 |
+
|
28 |
+
import os
|
29 |
+
import json
|
30 |
+
import requests
|
31 |
+
import tiktoken
|
32 |
+
from tqdm import tqdm
|
33 |
+
import torch
|
34 |
+
import torch.nn as nn
|
35 |
+
from torch.nn import functional as F
|
36 |
+
from transformers import GPT2LMHeadModel
|
37 |
+
|
38 |
+
# -----------------------------------------------------------------------------
|
39 |
+
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "hellaswag")
|
40 |
+
|
41 |
+
def download_file(url: str, fname: str, chunk_size=1024):
|
42 |
+
"""Helper function to download a file from a given url"""
|
43 |
+
resp = requests.get(url, stream=True)
|
44 |
+
total = int(resp.headers.get("content-length", 0))
|
45 |
+
with open(fname, "wb") as file, tqdm(
|
46 |
+
desc=fname,
|
47 |
+
total=total,
|
48 |
+
unit="iB",
|
49 |
+
unit_scale=True,
|
50 |
+
unit_divisor=1024,
|
51 |
+
) as bar:
|
52 |
+
for data in resp.iter_content(chunk_size=chunk_size):
|
53 |
+
size = file.write(data)
|
54 |
+
bar.update(size)
|
55 |
+
|
56 |
+
hellaswags = {
|
57 |
+
"train": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_train.jsonl",
|
58 |
+
"val": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl",
|
59 |
+
"test": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_test.jsonl",
|
60 |
+
}
|
61 |
+
|
62 |
+
enc = tiktoken.get_encoding("gpt2")
|
63 |
+
|
64 |
+
def download(split):
|
65 |
+
"""Downloads HellaSwag DATA_CACHE_DIR"""
|
66 |
+
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
|
67 |
+
data_url = hellaswags[split]
|
68 |
+
data_filename = os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl")
|
69 |
+
if not os.path.exists(data_filename):
|
70 |
+
print(f"Downloading {data_url} to {data_filename}...")
|
71 |
+
download_file(data_url, data_filename)
|
72 |
+
|
73 |
+
def render_example(example):
|
74 |
+
"""
|
75 |
+
Given the example as a dictionary, render it as three torch tensors:
|
76 |
+
- tokens (the tokens of context + completion, of size 4xN, as there are always 4 candidates)
|
77 |
+
- mask (is 1 in the region of the candidate completion, where we evaluate likelihoods)
|
78 |
+
- label (the index of the correct completion, which we hope has the highest likelihood)
|
79 |
+
"""
|
80 |
+
ctx = example["ctx"]
|
81 |
+
label = example["label"]
|
82 |
+
endings = example["endings"]
|
83 |
+
# data needed to reproduce this eval on the C size
|
84 |
+
data = {
|
85 |
+
"label": label,
|
86 |
+
"ctx_tokens": None,
|
87 |
+
"ending_tokens": [],
|
88 |
+
}
|
89 |
+
# gather up all the tokens
|
90 |
+
ctx_tokens = enc.encode(ctx)
|
91 |
+
data["ctx_tokens"] = ctx_tokens
|
92 |
+
tok_rows = []
|
93 |
+
mask_rows = []
|
94 |
+
for end in endings:
|
95 |
+
end_tokens = enc.encode(" " + end) # note: prepending " " because GPT-2 tokenizer
|
96 |
+
tok_rows.append(ctx_tokens + end_tokens)
|
97 |
+
mask_rows.append([0]*len(ctx_tokens) + [1]*len(end_tokens))
|
98 |
+
data["ending_tokens"].append(end_tokens)
|
99 |
+
|
100 |
+
# have to be careful during the collation because the number of tokens in each row can differ
|
101 |
+
max_len = max(len(row) for row in tok_rows)
|
102 |
+
tokens = torch.zeros((4, max_len), dtype=torch.long)
|
103 |
+
mask = torch.zeros((4, max_len), dtype=torch.long)
|
104 |
+
for i, (tok_row, mask_row) in enumerate(zip(tok_rows, mask_rows)):
|
105 |
+
tokens[i, :len(tok_row)] = torch.tensor(tok_row)
|
106 |
+
mask[i, :len(mask_row)] = torch.tensor(mask_row)
|
107 |
+
return data, tokens, mask, label
|
108 |
+
|
109 |
+
def iterate_examples(split):
|
110 |
+
# there are 10,042 examples in total in val
|
111 |
+
download(split)
|
112 |
+
with open(os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl"), "r") as f:
|
113 |
+
for line in f:
|
114 |
+
example = json.loads(line)
|
115 |
+
yield example
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def evaluate(model_type, device):
|
119 |
+
torch.set_float32_matmul_precision('high') # use tf32
|
120 |
+
model = GPT2LMHeadModel.from_pretrained(model_type)
|
121 |
+
model.to(device)
|
122 |
+
# model = torch.compile(model) # optionally torch compile the model
|
123 |
+
num_correct_norm = 0
|
124 |
+
num_correct = 0
|
125 |
+
num_total = 0
|
126 |
+
for example in iterate_examples("val"):
|
127 |
+
data, tokens, mask, label = render_example(example)
|
128 |
+
tokens = tokens.to(device)
|
129 |
+
mask = mask.to(device)
|
130 |
+
|
131 |
+
# get the logits
|
132 |
+
logits = model(tokens).logits
|
133 |
+
# evaluate the autoregressive loss at all positions
|
134 |
+
shift_logits = (logits[..., :-1, :]).contiguous()
|
135 |
+
shift_tokens = (tokens[..., 1:]).contiguous()
|
136 |
+
flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
137 |
+
flat_shift_tokens = shift_tokens.view(-1)
|
138 |
+
shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
|
139 |
+
shift_losses = shift_losses.view(tokens.size(0), -1)
|
140 |
+
# now get the average loss just for the completion region (where mask == 1), in each row
|
141 |
+
shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
|
142 |
+
masked_shift_losses = shift_losses * shift_mask
|
143 |
+
# sum and divide by the number of 1s in the mask
|
144 |
+
sum_loss = masked_shift_losses.sum(dim=1)
|
145 |
+
avg_loss = sum_loss / shift_mask.sum(dim=1)
|
146 |
+
# now we have a loss for each of the 4 completions
|
147 |
+
# the one with the lowest loss should be the most likely
|
148 |
+
pred = sum_loss.argmin().item()
|
149 |
+
pred_norm = avg_loss.argmin().item()
|
150 |
+
|
151 |
+
# accumulate stats
|
152 |
+
num_total += 1
|
153 |
+
num_correct += int(pred == label)
|
154 |
+
num_correct_norm += int(pred_norm == label)
|
155 |
+
print(f"{num_total} acc_norm: {num_correct_norm}/{num_total}={num_correct_norm/num_total:.4f}")
|
156 |
+
|
157 |
+
# debug: pretty print a few examples, and the losses in each case
|
158 |
+
if num_total < 10:
|
159 |
+
print("---")
|
160 |
+
print(f"Context:\n {example['ctx']}")
|
161 |
+
print(f"Endings:")
|
162 |
+
for i, end in enumerate(example["endings"]):
|
163 |
+
print(f"{i} (loss: {avg_loss[i].item():.4f}) {end}")
|
164 |
+
print(f"predicted: {pred_norm}, actual: {label}")
|
165 |
+
|
166 |
+
|
167 |
+
def get_most_likely_row(tokens, mask, logits):
|
168 |
+
"""
|
169 |
+
helper function for HellaSwag eval. Takes tokens, mask, and logits,
|
170 |
+
returns the index of the completion with the lowest loss
|
171 |
+
"""
|
172 |
+
# evaluate the autoregressive loss at all positions
|
173 |
+
shift_logits = (logits[..., :-1, :]).contiguous()
|
174 |
+
shift_tokens = (tokens[..., 1:]).contiguous()
|
175 |
+
flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
176 |
+
flat_shift_tokens = shift_tokens.view(-1)
|
177 |
+
shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
|
178 |
+
shift_losses = shift_losses.view(tokens.size(0), -1)
|
179 |
+
# now get the average loss just for the completion region (where mask == 1), in each row
|
180 |
+
shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
|
181 |
+
masked_shift_losses = shift_losses * shift_mask
|
182 |
+
# sum and divide by the number of 1s in the mask
|
183 |
+
sum_loss = masked_shift_losses.sum(dim=1)
|
184 |
+
avg_loss = sum_loss / shift_mask.sum(dim=1)
|
185 |
+
# now we have a loss for each of the 4 completions
|
186 |
+
# the one with the lowest loss should be the most likely
|
187 |
+
pred_norm = avg_loss.argmin().item()
|
188 |
+
return pred_norm
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
import argparse
|
193 |
+
parser = argparse.ArgumentParser()
|
194 |
+
parser.add_argument("-m", "--model_type", type=str, default="gpt2", help="the model type to use")
|
195 |
+
parser.add_argument("-d", "--device", type=str, default="cuda", help="the device to use")
|
196 |
+
args = parser.parse_args()
|
197 |
+
evaluate(args.model_type, args.device)
|
inference.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import tiktoken
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
from model import GPT
|
8 |
+
|
9 |
+
|
10 |
+
class GPT2Inference:
|
11 |
+
""" To generate text sequences using a trained GPT2 model """
|
12 |
+
|
13 |
+
def __init__(self, model, token_encoder, device):
|
14 |
+
self.model = model
|
15 |
+
self.token_encoder = token_encoder
|
16 |
+
self.device = device
|
17 |
+
self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
|
18 |
+
|
19 |
+
def generate_sequences(self, prompt, num_seq=5, max_tokens=50):
|
20 |
+
self.model.eval()
|
21 |
+
tokens = self.token_encoder.encode(prompt)
|
22 |
+
tokens = torch.tensor(tokens, dtype=torch.long) # (n,) n : current sequence length
|
23 |
+
tokens = tokens.unsqueeze(0).repeat(num_seq, 1) # (1,n) --> (num_seq, n)
|
24 |
+
gen_tokens = tokens.to(self.device)
|
25 |
+
# create a different rng generator so as not to impact the global rng state used for training
|
26 |
+
sample_rng = torch.Generator(device=self.device).manual_seed(42)
|
27 |
+
|
28 |
+
# generate new tokens one token at a time until the sequence length becomes 'max_tokens'
|
29 |
+
while gen_tokens.shape[-1] <= max_tokens:
|
30 |
+
with torch.no_grad():
|
31 |
+
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
|
32 |
+
logits, loss = self.model(gen_tokens) # (num_seq, n, vocab_size)
|
33 |
+
logits = logits[:, -1, :] # (num_seq, vocab_size)
|
34 |
+
probs = F.softmax(logits, dim=-1) # (num_seq, vocab_size)
|
35 |
+
# take top-k 50 probs
|
36 |
+
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # (num_seq, 50), (num_seq, 50)
|
37 |
+
# sample a token from top-50 probabilities
|
38 |
+
ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) # (num_seq, 1)
|
39 |
+
next_tok = torch.gather(topk_indices, -1, ix) # (num_seq, 1)
|
40 |
+
gen_tokens = torch.cat([gen_tokens, next_tok], dim=1)
|
41 |
+
# decode generated tokens and print generated text
|
42 |
+
for i in range(num_seq):
|
43 |
+
tokens = gen_tokens[i, :max_tokens].tolist()
|
44 |
+
gen_text = self.token_encoder.decode(tokens)
|
45 |
+
print(f"> sample {i}: {gen_text}")
|
46 |
+
|
47 |
+
|
48 |
+
def parse_args():
|
49 |
+
import argparse
|
50 |
+
parser = argparse.ArgumentParser()
|
51 |
+
parser.add_argument('--prompt', type=str, default="Hello, I am a language model,")
|
52 |
+
parser.add_argument('--num_seq', type=int, default=5)
|
53 |
+
parser.add_argument('--max_tokens', type=int, default=50)
|
54 |
+
args = parser.parse_args()
|
55 |
+
return args
|
56 |
+
|
57 |
+
|
58 |
+
@dataclass
|
59 |
+
class GPTConfig:
|
60 |
+
context_length: int = 1024 # max context / sequence length
|
61 |
+
vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 <endoftext> token
|
62 |
+
num_layers: int = 12
|
63 |
+
embd_size: int = 768 # embedding dim
|
64 |
+
num_heads: int = 12
|
65 |
+
|
66 |
+
|
67 |
+
def inference(args=None):
|
68 |
+
if args is None:
|
69 |
+
args = parse_args()
|
70 |
+
|
71 |
+
device = 'cpu'
|
72 |
+
if torch.cuda.is_available():
|
73 |
+
device = 'cuda'
|
74 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
75 |
+
device = 'mps' # for apple macbook GPUs
|
76 |
+
print(f'using device: {device}')
|
77 |
+
|
78 |
+
model_path = './logs/model_95364.pt'
|
79 |
+
checkpoint = torch.load(model_path, weights_only=False)
|
80 |
+
print(f"loaded model from: {model_path}")
|
81 |
+
# print(checkpoint['model'].keys())
|
82 |
+
|
83 |
+
model = GPT(config=checkpoint['config'])
|
84 |
+
model.load_state_dict(checkpoint['model'])
|
85 |
+
model = model.to(device)
|
86 |
+
token_encoder = tiktoken.get_encoding('gpt2')
|
87 |
+
generator = GPT2Inference(model, token_encoder, device)
|
88 |
+
|
89 |
+
generator.generate_sequences(args.prompt, args.num_seq, args.max_tokens)
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == '__main__':
|
93 |
+
inference()
|
log.txt
ADDED
File without changes
|
loss_eval.png
ADDED
![]() |
Git LFS Details
|
model.cpython-311.pyc
ADDED
Binary file (16.6 kB). View file
|
|
model.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from dataclasses import dataclass
|
5 |
+
import inspect
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class GPTConfig:
|
10 |
+
context_length: int = 1024 # max context / sequence length
|
11 |
+
vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 <endoftext> token
|
12 |
+
num_layers: int = 12
|
13 |
+
embd_size: int = 768 # embedding dim
|
14 |
+
num_heads: int = 12
|
15 |
+
|
16 |
+
|
17 |
+
class CausalSelfAttention(nn.Module):
|
18 |
+
def __init__(self, config):
|
19 |
+
super().__init__()
|
20 |
+
# 'embd_size' sized vector divided into 'num_heads' heads
|
21 |
+
assert config.embd_size % config.num_heads == 0, f"embedding dim should be divisible by number of heads"
|
22 |
+
self.num_heads = config.num_heads
|
23 |
+
self.embd_size = config.embd_size
|
24 |
+
# batched key, query, and value projections for all heads
|
25 |
+
self.c_attn = nn.Linear(config.embd_size, 3 * config.embd_size)
|
26 |
+
self.c_proj = nn.Linear(config.embd_size, config.embd_size)
|
27 |
+
self.c_proj.SCALE_INIT = 1.0
|
28 |
+
# not really a bias, more of a mask, but following OpenAI/HF naming convention
|
29 |
+
# self.register_buffer("bias", torch.tril(torch.ones(config.context_length, config.context_length)).view(1, 1, config.context_length, config.context_length))
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
B, T, C = x.shape
|
33 |
+
# calculate query, key, values for all heads in a batch and move head forward to be the batch dim
|
34 |
+
# nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
|
35 |
+
# e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels
|
36 |
+
qkv = self.c_attn(x) # (B, T, 3C)
|
37 |
+
q, k, v = qkv.split(self.embd_size, dim=-1) # (B,T,C), (B,T,C), (B,T,C)
|
38 |
+
q = q.view(B, T, self.num_heads, self.embd_size // self.num_heads).transpose(1, 2) # (B,nh,T,hs)
|
39 |
+
k = k.view(B, T, self.num_heads, self.embd_size // self.num_heads).transpose(1, 2) # (B,nh,T,hs)
|
40 |
+
v = v.view(B, T, self.num_heads, self.embd_size // self.num_heads).transpose(1, 2) # (B,nh,T,hs)
|
41 |
+
# attn = q @ k.transpose(-2, -1) / np.sqrt(k.shape[-1]) # (B,nh,T,hs) @ (B,nh,hs,T) --> (B,nh,T,T)
|
42 |
+
# attn = attn.masked_fill(self.bias[:,:,:T,:T] == 0, float("-inf"))
|
43 |
+
# attn = F.softmax(attn, dim=-1)
|
44 |
+
# out = attn @ v # (B,nh,T,T) @ (B,nh,T,hs) --> (B,nh,T,hs)
|
45 |
+
# flash-attention paper (significantly faster, but logically the same as above 4 lines)
|
46 |
+
out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # (B,nh,T,hs)
|
47 |
+
out = out.transpose(1, 2).contiguous().view(B, T, C) # (B,nh,T,hs) --> (B,T,nh,hs) --> (B,T,C=nh*hs)
|
48 |
+
out = self.c_proj(out) # (B,T,C) --> (B,T,C)
|
49 |
+
return out
|
50 |
+
|
51 |
+
|
52 |
+
class MLP(nn.Module):
|
53 |
+
def __init__(self, config):
|
54 |
+
super().__init__()
|
55 |
+
self.c_fc = nn.Linear(config.embd_size, 4 * config.embd_size)
|
56 |
+
self.gelu = nn.GELU(approximate='tanh') # approximate='tanh' used to try to reproduce gpt2 paper
|
57 |
+
self.c_proj = nn.Linear(4 * config.embd_size, config.embd_size)
|
58 |
+
self.c_proj.SCALE_INIT = 1.0
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
x = self.c_fc(x)
|
62 |
+
x = self.gelu(x)
|
63 |
+
x = self.c_proj(x)
|
64 |
+
return x
|
65 |
+
|
66 |
+
|
67 |
+
class Block(nn.Module):
|
68 |
+
""" Transformer Encoder block """
|
69 |
+
|
70 |
+
def __init__(self, config):
|
71 |
+
super().__init__()
|
72 |
+
self.ln_1 = nn.LayerNorm(config.embd_size)
|
73 |
+
self.attn = CausalSelfAttention(config)
|
74 |
+
self.ln_2 = nn.LayerNorm(config.embd_size)
|
75 |
+
self.mlp = MLP(config)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
x = x + self.attn(self.ln_1(x))
|
79 |
+
x = x + self.mlp(self.ln_2(x))
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
class GPT(nn.Module):
|
84 |
+
def __init__(self, config):
|
85 |
+
super().__init__()
|
86 |
+
self.config = config
|
87 |
+
self.transformer = nn.ModuleDict(dict(
|
88 |
+
wte = nn.Embedding(self.config.vocab_size, self.config.embd_size),
|
89 |
+
wpe = nn.Embedding(self.config.context_length, self.config.embd_size),
|
90 |
+
h = nn.ModuleList([Block(self.config) for _ in range(self.config.num_layers)]),
|
91 |
+
ln_f = nn.LayerNorm(self.config.embd_size)
|
92 |
+
))
|
93 |
+
# language modeling head
|
94 |
+
self.lm_head = nn.Linear(self.config.embd_size, self.config.vocab_size, bias=False)
|
95 |
+
# weight sharing scheme (reduces 768*50267=~40M params, fewer params, more efficient)
|
96 |
+
self.transformer.wte.weight = self.lm_head.weight
|
97 |
+
# init params (iterates over all submodules and applies _init_weights)
|
98 |
+
self.apply(self._init_weights)
|
99 |
+
|
100 |
+
def _init_weights(self, module):
|
101 |
+
if isinstance(module, nn.Linear):
|
102 |
+
std = 0.02
|
103 |
+
if hasattr(module, 'SCALE_INIT'):
|
104 |
+
std /= (2 * self.config.num_layers)**0.5
|
105 |
+
torch.nn.init.normal_(module.weight, mean=0, std=std) # as per openai gpt-2 source code
|
106 |
+
if module.bias is not None:
|
107 |
+
torch.nn.init.zeros_(module.bias)
|
108 |
+
elif isinstance(module, nn.Embedding):
|
109 |
+
torch.nn.init.normal_(module.weight, mean=0, std=0.02)
|
110 |
+
|
111 |
+
def forward(self, idx, targets=None):
|
112 |
+
B, T = idx.shape
|
113 |
+
assert T <= self.config.context_length, f'sequence length {T} should be <= {self.config.context_length}'
|
114 |
+
pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # (T,)
|
115 |
+
pos_embd = self.transformer.wpe(pos) # (T, embd_size)
|
116 |
+
tok_embd = self.transformer.wte(idx) # (B, T, embd_size)
|
117 |
+
x = pos_embd + tok_embd # (B, T, embd_size)
|
118 |
+
for block in self.transformer.h:
|
119 |
+
x = block(x)
|
120 |
+
x = self.transformer.ln_f(x) # (B, T, embd_size)
|
121 |
+
logits = self.lm_head(x) # (B, T, vocab_size)
|
122 |
+
loss = None
|
123 |
+
if targets is not None:
|
124 |
+
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
|
125 |
+
return logits, loss
|
126 |
+
|
127 |
+
@classmethod
|
128 |
+
def from_pretrained(cls, model_type):
|
129 |
+
""" Loads pretrained GPT2 model weights from huggingface """
|
130 |
+
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
|
131 |
+
from transformers import GPT2LMHeadModel
|
132 |
+
print(f"loading weights from pretrained gpt: {model_type}")
|
133 |
+
|
134 |
+
config_args = {
|
135 |
+
'gpt2': dict(num_layers=12, num_heads=12, embd_size=768), # 124M params
|
136 |
+
'gpt2-medium': dict(num_layers=24, num_heads=16, embd_size=1024), # 350M params
|
137 |
+
'gpt2-large': dict(num_layers=36, num_heads=20, embd_size=1280), # 774M params
|
138 |
+
'gpt2-xl': dict(num_layers=48, num_heads=25, embd_size=1600), # 1558M params
|
139 |
+
}[model_type]
|
140 |
+
config_args['vocab_size'] = 50257
|
141 |
+
config_args['context_length'] = 1024
|
142 |
+
|
143 |
+
# create a from-scratch minGPT model
|
144 |
+
config = GPTConfig(**config_args)
|
145 |
+
model = GPT(config)
|
146 |
+
sd = model.state_dict()
|
147 |
+
sd_keys = sd.keys()
|
148 |
+
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')]
|
149 |
+
|
150 |
+
# init a huggingface transformers model
|
151 |
+
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
|
152 |
+
sd_hf = model_hf.state_dict()
|
153 |
+
sd_keys_hf = sd_hf.keys()
|
154 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')]
|
155 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')]
|
156 |
+
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
157 |
+
|
158 |
+
assert len(sd_keys) == len(sd_keys_hf), f"mismatched keys {len(sd_keys)} != {len(sd_keys_hf)}"
|
159 |
+
|
160 |
+
# copy while ensuring all parameters are aligned in names and shape
|
161 |
+
for k in sd_keys_hf:
|
162 |
+
if any(k.endswith(w) for w in transposed):
|
163 |
+
# need to transpose Conv1D weights
|
164 |
+
assert sd_hf[k].shape[::-1] == sd[k].shape
|
165 |
+
with torch.no_grad():
|
166 |
+
sd[k].copy_(sd_hf[k].T)
|
167 |
+
else:
|
168 |
+
assert sd_hf[k].shape == sd[k].shape
|
169 |
+
with torch.no_grad():
|
170 |
+
sd[k].copy_(sd_hf[k])
|
171 |
+
return model
|
172 |
+
|
173 |
+
def configure_optimizers(self, weight_decay, lr, device_type, master_process):
|
174 |
+
"""
|
175 |
+
Essentially implements weight decay (regularization tool, by decaying the weights, we
|
176 |
+
forcing the optimizer to use more of the weights, and not allowing any single weight to dominate)
|
177 |
+
"""
|
178 |
+
# start with all of the candidate params (that require gradient)
|
179 |
+
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
|
180 |
+
|
181 |
+
# create optim groups: any parameters that are 2D will be weight decayed, otherwise no.
|
182 |
+
# i.e., all weight tensors in matmuls + embeddings will decay, whereas biases and layernorms won't be decayed
|
183 |
+
decay_params = [p for pn, p in param_dict.items() if p.dim() >= 2]
|
184 |
+
nodecay_params = [p for pn, p in param_dict.items() if p.dim() < 2]
|
185 |
+
optim_groups = [
|
186 |
+
{'params': decay_params, 'weight_decay': weight_decay},
|
187 |
+
{'params': nodecay_params, 'weight_decay': 0.0}
|
188 |
+
]
|
189 |
+
num_decay_params = sum(p.numel() for p in decay_params)
|
190 |
+
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
191 |
+
if master_process:
|
192 |
+
print(f'num decay parameter tensors: {len(decay_params)} with {num_decay_params:,} parameters')
|
193 |
+
print(f'num nodecay parameter tensors: {len(nodecay_params)} with {num_nodecay_params:,} parameters')
|
194 |
+
|
195 |
+
# use fused version of AdamW optimizer (faster than non-fused version)
|
196 |
+
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
197 |
+
use_fused = fused_available and device_type == 'cuda'
|
198 |
+
if master_process:
|
199 |
+
print(f'using fused AdamW optimizer: {use_fused}')
|
200 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
|
201 |
+
return optimizer
|
prepare_dataset.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import multiprocessing as mp
|
2 |
+
from datasets import load_dataset, DownloadConfig
|
3 |
+
import backoff
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import numpy as np
|
7 |
+
import tiktoken
|
8 |
+
|
9 |
+
# Function to process individual dataset items
|
10 |
+
def process_data(item):
|
11 |
+
"""
|
12 |
+
Process a single dataset item.
|
13 |
+
Replace this with your actual processing logic (e.g., tokenization).
|
14 |
+
"""
|
15 |
+
# Example: Tokenize text using tiktoken (adjust based on your needs)
|
16 |
+
encoder = tiktoken.get_encoding('gpt2')
|
17 |
+
text = item.get('text', '') # Assuming dataset has a 'text' field
|
18 |
+
tokens = encoder.encode(text)
|
19 |
+
return tokens
|
20 |
+
|
21 |
+
@backoff.on_exception(backoff.expo, Exception, max_tries=5)
|
22 |
+
def fetch_data(item):
|
23 |
+
"""
|
24 |
+
Wrapper for process_data with exponential backoff for retries.
|
25 |
+
"""
|
26 |
+
return process_data(item)
|
27 |
+
|
28 |
+
def main():
|
29 |
+
"""
|
30 |
+
Main function to load and process the FineWeb-Edu dataset.
|
31 |
+
"""
|
32 |
+
# Configuration
|
33 |
+
remote_name = "sample-10BT" # Dataset configuration name
|
34 |
+
output_dir = "./data" # Directory to save processed data
|
35 |
+
os.makedirs(output_dir, exist_ok=True)
|
36 |
+
|
37 |
+
# Set up download config to handle rate limits and caching
|
38 |
+
download_config = DownloadConfig(
|
39 |
+
max_retries=5,
|
40 |
+
num_proc=4, # Limit to 4 processes to avoid HTTP 429
|
41 |
+
cache_dir=Path.home() / ".cache" / "huggingface" / "datasets"
|
42 |
+
)
|
43 |
+
|
44 |
+
try:
|
45 |
+
# Load dataset with caching
|
46 |
+
print("Loading dataset...")
|
47 |
+
dataset = load_dataset(
|
48 |
+
'HuggingFaceFW/fineweb-edu',
|
49 |
+
name=remote_name,
|
50 |
+
split='train',
|
51 |
+
download_mode="reuse_dataset_if_exists",
|
52 |
+
download_config=download_config
|
53 |
+
)
|
54 |
+
print(f"Dataset loaded with {len(dataset)} items.")
|
55 |
+
|
56 |
+
# Limit number of processes to avoid overwhelming Hugging Face Hub
|
57 |
+
nprocs = min(mp.cpu_count(), 4)
|
58 |
+
print(f"Using {nprocs} processes for multiprocessing.")
|
59 |
+
|
60 |
+
# Process dataset using multiprocessing
|
61 |
+
with mp.Pool(nprocs) as pool:
|
62 |
+
results = pool.map(fetch_data, dataset)
|
63 |
+
|
64 |
+
# Save processed results (example: save as numpy arrays)
|
65 |
+
output_path = os.path.join(output_dir, "processed_fineweb_edu.npy")
|
66 |
+
np.save(output_path, results)
|
67 |
+
print(f"Processed dataset saved to {output_path}")
|
68 |
+
|
69 |
+
except Exception as e:
|
70 |
+
print(f"Error loading or processing dataset: {e}")
|
71 |
+
raise
|
72 |
+
|
73 |
+
if __name__ == '__main__':
|
74 |
+
mp.freeze_support() # Required for Windows compatibility with executables
|
75 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
torch
|
3 |
+
tiktoken
|
4 |
+
transformers
|
train.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import time
|
5 |
+
from dataclasses import dataclass
|
6 |
+
import tiktoken
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.distributed as dist
|
11 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
12 |
+
# import code; code.interact(local=locals())
|
13 |
+
|
14 |
+
from model import GPT
|
15 |
+
from dataloader import DataLoaderLite
|
16 |
+
from hellaswag_eval import render_example, iterate_examples, get_most_likely_row
|
17 |
+
|
18 |
+
torch.set_float32_matmul_precision('high') # enable TF32 precision
|
19 |
+
|
20 |
+
# set torch compile to True (if it doesn't throws any error) to speed up training
|
21 |
+
use_torch_compile = False
|
22 |
+
|
23 |
+
|
24 |
+
class Trainer:
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
model,
|
28 |
+
optimizer,
|
29 |
+
train_loader,
|
30 |
+
val_loader,
|
31 |
+
token_encoder,
|
32 |
+
eval_freq,
|
33 |
+
grad_accum_steps,
|
34 |
+
ddp,
|
35 |
+
ddp_rank,
|
36 |
+
ddp_world_size,
|
37 |
+
device,
|
38 |
+
logpath
|
39 |
+
):
|
40 |
+
self.ddp = ddp
|
41 |
+
self.ddp_rank = ddp_rank
|
42 |
+
self.master_process = ddp_rank == 0
|
43 |
+
self.ddp_world_size = ddp_world_size
|
44 |
+
|
45 |
+
self.model = model
|
46 |
+
self.optimizer = optimizer
|
47 |
+
self.train_loader = train_loader
|
48 |
+
self.val_loader = val_loader
|
49 |
+
self.token_encoder = token_encoder
|
50 |
+
|
51 |
+
self.eval_freq = eval_freq
|
52 |
+
self.grad_accum_steps = grad_accum_steps
|
53 |
+
self.device = device
|
54 |
+
self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
|
55 |
+
self.logpath = logpath
|
56 |
+
|
57 |
+
|
58 |
+
def train(
|
59 |
+
self,
|
60 |
+
max_steps,
|
61 |
+
warmup_steps,
|
62 |
+
max_lr,
|
63 |
+
min_lr
|
64 |
+
):
|
65 |
+
for step in range(max_steps):
|
66 |
+
t0 = time.time()
|
67 |
+
self.is_last_step = (step == max_steps - 1)
|
68 |
+
|
69 |
+
# evaluate validation loss
|
70 |
+
if step % self.eval_freq == 0 or self.is_last_step:
|
71 |
+
self.evaluate_validation(step)
|
72 |
+
|
73 |
+
# evaluate model performance on HellaSwag every once in a while
|
74 |
+
if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile):
|
75 |
+
self.evaluate_helloswag(step)
|
76 |
+
|
77 |
+
# generate sequences from the model every once in a while
|
78 |
+
if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile):
|
79 |
+
self.generate_sequences(num_seq=5, max_tokens=32)
|
80 |
+
|
81 |
+
# training loop starts here
|
82 |
+
self.model.train() # sets model to train mode
|
83 |
+
self.optimizer.zero_grad() # resets all gradients
|
84 |
+
batch_loss = 0.0
|
85 |
+
|
86 |
+
for mini_step in range(self.grad_accum_steps):
|
87 |
+
inp, tar = self.train_loader.next_batch()
|
88 |
+
inp, tar = inp.to(self.device), tar.to(self.device)
|
89 |
+
|
90 |
+
# FORWARD PASS !!!
|
91 |
+
# autocast to bfloat16 for faster compute and memory efficiency
|
92 |
+
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
|
93 |
+
logits, loss = self.model(inp, tar)
|
94 |
+
|
95 |
+
# loss is scaled to account for gradient accumulation, because the gradients just add
|
96 |
+
# on each successive backward() call. Addition of gradients corresponds to SUM in the objective,
|
97 |
+
# but we want MEAN instead of a SUM
|
98 |
+
loss /= self.grad_accum_steps
|
99 |
+
batch_loss += loss.detach()
|
100 |
+
|
101 |
+
if self.ddp:
|
102 |
+
# in the final mini_step, sync and avg all gradients across all processes. used by both forward and backward processes
|
103 |
+
# can use 'no_sync()' context manager alternatively.
|
104 |
+
self.model.require_backward_grad_sync = (mini_step == self.grad_accum_steps - 1)
|
105 |
+
|
106 |
+
# each process accumulates gradients separately when 'require_backward_grad_sync'=False
|
107 |
+
# in the final 'mini_step', 'require_backward_grad_sync' becomes True, therefore
|
108 |
+
# gradients are averaged across all processes and shared among them by loss.backward()
|
109 |
+
loss.backward()
|
110 |
+
|
111 |
+
if self.ddp:
|
112 |
+
# 'batch_loss' is outside of DDP container, so need to perform 'all_reduce' to
|
113 |
+
# average out 'batch_loss' across all processes of all ranks. 'batch_loss' tensor exists on all GPUs.
|
114 |
+
# 'all_reduce' averages and deposits the result on all the processes
|
115 |
+
dist.all_reduce(batch_loss, op=dist.ReduceOp.AVG)
|
116 |
+
|
117 |
+
# once gradients are computed, clip the global l2-norm of the gradient at 1.0
|
118 |
+
norm = nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) # monitor/print 'norm'
|
119 |
+
|
120 |
+
# determine learning rate with decay
|
121 |
+
lr = self.estimate_lr(step, warmup_steps, max_steps, max_lr, min_lr)
|
122 |
+
# set learning rate for this iteration
|
123 |
+
for param_group in self.optimizer.param_groups:
|
124 |
+
param_group['lr'] = lr
|
125 |
+
|
126 |
+
self.optimizer.step()
|
127 |
+
if self.device_type == 'cuda':
|
128 |
+
torch.cuda.synchronize() # wait for the GPU to finish work
|
129 |
+
|
130 |
+
dt = (time.time() - t0) * 1000.0 # in ms
|
131 |
+
tokens_processed = self.train_loader.B * self.train_loader.T * self.grad_accum_steps * self.ddp_world_size
|
132 |
+
tokens_per_sec = tokens_processed / dt
|
133 |
+
|
134 |
+
if self.master_process:
|
135 |
+
print(f'step {step:4d} | loss: {batch_loss.item():.6f} | lr: {lr:.2e} | norm: {norm:.4f} | dt: {dt:.4f}ms | tok/sec: {tokens_per_sec:.4f}')
|
136 |
+
with open(self.logpath, 'a') as f:
|
137 |
+
f.write(f'{step} train {batch_loss.item():.6f}\n')
|
138 |
+
|
139 |
+
|
140 |
+
def evaluate_validation(self, step):
|
141 |
+
self.model.eval() # sets model to eval mode
|
142 |
+
self.val_loader.reset()
|
143 |
+
# evaluate the model on validation set
|
144 |
+
with torch.no_grad():
|
145 |
+
val_loss_accum = 0.0
|
146 |
+
val_steps = 20
|
147 |
+
for _ in range(val_steps):
|
148 |
+
inp, tar = self.val_loader.next_batch()
|
149 |
+
inp, tar = inp.to(self.device), tar.to(self.device)
|
150 |
+
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
|
151 |
+
logits, loss = self.model(inp, tar)
|
152 |
+
loss /= val_steps
|
153 |
+
val_loss_accum += loss.detach()
|
154 |
+
|
155 |
+
if self.ddp:
|
156 |
+
dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
|
157 |
+
if self.master_process:
|
158 |
+
print(f'Val loss: {val_loss_accum.item():.4f}')
|
159 |
+
with open(self.logpath, 'a') as f:
|
160 |
+
f.write(f'{step} val {val_loss_accum.item():.4f}\n')
|
161 |
+
|
162 |
+
if step > 0 and (step % 10000 == 0 or self.is_last_step):
|
163 |
+
raw_model = self.model.module if self.ddp else self.model
|
164 |
+
logdir = os.path.dirname(self.logpath)
|
165 |
+
ckpt_path = os.path.join(logdir, f'model_{step:05d}.pt')
|
166 |
+
checkpoint = {
|
167 |
+
'model': raw_model.state_dict(),
|
168 |
+
'config': raw_model.config,
|
169 |
+
'step': step,
|
170 |
+
'val_loss': val_loss_accum.item()
|
171 |
+
} # add optimizer.state_dict(), rng_seeds, etc. if resuming training
|
172 |
+
torch.save(checkpoint, ckpt_path)
|
173 |
+
|
174 |
+
|
175 |
+
def evaluate_helloswag(self, step):
|
176 |
+
"""
|
177 |
+
Construct a batch of 4 sequences and perform token completion using
|
178 |
+
our model.
|
179 |
+
"""
|
180 |
+
n_total = 0
|
181 |
+
n_correct_norm = 0
|
182 |
+
for i, example in enumerate(iterate_examples('val')):
|
183 |
+
# only process examples where i % ddp_world_size == ddp_rank
|
184 |
+
if i % self.ddp_world_size != self.ddp_rank:
|
185 |
+
continue
|
186 |
+
# render the example into tokens and labels
|
187 |
+
_, tokens, mask, label = render_example(example) # (4,N), (4,N), (4,N)
|
188 |
+
tokens, mask = tokens.to(self.device), mask.to(self.device)
|
189 |
+
with torch.no_grad():
|
190 |
+
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
|
191 |
+
logits, loss = self.model(tokens)
|
192 |
+
pred_norm = get_most_likely_row(tokens, mask, logits)
|
193 |
+
n_total += 1
|
194 |
+
n_correct_norm += int(pred_norm == label)
|
195 |
+
# reduce the stats across all processes
|
196 |
+
if self.ddp:
|
197 |
+
n_total = torch.tensor(n_total, device=self.device, dtype=torch.long)
|
198 |
+
n_correct_norm = torch.tensor(n_correct_norm, device=self.device, dtype=torch.long)
|
199 |
+
dist.all_reduce(n_total, op=dist.ReduceOp.SUM)
|
200 |
+
dist.all_reduce(n_correct_norm, op=dist.ReduceOp.SUM)
|
201 |
+
n_total = n_total.item()
|
202 |
+
n_correct_norm = n_correct_norm.item()
|
203 |
+
acc_norm = n_correct_norm / n_total
|
204 |
+
if self.master_process:
|
205 |
+
print(f'HelloSwag accuracy: {n_correct_norm}/{n_total}={acc_norm:.4f}')
|
206 |
+
with open(self.logpath, 'a') as f:
|
207 |
+
f.write(f'{step} hellaswag {acc_norm:.4f}\n')
|
208 |
+
|
209 |
+
|
210 |
+
def generate_sequences(self, num_seq=4, max_tokens=32):
|
211 |
+
self.model.eval()
|
212 |
+
tokens = self.token_encoder.encode("Hello, I am a language model")
|
213 |
+
tokens = torch.tensor(tokens, dtype=torch.long) # (n,) n : current sequence length
|
214 |
+
tokens = tokens.unsqueeze(0).repeat(num_seq, 1) # (1,n) --> (num_seq, n)
|
215 |
+
gen_tokens = tokens.to(self.device)
|
216 |
+
# create a different rng generator so as not to impact the global rng state used for training
|
217 |
+
sample_rng = torch.Generator(device=self.device)
|
218 |
+
# adding 'ddp_rank' in seeding to generate different tokens for different rank processes
|
219 |
+
sample_rng.manual_seed(42 + self.ddp_rank)
|
220 |
+
# generate new tokens one token at a time until the sequence length becomes 'max_tokens'
|
221 |
+
while gen_tokens.shape[-1] <= max_tokens:
|
222 |
+
with torch.no_grad():
|
223 |
+
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
|
224 |
+
logits, loss = self.model(gen_tokens) # (num_seq, n, vocab_size)
|
225 |
+
logits = logits[:, -1, :] # (num_seq, vocab_size)
|
226 |
+
probs = F.softmax(logits, dim=-1) # (num_seq, vocab_size)
|
227 |
+
# take top-k 50 probs
|
228 |
+
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # (num_seq, 50), (num_seq, 50)
|
229 |
+
# sample a token from top-50 probabilities
|
230 |
+
ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) # (num_seq, 1)
|
231 |
+
next_tok = torch.gather(topk_indices, -1, ix) # (num_seq, 1)
|
232 |
+
gen_tokens = torch.cat([gen_tokens, next_tok], dim=1)
|
233 |
+
# decode generated tokens and print generated text
|
234 |
+
for i in range(num_seq):
|
235 |
+
tokens = gen_tokens[i, :max_tokens].tolist()
|
236 |
+
gen_text = self.token_encoder.decode(tokens)
|
237 |
+
print(f"> rank {self.ddp_rank} sample {i}: {gen_text}")
|
238 |
+
|
239 |
+
|
240 |
+
def estimate_lr(self, step, warmup_steps, max_steps, max_lr, min_lr):
|
241 |
+
"""
|
242 |
+
Learning rate scheduler: Cosine-decay learning schedule with warmup
|
243 |
+
"""
|
244 |
+
# 1) linear warmup for 'warmup_iters' steps
|
245 |
+
if step < warmup_steps:
|
246 |
+
return max_lr * (step+1) / warmup_steps
|
247 |
+
# 2) if step > lr_decay_iters, return min lr
|
248 |
+
if step > max_steps:
|
249 |
+
return min_lr
|
250 |
+
# 3) in between, use cosine decay down to min lr
|
251 |
+
decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
|
252 |
+
assert 0 <= decay_ratio <= 1
|
253 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
|
254 |
+
return min_lr + coeff * (max_lr - min_lr)
|
255 |
+
|
256 |
+
|
257 |
+
@dataclass
|
258 |
+
class GPTConfig:
|
259 |
+
context_length: int = 1024 # max context / sequence length
|
260 |
+
vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 <endoftext> token
|
261 |
+
num_layers: int = 12
|
262 |
+
embd_size: int = 768 # embedding dim
|
263 |
+
num_heads: int = 12
|
264 |
+
|
265 |
+
|
266 |
+
def get_args():
|
267 |
+
import argparse
|
268 |
+
parser = argparse.ArgumentParser(description="Hyperparameter Configuration")
|
269 |
+
parser.add_argument("--total_batch_size", type=int, default=524288, help="number of tokens processed for each weight update") # =2^19 tokens/step update, (~0.5M tokens used in openai gpt3 paper)
|
270 |
+
parser.add_argument("--mini_batch_size", type=int, default=32, help="setting of mini_batch_size is just a performance optimization. bigger gpu, bigger mini_batch_size")
|
271 |
+
parser.add_argument("--context_length", type=int, default=1024) # max sequence length (can also try 2048)
|
272 |
+
parser.add_argument("--num_layers", type=int, default=12)
|
273 |
+
parser.add_argument("--embd_size", type=int, default=768)
|
274 |
+
parser.add_argument("--num_heads", type=int, default=12)
|
275 |
+
parser.add_argument("--max_lr", type=float, default=1e-3)
|
276 |
+
parser.add_argument("--min_lr", type=float, default=1e-3 * 0.1)
|
277 |
+
parser.add_argument("--warmup_steps", type=int, default=715)
|
278 |
+
parser.add_argument("--weight_decay", type=float, default=0.1)
|
279 |
+
parser.add_argument("--num_epochs", type=int, default=5)
|
280 |
+
parser.add_argument("--steps_per_epoch", type=int, default=19073) # 10^10 / 2^19 ~ 19073 for 1 epoch on FineWebEdu-sample10BT
|
281 |
+
parser.add_argument("--eval_freq", type=int, default=250)
|
282 |
+
# parser.add_argument("--use_torch_compile", action='store_true') # default False
|
283 |
+
parser.add_argument("--seed", type=int, default=1337, help="Random seed for reproducibility")
|
284 |
+
parser.add_argument("--logdir", type=str, default="./logs/")
|
285 |
+
return parser.parse_args()
|
286 |
+
|
287 |
+
|
288 |
+
def main():
|
289 |
+
args = get_args()
|
290 |
+
|
291 |
+
# Print the hyperparameters
|
292 |
+
print("Hyperparameter Configuration:")
|
293 |
+
for key, value in vars(args).items():
|
294 |
+
print(f"{key}: {value}")
|
295 |
+
|
296 |
+
# create the logs directory if it doesn't exist
|
297 |
+
os.makedirs(args.logdir, exist_ok=True)
|
298 |
+
logpath = os.path.join(args.logdir, 'log.txt')
|
299 |
+
with open(logpath, 'w') as f:
|
300 |
+
pass
|
301 |
+
|
302 |
+
# set up DDP (distributed data parallel)
|
303 |
+
# 'torchrun' command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
|
304 |
+
# RANK and LOCAL_RANK same for (single node, multi-GPU) settings, may differ for (multinode,
|
305 |
+
# multi GPU) settings.
|
306 |
+
ddp = int(os.environ.get('RANK', -1)) != -1 # if this is a ddp run or not
|
307 |
+
if ddp:
|
308 |
+
# use of ddp requires CUDA
|
309 |
+
assert torch.cuda.is_available(), f'use of DDP requires CUDA'
|
310 |
+
dist.init_process_group(backend='nccl')
|
311 |
+
ddp_rank = int(os.environ['RANK'])
|
312 |
+
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
313 |
+
ddp_world_size = int(os.environ['WORLD_SIZE'])
|
314 |
+
device = f'cuda:{ddp_local_rank}'
|
315 |
+
torch.cuda.set_device(device)
|
316 |
+
# master process (arbitrarily set to 0) will do printing, logging, checkpointing, etc.
|
317 |
+
master_process = ddp_rank == 0
|
318 |
+
else:
|
319 |
+
# not using ddp
|
320 |
+
ddp_rank = 0
|
321 |
+
ddp_local_rank = 0
|
322 |
+
ddp_world_size = 1
|
323 |
+
master_process = True # ddp_rank == 0
|
324 |
+
device = 'cpu'
|
325 |
+
if torch.cuda.is_available():
|
326 |
+
device = 'cuda'
|
327 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
328 |
+
device = 'mps' # for apple macbook GPUs
|
329 |
+
print(f'using device: {device}')
|
330 |
+
|
331 |
+
device_type = 'cuda' if device.startswith('cuda') else 'cpu'
|
332 |
+
|
333 |
+
# setting seed for reproducibility
|
334 |
+
np.random.seed(args.seed)
|
335 |
+
torch.manual_seed(args.seed) # sets seed for random number generation on CPU
|
336 |
+
if torch.cuda.is_available():
|
337 |
+
torch.cuda.manual_seed(args.seed) # sets seed for random number generation on GPU
|
338 |
+
torch.cuda.manual_seed_all(args.seed) # sets seed for all GPUs
|
339 |
+
|
340 |
+
assert args.total_batch_size % (args.mini_batch_size * args.context_length * ddp_world_size) == 0, f'ensure total_batch_size divisible by B*T*ddp_world_size'
|
341 |
+
grad_accum_steps = args.total_batch_size // (args.mini_batch_size * args.context_length * ddp_world_size)
|
342 |
+
if master_process:
|
343 |
+
print(f'desired batch size (number of tokens): {args.total_batch_size}')
|
344 |
+
print(f'gradient accumulation steps: {grad_accum_steps}')
|
345 |
+
print(f'GPU: {ddp_rank}, {ddp_local_rank}')
|
346 |
+
|
347 |
+
train_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='train')
|
348 |
+
val_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='val')
|
349 |
+
|
350 |
+
# create GPT model. each ddp process will create its own instance of the model but since the seed is fixed,
|
351 |
+
# they will create same identical model
|
352 |
+
gpt_config = GPTConfig(vocab_size=50304, # 50304 (nice number, lots of power of 2s) used instead of 50257 (bad, odd number)
|
353 |
+
context_length=args.context_length,
|
354 |
+
num_layers=args.num_layers,
|
355 |
+
num_heads=args.num_heads,
|
356 |
+
embd_size=args.embd_size
|
357 |
+
)
|
358 |
+
model = GPT(config=gpt_config)
|
359 |
+
# model = GPT.from_pretrained('gpt2') # init from OpenAI GPT-2
|
360 |
+
model.to(device) # move model to device
|
361 |
+
if use_torch_compile:
|
362 |
+
# use torch compile almost always unless debugging (requires compilation time, but makes training faster)
|
363 |
+
# speedup comes from reducing python overhead and GPU read/write
|
364 |
+
model = torch.compile(model)
|
365 |
+
|
366 |
+
if ddp:
|
367 |
+
# wraps the model in DDP container (forward pass is unchanged, but after backward pass,
|
368 |
+
# gradients computed across each processes averaged by DDP using 'AllReduce' and shared across
|
369 |
+
# all processes so that each process has same gradients)
|
370 |
+
model = DDP(model, device_ids=[ddp_local_rank])
|
371 |
+
|
372 |
+
raw_model = model.module if ddp else model
|
373 |
+
optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, lr=args.max_lr, device_type=device_type, master_process=master_process)
|
374 |
+
token_encoder = tiktoken.get_encoding('gpt2')
|
375 |
+
|
376 |
+
start_time = time.time()
|
377 |
+
# init the trainer object
|
378 |
+
trainer = Trainer(model, optimizer, train_loader, val_loader, token_encoder, args.eval_freq, grad_accum_steps,
|
379 |
+
ddp, ddp_rank, ddp_world_size, device, logpath)
|
380 |
+
|
381 |
+
max_steps = args.steps_per_epoch * args.num_epochs
|
382 |
+
trainer.train(max_steps, args.warmup_steps, args.max_lr, args.min_lr)
|
383 |
+
|
384 |
+
dt = (time.time() - start_time) / (60*60)
|
385 |
+
print(f"Total training time: {dt:.4f}hr")
|
386 |
+
|
387 |
+
if ddp:
|
388 |
+
dist.destroy_process_group()
|
389 |
+
|
390 |
+
|
391 |
+
if __name__ == "__main__":
|
392 |
+
main()
|