Yan Bai
commited on
Commit
·
55e1701
1
Parent(s):
9eb3690
add
Browse files- .gitignore +1 -0
- Dockerfile +23 -0
- __init__.py +1 -0
- app.py +1 -0
- estimate.py +499 -0
- moe_mem_estimator/__init__.py +0 -0
- moe_mem_estimator/base.py +211 -0
- moe_mem_estimator/gpt_model.py +151 -0
- moe_mem_estimator/layers.py +1813 -0
- webui/index.html +163 -0
- webui/main.py +211 -0
- webui/model-configs/qwen3-14b.json +30 -0
- webui/model-configs/qwen3-235b-a22b.json +38 -0
- webui/model-configs/qwen3-30b-a3b.json +38 -0
- webui/model-configs/qwen3-32b.json +30 -0
- webui/model-configs/qwen3-8b.json +30 -0
- webui/requirements.txt +3 -0
- webui/script.js +715 -0
- webui/style.css +383 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
Dockerfile
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3
|
2 |
+
|
3 |
+
# 安装额外依赖(如果基础镜像已包含部分依赖,pip 会自动跳过)
|
4 |
+
RUN pip install --no-cache-dir \
|
5 |
+
fastapi \
|
6 |
+
uvicorn[standard] \
|
7 |
+
mbridge \
|
8 |
+
termcolor \
|
9 |
+
ipdb
|
10 |
+
# 添加 Megatron-LM core_v0.12.2
|
11 |
+
RUN git clone -b core_v0.12.2 --depth 1 https://github.com/NVIDIA/Megatron-LM.git /opt/Megatron-LM
|
12 |
+
|
13 |
+
# 复制代码至工作目录
|
14 |
+
WORKDIR /app
|
15 |
+
COPY . /app
|
16 |
+
|
17 |
+
# HF Spaces 默认通过 $PORT 注入端口
|
18 |
+
ENV PYTHONPATH=/opt/Megatron-LM:$PYTHONPATH
|
19 |
+
ENV PORT=7860
|
20 |
+
EXPOSE 7860
|
21 |
+
|
22 |
+
# 启动 FastAPI 服务
|
23 |
+
CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port $PORT"]
|
__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
app.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from webui.main import app
|
estimate.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
"""Pretrain GPT."""
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
6 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
7 |
+
warnings.filterwarnings("ignore")
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
from functools import partial
|
11 |
+
from contextlib import nullcontext
|
12 |
+
import inspect
|
13 |
+
|
14 |
+
from typing import Union
|
15 |
+
from megatron.training import get_args
|
16 |
+
from megatron.training import print_rank_0
|
17 |
+
from megatron.training import get_timers
|
18 |
+
from megatron.training import get_tokenizer
|
19 |
+
from megatron.core import mpu
|
20 |
+
from megatron.core.enums import ModelType
|
21 |
+
from megatron.core.datasets.blended_megatron_dataset_builder import (
|
22 |
+
BlendedMegatronDatasetBuilder,
|
23 |
+
)
|
24 |
+
from megatron.core.datasets.utils import get_blend_from_list
|
25 |
+
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
|
26 |
+
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
|
27 |
+
import megatron.legacy.model
|
28 |
+
from megatron.training import pretrain
|
29 |
+
from megatron.core.utils import StragglerDetector
|
30 |
+
from megatron.core.transformer.spec_utils import import_module
|
31 |
+
from megatron.training.utils import (
|
32 |
+
get_batch_on_this_cp_rank,
|
33 |
+
get_batch_on_this_tp_rank,
|
34 |
+
)
|
35 |
+
from megatron.training.arguments import core_transformer_config_from_args
|
36 |
+
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
|
37 |
+
from megatron.core.models.gpt.gpt_layer_specs import (
|
38 |
+
get_gpt_layer_local_spec,
|
39 |
+
get_gpt_layer_with_transformer_engine_spec,
|
40 |
+
)
|
41 |
+
from megatron.training.initialize import initialize_megatron
|
42 |
+
from moe_mem_estimator.gpt_model import GPTModel
|
43 |
+
from moe_mem_estimator.base import (
|
44 |
+
is_pipeline_first_stage,
|
45 |
+
is_pipeline_last_stage,
|
46 |
+
set_global_config,
|
47 |
+
set_pipeline_model_parallel_rank,
|
48 |
+
)
|
49 |
+
from moe_mem_estimator.layers import MLASelfAttention, MoELayer
|
50 |
+
|
51 |
+
|
52 |
+
def _calculate_rank_memory(config, args, input_shape, pp_rank=0, pp_size=1):
|
53 |
+
"""
|
54 |
+
Calculates the memory for a single pipeline parallel rank, containing the detailed logic.
|
55 |
+
"""
|
56 |
+
# Build the model for the current rank
|
57 |
+
set_global_config(config)
|
58 |
+
pre_process = (pp_rank == 0)
|
59 |
+
post_process = (pp_rank == pp_size - 1)
|
60 |
+
|
61 |
+
use_te = True
|
62 |
+
if hasattr(config, 'spec') and config.spec is not None:
|
63 |
+
transformer_layer_spec = import_module(config.spec)
|
64 |
+
else:
|
65 |
+
if use_te:
|
66 |
+
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
|
67 |
+
config.num_moe_experts, config.moe_grouped_gemm, config.qk_layernorm,
|
68 |
+
config.multi_latent_attention, config.fp8
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
transformer_layer_spec = get_gpt_layer_local_spec(
|
72 |
+
config.num_moe_experts, config.moe_grouped_gemm, config.qk_layernorm,
|
73 |
+
config.multi_latent_attention
|
74 |
+
)
|
75 |
+
|
76 |
+
model = GPTModel(
|
77 |
+
config=config,
|
78 |
+
transformer_layer_spec=transformer_layer_spec,
|
79 |
+
vocab_size=args.padded_vocab_size,
|
80 |
+
max_sequence_length=args.max_position_embeddings,
|
81 |
+
pre_process=pre_process,
|
82 |
+
post_process=post_process,
|
83 |
+
fp16_lm_cross_entropy=getattr(config, 'fp16_lm_cross_entropy', False),
|
84 |
+
parallel_output=True,
|
85 |
+
share_embeddings_and_output_weights=args.tie_word_embeddings,
|
86 |
+
position_embedding_type="rope",
|
87 |
+
rotary_percent=getattr(args, 'rotary_percent', 1.0),
|
88 |
+
rotary_base=getattr(args, 'rotary_base', 10000),
|
89 |
+
rope_scaling=getattr(config, 'use_rope_scaling', False),
|
90 |
+
)
|
91 |
+
|
92 |
+
# --- Start of detailed memory calculation logic ---
|
93 |
+
num_parameter_this_shard = model.num_parameter()
|
94 |
+
num_activation = model.num_activation(input_shape)
|
95 |
+
output_shape = model.mock_forward(input_shape)
|
96 |
+
|
97 |
+
num_parameter_this_shard_sparse = sum(
|
98 |
+
layer.mlp.num_parameter() for layer in model.decoder.layers.modules
|
99 |
+
if isinstance(layer.mlp, MoELayer)
|
100 |
+
)
|
101 |
+
num_activation_this_shard_mlp = sum(
|
102 |
+
m.mlp.num_activation() for m in model.decoder.layers.modules
|
103 |
+
)
|
104 |
+
|
105 |
+
num_microbatch_this_pp_rank = pp_size - pp_rank
|
106 |
+
if config.num_layers_per_virtual_pipeline_stage is not None:
|
107 |
+
layers_this_pprank = len(model.decoder.layers.modules)
|
108 |
+
vpp_size = layers_this_pprank // config.num_layers_per_virtual_pipeline_stage
|
109 |
+
if vpp_size > 0:
|
110 |
+
num_microbatch_this_pp_rank = (pp_size * (vpp_size - 1) + (pp_size - pp_rank) * 2 - 1) / vpp_size
|
111 |
+
|
112 |
+
# Activation Recomputation
|
113 |
+
# The base activation number is for one microbatch. With pipeline parallelism,
|
114 |
+
# the total activation is multiplied by the number of microbatches in flight.
|
115 |
+
# Recomputation reduces this by re-calculating activations during the backward pass
|
116 |
+
# instead of storing them.
|
117 |
+
|
118 |
+
# This is the activation memory without any recomputation.
|
119 |
+
num_activation = (num_activation - model.num_act_post) * num_microbatch_this_pp_rank + model.num_act_post
|
120 |
+
|
121 |
+
if config.recompute_granularity == "full":
|
122 |
+
# This logic is transplanted from the more detailed `report_memory_usage_one_pp_rank`
|
123 |
+
recompute_num_layers = config.recompute_num_layers
|
124 |
+
num_layers = model.num_layers
|
125 |
+
# Activations of a model with recompute enabled.
|
126 |
+
# The activation of a layer is an input to the next layer.
|
127 |
+
# So, the total activation is the sum of the activations of all layers,
|
128 |
+
# plus the activation of the embedding layer.
|
129 |
+
# The activation of a layer is stored only if it is not recomputed.
|
130 |
+
common_act = (
|
131 |
+
model.num_act_pre
|
132 |
+
+ model.num_act_between_layers * num_layers * num_microbatch_this_pp_rank
|
133 |
+
)
|
134 |
+
if config.recompute_method == "block":
|
135 |
+
num_layers_with_loss = num_layers - recompute_num_layers
|
136 |
+
if num_layers_with_loss == 0:
|
137 |
+
peak1 = common_act + model.num_act_post
|
138 |
+
peak2 = common_act + model.num_act_per_layer
|
139 |
+
recomputed_activation = max(peak1, peak2)
|
140 |
+
else:
|
141 |
+
recomputed_activation = (
|
142 |
+
common_act
|
143 |
+
+ model.num_act_post
|
144 |
+
+ model.num_act_per_layer
|
145 |
+
* num_layers_with_loss
|
146 |
+
* num_microbatch_this_pp_rank
|
147 |
+
)
|
148 |
+
elif config.recompute_method == "uniform":
|
149 |
+
peak1 = common_act + model.num_act_post
|
150 |
+
peak2 = (
|
151 |
+
common_act
|
152 |
+
+ model.num_act_per_layer
|
153 |
+
* recompute_num_layers
|
154 |
+
* num_microbatch_this_pp_rank
|
155 |
+
)
|
156 |
+
recomputed_activation = max(peak1, peak2)
|
157 |
+
|
158 |
+
if isinstance(model.decoder.layers.modules[0].self_attention, MLASelfAttention):
|
159 |
+
recomputed_activation += model.decoder.layers.modules[0].self_attention.core_attention.num_activation()
|
160 |
+
|
161 |
+
num_activation = recomputed_activation
|
162 |
+
|
163 |
+
elif config.recompute_granularity == "selective":
|
164 |
+
# Selective recomputation is the default in Megatron-LM and is handled
|
165 |
+
# by Transformer Engine. The base `num_activation` calculation from `GPTModel`
|
166 |
+
# already reflects this. We just need to scale it by the number of in-flight microbatches.
|
167 |
+
# This is already the case, so we do nothing here.
|
168 |
+
pass
|
169 |
+
|
170 |
+
|
171 |
+
# Context Parallelism
|
172 |
+
if config.context_parallel_size > 1:
|
173 |
+
num_activation = (num_activation - num_activation_this_shard_mlp) / config.context_parallel_size + num_activation_this_shard_mlp
|
174 |
+
|
175 |
+
# Calculate bytes per parameter for optimizer states
|
176 |
+
if args.use_distributed_optimizer:
|
177 |
+
base_optim_bytes = 6 # FP16 weight, FP32 master weight
|
178 |
+
world_optim_bytes = 12 # FP32 grad, FP32 momentum, FP32 variance
|
179 |
+
else:
|
180 |
+
base_optim_bytes = 18 # All states on each GPU
|
181 |
+
world_optim_bytes = 0
|
182 |
+
|
183 |
+
num_bytes_per_parameter = base_optim_bytes + (world_optim_bytes / (args.data_parallel_size * config.context_parallel_size))
|
184 |
+
|
185 |
+
# Handle MoE optimizer state sharding if applicable
|
186 |
+
if num_parameter_this_shard_sparse > 0 and config.expert_model_parallel_size > 1:
|
187 |
+
moe_dp_size = args.data_parallel_size * config.tensor_model_parallel_size // (config.expert_model_parallel_size * args.expert_tensor_parallel_size)
|
188 |
+
num_bytes_per_parameter_moe = base_optim_bytes + (world_optim_bytes / moe_dp_size)
|
189 |
+
|
190 |
+
weight_and_optimizer_memory = (
|
191 |
+
(num_parameter_this_shard - num_parameter_this_shard_sparse) * num_bytes_per_parameter +
|
192 |
+
num_parameter_this_shard_sparse * num_bytes_per_parameter_moe
|
193 |
+
) / NUM_BYTES_IN_GIGABYTE
|
194 |
+
else:
|
195 |
+
weight_and_optimizer_memory = (num_parameter_this_shard * num_bytes_per_parameter) / NUM_BYTES_IN_GIGABYTE
|
196 |
+
|
197 |
+
activation_memory = num_activation * 2 / NUM_BYTES_IN_GIGABYTE # Use GIGABYTE
|
198 |
+
total_memory = weight_and_optimizer_memory + activation_memory
|
199 |
+
|
200 |
+
report = {
|
201 |
+
"pp_rank": pp_rank,
|
202 |
+
"parameters_b": num_parameter_this_shard / 1e9,
|
203 |
+
"activation_b": num_activation / 1e9, # Renamed from _gb to _b
|
204 |
+
"weight_optimizer_gb": round(weight_and_optimizer_memory, 2),
|
205 |
+
"activation_gb": round(activation_memory, 2),
|
206 |
+
"total_gb": round(total_memory, 2),
|
207 |
+
"details": model.dump(),
|
208 |
+
"model_breakdown": str(model)
|
209 |
+
}
|
210 |
+
print(model)
|
211 |
+
|
212 |
+
return report, output_shape
|
213 |
+
|
214 |
+
|
215 |
+
def estimate_from_config(config, args):
|
216 |
+
"""
|
217 |
+
Estimate memory usage from a given config and args, instead of global state.
|
218 |
+
This version iterates over pipeline parallel ranks for accurate estimation.
|
219 |
+
"""
|
220 |
+
reports = []
|
221 |
+
input_shape = [args.micro_batch_size, args.seq_length]
|
222 |
+
pp_size = config.pipeline_model_parallel_size
|
223 |
+
|
224 |
+
if pp_size > 1:
|
225 |
+
for pp_rank in range(pp_size):
|
226 |
+
set_pipeline_model_parallel_rank(pp_rank)
|
227 |
+
report_for_rank, new_input_shape = _calculate_rank_memory(config, args, input_shape, pp_rank, pp_size)
|
228 |
+
reports.append(report_for_rank)
|
229 |
+
input_shape = new_input_shape # Pass output shape to the next stage
|
230 |
+
else:
|
231 |
+
report_for_rank, _ = _calculate_rank_memory(config, args, input_shape, 0, 1)
|
232 |
+
reports.append(report_for_rank)
|
233 |
+
|
234 |
+
return reports
|
235 |
+
|
236 |
+
|
237 |
+
def model_provider() -> GPTModel:
|
238 |
+
args = get_args()
|
239 |
+
use_te = args.transformer_impl == "transformer_engine"
|
240 |
+
|
241 |
+
# Experimental loading arguments from yaml
|
242 |
+
if args.yaml_cfg is not None:
|
243 |
+
config = core_transformer_config_from_yaml(args, "language_model")
|
244 |
+
else:
|
245 |
+
config = core_transformer_config_from_args(args)
|
246 |
+
assert not args.use_legacy_models
|
247 |
+
|
248 |
+
if args.spec is not None:
|
249 |
+
transformer_layer_spec = import_module(args.spec)
|
250 |
+
else:
|
251 |
+
if use_te:
|
252 |
+
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
|
253 |
+
args.num_experts,
|
254 |
+
args.moe_grouped_gemm,
|
255 |
+
args.qk_layernorm,
|
256 |
+
args.multi_latent_attention,
|
257 |
+
args.fp8,
|
258 |
+
)
|
259 |
+
else:
|
260 |
+
transformer_layer_spec = get_gpt_layer_local_spec(
|
261 |
+
args.num_experts,
|
262 |
+
args.moe_grouped_gemm,
|
263 |
+
args.qk_layernorm,
|
264 |
+
args.multi_latent_attention,
|
265 |
+
)
|
266 |
+
set_global_config(config)
|
267 |
+
pre_process = is_pipeline_first_stage()
|
268 |
+
post_process = is_pipeline_last_stage()
|
269 |
+
# TODO fp8
|
270 |
+
model = GPTModel(
|
271 |
+
config=config,
|
272 |
+
transformer_layer_spec=transformer_layer_spec,
|
273 |
+
vocab_size=args.padded_vocab_size,
|
274 |
+
max_sequence_length=args.max_position_embeddings,
|
275 |
+
pre_process=pre_process,
|
276 |
+
post_process=post_process,
|
277 |
+
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
|
278 |
+
parallel_output=True,
|
279 |
+
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
|
280 |
+
position_embedding_type=args.position_embedding_type,
|
281 |
+
rotary_percent=args.rotary_percent,
|
282 |
+
rotary_base=args.rotary_base,
|
283 |
+
rope_scaling=args.use_rope_scaling,
|
284 |
+
)
|
285 |
+
|
286 |
+
return model
|
287 |
+
|
288 |
+
|
289 |
+
NUM_BYTES_IN_MEGABYTE = 1024 * 1024
|
290 |
+
NUM_BYTES_IN_GIGABYTE = 1024 * 1024 * 1024
|
291 |
+
|
292 |
+
def report_memory_usage():
|
293 |
+
args = get_args()
|
294 |
+
if args.yaml_cfg is not None:
|
295 |
+
config = core_transformer_config_from_yaml(args, "language_model")
|
296 |
+
else:
|
297 |
+
config = core_transformer_config_from_args(args)
|
298 |
+
|
299 |
+
input_shape = [args.micro_batch_size, args.seq_length]
|
300 |
+
|
301 |
+
if config.pipeline_model_parallel_size > 1:
|
302 |
+
for pp_rank in range(config.pipeline_model_parallel_size):
|
303 |
+
set_pipeline_model_parallel_rank(pp_rank)
|
304 |
+
print(f"\n----------[Pipeline_Parallelism_Rank={pp_rank}]----------")
|
305 |
+
input_shape = report_memory_usage_one_pp_rank(
|
306 |
+
input_shape, pp_rank, config.pipeline_model_parallel_size
|
307 |
+
)
|
308 |
+
else:
|
309 |
+
report_memory_usage_one_pp_rank(input_shape)
|
310 |
+
|
311 |
+
|
312 |
+
def report_memory_usage_one_pp_rank(
|
313 |
+
input_shape: list[int], pp_rank=0, pp_size=1
|
314 |
+
) -> list[int]:
|
315 |
+
args = get_args()
|
316 |
+
|
317 |
+
print(f"{input_shape=}")
|
318 |
+
model: GPTModel = model_provider()
|
319 |
+
num_parameter_this_shard = model.num_parameter()
|
320 |
+
num_activation = model.num_activation(input_shape)
|
321 |
+
output_shape = model.mock_forward(input_shape)
|
322 |
+
|
323 |
+
num_parameter_this_shard_sparse = 0
|
324 |
+
for layer in model.decoder.layers.modules:
|
325 |
+
if isinstance(layer.mlp, MoELayer):
|
326 |
+
num_parameter_this_shard_sparse += layer.mlp.num_parameter()
|
327 |
+
if (
|
328 |
+
"shared_experts" in layer.mlp.__dir__()
|
329 |
+
and layer.mlp.shared_experts is not None
|
330 |
+
):
|
331 |
+
num_parameter_this_shard_sparse -= (
|
332 |
+
layer.mlp.shared_experts.num_parameter()
|
333 |
+
)
|
334 |
+
num_activation_this_shard_mlp = sum(
|
335 |
+
[m.mlp.num_activation() for m in model.decoder.layers.modules]
|
336 |
+
)
|
337 |
+
num_microbatch_this_pp_rank = pp_size - pp_rank
|
338 |
+
# vpp
|
339 |
+
if args.num_layers_per_virtual_pipeline_stage is not None:
|
340 |
+
layers_this_pprank = model.decoder.layers.modules.__len__()
|
341 |
+
vpp_size = layers_this_pprank // args.num_layers_per_virtual_pipeline_stage
|
342 |
+
num_microbatch_this_pp_rank = (
|
343 |
+
pp_size * (vpp_size - 1) + (pp_size - pp_rank) * 2 - 1
|
344 |
+
) / vpp_size
|
345 |
+
|
346 |
+
num_parameter_this_shard_sparse = 0
|
347 |
+
for layer in model.decoder.layers.modules:
|
348 |
+
if isinstance(layer.mlp, MoELayer):
|
349 |
+
num_parameter_this_shard_sparse += layer.mlp.num_parameter()
|
350 |
+
if (
|
351 |
+
"shared_experts" in layer.mlp.__dir__()
|
352 |
+
and layer.mlp.shared_experts is not None
|
353 |
+
):
|
354 |
+
num_parameter_this_shard_sparse -= (
|
355 |
+
layer.mlp.shared_experts.num_parameter()
|
356 |
+
)
|
357 |
+
num_microbatch_this_pp_rank = pp_size - pp_rank
|
358 |
+
# vpp
|
359 |
+
if args.num_layers_per_virtual_pipeline_stage is not None:
|
360 |
+
layers_this_pprank = model.decoder.layers.modules.__len__()
|
361 |
+
vpp_size = layers_this_pprank // args.num_layers_per_virtual_pipeline_stage
|
362 |
+
num_microbatch_this_pp_rank = (
|
363 |
+
pp_size * (vpp_size - 1) + (pp_size - pp_rank) * 2 - 1
|
364 |
+
) / vpp_size
|
365 |
+
model.__repr__()
|
366 |
+
print(model)
|
367 |
+
print(
|
368 |
+
f"Number of parameters in every GPU in billions: "
|
369 |
+
f"{num_parameter_this_shard / 10**9: .2f} where mlp part is {num_parameter_this_shard_sparse / 10**9: .2f}"
|
370 |
+
)
|
371 |
+
# recompute
|
372 |
+
if args.recompute_granularity == "full":
|
373 |
+
recompute_num_layers = args.recompute_num_layers
|
374 |
+
num_layers = model.num_layers
|
375 |
+
common_act = (
|
376 |
+
model.num_act_pre
|
377 |
+
+ model.num_act_between_layers * num_layers * num_microbatch_this_pp_rank
|
378 |
+
) # recompute with pipeline parallel
|
379 |
+
info = (
|
380 |
+
"With this recomputing setting, the number of activation achieve peak when "
|
381 |
+
)
|
382 |
+
if args.recompute_method == "block":
|
383 |
+
num_layers_with_loss = num_layers - recompute_num_layers
|
384 |
+
if num_layers_with_loss == 0:
|
385 |
+
peak1 = common_act + model.num_act_post
|
386 |
+
peak2 = common_act + model.num_act_per_layer
|
387 |
+
if peak1 > peak2:
|
388 |
+
info += "calculating loss"
|
389 |
+
else:
|
390 |
+
info += "back-propogating loss"
|
391 |
+
num_activation = max(peak1, peak2)
|
392 |
+
else:
|
393 |
+
info += (
|
394 |
+
f"calculating loss with {num_layers_with_loss} non-recompute layers"
|
395 |
+
)
|
396 |
+
num_activation = (
|
397 |
+
common_act
|
398 |
+
+ model.num_act_post
|
399 |
+
+ model.num_act_per_layer
|
400 |
+
* num_layers_with_loss
|
401 |
+
* num_microbatch_this_pp_rank
|
402 |
+
)
|
403 |
+
elif args.recompute_method == "uniform":
|
404 |
+
peak1 = common_act + model.num_act_post
|
405 |
+
peak2 = (
|
406 |
+
common_act
|
407 |
+
+ model.num_act_per_layer
|
408 |
+
* recompute_num_layers
|
409 |
+
* num_microbatch_this_pp_rank
|
410 |
+
)
|
411 |
+
if peak1 > peak2:
|
412 |
+
info += "calculating loss"
|
413 |
+
else:
|
414 |
+
info += f"back-propogating loss recomputing every {recompute_num_layers} layers"
|
415 |
+
num_activation = max(peak1, peak2)
|
416 |
+
if isinstance(
|
417 |
+
model.decoder.layers.modules[0].self_attention, MLASelfAttention
|
418 |
+
): # MLA recompute achieve peak at backward
|
419 |
+
num_activation += model.decoder.layers.modules[
|
420 |
+
0
|
421 |
+
].self_attention.core_attention.num_activation()
|
422 |
+
print(info)
|
423 |
+
|
424 |
+
else:
|
425 |
+
num_activation = (
|
426 |
+
num_activation - model.num_act_post
|
427 |
+
) * num_microbatch_this_pp_rank + model.num_act_post
|
428 |
+
|
429 |
+
# CP
|
430 |
+
num_activation = (
|
431 |
+
num_activation - num_activation_this_shard_mlp
|
432 |
+
) / args.context_parallel_size + num_activation_this_shard_mlp
|
433 |
+
if pp_size == 1:
|
434 |
+
print(
|
435 |
+
f"Number of activation in every GPU in billions: "
|
436 |
+
f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}"
|
437 |
+
)
|
438 |
+
else:
|
439 |
+
print(
|
440 |
+
f"Number of activation per microbatch in every GPU in billions: "
|
441 |
+
f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}"
|
442 |
+
f", {num_microbatch_this_pp_rank=}"
|
443 |
+
)
|
444 |
+
num_bytes_per_parameter = (
|
445 |
+
18
|
446 |
+
if not args.use_distributed_optimizer
|
447 |
+
else 6 + (12 / args.data_parallel_size / args.context_parallel_size)
|
448 |
+
)
|
449 |
+
if args.expert_model_parallel_size * args.expert_tensor_parallel_size > 1:
|
450 |
+
num_bytes_per_parameter_dense = num_bytes_per_parameter
|
451 |
+
num_bytes_per_parameter_moe = (
|
452 |
+
18
|
453 |
+
if not args.use_distributed_optimizer
|
454 |
+
else 6
|
455 |
+
+ (
|
456 |
+
12
|
457 |
+
/ (
|
458 |
+
args.data_parallel_size
|
459 |
+
* args.context_parallel_size
|
460 |
+
* args.tensor_model_parallel_size
|
461 |
+
/ args.expert_model_parallel_size
|
462 |
+
/ args.expert_tensor_parallel_size
|
463 |
+
)
|
464 |
+
)
|
465 |
+
)
|
466 |
+
print(f"{num_bytes_per_parameter_dense=} {num_bytes_per_parameter_moe=}")
|
467 |
+
|
468 |
+
weight_and_optimizer_memory = (
|
469 |
+
(num_parameter_this_shard - num_parameter_this_shard_sparse)
|
470 |
+
* num_bytes_per_parameter_dense
|
471 |
+
+ num_parameter_this_shard_sparse * num_bytes_per_parameter_moe
|
472 |
+
) / NUM_BYTES_IN_GIGABYTE
|
473 |
+
else:
|
474 |
+
print(f"{num_bytes_per_parameter=}")
|
475 |
+
weight_and_optimizer_memory = (
|
476 |
+
num_parameter_this_shard * num_bytes_per_parameter / NUM_BYTES_IN_GIGABYTE
|
477 |
+
)
|
478 |
+
|
479 |
+
activation_memory = num_activation * 2 / NUM_BYTES_IN_GIGABYTE # only support fp16
|
480 |
+
total_memory = weight_and_optimizer_memory + activation_memory
|
481 |
+
print(
|
482 |
+
f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory/1024:.2f} GB, "
|
483 |
+
f"activation={activation_memory/1024:.2f} GB, total={total_memory/1024:.2f} GB\n"
|
484 |
+
)
|
485 |
+
|
486 |
+
# import ipdb
|
487 |
+
|
488 |
+
# ipdb.set_trace()
|
489 |
+
return output_shape
|
490 |
+
pass
|
491 |
+
|
492 |
+
|
493 |
+
if __name__ == "__main__":
|
494 |
+
initialize_megatron(allow_no_cuda=True, skip_mpu_initialization=True)
|
495 |
+
|
496 |
+
import ipdb
|
497 |
+
|
498 |
+
with ipdb.launch_ipdb_on_exception():
|
499 |
+
report_memory_usage()
|
moe_mem_estimator/__init__.py
ADDED
File without changes
|
moe_mem_estimator/base.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
|
3 |
+
from megatron.core.transformer.transformer_config import TransformerConfig
|
4 |
+
from torch.nn.modules.module import _addindent
|
5 |
+
from termcolor import colored
|
6 |
+
|
7 |
+
|
8 |
+
def prehook_save_input_shape(func):
|
9 |
+
def wrapper(self, *input_shapes, **kw_input_shapes):
|
10 |
+
if len(input_shapes) + len(kw_input_shapes) == 0:
|
11 |
+
if "_input_shape" in self.__dict__:
|
12 |
+
return func(self, *self._input_shape, **self._kw_input_shapes)
|
13 |
+
else:
|
14 |
+
return 0
|
15 |
+
self._input_shape = input_shapes
|
16 |
+
self._kw_input_shapes = kw_input_shapes
|
17 |
+
return func(self, *self._input_shape, **self._kw_input_shapes)
|
18 |
+
|
19 |
+
return wrapper
|
20 |
+
|
21 |
+
|
22 |
+
class MetaBase(type):
|
23 |
+
def __new__(cls, name, bases, attrs):
|
24 |
+
if "num_activation" in attrs:
|
25 |
+
attrs["num_activation"] = prehook_save_input_shape(attrs["num_activation"])
|
26 |
+
|
27 |
+
return super().__new__(cls, name, bases, attrs)
|
28 |
+
|
29 |
+
|
30 |
+
class MemEstimator(metaclass=MetaBase):
|
31 |
+
def __init__(self, *args, **kwargs):
|
32 |
+
self._modules = {}
|
33 |
+
pass
|
34 |
+
|
35 |
+
def __repr__(self):
|
36 |
+
# We treat the extra repr like the sub-module, one item per line
|
37 |
+
extra_lines = []
|
38 |
+
# extra_repr = self.extra_repr()
|
39 |
+
# # empty string will be split into list ['']
|
40 |
+
# if extra_repr:
|
41 |
+
# extra_lines = extra_repr.split("\n")
|
42 |
+
child_lines = []
|
43 |
+
for key, module in self._modules.items():
|
44 |
+
mod_str = repr(module)
|
45 |
+
mod_str = _addindent(mod_str, 2)
|
46 |
+
child_lines.append("(" + key + "): " + mod_str)
|
47 |
+
lines = extra_lines + child_lines
|
48 |
+
|
49 |
+
stat = (
|
50 |
+
"\t/* n_params="
|
51 |
+
+ colored(f"{self.num_parameter()/1024/1024:.2f}M", "red")
|
52 |
+
+ "\tn_act="
|
53 |
+
+ colored(f"{self.num_activation()/1024/1024:.2f}M", "green")
|
54 |
+
+ " */"
|
55 |
+
)
|
56 |
+
main_str = self._get_name() + stat + " ("
|
57 |
+
if lines:
|
58 |
+
# simple one-liner info, which most builtin Modules will use
|
59 |
+
if len(extra_lines) == 1 and not child_lines:
|
60 |
+
main_str += extra_lines[0]
|
61 |
+
else:
|
62 |
+
main_str += "\n " + "\n ".join(lines) + "\n"
|
63 |
+
|
64 |
+
main_str += ")"
|
65 |
+
return main_str
|
66 |
+
return f"{self.__class__.__name__} n_param={self.num_parameter()}"
|
67 |
+
|
68 |
+
def dump(self):
|
69 |
+
ret = {}
|
70 |
+
ret['name'] = self._get_name()
|
71 |
+
ret['n_params'] = self.num_parameter()
|
72 |
+
ret['n_act'] = self.num_activation()
|
73 |
+
modules = {}
|
74 |
+
for key, module in self._modules.items():
|
75 |
+
modules[key] = module.dump()
|
76 |
+
if len(modules)>0:
|
77 |
+
ret['modules'] = modules
|
78 |
+
return ret
|
79 |
+
|
80 |
+
|
81 |
+
def _get_name(self):
|
82 |
+
return self.__class__.__name__
|
83 |
+
|
84 |
+
def num_parameter(self):
|
85 |
+
"""
|
86 |
+
Calculate number of the model parameters
|
87 |
+
"""
|
88 |
+
raise NotImplemented
|
89 |
+
|
90 |
+
def num_activation(self, input_shape: list[int]):
|
91 |
+
"""
|
92 |
+
Calculate number of the activation with given input_shape.
|
93 |
+
Args:
|
94 |
+
input shape
|
95 |
+
"""
|
96 |
+
raise NotImplemented
|
97 |
+
|
98 |
+
def mock_forward(self, input_shape: list[int]):
|
99 |
+
"""
|
100 |
+
Mock the forward.
|
101 |
+
Args:
|
102 |
+
input shape
|
103 |
+
return:
|
104 |
+
output shape
|
105 |
+
"""
|
106 |
+
raise NotImplemented
|
107 |
+
|
108 |
+
def __setattr__(self, name: str, value) -> None:
|
109 |
+
if isinstance(value, MemEstimator):
|
110 |
+
modules = self.__dict__.get("_modules")
|
111 |
+
modules[name] = value
|
112 |
+
else:
|
113 |
+
pass
|
114 |
+
return super().__setattr__(name, value)
|
115 |
+
|
116 |
+
def __delattr__(self, name):
|
117 |
+
modules = self.__dict__.get("_modules")
|
118 |
+
if name in modules:
|
119 |
+
del modules[name]
|
120 |
+
return super().__delattr__(name)
|
121 |
+
|
122 |
+
|
123 |
+
_global_config: TransformerConfig = None
|
124 |
+
|
125 |
+
|
126 |
+
def set_global_config(cfg):
|
127 |
+
global _global_config
|
128 |
+
_global_config = cfg
|
129 |
+
|
130 |
+
|
131 |
+
def get_tensor_model_parallel_world_size():
|
132 |
+
global _global_config
|
133 |
+
return _global_config.tensor_model_parallel_size
|
134 |
+
|
135 |
+
|
136 |
+
def get_tensor_model_parallel_rank():
|
137 |
+
return 0
|
138 |
+
|
139 |
+
|
140 |
+
def get_expert_tensor_parallel_world_size():
|
141 |
+
global _global_config
|
142 |
+
return _global_config.expert_tensor_parallel_size
|
143 |
+
|
144 |
+
|
145 |
+
def get_expert_tensor_parallel_rank():
|
146 |
+
return 0
|
147 |
+
|
148 |
+
|
149 |
+
_pp_rank = 0
|
150 |
+
|
151 |
+
|
152 |
+
def set_pipeline_model_parallel_rank(rank):
|
153 |
+
global _pp_rank
|
154 |
+
_pp_rank = rank
|
155 |
+
|
156 |
+
|
157 |
+
def get_pipeline_model_parallel_rank():
|
158 |
+
global _pp_rank
|
159 |
+
return _pp_rank
|
160 |
+
|
161 |
+
|
162 |
+
def get_virtual_pipeline_model_parallel_rank():
|
163 |
+
return 0
|
164 |
+
|
165 |
+
|
166 |
+
def get_pipeline_model_parallel_world_size():
|
167 |
+
global _global_config
|
168 |
+
return _global_config.pipeline_model_parallel_size
|
169 |
+
|
170 |
+
|
171 |
+
def get_expert_model_parallel_rank():
|
172 |
+
return 0
|
173 |
+
|
174 |
+
|
175 |
+
def get_expert_model_parallel_world_size():
|
176 |
+
global _global_config
|
177 |
+
return _global_config.expert_model_parallel_size
|
178 |
+
|
179 |
+
|
180 |
+
def get_virtual_pipeline_model_parallel_world_size():
|
181 |
+
global _global_config
|
182 |
+
return _global_config.virtual_pipeline_model_parallel_size
|
183 |
+
|
184 |
+
|
185 |
+
def is_pipeline_first_stage(ignore_virtual=False):
|
186 |
+
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
|
187 |
+
if not ignore_virtual:
|
188 |
+
if (
|
189 |
+
get_virtual_pipeline_model_parallel_world_size() is not None
|
190 |
+
and get_virtual_pipeline_model_parallel_rank() != 0
|
191 |
+
):
|
192 |
+
return False
|
193 |
+
return get_pipeline_model_parallel_rank() == 0
|
194 |
+
|
195 |
+
|
196 |
+
def is_pipeline_last_stage(ignore_virtual=False):
|
197 |
+
"""Return True if in the last pipeline-model-parallel stage, False otherwise."""
|
198 |
+
return get_pipeline_model_parallel_rank() == (
|
199 |
+
get_pipeline_model_parallel_world_size() - 1
|
200 |
+
)
|
201 |
+
|
202 |
+
|
203 |
+
def cum_mul(l: list):
|
204 |
+
try:
|
205 |
+
ret = 1
|
206 |
+
for one in l:
|
207 |
+
ret *= one
|
208 |
+
return ret
|
209 |
+
except:
|
210 |
+
return 0
|
211 |
+
__import__('ipdb').set_trace()
|
moe_mem_estimator/gpt_model.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import (
|
2 |
+
MemEstimator,
|
3 |
+
set_global_config,
|
4 |
+
get_tensor_model_parallel_world_size,
|
5 |
+
get_tensor_model_parallel_rank,
|
6 |
+
cum_mul,
|
7 |
+
)
|
8 |
+
|
9 |
+
from megatron.core.transformer.spec_utils import ModuleSpec
|
10 |
+
from typing import Dict, Literal, Optional, Union
|
11 |
+
from megatron.core.transformer.transformer_config import TransformerConfig
|
12 |
+
from megatron.core.model_parallel_config import ModelParallelConfig
|
13 |
+
from megatron.core.tensor_parallel.utils import VocabUtility
|
14 |
+
from megatron.core.transformer.transformer_block import (
|
15 |
+
TransformerBlockSubmodules,
|
16 |
+
_get_block_submodules,
|
17 |
+
)
|
18 |
+
from megatron.core.transformer.enums import ModelType
|
19 |
+
from .layers import LanguageModelEmbedding, TransformerBlock, ColumnParallelLinear
|
20 |
+
|
21 |
+
|
22 |
+
class GPTModel(MemEstimator):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
config: TransformerConfig,
|
26 |
+
transformer_layer_spec: ModuleSpec,
|
27 |
+
vocab_size: int,
|
28 |
+
max_sequence_length: int,
|
29 |
+
pre_process: bool = True,
|
30 |
+
post_process: bool = True,
|
31 |
+
fp16_lm_cross_entropy: bool = False,
|
32 |
+
parallel_output: bool = True,
|
33 |
+
share_embeddings_and_output_weights: bool = False,
|
34 |
+
position_embedding_type: Literal[
|
35 |
+
"learned_absolute", "rope", "none"
|
36 |
+
] = "learned_absolute",
|
37 |
+
rotary_percent: float = 1.0,
|
38 |
+
rotary_base: int = 10000,
|
39 |
+
rope_scaling: bool = False,
|
40 |
+
seq_len_interpolation_factor: Optional[float] = None,
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
self.config = config
|
45 |
+
config.use_cpu_initialization = True
|
46 |
+
|
47 |
+
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
|
48 |
+
self.vocab_size = vocab_size
|
49 |
+
self.max_sequence_length = max_sequence_length
|
50 |
+
self.pre_process = pre_process
|
51 |
+
self.post_process = post_process
|
52 |
+
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
|
53 |
+
self.parallel_output = parallel_output
|
54 |
+
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
|
55 |
+
self.position_embedding_type = position_embedding_type
|
56 |
+
|
57 |
+
# megatron core pipelining currently depends on model type
|
58 |
+
# TODO: remove this dependency ?
|
59 |
+
self.model_type = ModelType.encoder_or_decoder
|
60 |
+
|
61 |
+
# These 4 attributes are needed for TensorRT-LLM export.
|
62 |
+
self.max_position_embeddings = max_sequence_length
|
63 |
+
self.rotary_percent = rotary_percent
|
64 |
+
self.rotary_base = rotary_base
|
65 |
+
self.rotary_scaling = rope_scaling
|
66 |
+
|
67 |
+
if self.pre_process:
|
68 |
+
self.embedding = LanguageModelEmbedding(
|
69 |
+
config=self.config,
|
70 |
+
vocab_size=self.vocab_size,
|
71 |
+
max_sequence_length=self.max_sequence_length,
|
72 |
+
position_embedding_type=position_embedding_type,
|
73 |
+
)
|
74 |
+
|
75 |
+
# remove RotaryEmbedding
|
76 |
+
|
77 |
+
# Transformer.
|
78 |
+
self.decoder = TransformerBlock(
|
79 |
+
config=self.config,
|
80 |
+
spec=transformer_layer_spec,
|
81 |
+
pre_process=self.pre_process,
|
82 |
+
post_process=self.post_process,
|
83 |
+
)
|
84 |
+
|
85 |
+
# Output
|
86 |
+
if post_process:
|
87 |
+
if self.config.defer_embedding_wgrad_compute:
|
88 |
+
self.embedding_activation_buffer = []
|
89 |
+
self.grad_output_buffer = []
|
90 |
+
else:
|
91 |
+
self.embedding_activation_buffer = None
|
92 |
+
self.grad_output_buffer = None
|
93 |
+
|
94 |
+
self.output_layer = ColumnParallelLinear(
|
95 |
+
config.hidden_size,
|
96 |
+
self.vocab_size,
|
97 |
+
config=config,
|
98 |
+
init_method=config.init_method,
|
99 |
+
bias=False,
|
100 |
+
skip_bias_add=False,
|
101 |
+
gather_output=not self.parallel_output,
|
102 |
+
skip_weight_param_allocation=self.pre_process
|
103 |
+
and self.share_embeddings_and_output_weights,
|
104 |
+
embedding_activation_buffer=self.embedding_activation_buffer,
|
105 |
+
grad_output_buffer=self.grad_output_buffer,
|
106 |
+
)
|
107 |
+
|
108 |
+
def num_parameter(self):
|
109 |
+
ret = 0
|
110 |
+
if self.pre_process:
|
111 |
+
ret += self.embedding.num_parameter()
|
112 |
+
ret += self.decoder.num_parameter()
|
113 |
+
if self.post_process:
|
114 |
+
ret += self.output_layer.num_parameter()
|
115 |
+
return ret
|
116 |
+
|
117 |
+
def num_activation(self, input_shape: list[int]):
|
118 |
+
self._inited = True
|
119 |
+
ret = 0
|
120 |
+
|
121 |
+
self.num_act_pre = 0
|
122 |
+
self.num_act_post = 0
|
123 |
+
self.num_act_per_layer = 0
|
124 |
+
self.num_act_between_layers = 0
|
125 |
+
self.num_layers = self.decoder.layers.modules.__len__()
|
126 |
+
|
127 |
+
if self.pre_process:
|
128 |
+
self.num_act_pre = self.embedding.num_activation(input_shape)
|
129 |
+
ret += self.num_act_pre
|
130 |
+
input_shape = self.embedding.mock_forward(input_shape)
|
131 |
+
ret += self.decoder.num_activation(input_shape)
|
132 |
+
self.num_act_per_layer = self.decoder.layers.modules[0].num_activation()
|
133 |
+
input_shape = self.decoder.mock_forward(input_shape)
|
134 |
+
self.num_act_between_layers = cum_mul(input_shape)
|
135 |
+
|
136 |
+
if self.post_process:
|
137 |
+
self.num_act_post = self.output_layer.num_activation(input_shape)
|
138 |
+
softmax_activation = (
|
139 |
+
self.output_layer.num_activation(input_shape) * 2
|
140 |
+
) # due to softmax is calculate in fp32
|
141 |
+
self.num_act_post += softmax_activation
|
142 |
+
ret += self.num_act_post
|
143 |
+
return ret
|
144 |
+
|
145 |
+
def mock_forward(self, input_shape: list[int]):
|
146 |
+
if self.pre_process:
|
147 |
+
input_shape = self.embedding.mock_forward(input_shape)
|
148 |
+
input_shape = self.decoder.mock_forward(input_shape)
|
149 |
+
if self.post_process:
|
150 |
+
input_shape = self.output_layer.mock_forward(input_shape)
|
151 |
+
return input_shape
|
moe_mem_estimator/layers.py
ADDED
@@ -0,0 +1,1813 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import (
|
2 |
+
MemEstimator,
|
3 |
+
set_global_config,
|
4 |
+
get_tensor_model_parallel_world_size,
|
5 |
+
get_tensor_model_parallel_rank,
|
6 |
+
cum_mul,
|
7 |
+
get_expert_tensor_parallel_world_size,
|
8 |
+
get_expert_tensor_parallel_rank,
|
9 |
+
get_pipeline_model_parallel_world_size,
|
10 |
+
get_pipeline_model_parallel_rank,
|
11 |
+
get_expert_model_parallel_rank,
|
12 |
+
get_expert_model_parallel_world_size,
|
13 |
+
is_pipeline_first_stage,
|
14 |
+
is_pipeline_last_stage,
|
15 |
+
_addindent,
|
16 |
+
colored,
|
17 |
+
)
|
18 |
+
|
19 |
+
from megatron.core.transformer.spec_utils import ModuleSpec
|
20 |
+
from typing import Dict, Literal, Optional, Union
|
21 |
+
from megatron.core.transformer.transformer_config import (
|
22 |
+
TransformerConfig,
|
23 |
+
MLATransformerConfig,
|
24 |
+
)
|
25 |
+
from megatron.core.model_parallel_config import ModelParallelConfig
|
26 |
+
from megatron.core.tensor_parallel.utils import VocabUtility
|
27 |
+
from megatron.core.transformer.transformer_block import (
|
28 |
+
TransformerBlockSubmodules,
|
29 |
+
)
|
30 |
+
from megatron.core.models.common.embeddings import (
|
31 |
+
_yarn_get_mscale,
|
32 |
+
apply_rotary_pos_emb,
|
33 |
+
)
|
34 |
+
from megatron.core.extensions.transformer_engine import (
|
35 |
+
_get_extra_te_kwargs,
|
36 |
+
get_expert_parallel_rng_tracker_name,
|
37 |
+
condition_init_method,
|
38 |
+
)
|
39 |
+
from megatron.core.transformer.enums import AttnMaskType
|
40 |
+
from megatron.core.transformer.mlp import MLPSubmodules
|
41 |
+
from megatron.core.utils import divide
|
42 |
+
from megatron.core.transformer.spec_utils import import_module
|
43 |
+
from megatron.core.transformer import transformer_layer
|
44 |
+
import types, math
|
45 |
+
import warnings
|
46 |
+
from copy import deepcopy
|
47 |
+
|
48 |
+
|
49 |
+
class LanguageModelEmbedding(MemEstimator):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
config: TransformerConfig,
|
53 |
+
vocab_size: int,
|
54 |
+
max_sequence_length: int,
|
55 |
+
position_embedding_type: Literal[
|
56 |
+
"learned_absolute", "rope", "none"
|
57 |
+
] = "learned_absolute",
|
58 |
+
num_tokentypes: int = 0,
|
59 |
+
):
|
60 |
+
super().__init__()
|
61 |
+
|
62 |
+
self.config: TransformerConfig = config
|
63 |
+
self.vocab_size: int = vocab_size
|
64 |
+
self.max_sequence_length: int = max_sequence_length
|
65 |
+
self.add_position_embedding: bool = (
|
66 |
+
position_embedding_type == "learned_absolute"
|
67 |
+
)
|
68 |
+
self.num_tokentypes = num_tokentypes
|
69 |
+
self.reduce_scatter_embeddings = (
|
70 |
+
(not self.add_position_embedding)
|
71 |
+
and self.num_tokentypes <= 0
|
72 |
+
and self.config.sequence_parallel
|
73 |
+
)
|
74 |
+
# Word embeddings (parallel).
|
75 |
+
self.word_embeddings = VocabParallelEmbedding(
|
76 |
+
num_embeddings=self.vocab_size,
|
77 |
+
embedding_dim=self.config.hidden_size,
|
78 |
+
init_method=self.config.init_method,
|
79 |
+
reduce_scatter_embeddings=self.reduce_scatter_embeddings,
|
80 |
+
config=self.config,
|
81 |
+
)
|
82 |
+
|
83 |
+
# TODO if self.add_position_embedding:
|
84 |
+
|
85 |
+
# TODO if self.num_tokentypes > 0:
|
86 |
+
|
87 |
+
self.embedding_dropout = Dropout(self.config.hidden_dropout)
|
88 |
+
|
89 |
+
def num_parameter(self):
|
90 |
+
ret = self.word_embeddings.num_parameter()
|
91 |
+
ret += self.embedding_dropout.num_parameter()
|
92 |
+
return ret
|
93 |
+
|
94 |
+
def num_activation(self, input_shape: list[int]):
|
95 |
+
ret = self.word_embeddings.num_activation(input_shape)
|
96 |
+
input_shape = self.word_embeddings.mock_forward(input_shape)
|
97 |
+
ret += self.embedding_dropout.num_activation(input_shape)
|
98 |
+
return ret
|
99 |
+
|
100 |
+
def mock_forward(self, input_shape: list[int]):
|
101 |
+
input_shape = self.word_embeddings.mock_forward(input_shape)
|
102 |
+
return input_shape
|
103 |
+
|
104 |
+
|
105 |
+
class VocabParallelEmbedding(MemEstimator):
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
num_embeddings: int,
|
109 |
+
embedding_dim: int,
|
110 |
+
*,
|
111 |
+
init_method,
|
112 |
+
reduce_scatter_embeddings: bool = False,
|
113 |
+
config: ModelParallelConfig,
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
# Keep the input dimensions.
|
117 |
+
self.num_embeddings = num_embeddings
|
118 |
+
self.embedding_dim = embedding_dim
|
119 |
+
self.reduce_scatter_embeddings = reduce_scatter_embeddings
|
120 |
+
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
|
121 |
+
# Divide the weight matrix along the vocaburaly dimension.
|
122 |
+
(self.vocab_start_index, self.vocab_end_index) = (
|
123 |
+
VocabUtility.vocab_range_from_global_vocab_size(
|
124 |
+
self.num_embeddings,
|
125 |
+
get_tensor_model_parallel_rank(),
|
126 |
+
self.tensor_model_parallel_size,
|
127 |
+
)
|
128 |
+
)
|
129 |
+
self.num_embeddings_per_partition = (
|
130 |
+
self.vocab_end_index - self.vocab_start_index
|
131 |
+
)
|
132 |
+
self.deterministic_mode = config.deterministic_mode
|
133 |
+
self.weight = (self.num_embeddings_per_partition, self.embedding_dim)
|
134 |
+
|
135 |
+
def num_parameter(self):
|
136 |
+
return self.weight[0] * self.weight[1]
|
137 |
+
|
138 |
+
def num_activation(self, input_shape: list[int]):
|
139 |
+
return cum_mul(input_shape) * self.weight[1]
|
140 |
+
|
141 |
+
def mock_forward(self, input_shape: list[int]):
|
142 |
+
return input_shape + [self.weight[1]]
|
143 |
+
|
144 |
+
|
145 |
+
class Dropout(MemEstimator):
|
146 |
+
def __init__(self, p=0, *args, **kwargs):
|
147 |
+
super().__init__()
|
148 |
+
self.p = p
|
149 |
+
|
150 |
+
def num_parameter(self):
|
151 |
+
return 0
|
152 |
+
|
153 |
+
def num_activation(self, input_shape: list[int]):
|
154 |
+
if self.p == 0:
|
155 |
+
return 0
|
156 |
+
return cum_mul(input_shape[:])
|
157 |
+
|
158 |
+
def mock_forward(self, input_shape: list[int]):
|
159 |
+
return input_shape
|
160 |
+
|
161 |
+
|
162 |
+
class ColumnParallelLinear(MemEstimator):
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
input_size,
|
166 |
+
output_size,
|
167 |
+
*,
|
168 |
+
config: ModelParallelConfig,
|
169 |
+
init_method,
|
170 |
+
bias=True,
|
171 |
+
gather_output=False,
|
172 |
+
stride=1,
|
173 |
+
keep_master_weight_for_test=False,
|
174 |
+
skip_bias_add=False,
|
175 |
+
skip_weight_param_allocation: bool = False,
|
176 |
+
embedding_activation_buffer=None,
|
177 |
+
grad_output_buffer=None,
|
178 |
+
is_expert: bool = False,
|
179 |
+
tp_comm_buffer_name: str = None, # Not used
|
180 |
+
disable_grad_reduce: bool = False,
|
181 |
+
is_mla: bool = False,
|
182 |
+
):
|
183 |
+
super().__init__()
|
184 |
+
|
185 |
+
if is_mla and config.sequence_parallel:
|
186 |
+
tp_size = get_tensor_model_parallel_world_size()
|
187 |
+
output_size = divide(output_size, tp_size)
|
188 |
+
parallel_mode = None
|
189 |
+
tp_size = 1
|
190 |
+
tp_group = None
|
191 |
+
# Keep input parameters
|
192 |
+
self.input_size = input_size
|
193 |
+
self.output_size = output_size
|
194 |
+
self.gather_output = gather_output
|
195 |
+
# Divide the weight matrix along the last dimension.
|
196 |
+
self.skip_bias_add = skip_bias_add
|
197 |
+
self.is_expert = is_expert
|
198 |
+
self.expert_parallel = config.expert_model_parallel_size > 1
|
199 |
+
self.embedding_activation_buffer = embedding_activation_buffer
|
200 |
+
self.grad_output_buffer = grad_output_buffer
|
201 |
+
self.config = config
|
202 |
+
self.disable_grad_reduce = disable_grad_reduce
|
203 |
+
|
204 |
+
if is_expert:
|
205 |
+
world_size = get_expert_tensor_parallel_world_size()
|
206 |
+
rank = get_expert_tensor_parallel_rank()
|
207 |
+
else:
|
208 |
+
world_size = get_tensor_model_parallel_world_size()
|
209 |
+
rank = get_tensor_model_parallel_rank()
|
210 |
+
|
211 |
+
self.output_size_per_partition = divide(output_size, world_size)
|
212 |
+
|
213 |
+
# Parameters.
|
214 |
+
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
215 |
+
# we allocate the transpose.
|
216 |
+
# Initialize weight.
|
217 |
+
if not skip_weight_param_allocation:
|
218 |
+
self.weight = (self.output_size_per_partition, self.input_size)
|
219 |
+
else:
|
220 |
+
self.weight = (self.output_size_per_partition, self.input_size)
|
221 |
+
|
222 |
+
|
223 |
+
if bias:
|
224 |
+
self.bias = [self.output_size_per_partition]
|
225 |
+
else:
|
226 |
+
self.bias = None
|
227 |
+
|
228 |
+
self.sequence_parallel = config.sequence_parallel
|
229 |
+
if self.sequence_parallel and world_size <= 1:
|
230 |
+
warnings.warn(
|
231 |
+
"`sequence_parallel` is set to `True`, but tensor model parallel size "
|
232 |
+
f"is {world_size}. Disabling sequence parallel."
|
233 |
+
)
|
234 |
+
self.sequence_parallel = False
|
235 |
+
|
236 |
+
self.allreduce_dgrad = (
|
237 |
+
world_size > 1
|
238 |
+
and not self.sequence_parallel
|
239 |
+
and not self.disable_grad_reduce
|
240 |
+
)
|
241 |
+
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
|
242 |
+
|
243 |
+
def num_parameter(self):
|
244 |
+
ret = cum_mul(self.weight)
|
245 |
+
if self.bias is not None:
|
246 |
+
ret += self.bias[0]
|
247 |
+
return ret
|
248 |
+
|
249 |
+
def num_activation(self, input_shape: list[int]):
|
250 |
+
return cum_mul(input_shape[:-1]) * self.weight[0]
|
251 |
+
|
252 |
+
def mock_forward(self, input_shape: list[int]):
|
253 |
+
assert self.weight[-1] == input_shape[-1]
|
254 |
+
return input_shape[:-1] + [self.weight[0]]
|
255 |
+
|
256 |
+
|
257 |
+
class RowParallelLinear(MemEstimator):
|
258 |
+
def __init__(
|
259 |
+
self,
|
260 |
+
input_size: int,
|
261 |
+
output_size: int,
|
262 |
+
*,
|
263 |
+
config: ModelParallelConfig,
|
264 |
+
init_method,
|
265 |
+
bias: bool,
|
266 |
+
input_is_parallel: bool,
|
267 |
+
skip_bias_add: bool,
|
268 |
+
stride: int = 1,
|
269 |
+
keep_master_weight_for_test: bool = False,
|
270 |
+
is_expert: bool = False,
|
271 |
+
tp_comm_buffer_name: str = None, # Not used
|
272 |
+
):
|
273 |
+
super().__init__()
|
274 |
+
|
275 |
+
# Keep input parameters
|
276 |
+
self.input_size = input_size
|
277 |
+
self.output_size = output_size
|
278 |
+
self.input_is_parallel = input_is_parallel
|
279 |
+
self.skip_bias_add = skip_bias_add
|
280 |
+
self.config = config
|
281 |
+
self.is_expert = is_expert
|
282 |
+
self.expert_parallel = config.expert_model_parallel_size > 1
|
283 |
+
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
|
284 |
+
self.sequence_parallel = config.sequence_parallel
|
285 |
+
if self.sequence_parallel and not self.input_is_parallel:
|
286 |
+
raise RuntimeError(
|
287 |
+
"To enable `sequence_parallel`, `input_is_parallel` must be `True`"
|
288 |
+
)
|
289 |
+
|
290 |
+
# Divide the weight matrix along the last dimension.
|
291 |
+
if self.is_expert:
|
292 |
+
world_size = get_expert_tensor_parallel_world_size()
|
293 |
+
rank = get_expert_tensor_parallel_rank()
|
294 |
+
else:
|
295 |
+
world_size = get_tensor_model_parallel_world_size()
|
296 |
+
rank = get_tensor_model_parallel_rank()
|
297 |
+
|
298 |
+
self.input_size_per_partition = divide(input_size, world_size)
|
299 |
+
|
300 |
+
self.weight = (self.output_size, self.input_size_per_partition)
|
301 |
+
if bias:
|
302 |
+
self.bias = [self.output_size]
|
303 |
+
else:
|
304 |
+
self.bias = None
|
305 |
+
|
306 |
+
def num_parameter(self):
|
307 |
+
ret = cum_mul(self.weight)
|
308 |
+
if self.bias is not None:
|
309 |
+
ret += self.bias[0]
|
310 |
+
return ret
|
311 |
+
|
312 |
+
def num_activation(self, input_shape: list[int]):
|
313 |
+
return cum_mul(input_shape[:-1]) * self.weight[1]
|
314 |
+
|
315 |
+
def mock_forward(self, input_shape: list[int]):
|
316 |
+
assert self.weight[0] == input_shape[-1]
|
317 |
+
return input_shape[:-1] + [self.weight[1]]
|
318 |
+
|
319 |
+
|
320 |
+
class RMSNorm(MemEstimator):
|
321 |
+
def __init__(self, hidden_size: int, *args, **kwargs):
|
322 |
+
super().__init__()
|
323 |
+
self.weight = hidden_size
|
324 |
+
|
325 |
+
def num_parameter(self):
|
326 |
+
return self.weight
|
327 |
+
|
328 |
+
def num_activation(self, input_shape: list[int]):
|
329 |
+
return cum_mul(input_shape[:])
|
330 |
+
|
331 |
+
def mock_forward(self, input_shape: list[int]):
|
332 |
+
return input_shape
|
333 |
+
|
334 |
+
|
335 |
+
class GetBiasDropoutAdd(MemEstimator):
|
336 |
+
def __init__(self, *args, **kwargs):
|
337 |
+
super().__init__()
|
338 |
+
|
339 |
+
def num_parameter(self):
|
340 |
+
return 0
|
341 |
+
|
342 |
+
def num_activation(self, input_shape: list[int]):
|
343 |
+
return cum_mul(input_shape[:])
|
344 |
+
|
345 |
+
def mock_forward(self, input_shape: list[int]):
|
346 |
+
return input_shape
|
347 |
+
|
348 |
+
|
349 |
+
get_bias_dropout_add = GetBiasDropoutAdd()
|
350 |
+
|
351 |
+
|
352 |
+
class MLP(MemEstimator):
|
353 |
+
|
354 |
+
def __init__(
|
355 |
+
self,
|
356 |
+
config: TransformerConfig,
|
357 |
+
submodules,
|
358 |
+
is_expert: bool = False,
|
359 |
+
input_size: int = None,
|
360 |
+
):
|
361 |
+
super().__init__()
|
362 |
+
|
363 |
+
self.config: TransformerConfig = config
|
364 |
+
|
365 |
+
self.input_size = input_size if input_size != None else self.config.hidden_size
|
366 |
+
|
367 |
+
# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
368 |
+
ffn_hidden_size = self.config.ffn_hidden_size
|
369 |
+
if self.config.gated_linear_unit:
|
370 |
+
ffn_hidden_size *= 2
|
371 |
+
|
372 |
+
self.linear_fc1 = build_module(
|
373 |
+
submodules.linear_fc1,
|
374 |
+
self.input_size,
|
375 |
+
ffn_hidden_size,
|
376 |
+
config=self.config,
|
377 |
+
init_method=self.config.init_method,
|
378 |
+
gather_output=False,
|
379 |
+
bias=self.config.add_bias_linear,
|
380 |
+
skip_bias_add=True,
|
381 |
+
is_expert=is_expert,
|
382 |
+
tp_comm_buffer_name="fc1",
|
383 |
+
)
|
384 |
+
|
385 |
+
self.activation_func = self.config.activation_func
|
386 |
+
|
387 |
+
self.linear_fc2 = build_module(
|
388 |
+
submodules.linear_fc2,
|
389 |
+
self.config.ffn_hidden_size,
|
390 |
+
self.config.hidden_size,
|
391 |
+
config=self.config,
|
392 |
+
init_method=self.config.output_layer_init_method,
|
393 |
+
bias=self.config.add_bias_linear,
|
394 |
+
input_is_parallel=True,
|
395 |
+
skip_bias_add=True,
|
396 |
+
is_expert=is_expert,
|
397 |
+
tp_comm_buffer_name="fc2",
|
398 |
+
)
|
399 |
+
|
400 |
+
def num_parameter(self):
|
401 |
+
return self.linear_fc1.num_parameter() + self.linear_fc2.num_parameter()
|
402 |
+
|
403 |
+
def num_activation(self, input_shape: list[int]):
|
404 |
+
result = 0
|
405 |
+
result += self.linear_fc1.num_activation(input_shape)
|
406 |
+
intermediate_shape = self.linear_fc1.mock_forward(input_shape)
|
407 |
+
result += cum_mul(intermediate_shape) / 2 # activation layer
|
408 |
+
self.linear_fc2.num_activation(intermediate_shape)
|
409 |
+
|
410 |
+
return result
|
411 |
+
|
412 |
+
def mock_forward(self, input_shape: list[int]):
|
413 |
+
intermediate_shape = self.linear_fc1.mock_forward(input_shape)
|
414 |
+
output_shape = self.linear_fc2.mock_forward(intermediate_shape)
|
415 |
+
return output_shape
|
416 |
+
|
417 |
+
|
418 |
+
class ModuleList(MemEstimator):
|
419 |
+
def __init__(self, modules: list[MemEstimator] = None):
|
420 |
+
super().__init__()
|
421 |
+
if modules is None:
|
422 |
+
modules = []
|
423 |
+
self.modules = modules
|
424 |
+
|
425 |
+
def __repr__(self):
|
426 |
+
"""Return a custom repr for ModuleList that compresses repeated module representations."""
|
427 |
+
list_of_reprs = [repr(item) for item in self.modules]
|
428 |
+
if len(list_of_reprs) == 0:
|
429 |
+
return self._get_name() + "()"
|
430 |
+
|
431 |
+
start_end_indices = [[0, 0]]
|
432 |
+
repeated_blocks = [list_of_reprs[0]]
|
433 |
+
for i, r in enumerate(list_of_reprs[1:], 1):
|
434 |
+
if r == repeated_blocks[-1]:
|
435 |
+
start_end_indices[-1][1] += 1
|
436 |
+
continue
|
437 |
+
|
438 |
+
start_end_indices.append([i, i])
|
439 |
+
repeated_blocks.append(r)
|
440 |
+
|
441 |
+
lines = []
|
442 |
+
stat = (
|
443 |
+
"\t/* n_params="
|
444 |
+
+ colored(f"{self.num_parameter()/1024/1024:.2f}M", "red")
|
445 |
+
+ "\tn_act="
|
446 |
+
+ colored(f"{self.num_activation()/1024/1024:.2f}M", "green")
|
447 |
+
+ " */"
|
448 |
+
)
|
449 |
+
main_str = self._get_name() + stat + " ("
|
450 |
+
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
|
451 |
+
local_repr = f"({start_id}): {b}" # default repr
|
452 |
+
|
453 |
+
if start_id != end_id:
|
454 |
+
n = end_id - start_id + 1
|
455 |
+
local_repr = f"({start_id}-{end_id}): {n} x {b}"
|
456 |
+
|
457 |
+
local_repr = _addindent(local_repr, 2)
|
458 |
+
lines.append(local_repr)
|
459 |
+
|
460 |
+
main_str += "\n " + "\n ".join(lines) + "\n"
|
461 |
+
main_str += ")"
|
462 |
+
return main_str
|
463 |
+
|
464 |
+
def dump(self):
|
465 |
+
list_of_reprs = [repr(item) for item in self.modules]
|
466 |
+
if len(list_of_reprs) == 0:
|
467 |
+
return self._get_name() + "()"
|
468 |
+
list_of_dumps = [item.dump() for item in self.modules]
|
469 |
+
|
470 |
+
start_end_indices = [[0, 0]]
|
471 |
+
repeated_blocks = [list_of_reprs[0]]
|
472 |
+
repeated_blocks_dump = [list_of_dumps[0]]
|
473 |
+
for i, r in enumerate(list_of_reprs[1:], 1):
|
474 |
+
if r == repeated_blocks[-1]:
|
475 |
+
start_end_indices[-1][1] += 1
|
476 |
+
continue
|
477 |
+
|
478 |
+
start_end_indices.append([i, i])
|
479 |
+
repeated_blocks.append(r)
|
480 |
+
repeated_blocks_dump(list_of_dumps[i])
|
481 |
+
modules = {}
|
482 |
+
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks_dump):
|
483 |
+
key = f"({start_id})"
|
484 |
+
if start_id != end_id:
|
485 |
+
n = end_id - start_id + 1
|
486 |
+
key = f"({start_id}-{end_id}) {n} layers"
|
487 |
+
modules[key] = b
|
488 |
+
|
489 |
+
ret = {}
|
490 |
+
ret["name"] = self._get_name()
|
491 |
+
ret["n_params"] = self.num_parameter()
|
492 |
+
ret["n_act"] = self.num_activation()
|
493 |
+
if len(modules) > 0:
|
494 |
+
ret["modules"] = modules
|
495 |
+
return ret
|
496 |
+
|
497 |
+
def append(self, m: MemEstimator):
|
498 |
+
self.modules.append(m)
|
499 |
+
|
500 |
+
def __len__(
|
501 |
+
self,
|
502 |
+
):
|
503 |
+
return self.modules.__len__()
|
504 |
+
|
505 |
+
def num_parameter(self):
|
506 |
+
return sum([x.num_parameter() for x in self.modules])
|
507 |
+
|
508 |
+
def num_activation(self, input_shape: list[int]):
|
509 |
+
result = 0
|
510 |
+
for m in self.modules:
|
511 |
+
result += m.num_activation(input_shape)
|
512 |
+
input_shape = m.mock_forward(input_shape)
|
513 |
+
|
514 |
+
return result
|
515 |
+
|
516 |
+
def mock_forward(self, input_shape: list[int]):
|
517 |
+
for m in self.modules:
|
518 |
+
result += m.num_activation(input_shape)
|
519 |
+
input_shape = m.mock_forward(input_shape)
|
520 |
+
return input_shape
|
521 |
+
|
522 |
+
|
523 |
+
class SequentialMLP(MemEstimator):
|
524 |
+
def __init__(self, num_local_experts, config: TransformerConfig, submodules):
|
525 |
+
super().__init__()
|
526 |
+
self.config = config
|
527 |
+
self.add_bias = config.add_bias_linear
|
528 |
+
self.moe_extended_tp = config.moe_extended_tp
|
529 |
+
self.num_local_experts = num_local_experts
|
530 |
+
self.local_experts = ModuleList()
|
531 |
+
for _ in range(self.num_local_experts):
|
532 |
+
expert = MLP(self.config, submodules, is_expert=True)
|
533 |
+
self.local_experts.append(expert)
|
534 |
+
|
535 |
+
def num_parameter(self):
|
536 |
+
return self.local_experts.num_parameter()
|
537 |
+
|
538 |
+
def num_activation(self, input_shape: list[int], tokens_per_expert=None):
|
539 |
+
# assume all the inputs are routed equally
|
540 |
+
all_tokens = input_shape[1]
|
541 |
+
result = 0
|
542 |
+
for m in self.local_experts.modules:
|
543 |
+
result += m.num_activation(
|
544 |
+
input_shape[:1]
|
545 |
+
+ [all_tokens // self.num_local_experts]
|
546 |
+
+ input_shape[2:]
|
547 |
+
)
|
548 |
+
return result
|
549 |
+
|
550 |
+
def mock_forward(self, input_shape: list[int], tokens_per_expert=None):
|
551 |
+
# assume all the inputs are routed to the first expert
|
552 |
+
input_shape = self.local_experts.modules[0].mock_forward(input_shape)
|
553 |
+
return input_shape
|
554 |
+
|
555 |
+
|
556 |
+
class TEGroupedMLP(MemEstimator):
|
557 |
+
"""An efficient implementation of the Experts layer using TE's GroupedLinear.
|
558 |
+
|
559 |
+
Executes multiple experts in parallel to maximize computational efficiency.
|
560 |
+
"""
|
561 |
+
|
562 |
+
def __init__(self, num_local_experts, config: TransformerConfig, submodules):
|
563 |
+
super().__init__()
|
564 |
+
self.config = config
|
565 |
+
self.moe_extended_tp = config.moe_extended_tp
|
566 |
+
self.num_local_experts = num_local_experts
|
567 |
+
self.input_size = self.config.hidden_size
|
568 |
+
|
569 |
+
# Double the output width with gated linear unit, see https://arxiv.org/pdf/2002.05202.pdf
|
570 |
+
ffn_hidden_size = self.config.moe_ffn_hidden_size
|
571 |
+
if self.config.gated_linear_unit:
|
572 |
+
ffn_hidden_size *= 2
|
573 |
+
|
574 |
+
self.linear_fc1 = build_module(
|
575 |
+
submodules.linear_fc1,
|
576 |
+
self.num_local_experts,
|
577 |
+
self.input_size,
|
578 |
+
ffn_hidden_size,
|
579 |
+
config=self.config,
|
580 |
+
init_method=self.config.init_method,
|
581 |
+
bias=self.config.add_bias_linear,
|
582 |
+
skip_bias_add=True,
|
583 |
+
is_expert=True,
|
584 |
+
tp_comm_buffer_name="fc1",
|
585 |
+
)
|
586 |
+
|
587 |
+
self.activation_func = self.config.activation_func
|
588 |
+
|
589 |
+
self.linear_fc2 = build_module(
|
590 |
+
submodules.linear_fc2,
|
591 |
+
self.num_local_experts,
|
592 |
+
self.config.moe_ffn_hidden_size,
|
593 |
+
self.config.hidden_size,
|
594 |
+
config=self.config,
|
595 |
+
init_method=self.config.output_layer_init_method,
|
596 |
+
bias=self.config.add_bias_linear,
|
597 |
+
skip_bias_add=True,
|
598 |
+
is_expert=True,
|
599 |
+
tp_comm_buffer_name="fc2",
|
600 |
+
)
|
601 |
+
# TODO if self.config.fp8:
|
602 |
+
|
603 |
+
def num_parameter(self):
|
604 |
+
ret = self.linear_fc1.num_parameter()
|
605 |
+
ret += self.linear_fc2.num_parameter()
|
606 |
+
return ret
|
607 |
+
|
608 |
+
def num_activation(self, input_shape: list[int], tokens_per_expert=None):
|
609 |
+
ret = 0
|
610 |
+
ret += self.linear_fc1.num_activation(input_shape)
|
611 |
+
input_shape = self.linear_fc1.mock_forward(input_shape)
|
612 |
+
|
613 |
+
# activation
|
614 |
+
ret += cum_mul(input_shape) / 2 # swiglu or gelu
|
615 |
+
input_shape = deepcopy(input_shape)
|
616 |
+
input_shape[-1] //= 2
|
617 |
+
|
618 |
+
self.linear_fc2.num_activation(input_shape)
|
619 |
+
return ret
|
620 |
+
|
621 |
+
def mock_forward(self, input_shape: list[int], tokens_per_expert=None):
|
622 |
+
# assume all the inputs are routed to the first expert
|
623 |
+
input_shape = self.local_experts.modules[0].mock_forward(input_shape)
|
624 |
+
return input_shape
|
625 |
+
|
626 |
+
|
627 |
+
class TEGroupedLinear(MemEstimator):
|
628 |
+
def __init__(
|
629 |
+
self,
|
630 |
+
num_gemms: int,
|
631 |
+
input_size: int,
|
632 |
+
output_size: int,
|
633 |
+
*,
|
634 |
+
parallel_mode: str,
|
635 |
+
config: ModelParallelConfig,
|
636 |
+
init_method,
|
637 |
+
bias: bool,
|
638 |
+
skip_bias_add: bool,
|
639 |
+
is_expert: bool = False,
|
640 |
+
tp_comm_buffer_name: str = None,
|
641 |
+
):
|
642 |
+
super().__init__()
|
643 |
+
self.config = config
|
644 |
+
|
645 |
+
# TE returns a zero length Tensor when bias=False and
|
646 |
+
# return_bias=True, but we prefer None. So in that case we
|
647 |
+
# tell TE to not return the bias, and return None
|
648 |
+
# ourselves. This way our forward always returns two values
|
649 |
+
# and we don't have to deal with the zero length Tensor.
|
650 |
+
self.te_return_bias = skip_bias_add and bias
|
651 |
+
self.is_first_microbatch = True
|
652 |
+
self.disable_parameter_transpose_cache = (
|
653 |
+
self.config.disable_parameter_transpose_cache
|
654 |
+
)
|
655 |
+
|
656 |
+
extra_kwargs = _get_extra_te_kwargs(config)
|
657 |
+
extra_kwargs["ub_name"] = tp_comm_buffer_name
|
658 |
+
|
659 |
+
self.expert_parallel = self.config.expert_model_parallel_size > 1
|
660 |
+
if self.expert_parallel:
|
661 |
+
extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name()
|
662 |
+
|
663 |
+
# For MoE models, the comms between TP and EP group is explicitly handled by
|
664 |
+
# MoE token dispatcher. So we disable comms by making TE agnostic of model parallel.
|
665 |
+
self.explicit_expert_comm = is_expert and (
|
666 |
+
config.tensor_model_parallel_size > 1 or self.expert_parallel
|
667 |
+
)
|
668 |
+
if is_expert:
|
669 |
+
tp_size = get_expert_tensor_parallel_world_size()
|
670 |
+
else:
|
671 |
+
tp_size = get_tensor_model_parallel_world_size()
|
672 |
+
if self.explicit_expert_comm:
|
673 |
+
if parallel_mode == "column":
|
674 |
+
output_size = divide(output_size, tp_size)
|
675 |
+
elif parallel_mode == "row":
|
676 |
+
input_size = divide(input_size, tp_size)
|
677 |
+
parallel_mode = None
|
678 |
+
tp_size = 1
|
679 |
+
assert not bias, "bias is not considered for now"
|
680 |
+
|
681 |
+
self.num_gemms = num_gemms
|
682 |
+
self.input_size = input_size
|
683 |
+
self.output_size = output_size
|
684 |
+
|
685 |
+
def num_parameter(self):
|
686 |
+
ret = self.num_gemms * self.input_size * self.output_size
|
687 |
+
return ret
|
688 |
+
|
689 |
+
def num_activation(self, input_shape: list[int], tokens_per_expert=None):
|
690 |
+
ret = cum_mul(self.mock_forward(input_shape))
|
691 |
+
return ret
|
692 |
+
|
693 |
+
def mock_forward(self, input_shape: list[int], tokens_per_expert=None):
|
694 |
+
return input_shape[:-1] + [self.output_size]
|
695 |
+
|
696 |
+
|
697 |
+
class TEColumnParallelGroupedLinear(TEGroupedLinear):
|
698 |
+
def __init__(
|
699 |
+
self,
|
700 |
+
num_gemms: int,
|
701 |
+
input_size: int,
|
702 |
+
output_size: int,
|
703 |
+
*,
|
704 |
+
config: ModelParallelConfig,
|
705 |
+
init_method,
|
706 |
+
bias: bool,
|
707 |
+
skip_bias_add: bool,
|
708 |
+
is_expert: bool,
|
709 |
+
tp_comm_buffer_name: str = None,
|
710 |
+
):
|
711 |
+
super().__init__(
|
712 |
+
num_gemms=num_gemms,
|
713 |
+
input_size=input_size,
|
714 |
+
output_size=output_size,
|
715 |
+
parallel_mode="column",
|
716 |
+
config=config,
|
717 |
+
init_method=condition_init_method(config, init_method),
|
718 |
+
bias=bias,
|
719 |
+
skip_bias_add=skip_bias_add,
|
720 |
+
is_expert=is_expert,
|
721 |
+
tp_comm_buffer_name=tp_comm_buffer_name,
|
722 |
+
)
|
723 |
+
|
724 |
+
|
725 |
+
class TERowParallelGroupedLinear(TEGroupedLinear):
|
726 |
+
def __init__(
|
727 |
+
self,
|
728 |
+
num_gemms: int,
|
729 |
+
input_size: int,
|
730 |
+
output_size: int,
|
731 |
+
*,
|
732 |
+
config: ModelParallelConfig,
|
733 |
+
init_method,
|
734 |
+
bias: bool,
|
735 |
+
skip_bias_add: bool,
|
736 |
+
is_expert: bool,
|
737 |
+
tp_comm_buffer_name: str = None,
|
738 |
+
):
|
739 |
+
|
740 |
+
super().__init__(
|
741 |
+
num_gemms=num_gemms,
|
742 |
+
input_size=input_size,
|
743 |
+
output_size=output_size,
|
744 |
+
parallel_mode="row",
|
745 |
+
config=config,
|
746 |
+
init_method=condition_init_method(config, init_method),
|
747 |
+
bias=bias,
|
748 |
+
skip_bias_add=skip_bias_add,
|
749 |
+
is_expert=is_expert,
|
750 |
+
tp_comm_buffer_name=tp_comm_buffer_name,
|
751 |
+
)
|
752 |
+
|
753 |
+
|
754 |
+
class SharedExpertMLP(MLP):
|
755 |
+
"""
|
756 |
+
MLP layer for Shared Experts.
|
757 |
+
"""
|
758 |
+
|
759 |
+
def __init__(self, config: TransformerConfig, spec: ModuleSpec):
|
760 |
+
config = deepcopy(config)
|
761 |
+
assert (
|
762 |
+
config.add_bias_linear == False
|
763 |
+
), "bias is not supported in the shared experts, "
|
764 |
+
"please set '--disable-bias-linear' instead."
|
765 |
+
|
766 |
+
config.ffn_hidden_size = config.moe_shared_expert_intermediate_size
|
767 |
+
super().__init__(config=config, submodules=spec.submodules)
|
768 |
+
|
769 |
+
self.use_shared_expert_gate = spec.params.get("gate", False)
|
770 |
+
if self.use_shared_expert_gate:
|
771 |
+
assert False, "use_shared_expert_gate is not Implemented"
|
772 |
+
# self.gate_weight = torch.nn.Parameter(torch.empty((1, self.config.hidden_size)))
|
773 |
+
# if config.perform_initialization:
|
774 |
+
# if get_cuda_rng_tracker().is_initialized():
|
775 |
+
# with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
|
776 |
+
# config.init_method(self.gate_weight)
|
777 |
+
# else:
|
778 |
+
# config.init_method(self.gate_weight)
|
779 |
+
# self.gate_weight.data = self.gate_weight.data.to(dtype=config.params_dtype)
|
780 |
+
# setattr(self.gate_weight, 'sequence_parallel', self.config.sequence_parallel)
|
781 |
+
else:
|
782 |
+
self.gate_weight = None
|
783 |
+
|
784 |
+
|
785 |
+
class TransformerBlock(MemEstimator):
|
786 |
+
"""Transformer class."""
|
787 |
+
|
788 |
+
def __init__(
|
789 |
+
self,
|
790 |
+
config: TransformerConfig,
|
791 |
+
spec: Union[TransformerBlockSubmodules, ModuleSpec],
|
792 |
+
post_layer_norm: bool = True,
|
793 |
+
pre_process: bool = True,
|
794 |
+
post_process: bool = True,
|
795 |
+
):
|
796 |
+
super().__init__()
|
797 |
+
self.config = config
|
798 |
+
|
799 |
+
self.submodules = _get_block_submodules(config, spec)
|
800 |
+
self.post_layer_norm = post_layer_norm
|
801 |
+
self.pre_process = pre_process
|
802 |
+
self.post_process = post_process
|
803 |
+
self.cuda_graphs = {}
|
804 |
+
self.current_microbatch = -1
|
805 |
+
self.input_tensor = None
|
806 |
+
self.checkpoint_core_attention = (
|
807 |
+
self.config.recompute_granularity == "selective"
|
808 |
+
)
|
809 |
+
|
810 |
+
self._build_layers()
|
811 |
+
self.num_layers_per_pipeline_rank = len(self.layers)
|
812 |
+
self.tp_only_amax_red = config.tp_only_amax_red
|
813 |
+
|
814 |
+
def _build_layers(self):
|
815 |
+
def build_layer(layer_spec, layer_number):
|
816 |
+
return build_module(
|
817 |
+
layer_spec, config=self.config, layer_number=layer_number
|
818 |
+
)
|
819 |
+
|
820 |
+
# offset is implicit in TransformerLayer
|
821 |
+
self.layers = ModuleList(
|
822 |
+
[
|
823 |
+
build_layer(layer_spec, i + 1)
|
824 |
+
for i, layer_spec in enumerate(self.submodules.layer_specs)
|
825 |
+
]
|
826 |
+
)
|
827 |
+
|
828 |
+
if self.submodules.layer_norm and self.post_process and self.post_layer_norm:
|
829 |
+
self.final_layernorm = build_module(
|
830 |
+
self.submodules.layer_norm,
|
831 |
+
config=self.config,
|
832 |
+
hidden_size=self.config.hidden_size,
|
833 |
+
eps=self.config.layernorm_epsilon,
|
834 |
+
)
|
835 |
+
else:
|
836 |
+
self.final_layernorm = None # Either this or nn.Identity
|
837 |
+
|
838 |
+
def num_parameter(self):
|
839 |
+
ret = self.layers.num_parameter()
|
840 |
+
if self.final_layernorm is not None:
|
841 |
+
ret += self.final_layernorm.num_parameter()
|
842 |
+
|
843 |
+
return ret
|
844 |
+
|
845 |
+
def num_activation(self, input_shape: list[int]):
|
846 |
+
result = self.layers.num_activation(input_shape)
|
847 |
+
if self.final_layernorm is not None:
|
848 |
+
result += self.final_layernorm.num_activation(input_shape)
|
849 |
+
return result
|
850 |
+
|
851 |
+
def mock_forward(self, input_shape: list[int]):
|
852 |
+
return input_shape
|
853 |
+
|
854 |
+
|
855 |
+
class TopKRouter(MemEstimator):
|
856 |
+
|
857 |
+
def __init__(self, config: TransformerConfig) -> None:
|
858 |
+
super().__init__()
|
859 |
+
self.config = config
|
860 |
+
self.topk = self.config.moe_router_topk
|
861 |
+
self.routing_type = self.config.moe_router_load_balancing_type
|
862 |
+
self.input_jitter = None
|
863 |
+
|
864 |
+
def num_parameter(self):
|
865 |
+
return 0
|
866 |
+
|
867 |
+
def num_activation(self, input_shape: list[int]):
|
868 |
+
result = cum_mul(input_shape) * 2 # sinkhorn and sinkhorn activation
|
869 |
+
return result
|
870 |
+
|
871 |
+
def mock_forward(self, input_shape: list[int]):
|
872 |
+
return input_shape[:-1] + [self.topk]
|
873 |
+
|
874 |
+
|
875 |
+
class MoELayer(MemEstimator):
|
876 |
+
|
877 |
+
def __init__(
|
878 |
+
self, config: TransformerConfig, submodules=None, layer_number: int = None
|
879 |
+
):
|
880 |
+
super().__init__()
|
881 |
+
self.config = config
|
882 |
+
self.submodules = submodules
|
883 |
+
self.moe_layer_recompute = config.moe_layer_recompute
|
884 |
+
|
885 |
+
self.expert_parallel_size = get_expert_model_parallel_world_size()
|
886 |
+
assert (
|
887 |
+
self.expert_parallel_size > 0
|
888 |
+
), "Expected non-negative expert parallel size"
|
889 |
+
|
890 |
+
assert self.config.num_moe_experts % self.expert_parallel_size == 0
|
891 |
+
self.num_local_experts = (
|
892 |
+
self.config.num_moe_experts // self.expert_parallel_size
|
893 |
+
)
|
894 |
+
local_expert_indices_offset = (
|
895 |
+
get_expert_model_parallel_rank() * self.num_local_experts
|
896 |
+
)
|
897 |
+
|
898 |
+
self.router = TopKRouter(config=self.config)
|
899 |
+
self.use_shared_expert = (
|
900 |
+
self.config.moe_shared_expert_intermediate_size is not None
|
901 |
+
)
|
902 |
+
self.shared_expert_overlap = self.config.moe_shared_expert_overlap
|
903 |
+
|
904 |
+
self.local_expert_indices = [
|
905 |
+
local_expert_indices_offset + i for i in range(self.num_local_experts)
|
906 |
+
]
|
907 |
+
assert all(
|
908 |
+
map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices)
|
909 |
+
)
|
910 |
+
|
911 |
+
self.experts = None
|
912 |
+
self.shared_experts = None
|
913 |
+
self.token_dispatcher = None
|
914 |
+
self.layer_number = layer_number
|
915 |
+
# Initialize experts
|
916 |
+
self.experts = build_module(
|
917 |
+
self.submodules.experts, self.num_local_experts, self.config
|
918 |
+
)
|
919 |
+
|
920 |
+
# Initialize shared experts
|
921 |
+
if self.use_shared_expert:
|
922 |
+
self.shared_experts = SharedExpertMLP(
|
923 |
+
self.config, self.submodules.shared_experts
|
924 |
+
)
|
925 |
+
# if self.shared_expert_overlap:
|
926 |
+
# self.token_dispatcher.set_shared_experts(self.shared_experts)
|
927 |
+
|
928 |
+
def num_parameter(self):
|
929 |
+
ret = self.experts.num_parameter() + self.router.num_parameter()
|
930 |
+
if self.use_shared_expert:
|
931 |
+
ret += self.shared_experts.num_parameter()
|
932 |
+
return ret
|
933 |
+
|
934 |
+
def num_activation(self, input_shape: list[int]):
|
935 |
+
result = self.router.num_activation(input_shape)
|
936 |
+
result += cum_mul(input_shape) * self.router.topk # token dispatcher
|
937 |
+
moe_input_shape_average = deepcopy(input_shape)
|
938 |
+
moe_input_shape_average[1] = int(moe_input_shape_average[1] * self.router.topk)
|
939 |
+
|
940 |
+
result += self.experts.num_activation(moe_input_shape_average)
|
941 |
+
if self.use_shared_expert:
|
942 |
+
result += self.shared_experts.num_activation(input_shape)
|
943 |
+
|
944 |
+
if self.config.moe_layer_recompute:
|
945 |
+
result = cum_mul(input_shape) * 2
|
946 |
+
return result
|
947 |
+
|
948 |
+
def mock_forward(self, input_shape: list[int]):
|
949 |
+
return input_shape
|
950 |
+
|
951 |
+
|
952 |
+
class IdentityOp(MemEstimator):
|
953 |
+
def num_parameter(self):
|
954 |
+
return 0
|
955 |
+
|
956 |
+
def num_activation(self, input_shape: list[int]):
|
957 |
+
return 0
|
958 |
+
|
959 |
+
def mock_forward(self, input_shape: list[int]):
|
960 |
+
return input_shape
|
961 |
+
|
962 |
+
|
963 |
+
IdentityFuncOp = IdentityOp
|
964 |
+
TERowParallelLinear = RowParallelLinear
|
965 |
+
TEColumnParallelLinear = ColumnParallelLinear
|
966 |
+
TELayerNormColumnParallelLinear = ColumnParallelLinear
|
967 |
+
|
968 |
+
|
969 |
+
class TEDotProductAttention(MemEstimator):
|
970 |
+
def __init__(self, config: TransformerConfig, *args, **kwargs):
|
971 |
+
super().__init__()
|
972 |
+
self.config = config
|
973 |
+
|
974 |
+
def num_parameter(self):
|
975 |
+
return 0
|
976 |
+
|
977 |
+
def num_activation(
|
978 |
+
self, q_shape: list[int], k_shape: list[int], v_shape: list[int]
|
979 |
+
):
|
980 |
+
bs, seqs, heads, dim = q_shape
|
981 |
+
if self.config.multi_latent_attention and False:
|
982 |
+
result = bs * seqs * seqs * heads
|
983 |
+
else:
|
984 |
+
bs, seqs, heads, dim = k_shape
|
985 |
+
result = (
|
986 |
+
bs * seqs * dim * heads * 2 # * self.config.tensor_model_parallel_size
|
987 |
+
) # flash attention
|
988 |
+
if self.config.context_parallel_size > 1:
|
989 |
+
result *= 2
|
990 |
+
return result
|
991 |
+
|
992 |
+
def mock_forward(
|
993 |
+
self,
|
994 |
+
hidden_size: int,
|
995 |
+
q_shape: list[int],
|
996 |
+
k_shape: list[int],
|
997 |
+
v_shape: list[int],
|
998 |
+
):
|
999 |
+
seqs, bs, heads, dim = q_shape
|
1000 |
+
return [seqs, bs, hidden_size]
|
1001 |
+
|
1002 |
+
|
1003 |
+
class TransformerLayer(MemEstimator):
|
1004 |
+
def __init__(
|
1005 |
+
self,
|
1006 |
+
config: TransformerConfig,
|
1007 |
+
submodules,
|
1008 |
+
layer_number: int = 1,
|
1009 |
+
hidden_dropout: float = None,
|
1010 |
+
):
|
1011 |
+
super().__init__()
|
1012 |
+
self.config = config
|
1013 |
+
|
1014 |
+
if config.enable_cuda_graph and self.training:
|
1015 |
+
assert (
|
1016 |
+
not config.cpu_offloading and config.recompute_granularity is None
|
1017 |
+
), "Cudagraphs not supported"
|
1018 |
+
self.cudagraph_manager = CudaGraphManager()
|
1019 |
+
|
1020 |
+
self.submodules_config = submodules
|
1021 |
+
self.layer_number = layer_number + get_transformer_layer_offset(self.config)
|
1022 |
+
self.hidden_dropout = (
|
1023 |
+
config.hidden_dropout if hidden_dropout is None else hidden_dropout
|
1024 |
+
)
|
1025 |
+
|
1026 |
+
# [Module 1: Input Layernorm] Optional Layernorm on the input data
|
1027 |
+
# TODO: add pytorch only layernorm
|
1028 |
+
self.input_layernorm = build_module(
|
1029 |
+
submodules.input_layernorm,
|
1030 |
+
config=self.config,
|
1031 |
+
hidden_size=self.config.hidden_size,
|
1032 |
+
eps=self.config.layernorm_epsilon,
|
1033 |
+
)
|
1034 |
+
|
1035 |
+
# [Module 2: SelfAttention]
|
1036 |
+
self.self_attention = build_module(
|
1037 |
+
submodules.self_attention, config=self.config, layer_number=layer_number
|
1038 |
+
)
|
1039 |
+
|
1040 |
+
# [Module 3: BiasDropoutFusion]
|
1041 |
+
self.self_attn_bda = build_module(submodules.self_attn_bda)
|
1042 |
+
|
1043 |
+
# [Module 4: Post SelfAttention] Optional Layernorm after self-attn
|
1044 |
+
self.pre_cross_attn_layernorm = build_module(
|
1045 |
+
submodules.pre_cross_attn_layernorm,
|
1046 |
+
config=self.config,
|
1047 |
+
hidden_size=self.config.hidden_size,
|
1048 |
+
eps=self.config.layernorm_epsilon,
|
1049 |
+
)
|
1050 |
+
|
1051 |
+
# [Module 5: CrossAttention]
|
1052 |
+
self.cross_attention = build_module(
|
1053 |
+
submodules.cross_attention, config=self.config, layer_number=layer_number
|
1054 |
+
)
|
1055 |
+
|
1056 |
+
# [Module 6: BiasDropoutFusion]
|
1057 |
+
self.cross_attn_bda = build_module(
|
1058 |
+
submodules.cross_attn_bda, config=self.config
|
1059 |
+
)
|
1060 |
+
|
1061 |
+
# [Module 7: Pre MLP] Optional Layernorm before MLP
|
1062 |
+
self.pre_mlp_layernorm = build_module(
|
1063 |
+
submodules.pre_mlp_layernorm,
|
1064 |
+
config=self.config,
|
1065 |
+
hidden_size=self.config.hidden_size,
|
1066 |
+
eps=self.config.layernorm_epsilon,
|
1067 |
+
)
|
1068 |
+
|
1069 |
+
# [Module 8: MLP block]
|
1070 |
+
self.mlp = build_module(submodules.mlp, config=self.config)
|
1071 |
+
if hasattr(self.mlp, "set_layer_number"):
|
1072 |
+
self.mlp.set_layer_number(self.layer_number)
|
1073 |
+
|
1074 |
+
# [Module 9: BiasDropoutFusion]
|
1075 |
+
self.mlp_bda = build_module(submodules.mlp_bda)
|
1076 |
+
|
1077 |
+
def num_parameter(self):
|
1078 |
+
result = self.input_layernorm.num_parameter()
|
1079 |
+
result += self.self_attention.num_parameter()
|
1080 |
+
result += self.pre_cross_attn_layernorm.num_parameter()
|
1081 |
+
result += self.cross_attention.num_parameter()
|
1082 |
+
result += self.cross_attn_bda.num_parameter()
|
1083 |
+
result += self.pre_mlp_layernorm.num_parameter()
|
1084 |
+
result += self.mlp.num_parameter()
|
1085 |
+
|
1086 |
+
return result
|
1087 |
+
|
1088 |
+
def num_activation(self, input_shape: list[int]):
|
1089 |
+
result = 0
|
1090 |
+
result += self.self_attention.num_activation(input_shape)
|
1091 |
+
result += self.mlp.num_activation(input_shape)
|
1092 |
+
# __import__('ipdb').set_trace()
|
1093 |
+
# sequence parallel
|
1094 |
+
if self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1:
|
1095 |
+
input_shape = deepcopy(input_shape)
|
1096 |
+
input_shape[1] /= self.config.tensor_model_parallel_size
|
1097 |
+
result += self.input_layernorm.num_activation(input_shape)
|
1098 |
+
result += self.pre_mlp_layernorm.num_activation(input_shape)
|
1099 |
+
result += self.self_attn_bda.num_activation(input_shape)
|
1100 |
+
result += self.mlp_bda.num_activation(input_shape)
|
1101 |
+
return result
|
1102 |
+
|
1103 |
+
def mock_forward(self, input_shape: list[int]):
|
1104 |
+
return input_shape
|
1105 |
+
|
1106 |
+
|
1107 |
+
class SelfAttention(MemEstimator):
|
1108 |
+
|
1109 |
+
def __init__(
|
1110 |
+
self,
|
1111 |
+
config: TransformerConfig,
|
1112 |
+
submodules,
|
1113 |
+
layer_number: int,
|
1114 |
+
attn_mask_type,
|
1115 |
+
):
|
1116 |
+
super().__init__()
|
1117 |
+
|
1118 |
+
self.config = config
|
1119 |
+
self.layer_number = layer_number
|
1120 |
+
self.attn_mask_type = attn_mask_type
|
1121 |
+
self.attention_type = ""
|
1122 |
+
|
1123 |
+
# For normal attention without groups, num_query_groups == num_attention_heads,
|
1124 |
+
# so these two will be the same
|
1125 |
+
self.query_projection_size = (
|
1126 |
+
self.config.kv_channels * self.config.num_attention_heads
|
1127 |
+
)
|
1128 |
+
self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups
|
1129 |
+
|
1130 |
+
# Per attention head and per partition values.
|
1131 |
+
world_size = get_tensor_model_parallel_world_size()
|
1132 |
+
self.hidden_size_per_attention_head = divide(
|
1133 |
+
self.query_projection_size, self.config.num_attention_heads
|
1134 |
+
)
|
1135 |
+
self.num_attention_heads_per_partition = divide(
|
1136 |
+
self.config.num_attention_heads, world_size
|
1137 |
+
)
|
1138 |
+
self.num_query_groups_per_partition = divide(
|
1139 |
+
self.config.num_query_groups, world_size
|
1140 |
+
)
|
1141 |
+
self.core_attention = build_module(
|
1142 |
+
submodules.core_attention,
|
1143 |
+
config=self.config,
|
1144 |
+
layer_number=self.layer_number,
|
1145 |
+
attn_mask_type=self.attn_mask_type,
|
1146 |
+
)
|
1147 |
+
self.linear_qkv = build_module(
|
1148 |
+
submodules.linear_qkv,
|
1149 |
+
self.config.hidden_size,
|
1150 |
+
self.query_projection_size + 2 * self.kv_projection_size,
|
1151 |
+
config=self.config,
|
1152 |
+
init_method=self.config.init_method,
|
1153 |
+
gather_output=False,
|
1154 |
+
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
|
1155 |
+
skip_bias_add=False,
|
1156 |
+
is_expert=False,
|
1157 |
+
tp_comm_buffer_name="qkv",
|
1158 |
+
)
|
1159 |
+
|
1160 |
+
if submodules.q_layernorm is not None:
|
1161 |
+
self.q_layernorm = build_module(
|
1162 |
+
submodules.q_layernorm,
|
1163 |
+
hidden_size=self.hidden_size_per_attention_head,
|
1164 |
+
config=self.config,
|
1165 |
+
eps=self.config.layernorm_epsilon,
|
1166 |
+
)
|
1167 |
+
else:
|
1168 |
+
self.q_layernorm = None
|
1169 |
+
|
1170 |
+
if submodules.k_layernorm is not None:
|
1171 |
+
self.k_layernorm = build_module(
|
1172 |
+
submodules.k_layernorm,
|
1173 |
+
hidden_size=self.hidden_size_per_attention_head,
|
1174 |
+
config=self.config,
|
1175 |
+
eps=self.config.layernorm_epsilon,
|
1176 |
+
)
|
1177 |
+
else:
|
1178 |
+
self.k_layernorm = None
|
1179 |
+
self.linear_proj = build_module(
|
1180 |
+
submodules.linear_proj,
|
1181 |
+
self.query_projection_size,
|
1182 |
+
self.config.hidden_size,
|
1183 |
+
config=self.config,
|
1184 |
+
init_method=self.config.output_layer_init_method,
|
1185 |
+
bias=self.config.add_bias_linear,
|
1186 |
+
input_is_parallel=True,
|
1187 |
+
skip_bias_add=True,
|
1188 |
+
is_expert=False,
|
1189 |
+
tp_comm_buffer_name="proj",
|
1190 |
+
)
|
1191 |
+
self.checkpoint_core_attention = (
|
1192 |
+
self.config.recompute_granularity == "selective"
|
1193 |
+
)
|
1194 |
+
|
1195 |
+
def num_parameter(self):
|
1196 |
+
result = 0
|
1197 |
+
result += self.core_attention.num_parameter()
|
1198 |
+
result += self.linear_proj.num_parameter()
|
1199 |
+
result += self.linear_qkv.num_parameter()
|
1200 |
+
if self.q_layernorm is not None:
|
1201 |
+
result += self.q_layernorm.num_parameter()
|
1202 |
+
if self.k_layernorm is not None:
|
1203 |
+
result += self.k_layernorm.num_parameter()
|
1204 |
+
|
1205 |
+
return result
|
1206 |
+
|
1207 |
+
def num_activation(self, input_shape: list[int]):
|
1208 |
+
ret = 0
|
1209 |
+
## in estimator: act(linear) = 1.5*cum_mul(input_shape)
|
1210 |
+
## in reality: act(linear) = cum_mul(input_shape), act(rotary) = cum_mul(input_shape), act(attn_forward_func_with_cp) = cum_mul(input_shape)
|
1211 |
+
# ret += self.linear_qkv.num_activation(input_shape)
|
1212 |
+
mixed_qkv_shape = self.linear_qkv.mock_forward(input_shape)
|
1213 |
+
new_tensor_shape = mixed_qkv_shape[:-1] + [
|
1214 |
+
self.num_query_groups_per_partition,
|
1215 |
+
(
|
1216 |
+
(
|
1217 |
+
self.num_attention_heads_per_partition
|
1218 |
+
// self.num_query_groups_per_partition
|
1219 |
+
+ 2
|
1220 |
+
)
|
1221 |
+
* self.hidden_size_per_attention_head
|
1222 |
+
),
|
1223 |
+
]
|
1224 |
+
split_arg_list = [
|
1225 |
+
(
|
1226 |
+
self.num_attention_heads_per_partition
|
1227 |
+
// self.num_query_groups_per_partition
|
1228 |
+
* self.hidden_size_per_attention_head
|
1229 |
+
),
|
1230 |
+
self.hidden_size_per_attention_head,
|
1231 |
+
self.hidden_size_per_attention_head,
|
1232 |
+
]
|
1233 |
+
# [sq, b, ng, (np/ng + 2) * hn]
|
1234 |
+
# --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
|
1235 |
+
q_shape = new_tensor_shape[:-1] + [split_arg_list[0]]
|
1236 |
+
k_shape = new_tensor_shape[:-1] + [split_arg_list[1]]
|
1237 |
+
v_shape = new_tensor_shape[:-1] + [split_arg_list[2]]
|
1238 |
+
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
|
1239 |
+
q_shape = (
|
1240 |
+
q_shape[:2]
|
1241 |
+
+ [cum_mul(q_shape[-2:]) // self.hidden_size_per_attention_head]
|
1242 |
+
+ [self.hidden_size_per_attention_head]
|
1243 |
+
)
|
1244 |
+
|
1245 |
+
if not self.checkpoint_core_attention:
|
1246 |
+
ret += self.core_attention.num_activation(q_shape, k_shape, v_shape)
|
1247 |
+
ret += self.linear_proj.num_activation(input_shape)
|
1248 |
+
## in reality: act(linear) = cum_mul(input_shape), act(rotary) = cum_mul(input_shape), act(attn_forward_func_with_cp) = cum_mul(input_shape)
|
1249 |
+
ret += self.linear_proj.num_activation(input_shape) * 3
|
1250 |
+
|
1251 |
+
return ret
|
1252 |
+
|
1253 |
+
def mock_forward(self, input_shape: list[int]):
|
1254 |
+
return input_shape
|
1255 |
+
|
1256 |
+
|
1257 |
+
class Linear(MemEstimator):
|
1258 |
+
def __init__(
|
1259 |
+
self,
|
1260 |
+
in_features: int,
|
1261 |
+
out_features: int,
|
1262 |
+
bias: bool = True,
|
1263 |
+
device=None,
|
1264 |
+
dtype=None,
|
1265 |
+
) -> None:
|
1266 |
+
|
1267 |
+
super().__init__()
|
1268 |
+
self.weight = (in_features, out_features)
|
1269 |
+
|
1270 |
+
def num_parameter(self):
|
1271 |
+
return self.weight[0] * self.weight[1]
|
1272 |
+
|
1273 |
+
def num_activation(self, input_shape: list[int]):
|
1274 |
+
return cum_mul(input_shape[:-1]) * self.weight[1]
|
1275 |
+
|
1276 |
+
def mock_forward(self, input_shape: list[int]):
|
1277 |
+
return input_shape[:-1] + [self.weight[1]]
|
1278 |
+
|
1279 |
+
|
1280 |
+
class MLASelfAttention(MemEstimator):
|
1281 |
+
"""MLA Self-attention layer class
|
1282 |
+
|
1283 |
+
Self-attention layer takes input with size [s, b, h]
|
1284 |
+
and returns output of the same size.
|
1285 |
+
"""
|
1286 |
+
|
1287 |
+
def __init__(
|
1288 |
+
self,
|
1289 |
+
config: MLATransformerConfig,
|
1290 |
+
submodules,
|
1291 |
+
layer_number: int,
|
1292 |
+
attn_mask_type=AttnMaskType.padding,
|
1293 |
+
) -> None:
|
1294 |
+
|
1295 |
+
super().__init__()
|
1296 |
+
self.config = config
|
1297 |
+
self.layer_number = layer_number
|
1298 |
+
self.attn_mask_type = attn_mask_type
|
1299 |
+
self.attention_type = "self"
|
1300 |
+
self.world_size = get_tensor_model_parallel_world_size()
|
1301 |
+
# assert (
|
1302 |
+
# world_size == 1
|
1303 |
+
# ), "MLA is not supported with Tensor Parallelism yet, \
|
1304 |
+
# use Expert Parallelism and Pipeline Parallelism for better performance."
|
1305 |
+
|
1306 |
+
self.query_projection_size = (
|
1307 |
+
self.config.v_head_dim * self.config.num_attention_heads
|
1308 |
+
)
|
1309 |
+
|
1310 |
+
self.q_head_dim = self.config.qk_head_dim + self.config.qk_pos_emb_head_dim
|
1311 |
+
|
1312 |
+
mscale = _yarn_get_mscale(self.config.rotary_scaling_factor, self.config.mscale)
|
1313 |
+
self.softmax_scale = mscale * mscale / math.sqrt(self.q_head_dim)
|
1314 |
+
|
1315 |
+
# Per attention head and per partition values.
|
1316 |
+
world_size = get_tensor_model_parallel_world_size()
|
1317 |
+
self.hidden_size_per_attention_head = divide(
|
1318 |
+
self.query_projection_size, self.config.num_attention_heads
|
1319 |
+
)
|
1320 |
+
self.num_attention_heads_per_partition = divide(
|
1321 |
+
self.config.num_attention_heads, world_size
|
1322 |
+
)
|
1323 |
+
self.num_query_groups_per_partition = divide(
|
1324 |
+
self.config.num_query_groups, world_size
|
1325 |
+
)
|
1326 |
+
# TODO Rotary Embedding
|
1327 |
+
# self.rotary_pos_emb = YarnRotaryEmbedding(
|
1328 |
+
# self.config.qk_pos_emb_head_dim,
|
1329 |
+
# rotary_base=self.config.rotary_base,
|
1330 |
+
# scaling_factor=self.config.rotary_scaling_factor,
|
1331 |
+
# original_max_position_embeddings=self.config.max_position_embeddings,
|
1332 |
+
# beta_fast=self.config.beta_fast,
|
1333 |
+
# beta_slow=self.config.beta_slow,
|
1334 |
+
# mscale=self.config.mscale,
|
1335 |
+
# mscale_all_dim=self.config.mscale_all_dim,
|
1336 |
+
# )
|
1337 |
+
|
1338 |
+
self.core_attention = build_module(
|
1339 |
+
submodules.core_attention,
|
1340 |
+
config=self.config,
|
1341 |
+
layer_number=self.layer_number,
|
1342 |
+
attn_mask_type=self.attn_mask_type,
|
1343 |
+
attention_type=self.attention_type,
|
1344 |
+
softmax_scale=self.softmax_scale,
|
1345 |
+
k_channels=self.q_head_dim,
|
1346 |
+
v_channels=self.config.v_head_dim,
|
1347 |
+
)
|
1348 |
+
|
1349 |
+
if self.config.q_lora_rank is None:
|
1350 |
+
# Not projectiing query
|
1351 |
+
self.linear_q_proj = build_module(
|
1352 |
+
submodules.linear_q_proj,
|
1353 |
+
self.config.hidden_size,
|
1354 |
+
self.config.num_attention_heads * self.q_head_dim,
|
1355 |
+
config=self.config,
|
1356 |
+
init_method=self.config.init_method,
|
1357 |
+
gather_output=False,
|
1358 |
+
bias=False,
|
1359 |
+
skip_bias_add=False,
|
1360 |
+
is_expert=False,
|
1361 |
+
is_mla=True,
|
1362 |
+
)
|
1363 |
+
|
1364 |
+
else:
|
1365 |
+
self.linear_q_down_proj = Linear(
|
1366 |
+
self.config.hidden_size, self.config.q_lora_rank, bias=False
|
1367 |
+
)
|
1368 |
+
|
1369 |
+
self.linear_q_up_proj = build_module(
|
1370 |
+
submodules.linear_q_up_proj,
|
1371 |
+
self.config.q_lora_rank,
|
1372 |
+
self.config.num_attention_heads * self.q_head_dim,
|
1373 |
+
config=self.config,
|
1374 |
+
init_method=self.config.init_method,
|
1375 |
+
gather_output=False,
|
1376 |
+
bias=False,
|
1377 |
+
skip_bias_add=False,
|
1378 |
+
is_expert=False,
|
1379 |
+
is_mla=True,
|
1380 |
+
)
|
1381 |
+
self.linear_kv_down_proj = Linear(
|
1382 |
+
self.config.hidden_size,
|
1383 |
+
self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim,
|
1384 |
+
bias=False,
|
1385 |
+
)
|
1386 |
+
|
1387 |
+
self.linear_kv_up_proj = build_module(
|
1388 |
+
submodules.linear_kv_up_proj,
|
1389 |
+
self.config.kv_lora_rank,
|
1390 |
+
self.config.num_attention_heads
|
1391 |
+
* (self.config.qk_head_dim + self.config.v_head_dim),
|
1392 |
+
config=self.config,
|
1393 |
+
init_method=self.config.init_method,
|
1394 |
+
gather_output=False,
|
1395 |
+
bias=False,
|
1396 |
+
skip_bias_add=False,
|
1397 |
+
is_expert=False,
|
1398 |
+
is_mla=True,
|
1399 |
+
)
|
1400 |
+
|
1401 |
+
if self.config.q_lora_rank is not None:
|
1402 |
+
self.q_layernorm = build_module(
|
1403 |
+
submodules.q_layernorm,
|
1404 |
+
hidden_size=self.config.q_lora_rank,
|
1405 |
+
config=self.config,
|
1406 |
+
eps=self.config.layernorm_epsilon,
|
1407 |
+
)
|
1408 |
+
|
1409 |
+
self.kv_layernorm = build_module(
|
1410 |
+
submodules.kv_layernorm,
|
1411 |
+
hidden_size=self.config.kv_lora_rank,
|
1412 |
+
config=self.config,
|
1413 |
+
eps=self.config.layernorm_epsilon,
|
1414 |
+
)
|
1415 |
+
|
1416 |
+
# Output.
|
1417 |
+
self.linear_proj = build_module(
|
1418 |
+
submodules.linear_proj,
|
1419 |
+
self.query_projection_size,
|
1420 |
+
self.config.hidden_size,
|
1421 |
+
config=self.config,
|
1422 |
+
init_method=self.config.output_layer_init_method,
|
1423 |
+
bias=self.config.add_bias_linear,
|
1424 |
+
input_is_parallel=True,
|
1425 |
+
skip_bias_add=True,
|
1426 |
+
is_expert=False,
|
1427 |
+
tp_comm_buffer_name="proj",
|
1428 |
+
)
|
1429 |
+
|
1430 |
+
self.checkpoint_core_attention = (
|
1431 |
+
self.config.recompute_granularity == "selective"
|
1432 |
+
)
|
1433 |
+
|
1434 |
+
def num_parameter(self):
|
1435 |
+
result = 0
|
1436 |
+
result += self.core_attention.num_parameter()
|
1437 |
+
result += self.linear_proj.num_parameter()
|
1438 |
+
if self.config.q_lora_rank is None:
|
1439 |
+
result += self.linear_q_proj.num_parameter()
|
1440 |
+
else:
|
1441 |
+
result += self.linear_q_down_proj.num_parameter()
|
1442 |
+
result += self.linear_q_up_proj.num_parameter()
|
1443 |
+
result += self.linear_kv_down_proj.num_parameter()
|
1444 |
+
result += self.linear_kv_up_proj.num_parameter()
|
1445 |
+
result += self.kv_layernorm.num_parameter()
|
1446 |
+
if self.config.q_lora_rank is not None:
|
1447 |
+
result += self.q_layernorm.num_parameter()
|
1448 |
+
|
1449 |
+
return result
|
1450 |
+
|
1451 |
+
def num_activation(self, input_shape: list[int]):
|
1452 |
+
q_len, bsz, _ = input_shape
|
1453 |
+
ret = 0
|
1454 |
+
if self.config.q_lora_rank is not None:
|
1455 |
+
ret += self.linear_q_down_proj.num_activation(input_shape)
|
1456 |
+
q_compressed_shape = self.linear_q_down_proj.mock_forward(input_shape)
|
1457 |
+
ret += self.q_layernorm.num_activation(q_compressed_shape)
|
1458 |
+
ret += self.linear_q_up_proj.num_activation(q_compressed_shape)
|
1459 |
+
q_shape = self.linear_q_up_proj.mock_forward(q_compressed_shape)
|
1460 |
+
else:
|
1461 |
+
# hidden_states:[s, b, 2048], q: [s, b, n * 192]
|
1462 |
+
ret += self.linear_q_proj.num_activation(input_shape)
|
1463 |
+
q_shape = self.linear_q_proj.mock_forward(input_shape)
|
1464 |
+
|
1465 |
+
# kv_combined: [s, b, 576]
|
1466 |
+
ret += self.linear_kv_down_proj.num_activation(input_shape)
|
1467 |
+
kv_combined_shape = self.linear_kv_down_proj.mock_forward(input_shape)
|
1468 |
+
# kv_compressed:[s, b, 512], k_pos_emb: [s, b, 64]
|
1469 |
+
kv_compressed_shape = kv_combined_shape[:-1] + [self.config.kv_lora_rank]
|
1470 |
+
|
1471 |
+
# kv: [s, b, 2048]
|
1472 |
+
ret += self.kv_layernorm.num_activation(kv_compressed_shape)
|
1473 |
+
ret += self.linear_kv_up_proj.num_activation(kv_compressed_shape)
|
1474 |
+
|
1475 |
+
q_shape = [q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim]
|
1476 |
+
k_shape = [q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim]
|
1477 |
+
v_shape = [
|
1478 |
+
q_len,
|
1479 |
+
bsz,
|
1480 |
+
self.num_attention_heads_per_partition,
|
1481 |
+
self.config.v_head_dim,
|
1482 |
+
]
|
1483 |
+
|
1484 |
+
if not self.checkpoint_core_attention:
|
1485 |
+
ret += self.core_attention.num_activation(q_shape, k_shape, v_shape)
|
1486 |
+
|
1487 |
+
ret += self.linear_proj.num_activation(input_shape)
|
1488 |
+
|
1489 |
+
return ret
|
1490 |
+
|
1491 |
+
def mock_forward(self, input_shape: list[int]):
|
1492 |
+
return input_shape
|
1493 |
+
|
1494 |
+
|
1495 |
+
class TENorm:
|
1496 |
+
def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5):
|
1497 |
+
from megatron.core.extensions.transformer_engine import _get_extra_te_kwargs, te
|
1498 |
+
|
1499 |
+
if config.normalization == "LayerNorm":
|
1500 |
+
# TODO layernorm
|
1501 |
+
pass
|
1502 |
+
elif config.normalization == "RMSNorm":
|
1503 |
+
assert hasattr(
|
1504 |
+
te.pytorch, "RMSNorm"
|
1505 |
+
), "Transformer-Engine >= v0.11 required to use this feature"
|
1506 |
+
instance = RMSNorm(
|
1507 |
+
hidden_size=hidden_size,
|
1508 |
+
eps=eps,
|
1509 |
+
sequence_parallel=config.sequence_parallel,
|
1510 |
+
zero_centered_gamma=config.layernorm_zero_centered_gamma,
|
1511 |
+
**_get_extra_te_kwargs(config),
|
1512 |
+
)
|
1513 |
+
else:
|
1514 |
+
raise Exception("Only LayerNorm and RMSNorm are curently supported")
|
1515 |
+
|
1516 |
+
return instance
|
1517 |
+
|
1518 |
+
|
1519 |
+
def build_module(
|
1520 |
+
spec_or_module: Union[ModuleSpec, type], *args, **kwargs
|
1521 |
+
) -> MemEstimator:
|
1522 |
+
"""replace module with MemEstimators"""
|
1523 |
+
if isinstance(spec_or_module, types.FunctionType):
|
1524 |
+
return globals()[spec_or_module.__name__]
|
1525 |
+
|
1526 |
+
if isinstance(spec_or_module, ModuleSpec) and isinstance(
|
1527 |
+
spec_or_module.module, types.FunctionType
|
1528 |
+
):
|
1529 |
+
assert False
|
1530 |
+
return spec_or_module.module
|
1531 |
+
|
1532 |
+
if isinstance(spec_or_module, type):
|
1533 |
+
module = spec_or_module
|
1534 |
+
elif hasattr(spec_or_module, "module") and isinstance(spec_or_module.module, type):
|
1535 |
+
module = spec_or_module.module
|
1536 |
+
else:
|
1537 |
+
module = import_module(spec_or_module.module)
|
1538 |
+
|
1539 |
+
if isinstance(module, types.FunctionType):
|
1540 |
+
assert False
|
1541 |
+
return module
|
1542 |
+
|
1543 |
+
if hasattr(spec_or_module, "submodules") and spec_or_module.submodules is not None:
|
1544 |
+
kwargs["submodules"] = spec_or_module.submodules
|
1545 |
+
|
1546 |
+
try:
|
1547 |
+
module = globals()[module.__name__]
|
1548 |
+
return module(
|
1549 |
+
*args,
|
1550 |
+
**spec_or_module.params if hasattr(spec_or_module, "params") else {},
|
1551 |
+
**kwargs,
|
1552 |
+
)
|
1553 |
+
except Exception as e:
|
1554 |
+
# import ipdb
|
1555 |
+
|
1556 |
+
# ipdb.set_trace()
|
1557 |
+
# improve the error message since we hide the module name in the line above
|
1558 |
+
import sys
|
1559 |
+
|
1560 |
+
raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
|
1561 |
+
sys.exc_info()[2]
|
1562 |
+
)
|
1563 |
+
|
1564 |
+
|
1565 |
+
from megatron.core.transformer.transformer_block import (
|
1566 |
+
TransformerBlockSubmodules,
|
1567 |
+
BaseTransformerLayer,
|
1568 |
+
LayerNormImpl,
|
1569 |
+
)
|
1570 |
+
|
1571 |
+
|
1572 |
+
def _get_block_submodules(
|
1573 |
+
config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec]
|
1574 |
+
) -> TransformerBlockSubmodules:
|
1575 |
+
"""
|
1576 |
+
Retrieve or construct TransformerBlockSubmodules based on the provided specification.
|
1577 |
+
|
1578 |
+
Args:
|
1579 |
+
config (TransformerConfig): Configuration object for the transformer model.
|
1580 |
+
spec (Union[TransformerBlockSubmodules, ModuleSpec]): Specification for the
|
1581 |
+
transformer block submodules. Can be either a TransformerBlockSubmodules
|
1582 |
+
instance or a ModuleSpec.
|
1583 |
+
|
1584 |
+
Returns:
|
1585 |
+
TransformerBlockSubmodules: The submodules for the transformer block.
|
1586 |
+
"""
|
1587 |
+
|
1588 |
+
# Transformer block submodules.
|
1589 |
+
if isinstance(spec, TransformerBlockSubmodules):
|
1590 |
+
return spec
|
1591 |
+
|
1592 |
+
# ModuleSpec here is generally assumed to be for a transformer layer that
|
1593 |
+
# is implemented in `transformer_layer.py` or if it subclasses
|
1594 |
+
# `BaseTransformerLayer` from the `transformer_layer.py` file.
|
1595 |
+
elif isinstance(spec, ModuleSpec):
|
1596 |
+
if issubclass(spec.module, TransformerBlock):
|
1597 |
+
return spec.submodules
|
1598 |
+
elif issubclass(spec.module, BaseTransformerLayer):
|
1599 |
+
num_layers = get_num_layers_to_build(config)
|
1600 |
+
return TransformerBlockSubmodules(
|
1601 |
+
layer_specs=[spec] * num_layers, layer_norm=LayerNormImpl
|
1602 |
+
)
|
1603 |
+
else:
|
1604 |
+
raise Exception(f"specialize for {spec.module.__name__}.")
|
1605 |
+
else:
|
1606 |
+
raise Exception(f"specialize for {type(spec).__name__}.")
|
1607 |
+
|
1608 |
+
|
1609 |
+
def get_num_layers_to_build(config: TransformerConfig) -> int:
|
1610 |
+
"""
|
1611 |
+
Determine the number of transformer layers to build for the current pipeline stage.
|
1612 |
+
Args:
|
1613 |
+
config (TransformerConfig): Configuration object containing transformer model parameters.
|
1614 |
+
|
1615 |
+
Returns:
|
1616 |
+
int: The number of layers to be built for the current pipeline stage.
|
1617 |
+
"""
|
1618 |
+
if (
|
1619 |
+
config.num_layers_in_first_pipeline_stage is not None
|
1620 |
+
or config.num_layers_in_last_pipeline_stage is not None
|
1621 |
+
):
|
1622 |
+
|
1623 |
+
assert not (
|
1624 |
+
config.account_for_embedding_in_pipeline_split
|
1625 |
+
or config.account_for_loss_in_pipeline_split
|
1626 |
+
), " \
|
1627 |
+
Does not support standalone embedding stage and standalone loss stage with uneven pp"
|
1628 |
+
# Number of layers to distribute over rest of pipeline stages
|
1629 |
+
layers_to_distribute = config.num_layers
|
1630 |
+
# Number of pipeline stages left for distributing transformer layers
|
1631 |
+
pipeline_stages_left = get_pipeline_model_parallel_world_size()
|
1632 |
+
|
1633 |
+
# If the uneven first (last) pipeline stage is enabled, remove the specified number
|
1634 |
+
# of layers to calculate the number of layers on each middle pipeline stage.
|
1635 |
+
if config.num_layers_in_first_pipeline_stage is not None:
|
1636 |
+
layers_to_distribute -= config.num_layers_in_first_pipeline_stage
|
1637 |
+
pipeline_stages_left -= 1
|
1638 |
+
|
1639 |
+
if config.num_layers_in_last_pipeline_stage is not None:
|
1640 |
+
layers_to_distribute -= config.num_layers_in_last_pipeline_stage
|
1641 |
+
pipeline_stages_left -= 1
|
1642 |
+
|
1643 |
+
assert (
|
1644 |
+
layers_to_distribute % pipeline_stages_left == 0
|
1645 |
+
), "With uneven pipelineing the left over layers must be divisible by left over stages"
|
1646 |
+
num_layers_per_pipeline_rank = layers_to_distribute // pipeline_stages_left
|
1647 |
+
|
1648 |
+
# If the uneven first (last) pipeline stage is enabled, return the specified number
|
1649 |
+
# of layers for all virtual pipeline parallel stages within the first (last) pipeline
|
1650 |
+
# parallel stage.
|
1651 |
+
if (
|
1652 |
+
is_pipeline_first_stage(ignore_virtual=True)
|
1653 |
+
and config.num_layers_in_first_pipeline_stage is not None
|
1654 |
+
):
|
1655 |
+
num_layers_per_pipeline_rank = config.num_layers_in_first_pipeline_stage
|
1656 |
+
|
1657 |
+
if (
|
1658 |
+
is_pipeline_last_stage(ignore_virtual=True)
|
1659 |
+
and config.num_layers_in_last_pipeline_stage is not None
|
1660 |
+
):
|
1661 |
+
num_layers_per_pipeline_rank = config.num_layers_in_last_pipeline_stage
|
1662 |
+
else:
|
1663 |
+
# Include the embedding layer and loss layer into pipeline parallelism partition
|
1664 |
+
num_layers = config.num_layers
|
1665 |
+
if config.account_for_embedding_in_pipeline_split:
|
1666 |
+
num_layers += 1
|
1667 |
+
|
1668 |
+
if config.account_for_loss_in_pipeline_split:
|
1669 |
+
num_layers += 1
|
1670 |
+
|
1671 |
+
assert (
|
1672 |
+
num_layers % config.pipeline_model_parallel_size == 0
|
1673 |
+
), "num_layers should be divisible by pipeline_model_parallel_size"
|
1674 |
+
num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size
|
1675 |
+
|
1676 |
+
# if get_virtual_pipeline_model_parallel_world_size() is not None:
|
1677 |
+
# # Interleaved pipeline parallelism:
|
1678 |
+
# # Number of layers in each model chunk is the number of layers in the stage,
|
1679 |
+
# # divided by the number of model chunks in a stage.
|
1680 |
+
# # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
|
1681 |
+
# # layers to stages like (each list is a model chunk):
|
1682 |
+
# # Stage 0: [0] [2] [4] [6]
|
1683 |
+
# # Stage 1: [1] [3] [5] [7]
|
1684 |
+
# # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
|
1685 |
+
# # layers to stages like (each list is a model chunk):
|
1686 |
+
# # Stage 0: [0, 1] [4, 5]
|
1687 |
+
# # Stage 1: [2, 3] [6, 7]
|
1688 |
+
# vp_size = get_virtual_pipeline_model_parallel_world_size()
|
1689 |
+
|
1690 |
+
# assert (
|
1691 |
+
# num_layers_per_pipeline_rank % vp_size == 0
|
1692 |
+
# ), "num_layers_per_pipeline_rank should be divisible by vp_size"
|
1693 |
+
# num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
|
1694 |
+
|
1695 |
+
# num_layers_to_build = num_layers_per_virtual_rank
|
1696 |
+
|
1697 |
+
# else:
|
1698 |
+
# # Non-interleaved pipeline parallelism:
|
1699 |
+
# # Each stage gets a contiguous set of layers.
|
1700 |
+
# num_layers_to_build = num_layers_per_pipeline_rank
|
1701 |
+
num_layers_to_build = num_layers_per_pipeline_rank
|
1702 |
+
# The embedding (or loss) layer cannot function as a standalone transformer layer
|
1703 |
+
# Reduce the number of layers to construct by 1 on the first (or last) stage if the
|
1704 |
+
# embedding (or loss) layer is included in the pipeline parallelism partition and placement.
|
1705 |
+
if is_pipeline_first_stage() and config.account_for_embedding_in_pipeline_split:
|
1706 |
+
num_layers_to_build -= 1
|
1707 |
+
assert (
|
1708 |
+
num_layers_to_build >= 0
|
1709 |
+
), "Not enough layers in the first virtual pipeline stage"
|
1710 |
+
|
1711 |
+
if is_pipeline_last_stage() and config.account_for_loss_in_pipeline_split:
|
1712 |
+
num_layers_to_build -= 1
|
1713 |
+
assert (
|
1714 |
+
num_layers_to_build >= 0
|
1715 |
+
), "Not enough layers in the last virtual pipeline stage"
|
1716 |
+
|
1717 |
+
return num_layers_to_build
|
1718 |
+
|
1719 |
+
|
1720 |
+
def get_transformer_layer_offset(config: TransformerConfig):
|
1721 |
+
"""Get the index offset of current pipeline stage, given the level of pipelining."""
|
1722 |
+
pipeline_rank = get_pipeline_model_parallel_rank()
|
1723 |
+
# if not is_inside_encoder():
|
1724 |
+
if True:
|
1725 |
+
pp_decoder_start = 0
|
1726 |
+
if pp_decoder_start is not None:
|
1727 |
+
pipeline_rank = pipeline_rank - pp_decoder_start
|
1728 |
+
|
1729 |
+
if config.pipeline_model_parallel_size > 1:
|
1730 |
+
|
1731 |
+
if (
|
1732 |
+
config.num_layers_in_first_pipeline_stage is not None
|
1733 |
+
or config.num_layers_in_last_pipeline_stage is not None
|
1734 |
+
):
|
1735 |
+
# Calculate number of pipeline stages to distribute the remaining Transformer
|
1736 |
+
# layers after deducting the Transformer layers in the first or the last stages
|
1737 |
+
middle_pipeline_stages = config.pipeline_model_parallel_size
|
1738 |
+
middle_pipeline_stages -= sum(
|
1739 |
+
[
|
1740 |
+
1 if x is not None else 0
|
1741 |
+
for x in (
|
1742 |
+
config.num_layers_in_first_pipeline_stage,
|
1743 |
+
config.num_layers_in_last_pipeline_stage,
|
1744 |
+
)
|
1745 |
+
]
|
1746 |
+
)
|
1747 |
+
|
1748 |
+
# Calculate layers to distribute in each pipeline stage. If the
|
1749 |
+
# num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage
|
1750 |
+
# are not set, we will not enable uneven pipeline. All layers will be treated
|
1751 |
+
# as middle layers.
|
1752 |
+
num_layers_in_first_pipeline_stage = (
|
1753 |
+
0
|
1754 |
+
if config.num_layers_in_first_pipeline_stage is None
|
1755 |
+
else config.num_layers_in_first_pipeline_stage
|
1756 |
+
)
|
1757 |
+
num_layers_in_last_pipeline_stage = (
|
1758 |
+
0
|
1759 |
+
if config.num_layers_in_last_pipeline_stage is None
|
1760 |
+
else config.num_layers_in_last_pipeline_stage
|
1761 |
+
)
|
1762 |
+
|
1763 |
+
middle_num_layers = (
|
1764 |
+
config.num_layers
|
1765 |
+
- num_layers_in_first_pipeline_stage
|
1766 |
+
- num_layers_in_last_pipeline_stage
|
1767 |
+
)
|
1768 |
+
|
1769 |
+
if middle_pipeline_stages > 0:
|
1770 |
+
num_layers_per_pipeline_rank = (
|
1771 |
+
middle_num_layers // middle_pipeline_stages
|
1772 |
+
)
|
1773 |
+
else:
|
1774 |
+
num_layers_per_pipeline_rank = 0
|
1775 |
+
|
1776 |
+
middle_pipeline_rank = (
|
1777 |
+
pipeline_rank
|
1778 |
+
if config.num_layers_in_first_pipeline_stage is None
|
1779 |
+
else pipeline_rank - 1
|
1780 |
+
)
|
1781 |
+
|
1782 |
+
if pipeline_rank == 0:
|
1783 |
+
offset = 0
|
1784 |
+
else:
|
1785 |
+
offset = (
|
1786 |
+
middle_pipeline_rank * num_layers_per_pipeline_rank
|
1787 |
+
) + num_layers_in_first_pipeline_stage
|
1788 |
+
else:
|
1789 |
+
num_layers = config.num_layers
|
1790 |
+
|
1791 |
+
# Increase the number of layers by one if we include the embedding (loss)
|
1792 |
+
# layer into pipeline parallelism partition and placement
|
1793 |
+
if config.account_for_embedding_in_pipeline_split:
|
1794 |
+
num_layers += 1
|
1795 |
+
|
1796 |
+
if config.account_for_loss_in_pipeline_split:
|
1797 |
+
num_layers += 1
|
1798 |
+
|
1799 |
+
num_layers_per_pipeline_rank = (
|
1800 |
+
num_layers // config.pipeline_model_parallel_size
|
1801 |
+
)
|
1802 |
+
|
1803 |
+
offset = pipeline_rank * num_layers_per_pipeline_rank
|
1804 |
+
|
1805 |
+
# Reduce the offset of embedding layer from the total layer number
|
1806 |
+
if (
|
1807 |
+
config.account_for_embedding_in_pipeline_split
|
1808 |
+
and not is_pipeline_first_stage()
|
1809 |
+
):
|
1810 |
+
offset -= 1
|
1811 |
+
else:
|
1812 |
+
offset = 0
|
1813 |
+
return offset
|
webui/index.html
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Megatron Memory Estimator</title>
|
7 |
+
<link rel="stylesheet" href="style.css">
|
8 |
+
</head>
|
9 |
+
<body>
|
10 |
+
<div class="container">
|
11 |
+
<h1>Megatron Memory Estimator</h1>
|
12 |
+
<div class="disclaimer-banner">
|
13 |
+
Note: This estimator only measures the GPU memory directly managed by PyTorch when running Megatron. It does not include extra consumption from NCCL communication buffers, kernel fusion, overlap optimizations, CUDA Graphs, etc. Please use the "Overhead per GPU" option below to account for these additional costs.
|
14 |
+
</div>
|
15 |
+
|
16 |
+
<div class="main-layout">
|
17 |
+
<div class="top-section">
|
18 |
+
<div class="config-column">
|
19 |
+
<form id="config-form">
|
20 |
+
<h2>Configuration</h2>
|
21 |
+
<div class="form-group">
|
22 |
+
<label for="model-select">Select a Local Config:</label>
|
23 |
+
<select id="model-select" name="model">
|
24 |
+
<option value="">Loading...</option>
|
25 |
+
</select>
|
26 |
+
</div>
|
27 |
+
|
28 |
+
<!-- All settings are now in one block -->
|
29 |
+
<div class="form-row">
|
30 |
+
<div class="form-group">
|
31 |
+
<label for="num-gpus">Total GPUs:</label>
|
32 |
+
<input type="number" id="num-gpus" name="num_gpus" value="8" step="8" min="8">
|
33 |
+
</div>
|
34 |
+
<div class="form-group">
|
35 |
+
<label for="mbs">micro batch size:</label>
|
36 |
+
<input type="number" id="mbs" name="mbs" value="1" min="1">
|
37 |
+
</div>
|
38 |
+
<div class="form-group">
|
39 |
+
<label for="seq-len">SeqLen:</label>
|
40 |
+
<input type="number"id="seq-len" name="seq-len" value="4096" min="1">
|
41 |
+
</div>
|
42 |
+
</div>
|
43 |
+
|
44 |
+
<div class="form-group">
|
45 |
+
<input type="checkbox" id="use-distributed-optimizer" name="use_distributed_optimizer" checked>
|
46 |
+
<label for="use-distributed-optimizer" class="inline-label">Use Distributed Optimizer</label>
|
47 |
+
</div>
|
48 |
+
|
49 |
+
<div class="form-row">
|
50 |
+
<div class="form-group">
|
51 |
+
<label for="recompute-granularity">Recomputation:</label>
|
52 |
+
<select id="recompute-granularity" name="recompute_granularity">
|
53 |
+
<option value="none">None</option>
|
54 |
+
<option value="selective">Selective</option>
|
55 |
+
<option value="full">Full</option>
|
56 |
+
</select>
|
57 |
+
</div>
|
58 |
+
<div class="form-group recompute-options" style="display: none;">
|
59 |
+
<label for="recompute-method">Method:</label>
|
60 |
+
<select id="recompute-method" name="recompute_method">
|
61 |
+
<option value="uniform">Uniform</option>
|
62 |
+
<option value="block">Block</option>
|
63 |
+
</select>
|
64 |
+
</div>
|
65 |
+
<div class="form-group recompute-options" style="display: none;">
|
66 |
+
<label for="recompute-num-layers">Layers:</label>
|
67 |
+
<input type="number" id="recompute-num-layers" name="recompute_num_layers" value="1" min="1">
|
68 |
+
</div>
|
69 |
+
</div>
|
70 |
+
|
71 |
+
<div class="form-row">
|
72 |
+
<div class="form-group">
|
73 |
+
<label for="tp">TP:</label>
|
74 |
+
<select id="tp" name="tp"></select>
|
75 |
+
</div>
|
76 |
+
<div class="form-group">
|
77 |
+
<label for="pp">PP:</label>
|
78 |
+
<input type="number" id="pp" name="pp" value="1" min="1">
|
79 |
+
</div>
|
80 |
+
<div class="form-group">
|
81 |
+
<label for="ep">EP:</label>
|
82 |
+
<select id="ep" name="ep"></select>
|
83 |
+
</div>
|
84 |
+
<div class="form-group">
|
85 |
+
<label for="cp">CP:</label>
|
86 |
+
<select id="cp" name="cp"></select>
|
87 |
+
</div>
|
88 |
+
</div>
|
89 |
+
<div class="form-row">
|
90 |
+
<div class="form-group">
|
91 |
+
<label for="vpp">VPP:</label>
|
92 |
+
<input type="number" id="vpp" name="vpp" placeholder="None" min="1">
|
93 |
+
</div>
|
94 |
+
<div class="form-group">
|
95 |
+
<label for="etp">ETP:</label>
|
96 |
+
<input type="number" id="etp" name="etp" placeholder="None" min="1">
|
97 |
+
</div>
|
98 |
+
</div>
|
99 |
+
<div class="form-row">
|
100 |
+
<div class="form-group">
|
101 |
+
<label for="num_layers_in_first_pipeline_stage">First Stage Layers:</label>
|
102 |
+
<input type="number" id="num_layers_in_first_pipeline_stage" name="num_layers_in_first_pipeline_stage" placeholder="None" min="0">
|
103 |
+
</div>
|
104 |
+
<div class="form-group">
|
105 |
+
<label for="num_layers_in_last_pipeline_stage">Last Stage Layers:</label>
|
106 |
+
<input type="number" id="num_layers_in_last_pipeline_stage" name="num_layers_in_last_pipeline_stage" placeholder="None" min="0">
|
107 |
+
</div>
|
108 |
+
</div>
|
109 |
+
<div class="form-row">
|
110 |
+
<div class="form-group">
|
111 |
+
<label for="overhead">Overhead per GPU:</label>
|
112 |
+
<select id="overhead" name="overhead">
|
113 |
+
<option value="5">5GB</option>
|
114 |
+
<option value="10" selected>10GB</option>
|
115 |
+
</select>
|
116 |
+
</div>
|
117 |
+
</div>
|
118 |
+
|
119 |
+
<div id="validation-message" class="error-message" style="display: none;"></div>
|
120 |
+
<div class="button-container">
|
121 |
+
<button type="submit">Estimate</button>
|
122 |
+
</div>
|
123 |
+
</form>
|
124 |
+
</div>
|
125 |
+
|
126 |
+
<div class="output-column">
|
127 |
+
<div class="config-editor-wrapper">
|
128 |
+
<h2>Model Config (Editable)</h2>
|
129 |
+
<textarea id="config-editor" rows="20"></textarea>
|
130 |
+
</div>
|
131 |
+
</div>
|
132 |
+
</div>
|
133 |
+
|
134 |
+
<div class="bottom-section">
|
135 |
+
<div id="output-container">
|
136 |
+
<div id="loading" style="display: none;">Calculating...</div>
|
137 |
+
<div id="history-wrapper">
|
138 |
+
<h3>History</h3>
|
139 |
+
<table id="history-table">
|
140 |
+
<thead>
|
141 |
+
<tr>
|
142 |
+
<th>Model</th>
|
143 |
+
<th>Weight Optimizer (GB)</th>
|
144 |
+
<th>Activation (GB)</th>
|
145 |
+
<th>Total (GB/GPU)</th>
|
146 |
+
<th>Actions</th>
|
147 |
+
</tr>
|
148 |
+
</thead>
|
149 |
+
<tbody>
|
150 |
+
</tbody>
|
151 |
+
</table>
|
152 |
+
<button id="clear-history" style="margin-top: 1em;">Clear History</button>
|
153 |
+
</div>
|
154 |
+
</div>
|
155 |
+
</div>
|
156 |
+
</div>
|
157 |
+
</div>
|
158 |
+
<script src="script.js"></script>
|
159 |
+
<footer class="footer">
|
160 |
+
<p>© 2025 <a href="https://github.com/ISEEKYAN" target="_blank">ISEEKYAN</a>. Developed at NVIDIA.</p>
|
161 |
+
</footer>
|
162 |
+
</body>
|
163 |
+
</html>
|
webui/main.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
from fastapi import FastAPI, Body
|
4 |
+
from fastapi.staticfiles import StaticFiles
|
5 |
+
from fastapi.responses import FileResponse
|
6 |
+
import requests
|
7 |
+
from pydantic import BaseModel, field_validator
|
8 |
+
from typing import Optional
|
9 |
+
from mbridge import AutoBridge
|
10 |
+
from estimate import estimate_from_config
|
11 |
+
from megatron.core import parallel_state as mpu
|
12 |
+
import argparse
|
13 |
+
import json
|
14 |
+
import tempfile
|
15 |
+
|
16 |
+
# The directory of the current script (main.py)
|
17 |
+
WEBUI_DIR = os.path.dirname(os.path.abspath(__file__))
|
18 |
+
|
19 |
+
app = FastAPI()
|
20 |
+
|
21 |
+
# Mount static files from the webui directory
|
22 |
+
app.mount("/static", StaticFiles(directory=WEBUI_DIR), name="static")
|
23 |
+
|
24 |
+
|
25 |
+
@app.get("/")
|
26 |
+
async def read_index():
|
27 |
+
return FileResponse(os.path.join(WEBUI_DIR, 'index.html'))
|
28 |
+
|
29 |
+
@app.get("/style.css")
|
30 |
+
async def read_css():
|
31 |
+
return FileResponse(os.path.join(WEBUI_DIR, 'style.css'))
|
32 |
+
|
33 |
+
@app.get("/script.js")
|
34 |
+
async def read_js():
|
35 |
+
return FileResponse(os.path.join(WEBUI_DIR, 'script.js'))
|
36 |
+
|
37 |
+
|
38 |
+
SUPPORTED_MODELS = [
|
39 |
+
"Qwen/Qwen3-235B-A22B",
|
40 |
+
"Qwen/Qwen3-30B-A3B",
|
41 |
+
"Qwen/Qwen3-32B",
|
42 |
+
"Qwen/Qwen3-14B",
|
43 |
+
"Qwen/Qwen3-8B",
|
44 |
+
"Qwen/Qwen2.5-7B",
|
45 |
+
"Qwen/Qwen2.5-14B",
|
46 |
+
"Qwen/Qwen2.5-32B",
|
47 |
+
"Qwen/Qwen2.5-72B",
|
48 |
+
"moonshotai/Moonlight-16B-A3B",
|
49 |
+
"moonshotai/Kimi-K2-Instruct",
|
50 |
+
"deepseek-ai/DeepSeek-V3",
|
51 |
+
]
|
52 |
+
|
53 |
+
|
54 |
+
@app.get("/local-hf-configs")
|
55 |
+
async def get_supported_models():
|
56 |
+
"""Return the list of HF model identifiers supported by the UI."""
|
57 |
+
return SUPPORTED_MODELS
|
58 |
+
|
59 |
+
@app.get("/get-megatron-config/{model_path:path}")
|
60 |
+
async def get_remote_hf_config(model_path: str):
|
61 |
+
"""Fetch the HuggingFace config.json for the given model id."""
|
62 |
+
url = f"https://huggingface.co/{model_path}/raw/main/config.json"
|
63 |
+
try:
|
64 |
+
resp = requests.get(url, timeout=10)
|
65 |
+
resp.raise_for_status()
|
66 |
+
return resp.json()
|
67 |
+
except Exception as e:
|
68 |
+
return {"error": f"Failed to fetch config from {url}: {str(e)}"}
|
69 |
+
|
70 |
+
|
71 |
+
class MBridgeEstimateConfig(BaseModel):
|
72 |
+
hf_model_path: str
|
73 |
+
custom_hf_config: Optional[dict] = None # Renamed for clarity
|
74 |
+
|
75 |
+
# Hardware & Training
|
76 |
+
num_gpus: int = 8
|
77 |
+
mbs: int = 1
|
78 |
+
seq_len: int = 4096
|
79 |
+
use_distributed_optimizer: bool = True
|
80 |
+
# Recompute settings are now part of the main config
|
81 |
+
recompute_granularity: str = "selective"
|
82 |
+
recompute_method: str = "uniform"
|
83 |
+
recompute_num_layers: Optional[int] = 1
|
84 |
+
|
85 |
+
# Parallelism
|
86 |
+
tp: int = 1
|
87 |
+
pp: int = 1
|
88 |
+
ep: int = 1
|
89 |
+
cp: int = 1
|
90 |
+
vpp: Optional[int] = None
|
91 |
+
etp: Optional[int] = None
|
92 |
+
|
93 |
+
# Pipeline stage layer counts
|
94 |
+
num_layers_in_first_pipeline_stage: Optional[int] = None
|
95 |
+
num_layers_in_last_pipeline_stage: Optional[int] = None
|
96 |
+
|
97 |
+
@field_validator('num_gpus')
|
98 |
+
def num_gpus_must_be_multiple_of_8(cls, v):
|
99 |
+
if v <= 0 or v % 8 != 0:
|
100 |
+
raise ValueError('must be a positive multiple of 8')
|
101 |
+
return v
|
102 |
+
|
103 |
+
def patch_parallel_states(config: MBridgeEstimateConfig):
|
104 |
+
from mbridge.core.parallel_states import ParallelStates
|
105 |
+
ParallelStates.get_default_parallel_states = lambda: ParallelStates(
|
106 |
+
tp_size=config.tp,
|
107 |
+
pp_size=config.pp,
|
108 |
+
ep_size=config.ep,
|
109 |
+
cp_size=config.cp,
|
110 |
+
vpp_size=config.vpp,
|
111 |
+
etp_size=config.etp,
|
112 |
+
)
|
113 |
+
|
114 |
+
@app.post("/estimate_with_mbridge")
|
115 |
+
async def estimate_with_mbridge(config: MBridgeEstimateConfig):
|
116 |
+
# Validate Inputs
|
117 |
+
if config.num_gpus <= 0 or config.num_gpus % 8 != 0:
|
118 |
+
return {"error": "Total number of GPUs must be a positive multiple of 8."}
|
119 |
+
|
120 |
+
parallel_product = config.tp * config.pp * config.cp
|
121 |
+
if parallel_product == 0: # Avoid division by zero
|
122 |
+
return {"error": "Parallelism dimensions (TP, PP, CP) cannot be zero."}
|
123 |
+
|
124 |
+
if config.num_gpus % parallel_product != 0:
|
125 |
+
return {"error": f"Number of GPUs ({config.num_gpus}) must be divisible by the product of TP*PP*CP ({parallel_product})."}
|
126 |
+
|
127 |
+
patch_parallel_states(config)
|
128 |
+
|
129 |
+
# If the path is just a filename, assume it's in our local model-configs dir
|
130 |
+
hf_model_path = config.hf_model_path
|
131 |
+
# This logic needs to change. The custom config from the UI is an HF config, not a Megatron config.
|
132 |
+
# We need to load it via a temporary file.
|
133 |
+
if config.custom_hf_config:
|
134 |
+
try:
|
135 |
+
# Create a temporary file to save the custom HF config
|
136 |
+
with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix=".json", dir=os.path.join(WEBUI_DIR, 'model-configs')) as tmp:
|
137 |
+
json.dump(config.custom_hf_config, tmp)
|
138 |
+
tmp_path = tmp.name
|
139 |
+
|
140 |
+
# Load the bridge from the temporary config file
|
141 |
+
from transformers import AutoConfig
|
142 |
+
AutoConfig.trust_remote_code = True
|
143 |
+
bridge = AutoBridge.from_pretrained(tmp_path)
|
144 |
+
tf_config = bridge.config
|
145 |
+
hf_config = bridge.hf_config
|
146 |
+
|
147 |
+
finally:
|
148 |
+
# Ensure the temporary file is deleted
|
149 |
+
if 'tmp_path' in locals() and os.path.exists(tmp_path):
|
150 |
+
os.remove(tmp_path)
|
151 |
+
else:
|
152 |
+
# If no custom config, load from the original path
|
153 |
+
if not os.path.isabs(hf_model_path) and not hf_model_path.startswith(('http', './', '../')):
|
154 |
+
hf_model_path = os.path.join(WEBUI_DIR, 'model-configs', hf_model_path)
|
155 |
+
bridge = AutoBridge.from_pretrained(hf_model_path)
|
156 |
+
tf_config = bridge.config
|
157 |
+
hf_config = bridge.hf_config
|
158 |
+
|
159 |
+
# --- Configuration Unification ---
|
160 |
+
# Update the tf_config with values from the form. This makes tf_config the single source of truth.
|
161 |
+
tf_config.tensor_model_parallel_size = config.tp
|
162 |
+
tf_config.pipeline_model_parallel_size = config.pp
|
163 |
+
tf_config.expert_model_parallel_size = config.ep
|
164 |
+
tf_config.context_parallel_size = config.cp
|
165 |
+
tf_config.recompute_granularity = config.recompute_granularity
|
166 |
+
tf_config.recompute_method = config.recompute_method
|
167 |
+
tf_config.recompute_num_layers = config.recompute_num_layers
|
168 |
+
tf_config.num_layers_per_virtual_pipeline_stage = config.vpp if config.vpp and config.vpp > 1 else None
|
169 |
+
|
170 |
+
if config.num_layers_in_first_pipeline_stage is not None:
|
171 |
+
tf_config.num_layers_in_first_pipeline_stage = config.num_layers_in_first_pipeline_stage
|
172 |
+
if config.num_layers_in_last_pipeline_stage is not None:
|
173 |
+
tf_config.num_layers_in_last_pipeline_stage = config.num_layers_in_last_pipeline_stage
|
174 |
+
# print(tf_config)
|
175 |
+
|
176 |
+
# Create a minimal 'args' object with parameters not present in TransformerConfig
|
177 |
+
args = argparse.Namespace()
|
178 |
+
args.micro_batch_size = config.mbs
|
179 |
+
args.seq_length = config.seq_len
|
180 |
+
args.use_distributed_optimizer = config.use_distributed_optimizer
|
181 |
+
args.data_parallel_size = config.num_gpus // parallel_product
|
182 |
+
args.expert_tensor_parallel_size = config.etp if config.etp else 1
|
183 |
+
|
184 |
+
# These are required by the estimator but can be derived or defaulted
|
185 |
+
args.transformer_impl = "transformer_engine"
|
186 |
+
args.fp8 = False
|
187 |
+
args.num_experts = getattr(tf_config, 'num_moe_experts', 1) # Needed for layer spec
|
188 |
+
args.moe_grouped_gemm = True # Default
|
189 |
+
args.qk_layernorm = tf_config.qk_layernorm
|
190 |
+
args.multi_latent_attention = "deepseek" in getattr(hf_config, "model_type", "")
|
191 |
+
args.padded_vocab_size = getattr(hf_config, "vocab_size")
|
192 |
+
args.max_position_embeddings = getattr(hf_config, "max_position_embeddings")
|
193 |
+
args.tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False)
|
194 |
+
|
195 |
+
|
196 |
+
# This function now returns a list of reports, one for each PP rank
|
197 |
+
raw_reports_list = estimate_from_config(tf_config, args)
|
198 |
+
|
199 |
+
# The report from estimate.py now has the correct units (GB), so no conversion is needed.
|
200 |
+
# We just need to remove the complex 'details' part for the main display table.
|
201 |
+
processed_reports = []
|
202 |
+
for report in raw_reports_list:
|
203 |
+
# Create a copy of the report and remove the 'details' key
|
204 |
+
processed_report = report.copy()
|
205 |
+
processed_report.pop('details', None)
|
206 |
+
processed_reports.append(processed_report)
|
207 |
+
|
208 |
+
return {
|
209 |
+
"processed_report": processed_reports,
|
210 |
+
"raw_report": raw_reports_list
|
211 |
+
}
|
webui/model-configs/qwen3-14b.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"Qwen3ForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_bias": false,
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"bos_token_id": 151643,
|
8 |
+
"eos_token_id": 151645,
|
9 |
+
"head_dim": 128,
|
10 |
+
"hidden_act": "silu",
|
11 |
+
"hidden_size": 5120,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 17408,
|
14 |
+
"max_position_embeddings": 40960,
|
15 |
+
"max_window_layers": 40,
|
16 |
+
"model_type": "qwen3",
|
17 |
+
"num_attention_heads": 40,
|
18 |
+
"num_hidden_layers": 40,
|
19 |
+
"num_key_value_heads": 8,
|
20 |
+
"rms_norm_eps": 1e-06,
|
21 |
+
"rope_scaling": null,
|
22 |
+
"rope_theta": 1000000,
|
23 |
+
"sliding_window": null,
|
24 |
+
"tie_word_embeddings": false,
|
25 |
+
"torch_dtype": "bfloat16",
|
26 |
+
"transformers_version": "4.51.0",
|
27 |
+
"use_cache": true,
|
28 |
+
"use_sliding_window": false,
|
29 |
+
"vocab_size": 151936
|
30 |
+
}
|
webui/model-configs/qwen3-235b-a22b.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"Qwen3MoeForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_bias": false,
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"bos_token_id": 151643,
|
8 |
+
"decoder_sparse_step": 1,
|
9 |
+
"eos_token_id": 151645,
|
10 |
+
"head_dim": 128,
|
11 |
+
"hidden_act": "silu",
|
12 |
+
"hidden_size": 4096,
|
13 |
+
"initializer_range": 0.02,
|
14 |
+
"intermediate_size": 12288,
|
15 |
+
"max_position_embeddings": 40960,
|
16 |
+
"max_window_layers": 94,
|
17 |
+
"mlp_only_layers": [],
|
18 |
+
"model_type": "qwen3_moe",
|
19 |
+
"moe_intermediate_size": 1536,
|
20 |
+
"norm_topk_prob": true,
|
21 |
+
"num_attention_heads": 64,
|
22 |
+
"num_experts": 128,
|
23 |
+
"num_experts_per_tok": 8,
|
24 |
+
"num_hidden_layers": 94,
|
25 |
+
"num_key_value_heads": 4,
|
26 |
+
"output_router_logits": false,
|
27 |
+
"rms_norm_eps": 1e-06,
|
28 |
+
"rope_scaling": null,
|
29 |
+
"rope_theta": 1000000.0,
|
30 |
+
"router_aux_loss_coef": 0.001,
|
31 |
+
"sliding_window": null,
|
32 |
+
"tie_word_embeddings": false,
|
33 |
+
"torch_dtype": "bfloat16",
|
34 |
+
"transformers_version": "4.51.0",
|
35 |
+
"use_cache": true,
|
36 |
+
"use_sliding_window": false,
|
37 |
+
"vocab_size": 151936
|
38 |
+
}
|
webui/model-configs/qwen3-30b-a3b.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"Qwen3MoeForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_bias": false,
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"bos_token_id": 151643,
|
8 |
+
"decoder_sparse_step": 1,
|
9 |
+
"eos_token_id": 151645,
|
10 |
+
"head_dim": 128,
|
11 |
+
"hidden_act": "silu",
|
12 |
+
"hidden_size": 2048,
|
13 |
+
"initializer_range": 0.02,
|
14 |
+
"intermediate_size": 6144,
|
15 |
+
"max_position_embeddings": 40960,
|
16 |
+
"max_window_layers": 48,
|
17 |
+
"mlp_only_layers": [],
|
18 |
+
"model_type": "qwen3_moe",
|
19 |
+
"moe_intermediate_size": 768,
|
20 |
+
"norm_topk_prob": true,
|
21 |
+
"num_attention_heads": 32,
|
22 |
+
"num_experts": 128,
|
23 |
+
"num_experts_per_tok": 8,
|
24 |
+
"num_hidden_layers": 48,
|
25 |
+
"num_key_value_heads": 4,
|
26 |
+
"output_router_logits": false,
|
27 |
+
"rms_norm_eps": 1e-06,
|
28 |
+
"rope_scaling": null,
|
29 |
+
"rope_theta": 1000000.0,
|
30 |
+
"router_aux_loss_coef": 0.001,
|
31 |
+
"sliding_window": null,
|
32 |
+
"tie_word_embeddings": false,
|
33 |
+
"torch_dtype": "bfloat16",
|
34 |
+
"transformers_version": "4.51.0",
|
35 |
+
"use_cache": true,
|
36 |
+
"use_sliding_window": false,
|
37 |
+
"vocab_size": 151936
|
38 |
+
}
|
webui/model-configs/qwen3-32b.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"Qwen3ForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_bias": false,
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"bos_token_id": 151643,
|
8 |
+
"eos_token_id": 151645,
|
9 |
+
"head_dim": 128,
|
10 |
+
"hidden_act": "silu",
|
11 |
+
"hidden_size": 5120,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 25600,
|
14 |
+
"max_position_embeddings": 40960,
|
15 |
+
"max_window_layers": 64,
|
16 |
+
"model_type": "qwen3",
|
17 |
+
"num_attention_heads": 64,
|
18 |
+
"num_hidden_layers": 64,
|
19 |
+
"num_key_value_heads": 8,
|
20 |
+
"rms_norm_eps": 1e-06,
|
21 |
+
"rope_scaling": null,
|
22 |
+
"rope_theta": 1000000,
|
23 |
+
"sliding_window": null,
|
24 |
+
"tie_word_embeddings": false,
|
25 |
+
"torch_dtype": "bfloat16",
|
26 |
+
"transformers_version": "4.51.0",
|
27 |
+
"use_cache": true,
|
28 |
+
"use_sliding_window": false,
|
29 |
+
"vocab_size": 151936
|
30 |
+
}
|
webui/model-configs/qwen3-8b.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"Qwen3ForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_bias": false,
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"bos_token_id": 151643,
|
8 |
+
"eos_token_id": 151645,
|
9 |
+
"head_dim": 128,
|
10 |
+
"hidden_act": "silu",
|
11 |
+
"hidden_size": 4096,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 12288,
|
14 |
+
"max_position_embeddings": 40960,
|
15 |
+
"max_window_layers": 36,
|
16 |
+
"model_type": "qwen3",
|
17 |
+
"num_attention_heads": 32,
|
18 |
+
"num_hidden_layers": 36,
|
19 |
+
"num_key_value_heads": 8,
|
20 |
+
"rms_norm_eps": 1e-06,
|
21 |
+
"rope_scaling": null,
|
22 |
+
"rope_theta": 1000000,
|
23 |
+
"sliding_window": null,
|
24 |
+
"tie_word_embeddings": false,
|
25 |
+
"torch_dtype": "bfloat16",
|
26 |
+
"transformers_version": "4.51.0",
|
27 |
+
"use_cache": true,
|
28 |
+
"use_sliding_window": false,
|
29 |
+
"vocab_size": 151936
|
30 |
+
}
|
webui/requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
uvicorn[standard]
|
3 |
+
mbridge
|
webui/script.js
ADDED
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
document.addEventListener('DOMContentLoaded', () => {
|
2 |
+
// Initial UI setup
|
3 |
+
loadLocalConfigs();
|
4 |
+
updateHistoryView();
|
5 |
+
setupEventListeners();
|
6 |
+
updateParallelismOptions();
|
7 |
+
validateParallelismLive();
|
8 |
+
toggleEpBasedOnConfig(); // Disable EP initially
|
9 |
+
});
|
10 |
+
|
11 |
+
// Utility: convert ANSI color codes (red 31, green 32) to HTML spans for display
|
12 |
+
function ansiToHtml(str) {
|
13 |
+
if (!str) return '';
|
14 |
+
// Replace known ANSI codes
|
15 |
+
return str
|
16 |
+
.replace(/\u001b\[31m/g, '<span class="ansi-red">')
|
17 |
+
.replace(/\u001b\[32m/g, '<span class="ansi-green">')
|
18 |
+
.replace(/\u001b\[33m/g, '<span class="ansi-yellow">')
|
19 |
+
.replace(/\u001b\[34m/g, '<span class="ansi-blue">')
|
20 |
+
.replace(/\u001b\[35m/g, '<span class="ansi-magenta">')
|
21 |
+
.replace(/\u001b\[36m/g, '<span class="ansi-cyan">')
|
22 |
+
.replace(/\u001b\[0m/g, '</span>');
|
23 |
+
}
|
24 |
+
|
25 |
+
function setupEventListeners() {
|
26 |
+
document.getElementById('config-form').addEventListener('submit', (e) => {
|
27 |
+
e.preventDefault();
|
28 |
+
submitForm();
|
29 |
+
});
|
30 |
+
|
31 |
+
document.getElementById('model-select').addEventListener('change', loadSelectedModelConfig);
|
32 |
+
|
33 |
+
document.getElementById('recompute-granularity').addEventListener('change', (e) => {
|
34 |
+
const recomputeOptions = document.querySelectorAll('.recompute-options');
|
35 |
+
recomputeOptions.forEach(opt => {
|
36 |
+
opt.style.display = e.target.value === 'full' ? 'block' : 'none';
|
37 |
+
});
|
38 |
+
});
|
39 |
+
|
40 |
+
const liveValidationInputs = ['num-gpus', 'tp', 'pp', 'ep', 'cp', 'etp', 'config-editor'];
|
41 |
+
liveValidationInputs.forEach(id => {
|
42 |
+
const input = document.getElementById(id);
|
43 |
+
if(input) {
|
44 |
+
input.addEventListener('change', validateParallelismLive);
|
45 |
+
if (id === 'num-gpus') {
|
46 |
+
input.addEventListener('change', updateParallelismOptions);
|
47 |
+
}
|
48 |
+
}
|
49 |
+
});
|
50 |
+
|
51 |
+
document.getElementById('config-editor').addEventListener('input', toggleEpBasedOnConfig);
|
52 |
+
document.getElementById('history-table').addEventListener('click', handleHistoryAction);
|
53 |
+
document.getElementById('clear-history').addEventListener('click', clearHistory);
|
54 |
+
}
|
55 |
+
|
56 |
+
|
57 |
+
async function loadLocalConfigs() {
|
58 |
+
const modelSelect = document.getElementById('model-select');
|
59 |
+
const defaultConfigName = 'Qwen/Qwen3-235B-A22B'; // Updated default model
|
60 |
+
|
61 |
+
try {
|
62 |
+
const response = await fetch('/local-hf-configs');
|
63 |
+
const configs = await response.json();
|
64 |
+
|
65 |
+
modelSelect.innerHTML = '<option value="">Select a model...</option>';
|
66 |
+
// Add custom option to allow user supplied configs
|
67 |
+
modelSelect.innerHTML += '<option value="__custom__">Custom (paste JSON below)...</option>';
|
68 |
+
configs.forEach(config => {
|
69 |
+
modelSelect.innerHTML += `<option value="${config}">${config}</option>`;
|
70 |
+
});
|
71 |
+
|
72 |
+
// Check if the default config exists and select it
|
73 |
+
if (configs.includes(defaultConfigName)) {
|
74 |
+
modelSelect.value = defaultConfigName;
|
75 |
+
// Await the loading of the model config to ensure it's ready
|
76 |
+
await loadSelectedModelConfig();
|
77 |
+
}
|
78 |
+
|
79 |
+
} catch (error) {
|
80 |
+
modelSelect.innerHTML = '<option value="">Error loading configs</option>';
|
81 |
+
console.error('Error loading local configs:', error);
|
82 |
+
}
|
83 |
+
}
|
84 |
+
|
85 |
+
async function loadSelectedModelConfig() {
|
86 |
+
const modelSelect = document.getElementById('model-select');
|
87 |
+
const editor = document.getElementById('config-editor');
|
88 |
+
const selectedConfig = modelSelect.value;
|
89 |
+
const messageDiv = document.getElementById('validation-message'); // move early for use in all branches
|
90 |
+
let configData = null; // declare for wider scope
|
91 |
+
|
92 |
+
if (!selectedConfig) {
|
93 |
+
editor.value = '';
|
94 |
+
toggleEpBasedOnConfig();
|
95 |
+
if (messageDiv) messageDiv.style.display = 'none';
|
96 |
+
return;
|
97 |
+
} else if (selectedConfig === '__custom__') {
|
98 |
+
// Custom config: do not fetch, user must paste JSON
|
99 |
+
editor.value = '';
|
100 |
+
toggleEpBasedOnConfig();
|
101 |
+
if (messageDiv) messageDiv.style.display = 'none';
|
102 |
+
return;
|
103 |
+
}
|
104 |
+
|
105 |
+
try {
|
106 |
+
const response = await fetch(`/get-megatron-config/${encodeURIComponent(selectedConfig)}`);
|
107 |
+
configData = await response.json();
|
108 |
+
if (configData.error) {
|
109 |
+
editor.value = `Error: ${configData.error}`;
|
110 |
+
} else {
|
111 |
+
editor.value = JSON.stringify(configData, null, 2);
|
112 |
+
}
|
113 |
+
} catch (error) {
|
114 |
+
editor.value = 'Failed to fetch model configuration.';
|
115 |
+
console.error('Error fetching model config:', error);
|
116 |
+
}
|
117 |
+
|
118 |
+
// Trigger validation and UI updates after loading new config
|
119 |
+
validateParallelismLive();
|
120 |
+
toggleEpBasedOnConfig();
|
121 |
+
|
122 |
+
// Show Kimi-K2-Instruct warning if needed
|
123 |
+
if (selectedConfig.includes('Kimi-K2-Instruct') && configData && configData.model_type !== 'deepseek_v3') {
|
124 |
+
messageDiv.textContent = 'Notice: For Kimi-K2-Instruct the config field "model_type" must be set to "deepseek_v3" before memory estimation.';
|
125 |
+
messageDiv.style.display = 'block';
|
126 |
+
} else if (messageDiv) {
|
127 |
+
messageDiv.style.display = 'none';
|
128 |
+
}
|
129 |
+
}
|
130 |
+
|
131 |
+
|
132 |
+
function getFormValues(isSubmission = false) {
|
133 |
+
const form = document.getElementById('config-form');
|
134 |
+
const formData = new FormData(form);
|
135 |
+
const modelSelect = document.getElementById('model-select');
|
136 |
+
|
137 |
+
const hfPath = modelSelect.value;
|
138 |
+
if (!hfPath) {
|
139 |
+
// We will now handle this case in the submitForm function instead of an alert.
|
140 |
+
return null;
|
141 |
+
}
|
142 |
+
|
143 |
+
const editor = document.getElementById('config-editor');
|
144 |
+
let customConfig = null;
|
145 |
+
try {
|
146 |
+
// Only parse if the editor has content
|
147 |
+
if (editor.value) {
|
148 |
+
customConfig = JSON.parse(editor.value);
|
149 |
+
}
|
150 |
+
} catch (e) {
|
151 |
+
// Only alert on final submission, not on live validation
|
152 |
+
if (isSubmission) {
|
153 |
+
// alert('Model Config is not valid JSON.'); // Removing alert
|
154 |
+
}
|
155 |
+
return null; // Return null if JSON is invalid
|
156 |
+
}
|
157 |
+
|
158 |
+
const vppInput = formData.get('vpp');
|
159 |
+
const etpInput = formData.get('etp');
|
160 |
+
|
161 |
+
return {
|
162 |
+
hf_model_path: hfPath,
|
163 |
+
custom_hf_config: customConfig, // Renamed for clarity
|
164 |
+
num_gpus: parseInt(formData.get('num_gpus')),
|
165 |
+
mbs: parseInt(formData.get('mbs')),
|
166 |
+
seq_len: parseInt(formData.get('seq-len')),
|
167 |
+
use_distributed_optimizer: document.getElementById('use-distributed-optimizer').checked,
|
168 |
+
recompute_granularity: formData.get('recompute_granularity'),
|
169 |
+
recompute_method: formData.get('recompute_method'),
|
170 |
+
recompute_num_layers: parseInt(formData.get('recompute_num_layers')),
|
171 |
+
tp: parseInt(formData.get('tp')),
|
172 |
+
pp: parseInt(formData.get('pp')),
|
173 |
+
ep: parseInt(formData.get('ep')) || 1, // Default to 1 if disabled/null
|
174 |
+
cp: parseInt(formData.get('cp')),
|
175 |
+
vpp: vppInput ? parseInt(vppInput) : null,
|
176 |
+
etp: etpInput ? parseInt(etpInput) : null,
|
177 |
+
num_layers_in_first_pipeline_stage: formData.get('num_layers_in_first_pipeline_stage') ? parseInt(formData.get('num_layers_in_first_pipeline_stage')) : null,
|
178 |
+
num_layers_in_last_pipeline_stage: formData.get('num_layers_in_last_pipeline_stage') ? parseInt(formData.get('num_layers_in_last_pipeline_stage')) : null,
|
179 |
+
overhead: parseInt(formData.get('overhead')),
|
180 |
+
};
|
181 |
+
}
|
182 |
+
|
183 |
+
async function submitForm() {
|
184 |
+
const messageDiv = document.getElementById('validation-message');
|
185 |
+
messageDiv.textContent = '';
|
186 |
+
messageDiv.style.display = 'none';
|
187 |
+
|
188 |
+
// Get all form values first. We use getFormValues(false) to avoid any legacy alerts
|
189 |
+
// and handle all validation directly within this function for clarity.
|
190 |
+
const formValues = getFormValues(false);
|
191 |
+
|
192 |
+
// === START SUBMISSION VALIDATION ===
|
193 |
+
|
194 |
+
// 1. Check if form values could be retrieved. This catches both missing model selection
|
195 |
+
// and invalid JSON, as getFormValues returns null in those cases.
|
196 |
+
if (!formValues) {
|
197 |
+
if (!document.getElementById('model-select').value) {
|
198 |
+
messageDiv.textContent = 'Validation Error: Please select a model config.';
|
199 |
+
} else {
|
200 |
+
messageDiv.textContent = 'Validation Error: Model Config is not valid JSON.';
|
201 |
+
}
|
202 |
+
messageDiv.style.display = 'block';
|
203 |
+
return;
|
204 |
+
}
|
205 |
+
|
206 |
+
// Custom config must have valid JSON
|
207 |
+
if (document.getElementById('model-select').value === '__custom__' && !formValues.custom_hf_config) {
|
208 |
+
messageDiv.textContent = 'Validation Error: Please paste a valid model configuration JSON for the custom model.';
|
209 |
+
messageDiv.style.display = 'block';
|
210 |
+
return;
|
211 |
+
}
|
212 |
+
|
213 |
+
// 2. Perform all numeric and parallelism validation.
|
214 |
+
const { num_gpus, tp, pp, ep, cp, etp, custom_hf_config } = formValues;
|
215 |
+
const num_kv_heads = custom_hf_config?.num_key_value_heads || null;
|
216 |
+
|
217 |
+
let errors = [];
|
218 |
+
if (tp * pp * cp > num_gpus) {
|
219 |
+
errors.push(`TP*PP*CP (${tp * pp * cp}) > GPUs (${num_gpus}).`);
|
220 |
+
}
|
221 |
+
if (etp){
|
222 |
+
if (etp * pp * cp * ep > num_gpus) {
|
223 |
+
errors.push(`ETP*PP*CP*EP (${etp * pp * cp * ep}) > GPUs (${num_gpus}).`);
|
224 |
+
}
|
225 |
+
} else {
|
226 |
+
if (tp * pp * cp * ep > num_gpus) {
|
227 |
+
errors.push(`TP*PP*CP*EP (${tp * pp * cp * ep}) > GPUs (${num_gpus}) when ETP is not set.`);
|
228 |
+
}
|
229 |
+
}
|
230 |
+
if (num_kv_heads && tp > num_kv_heads) {
|
231 |
+
errors.push(`TP (${tp}) > Num KV Heads (${num_kv_heads}).`);
|
232 |
+
}
|
233 |
+
|
234 |
+
if (errors.length > 0) {
|
235 |
+
messageDiv.textContent = 'Validation Error: ' + errors.join(' ');
|
236 |
+
messageDiv.style.display = 'block';
|
237 |
+
return;
|
238 |
+
}
|
239 |
+
// === END SUBMISSION VALIDATION ===
|
240 |
+
|
241 |
+
const loading = document.getElementById('loading');
|
242 |
+
const submitBtn = document.querySelector('#config-form button[type="submit"]');
|
243 |
+
loading.style.display = 'block';
|
244 |
+
if (submitBtn) submitBtn.disabled = true;
|
245 |
+
|
246 |
+
try {
|
247 |
+
const response = await fetch('/estimate_with_mbridge', {
|
248 |
+
method: 'POST',
|
249 |
+
headers: { 'Content-Type': 'application/json' },
|
250 |
+
body: JSON.stringify(formValues) // Send the now fully-validated formValues
|
251 |
+
});
|
252 |
+
|
253 |
+
console.log('Response Status:', response.status);
|
254 |
+
|
255 |
+
if (response.ok) {
|
256 |
+
const data = await response.json();
|
257 |
+
|
258 |
+
// FIX: Ensure history wrapper is visible before updating and showing details
|
259 |
+
document.getElementById('history-wrapper').style.display = 'block';
|
260 |
+
|
261 |
+
saveToHistory(formValues, data);
|
262 |
+
updateHistoryView();
|
263 |
+
const newEntryRow = document.querySelector('#history-table tbody tr:first-child');
|
264 |
+
if (newEntryRow) {
|
265 |
+
const detailBtn = newEntryRow.querySelector('.detail-btn');
|
266 |
+
if (detailBtn) {
|
267 |
+
// We need to pass the event object structure to handleHistoryAction
|
268 |
+
handleHistoryAction({ target: detailBtn });
|
269 |
+
}
|
270 |
+
}
|
271 |
+
} else {
|
272 |
+
const error = await response.text();
|
273 |
+
console.error('Server error response:', error);
|
274 |
+
// Since we removed the main results display, show error in the validation div
|
275 |
+
messageDiv.textContent = `Server Error: ${error}`;
|
276 |
+
messageDiv.style.display = 'block';
|
277 |
+
}
|
278 |
+
} catch (error) {
|
279 |
+
console.error('Fetch API Error:', error);
|
280 |
+
messageDiv.textContent = `Client Error: ${error.message}`;
|
281 |
+
messageDiv.style.display = 'block';
|
282 |
+
} finally {
|
283 |
+
loading.style.display = 'none';
|
284 |
+
if (submitBtn) submitBtn.disabled = false;
|
285 |
+
}
|
286 |
+
}
|
287 |
+
|
288 |
+
function renderTable(details, rawFullReport) {
|
289 |
+
if (!details || details.length === 0) {
|
290 |
+
return '<p>No detailed memory breakdown available.</p>';
|
291 |
+
}
|
292 |
+
|
293 |
+
const headers = Object.keys(details[0]);
|
294 |
+
headers.push('Breakdown');
|
295 |
+
|
296 |
+
let table = '<table><thead><tr>';
|
297 |
+
headers.forEach(h => table += `<th>${h}</th>`);
|
298 |
+
table += '</tr></thead><tbody>';
|
299 |
+
|
300 |
+
details.forEach(row => {
|
301 |
+
const ppRank = row.pp_rank;
|
302 |
+
// FIX: Look in the full raw report array passed in.
|
303 |
+
const rawDataForRank = rawFullReport ? rawFullReport.find(r => r.pp_rank === ppRank) : null;
|
304 |
+
|
305 |
+
// FIX: Change to `let` to allow modification for highlighting.
|
306 |
+
let modelBreakdown = (rawDataForRank && rawDataForRank.model_breakdown)
|
307 |
+
? rawDataForRank.model_breakdown
|
308 |
+
: 'No breakdown available.';
|
309 |
+
|
310 |
+
// Add syntax-like highlighting for params and activations
|
311 |
+
// Basic HTML escaping for safety before inserting spans
|
312 |
+
modelBreakdown = modelBreakdown.replace(/&/g, "&").replace(/</g, "<").replace(/>/g, ">");
|
313 |
+
modelBreakdown = modelBreakdown
|
314 |
+
.replace(/(n_params=[0-9.]+[a-zA-Z]*)/g, '<span class="highlight-red">$1</span>')
|
315 |
+
.replace(/(n_act=[0-9.]+[a-zA-Z]*)/g, '<span class="highlight-red">$1</span>');
|
316 |
+
|
317 |
+
// Main row with data
|
318 |
+
table += `<tr data-pp-rank="${ppRank}">`;
|
319 |
+
headers.forEach(h => {
|
320 |
+
if (h !== 'Breakdown') {
|
321 |
+
table += `<td>${row[h]}</td>`;
|
322 |
+
}
|
323 |
+
});
|
324 |
+
table += `<td><button class="action-btn raw-per-rank-btn" data-pp-rank="${ppRank}">Raw</button></td>`;
|
325 |
+
table += '</tr>';
|
326 |
+
|
327 |
+
// Hidden row for the breakdown
|
328 |
+
table += `<tr class="raw-breakdown-row" data-pp-rank="${ppRank}" style="display: none;">
|
329 |
+
<td colspan="${headers.length}">
|
330 |
+
<pre>${modelBreakdown}</pre>
|
331 |
+
</td>
|
332 |
+
</tr>`;
|
333 |
+
});
|
334 |
+
|
335 |
+
table += '</tbody></table>';
|
336 |
+
return table;
|
337 |
+
}
|
338 |
+
|
339 |
+
function saveToHistory(params, resultData) {
|
340 |
+
let history = JSON.parse(localStorage.getItem('estimationHistory')) || [];
|
341 |
+
const historyEntry = {
|
342 |
+
params: params,
|
343 |
+
result: resultData, // Store the full result object { processed_report, raw_report }
|
344 |
+
id: new Date().getTime()
|
345 |
+
};
|
346 |
+
history.unshift(historyEntry); // Add to the beginning
|
347 |
+
if (history.length > 20) { // Keep history size manageable
|
348 |
+
history.pop();
|
349 |
+
}
|
350 |
+
localStorage.setItem('estimationHistory', JSON.stringify(history));
|
351 |
+
}
|
352 |
+
|
353 |
+
function updateHistoryView() {
|
354 |
+
const history = JSON.parse(localStorage.getItem('estimationHistory')) || [];
|
355 |
+
const historyTableBody = document.querySelector('#history-table tbody');
|
356 |
+
const historyWrapper = document.getElementById('history-wrapper');
|
357 |
+
historyTableBody.innerHTML = '';
|
358 |
+
|
359 |
+
if (history.length === 0) {
|
360 |
+
historyWrapper.style.display = 'none';
|
361 |
+
return;
|
362 |
+
}
|
363 |
+
|
364 |
+
historyWrapper.style.display = 'block';
|
365 |
+
|
366 |
+
history.forEach(item => {
|
367 |
+
const row = document.createElement('tr');
|
368 |
+
|
369 |
+
const params = item.params;
|
370 |
+
const resultData = item.result || {};
|
371 |
+
|
372 |
+
// FIX: Handle both old and new data structures for compatibility.
|
373 |
+
const details = (resultData.report && resultData.report.details) ? resultData.report.details : (resultData.processed_report || []);
|
374 |
+
const pp0Result = details.find(r => r.pp_rank === 0) || details[0] || {};
|
375 |
+
|
376 |
+
const modelName = params.hf_model_path.split('/').pop();
|
377 |
+
|
378 |
+
// Build parallelism string, e.g., "TP2 PP2 VPP2"
|
379 |
+
const parallelismParts = [];
|
380 |
+
['tp', 'pp', 'ep', 'cp', 'vpp', 'etp'].forEach(p => {
|
381 |
+
const value = params[p];
|
382 |
+
if (value && value > 1) {
|
383 |
+
parallelismParts.push(`${p.toUpperCase()}${value}`);
|
384 |
+
}
|
385 |
+
});
|
386 |
+
const parallelismInfo = parallelismParts.join(' ') || 'No Parallelism';
|
387 |
+
|
388 |
+
const overheadGb = params.overhead ? parseInt(params.overhead) : 0;
|
389 |
+
const baseTotal = details.length > 0 ? Math.max(...details.map(r => r.total_gb || 0)) : null;
|
390 |
+
const totalGb = baseTotal !== null ? (baseTotal + overheadGb).toFixed(2) : 'N/A';
|
391 |
+
|
392 |
+
const seqLen = params.seq_len || 0;
|
393 |
+
const formattedSeqLen = seqLen >= 1024 ? `${seqLen / 1024}k` : seqLen;
|
394 |
+
const sequenceInfo = `${params.mbs || 'N/A'}*${formattedSeqLen}`;
|
395 |
+
|
396 |
+
row.innerHTML = `
|
397 |
+
<td>
|
398 |
+
<div>${modelName}</div>
|
399 |
+
<div class="model-meta-info">
|
400 |
+
<span>GPUs: ${params.num_gpus || 'N/A'}</span>
|
401 |
+
<span>${parallelismInfo}</span>
|
402 |
+
<span>Sequence: ${sequenceInfo}</span>
|
403 |
+
</div>
|
404 |
+
</td>
|
405 |
+
<td>${pp0Result.weight_optimizer_gb || 'N/A'}</td>
|
406 |
+
<td>${pp0Result.activation_gb || 'N/A'}</td>
|
407 |
+
<td>${totalGb}</td>
|
408 |
+
<td>
|
409 |
+
<button class="restore-btn" data-id="${item.id}">Restore</button>
|
410 |
+
<button class="detail-btn" data-id="${item.id}">Detail</button>
|
411 |
+
<button class="delete-btn" data-id="${item.id}">Delete</button>
|
412 |
+
</td>
|
413 |
+
`;
|
414 |
+
historyTableBody.appendChild(row);
|
415 |
+
});
|
416 |
+
}
|
417 |
+
|
418 |
+
async function handleHistoryAction(e) {
|
419 |
+
const button = e.target.closest('button');
|
420 |
+
if (!button) return;
|
421 |
+
|
422 |
+
// Handle breakdown toggle first
|
423 |
+
if (button.classList.contains('breakdown-btn')) {
|
424 |
+
const ppRank = button.dataset.ppRank;
|
425 |
+
const detailTable = button.closest('table');
|
426 |
+
if (!detailTable) return;
|
427 |
+
|
428 |
+
const breakdownRow = detailTable.querySelector(`tr.breakdown-row[data-pp-rank="${ppRank}"]`);
|
429 |
+
if (!breakdownRow) return;
|
430 |
+
|
431 |
+
const isVisible = breakdownRow.style.display !== 'none';
|
432 |
+
breakdownRow.style.display = isVisible ? 'none' : 'table-row';
|
433 |
+
button.textContent = isVisible ? 'Breakdown' : 'Hide';
|
434 |
+
return; // Do not continue to other handlers
|
435 |
+
}
|
436 |
+
|
437 |
+
if (!button.matches('.detail-btn, .restore-btn, .delete-btn')) return;
|
438 |
+
|
439 |
+
const id = parseInt(button.dataset.id, 10);
|
440 |
+
const history = JSON.parse(localStorage.getItem('estimationHistory')) || [];
|
441 |
+
const entry = history.find(item => item.id === id);
|
442 |
+
|
443 |
+
if (!entry) {
|
444 |
+
console.error('History entry not found for id:', id);
|
445 |
+
return;
|
446 |
+
}
|
447 |
+
|
448 |
+
const row = button.closest('tr');
|
449 |
+
|
450 |
+
if (button.classList.contains('detail-btn')) {
|
451 |
+
const isDetailsVisible = row.nextElementSibling && row.nextElementSibling.classList.contains('detail-row');
|
452 |
+
|
453 |
+
document.querySelectorAll('.detail-row').forEach(detailRow => {
|
454 |
+
const prevRow = detailRow.previousElementSibling;
|
455 |
+
const detailBtn = prevRow.querySelector('.detail-btn');
|
456 |
+
if (detailRow !== row.nextElementSibling) {
|
457 |
+
detailRow.remove();
|
458 |
+
if (detailBtn) detailBtn.textContent = 'Detail';
|
459 |
+
}
|
460 |
+
});
|
461 |
+
|
462 |
+
if (isDetailsVisible) {
|
463 |
+
row.nextElementSibling.remove();
|
464 |
+
button.textContent = 'Detail';
|
465 |
+
} else {
|
466 |
+
const detailRow = document.createElement('tr');
|
467 |
+
detailRow.classList.add('detail-row');
|
468 |
+
const detailCell = detailRow.insertCell();
|
469 |
+
detailCell.colSpan = row.cells.length;
|
470 |
+
|
471 |
+
// FIX: Handle both old and new data structures for compatibility.
|
472 |
+
const report = entry.result.report;
|
473 |
+
const details = (report && report.details) ? report.details : (entry.result.processed_report || []);
|
474 |
+
const modelBreakdown = (report && report.model_breakdown) ? report.model_breakdown : null;
|
475 |
+
|
476 |
+
if (details && details.length > 0) {
|
477 |
+
const newTable = document.createElement('table');
|
478 |
+
// Determine if breakdown information exists per-row or globally
|
479 |
+
let headers = Object.keys(details[0]);
|
480 |
+
|
481 |
+
// If old-format data, there is a 'model_breakdown' key on each detail row
|
482 |
+
const hasRowBreakdown = headers.includes('model_breakdown');
|
483 |
+
|
484 |
+
// Remove the raw model_breakdown column from headers to keep table compact
|
485 |
+
if (hasRowBreakdown) {
|
486 |
+
headers = headers.filter(h => h !== 'model_breakdown');
|
487 |
+
}
|
488 |
+
|
489 |
+
// Include global breakdown if provided, or row breakdowns if present
|
490 |
+
const includeBreakdown = hasRowBreakdown || (modelBreakdown && typeof modelBreakdown === 'string');
|
491 |
+
|
492 |
+
if (includeBreakdown) {
|
493 |
+
headers.push('Breakdown');
|
494 |
+
}
|
495 |
+
|
496 |
+
const headerRow = newTable.insertRow();
|
497 |
+
headers.forEach(h => {
|
498 |
+
const th = document.createElement('th');
|
499 |
+
th.textContent = h;
|
500 |
+
headerRow.appendChild(th);
|
501 |
+
});
|
502 |
+
|
503 |
+
details.forEach(detail => {
|
504 |
+
const newRow = newTable.insertRow();
|
505 |
+
headers.forEach(header => {
|
506 |
+
if (header === 'Breakdown') {
|
507 |
+
const cell = newRow.insertCell();
|
508 |
+
cell.innerHTML = `<button class="breakdown-btn" data-pp-rank="${detail.pp_rank}">Breakdown</button>`;
|
509 |
+
} else {
|
510 |
+
const cell = newRow.insertCell();
|
511 |
+
let value = detail[header];
|
512 |
+
if (typeof value === 'number' && !Number.isInteger(value)) {
|
513 |
+
value = value.toFixed(4);
|
514 |
+
}
|
515 |
+
cell.textContent = value;
|
516 |
+
}
|
517 |
+
});
|
518 |
+
|
519 |
+
// Hidden breakdown row
|
520 |
+
if (includeBreakdown) {
|
521 |
+
const breakdownRow = newTable.insertRow();
|
522 |
+
breakdownRow.classList.add('breakdown-row');
|
523 |
+
breakdownRow.dataset.ppRank = detail.pp_rank;
|
524 |
+
breakdownRow.style.display = 'none';
|
525 |
+
const breakdownCell = breakdownRow.insertCell();
|
526 |
+
breakdownCell.colSpan = headers.length;
|
527 |
+
const rowSpecificBreakdown = hasRowBreakdown ? (detail.model_breakdown || '') : modelBreakdown;
|
528 |
+
const htmlBreakdown = ansiToHtml(rowSpecificBreakdown);
|
529 |
+
breakdownCell.innerHTML = `<pre class="model-breakdown-view">${htmlBreakdown || 'No breakdown available.'}</pre>`;
|
530 |
+
}
|
531 |
+
});
|
532 |
+
|
533 |
+
detailCell.appendChild(newTable);
|
534 |
+
} else {
|
535 |
+
detailCell.innerHTML = 'No detailed per-rank results available.';
|
536 |
+
}
|
537 |
+
|
538 |
+
row.after(detailRow);
|
539 |
+
button.textContent = 'Hide';
|
540 |
+
}
|
541 |
+
} else if (button.classList.contains('restore-btn')) {
|
542 |
+
restoreForm(entry.params);
|
543 |
+
} else if (button.classList.contains('delete-btn')) {
|
544 |
+
deleteHistoryEntry(id);
|
545 |
+
}
|
546 |
+
}
|
547 |
+
|
548 |
+
function deleteHistoryEntry(id) {
|
549 |
+
let history = JSON.parse(localStorage.getItem('estimationHistory')) || [];
|
550 |
+
const updatedHistory = history.filter(item => item.id != id);
|
551 |
+
localStorage.setItem('estimationHistory', JSON.stringify(updatedHistory));
|
552 |
+
updateHistoryView();
|
553 |
+
|
554 |
+
// If history is now empty, hide the whole output container
|
555 |
+
if (updatedHistory.length === 0) {
|
556 |
+
// document.getElementById('output-container').style.display = 'none';
|
557 |
+
}
|
558 |
+
}
|
559 |
+
|
560 |
+
function clearHistory() {
|
561 |
+
localStorage.removeItem('estimationHistory');
|
562 |
+
updateHistoryView();
|
563 |
+
// document.getElementById('output-container').style.display = 'none';
|
564 |
+
}
|
565 |
+
|
566 |
+
|
567 |
+
function restoreForm(params) {
|
568 |
+
if (!params) return;
|
569 |
+
|
570 |
+
const setElementValue = (id, value, defaultValue = '') => {
|
571 |
+
const element = document.getElementById(id);
|
572 |
+
if (element) {
|
573 |
+
if (element.type === 'checkbox') {
|
574 |
+
element.checked = value ?? defaultValue;
|
575 |
+
} else {
|
576 |
+
element.value = value ?? defaultValue;
|
577 |
+
}
|
578 |
+
}
|
579 |
+
};
|
580 |
+
|
581 |
+
setElementValue('num-gpus', params.num_gpus, 8);
|
582 |
+
setElementValue('mbs', params.mbs, 1);
|
583 |
+
setElementValue('seq-len', params.seq_len, 4096);
|
584 |
+
setElementValue('use-distributed-optimizer', params.use_distributed_optimizer, true);
|
585 |
+
setElementValue('recompute_granularity', params.recompute_granularity, 'selective');
|
586 |
+
setElementValue('recompute_method', params.recompute_method, 'uniform');
|
587 |
+
setElementValue('recompute_num_layers', params.recompute_num_layers, 1);
|
588 |
+
setElementValue('tp', params.tp, 1);
|
589 |
+
setElementValue('pp', params.pp, 1);
|
590 |
+
setElementValue('ep', params.ep, 1);
|
591 |
+
setElementValue('cp', params.cp, 1);
|
592 |
+
setElementValue('vpp', params.vpp);
|
593 |
+
setElementValue('etp', params.etp);
|
594 |
+
setElementValue('num_layers_in_first_pipeline_stage', params.num_layers_in_first_pipeline_stage);
|
595 |
+
setElementValue('num_layers_in_last_pipeline_stage', params.num_layers_in_last_pipeline_stage);
|
596 |
+
setElementValue('overhead', params.overhead, 10);
|
597 |
+
|
598 |
+
const modelSelect = document.getElementById('model-select');
|
599 |
+
if (modelSelect && params.hf_model_path) {
|
600 |
+
modelSelect.value = params.hf_model_path;
|
601 |
+
}
|
602 |
+
|
603 |
+
// Manually trigger change event for UI updates
|
604 |
+
const recomputeSelect = document.getElementById('recompute_granularity');
|
605 |
+
if (recomputeSelect) {
|
606 |
+
recomputeSelect.dispatchEvent(new Event('change'));
|
607 |
+
}
|
608 |
+
}
|
609 |
+
|
610 |
+
function updateParallelismOptions() {
|
611 |
+
const numGpusInput = document.getElementById('num-gpus');
|
612 |
+
if (!numGpusInput) return;
|
613 |
+
|
614 |
+
const numGpus = parseInt(numGpusInput.value);
|
615 |
+
if (isNaN(numGpus) || numGpus <= 0) {
|
616 |
+
return; // Don't update if GPU count is invalid
|
617 |
+
}
|
618 |
+
|
619 |
+
const tpSelect = document.getElementById('tp');
|
620 |
+
const epSelect = document.getElementById('ep');
|
621 |
+
const cpSelect = document.getElementById('cp');
|
622 |
+
|
623 |
+
// PP is now a manual input, so we only handle TP, EP, CP here.
|
624 |
+
const selects = [tpSelect, epSelect, cpSelect];
|
625 |
+
|
626 |
+
const powersOfTwo = [1];
|
627 |
+
for (let i = 1; (1 << i) <= numGpus; i++) {
|
628 |
+
powersOfTwo.push(1 << i);
|
629 |
+
}
|
630 |
+
|
631 |
+
selects.forEach(select => {
|
632 |
+
if (!select) return;
|
633 |
+
const currentVal = select.value;
|
634 |
+
select.innerHTML = ''; // Clear existing options
|
635 |
+
|
636 |
+
powersOfTwo.forEach(val => {
|
637 |
+
const option = document.createElement('option');
|
638 |
+
option.value = val;
|
639 |
+
option.textContent = val;
|
640 |
+
select.appendChild(option);
|
641 |
+
});
|
642 |
+
|
643 |
+
// Try to restore the previous value, otherwise default to 1
|
644 |
+
if (powersOfTwo.includes(parseInt(currentVal))) {
|
645 |
+
select.value = currentVal;
|
646 |
+
} else {
|
647 |
+
select.value = 1;
|
648 |
+
}
|
649 |
+
});
|
650 |
+
}
|
651 |
+
|
652 |
+
function validateParallelismLive() {
|
653 |
+
const messageDiv = document.getElementById('validation-message');
|
654 |
+
// Pass isSubmission = false to getFormValues to prevent alerts during live validation
|
655 |
+
const formValues = getFormValues(false);
|
656 |
+
|
657 |
+
if (!formValues) {
|
658 |
+
messageDiv.textContent = '';
|
659 |
+
return true;
|
660 |
+
}
|
661 |
+
|
662 |
+
const { num_gpus, tp, pp, ep, cp, etp, custom_hf_config } = formValues;
|
663 |
+
// The key is the same in the HF config, so this logic remains valid.
|
664 |
+
const num_kv_heads = custom_hf_config?.num_key_value_heads || null;
|
665 |
+
|
666 |
+
let errors = [];
|
667 |
+
if (tp * pp * cp > num_gpus) {
|
668 |
+
errors.push(`TP*PP*CP (${tp*pp*cp}) > GPUs (${num_gpus}).`);
|
669 |
+
}
|
670 |
+
if (etp) {
|
671 |
+
if (etp * pp * cp * ep > num_gpus) {
|
672 |
+
errors.push(`ETP*PP*CP*EP (${etp*pp*cp*ep}) > GPUs (${num_gpus}).`);
|
673 |
+
}
|
674 |
+
} else {
|
675 |
+
if (tp * pp * cp * ep > num_gpus) {
|
676 |
+
errors.push(`TP*PP*CP*EP (${tp*pp*cp*ep}) > GPUs (${num_gpus}) when ETP is not set.`);
|
677 |
+
}
|
678 |
+
}
|
679 |
+
if (num_kv_heads && tp > num_kv_heads) {
|
680 |
+
errors.push(`TP (${tp}) > Num KV Heads (${num_kv_heads}).`);
|
681 |
+
}
|
682 |
+
|
683 |
+
if (errors.length > 0) {
|
684 |
+
messageDiv.textContent = 'Validation Error: ' + errors.join(' ');
|
685 |
+
messageDiv.style.display = 'block';
|
686 |
+
} else {
|
687 |
+
messageDiv.textContent = '';
|
688 |
+
messageDiv.style.display = 'none';
|
689 |
+
}
|
690 |
+
return errors.length === 0;
|
691 |
+
}
|
692 |
+
|
693 |
+
function toggleEpBasedOnConfig() {
|
694 |
+
const editor = document.getElementById('config-editor');
|
695 |
+
const epSelect = document.getElementById('ep');
|
696 |
+
if (!editor || !epSelect) return;
|
697 |
+
|
698 |
+
let config = null;
|
699 |
+
try {
|
700 |
+
if (editor.value) {
|
701 |
+
config = JSON.parse(editor.value);
|
702 |
+
}
|
703 |
+
} catch (e) {
|
704 |
+
// Invalid JSON, disable EP as a safety measure
|
705 |
+
epSelect.disabled = true;
|
706 |
+
return;
|
707 |
+
}
|
708 |
+
|
709 |
+
if (config && config.num_experts && config.num_experts > 0) {
|
710 |
+
epSelect.disabled = false;
|
711 |
+
} else {
|
712 |
+
epSelect.disabled = true;
|
713 |
+
epSelect.value = 1; // Reset to 1 if disabled
|
714 |
+
}
|
715 |
+
}
|
webui/style.css
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
body {
|
2 |
+
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
|
3 |
+
line-height: 1.6;
|
4 |
+
background-color: #f4f4f4;
|
5 |
+
color: #333;
|
6 |
+
margin: 0;
|
7 |
+
padding: 1em;
|
8 |
+
}
|
9 |
+
|
10 |
+
.container {
|
11 |
+
max-width: 1600px;
|
12 |
+
margin: auto;
|
13 |
+
background: #fff;
|
14 |
+
padding: 2em;
|
15 |
+
border-radius: 8px;
|
16 |
+
box-shadow: 0 0 20px rgba(0, 0, 0, 0.05);
|
17 |
+
}
|
18 |
+
|
19 |
+
.main-layout {
|
20 |
+
display: flex;
|
21 |
+
flex-direction: column; /* Main axis is vertical */
|
22 |
+
gap: 2em;
|
23 |
+
}
|
24 |
+
|
25 |
+
.top-section {
|
26 |
+
display: flex;
|
27 |
+
flex-direction: row; /* Children are horizontal */
|
28 |
+
gap: 2em;
|
29 |
+
}
|
30 |
+
|
31 |
+
.config-column, .output-column {
|
32 |
+
flex: 1; /* Each column takes up half the space */
|
33 |
+
display: flex;
|
34 |
+
flex-direction: column;
|
35 |
+
}
|
36 |
+
|
37 |
+
/* The editor wrapper should grow to fill the space */
|
38 |
+
.config-editor-wrapper {
|
39 |
+
flex-grow: 1;
|
40 |
+
display: flex;
|
41 |
+
flex-direction: column;
|
42 |
+
}
|
43 |
+
|
44 |
+
#config-editor {
|
45 |
+
flex-grow: 1; /* The textarea itself should grow */
|
46 |
+
width: 100%;
|
47 |
+
box-sizing: border-box; /* Include padding and border in the element's total width and height */
|
48 |
+
resize: vertical; /* Allow vertical resizing */
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
.bottom-section {
|
53 |
+
width: 100%;
|
54 |
+
}
|
55 |
+
|
56 |
+
.form-row {
|
57 |
+
display: flex;
|
58 |
+
gap: 1em;
|
59 |
+
align-items: flex-end;
|
60 |
+
}
|
61 |
+
|
62 |
+
.form-row .form-group {
|
63 |
+
flex: 1; /* Allow groups to grow and fill space */
|
64 |
+
margin-bottom: 0.8em;
|
65 |
+
}
|
66 |
+
|
67 |
+
.form-group {
|
68 |
+
margin-bottom: 0.8em; /* Reduced from default */
|
69 |
+
}
|
70 |
+
|
71 |
+
.form-group label {
|
72 |
+
display: block;
|
73 |
+
margin-bottom: 0.25em; /* Reduced */
|
74 |
+
font-weight: 500;
|
75 |
+
}
|
76 |
+
|
77 |
+
.form-group label.inline-label {
|
78 |
+
display: inline-block;
|
79 |
+
margin-left: 0.5em;
|
80 |
+
font-weight: normal;
|
81 |
+
}
|
82 |
+
|
83 |
+
.form-group input[type="number"],
|
84 |
+
.form-group select {
|
85 |
+
width: 100%;
|
86 |
+
padding: 6px 10px; /* Reduced padding */
|
87 |
+
border-radius: 4px;
|
88 |
+
border: 1px solid #ccc;
|
89 |
+
box-sizing: border-box;
|
90 |
+
}
|
91 |
+
|
92 |
+
button {
|
93 |
+
background-color: #3498db;
|
94 |
+
color: white;
|
95 |
+
padding: 10px 15px;
|
96 |
+
border: none;
|
97 |
+
border-radius: 4px;
|
98 |
+
cursor: pointer;
|
99 |
+
font-size: 16px;
|
100 |
+
margin-top: 10px;
|
101 |
+
}
|
102 |
+
|
103 |
+
button:hover {
|
104 |
+
background-color: #2980b9;
|
105 |
+
}
|
106 |
+
|
107 |
+
#results {
|
108 |
+
background-color: #ecf0f1;
|
109 |
+
padding: 15px;
|
110 |
+
border-radius: 4px;
|
111 |
+
white-space: pre-wrap;
|
112 |
+
word-wrap: break-word;
|
113 |
+
min-height: 100px;
|
114 |
+
}
|
115 |
+
|
116 |
+
.results-container {
|
117 |
+
margin-top: 20px;
|
118 |
+
}
|
119 |
+
|
120 |
+
/* New styles for results table */
|
121 |
+
table {
|
122 |
+
width: 100%;
|
123 |
+
border-collapse: collapse;
|
124 |
+
margin-top: 20px;
|
125 |
+
}
|
126 |
+
|
127 |
+
th, td {
|
128 |
+
border: 1px solid #ddd;
|
129 |
+
padding: 12px;
|
130 |
+
text-align: left;
|
131 |
+
}
|
132 |
+
|
133 |
+
th {
|
134 |
+
background-color: #f2f2f2;
|
135 |
+
font-weight: bold;
|
136 |
+
}
|
137 |
+
|
138 |
+
tbody tr:nth-child(even) {
|
139 |
+
background-color: #f9f9f9;
|
140 |
+
}
|
141 |
+
|
142 |
+
tbody tr:hover {
|
143 |
+
background-color: #f1f1f1;
|
144 |
+
}
|
145 |
+
|
146 |
+
.error {
|
147 |
+
color: #e74c3c;
|
148 |
+
font-weight: bold;
|
149 |
+
}
|
150 |
+
|
151 |
+
.button-container {
|
152 |
+
grid-column: 1 / -1; /* Span across all columns */
|
153 |
+
text-align: center;
|
154 |
+
margin-top: 20px;
|
155 |
+
}
|
156 |
+
|
157 |
+
/* History Section */
|
158 |
+
.history-container {
|
159 |
+
margin-top: 40px;
|
160 |
+
border-top: 1px solid #e0e0e0;
|
161 |
+
padding-top: 20px;
|
162 |
+
}
|
163 |
+
|
164 |
+
.history-container h2 {
|
165 |
+
display: flex;
|
166 |
+
justify-content: space-between;
|
167 |
+
align-items: center;
|
168 |
+
}
|
169 |
+
|
170 |
+
#history-list table {
|
171 |
+
margin-top: 10px;
|
172 |
+
}
|
173 |
+
|
174 |
+
.small-button {
|
175 |
+
padding: 4px 8px;
|
176 |
+
font-size: 0.8em;
|
177 |
+
background-color: #e74c3c;
|
178 |
+
}
|
179 |
+
|
180 |
+
.small-button:hover {
|
181 |
+
background-color: #c0392b;
|
182 |
+
}
|
183 |
+
|
184 |
+
.history-item-actions {
|
185 |
+
display: flex;
|
186 |
+
gap: 10px;
|
187 |
+
}
|
188 |
+
|
189 |
+
#output-container {
|
190 |
+
margin-top: 2em;
|
191 |
+
padding: 1.5em;
|
192 |
+
background-color: #f9f9f9;
|
193 |
+
border: 1px solid #ddd;
|
194 |
+
border-radius: 8px;
|
195 |
+
}
|
196 |
+
|
197 |
+
#results-wrapper h3, #history-wrapper h3 {
|
198 |
+
margin-top: 0;
|
199 |
+
border-bottom: 2px solid #eee;
|
200 |
+
padding-bottom: 0.5em;
|
201 |
+
margin-bottom: 1em;
|
202 |
+
}
|
203 |
+
|
204 |
+
#results-display table {
|
205 |
+
width: 100%;
|
206 |
+
border-collapse: collapse;
|
207 |
+
}
|
208 |
+
|
209 |
+
#results-display th, #results-display td {
|
210 |
+
padding: 8px 12px;
|
211 |
+
border: 1px solid #ddd;
|
212 |
+
text-align: left;
|
213 |
+
}
|
214 |
+
|
215 |
+
#results-display th {
|
216 |
+
background-color: #f2f2f2;
|
217 |
+
}
|
218 |
+
|
219 |
+
#history-table {
|
220 |
+
width: 100%;
|
221 |
+
border-collapse: collapse;
|
222 |
+
}
|
223 |
+
|
224 |
+
#history-table th, #history-table td {
|
225 |
+
padding: 8px 12px;
|
226 |
+
border: 1px solid #ddd;
|
227 |
+
text-align: left;
|
228 |
+
}
|
229 |
+
|
230 |
+
#history-table th {
|
231 |
+
background-color: #f2f2f2;
|
232 |
+
}
|
233 |
+
|
234 |
+
#history-table td:last-child {
|
235 |
+
text-align: right;
|
236 |
+
}
|
237 |
+
|
238 |
+
#raw-json-output {
|
239 |
+
background-color: #2d2d2d;
|
240 |
+
color: #f1f1f1;
|
241 |
+
padding: 1em;
|
242 |
+
border-radius: 5px;
|
243 |
+
max-height: 500px;
|
244 |
+
overflow-y: auto;
|
245 |
+
}
|
246 |
+
|
247 |
+
#clear-history {
|
248 |
+
background-color: #dc3545;
|
249 |
+
}
|
250 |
+
|
251 |
+
#clear-history:hover {
|
252 |
+
background-color: #c82333;
|
253 |
+
}
|
254 |
+
|
255 |
+
.error-message {
|
256 |
+
color: #dc3545;
|
257 |
+
background-color: #f8d7da;
|
258 |
+
border: 1px solid #f5c6cb;
|
259 |
+
padding: 0.75rem 1.25rem;
|
260 |
+
margin-top: 1rem;
|
261 |
+
margin-bottom: 1rem;
|
262 |
+
border-radius: 0.25rem;
|
263 |
+
text-align: center;
|
264 |
+
}
|
265 |
+
|
266 |
+
/* Responsive Design for smaller screens */
|
267 |
+
@media (max-width: 992px) {
|
268 |
+
.top-section {
|
269 |
+
flex-direction: column;
|
270 |
+
}
|
271 |
+
}
|
272 |
+
|
273 |
+
.history-detail-row td {
|
274 |
+
background-color: #333;
|
275 |
+
padding: 15px;
|
276 |
+
border-top: 2px solid #555;
|
277 |
+
text-align: left; /* Align content to the left */
|
278 |
+
}
|
279 |
+
|
280 |
+
.history-detail-row pre {
|
281 |
+
background-color: #1e1e1e;
|
282 |
+
color: #d4d4d4;
|
283 |
+
padding: 10px;
|
284 |
+
border-radius: 4px;
|
285 |
+
white-space: pre-wrap;
|
286 |
+
word-break: break-all;
|
287 |
+
}
|
288 |
+
|
289 |
+
.history-detail-row table {
|
290 |
+
width: 100%;
|
291 |
+
border-collapse: collapse;
|
292 |
+
margin: 0;
|
293 |
+
}
|
294 |
+
|
295 |
+
.history-detail-row table th {
|
296 |
+
background-color: #e0e0e0;
|
297 |
+
color: #333;
|
298 |
+
padding: 8px 12px;
|
299 |
+
border: 1px solid #555;
|
300 |
+
}
|
301 |
+
|
302 |
+
.history-detail-row table td {
|
303 |
+
color: #d4d4d4;
|
304 |
+
padding: 8px 12px;
|
305 |
+
border: 1px solid #555;
|
306 |
+
background-color: #2a2a2a;
|
307 |
+
}
|
308 |
+
|
309 |
+
.model-breakdown-view {
|
310 |
+
max-height: 400px; /* Or any other suitable height */
|
311 |
+
overflow-y: auto;
|
312 |
+
overflow-x: auto;
|
313 |
+
background-color: #2d2d2d;
|
314 |
+
color: #f1f1f1;
|
315 |
+
padding: 1em;
|
316 |
+
border-radius: 5px;
|
317 |
+
white-space: pre-wrap; /* Ensures the pre content wraps */
|
318 |
+
margin: 0;
|
319 |
+
font-family: monospace;
|
320 |
+
font-size: 0.85em;
|
321 |
+
}
|
322 |
+
|
323 |
+
.model-meta-info {
|
324 |
+
font-size: 0.9em;
|
325 |
+
color: #666;
|
326 |
+
margin-top: 4px;
|
327 |
+
}
|
328 |
+
|
329 |
+
.model-meta-info span {
|
330 |
+
margin-right: 15px;
|
331 |
+
}
|
332 |
+
|
333 |
+
.action-btn.raw-btn {
|
334 |
+
background-color: #555;
|
335 |
+
color: white;
|
336 |
+
}
|
337 |
+
|
338 |
+
.highlight-red {
|
339 |
+
color: #ff6b6b;
|
340 |
+
}
|
341 |
+
|
342 |
+
.ansi-red { color: #e74c3c; }
|
343 |
+
.ansi-green { color: #2ecc71; }
|
344 |
+
.ansi-yellow { color: #f1c40f; }
|
345 |
+
.ansi-blue { color: #3498db; }
|
346 |
+
.ansi-magenta { color: #9b59b6; }
|
347 |
+
.ansi-cyan { color: #1abc9c; }
|
348 |
+
|
349 |
+
.breakdown-row td {
|
350 |
+
text-align: left !important;
|
351 |
+
}
|
352 |
+
|
353 |
+
.footer {
|
354 |
+
margin-top: 2em;
|
355 |
+
font-size: 0.85em;
|
356 |
+
color: #555;
|
357 |
+
text-align: center;
|
358 |
+
}
|
359 |
+
|
360 |
+
.footer a {
|
361 |
+
color: #2a77d4;
|
362 |
+
text-decoration: none;
|
363 |
+
}
|
364 |
+
|
365 |
+
.footer a:hover {
|
366 |
+
text-decoration: underline;
|
367 |
+
}
|
368 |
+
|
369 |
+
.disclaimer {
|
370 |
+
margin-top: 0.5em;
|
371 |
+
font-style: italic;
|
372 |
+
}
|
373 |
+
|
374 |
+
.disclaimer-banner {
|
375 |
+
background-color: #fff3cd;
|
376 |
+
color: #856404;
|
377 |
+
border: 1px solid #ffeeba;
|
378 |
+
padding: 10px 15px;
|
379 |
+
border-radius: 4px;
|
380 |
+
margin: 15px 0;
|
381 |
+
font-weight: bold;
|
382 |
+
text-align: center;
|
383 |
+
}
|