Yan Bai
add
55e1701
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT."""
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore")
import os
import torch
from functools import partial
from contextlib import nullcontext
import inspect
from typing import Union
from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.core.datasets.blended_megatron_dataset_builder import (
BlendedMegatronDatasetBuilder,
)
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
import megatron.legacy.model
from megatron.training import pretrain
from megatron.core.utils import StragglerDetector
from megatron.core.transformer.spec_utils import import_module
from megatron.training.utils import (
get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank,
)
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.training.initialize import initialize_megatron
from moe_mem_estimator.gpt_model import GPTModel
from moe_mem_estimator.base import (
is_pipeline_first_stage,
is_pipeline_last_stage,
set_global_config,
set_pipeline_model_parallel_rank,
)
from moe_mem_estimator.layers import MLASelfAttention, MoELayer
def _calculate_rank_memory(config, args, input_shape, pp_rank=0, pp_size=1):
"""
Calculates the memory for a single pipeline parallel rank, containing the detailed logic.
"""
# Build the model for the current rank
set_global_config(config)
pre_process = (pp_rank == 0)
post_process = (pp_rank == pp_size - 1)
use_te = True
if hasattr(config, 'spec') and config.spec is not None:
transformer_layer_spec = import_module(config.spec)
else:
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
config.num_moe_experts, config.moe_grouped_gemm, config.qk_layernorm,
config.multi_latent_attention, config.fp8
)
else:
transformer_layer_spec = get_gpt_layer_local_spec(
config.num_moe_experts, config.moe_grouped_gemm, config.qk_layernorm,
config.multi_latent_attention
)
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=getattr(config, 'fp16_lm_cross_entropy', False),
parallel_output=True,
share_embeddings_and_output_weights=args.tie_word_embeddings,
position_embedding_type="rope",
rotary_percent=getattr(args, 'rotary_percent', 1.0),
rotary_base=getattr(args, 'rotary_base', 10000),
rope_scaling=getattr(config, 'use_rope_scaling', False),
)
# --- Start of detailed memory calculation logic ---
num_parameter_this_shard = model.num_parameter()
num_activation = model.num_activation(input_shape)
output_shape = model.mock_forward(input_shape)
num_parameter_this_shard_sparse = sum(
layer.mlp.num_parameter() for layer in model.decoder.layers.modules
if isinstance(layer.mlp, MoELayer)
)
num_activation_this_shard_mlp = sum(
m.mlp.num_activation() for m in model.decoder.layers.modules
)
num_microbatch_this_pp_rank = pp_size - pp_rank
if config.num_layers_per_virtual_pipeline_stage is not None:
layers_this_pprank = len(model.decoder.layers.modules)
vpp_size = layers_this_pprank // config.num_layers_per_virtual_pipeline_stage
if vpp_size > 0:
num_microbatch_this_pp_rank = (pp_size * (vpp_size - 1) + (pp_size - pp_rank) * 2 - 1) / vpp_size
# Activation Recomputation
# The base activation number is for one microbatch. With pipeline parallelism,
# the total activation is multiplied by the number of microbatches in flight.
# Recomputation reduces this by re-calculating activations during the backward pass
# instead of storing them.
# This is the activation memory without any recomputation.
num_activation = (num_activation - model.num_act_post) * num_microbatch_this_pp_rank + model.num_act_post
if config.recompute_granularity == "full":
# This logic is transplanted from the more detailed `report_memory_usage_one_pp_rank`
recompute_num_layers = config.recompute_num_layers
num_layers = model.num_layers
# Activations of a model with recompute enabled.
# The activation of a layer is an input to the next layer.
# So, the total activation is the sum of the activations of all layers,
# plus the activation of the embedding layer.
# The activation of a layer is stored only if it is not recomputed.
common_act = (
model.num_act_pre
+ model.num_act_between_layers * num_layers * num_microbatch_this_pp_rank
)
if config.recompute_method == "block":
num_layers_with_loss = num_layers - recompute_num_layers
if num_layers_with_loss == 0:
peak1 = common_act + model.num_act_post
peak2 = common_act + model.num_act_per_layer
recomputed_activation = max(peak1, peak2)
else:
recomputed_activation = (
common_act
+ model.num_act_post
+ model.num_act_per_layer
* num_layers_with_loss
* num_microbatch_this_pp_rank
)
elif config.recompute_method == "uniform":
peak1 = common_act + model.num_act_post
peak2 = (
common_act
+ model.num_act_per_layer
* recompute_num_layers
* num_microbatch_this_pp_rank
)
recomputed_activation = max(peak1, peak2)
if isinstance(model.decoder.layers.modules[0].self_attention, MLASelfAttention):
recomputed_activation += model.decoder.layers.modules[0].self_attention.core_attention.num_activation()
num_activation = recomputed_activation
elif config.recompute_granularity == "selective":
# Selective recomputation is the default in Megatron-LM and is handled
# by Transformer Engine. The base `num_activation` calculation from `GPTModel`
# already reflects this. We just need to scale it by the number of in-flight microbatches.
# This is already the case, so we do nothing here.
pass
# Context Parallelism
if config.context_parallel_size > 1:
num_activation = (num_activation - num_activation_this_shard_mlp) / config.context_parallel_size + num_activation_this_shard_mlp
# Calculate bytes per parameter for optimizer states
if args.use_distributed_optimizer:
base_optim_bytes = 6 # FP16 weight, FP32 master weight
world_optim_bytes = 12 # FP32 grad, FP32 momentum, FP32 variance
else:
base_optim_bytes = 18 # All states on each GPU
world_optim_bytes = 0
num_bytes_per_parameter = base_optim_bytes + (world_optim_bytes / (args.data_parallel_size * config.context_parallel_size))
# Handle MoE optimizer state sharding if applicable
if num_parameter_this_shard_sparse > 0 and config.expert_model_parallel_size > 1:
moe_dp_size = args.data_parallel_size * config.tensor_model_parallel_size // (config.expert_model_parallel_size * args.expert_tensor_parallel_size)
num_bytes_per_parameter_moe = base_optim_bytes + (world_optim_bytes / moe_dp_size)
weight_and_optimizer_memory = (
(num_parameter_this_shard - num_parameter_this_shard_sparse) * num_bytes_per_parameter +
num_parameter_this_shard_sparse * num_bytes_per_parameter_moe
) / NUM_BYTES_IN_GIGABYTE
else:
weight_and_optimizer_memory = (num_parameter_this_shard * num_bytes_per_parameter) / NUM_BYTES_IN_GIGABYTE
activation_memory = num_activation * 2 / NUM_BYTES_IN_GIGABYTE # Use GIGABYTE
total_memory = weight_and_optimizer_memory + activation_memory
report = {
"pp_rank": pp_rank,
"parameters_b": num_parameter_this_shard / 1e9,
"activation_b": num_activation / 1e9, # Renamed from _gb to _b
"weight_optimizer_gb": round(weight_and_optimizer_memory, 2),
"activation_gb": round(activation_memory, 2),
"total_gb": round(total_memory, 2),
"details": model.dump(),
"model_breakdown": str(model)
}
print(model)
return report, output_shape
def estimate_from_config(config, args):
"""
Estimate memory usage from a given config and args, instead of global state.
This version iterates over pipeline parallel ranks for accurate estimation.
"""
reports = []
input_shape = [args.micro_batch_size, args.seq_length]
pp_size = config.pipeline_model_parallel_size
if pp_size > 1:
for pp_rank in range(pp_size):
set_pipeline_model_parallel_rank(pp_rank)
report_for_rank, new_input_shape = _calculate_rank_memory(config, args, input_shape, pp_rank, pp_size)
reports.append(report_for_rank)
input_shape = new_input_shape # Pass output shape to the next stage
else:
report_for_rank, _ = _calculate_rank_memory(config, args, input_shape, 0, 1)
reports.append(report_for_rank)
return reports
def model_provider() -> GPTModel:
args = get_args()
use_te = args.transformer_impl == "transformer_engine"
# Experimental loading arguments from yaml
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)
assert not args.use_legacy_models
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
args.fp8,
)
else:
transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
)
set_global_config(config)
pre_process = is_pipeline_first_stage()
post_process = is_pipeline_last_stage()
# TODO fp8
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling,
)
return model
NUM_BYTES_IN_MEGABYTE = 1024 * 1024
NUM_BYTES_IN_GIGABYTE = 1024 * 1024 * 1024
def report_memory_usage():
args = get_args()
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)
input_shape = [args.micro_batch_size, args.seq_length]
if config.pipeline_model_parallel_size > 1:
for pp_rank in range(config.pipeline_model_parallel_size):
set_pipeline_model_parallel_rank(pp_rank)
print(f"\n----------[Pipeline_Parallelism_Rank={pp_rank}]----------")
input_shape = report_memory_usage_one_pp_rank(
input_shape, pp_rank, config.pipeline_model_parallel_size
)
else:
report_memory_usage_one_pp_rank(input_shape)
def report_memory_usage_one_pp_rank(
input_shape: list[int], pp_rank=0, pp_size=1
) -> list[int]:
args = get_args()
print(f"{input_shape=}")
model: GPTModel = model_provider()
num_parameter_this_shard = model.num_parameter()
num_activation = model.num_activation(input_shape)
output_shape = model.mock_forward(input_shape)
num_parameter_this_shard_sparse = 0
for layer in model.decoder.layers.modules:
if isinstance(layer.mlp, MoELayer):
num_parameter_this_shard_sparse += layer.mlp.num_parameter()
if (
"shared_experts" in layer.mlp.__dir__()
and layer.mlp.shared_experts is not None
):
num_parameter_this_shard_sparse -= (
layer.mlp.shared_experts.num_parameter()
)
num_activation_this_shard_mlp = sum(
[m.mlp.num_activation() for m in model.decoder.layers.modules]
)
num_microbatch_this_pp_rank = pp_size - pp_rank
# vpp
if args.num_layers_per_virtual_pipeline_stage is not None:
layers_this_pprank = model.decoder.layers.modules.__len__()
vpp_size = layers_this_pprank // args.num_layers_per_virtual_pipeline_stage
num_microbatch_this_pp_rank = (
pp_size * (vpp_size - 1) + (pp_size - pp_rank) * 2 - 1
) / vpp_size
num_parameter_this_shard_sparse = 0
for layer in model.decoder.layers.modules:
if isinstance(layer.mlp, MoELayer):
num_parameter_this_shard_sparse += layer.mlp.num_parameter()
if (
"shared_experts" in layer.mlp.__dir__()
and layer.mlp.shared_experts is not None
):
num_parameter_this_shard_sparse -= (
layer.mlp.shared_experts.num_parameter()
)
num_microbatch_this_pp_rank = pp_size - pp_rank
# vpp
if args.num_layers_per_virtual_pipeline_stage is not None:
layers_this_pprank = model.decoder.layers.modules.__len__()
vpp_size = layers_this_pprank // args.num_layers_per_virtual_pipeline_stage
num_microbatch_this_pp_rank = (
pp_size * (vpp_size - 1) + (pp_size - pp_rank) * 2 - 1
) / vpp_size
model.__repr__()
print(model)
print(
f"Number of parameters in every GPU in billions: "
f"{num_parameter_this_shard / 10**9: .2f} where mlp part is {num_parameter_this_shard_sparse / 10**9: .2f}"
)
# recompute
if args.recompute_granularity == "full":
recompute_num_layers = args.recompute_num_layers
num_layers = model.num_layers
common_act = (
model.num_act_pre
+ model.num_act_between_layers * num_layers * num_microbatch_this_pp_rank
) # recompute with pipeline parallel
info = (
"With this recomputing setting, the number of activation achieve peak when "
)
if args.recompute_method == "block":
num_layers_with_loss = num_layers - recompute_num_layers
if num_layers_with_loss == 0:
peak1 = common_act + model.num_act_post
peak2 = common_act + model.num_act_per_layer
if peak1 > peak2:
info += "calculating loss"
else:
info += "back-propogating loss"
num_activation = max(peak1, peak2)
else:
info += (
f"calculating loss with {num_layers_with_loss} non-recompute layers"
)
num_activation = (
common_act
+ model.num_act_post
+ model.num_act_per_layer
* num_layers_with_loss
* num_microbatch_this_pp_rank
)
elif args.recompute_method == "uniform":
peak1 = common_act + model.num_act_post
peak2 = (
common_act
+ model.num_act_per_layer
* recompute_num_layers
* num_microbatch_this_pp_rank
)
if peak1 > peak2:
info += "calculating loss"
else:
info += f"back-propogating loss recomputing every {recompute_num_layers} layers"
num_activation = max(peak1, peak2)
if isinstance(
model.decoder.layers.modules[0].self_attention, MLASelfAttention
): # MLA recompute achieve peak at backward
num_activation += model.decoder.layers.modules[
0
].self_attention.core_attention.num_activation()
print(info)
else:
num_activation = (
num_activation - model.num_act_post
) * num_microbatch_this_pp_rank + model.num_act_post
# CP
num_activation = (
num_activation - num_activation_this_shard_mlp
) / args.context_parallel_size + num_activation_this_shard_mlp
if pp_size == 1:
print(
f"Number of activation in every GPU in billions: "
f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}"
)
else:
print(
f"Number of activation per microbatch in every GPU in billions: "
f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}"
f", {num_microbatch_this_pp_rank=}"
)
num_bytes_per_parameter = (
18
if not args.use_distributed_optimizer
else 6 + (12 / args.data_parallel_size / args.context_parallel_size)
)
if args.expert_model_parallel_size * args.expert_tensor_parallel_size > 1:
num_bytes_per_parameter_dense = num_bytes_per_parameter
num_bytes_per_parameter_moe = (
18
if not args.use_distributed_optimizer
else 6
+ (
12
/ (
args.data_parallel_size
* args.context_parallel_size
* args.tensor_model_parallel_size
/ args.expert_model_parallel_size
/ args.expert_tensor_parallel_size
)
)
)
print(f"{num_bytes_per_parameter_dense=} {num_bytes_per_parameter_moe=}")
weight_and_optimizer_memory = (
(num_parameter_this_shard - num_parameter_this_shard_sparse)
* num_bytes_per_parameter_dense
+ num_parameter_this_shard_sparse * num_bytes_per_parameter_moe
) / NUM_BYTES_IN_GIGABYTE
else:
print(f"{num_bytes_per_parameter=}")
weight_and_optimizer_memory = (
num_parameter_this_shard * num_bytes_per_parameter / NUM_BYTES_IN_GIGABYTE
)
activation_memory = num_activation * 2 / NUM_BYTES_IN_GIGABYTE # only support fp16
total_memory = weight_and_optimizer_memory + activation_memory
print(
f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory/1024:.2f} GB, "
f"activation={activation_memory/1024:.2f} GB, total={total_memory/1024:.2f} GB\n"
)
# import ipdb
# ipdb.set_trace()
return output_shape
pass
if __name__ == "__main__":
initialize_megatron(allow_no_cuda=True, skip_mpu_initialization=True)
import ipdb
with ipdb.launch_ipdb_on_exception():
report_memory_usage()