If you see mistakes or want to suggest changes, please create an issue on GitHub. create an issue on GitHub. Diagrams and text are licensed under Creative Commons Attribution CC-BY 4.0 with the source available on GitHub, unless noted otherwise. The figures that have been reused from other sources don’t fall under this license and can be recognized by a note in their caption: “Figure from …”. For attribution in academic contexts, please cite this work as BibTeX citation tag. If you see mistakes or want to suggest changes, please create an issue on GitHub. Diagrams and text are licensed under Creative Commons Attribution CC-BY 4.0 with the source available on GitHub, unless noted otherwise. The figures that have been reused from other sources don't fall under this license and can be recognized by a note in their caption: "Figure from …". For attribution in academic contexts, please cite this work as BibTeX citation Diagrams and text are licensed under Creative Commons Attribution CC-BY 4.0 with the source available on GitHub, unless noted otherwise. The figures that have been reused from other sources don’t fall under this license and can be recognized by a note in their caption: “Figure from …”. For attribution in academic contexts, please cite this work as BibTeX citation The performance of a large language model (LLM) depends heavily on the quality and size of the LLMs.
For attribution in academic contexts, please cite this work as BibTeX citation
Table of contents
+ 🔭 Ultra-Guide to Scaling LLM training
Scaling Models and Hardware
+ Experiment setup
+ 1B (1)
+ 7B
+ 70B
+ 340B (2)
+ 400B (3)
+ N Layers
+ 24
+ 32
+ 80
+ 96
+ 126
+ N Heads
+ 32
+ 32
+ 64
+ 96
+ 128
+ Dimension
+ 2048
+ 4096
+ 8192
+ 18432
+ 16384
+ Distribution Methods
+No Parallelism
?A brief overview of memory usage in Transformers
+import torch; torch.ones((1, 1)).to("cuda")
and then checking the GPU memory with nvidia-smi
+ Model parameters
+ Memory requirements
+ 1B
+ 16 GB
+ 7B
+ 112 GB
+ 70B
+ 1120 GB
+ 6480 GB
Activation recomputation
strategy since it requires a forward pass through each layer essentially adding a full forward pass during the backward pass. This strategy saves the most memory but is the most expensive one in terms of compute. This increases the compute cost by up to 30-40% which is very noticeable.
settings.micro batch size
(MBS) the batch size for each forward pass on a single node (the number of samples flowing through the model in one forward pass). We’ll refer to the overall batch size between each optimizer step as the global batch size
(GBS). If we do one optimizer step each 8 forward/backward pass, the global batch size
will be 8 times the micro batch size
.global batch size
thus corresponds to what we’ve called up to now just batch size
for simplicity (we now make the terms more precise to avoid ambiguity).
Data Parallelism
+Tensor Parallelism
+def example_gelu():
+ from torch.nn.functional import gelu
+ X = torch.randn(4, 2, device="cuda", dtype=torch.float32)
+ W = torch.randn(2, 2, device="cuda", dtype=torch.float32)
+ W_0, W_1 = W.chunk(2, dim=1)
+ # Column linear
+ y_col_1 = torch.cat([gelu(X @ W_0), gelu(X @ W_1)], dim=1)
+ y_col_2 = gelu(torch.cat([X @ W_0, X @ W_1], dim=1))
+ # All match
+ torch.testing.assert_close(y_col_1, y_col_2, rtol=1e-5, atol=1e-5)
+ # Row linear
+ X_0, X_1 = X.chunk(2, dim=1)
+ W_0, W_1 = W.chunk(2, dim=0)
+ y_row_1 = gelu(X_0 @ W_0) + gelu(X_1 @ W_1)
+ y_row_2 = gelu(X_0 @ W_0 + X_1 @ W_1)
+ # Mismatch
+ torch.testing.assert_close(y_row_1, y_row_2, rtol=1e-5, atol=1e-5)
+def column_linear_forward(X, local_W, group):
+ Y_local = X @ local_W.t()
+ return Y_local
+def column_linear_backward(local_grad_Y, X, local_W, group):
+ local_grad_X = local_grad_Y @ local_W
+ grad_W = local_grad_Y.t() @ X
+ return local_grad_X, grad_W
+def row_linear_forward(local_X, local_W, group):
+ Y_local = local_X @ local_W.t()
+ dist.all_reduce(Y_local, group=group)
+ Y = Y_local
+ return Y
+def row_linear_backward(grad_Y, X, local_W, group):
+ local_grad_X = grad_Y @ local_W
+ grad_W = grad_Y.t() @ X
+ return local_grad_X, grad_W
+def example_column_row_linear():
+ # torchrun --nproc_per_node=2 tp_all_reduce.py
+ group = dist.distributed_c10d._get_default_group()
+ X_ref = torch.arange(4 * 2, device="cuda", dtype=torch.float32, requires_grad=True).reshape(4, 2)
+ W_ref_layer1 = torch.arange(1, 5, device="cuda", dtype=torch.float32, requires_grad=True).reshape(2, 2) * 10
+ W_ref_layer2 = torch.arange(1, 5, device="cuda", dtype=torch.float32, requires_grad=True).reshape(2, 2)
+ X_ref.retain_grad()
+ W_ref_layer1.retain_grad()
+ W_ref_layer2.retain_grad()
+ dist.broadcast(X_ref, src=0, group=group)
+ dist.broadcast(W_ref_layer1, src=0, group=group)
+ dist.broadcast(W_ref_layer2, src=0, group=group)
+ X = X_ref.clone()
+ W_layer1 = W_ref_layer1.clone()
+ W_layer2 = W_ref_layer2.clone()
+ # Forward
+ Y_ref_linear1 = X_ref @ W_ref_layer1.t()
+ Y_ref_linear1.retain_grad()
+ # We will transpose for matrix multiplication. As a result, we need to split row-wise
+ Y_local_linear1 = column_linear_forward(X, split_tensor(W_layer1, dim=0), group)
+ torch.testing.assert_close(Y_local_linear1, split_tensor(Y_ref_linear1, dim=1), rtol=1e-5, atol=1e-5)
+ Y_local_linear2 = row_linear_forward(Y_local_linear1, split_tensor(W_ref_layer2, dim=1), group)
+ Y_ref_linear2 = Y_ref_linear1 @ W_ref_layer2.t()
+ torch.testing.assert_close(Y_local_linear2, Y_ref_linear2, rtol=1e-5, atol=1e-5)
+ # Backward
+ Y_ref_linear2.sum().backward()
+ grad_Y = torch.ones_like(Y_ref_linear2)
+ grad_X_linear2, grad_W_linear2 = row_linear_backward(grad_Y, Y_local_linear1, split_tensor(W_layer2, dim=1), group)
+ torch.testing.assert_close(grad_X_linear2, split_tensor(Y_ref_linear1.grad, dim=1), rtol=1e-5, atol=1e-5)
+ torch.testing.assert_close(grad_W_linear2, split_tensor(W_ref_layer2.grad, dim=1), rtol=1e-5, atol=1e-5)
+ grad_X, grad_W = column_linear_backward(grad_X_linear2, X, split_tensor(W_layer1, dim=0), group)
+ torch.testing.assert_close(grad_X, X_ref.grad, rtol=1e-5, atol=1e-5)
+ torch.testing.assert_close(grad_W, split_tensor(W_ref_layer1.grad, dim=0), rtol=1e-5, atol=1e-5)
+if __name__ == "__main__":
+ dist.init_process_group("nccl", rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"]))
+ torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
+ example_column_row_linear()
Sequence Parallelism
+Context Parallelism
+Pipeline Parallelism
+Overlapping computation and communication
+II – Architecture
+Choosing the right dimensions
+Positional Embeddings (Learned, RoPE, ALiBi)
Attention (MHA, MQA, GQA)
+Optimized Operations
+Flash Attention 1&2&3
+Fused Kernels
+III – Training Recipe
+Batch Size
+Initialization + rescaling activations inside the model
+Numerical Precision
+Long Context Training
+Ring Attention
+RoPE scaling / Yarn
+ Conclusion and looking forward
+ Citation
+ Penedo, et al., "The FineWeb Datasets: Decanting the Web for the Finest Text Data at Scale", 2024.
+ @misc{penedo2024finewebdatasetsdecantingweb,
+ title={The FineWeb Datasets: Decanting the Web for the Finest Text Data at Scale},
+ author={Guilherme Penedo and Hynek Kydlíček and Loubna Ben allal and Anton Lozhkov and Margaret Mitchell and Colin Raffel and Leandro Von Werra and Thomas Wolf},
+ year={2024},
+ eprint={2406.17557},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+ url={https://arxiv.org/abs/2406.17557},
