Yan Bai commited on
Commit
55e1701
·
1 Parent(s): 9eb3690
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3
2
+
3
+ # 安装额外依赖(如果基础镜像已包含部分依赖,pip 会自动跳过)
4
+ RUN pip install --no-cache-dir \
5
+ fastapi \
6
+ uvicorn[standard] \
7
+ mbridge \
8
+ termcolor \
9
+ ipdb
10
+ # 添加 Megatron-LM core_v0.12.2
11
+ RUN git clone -b core_v0.12.2 --depth 1 https://github.com/NVIDIA/Megatron-LM.git /opt/Megatron-LM
12
+
13
+ # 复制代码至工作目录
14
+ WORKDIR /app
15
+ COPY . /app
16
+
17
+ # HF Spaces 默认通过 $PORT 注入端口
18
+ ENV PYTHONPATH=/opt/Megatron-LM:$PYTHONPATH
19
+ ENV PORT=7860
20
+ EXPOSE 7860
21
+
22
+ # 启动 FastAPI 服务
23
+ CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port $PORT"]
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
app.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from webui.main import app
estimate.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ """Pretrain GPT."""
3
+ import warnings
4
+
5
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
6
+ warnings.filterwarnings("ignore", category=FutureWarning)
7
+ warnings.filterwarnings("ignore")
8
+ import os
9
+ import torch
10
+ from functools import partial
11
+ from contextlib import nullcontext
12
+ import inspect
13
+
14
+ from typing import Union
15
+ from megatron.training import get_args
16
+ from megatron.training import print_rank_0
17
+ from megatron.training import get_timers
18
+ from megatron.training import get_tokenizer
19
+ from megatron.core import mpu
20
+ from megatron.core.enums import ModelType
21
+ from megatron.core.datasets.blended_megatron_dataset_builder import (
22
+ BlendedMegatronDatasetBuilder,
23
+ )
24
+ from megatron.core.datasets.utils import get_blend_from_list
25
+ from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
26
+ from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
27
+ import megatron.legacy.model
28
+ from megatron.training import pretrain
29
+ from megatron.core.utils import StragglerDetector
30
+ from megatron.core.transformer.spec_utils import import_module
31
+ from megatron.training.utils import (
32
+ get_batch_on_this_cp_rank,
33
+ get_batch_on_this_tp_rank,
34
+ )
35
+ from megatron.training.arguments import core_transformer_config_from_args
36
+ from megatron.training.yaml_arguments import core_transformer_config_from_yaml
37
+ from megatron.core.models.gpt.gpt_layer_specs import (
38
+ get_gpt_layer_local_spec,
39
+ get_gpt_layer_with_transformer_engine_spec,
40
+ )
41
+ from megatron.training.initialize import initialize_megatron
42
+ from moe_mem_estimator.gpt_model import GPTModel
43
+ from moe_mem_estimator.base import (
44
+ is_pipeline_first_stage,
45
+ is_pipeline_last_stage,
46
+ set_global_config,
47
+ set_pipeline_model_parallel_rank,
48
+ )
49
+ from moe_mem_estimator.layers import MLASelfAttention, MoELayer
50
+
51
+
52
+ def _calculate_rank_memory(config, args, input_shape, pp_rank=0, pp_size=1):
53
+ """
54
+ Calculates the memory for a single pipeline parallel rank, containing the detailed logic.
55
+ """
56
+ # Build the model for the current rank
57
+ set_global_config(config)
58
+ pre_process = (pp_rank == 0)
59
+ post_process = (pp_rank == pp_size - 1)
60
+
61
+ use_te = True
62
+ if hasattr(config, 'spec') and config.spec is not None:
63
+ transformer_layer_spec = import_module(config.spec)
64
+ else:
65
+ if use_te:
66
+ transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
67
+ config.num_moe_experts, config.moe_grouped_gemm, config.qk_layernorm,
68
+ config.multi_latent_attention, config.fp8
69
+ )
70
+ else:
71
+ transformer_layer_spec = get_gpt_layer_local_spec(
72
+ config.num_moe_experts, config.moe_grouped_gemm, config.qk_layernorm,
73
+ config.multi_latent_attention
74
+ )
75
+
76
+ model = GPTModel(
77
+ config=config,
78
+ transformer_layer_spec=transformer_layer_spec,
79
+ vocab_size=args.padded_vocab_size,
80
+ max_sequence_length=args.max_position_embeddings,
81
+ pre_process=pre_process,
82
+ post_process=post_process,
83
+ fp16_lm_cross_entropy=getattr(config, 'fp16_lm_cross_entropy', False),
84
+ parallel_output=True,
85
+ share_embeddings_and_output_weights=args.tie_word_embeddings,
86
+ position_embedding_type="rope",
87
+ rotary_percent=getattr(args, 'rotary_percent', 1.0),
88
+ rotary_base=getattr(args, 'rotary_base', 10000),
89
+ rope_scaling=getattr(config, 'use_rope_scaling', False),
90
+ )
91
+
92
+ # --- Start of detailed memory calculation logic ---
93
+ num_parameter_this_shard = model.num_parameter()
94
+ num_activation = model.num_activation(input_shape)
95
+ output_shape = model.mock_forward(input_shape)
96
+
97
+ num_parameter_this_shard_sparse = sum(
98
+ layer.mlp.num_parameter() for layer in model.decoder.layers.modules
99
+ if isinstance(layer.mlp, MoELayer)
100
+ )
101
+ num_activation_this_shard_mlp = sum(
102
+ m.mlp.num_activation() for m in model.decoder.layers.modules
103
+ )
104
+
105
+ num_microbatch_this_pp_rank = pp_size - pp_rank
106
+ if config.num_layers_per_virtual_pipeline_stage is not None:
107
+ layers_this_pprank = len(model.decoder.layers.modules)
108
+ vpp_size = layers_this_pprank // config.num_layers_per_virtual_pipeline_stage
109
+ if vpp_size > 0:
110
+ num_microbatch_this_pp_rank = (pp_size * (vpp_size - 1) + (pp_size - pp_rank) * 2 - 1) / vpp_size
111
+
112
+ # Activation Recomputation
113
+ # The base activation number is for one microbatch. With pipeline parallelism,
114
+ # the total activation is multiplied by the number of microbatches in flight.
115
+ # Recomputation reduces this by re-calculating activations during the backward pass
116
+ # instead of storing them.
117
+
118
+ # This is the activation memory without any recomputation.
119
+ num_activation = (num_activation - model.num_act_post) * num_microbatch_this_pp_rank + model.num_act_post
120
+
121
+ if config.recompute_granularity == "full":
122
+ # This logic is transplanted from the more detailed `report_memory_usage_one_pp_rank`
123
+ recompute_num_layers = config.recompute_num_layers
124
+ num_layers = model.num_layers
125
+ # Activations of a model with recompute enabled.
126
+ # The activation of a layer is an input to the next layer.
127
+ # So, the total activation is the sum of the activations of all layers,
128
+ # plus the activation of the embedding layer.
129
+ # The activation of a layer is stored only if it is not recomputed.
130
+ common_act = (
131
+ model.num_act_pre
132
+ + model.num_act_between_layers * num_layers * num_microbatch_this_pp_rank
133
+ )
134
+ if config.recompute_method == "block":
135
+ num_layers_with_loss = num_layers - recompute_num_layers
136
+ if num_layers_with_loss == 0:
137
+ peak1 = common_act + model.num_act_post
138
+ peak2 = common_act + model.num_act_per_layer
139
+ recomputed_activation = max(peak1, peak2)
140
+ else:
141
+ recomputed_activation = (
142
+ common_act
143
+ + model.num_act_post
144
+ + model.num_act_per_layer
145
+ * num_layers_with_loss
146
+ * num_microbatch_this_pp_rank
147
+ )
148
+ elif config.recompute_method == "uniform":
149
+ peak1 = common_act + model.num_act_post
150
+ peak2 = (
151
+ common_act
152
+ + model.num_act_per_layer
153
+ * recompute_num_layers
154
+ * num_microbatch_this_pp_rank
155
+ )
156
+ recomputed_activation = max(peak1, peak2)
157
+
158
+ if isinstance(model.decoder.layers.modules[0].self_attention, MLASelfAttention):
159
+ recomputed_activation += model.decoder.layers.modules[0].self_attention.core_attention.num_activation()
160
+
161
+ num_activation = recomputed_activation
162
+
163
+ elif config.recompute_granularity == "selective":
164
+ # Selective recomputation is the default in Megatron-LM and is handled
165
+ # by Transformer Engine. The base `num_activation` calculation from `GPTModel`
166
+ # already reflects this. We just need to scale it by the number of in-flight microbatches.
167
+ # This is already the case, so we do nothing here.
168
+ pass
169
+
170
+
171
+ # Context Parallelism
172
+ if config.context_parallel_size > 1:
173
+ num_activation = (num_activation - num_activation_this_shard_mlp) / config.context_parallel_size + num_activation_this_shard_mlp
174
+
175
+ # Calculate bytes per parameter for optimizer states
176
+ if args.use_distributed_optimizer:
177
+ base_optim_bytes = 6 # FP16 weight, FP32 master weight
178
+ world_optim_bytes = 12 # FP32 grad, FP32 momentum, FP32 variance
179
+ else:
180
+ base_optim_bytes = 18 # All states on each GPU
181
+ world_optim_bytes = 0
182
+
183
+ num_bytes_per_parameter = base_optim_bytes + (world_optim_bytes / (args.data_parallel_size * config.context_parallel_size))
184
+
185
+ # Handle MoE optimizer state sharding if applicable
186
+ if num_parameter_this_shard_sparse > 0 and config.expert_model_parallel_size > 1:
187
+ moe_dp_size = args.data_parallel_size * config.tensor_model_parallel_size // (config.expert_model_parallel_size * args.expert_tensor_parallel_size)
188
+ num_bytes_per_parameter_moe = base_optim_bytes + (world_optim_bytes / moe_dp_size)
189
+
190
+ weight_and_optimizer_memory = (
191
+ (num_parameter_this_shard - num_parameter_this_shard_sparse) * num_bytes_per_parameter +
192
+ num_parameter_this_shard_sparse * num_bytes_per_parameter_moe
193
+ ) / NUM_BYTES_IN_GIGABYTE
194
+ else:
195
+ weight_and_optimizer_memory = (num_parameter_this_shard * num_bytes_per_parameter) / NUM_BYTES_IN_GIGABYTE
196
+
197
+ activation_memory = num_activation * 2 / NUM_BYTES_IN_GIGABYTE # Use GIGABYTE
198
+ total_memory = weight_and_optimizer_memory + activation_memory
199
+
200
+ report = {
201
+ "pp_rank": pp_rank,
202
+ "parameters_b": num_parameter_this_shard / 1e9,
203
+ "activation_b": num_activation / 1e9, # Renamed from _gb to _b
204
+ "weight_optimizer_gb": round(weight_and_optimizer_memory, 2),
205
+ "activation_gb": round(activation_memory, 2),
206
+ "total_gb": round(total_memory, 2),
207
+ "details": model.dump(),
208
+ "model_breakdown": str(model)
209
+ }
210
+ print(model)
211
+
212
+ return report, output_shape
213
+
214
+
215
+ def estimate_from_config(config, args):
216
+ """
217
+ Estimate memory usage from a given config and args, instead of global state.
218
+ This version iterates over pipeline parallel ranks for accurate estimation.
219
+ """
220
+ reports = []
221
+ input_shape = [args.micro_batch_size, args.seq_length]
222
+ pp_size = config.pipeline_model_parallel_size
223
+
224
+ if pp_size > 1:
225
+ for pp_rank in range(pp_size):
226
+ set_pipeline_model_parallel_rank(pp_rank)
227
+ report_for_rank, new_input_shape = _calculate_rank_memory(config, args, input_shape, pp_rank, pp_size)
228
+ reports.append(report_for_rank)
229
+ input_shape = new_input_shape # Pass output shape to the next stage
230
+ else:
231
+ report_for_rank, _ = _calculate_rank_memory(config, args, input_shape, 0, 1)
232
+ reports.append(report_for_rank)
233
+
234
+ return reports
235
+
236
+
237
+ def model_provider() -> GPTModel:
238
+ args = get_args()
239
+ use_te = args.transformer_impl == "transformer_engine"
240
+
241
+ # Experimental loading arguments from yaml
242
+ if args.yaml_cfg is not None:
243
+ config = core_transformer_config_from_yaml(args, "language_model")
244
+ else:
245
+ config = core_transformer_config_from_args(args)
246
+ assert not args.use_legacy_models
247
+
248
+ if args.spec is not None:
249
+ transformer_layer_spec = import_module(args.spec)
250
+ else:
251
+ if use_te:
252
+ transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
253
+ args.num_experts,
254
+ args.moe_grouped_gemm,
255
+ args.qk_layernorm,
256
+ args.multi_latent_attention,
257
+ args.fp8,
258
+ )
259
+ else:
260
+ transformer_layer_spec = get_gpt_layer_local_spec(
261
+ args.num_experts,
262
+ args.moe_grouped_gemm,
263
+ args.qk_layernorm,
264
+ args.multi_latent_attention,
265
+ )
266
+ set_global_config(config)
267
+ pre_process = is_pipeline_first_stage()
268
+ post_process = is_pipeline_last_stage()
269
+ # TODO fp8
270
+ model = GPTModel(
271
+ config=config,
272
+ transformer_layer_spec=transformer_layer_spec,
273
+ vocab_size=args.padded_vocab_size,
274
+ max_sequence_length=args.max_position_embeddings,
275
+ pre_process=pre_process,
276
+ post_process=post_process,
277
+ fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
278
+ parallel_output=True,
279
+ share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
280
+ position_embedding_type=args.position_embedding_type,
281
+ rotary_percent=args.rotary_percent,
282
+ rotary_base=args.rotary_base,
283
+ rope_scaling=args.use_rope_scaling,
284
+ )
285
+
286
+ return model
287
+
288
+
289
+ NUM_BYTES_IN_MEGABYTE = 1024 * 1024
290
+ NUM_BYTES_IN_GIGABYTE = 1024 * 1024 * 1024
291
+
292
+ def report_memory_usage():
293
+ args = get_args()
294
+ if args.yaml_cfg is not None:
295
+ config = core_transformer_config_from_yaml(args, "language_model")
296
+ else:
297
+ config = core_transformer_config_from_args(args)
298
+
299
+ input_shape = [args.micro_batch_size, args.seq_length]
300
+
301
+ if config.pipeline_model_parallel_size > 1:
302
+ for pp_rank in range(config.pipeline_model_parallel_size):
303
+ set_pipeline_model_parallel_rank(pp_rank)
304
+ print(f"\n----------[Pipeline_Parallelism_Rank={pp_rank}]----------")
305
+ input_shape = report_memory_usage_one_pp_rank(
306
+ input_shape, pp_rank, config.pipeline_model_parallel_size
307
+ )
308
+ else:
309
+ report_memory_usage_one_pp_rank(input_shape)
310
+
311
+
312
+ def report_memory_usage_one_pp_rank(
313
+ input_shape: list[int], pp_rank=0, pp_size=1
314
+ ) -> list[int]:
315
+ args = get_args()
316
+
317
+ print(f"{input_shape=}")
318
+ model: GPTModel = model_provider()
319
+ num_parameter_this_shard = model.num_parameter()
320
+ num_activation = model.num_activation(input_shape)
321
+ output_shape = model.mock_forward(input_shape)
322
+
323
+ num_parameter_this_shard_sparse = 0
324
+ for layer in model.decoder.layers.modules:
325
+ if isinstance(layer.mlp, MoELayer):
326
+ num_parameter_this_shard_sparse += layer.mlp.num_parameter()
327
+ if (
328
+ "shared_experts" in layer.mlp.__dir__()
329
+ and layer.mlp.shared_experts is not None
330
+ ):
331
+ num_parameter_this_shard_sparse -= (
332
+ layer.mlp.shared_experts.num_parameter()
333
+ )
334
+ num_activation_this_shard_mlp = sum(
335
+ [m.mlp.num_activation() for m in model.decoder.layers.modules]
336
+ )
337
+ num_microbatch_this_pp_rank = pp_size - pp_rank
338
+ # vpp
339
+ if args.num_layers_per_virtual_pipeline_stage is not None:
340
+ layers_this_pprank = model.decoder.layers.modules.__len__()
341
+ vpp_size = layers_this_pprank // args.num_layers_per_virtual_pipeline_stage
342
+ num_microbatch_this_pp_rank = (
343
+ pp_size * (vpp_size - 1) + (pp_size - pp_rank) * 2 - 1
344
+ ) / vpp_size
345
+
346
+ num_parameter_this_shard_sparse = 0
347
+ for layer in model.decoder.layers.modules:
348
+ if isinstance(layer.mlp, MoELayer):
349
+ num_parameter_this_shard_sparse += layer.mlp.num_parameter()
350
+ if (
351
+ "shared_experts" in layer.mlp.__dir__()
352
+ and layer.mlp.shared_experts is not None
353
+ ):
354
+ num_parameter_this_shard_sparse -= (
355
+ layer.mlp.shared_experts.num_parameter()
356
+ )
357
+ num_microbatch_this_pp_rank = pp_size - pp_rank
358
+ # vpp
359
+ if args.num_layers_per_virtual_pipeline_stage is not None:
360
+ layers_this_pprank = model.decoder.layers.modules.__len__()
361
+ vpp_size = layers_this_pprank // args.num_layers_per_virtual_pipeline_stage
362
+ num_microbatch_this_pp_rank = (
363
+ pp_size * (vpp_size - 1) + (pp_size - pp_rank) * 2 - 1
364
+ ) / vpp_size
365
+ model.__repr__()
366
+ print(model)
367
+ print(
368
+ f"Number of parameters in every GPU in billions: "
369
+ f"{num_parameter_this_shard / 10**9: .2f} where mlp part is {num_parameter_this_shard_sparse / 10**9: .2f}"
370
+ )
371
+ # recompute
372
+ if args.recompute_granularity == "full":
373
+ recompute_num_layers = args.recompute_num_layers
374
+ num_layers = model.num_layers
375
+ common_act = (
376
+ model.num_act_pre
377
+ + model.num_act_between_layers * num_layers * num_microbatch_this_pp_rank
378
+ ) # recompute with pipeline parallel
379
+ info = (
380
+ "With this recomputing setting, the number of activation achieve peak when "
381
+ )
382
+ if args.recompute_method == "block":
383
+ num_layers_with_loss = num_layers - recompute_num_layers
384
+ if num_layers_with_loss == 0:
385
+ peak1 = common_act + model.num_act_post
386
+ peak2 = common_act + model.num_act_per_layer
387
+ if peak1 > peak2:
388
+ info += "calculating loss"
389
+ else:
390
+ info += "back-propogating loss"
391
+ num_activation = max(peak1, peak2)
392
+ else:
393
+ info += (
394
+ f"calculating loss with {num_layers_with_loss} non-recompute layers"
395
+ )
396
+ num_activation = (
397
+ common_act
398
+ + model.num_act_post
399
+ + model.num_act_per_layer
400
+ * num_layers_with_loss
401
+ * num_microbatch_this_pp_rank
402
+ )
403
+ elif args.recompute_method == "uniform":
404
+ peak1 = common_act + model.num_act_post
405
+ peak2 = (
406
+ common_act
407
+ + model.num_act_per_layer
408
+ * recompute_num_layers
409
+ * num_microbatch_this_pp_rank
410
+ )
411
+ if peak1 > peak2:
412
+ info += "calculating loss"
413
+ else:
414
+ info += f"back-propogating loss recomputing every {recompute_num_layers} layers"
415
+ num_activation = max(peak1, peak2)
416
+ if isinstance(
417
+ model.decoder.layers.modules[0].self_attention, MLASelfAttention
418
+ ): # MLA recompute achieve peak at backward
419
+ num_activation += model.decoder.layers.modules[
420
+ 0
421
+ ].self_attention.core_attention.num_activation()
422
+ print(info)
423
+
424
+ else:
425
+ num_activation = (
426
+ num_activation - model.num_act_post
427
+ ) * num_microbatch_this_pp_rank + model.num_act_post
428
+
429
+ # CP
430
+ num_activation = (
431
+ num_activation - num_activation_this_shard_mlp
432
+ ) / args.context_parallel_size + num_activation_this_shard_mlp
433
+ if pp_size == 1:
434
+ print(
435
+ f"Number of activation in every GPU in billions: "
436
+ f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}"
437
+ )
438
+ else:
439
+ print(
440
+ f"Number of activation per microbatch in every GPU in billions: "
441
+ f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}"
442
+ f", {num_microbatch_this_pp_rank=}"
443
+ )
444
+ num_bytes_per_parameter = (
445
+ 18
446
+ if not args.use_distributed_optimizer
447
+ else 6 + (12 / args.data_parallel_size / args.context_parallel_size)
448
+ )
449
+ if args.expert_model_parallel_size * args.expert_tensor_parallel_size > 1:
450
+ num_bytes_per_parameter_dense = num_bytes_per_parameter
451
+ num_bytes_per_parameter_moe = (
452
+ 18
453
+ if not args.use_distributed_optimizer
454
+ else 6
455
+ + (
456
+ 12
457
+ / (
458
+ args.data_parallel_size
459
+ * args.context_parallel_size
460
+ * args.tensor_model_parallel_size
461
+ / args.expert_model_parallel_size
462
+ / args.expert_tensor_parallel_size
463
+ )
464
+ )
465
+ )
466
+ print(f"{num_bytes_per_parameter_dense=} {num_bytes_per_parameter_moe=}")
467
+
468
+ weight_and_optimizer_memory = (
469
+ (num_parameter_this_shard - num_parameter_this_shard_sparse)
470
+ * num_bytes_per_parameter_dense
471
+ + num_parameter_this_shard_sparse * num_bytes_per_parameter_moe
472
+ ) / NUM_BYTES_IN_GIGABYTE
473
+ else:
474
+ print(f"{num_bytes_per_parameter=}")
475
+ weight_and_optimizer_memory = (
476
+ num_parameter_this_shard * num_bytes_per_parameter / NUM_BYTES_IN_GIGABYTE
477
+ )
478
+
479
+ activation_memory = num_activation * 2 / NUM_BYTES_IN_GIGABYTE # only support fp16
480
+ total_memory = weight_and_optimizer_memory + activation_memory
481
+ print(
482
+ f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory/1024:.2f} GB, "
483
+ f"activation={activation_memory/1024:.2f} GB, total={total_memory/1024:.2f} GB\n"
484
+ )
485
+
486
+ # import ipdb
487
+
488
+ # ipdb.set_trace()
489
+ return output_shape
490
+ pass
491
+
492
+
493
+ if __name__ == "__main__":
494
+ initialize_megatron(allow_no_cuda=True, skip_mpu_initialization=True)
495
+
496
+ import ipdb
497
+
498
+ with ipdb.launch_ipdb_on_exception():
499
+ report_memory_usage()
moe_mem_estimator/__init__.py ADDED
File without changes
moe_mem_estimator/base.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ from megatron.core.transformer.transformer_config import TransformerConfig
4
+ from torch.nn.modules.module import _addindent
5
+ from termcolor import colored
6
+
7
+
8
+ def prehook_save_input_shape(func):
9
+ def wrapper(self, *input_shapes, **kw_input_shapes):
10
+ if len(input_shapes) + len(kw_input_shapes) == 0:
11
+ if "_input_shape" in self.__dict__:
12
+ return func(self, *self._input_shape, **self._kw_input_shapes)
13
+ else:
14
+ return 0
15
+ self._input_shape = input_shapes
16
+ self._kw_input_shapes = kw_input_shapes
17
+ return func(self, *self._input_shape, **self._kw_input_shapes)
18
+
19
+ return wrapper
20
+
21
+
22
+ class MetaBase(type):
23
+ def __new__(cls, name, bases, attrs):
24
+ if "num_activation" in attrs:
25
+ attrs["num_activation"] = prehook_save_input_shape(attrs["num_activation"])
26
+
27
+ return super().__new__(cls, name, bases, attrs)
28
+
29
+
30
+ class MemEstimator(metaclass=MetaBase):
31
+ def __init__(self, *args, **kwargs):
32
+ self._modules = {}
33
+ pass
34
+
35
+ def __repr__(self):
36
+ # We treat the extra repr like the sub-module, one item per line
37
+ extra_lines = []
38
+ # extra_repr = self.extra_repr()
39
+ # # empty string will be split into list ['']
40
+ # if extra_repr:
41
+ # extra_lines = extra_repr.split("\n")
42
+ child_lines = []
43
+ for key, module in self._modules.items():
44
+ mod_str = repr(module)
45
+ mod_str = _addindent(mod_str, 2)
46
+ child_lines.append("(" + key + "): " + mod_str)
47
+ lines = extra_lines + child_lines
48
+
49
+ stat = (
50
+ "\t/* n_params="
51
+ + colored(f"{self.num_parameter()/1024/1024:.2f}M", "red")
52
+ + "\tn_act="
53
+ + colored(f"{self.num_activation()/1024/1024:.2f}M", "green")
54
+ + " */"
55
+ )
56
+ main_str = self._get_name() + stat + " ("
57
+ if lines:
58
+ # simple one-liner info, which most builtin Modules will use
59
+ if len(extra_lines) == 1 and not child_lines:
60
+ main_str += extra_lines[0]
61
+ else:
62
+ main_str += "\n " + "\n ".join(lines) + "\n"
63
+
64
+ main_str += ")"
65
+ return main_str
66
+ return f"{self.__class__.__name__} n_param={self.num_parameter()}"
67
+
68
+ def dump(self):
69
+ ret = {}
70
+ ret['name'] = self._get_name()
71
+ ret['n_params'] = self.num_parameter()
72
+ ret['n_act'] = self.num_activation()
73
+ modules = {}
74
+ for key, module in self._modules.items():
75
+ modules[key] = module.dump()
76
+ if len(modules)>0:
77
+ ret['modules'] = modules
78
+ return ret
79
+
80
+
81
+ def _get_name(self):
82
+ return self.__class__.__name__
83
+
84
+ def num_parameter(self):
85
+ """
86
+ Calculate number of the model parameters
87
+ """
88
+ raise NotImplemented
89
+
90
+ def num_activation(self, input_shape: list[int]):
91
+ """
92
+ Calculate number of the activation with given input_shape.
93
+ Args:
94
+ input shape
95
+ """
96
+ raise NotImplemented
97
+
98
+ def mock_forward(self, input_shape: list[int]):
99
+ """
100
+ Mock the forward.
101
+ Args:
102
+ input shape
103
+ return:
104
+ output shape
105
+ """
106
+ raise NotImplemented
107
+
108
+ def __setattr__(self, name: str, value) -> None:
109
+ if isinstance(value, MemEstimator):
110
+ modules = self.__dict__.get("_modules")
111
+ modules[name] = value
112
+ else:
113
+ pass
114
+ return super().__setattr__(name, value)
115
+
116
+ def __delattr__(self, name):
117
+ modules = self.__dict__.get("_modules")
118
+ if name in modules:
119
+ del modules[name]
120
+ return super().__delattr__(name)
121
+
122
+
123
+ _global_config: TransformerConfig = None
124
+
125
+
126
+ def set_global_config(cfg):
127
+ global _global_config
128
+ _global_config = cfg
129
+
130
+
131
+ def get_tensor_model_parallel_world_size():
132
+ global _global_config
133
+ return _global_config.tensor_model_parallel_size
134
+
135
+
136
+ def get_tensor_model_parallel_rank():
137
+ return 0
138
+
139
+
140
+ def get_expert_tensor_parallel_world_size():
141
+ global _global_config
142
+ return _global_config.expert_tensor_parallel_size
143
+
144
+
145
+ def get_expert_tensor_parallel_rank():
146
+ return 0
147
+
148
+
149
+ _pp_rank = 0
150
+
151
+
152
+ def set_pipeline_model_parallel_rank(rank):
153
+ global _pp_rank
154
+ _pp_rank = rank
155
+
156
+
157
+ def get_pipeline_model_parallel_rank():
158
+ global _pp_rank
159
+ return _pp_rank
160
+
161
+
162
+ def get_virtual_pipeline_model_parallel_rank():
163
+ return 0
164
+
165
+
166
+ def get_pipeline_model_parallel_world_size():
167
+ global _global_config
168
+ return _global_config.pipeline_model_parallel_size
169
+
170
+
171
+ def get_expert_model_parallel_rank():
172
+ return 0
173
+
174
+
175
+ def get_expert_model_parallel_world_size():
176
+ global _global_config
177
+ return _global_config.expert_model_parallel_size
178
+
179
+
180
+ def get_virtual_pipeline_model_parallel_world_size():
181
+ global _global_config
182
+ return _global_config.virtual_pipeline_model_parallel_size
183
+
184
+
185
+ def is_pipeline_first_stage(ignore_virtual=False):
186
+ """Return True if in the first pipeline model-parallel stage, False otherwise."""
187
+ if not ignore_virtual:
188
+ if (
189
+ get_virtual_pipeline_model_parallel_world_size() is not None
190
+ and get_virtual_pipeline_model_parallel_rank() != 0
191
+ ):
192
+ return False
193
+ return get_pipeline_model_parallel_rank() == 0
194
+
195
+
196
+ def is_pipeline_last_stage(ignore_virtual=False):
197
+ """Return True if in the last pipeline-model-parallel stage, False otherwise."""
198
+ return get_pipeline_model_parallel_rank() == (
199
+ get_pipeline_model_parallel_world_size() - 1
200
+ )
201
+
202
+
203
+ def cum_mul(l: list):
204
+ try:
205
+ ret = 1
206
+ for one in l:
207
+ ret *= one
208
+ return ret
209
+ except:
210
+ return 0
211
+ __import__('ipdb').set_trace()
moe_mem_estimator/gpt_model.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import (
2
+ MemEstimator,
3
+ set_global_config,
4
+ get_tensor_model_parallel_world_size,
5
+ get_tensor_model_parallel_rank,
6
+ cum_mul,
7
+ )
8
+
9
+ from megatron.core.transformer.spec_utils import ModuleSpec
10
+ from typing import Dict, Literal, Optional, Union
11
+ from megatron.core.transformer.transformer_config import TransformerConfig
12
+ from megatron.core.model_parallel_config import ModelParallelConfig
13
+ from megatron.core.tensor_parallel.utils import VocabUtility
14
+ from megatron.core.transformer.transformer_block import (
15
+ TransformerBlockSubmodules,
16
+ _get_block_submodules,
17
+ )
18
+ from megatron.core.transformer.enums import ModelType
19
+ from .layers import LanguageModelEmbedding, TransformerBlock, ColumnParallelLinear
20
+
21
+
22
+ class GPTModel(MemEstimator):
23
+ def __init__(
24
+ self,
25
+ config: TransformerConfig,
26
+ transformer_layer_spec: ModuleSpec,
27
+ vocab_size: int,
28
+ max_sequence_length: int,
29
+ pre_process: bool = True,
30
+ post_process: bool = True,
31
+ fp16_lm_cross_entropy: bool = False,
32
+ parallel_output: bool = True,
33
+ share_embeddings_and_output_weights: bool = False,
34
+ position_embedding_type: Literal[
35
+ "learned_absolute", "rope", "none"
36
+ ] = "learned_absolute",
37
+ rotary_percent: float = 1.0,
38
+ rotary_base: int = 10000,
39
+ rope_scaling: bool = False,
40
+ seq_len_interpolation_factor: Optional[float] = None,
41
+ ):
42
+ super().__init__()
43
+
44
+ self.config = config
45
+ config.use_cpu_initialization = True
46
+
47
+ self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
48
+ self.vocab_size = vocab_size
49
+ self.max_sequence_length = max_sequence_length
50
+ self.pre_process = pre_process
51
+ self.post_process = post_process
52
+ self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
53
+ self.parallel_output = parallel_output
54
+ self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
55
+ self.position_embedding_type = position_embedding_type
56
+
57
+ # megatron core pipelining currently depends on model type
58
+ # TODO: remove this dependency ?
59
+ self.model_type = ModelType.encoder_or_decoder
60
+
61
+ # These 4 attributes are needed for TensorRT-LLM export.
62
+ self.max_position_embeddings = max_sequence_length
63
+ self.rotary_percent = rotary_percent
64
+ self.rotary_base = rotary_base
65
+ self.rotary_scaling = rope_scaling
66
+
67
+ if self.pre_process:
68
+ self.embedding = LanguageModelEmbedding(
69
+ config=self.config,
70
+ vocab_size=self.vocab_size,
71
+ max_sequence_length=self.max_sequence_length,
72
+ position_embedding_type=position_embedding_type,
73
+ )
74
+
75
+ # remove RotaryEmbedding
76
+
77
+ # Transformer.
78
+ self.decoder = TransformerBlock(
79
+ config=self.config,
80
+ spec=transformer_layer_spec,
81
+ pre_process=self.pre_process,
82
+ post_process=self.post_process,
83
+ )
84
+
85
+ # Output
86
+ if post_process:
87
+ if self.config.defer_embedding_wgrad_compute:
88
+ self.embedding_activation_buffer = []
89
+ self.grad_output_buffer = []
90
+ else:
91
+ self.embedding_activation_buffer = None
92
+ self.grad_output_buffer = None
93
+
94
+ self.output_layer = ColumnParallelLinear(
95
+ config.hidden_size,
96
+ self.vocab_size,
97
+ config=config,
98
+ init_method=config.init_method,
99
+ bias=False,
100
+ skip_bias_add=False,
101
+ gather_output=not self.parallel_output,
102
+ skip_weight_param_allocation=self.pre_process
103
+ and self.share_embeddings_and_output_weights,
104
+ embedding_activation_buffer=self.embedding_activation_buffer,
105
+ grad_output_buffer=self.grad_output_buffer,
106
+ )
107
+
108
+ def num_parameter(self):
109
+ ret = 0
110
+ if self.pre_process:
111
+ ret += self.embedding.num_parameter()
112
+ ret += self.decoder.num_parameter()
113
+ if self.post_process:
114
+ ret += self.output_layer.num_parameter()
115
+ return ret
116
+
117
+ def num_activation(self, input_shape: list[int]):
118
+ self._inited = True
119
+ ret = 0
120
+
121
+ self.num_act_pre = 0
122
+ self.num_act_post = 0
123
+ self.num_act_per_layer = 0
124
+ self.num_act_between_layers = 0
125
+ self.num_layers = self.decoder.layers.modules.__len__()
126
+
127
+ if self.pre_process:
128
+ self.num_act_pre = self.embedding.num_activation(input_shape)
129
+ ret += self.num_act_pre
130
+ input_shape = self.embedding.mock_forward(input_shape)
131
+ ret += self.decoder.num_activation(input_shape)
132
+ self.num_act_per_layer = self.decoder.layers.modules[0].num_activation()
133
+ input_shape = self.decoder.mock_forward(input_shape)
134
+ self.num_act_between_layers = cum_mul(input_shape)
135
+
136
+ if self.post_process:
137
+ self.num_act_post = self.output_layer.num_activation(input_shape)
138
+ softmax_activation = (
139
+ self.output_layer.num_activation(input_shape) * 2
140
+ ) # due to softmax is calculate in fp32
141
+ self.num_act_post += softmax_activation
142
+ ret += self.num_act_post
143
+ return ret
144
+
145
+ def mock_forward(self, input_shape: list[int]):
146
+ if self.pre_process:
147
+ input_shape = self.embedding.mock_forward(input_shape)
148
+ input_shape = self.decoder.mock_forward(input_shape)
149
+ if self.post_process:
150
+ input_shape = self.output_layer.mock_forward(input_shape)
151
+ return input_shape
moe_mem_estimator/layers.py ADDED
@@ -0,0 +1,1813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import (
2
+ MemEstimator,
3
+ set_global_config,
4
+ get_tensor_model_parallel_world_size,
5
+ get_tensor_model_parallel_rank,
6
+ cum_mul,
7
+ get_expert_tensor_parallel_world_size,
8
+ get_expert_tensor_parallel_rank,
9
+ get_pipeline_model_parallel_world_size,
10
+ get_pipeline_model_parallel_rank,
11
+ get_expert_model_parallel_rank,
12
+ get_expert_model_parallel_world_size,
13
+ is_pipeline_first_stage,
14
+ is_pipeline_last_stage,
15
+ _addindent,
16
+ colored,
17
+ )
18
+
19
+ from megatron.core.transformer.spec_utils import ModuleSpec
20
+ from typing import Dict, Literal, Optional, Union
21
+ from megatron.core.transformer.transformer_config import (
22
+ TransformerConfig,
23
+ MLATransformerConfig,
24
+ )
25
+ from megatron.core.model_parallel_config import ModelParallelConfig
26
+ from megatron.core.tensor_parallel.utils import VocabUtility
27
+ from megatron.core.transformer.transformer_block import (
28
+ TransformerBlockSubmodules,
29
+ )
30
+ from megatron.core.models.common.embeddings import (
31
+ _yarn_get_mscale,
32
+ apply_rotary_pos_emb,
33
+ )
34
+ from megatron.core.extensions.transformer_engine import (
35
+ _get_extra_te_kwargs,
36
+ get_expert_parallel_rng_tracker_name,
37
+ condition_init_method,
38
+ )
39
+ from megatron.core.transformer.enums import AttnMaskType
40
+ from megatron.core.transformer.mlp import MLPSubmodules
41
+ from megatron.core.utils import divide
42
+ from megatron.core.transformer.spec_utils import import_module
43
+ from megatron.core.transformer import transformer_layer
44
+ import types, math
45
+ import warnings
46
+ from copy import deepcopy
47
+
48
+
49
+ class LanguageModelEmbedding(MemEstimator):
50
+ def __init__(
51
+ self,
52
+ config: TransformerConfig,
53
+ vocab_size: int,
54
+ max_sequence_length: int,
55
+ position_embedding_type: Literal[
56
+ "learned_absolute", "rope", "none"
57
+ ] = "learned_absolute",
58
+ num_tokentypes: int = 0,
59
+ ):
60
+ super().__init__()
61
+
62
+ self.config: TransformerConfig = config
63
+ self.vocab_size: int = vocab_size
64
+ self.max_sequence_length: int = max_sequence_length
65
+ self.add_position_embedding: bool = (
66
+ position_embedding_type == "learned_absolute"
67
+ )
68
+ self.num_tokentypes = num_tokentypes
69
+ self.reduce_scatter_embeddings = (
70
+ (not self.add_position_embedding)
71
+ and self.num_tokentypes <= 0
72
+ and self.config.sequence_parallel
73
+ )
74
+ # Word embeddings (parallel).
75
+ self.word_embeddings = VocabParallelEmbedding(
76
+ num_embeddings=self.vocab_size,
77
+ embedding_dim=self.config.hidden_size,
78
+ init_method=self.config.init_method,
79
+ reduce_scatter_embeddings=self.reduce_scatter_embeddings,
80
+ config=self.config,
81
+ )
82
+
83
+ # TODO if self.add_position_embedding:
84
+
85
+ # TODO if self.num_tokentypes > 0:
86
+
87
+ self.embedding_dropout = Dropout(self.config.hidden_dropout)
88
+
89
+ def num_parameter(self):
90
+ ret = self.word_embeddings.num_parameter()
91
+ ret += self.embedding_dropout.num_parameter()
92
+ return ret
93
+
94
+ def num_activation(self, input_shape: list[int]):
95
+ ret = self.word_embeddings.num_activation(input_shape)
96
+ input_shape = self.word_embeddings.mock_forward(input_shape)
97
+ ret += self.embedding_dropout.num_activation(input_shape)
98
+ return ret
99
+
100
+ def mock_forward(self, input_shape: list[int]):
101
+ input_shape = self.word_embeddings.mock_forward(input_shape)
102
+ return input_shape
103
+
104
+
105
+ class VocabParallelEmbedding(MemEstimator):
106
+ def __init__(
107
+ self,
108
+ num_embeddings: int,
109
+ embedding_dim: int,
110
+ *,
111
+ init_method,
112
+ reduce_scatter_embeddings: bool = False,
113
+ config: ModelParallelConfig,
114
+ ):
115
+ super().__init__()
116
+ # Keep the input dimensions.
117
+ self.num_embeddings = num_embeddings
118
+ self.embedding_dim = embedding_dim
119
+ self.reduce_scatter_embeddings = reduce_scatter_embeddings
120
+ self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
121
+ # Divide the weight matrix along the vocaburaly dimension.
122
+ (self.vocab_start_index, self.vocab_end_index) = (
123
+ VocabUtility.vocab_range_from_global_vocab_size(
124
+ self.num_embeddings,
125
+ get_tensor_model_parallel_rank(),
126
+ self.tensor_model_parallel_size,
127
+ )
128
+ )
129
+ self.num_embeddings_per_partition = (
130
+ self.vocab_end_index - self.vocab_start_index
131
+ )
132
+ self.deterministic_mode = config.deterministic_mode
133
+ self.weight = (self.num_embeddings_per_partition, self.embedding_dim)
134
+
135
+ def num_parameter(self):
136
+ return self.weight[0] * self.weight[1]
137
+
138
+ def num_activation(self, input_shape: list[int]):
139
+ return cum_mul(input_shape) * self.weight[1]
140
+
141
+ def mock_forward(self, input_shape: list[int]):
142
+ return input_shape + [self.weight[1]]
143
+
144
+
145
+ class Dropout(MemEstimator):
146
+ def __init__(self, p=0, *args, **kwargs):
147
+ super().__init__()
148
+ self.p = p
149
+
150
+ def num_parameter(self):
151
+ return 0
152
+
153
+ def num_activation(self, input_shape: list[int]):
154
+ if self.p == 0:
155
+ return 0
156
+ return cum_mul(input_shape[:])
157
+
158
+ def mock_forward(self, input_shape: list[int]):
159
+ return input_shape
160
+
161
+
162
+ class ColumnParallelLinear(MemEstimator):
163
+ def __init__(
164
+ self,
165
+ input_size,
166
+ output_size,
167
+ *,
168
+ config: ModelParallelConfig,
169
+ init_method,
170
+ bias=True,
171
+ gather_output=False,
172
+ stride=1,
173
+ keep_master_weight_for_test=False,
174
+ skip_bias_add=False,
175
+ skip_weight_param_allocation: bool = False,
176
+ embedding_activation_buffer=None,
177
+ grad_output_buffer=None,
178
+ is_expert: bool = False,
179
+ tp_comm_buffer_name: str = None, # Not used
180
+ disable_grad_reduce: bool = False,
181
+ is_mla: bool = False,
182
+ ):
183
+ super().__init__()
184
+
185
+ if is_mla and config.sequence_parallel:
186
+ tp_size = get_tensor_model_parallel_world_size()
187
+ output_size = divide(output_size, tp_size)
188
+ parallel_mode = None
189
+ tp_size = 1
190
+ tp_group = None
191
+ # Keep input parameters
192
+ self.input_size = input_size
193
+ self.output_size = output_size
194
+ self.gather_output = gather_output
195
+ # Divide the weight matrix along the last dimension.
196
+ self.skip_bias_add = skip_bias_add
197
+ self.is_expert = is_expert
198
+ self.expert_parallel = config.expert_model_parallel_size > 1
199
+ self.embedding_activation_buffer = embedding_activation_buffer
200
+ self.grad_output_buffer = grad_output_buffer
201
+ self.config = config
202
+ self.disable_grad_reduce = disable_grad_reduce
203
+
204
+ if is_expert:
205
+ world_size = get_expert_tensor_parallel_world_size()
206
+ rank = get_expert_tensor_parallel_rank()
207
+ else:
208
+ world_size = get_tensor_model_parallel_world_size()
209
+ rank = get_tensor_model_parallel_rank()
210
+
211
+ self.output_size_per_partition = divide(output_size, world_size)
212
+
213
+ # Parameters.
214
+ # Note: torch.nn.functional.linear performs XA^T + b and as a result
215
+ # we allocate the transpose.
216
+ # Initialize weight.
217
+ if not skip_weight_param_allocation:
218
+ self.weight = (self.output_size_per_partition, self.input_size)
219
+ else:
220
+ self.weight = (self.output_size_per_partition, self.input_size)
221
+
222
+
223
+ if bias:
224
+ self.bias = [self.output_size_per_partition]
225
+ else:
226
+ self.bias = None
227
+
228
+ self.sequence_parallel = config.sequence_parallel
229
+ if self.sequence_parallel and world_size <= 1:
230
+ warnings.warn(
231
+ "`sequence_parallel` is set to `True`, but tensor model parallel size "
232
+ f"is {world_size}. Disabling sequence parallel."
233
+ )
234
+ self.sequence_parallel = False
235
+
236
+ self.allreduce_dgrad = (
237
+ world_size > 1
238
+ and not self.sequence_parallel
239
+ and not self.disable_grad_reduce
240
+ )
241
+ self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
242
+
243
+ def num_parameter(self):
244
+ ret = cum_mul(self.weight)
245
+ if self.bias is not None:
246
+ ret += self.bias[0]
247
+ return ret
248
+
249
+ def num_activation(self, input_shape: list[int]):
250
+ return cum_mul(input_shape[:-1]) * self.weight[0]
251
+
252
+ def mock_forward(self, input_shape: list[int]):
253
+ assert self.weight[-1] == input_shape[-1]
254
+ return input_shape[:-1] + [self.weight[0]]
255
+
256
+
257
+ class RowParallelLinear(MemEstimator):
258
+ def __init__(
259
+ self,
260
+ input_size: int,
261
+ output_size: int,
262
+ *,
263
+ config: ModelParallelConfig,
264
+ init_method,
265
+ bias: bool,
266
+ input_is_parallel: bool,
267
+ skip_bias_add: bool,
268
+ stride: int = 1,
269
+ keep_master_weight_for_test: bool = False,
270
+ is_expert: bool = False,
271
+ tp_comm_buffer_name: str = None, # Not used
272
+ ):
273
+ super().__init__()
274
+
275
+ # Keep input parameters
276
+ self.input_size = input_size
277
+ self.output_size = output_size
278
+ self.input_is_parallel = input_is_parallel
279
+ self.skip_bias_add = skip_bias_add
280
+ self.config = config
281
+ self.is_expert = is_expert
282
+ self.expert_parallel = config.expert_model_parallel_size > 1
283
+ self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
284
+ self.sequence_parallel = config.sequence_parallel
285
+ if self.sequence_parallel and not self.input_is_parallel:
286
+ raise RuntimeError(
287
+ "To enable `sequence_parallel`, `input_is_parallel` must be `True`"
288
+ )
289
+
290
+ # Divide the weight matrix along the last dimension.
291
+ if self.is_expert:
292
+ world_size = get_expert_tensor_parallel_world_size()
293
+ rank = get_expert_tensor_parallel_rank()
294
+ else:
295
+ world_size = get_tensor_model_parallel_world_size()
296
+ rank = get_tensor_model_parallel_rank()
297
+
298
+ self.input_size_per_partition = divide(input_size, world_size)
299
+
300
+ self.weight = (self.output_size, self.input_size_per_partition)
301
+ if bias:
302
+ self.bias = [self.output_size]
303
+ else:
304
+ self.bias = None
305
+
306
+ def num_parameter(self):
307
+ ret = cum_mul(self.weight)
308
+ if self.bias is not None:
309
+ ret += self.bias[0]
310
+ return ret
311
+
312
+ def num_activation(self, input_shape: list[int]):
313
+ return cum_mul(input_shape[:-1]) * self.weight[1]
314
+
315
+ def mock_forward(self, input_shape: list[int]):
316
+ assert self.weight[0] == input_shape[-1]
317
+ return input_shape[:-1] + [self.weight[1]]
318
+
319
+
320
+ class RMSNorm(MemEstimator):
321
+ def __init__(self, hidden_size: int, *args, **kwargs):
322
+ super().__init__()
323
+ self.weight = hidden_size
324
+
325
+ def num_parameter(self):
326
+ return self.weight
327
+
328
+ def num_activation(self, input_shape: list[int]):
329
+ return cum_mul(input_shape[:])
330
+
331
+ def mock_forward(self, input_shape: list[int]):
332
+ return input_shape
333
+
334
+
335
+ class GetBiasDropoutAdd(MemEstimator):
336
+ def __init__(self, *args, **kwargs):
337
+ super().__init__()
338
+
339
+ def num_parameter(self):
340
+ return 0
341
+
342
+ def num_activation(self, input_shape: list[int]):
343
+ return cum_mul(input_shape[:])
344
+
345
+ def mock_forward(self, input_shape: list[int]):
346
+ return input_shape
347
+
348
+
349
+ get_bias_dropout_add = GetBiasDropoutAdd()
350
+
351
+
352
+ class MLP(MemEstimator):
353
+
354
+ def __init__(
355
+ self,
356
+ config: TransformerConfig,
357
+ submodules,
358
+ is_expert: bool = False,
359
+ input_size: int = None,
360
+ ):
361
+ super().__init__()
362
+
363
+ self.config: TransformerConfig = config
364
+
365
+ self.input_size = input_size if input_size != None else self.config.hidden_size
366
+
367
+ # If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
368
+ ffn_hidden_size = self.config.ffn_hidden_size
369
+ if self.config.gated_linear_unit:
370
+ ffn_hidden_size *= 2
371
+
372
+ self.linear_fc1 = build_module(
373
+ submodules.linear_fc1,
374
+ self.input_size,
375
+ ffn_hidden_size,
376
+ config=self.config,
377
+ init_method=self.config.init_method,
378
+ gather_output=False,
379
+ bias=self.config.add_bias_linear,
380
+ skip_bias_add=True,
381
+ is_expert=is_expert,
382
+ tp_comm_buffer_name="fc1",
383
+ )
384
+
385
+ self.activation_func = self.config.activation_func
386
+
387
+ self.linear_fc2 = build_module(
388
+ submodules.linear_fc2,
389
+ self.config.ffn_hidden_size,
390
+ self.config.hidden_size,
391
+ config=self.config,
392
+ init_method=self.config.output_layer_init_method,
393
+ bias=self.config.add_bias_linear,
394
+ input_is_parallel=True,
395
+ skip_bias_add=True,
396
+ is_expert=is_expert,
397
+ tp_comm_buffer_name="fc2",
398
+ )
399
+
400
+ def num_parameter(self):
401
+ return self.linear_fc1.num_parameter() + self.linear_fc2.num_parameter()
402
+
403
+ def num_activation(self, input_shape: list[int]):
404
+ result = 0
405
+ result += self.linear_fc1.num_activation(input_shape)
406
+ intermediate_shape = self.linear_fc1.mock_forward(input_shape)
407
+ result += cum_mul(intermediate_shape) / 2 # activation layer
408
+ self.linear_fc2.num_activation(intermediate_shape)
409
+
410
+ return result
411
+
412
+ def mock_forward(self, input_shape: list[int]):
413
+ intermediate_shape = self.linear_fc1.mock_forward(input_shape)
414
+ output_shape = self.linear_fc2.mock_forward(intermediate_shape)
415
+ return output_shape
416
+
417
+
418
+ class ModuleList(MemEstimator):
419
+ def __init__(self, modules: list[MemEstimator] = None):
420
+ super().__init__()
421
+ if modules is None:
422
+ modules = []
423
+ self.modules = modules
424
+
425
+ def __repr__(self):
426
+ """Return a custom repr for ModuleList that compresses repeated module representations."""
427
+ list_of_reprs = [repr(item) for item in self.modules]
428
+ if len(list_of_reprs) == 0:
429
+ return self._get_name() + "()"
430
+
431
+ start_end_indices = [[0, 0]]
432
+ repeated_blocks = [list_of_reprs[0]]
433
+ for i, r in enumerate(list_of_reprs[1:], 1):
434
+ if r == repeated_blocks[-1]:
435
+ start_end_indices[-1][1] += 1
436
+ continue
437
+
438
+ start_end_indices.append([i, i])
439
+ repeated_blocks.append(r)
440
+
441
+ lines = []
442
+ stat = (
443
+ "\t/* n_params="
444
+ + colored(f"{self.num_parameter()/1024/1024:.2f}M", "red")
445
+ + "\tn_act="
446
+ + colored(f"{self.num_activation()/1024/1024:.2f}M", "green")
447
+ + " */"
448
+ )
449
+ main_str = self._get_name() + stat + " ("
450
+ for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
451
+ local_repr = f"({start_id}): {b}" # default repr
452
+
453
+ if start_id != end_id:
454
+ n = end_id - start_id + 1
455
+ local_repr = f"({start_id}-{end_id}): {n} x {b}"
456
+
457
+ local_repr = _addindent(local_repr, 2)
458
+ lines.append(local_repr)
459
+
460
+ main_str += "\n " + "\n ".join(lines) + "\n"
461
+ main_str += ")"
462
+ return main_str
463
+
464
+ def dump(self):
465
+ list_of_reprs = [repr(item) for item in self.modules]
466
+ if len(list_of_reprs) == 0:
467
+ return self._get_name() + "()"
468
+ list_of_dumps = [item.dump() for item in self.modules]
469
+
470
+ start_end_indices = [[0, 0]]
471
+ repeated_blocks = [list_of_reprs[0]]
472
+ repeated_blocks_dump = [list_of_dumps[0]]
473
+ for i, r in enumerate(list_of_reprs[1:], 1):
474
+ if r == repeated_blocks[-1]:
475
+ start_end_indices[-1][1] += 1
476
+ continue
477
+
478
+ start_end_indices.append([i, i])
479
+ repeated_blocks.append(r)
480
+ repeated_blocks_dump(list_of_dumps[i])
481
+ modules = {}
482
+ for (start_id, end_id), b in zip(start_end_indices, repeated_blocks_dump):
483
+ key = f"({start_id})"
484
+ if start_id != end_id:
485
+ n = end_id - start_id + 1
486
+ key = f"({start_id}-{end_id}) {n} layers"
487
+ modules[key] = b
488
+
489
+ ret = {}
490
+ ret["name"] = self._get_name()
491
+ ret["n_params"] = self.num_parameter()
492
+ ret["n_act"] = self.num_activation()
493
+ if len(modules) > 0:
494
+ ret["modules"] = modules
495
+ return ret
496
+
497
+ def append(self, m: MemEstimator):
498
+ self.modules.append(m)
499
+
500
+ def __len__(
501
+ self,
502
+ ):
503
+ return self.modules.__len__()
504
+
505
+ def num_parameter(self):
506
+ return sum([x.num_parameter() for x in self.modules])
507
+
508
+ def num_activation(self, input_shape: list[int]):
509
+ result = 0
510
+ for m in self.modules:
511
+ result += m.num_activation(input_shape)
512
+ input_shape = m.mock_forward(input_shape)
513
+
514
+ return result
515
+
516
+ def mock_forward(self, input_shape: list[int]):
517
+ for m in self.modules:
518
+ result += m.num_activation(input_shape)
519
+ input_shape = m.mock_forward(input_shape)
520
+ return input_shape
521
+
522
+
523
+ class SequentialMLP(MemEstimator):
524
+ def __init__(self, num_local_experts, config: TransformerConfig, submodules):
525
+ super().__init__()
526
+ self.config = config
527
+ self.add_bias = config.add_bias_linear
528
+ self.moe_extended_tp = config.moe_extended_tp
529
+ self.num_local_experts = num_local_experts
530
+ self.local_experts = ModuleList()
531
+ for _ in range(self.num_local_experts):
532
+ expert = MLP(self.config, submodules, is_expert=True)
533
+ self.local_experts.append(expert)
534
+
535
+ def num_parameter(self):
536
+ return self.local_experts.num_parameter()
537
+
538
+ def num_activation(self, input_shape: list[int], tokens_per_expert=None):
539
+ # assume all the inputs are routed equally
540
+ all_tokens = input_shape[1]
541
+ result = 0
542
+ for m in self.local_experts.modules:
543
+ result += m.num_activation(
544
+ input_shape[:1]
545
+ + [all_tokens // self.num_local_experts]
546
+ + input_shape[2:]
547
+ )
548
+ return result
549
+
550
+ def mock_forward(self, input_shape: list[int], tokens_per_expert=None):
551
+ # assume all the inputs are routed to the first expert
552
+ input_shape = self.local_experts.modules[0].mock_forward(input_shape)
553
+ return input_shape
554
+
555
+
556
+ class TEGroupedMLP(MemEstimator):
557
+ """An efficient implementation of the Experts layer using TE's GroupedLinear.
558
+
559
+ Executes multiple experts in parallel to maximize computational efficiency.
560
+ """
561
+
562
+ def __init__(self, num_local_experts, config: TransformerConfig, submodules):
563
+ super().__init__()
564
+ self.config = config
565
+ self.moe_extended_tp = config.moe_extended_tp
566
+ self.num_local_experts = num_local_experts
567
+ self.input_size = self.config.hidden_size
568
+
569
+ # Double the output width with gated linear unit, see https://arxiv.org/pdf/2002.05202.pdf
570
+ ffn_hidden_size = self.config.moe_ffn_hidden_size
571
+ if self.config.gated_linear_unit:
572
+ ffn_hidden_size *= 2
573
+
574
+ self.linear_fc1 = build_module(
575
+ submodules.linear_fc1,
576
+ self.num_local_experts,
577
+ self.input_size,
578
+ ffn_hidden_size,
579
+ config=self.config,
580
+ init_method=self.config.init_method,
581
+ bias=self.config.add_bias_linear,
582
+ skip_bias_add=True,
583
+ is_expert=True,
584
+ tp_comm_buffer_name="fc1",
585
+ )
586
+
587
+ self.activation_func = self.config.activation_func
588
+
589
+ self.linear_fc2 = build_module(
590
+ submodules.linear_fc2,
591
+ self.num_local_experts,
592
+ self.config.moe_ffn_hidden_size,
593
+ self.config.hidden_size,
594
+ config=self.config,
595
+ init_method=self.config.output_layer_init_method,
596
+ bias=self.config.add_bias_linear,
597
+ skip_bias_add=True,
598
+ is_expert=True,
599
+ tp_comm_buffer_name="fc2",
600
+ )
601
+ # TODO if self.config.fp8:
602
+
603
+ def num_parameter(self):
604
+ ret = self.linear_fc1.num_parameter()
605
+ ret += self.linear_fc2.num_parameter()
606
+ return ret
607
+
608
+ def num_activation(self, input_shape: list[int], tokens_per_expert=None):
609
+ ret = 0
610
+ ret += self.linear_fc1.num_activation(input_shape)
611
+ input_shape = self.linear_fc1.mock_forward(input_shape)
612
+
613
+ # activation
614
+ ret += cum_mul(input_shape) / 2 # swiglu or gelu
615
+ input_shape = deepcopy(input_shape)
616
+ input_shape[-1] //= 2
617
+
618
+ self.linear_fc2.num_activation(input_shape)
619
+ return ret
620
+
621
+ def mock_forward(self, input_shape: list[int], tokens_per_expert=None):
622
+ # assume all the inputs are routed to the first expert
623
+ input_shape = self.local_experts.modules[0].mock_forward(input_shape)
624
+ return input_shape
625
+
626
+
627
+ class TEGroupedLinear(MemEstimator):
628
+ def __init__(
629
+ self,
630
+ num_gemms: int,
631
+ input_size: int,
632
+ output_size: int,
633
+ *,
634
+ parallel_mode: str,
635
+ config: ModelParallelConfig,
636
+ init_method,
637
+ bias: bool,
638
+ skip_bias_add: bool,
639
+ is_expert: bool = False,
640
+ tp_comm_buffer_name: str = None,
641
+ ):
642
+ super().__init__()
643
+ self.config = config
644
+
645
+ # TE returns a zero length Tensor when bias=False and
646
+ # return_bias=True, but we prefer None. So in that case we
647
+ # tell TE to not return the bias, and return None
648
+ # ourselves. This way our forward always returns two values
649
+ # and we don't have to deal with the zero length Tensor.
650
+ self.te_return_bias = skip_bias_add and bias
651
+ self.is_first_microbatch = True
652
+ self.disable_parameter_transpose_cache = (
653
+ self.config.disable_parameter_transpose_cache
654
+ )
655
+
656
+ extra_kwargs = _get_extra_te_kwargs(config)
657
+ extra_kwargs["ub_name"] = tp_comm_buffer_name
658
+
659
+ self.expert_parallel = self.config.expert_model_parallel_size > 1
660
+ if self.expert_parallel:
661
+ extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name()
662
+
663
+ # For MoE models, the comms between TP and EP group is explicitly handled by
664
+ # MoE token dispatcher. So we disable comms by making TE agnostic of model parallel.
665
+ self.explicit_expert_comm = is_expert and (
666
+ config.tensor_model_parallel_size > 1 or self.expert_parallel
667
+ )
668
+ if is_expert:
669
+ tp_size = get_expert_tensor_parallel_world_size()
670
+ else:
671
+ tp_size = get_tensor_model_parallel_world_size()
672
+ if self.explicit_expert_comm:
673
+ if parallel_mode == "column":
674
+ output_size = divide(output_size, tp_size)
675
+ elif parallel_mode == "row":
676
+ input_size = divide(input_size, tp_size)
677
+ parallel_mode = None
678
+ tp_size = 1
679
+ assert not bias, "bias is not considered for now"
680
+
681
+ self.num_gemms = num_gemms
682
+ self.input_size = input_size
683
+ self.output_size = output_size
684
+
685
+ def num_parameter(self):
686
+ ret = self.num_gemms * self.input_size * self.output_size
687
+ return ret
688
+
689
+ def num_activation(self, input_shape: list[int], tokens_per_expert=None):
690
+ ret = cum_mul(self.mock_forward(input_shape))
691
+ return ret
692
+
693
+ def mock_forward(self, input_shape: list[int], tokens_per_expert=None):
694
+ return input_shape[:-1] + [self.output_size]
695
+
696
+
697
+ class TEColumnParallelGroupedLinear(TEGroupedLinear):
698
+ def __init__(
699
+ self,
700
+ num_gemms: int,
701
+ input_size: int,
702
+ output_size: int,
703
+ *,
704
+ config: ModelParallelConfig,
705
+ init_method,
706
+ bias: bool,
707
+ skip_bias_add: bool,
708
+ is_expert: bool,
709
+ tp_comm_buffer_name: str = None,
710
+ ):
711
+ super().__init__(
712
+ num_gemms=num_gemms,
713
+ input_size=input_size,
714
+ output_size=output_size,
715
+ parallel_mode="column",
716
+ config=config,
717
+ init_method=condition_init_method(config, init_method),
718
+ bias=bias,
719
+ skip_bias_add=skip_bias_add,
720
+ is_expert=is_expert,
721
+ tp_comm_buffer_name=tp_comm_buffer_name,
722
+ )
723
+
724
+
725
+ class TERowParallelGroupedLinear(TEGroupedLinear):
726
+ def __init__(
727
+ self,
728
+ num_gemms: int,
729
+ input_size: int,
730
+ output_size: int,
731
+ *,
732
+ config: ModelParallelConfig,
733
+ init_method,
734
+ bias: bool,
735
+ skip_bias_add: bool,
736
+ is_expert: bool,
737
+ tp_comm_buffer_name: str = None,
738
+ ):
739
+
740
+ super().__init__(
741
+ num_gemms=num_gemms,
742
+ input_size=input_size,
743
+ output_size=output_size,
744
+ parallel_mode="row",
745
+ config=config,
746
+ init_method=condition_init_method(config, init_method),
747
+ bias=bias,
748
+ skip_bias_add=skip_bias_add,
749
+ is_expert=is_expert,
750
+ tp_comm_buffer_name=tp_comm_buffer_name,
751
+ )
752
+
753
+
754
+ class SharedExpertMLP(MLP):
755
+ """
756
+ MLP layer for Shared Experts.
757
+ """
758
+
759
+ def __init__(self, config: TransformerConfig, spec: ModuleSpec):
760
+ config = deepcopy(config)
761
+ assert (
762
+ config.add_bias_linear == False
763
+ ), "bias is not supported in the shared experts, "
764
+ "please set '--disable-bias-linear' instead."
765
+
766
+ config.ffn_hidden_size = config.moe_shared_expert_intermediate_size
767
+ super().__init__(config=config, submodules=spec.submodules)
768
+
769
+ self.use_shared_expert_gate = spec.params.get("gate", False)
770
+ if self.use_shared_expert_gate:
771
+ assert False, "use_shared_expert_gate is not Implemented"
772
+ # self.gate_weight = torch.nn.Parameter(torch.empty((1, self.config.hidden_size)))
773
+ # if config.perform_initialization:
774
+ # if get_cuda_rng_tracker().is_initialized():
775
+ # with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
776
+ # config.init_method(self.gate_weight)
777
+ # else:
778
+ # config.init_method(self.gate_weight)
779
+ # self.gate_weight.data = self.gate_weight.data.to(dtype=config.params_dtype)
780
+ # setattr(self.gate_weight, 'sequence_parallel', self.config.sequence_parallel)
781
+ else:
782
+ self.gate_weight = None
783
+
784
+
785
+ class TransformerBlock(MemEstimator):
786
+ """Transformer class."""
787
+
788
+ def __init__(
789
+ self,
790
+ config: TransformerConfig,
791
+ spec: Union[TransformerBlockSubmodules, ModuleSpec],
792
+ post_layer_norm: bool = True,
793
+ pre_process: bool = True,
794
+ post_process: bool = True,
795
+ ):
796
+ super().__init__()
797
+ self.config = config
798
+
799
+ self.submodules = _get_block_submodules(config, spec)
800
+ self.post_layer_norm = post_layer_norm
801
+ self.pre_process = pre_process
802
+ self.post_process = post_process
803
+ self.cuda_graphs = {}
804
+ self.current_microbatch = -1
805
+ self.input_tensor = None
806
+ self.checkpoint_core_attention = (
807
+ self.config.recompute_granularity == "selective"
808
+ )
809
+
810
+ self._build_layers()
811
+ self.num_layers_per_pipeline_rank = len(self.layers)
812
+ self.tp_only_amax_red = config.tp_only_amax_red
813
+
814
+ def _build_layers(self):
815
+ def build_layer(layer_spec, layer_number):
816
+ return build_module(
817
+ layer_spec, config=self.config, layer_number=layer_number
818
+ )
819
+
820
+ # offset is implicit in TransformerLayer
821
+ self.layers = ModuleList(
822
+ [
823
+ build_layer(layer_spec, i + 1)
824
+ for i, layer_spec in enumerate(self.submodules.layer_specs)
825
+ ]
826
+ )
827
+
828
+ if self.submodules.layer_norm and self.post_process and self.post_layer_norm:
829
+ self.final_layernorm = build_module(
830
+ self.submodules.layer_norm,
831
+ config=self.config,
832
+ hidden_size=self.config.hidden_size,
833
+ eps=self.config.layernorm_epsilon,
834
+ )
835
+ else:
836
+ self.final_layernorm = None # Either this or nn.Identity
837
+
838
+ def num_parameter(self):
839
+ ret = self.layers.num_parameter()
840
+ if self.final_layernorm is not None:
841
+ ret += self.final_layernorm.num_parameter()
842
+
843
+ return ret
844
+
845
+ def num_activation(self, input_shape: list[int]):
846
+ result = self.layers.num_activation(input_shape)
847
+ if self.final_layernorm is not None:
848
+ result += self.final_layernorm.num_activation(input_shape)
849
+ return result
850
+
851
+ def mock_forward(self, input_shape: list[int]):
852
+ return input_shape
853
+
854
+
855
+ class TopKRouter(MemEstimator):
856
+
857
+ def __init__(self, config: TransformerConfig) -> None:
858
+ super().__init__()
859
+ self.config = config
860
+ self.topk = self.config.moe_router_topk
861
+ self.routing_type = self.config.moe_router_load_balancing_type
862
+ self.input_jitter = None
863
+
864
+ def num_parameter(self):
865
+ return 0
866
+
867
+ def num_activation(self, input_shape: list[int]):
868
+ result = cum_mul(input_shape) * 2 # sinkhorn and sinkhorn activation
869
+ return result
870
+
871
+ def mock_forward(self, input_shape: list[int]):
872
+ return input_shape[:-1] + [self.topk]
873
+
874
+
875
+ class MoELayer(MemEstimator):
876
+
877
+ def __init__(
878
+ self, config: TransformerConfig, submodules=None, layer_number: int = None
879
+ ):
880
+ super().__init__()
881
+ self.config = config
882
+ self.submodules = submodules
883
+ self.moe_layer_recompute = config.moe_layer_recompute
884
+
885
+ self.expert_parallel_size = get_expert_model_parallel_world_size()
886
+ assert (
887
+ self.expert_parallel_size > 0
888
+ ), "Expected non-negative expert parallel size"
889
+
890
+ assert self.config.num_moe_experts % self.expert_parallel_size == 0
891
+ self.num_local_experts = (
892
+ self.config.num_moe_experts // self.expert_parallel_size
893
+ )
894
+ local_expert_indices_offset = (
895
+ get_expert_model_parallel_rank() * self.num_local_experts
896
+ )
897
+
898
+ self.router = TopKRouter(config=self.config)
899
+ self.use_shared_expert = (
900
+ self.config.moe_shared_expert_intermediate_size is not None
901
+ )
902
+ self.shared_expert_overlap = self.config.moe_shared_expert_overlap
903
+
904
+ self.local_expert_indices = [
905
+ local_expert_indices_offset + i for i in range(self.num_local_experts)
906
+ ]
907
+ assert all(
908
+ map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices)
909
+ )
910
+
911
+ self.experts = None
912
+ self.shared_experts = None
913
+ self.token_dispatcher = None
914
+ self.layer_number = layer_number
915
+ # Initialize experts
916
+ self.experts = build_module(
917
+ self.submodules.experts, self.num_local_experts, self.config
918
+ )
919
+
920
+ # Initialize shared experts
921
+ if self.use_shared_expert:
922
+ self.shared_experts = SharedExpertMLP(
923
+ self.config, self.submodules.shared_experts
924
+ )
925
+ # if self.shared_expert_overlap:
926
+ # self.token_dispatcher.set_shared_experts(self.shared_experts)
927
+
928
+ def num_parameter(self):
929
+ ret = self.experts.num_parameter() + self.router.num_parameter()
930
+ if self.use_shared_expert:
931
+ ret += self.shared_experts.num_parameter()
932
+ return ret
933
+
934
+ def num_activation(self, input_shape: list[int]):
935
+ result = self.router.num_activation(input_shape)
936
+ result += cum_mul(input_shape) * self.router.topk # token dispatcher
937
+ moe_input_shape_average = deepcopy(input_shape)
938
+ moe_input_shape_average[1] = int(moe_input_shape_average[1] * self.router.topk)
939
+
940
+ result += self.experts.num_activation(moe_input_shape_average)
941
+ if self.use_shared_expert:
942
+ result += self.shared_experts.num_activation(input_shape)
943
+
944
+ if self.config.moe_layer_recompute:
945
+ result = cum_mul(input_shape) * 2
946
+ return result
947
+
948
+ def mock_forward(self, input_shape: list[int]):
949
+ return input_shape
950
+
951
+
952
+ class IdentityOp(MemEstimator):
953
+ def num_parameter(self):
954
+ return 0
955
+
956
+ def num_activation(self, input_shape: list[int]):
957
+ return 0
958
+
959
+ def mock_forward(self, input_shape: list[int]):
960
+ return input_shape
961
+
962
+
963
+ IdentityFuncOp = IdentityOp
964
+ TERowParallelLinear = RowParallelLinear
965
+ TEColumnParallelLinear = ColumnParallelLinear
966
+ TELayerNormColumnParallelLinear = ColumnParallelLinear
967
+
968
+
969
+ class TEDotProductAttention(MemEstimator):
970
+ def __init__(self, config: TransformerConfig, *args, **kwargs):
971
+ super().__init__()
972
+ self.config = config
973
+
974
+ def num_parameter(self):
975
+ return 0
976
+
977
+ def num_activation(
978
+ self, q_shape: list[int], k_shape: list[int], v_shape: list[int]
979
+ ):
980
+ bs, seqs, heads, dim = q_shape
981
+ if self.config.multi_latent_attention and False:
982
+ result = bs * seqs * seqs * heads
983
+ else:
984
+ bs, seqs, heads, dim = k_shape
985
+ result = (
986
+ bs * seqs * dim * heads * 2 # * self.config.tensor_model_parallel_size
987
+ ) # flash attention
988
+ if self.config.context_parallel_size > 1:
989
+ result *= 2
990
+ return result
991
+
992
+ def mock_forward(
993
+ self,
994
+ hidden_size: int,
995
+ q_shape: list[int],
996
+ k_shape: list[int],
997
+ v_shape: list[int],
998
+ ):
999
+ seqs, bs, heads, dim = q_shape
1000
+ return [seqs, bs, hidden_size]
1001
+
1002
+
1003
+ class TransformerLayer(MemEstimator):
1004
+ def __init__(
1005
+ self,
1006
+ config: TransformerConfig,
1007
+ submodules,
1008
+ layer_number: int = 1,
1009
+ hidden_dropout: float = None,
1010
+ ):
1011
+ super().__init__()
1012
+ self.config = config
1013
+
1014
+ if config.enable_cuda_graph and self.training:
1015
+ assert (
1016
+ not config.cpu_offloading and config.recompute_granularity is None
1017
+ ), "Cudagraphs not supported"
1018
+ self.cudagraph_manager = CudaGraphManager()
1019
+
1020
+ self.submodules_config = submodules
1021
+ self.layer_number = layer_number + get_transformer_layer_offset(self.config)
1022
+ self.hidden_dropout = (
1023
+ config.hidden_dropout if hidden_dropout is None else hidden_dropout
1024
+ )
1025
+
1026
+ # [Module 1: Input Layernorm] Optional Layernorm on the input data
1027
+ # TODO: add pytorch only layernorm
1028
+ self.input_layernorm = build_module(
1029
+ submodules.input_layernorm,
1030
+ config=self.config,
1031
+ hidden_size=self.config.hidden_size,
1032
+ eps=self.config.layernorm_epsilon,
1033
+ )
1034
+
1035
+ # [Module 2: SelfAttention]
1036
+ self.self_attention = build_module(
1037
+ submodules.self_attention, config=self.config, layer_number=layer_number
1038
+ )
1039
+
1040
+ # [Module 3: BiasDropoutFusion]
1041
+ self.self_attn_bda = build_module(submodules.self_attn_bda)
1042
+
1043
+ # [Module 4: Post SelfAttention] Optional Layernorm after self-attn
1044
+ self.pre_cross_attn_layernorm = build_module(
1045
+ submodules.pre_cross_attn_layernorm,
1046
+ config=self.config,
1047
+ hidden_size=self.config.hidden_size,
1048
+ eps=self.config.layernorm_epsilon,
1049
+ )
1050
+
1051
+ # [Module 5: CrossAttention]
1052
+ self.cross_attention = build_module(
1053
+ submodules.cross_attention, config=self.config, layer_number=layer_number
1054
+ )
1055
+
1056
+ # [Module 6: BiasDropoutFusion]
1057
+ self.cross_attn_bda = build_module(
1058
+ submodules.cross_attn_bda, config=self.config
1059
+ )
1060
+
1061
+ # [Module 7: Pre MLP] Optional Layernorm before MLP
1062
+ self.pre_mlp_layernorm = build_module(
1063
+ submodules.pre_mlp_layernorm,
1064
+ config=self.config,
1065
+ hidden_size=self.config.hidden_size,
1066
+ eps=self.config.layernorm_epsilon,
1067
+ )
1068
+
1069
+ # [Module 8: MLP block]
1070
+ self.mlp = build_module(submodules.mlp, config=self.config)
1071
+ if hasattr(self.mlp, "set_layer_number"):
1072
+ self.mlp.set_layer_number(self.layer_number)
1073
+
1074
+ # [Module 9: BiasDropoutFusion]
1075
+ self.mlp_bda = build_module(submodules.mlp_bda)
1076
+
1077
+ def num_parameter(self):
1078
+ result = self.input_layernorm.num_parameter()
1079
+ result += self.self_attention.num_parameter()
1080
+ result += self.pre_cross_attn_layernorm.num_parameter()
1081
+ result += self.cross_attention.num_parameter()
1082
+ result += self.cross_attn_bda.num_parameter()
1083
+ result += self.pre_mlp_layernorm.num_parameter()
1084
+ result += self.mlp.num_parameter()
1085
+
1086
+ return result
1087
+
1088
+ def num_activation(self, input_shape: list[int]):
1089
+ result = 0
1090
+ result += self.self_attention.num_activation(input_shape)
1091
+ result += self.mlp.num_activation(input_shape)
1092
+ # __import__('ipdb').set_trace()
1093
+ # sequence parallel
1094
+ if self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1:
1095
+ input_shape = deepcopy(input_shape)
1096
+ input_shape[1] /= self.config.tensor_model_parallel_size
1097
+ result += self.input_layernorm.num_activation(input_shape)
1098
+ result += self.pre_mlp_layernorm.num_activation(input_shape)
1099
+ result += self.self_attn_bda.num_activation(input_shape)
1100
+ result += self.mlp_bda.num_activation(input_shape)
1101
+ return result
1102
+
1103
+ def mock_forward(self, input_shape: list[int]):
1104
+ return input_shape
1105
+
1106
+
1107
+ class SelfAttention(MemEstimator):
1108
+
1109
+ def __init__(
1110
+ self,
1111
+ config: TransformerConfig,
1112
+ submodules,
1113
+ layer_number: int,
1114
+ attn_mask_type,
1115
+ ):
1116
+ super().__init__()
1117
+
1118
+ self.config = config
1119
+ self.layer_number = layer_number
1120
+ self.attn_mask_type = attn_mask_type
1121
+ self.attention_type = ""
1122
+
1123
+ # For normal attention without groups, num_query_groups == num_attention_heads,
1124
+ # so these two will be the same
1125
+ self.query_projection_size = (
1126
+ self.config.kv_channels * self.config.num_attention_heads
1127
+ )
1128
+ self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups
1129
+
1130
+ # Per attention head and per partition values.
1131
+ world_size = get_tensor_model_parallel_world_size()
1132
+ self.hidden_size_per_attention_head = divide(
1133
+ self.query_projection_size, self.config.num_attention_heads
1134
+ )
1135
+ self.num_attention_heads_per_partition = divide(
1136
+ self.config.num_attention_heads, world_size
1137
+ )
1138
+ self.num_query_groups_per_partition = divide(
1139
+ self.config.num_query_groups, world_size
1140
+ )
1141
+ self.core_attention = build_module(
1142
+ submodules.core_attention,
1143
+ config=self.config,
1144
+ layer_number=self.layer_number,
1145
+ attn_mask_type=self.attn_mask_type,
1146
+ )
1147
+ self.linear_qkv = build_module(
1148
+ submodules.linear_qkv,
1149
+ self.config.hidden_size,
1150
+ self.query_projection_size + 2 * self.kv_projection_size,
1151
+ config=self.config,
1152
+ init_method=self.config.init_method,
1153
+ gather_output=False,
1154
+ bias=self.config.add_bias_linear or self.config.add_qkv_bias,
1155
+ skip_bias_add=False,
1156
+ is_expert=False,
1157
+ tp_comm_buffer_name="qkv",
1158
+ )
1159
+
1160
+ if submodules.q_layernorm is not None:
1161
+ self.q_layernorm = build_module(
1162
+ submodules.q_layernorm,
1163
+ hidden_size=self.hidden_size_per_attention_head,
1164
+ config=self.config,
1165
+ eps=self.config.layernorm_epsilon,
1166
+ )
1167
+ else:
1168
+ self.q_layernorm = None
1169
+
1170
+ if submodules.k_layernorm is not None:
1171
+ self.k_layernorm = build_module(
1172
+ submodules.k_layernorm,
1173
+ hidden_size=self.hidden_size_per_attention_head,
1174
+ config=self.config,
1175
+ eps=self.config.layernorm_epsilon,
1176
+ )
1177
+ else:
1178
+ self.k_layernorm = None
1179
+ self.linear_proj = build_module(
1180
+ submodules.linear_proj,
1181
+ self.query_projection_size,
1182
+ self.config.hidden_size,
1183
+ config=self.config,
1184
+ init_method=self.config.output_layer_init_method,
1185
+ bias=self.config.add_bias_linear,
1186
+ input_is_parallel=True,
1187
+ skip_bias_add=True,
1188
+ is_expert=False,
1189
+ tp_comm_buffer_name="proj",
1190
+ )
1191
+ self.checkpoint_core_attention = (
1192
+ self.config.recompute_granularity == "selective"
1193
+ )
1194
+
1195
+ def num_parameter(self):
1196
+ result = 0
1197
+ result += self.core_attention.num_parameter()
1198
+ result += self.linear_proj.num_parameter()
1199
+ result += self.linear_qkv.num_parameter()
1200
+ if self.q_layernorm is not None:
1201
+ result += self.q_layernorm.num_parameter()
1202
+ if self.k_layernorm is not None:
1203
+ result += self.k_layernorm.num_parameter()
1204
+
1205
+ return result
1206
+
1207
+ def num_activation(self, input_shape: list[int]):
1208
+ ret = 0
1209
+ ## in estimator: act(linear) = 1.5*cum_mul(input_shape)
1210
+ ## in reality: act(linear) = cum_mul(input_shape), act(rotary) = cum_mul(input_shape), act(attn_forward_func_with_cp) = cum_mul(input_shape)
1211
+ # ret += self.linear_qkv.num_activation(input_shape)
1212
+ mixed_qkv_shape = self.linear_qkv.mock_forward(input_shape)
1213
+ new_tensor_shape = mixed_qkv_shape[:-1] + [
1214
+ self.num_query_groups_per_partition,
1215
+ (
1216
+ (
1217
+ self.num_attention_heads_per_partition
1218
+ // self.num_query_groups_per_partition
1219
+ + 2
1220
+ )
1221
+ * self.hidden_size_per_attention_head
1222
+ ),
1223
+ ]
1224
+ split_arg_list = [
1225
+ (
1226
+ self.num_attention_heads_per_partition
1227
+ // self.num_query_groups_per_partition
1228
+ * self.hidden_size_per_attention_head
1229
+ ),
1230
+ self.hidden_size_per_attention_head,
1231
+ self.hidden_size_per_attention_head,
1232
+ ]
1233
+ # [sq, b, ng, (np/ng + 2) * hn]
1234
+ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
1235
+ q_shape = new_tensor_shape[:-1] + [split_arg_list[0]]
1236
+ k_shape = new_tensor_shape[:-1] + [split_arg_list[1]]
1237
+ v_shape = new_tensor_shape[:-1] + [split_arg_list[2]]
1238
+ # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
1239
+ q_shape = (
1240
+ q_shape[:2]
1241
+ + [cum_mul(q_shape[-2:]) // self.hidden_size_per_attention_head]
1242
+ + [self.hidden_size_per_attention_head]
1243
+ )
1244
+
1245
+ if not self.checkpoint_core_attention:
1246
+ ret += self.core_attention.num_activation(q_shape, k_shape, v_shape)
1247
+ ret += self.linear_proj.num_activation(input_shape)
1248
+ ## in reality: act(linear) = cum_mul(input_shape), act(rotary) = cum_mul(input_shape), act(attn_forward_func_with_cp) = cum_mul(input_shape)
1249
+ ret += self.linear_proj.num_activation(input_shape) * 3
1250
+
1251
+ return ret
1252
+
1253
+ def mock_forward(self, input_shape: list[int]):
1254
+ return input_shape
1255
+
1256
+
1257
+ class Linear(MemEstimator):
1258
+ def __init__(
1259
+ self,
1260
+ in_features: int,
1261
+ out_features: int,
1262
+ bias: bool = True,
1263
+ device=None,
1264
+ dtype=None,
1265
+ ) -> None:
1266
+
1267
+ super().__init__()
1268
+ self.weight = (in_features, out_features)
1269
+
1270
+ def num_parameter(self):
1271
+ return self.weight[0] * self.weight[1]
1272
+
1273
+ def num_activation(self, input_shape: list[int]):
1274
+ return cum_mul(input_shape[:-1]) * self.weight[1]
1275
+
1276
+ def mock_forward(self, input_shape: list[int]):
1277
+ return input_shape[:-1] + [self.weight[1]]
1278
+
1279
+
1280
+ class MLASelfAttention(MemEstimator):
1281
+ """MLA Self-attention layer class
1282
+
1283
+ Self-attention layer takes input with size [s, b, h]
1284
+ and returns output of the same size.
1285
+ """
1286
+
1287
+ def __init__(
1288
+ self,
1289
+ config: MLATransformerConfig,
1290
+ submodules,
1291
+ layer_number: int,
1292
+ attn_mask_type=AttnMaskType.padding,
1293
+ ) -> None:
1294
+
1295
+ super().__init__()
1296
+ self.config = config
1297
+ self.layer_number = layer_number
1298
+ self.attn_mask_type = attn_mask_type
1299
+ self.attention_type = "self"
1300
+ self.world_size = get_tensor_model_parallel_world_size()
1301
+ # assert (
1302
+ # world_size == 1
1303
+ # ), "MLA is not supported with Tensor Parallelism yet, \
1304
+ # use Expert Parallelism and Pipeline Parallelism for better performance."
1305
+
1306
+ self.query_projection_size = (
1307
+ self.config.v_head_dim * self.config.num_attention_heads
1308
+ )
1309
+
1310
+ self.q_head_dim = self.config.qk_head_dim + self.config.qk_pos_emb_head_dim
1311
+
1312
+ mscale = _yarn_get_mscale(self.config.rotary_scaling_factor, self.config.mscale)
1313
+ self.softmax_scale = mscale * mscale / math.sqrt(self.q_head_dim)
1314
+
1315
+ # Per attention head and per partition values.
1316
+ world_size = get_tensor_model_parallel_world_size()
1317
+ self.hidden_size_per_attention_head = divide(
1318
+ self.query_projection_size, self.config.num_attention_heads
1319
+ )
1320
+ self.num_attention_heads_per_partition = divide(
1321
+ self.config.num_attention_heads, world_size
1322
+ )
1323
+ self.num_query_groups_per_partition = divide(
1324
+ self.config.num_query_groups, world_size
1325
+ )
1326
+ # TODO Rotary Embedding
1327
+ # self.rotary_pos_emb = YarnRotaryEmbedding(
1328
+ # self.config.qk_pos_emb_head_dim,
1329
+ # rotary_base=self.config.rotary_base,
1330
+ # scaling_factor=self.config.rotary_scaling_factor,
1331
+ # original_max_position_embeddings=self.config.max_position_embeddings,
1332
+ # beta_fast=self.config.beta_fast,
1333
+ # beta_slow=self.config.beta_slow,
1334
+ # mscale=self.config.mscale,
1335
+ # mscale_all_dim=self.config.mscale_all_dim,
1336
+ # )
1337
+
1338
+ self.core_attention = build_module(
1339
+ submodules.core_attention,
1340
+ config=self.config,
1341
+ layer_number=self.layer_number,
1342
+ attn_mask_type=self.attn_mask_type,
1343
+ attention_type=self.attention_type,
1344
+ softmax_scale=self.softmax_scale,
1345
+ k_channels=self.q_head_dim,
1346
+ v_channels=self.config.v_head_dim,
1347
+ )
1348
+
1349
+ if self.config.q_lora_rank is None:
1350
+ # Not projectiing query
1351
+ self.linear_q_proj = build_module(
1352
+ submodules.linear_q_proj,
1353
+ self.config.hidden_size,
1354
+ self.config.num_attention_heads * self.q_head_dim,
1355
+ config=self.config,
1356
+ init_method=self.config.init_method,
1357
+ gather_output=False,
1358
+ bias=False,
1359
+ skip_bias_add=False,
1360
+ is_expert=False,
1361
+ is_mla=True,
1362
+ )
1363
+
1364
+ else:
1365
+ self.linear_q_down_proj = Linear(
1366
+ self.config.hidden_size, self.config.q_lora_rank, bias=False
1367
+ )
1368
+
1369
+ self.linear_q_up_proj = build_module(
1370
+ submodules.linear_q_up_proj,
1371
+ self.config.q_lora_rank,
1372
+ self.config.num_attention_heads * self.q_head_dim,
1373
+ config=self.config,
1374
+ init_method=self.config.init_method,
1375
+ gather_output=False,
1376
+ bias=False,
1377
+ skip_bias_add=False,
1378
+ is_expert=False,
1379
+ is_mla=True,
1380
+ )
1381
+ self.linear_kv_down_proj = Linear(
1382
+ self.config.hidden_size,
1383
+ self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim,
1384
+ bias=False,
1385
+ )
1386
+
1387
+ self.linear_kv_up_proj = build_module(
1388
+ submodules.linear_kv_up_proj,
1389
+ self.config.kv_lora_rank,
1390
+ self.config.num_attention_heads
1391
+ * (self.config.qk_head_dim + self.config.v_head_dim),
1392
+ config=self.config,
1393
+ init_method=self.config.init_method,
1394
+ gather_output=False,
1395
+ bias=False,
1396
+ skip_bias_add=False,
1397
+ is_expert=False,
1398
+ is_mla=True,
1399
+ )
1400
+
1401
+ if self.config.q_lora_rank is not None:
1402
+ self.q_layernorm = build_module(
1403
+ submodules.q_layernorm,
1404
+ hidden_size=self.config.q_lora_rank,
1405
+ config=self.config,
1406
+ eps=self.config.layernorm_epsilon,
1407
+ )
1408
+
1409
+ self.kv_layernorm = build_module(
1410
+ submodules.kv_layernorm,
1411
+ hidden_size=self.config.kv_lora_rank,
1412
+ config=self.config,
1413
+ eps=self.config.layernorm_epsilon,
1414
+ )
1415
+
1416
+ # Output.
1417
+ self.linear_proj = build_module(
1418
+ submodules.linear_proj,
1419
+ self.query_projection_size,
1420
+ self.config.hidden_size,
1421
+ config=self.config,
1422
+ init_method=self.config.output_layer_init_method,
1423
+ bias=self.config.add_bias_linear,
1424
+ input_is_parallel=True,
1425
+ skip_bias_add=True,
1426
+ is_expert=False,
1427
+ tp_comm_buffer_name="proj",
1428
+ )
1429
+
1430
+ self.checkpoint_core_attention = (
1431
+ self.config.recompute_granularity == "selective"
1432
+ )
1433
+
1434
+ def num_parameter(self):
1435
+ result = 0
1436
+ result += self.core_attention.num_parameter()
1437
+ result += self.linear_proj.num_parameter()
1438
+ if self.config.q_lora_rank is None:
1439
+ result += self.linear_q_proj.num_parameter()
1440
+ else:
1441
+ result += self.linear_q_down_proj.num_parameter()
1442
+ result += self.linear_q_up_proj.num_parameter()
1443
+ result += self.linear_kv_down_proj.num_parameter()
1444
+ result += self.linear_kv_up_proj.num_parameter()
1445
+ result += self.kv_layernorm.num_parameter()
1446
+ if self.config.q_lora_rank is not None:
1447
+ result += self.q_layernorm.num_parameter()
1448
+
1449
+ return result
1450
+
1451
+ def num_activation(self, input_shape: list[int]):
1452
+ q_len, bsz, _ = input_shape
1453
+ ret = 0
1454
+ if self.config.q_lora_rank is not None:
1455
+ ret += self.linear_q_down_proj.num_activation(input_shape)
1456
+ q_compressed_shape = self.linear_q_down_proj.mock_forward(input_shape)
1457
+ ret += self.q_layernorm.num_activation(q_compressed_shape)
1458
+ ret += self.linear_q_up_proj.num_activation(q_compressed_shape)
1459
+ q_shape = self.linear_q_up_proj.mock_forward(q_compressed_shape)
1460
+ else:
1461
+ # hidden_states:[s, b, 2048], q: [s, b, n * 192]
1462
+ ret += self.linear_q_proj.num_activation(input_shape)
1463
+ q_shape = self.linear_q_proj.mock_forward(input_shape)
1464
+
1465
+ # kv_combined: [s, b, 576]
1466
+ ret += self.linear_kv_down_proj.num_activation(input_shape)
1467
+ kv_combined_shape = self.linear_kv_down_proj.mock_forward(input_shape)
1468
+ # kv_compressed:[s, b, 512], k_pos_emb: [s, b, 64]
1469
+ kv_compressed_shape = kv_combined_shape[:-1] + [self.config.kv_lora_rank]
1470
+
1471
+ # kv: [s, b, 2048]
1472
+ ret += self.kv_layernorm.num_activation(kv_compressed_shape)
1473
+ ret += self.linear_kv_up_proj.num_activation(kv_compressed_shape)
1474
+
1475
+ q_shape = [q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim]
1476
+ k_shape = [q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim]
1477
+ v_shape = [
1478
+ q_len,
1479
+ bsz,
1480
+ self.num_attention_heads_per_partition,
1481
+ self.config.v_head_dim,
1482
+ ]
1483
+
1484
+ if not self.checkpoint_core_attention:
1485
+ ret += self.core_attention.num_activation(q_shape, k_shape, v_shape)
1486
+
1487
+ ret += self.linear_proj.num_activation(input_shape)
1488
+
1489
+ return ret
1490
+
1491
+ def mock_forward(self, input_shape: list[int]):
1492
+ return input_shape
1493
+
1494
+
1495
+ class TENorm:
1496
+ def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5):
1497
+ from megatron.core.extensions.transformer_engine import _get_extra_te_kwargs, te
1498
+
1499
+ if config.normalization == "LayerNorm":
1500
+ # TODO layernorm
1501
+ pass
1502
+ elif config.normalization == "RMSNorm":
1503
+ assert hasattr(
1504
+ te.pytorch, "RMSNorm"
1505
+ ), "Transformer-Engine >= v0.11 required to use this feature"
1506
+ instance = RMSNorm(
1507
+ hidden_size=hidden_size,
1508
+ eps=eps,
1509
+ sequence_parallel=config.sequence_parallel,
1510
+ zero_centered_gamma=config.layernorm_zero_centered_gamma,
1511
+ **_get_extra_te_kwargs(config),
1512
+ )
1513
+ else:
1514
+ raise Exception("Only LayerNorm and RMSNorm are curently supported")
1515
+
1516
+ return instance
1517
+
1518
+
1519
+ def build_module(
1520
+ spec_or_module: Union[ModuleSpec, type], *args, **kwargs
1521
+ ) -> MemEstimator:
1522
+ """replace module with MemEstimators"""
1523
+ if isinstance(spec_or_module, types.FunctionType):
1524
+ return globals()[spec_or_module.__name__]
1525
+
1526
+ if isinstance(spec_or_module, ModuleSpec) and isinstance(
1527
+ spec_or_module.module, types.FunctionType
1528
+ ):
1529
+ assert False
1530
+ return spec_or_module.module
1531
+
1532
+ if isinstance(spec_or_module, type):
1533
+ module = spec_or_module
1534
+ elif hasattr(spec_or_module, "module") and isinstance(spec_or_module.module, type):
1535
+ module = spec_or_module.module
1536
+ else:
1537
+ module = import_module(spec_or_module.module)
1538
+
1539
+ if isinstance(module, types.FunctionType):
1540
+ assert False
1541
+ return module
1542
+
1543
+ if hasattr(spec_or_module, "submodules") and spec_or_module.submodules is not None:
1544
+ kwargs["submodules"] = spec_or_module.submodules
1545
+
1546
+ try:
1547
+ module = globals()[module.__name__]
1548
+ return module(
1549
+ *args,
1550
+ **spec_or_module.params if hasattr(spec_or_module, "params") else {},
1551
+ **kwargs,
1552
+ )
1553
+ except Exception as e:
1554
+ # import ipdb
1555
+
1556
+ # ipdb.set_trace()
1557
+ # improve the error message since we hide the module name in the line above
1558
+ import sys
1559
+
1560
+ raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
1561
+ sys.exc_info()[2]
1562
+ )
1563
+
1564
+
1565
+ from megatron.core.transformer.transformer_block import (
1566
+ TransformerBlockSubmodules,
1567
+ BaseTransformerLayer,
1568
+ LayerNormImpl,
1569
+ )
1570
+
1571
+
1572
+ def _get_block_submodules(
1573
+ config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec]
1574
+ ) -> TransformerBlockSubmodules:
1575
+ """
1576
+ Retrieve or construct TransformerBlockSubmodules based on the provided specification.
1577
+
1578
+ Args:
1579
+ config (TransformerConfig): Configuration object for the transformer model.
1580
+ spec (Union[TransformerBlockSubmodules, ModuleSpec]): Specification for the
1581
+ transformer block submodules. Can be either a TransformerBlockSubmodules
1582
+ instance or a ModuleSpec.
1583
+
1584
+ Returns:
1585
+ TransformerBlockSubmodules: The submodules for the transformer block.
1586
+ """
1587
+
1588
+ # Transformer block submodules.
1589
+ if isinstance(spec, TransformerBlockSubmodules):
1590
+ return spec
1591
+
1592
+ # ModuleSpec here is generally assumed to be for a transformer layer that
1593
+ # is implemented in `transformer_layer.py` or if it subclasses
1594
+ # `BaseTransformerLayer` from the `transformer_layer.py` file.
1595
+ elif isinstance(spec, ModuleSpec):
1596
+ if issubclass(spec.module, TransformerBlock):
1597
+ return spec.submodules
1598
+ elif issubclass(spec.module, BaseTransformerLayer):
1599
+ num_layers = get_num_layers_to_build(config)
1600
+ return TransformerBlockSubmodules(
1601
+ layer_specs=[spec] * num_layers, layer_norm=LayerNormImpl
1602
+ )
1603
+ else:
1604
+ raise Exception(f"specialize for {spec.module.__name__}.")
1605
+ else:
1606
+ raise Exception(f"specialize for {type(spec).__name__}.")
1607
+
1608
+
1609
+ def get_num_layers_to_build(config: TransformerConfig) -> int:
1610
+ """
1611
+ Determine the number of transformer layers to build for the current pipeline stage.
1612
+ Args:
1613
+ config (TransformerConfig): Configuration object containing transformer model parameters.
1614
+
1615
+ Returns:
1616
+ int: The number of layers to be built for the current pipeline stage.
1617
+ """
1618
+ if (
1619
+ config.num_layers_in_first_pipeline_stage is not None
1620
+ or config.num_layers_in_last_pipeline_stage is not None
1621
+ ):
1622
+
1623
+ assert not (
1624
+ config.account_for_embedding_in_pipeline_split
1625
+ or config.account_for_loss_in_pipeline_split
1626
+ ), " \
1627
+ Does not support standalone embedding stage and standalone loss stage with uneven pp"
1628
+ # Number of layers to distribute over rest of pipeline stages
1629
+ layers_to_distribute = config.num_layers
1630
+ # Number of pipeline stages left for distributing transformer layers
1631
+ pipeline_stages_left = get_pipeline_model_parallel_world_size()
1632
+
1633
+ # If the uneven first (last) pipeline stage is enabled, remove the specified number
1634
+ # of layers to calculate the number of layers on each middle pipeline stage.
1635
+ if config.num_layers_in_first_pipeline_stage is not None:
1636
+ layers_to_distribute -= config.num_layers_in_first_pipeline_stage
1637
+ pipeline_stages_left -= 1
1638
+
1639
+ if config.num_layers_in_last_pipeline_stage is not None:
1640
+ layers_to_distribute -= config.num_layers_in_last_pipeline_stage
1641
+ pipeline_stages_left -= 1
1642
+
1643
+ assert (
1644
+ layers_to_distribute % pipeline_stages_left == 0
1645
+ ), "With uneven pipelineing the left over layers must be divisible by left over stages"
1646
+ num_layers_per_pipeline_rank = layers_to_distribute // pipeline_stages_left
1647
+
1648
+ # If the uneven first (last) pipeline stage is enabled, return the specified number
1649
+ # of layers for all virtual pipeline parallel stages within the first (last) pipeline
1650
+ # parallel stage.
1651
+ if (
1652
+ is_pipeline_first_stage(ignore_virtual=True)
1653
+ and config.num_layers_in_first_pipeline_stage is not None
1654
+ ):
1655
+ num_layers_per_pipeline_rank = config.num_layers_in_first_pipeline_stage
1656
+
1657
+ if (
1658
+ is_pipeline_last_stage(ignore_virtual=True)
1659
+ and config.num_layers_in_last_pipeline_stage is not None
1660
+ ):
1661
+ num_layers_per_pipeline_rank = config.num_layers_in_last_pipeline_stage
1662
+ else:
1663
+ # Include the embedding layer and loss layer into pipeline parallelism partition
1664
+ num_layers = config.num_layers
1665
+ if config.account_for_embedding_in_pipeline_split:
1666
+ num_layers += 1
1667
+
1668
+ if config.account_for_loss_in_pipeline_split:
1669
+ num_layers += 1
1670
+
1671
+ assert (
1672
+ num_layers % config.pipeline_model_parallel_size == 0
1673
+ ), "num_layers should be divisible by pipeline_model_parallel_size"
1674
+ num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size
1675
+
1676
+ # if get_virtual_pipeline_model_parallel_world_size() is not None:
1677
+ # # Interleaved pipeline parallelism:
1678
+ # # Number of layers in each model chunk is the number of layers in the stage,
1679
+ # # divided by the number of model chunks in a stage.
1680
+ # # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
1681
+ # # layers to stages like (each list is a model chunk):
1682
+ # # Stage 0: [0] [2] [4] [6]
1683
+ # # Stage 1: [1] [3] [5] [7]
1684
+ # # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
1685
+ # # layers to stages like (each list is a model chunk):
1686
+ # # Stage 0: [0, 1] [4, 5]
1687
+ # # Stage 1: [2, 3] [6, 7]
1688
+ # vp_size = get_virtual_pipeline_model_parallel_world_size()
1689
+
1690
+ # assert (
1691
+ # num_layers_per_pipeline_rank % vp_size == 0
1692
+ # ), "num_layers_per_pipeline_rank should be divisible by vp_size"
1693
+ # num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
1694
+
1695
+ # num_layers_to_build = num_layers_per_virtual_rank
1696
+
1697
+ # else:
1698
+ # # Non-interleaved pipeline parallelism:
1699
+ # # Each stage gets a contiguous set of layers.
1700
+ # num_layers_to_build = num_layers_per_pipeline_rank
1701
+ num_layers_to_build = num_layers_per_pipeline_rank
1702
+ # The embedding (or loss) layer cannot function as a standalone transformer layer
1703
+ # Reduce the number of layers to construct by 1 on the first (or last) stage if the
1704
+ # embedding (or loss) layer is included in the pipeline parallelism partition and placement.
1705
+ if is_pipeline_first_stage() and config.account_for_embedding_in_pipeline_split:
1706
+ num_layers_to_build -= 1
1707
+ assert (
1708
+ num_layers_to_build >= 0
1709
+ ), "Not enough layers in the first virtual pipeline stage"
1710
+
1711
+ if is_pipeline_last_stage() and config.account_for_loss_in_pipeline_split:
1712
+ num_layers_to_build -= 1
1713
+ assert (
1714
+ num_layers_to_build >= 0
1715
+ ), "Not enough layers in the last virtual pipeline stage"
1716
+
1717
+ return num_layers_to_build
1718
+
1719
+
1720
+ def get_transformer_layer_offset(config: TransformerConfig):
1721
+ """Get the index offset of current pipeline stage, given the level of pipelining."""
1722
+ pipeline_rank = get_pipeline_model_parallel_rank()
1723
+ # if not is_inside_encoder():
1724
+ if True:
1725
+ pp_decoder_start = 0
1726
+ if pp_decoder_start is not None:
1727
+ pipeline_rank = pipeline_rank - pp_decoder_start
1728
+
1729
+ if config.pipeline_model_parallel_size > 1:
1730
+
1731
+ if (
1732
+ config.num_layers_in_first_pipeline_stage is not None
1733
+ or config.num_layers_in_last_pipeline_stage is not None
1734
+ ):
1735
+ # Calculate number of pipeline stages to distribute the remaining Transformer
1736
+ # layers after deducting the Transformer layers in the first or the last stages
1737
+ middle_pipeline_stages = config.pipeline_model_parallel_size
1738
+ middle_pipeline_stages -= sum(
1739
+ [
1740
+ 1 if x is not None else 0
1741
+ for x in (
1742
+ config.num_layers_in_first_pipeline_stage,
1743
+ config.num_layers_in_last_pipeline_stage,
1744
+ )
1745
+ ]
1746
+ )
1747
+
1748
+ # Calculate layers to distribute in each pipeline stage. If the
1749
+ # num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage
1750
+ # are not set, we will not enable uneven pipeline. All layers will be treated
1751
+ # as middle layers.
1752
+ num_layers_in_first_pipeline_stage = (
1753
+ 0
1754
+ if config.num_layers_in_first_pipeline_stage is None
1755
+ else config.num_layers_in_first_pipeline_stage
1756
+ )
1757
+ num_layers_in_last_pipeline_stage = (
1758
+ 0
1759
+ if config.num_layers_in_last_pipeline_stage is None
1760
+ else config.num_layers_in_last_pipeline_stage
1761
+ )
1762
+
1763
+ middle_num_layers = (
1764
+ config.num_layers
1765
+ - num_layers_in_first_pipeline_stage
1766
+ - num_layers_in_last_pipeline_stage
1767
+ )
1768
+
1769
+ if middle_pipeline_stages > 0:
1770
+ num_layers_per_pipeline_rank = (
1771
+ middle_num_layers // middle_pipeline_stages
1772
+ )
1773
+ else:
1774
+ num_layers_per_pipeline_rank = 0
1775
+
1776
+ middle_pipeline_rank = (
1777
+ pipeline_rank
1778
+ if config.num_layers_in_first_pipeline_stage is None
1779
+ else pipeline_rank - 1
1780
+ )
1781
+
1782
+ if pipeline_rank == 0:
1783
+ offset = 0
1784
+ else:
1785
+ offset = (
1786
+ middle_pipeline_rank * num_layers_per_pipeline_rank
1787
+ ) + num_layers_in_first_pipeline_stage
1788
+ else:
1789
+ num_layers = config.num_layers
1790
+
1791
+ # Increase the number of layers by one if we include the embedding (loss)
1792
+ # layer into pipeline parallelism partition and placement
1793
+ if config.account_for_embedding_in_pipeline_split:
1794
+ num_layers += 1
1795
+
1796
+ if config.account_for_loss_in_pipeline_split:
1797
+ num_layers += 1
1798
+
1799
+ num_layers_per_pipeline_rank = (
1800
+ num_layers // config.pipeline_model_parallel_size
1801
+ )
1802
+
1803
+ offset = pipeline_rank * num_layers_per_pipeline_rank
1804
+
1805
+ # Reduce the offset of embedding layer from the total layer number
1806
+ if (
1807
+ config.account_for_embedding_in_pipeline_split
1808
+ and not is_pipeline_first_stage()
1809
+ ):
1810
+ offset -= 1
1811
+ else:
1812
+ offset = 0
1813
+ return offset
webui/index.html ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Megatron Memory Estimator</title>
7
+ <link rel="stylesheet" href="style.css">
8
+ </head>
9
+ <body>
10
+ <div class="container">
11
+ <h1>Megatron Memory Estimator</h1>
12
+ <div class="disclaimer-banner">
13
+ Note: This estimator only measures the GPU memory directly managed by PyTorch when running Megatron. It does not include extra consumption from NCCL communication buffers, kernel fusion, overlap optimizations, CUDA Graphs, etc. Please use the "Overhead per GPU" option below to account for these additional costs.
14
+ </div>
15
+
16
+ <div class="main-layout">
17
+ <div class="top-section">
18
+ <div class="config-column">
19
+ <form id="config-form">
20
+ <h2>Configuration</h2>
21
+ <div class="form-group">
22
+ <label for="model-select">Select a Local Config:</label>
23
+ <select id="model-select" name="model">
24
+ <option value="">Loading...</option>
25
+ </select>
26
+ </div>
27
+
28
+ <!-- All settings are now in one block -->
29
+ <div class="form-row">
30
+ <div class="form-group">
31
+ <label for="num-gpus">Total GPUs:</label>
32
+ <input type="number" id="num-gpus" name="num_gpus" value="8" step="8" min="8">
33
+ </div>
34
+ <div class="form-group">
35
+ <label for="mbs">micro batch size:</label>
36
+ <input type="number" id="mbs" name="mbs" value="1" min="1">
37
+ </div>
38
+ <div class="form-group">
39
+ <label for="seq-len">SeqLen:</label>
40
+ <input type="number"id="seq-len" name="seq-len" value="4096" min="1">
41
+ </div>
42
+ </div>
43
+
44
+ <div class="form-group">
45
+ <input type="checkbox" id="use-distributed-optimizer" name="use_distributed_optimizer" checked>
46
+ <label for="use-distributed-optimizer" class="inline-label">Use Distributed Optimizer</label>
47
+ </div>
48
+
49
+ <div class="form-row">
50
+ <div class="form-group">
51
+ <label for="recompute-granularity">Recomputation:</label>
52
+ <select id="recompute-granularity" name="recompute_granularity">
53
+ <option value="none">None</option>
54
+ <option value="selective">Selective</option>
55
+ <option value="full">Full</option>
56
+ </select>
57
+ </div>
58
+ <div class="form-group recompute-options" style="display: none;">
59
+ <label for="recompute-method">Method:</label>
60
+ <select id="recompute-method" name="recompute_method">
61
+ <option value="uniform">Uniform</option>
62
+ <option value="block">Block</option>
63
+ </select>
64
+ </div>
65
+ <div class="form-group recompute-options" style="display: none;">
66
+ <label for="recompute-num-layers">Layers:</label>
67
+ <input type="number" id="recompute-num-layers" name="recompute_num_layers" value="1" min="1">
68
+ </div>
69
+ </div>
70
+
71
+ <div class="form-row">
72
+ <div class="form-group">
73
+ <label for="tp">TP:</label>
74
+ <select id="tp" name="tp"></select>
75
+ </div>
76
+ <div class="form-group">
77
+ <label for="pp">PP:</label>
78
+ <input type="number" id="pp" name="pp" value="1" min="1">
79
+ </div>
80
+ <div class="form-group">
81
+ <label for="ep">EP:</label>
82
+ <select id="ep" name="ep"></select>
83
+ </div>
84
+ <div class="form-group">
85
+ <label for="cp">CP:</label>
86
+ <select id="cp" name="cp"></select>
87
+ </div>
88
+ </div>
89
+ <div class="form-row">
90
+ <div class="form-group">
91
+ <label for="vpp">VPP:</label>
92
+ <input type="number" id="vpp" name="vpp" placeholder="None" min="1">
93
+ </div>
94
+ <div class="form-group">
95
+ <label for="etp">ETP:</label>
96
+ <input type="number" id="etp" name="etp" placeholder="None" min="1">
97
+ </div>
98
+ </div>
99
+ <div class="form-row">
100
+ <div class="form-group">
101
+ <label for="num_layers_in_first_pipeline_stage">First Stage Layers:</label>
102
+ <input type="number" id="num_layers_in_first_pipeline_stage" name="num_layers_in_first_pipeline_stage" placeholder="None" min="0">
103
+ </div>
104
+ <div class="form-group">
105
+ <label for="num_layers_in_last_pipeline_stage">Last Stage Layers:</label>
106
+ <input type="number" id="num_layers_in_last_pipeline_stage" name="num_layers_in_last_pipeline_stage" placeholder="None" min="0">
107
+ </div>
108
+ </div>
109
+ <div class="form-row">
110
+ <div class="form-group">
111
+ <label for="overhead">Overhead per GPU:</label>
112
+ <select id="overhead" name="overhead">
113
+ <option value="5">5GB</option>
114
+ <option value="10" selected>10GB</option>
115
+ </select>
116
+ </div>
117
+ </div>
118
+
119
+ <div id="validation-message" class="error-message" style="display: none;"></div>
120
+ <div class="button-container">
121
+ <button type="submit">Estimate</button>
122
+ </div>
123
+ </form>
124
+ </div>
125
+
126
+ <div class="output-column">
127
+ <div class="config-editor-wrapper">
128
+ <h2>Model Config (Editable)</h2>
129
+ <textarea id="config-editor" rows="20"></textarea>
130
+ </div>
131
+ </div>
132
+ </div>
133
+
134
+ <div class="bottom-section">
135
+ <div id="output-container">
136
+ <div id="loading" style="display: none;">Calculating...</div>
137
+ <div id="history-wrapper">
138
+ <h3>History</h3>
139
+ <table id="history-table">
140
+ <thead>
141
+ <tr>
142
+ <th>Model</th>
143
+ <th>Weight Optimizer (GB)</th>
144
+ <th>Activation (GB)</th>
145
+ <th>Total (GB/GPU)</th>
146
+ <th>Actions</th>
147
+ </tr>
148
+ </thead>
149
+ <tbody>
150
+ </tbody>
151
+ </table>
152
+ <button id="clear-history" style="margin-top: 1em;">Clear History</button>
153
+ </div>
154
+ </div>
155
+ </div>
156
+ </div>
157
+ </div>
158
+ <script src="script.js"></script>
159
+ <footer class="footer">
160
+ <p>&copy; 2025 <a href="https://github.com/ISEEKYAN" target="_blank">ISEEKYAN</a>. Developed at NVIDIA.</p>
161
+ </footer>
162
+ </body>
163
+ </html>
webui/main.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from fastapi import FastAPI, Body
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi.responses import FileResponse
6
+ import requests
7
+ from pydantic import BaseModel, field_validator
8
+ from typing import Optional
9
+ from mbridge import AutoBridge
10
+ from estimate import estimate_from_config
11
+ from megatron.core import parallel_state as mpu
12
+ import argparse
13
+ import json
14
+ import tempfile
15
+
16
+ # The directory of the current script (main.py)
17
+ WEBUI_DIR = os.path.dirname(os.path.abspath(__file__))
18
+
19
+ app = FastAPI()
20
+
21
+ # Mount static files from the webui directory
22
+ app.mount("/static", StaticFiles(directory=WEBUI_DIR), name="static")
23
+
24
+
25
+ @app.get("/")
26
+ async def read_index():
27
+ return FileResponse(os.path.join(WEBUI_DIR, 'index.html'))
28
+
29
+ @app.get("/style.css")
30
+ async def read_css():
31
+ return FileResponse(os.path.join(WEBUI_DIR, 'style.css'))
32
+
33
+ @app.get("/script.js")
34
+ async def read_js():
35
+ return FileResponse(os.path.join(WEBUI_DIR, 'script.js'))
36
+
37
+
38
+ SUPPORTED_MODELS = [
39
+ "Qwen/Qwen3-235B-A22B",
40
+ "Qwen/Qwen3-30B-A3B",
41
+ "Qwen/Qwen3-32B",
42
+ "Qwen/Qwen3-14B",
43
+ "Qwen/Qwen3-8B",
44
+ "Qwen/Qwen2.5-7B",
45
+ "Qwen/Qwen2.5-14B",
46
+ "Qwen/Qwen2.5-32B",
47
+ "Qwen/Qwen2.5-72B",
48
+ "moonshotai/Moonlight-16B-A3B",
49
+ "moonshotai/Kimi-K2-Instruct",
50
+ "deepseek-ai/DeepSeek-V3",
51
+ ]
52
+
53
+
54
+ @app.get("/local-hf-configs")
55
+ async def get_supported_models():
56
+ """Return the list of HF model identifiers supported by the UI."""
57
+ return SUPPORTED_MODELS
58
+
59
+ @app.get("/get-megatron-config/{model_path:path}")
60
+ async def get_remote_hf_config(model_path: str):
61
+ """Fetch the HuggingFace config.json for the given model id."""
62
+ url = f"https://huggingface.co/{model_path}/raw/main/config.json"
63
+ try:
64
+ resp = requests.get(url, timeout=10)
65
+ resp.raise_for_status()
66
+ return resp.json()
67
+ except Exception as e:
68
+ return {"error": f"Failed to fetch config from {url}: {str(e)}"}
69
+
70
+
71
+ class MBridgeEstimateConfig(BaseModel):
72
+ hf_model_path: str
73
+ custom_hf_config: Optional[dict] = None # Renamed for clarity
74
+
75
+ # Hardware & Training
76
+ num_gpus: int = 8
77
+ mbs: int = 1
78
+ seq_len: int = 4096
79
+ use_distributed_optimizer: bool = True
80
+ # Recompute settings are now part of the main config
81
+ recompute_granularity: str = "selective"
82
+ recompute_method: str = "uniform"
83
+ recompute_num_layers: Optional[int] = 1
84
+
85
+ # Parallelism
86
+ tp: int = 1
87
+ pp: int = 1
88
+ ep: int = 1
89
+ cp: int = 1
90
+ vpp: Optional[int] = None
91
+ etp: Optional[int] = None
92
+
93
+ # Pipeline stage layer counts
94
+ num_layers_in_first_pipeline_stage: Optional[int] = None
95
+ num_layers_in_last_pipeline_stage: Optional[int] = None
96
+
97
+ @field_validator('num_gpus')
98
+ def num_gpus_must_be_multiple_of_8(cls, v):
99
+ if v <= 0 or v % 8 != 0:
100
+ raise ValueError('must be a positive multiple of 8')
101
+ return v
102
+
103
+ def patch_parallel_states(config: MBridgeEstimateConfig):
104
+ from mbridge.core.parallel_states import ParallelStates
105
+ ParallelStates.get_default_parallel_states = lambda: ParallelStates(
106
+ tp_size=config.tp,
107
+ pp_size=config.pp,
108
+ ep_size=config.ep,
109
+ cp_size=config.cp,
110
+ vpp_size=config.vpp,
111
+ etp_size=config.etp,
112
+ )
113
+
114
+ @app.post("/estimate_with_mbridge")
115
+ async def estimate_with_mbridge(config: MBridgeEstimateConfig):
116
+ # Validate Inputs
117
+ if config.num_gpus <= 0 or config.num_gpus % 8 != 0:
118
+ return {"error": "Total number of GPUs must be a positive multiple of 8."}
119
+
120
+ parallel_product = config.tp * config.pp * config.cp
121
+ if parallel_product == 0: # Avoid division by zero
122
+ return {"error": "Parallelism dimensions (TP, PP, CP) cannot be zero."}
123
+
124
+ if config.num_gpus % parallel_product != 0:
125
+ return {"error": f"Number of GPUs ({config.num_gpus}) must be divisible by the product of TP*PP*CP ({parallel_product})."}
126
+
127
+ patch_parallel_states(config)
128
+
129
+ # If the path is just a filename, assume it's in our local model-configs dir
130
+ hf_model_path = config.hf_model_path
131
+ # This logic needs to change. The custom config from the UI is an HF config, not a Megatron config.
132
+ # We need to load it via a temporary file.
133
+ if config.custom_hf_config:
134
+ try:
135
+ # Create a temporary file to save the custom HF config
136
+ with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix=".json", dir=os.path.join(WEBUI_DIR, 'model-configs')) as tmp:
137
+ json.dump(config.custom_hf_config, tmp)
138
+ tmp_path = tmp.name
139
+
140
+ # Load the bridge from the temporary config file
141
+ from transformers import AutoConfig
142
+ AutoConfig.trust_remote_code = True
143
+ bridge = AutoBridge.from_pretrained(tmp_path)
144
+ tf_config = bridge.config
145
+ hf_config = bridge.hf_config
146
+
147
+ finally:
148
+ # Ensure the temporary file is deleted
149
+ if 'tmp_path' in locals() and os.path.exists(tmp_path):
150
+ os.remove(tmp_path)
151
+ else:
152
+ # If no custom config, load from the original path
153
+ if not os.path.isabs(hf_model_path) and not hf_model_path.startswith(('http', './', '../')):
154
+ hf_model_path = os.path.join(WEBUI_DIR, 'model-configs', hf_model_path)
155
+ bridge = AutoBridge.from_pretrained(hf_model_path)
156
+ tf_config = bridge.config
157
+ hf_config = bridge.hf_config
158
+
159
+ # --- Configuration Unification ---
160
+ # Update the tf_config with values from the form. This makes tf_config the single source of truth.
161
+ tf_config.tensor_model_parallel_size = config.tp
162
+ tf_config.pipeline_model_parallel_size = config.pp
163
+ tf_config.expert_model_parallel_size = config.ep
164
+ tf_config.context_parallel_size = config.cp
165
+ tf_config.recompute_granularity = config.recompute_granularity
166
+ tf_config.recompute_method = config.recompute_method
167
+ tf_config.recompute_num_layers = config.recompute_num_layers
168
+ tf_config.num_layers_per_virtual_pipeline_stage = config.vpp if config.vpp and config.vpp > 1 else None
169
+
170
+ if config.num_layers_in_first_pipeline_stage is not None:
171
+ tf_config.num_layers_in_first_pipeline_stage = config.num_layers_in_first_pipeline_stage
172
+ if config.num_layers_in_last_pipeline_stage is not None:
173
+ tf_config.num_layers_in_last_pipeline_stage = config.num_layers_in_last_pipeline_stage
174
+ # print(tf_config)
175
+
176
+ # Create a minimal 'args' object with parameters not present in TransformerConfig
177
+ args = argparse.Namespace()
178
+ args.micro_batch_size = config.mbs
179
+ args.seq_length = config.seq_len
180
+ args.use_distributed_optimizer = config.use_distributed_optimizer
181
+ args.data_parallel_size = config.num_gpus // parallel_product
182
+ args.expert_tensor_parallel_size = config.etp if config.etp else 1
183
+
184
+ # These are required by the estimator but can be derived or defaulted
185
+ args.transformer_impl = "transformer_engine"
186
+ args.fp8 = False
187
+ args.num_experts = getattr(tf_config, 'num_moe_experts', 1) # Needed for layer spec
188
+ args.moe_grouped_gemm = True # Default
189
+ args.qk_layernorm = tf_config.qk_layernorm
190
+ args.multi_latent_attention = "deepseek" in getattr(hf_config, "model_type", "")
191
+ args.padded_vocab_size = getattr(hf_config, "vocab_size")
192
+ args.max_position_embeddings = getattr(hf_config, "max_position_embeddings")
193
+ args.tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False)
194
+
195
+
196
+ # This function now returns a list of reports, one for each PP rank
197
+ raw_reports_list = estimate_from_config(tf_config, args)
198
+
199
+ # The report from estimate.py now has the correct units (GB), so no conversion is needed.
200
+ # We just need to remove the complex 'details' part for the main display table.
201
+ processed_reports = []
202
+ for report in raw_reports_list:
203
+ # Create a copy of the report and remove the 'details' key
204
+ processed_report = report.copy()
205
+ processed_report.pop('details', None)
206
+ processed_reports.append(processed_report)
207
+
208
+ return {
209
+ "processed_report": processed_reports,
210
+ "raw_report": raw_reports_list
211
+ }
webui/model-configs/qwen3-14b.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "head_dim": 128,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 5120,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 17408,
14
+ "max_position_embeddings": 40960,
15
+ "max_window_layers": 40,
16
+ "model_type": "qwen3",
17
+ "num_attention_heads": 40,
18
+ "num_hidden_layers": 40,
19
+ "num_key_value_heads": 8,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_scaling": null,
22
+ "rope_theta": 1000000,
23
+ "sliding_window": null,
24
+ "tie_word_embeddings": false,
25
+ "torch_dtype": "bfloat16",
26
+ "transformers_version": "4.51.0",
27
+ "use_cache": true,
28
+ "use_sliding_window": false,
29
+ "vocab_size": 151936
30
+ }
webui/model-configs/qwen3-235b-a22b.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3MoeForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "decoder_sparse_step": 1,
9
+ "eos_token_id": 151645,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 4096,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 12288,
15
+ "max_position_embeddings": 40960,
16
+ "max_window_layers": 94,
17
+ "mlp_only_layers": [],
18
+ "model_type": "qwen3_moe",
19
+ "moe_intermediate_size": 1536,
20
+ "norm_topk_prob": true,
21
+ "num_attention_heads": 64,
22
+ "num_experts": 128,
23
+ "num_experts_per_tok": 8,
24
+ "num_hidden_layers": 94,
25
+ "num_key_value_heads": 4,
26
+ "output_router_logits": false,
27
+ "rms_norm_eps": 1e-06,
28
+ "rope_scaling": null,
29
+ "rope_theta": 1000000.0,
30
+ "router_aux_loss_coef": 0.001,
31
+ "sliding_window": null,
32
+ "tie_word_embeddings": false,
33
+ "torch_dtype": "bfloat16",
34
+ "transformers_version": "4.51.0",
35
+ "use_cache": true,
36
+ "use_sliding_window": false,
37
+ "vocab_size": 151936
38
+ }
webui/model-configs/qwen3-30b-a3b.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3MoeForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "decoder_sparse_step": 1,
9
+ "eos_token_id": 151645,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 2048,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 6144,
15
+ "max_position_embeddings": 40960,
16
+ "max_window_layers": 48,
17
+ "mlp_only_layers": [],
18
+ "model_type": "qwen3_moe",
19
+ "moe_intermediate_size": 768,
20
+ "norm_topk_prob": true,
21
+ "num_attention_heads": 32,
22
+ "num_experts": 128,
23
+ "num_experts_per_tok": 8,
24
+ "num_hidden_layers": 48,
25
+ "num_key_value_heads": 4,
26
+ "output_router_logits": false,
27
+ "rms_norm_eps": 1e-06,
28
+ "rope_scaling": null,
29
+ "rope_theta": 1000000.0,
30
+ "router_aux_loss_coef": 0.001,
31
+ "sliding_window": null,
32
+ "tie_word_embeddings": false,
33
+ "torch_dtype": "bfloat16",
34
+ "transformers_version": "4.51.0",
35
+ "use_cache": true,
36
+ "use_sliding_window": false,
37
+ "vocab_size": 151936
38
+ }
webui/model-configs/qwen3-32b.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "head_dim": 128,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 5120,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 25600,
14
+ "max_position_embeddings": 40960,
15
+ "max_window_layers": 64,
16
+ "model_type": "qwen3",
17
+ "num_attention_heads": 64,
18
+ "num_hidden_layers": 64,
19
+ "num_key_value_heads": 8,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_scaling": null,
22
+ "rope_theta": 1000000,
23
+ "sliding_window": null,
24
+ "tie_word_embeddings": false,
25
+ "torch_dtype": "bfloat16",
26
+ "transformers_version": "4.51.0",
27
+ "use_cache": true,
28
+ "use_sliding_window": false,
29
+ "vocab_size": 151936
30
+ }
webui/model-configs/qwen3-8b.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "head_dim": 128,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 4096,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 12288,
14
+ "max_position_embeddings": 40960,
15
+ "max_window_layers": 36,
16
+ "model_type": "qwen3",
17
+ "num_attention_heads": 32,
18
+ "num_hidden_layers": 36,
19
+ "num_key_value_heads": 8,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_scaling": null,
22
+ "rope_theta": 1000000,
23
+ "sliding_window": null,
24
+ "tie_word_embeddings": false,
25
+ "torch_dtype": "bfloat16",
26
+ "transformers_version": "4.51.0",
27
+ "use_cache": true,
28
+ "use_sliding_window": false,
29
+ "vocab_size": 151936
30
+ }
webui/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ mbridge
webui/script.js ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ document.addEventListener('DOMContentLoaded', () => {
2
+ // Initial UI setup
3
+ loadLocalConfigs();
4
+ updateHistoryView();
5
+ setupEventListeners();
6
+ updateParallelismOptions();
7
+ validateParallelismLive();
8
+ toggleEpBasedOnConfig(); // Disable EP initially
9
+ });
10
+
11
+ // Utility: convert ANSI color codes (red 31, green 32) to HTML spans for display
12
+ function ansiToHtml(str) {
13
+ if (!str) return '';
14
+ // Replace known ANSI codes
15
+ return str
16
+ .replace(/\u001b\[31m/g, '<span class="ansi-red">')
17
+ .replace(/\u001b\[32m/g, '<span class="ansi-green">')
18
+ .replace(/\u001b\[33m/g, '<span class="ansi-yellow">')
19
+ .replace(/\u001b\[34m/g, '<span class="ansi-blue">')
20
+ .replace(/\u001b\[35m/g, '<span class="ansi-magenta">')
21
+ .replace(/\u001b\[36m/g, '<span class="ansi-cyan">')
22
+ .replace(/\u001b\[0m/g, '</span>');
23
+ }
24
+
25
+ function setupEventListeners() {
26
+ document.getElementById('config-form').addEventListener('submit', (e) => {
27
+ e.preventDefault();
28
+ submitForm();
29
+ });
30
+
31
+ document.getElementById('model-select').addEventListener('change', loadSelectedModelConfig);
32
+
33
+ document.getElementById('recompute-granularity').addEventListener('change', (e) => {
34
+ const recomputeOptions = document.querySelectorAll('.recompute-options');
35
+ recomputeOptions.forEach(opt => {
36
+ opt.style.display = e.target.value === 'full' ? 'block' : 'none';
37
+ });
38
+ });
39
+
40
+ const liveValidationInputs = ['num-gpus', 'tp', 'pp', 'ep', 'cp', 'etp', 'config-editor'];
41
+ liveValidationInputs.forEach(id => {
42
+ const input = document.getElementById(id);
43
+ if(input) {
44
+ input.addEventListener('change', validateParallelismLive);
45
+ if (id === 'num-gpus') {
46
+ input.addEventListener('change', updateParallelismOptions);
47
+ }
48
+ }
49
+ });
50
+
51
+ document.getElementById('config-editor').addEventListener('input', toggleEpBasedOnConfig);
52
+ document.getElementById('history-table').addEventListener('click', handleHistoryAction);
53
+ document.getElementById('clear-history').addEventListener('click', clearHistory);
54
+ }
55
+
56
+
57
+ async function loadLocalConfigs() {
58
+ const modelSelect = document.getElementById('model-select');
59
+ const defaultConfigName = 'Qwen/Qwen3-235B-A22B'; // Updated default model
60
+
61
+ try {
62
+ const response = await fetch('/local-hf-configs');
63
+ const configs = await response.json();
64
+
65
+ modelSelect.innerHTML = '<option value="">Select a model...</option>';
66
+ // Add custom option to allow user supplied configs
67
+ modelSelect.innerHTML += '<option value="__custom__">Custom (paste JSON below)...</option>';
68
+ configs.forEach(config => {
69
+ modelSelect.innerHTML += `<option value="${config}">${config}</option>`;
70
+ });
71
+
72
+ // Check if the default config exists and select it
73
+ if (configs.includes(defaultConfigName)) {
74
+ modelSelect.value = defaultConfigName;
75
+ // Await the loading of the model config to ensure it's ready
76
+ await loadSelectedModelConfig();
77
+ }
78
+
79
+ } catch (error) {
80
+ modelSelect.innerHTML = '<option value="">Error loading configs</option>';
81
+ console.error('Error loading local configs:', error);
82
+ }
83
+ }
84
+
85
+ async function loadSelectedModelConfig() {
86
+ const modelSelect = document.getElementById('model-select');
87
+ const editor = document.getElementById('config-editor');
88
+ const selectedConfig = modelSelect.value;
89
+ const messageDiv = document.getElementById('validation-message'); // move early for use in all branches
90
+ let configData = null; // declare for wider scope
91
+
92
+ if (!selectedConfig) {
93
+ editor.value = '';
94
+ toggleEpBasedOnConfig();
95
+ if (messageDiv) messageDiv.style.display = 'none';
96
+ return;
97
+ } else if (selectedConfig === '__custom__') {
98
+ // Custom config: do not fetch, user must paste JSON
99
+ editor.value = '';
100
+ toggleEpBasedOnConfig();
101
+ if (messageDiv) messageDiv.style.display = 'none';
102
+ return;
103
+ }
104
+
105
+ try {
106
+ const response = await fetch(`/get-megatron-config/${encodeURIComponent(selectedConfig)}`);
107
+ configData = await response.json();
108
+ if (configData.error) {
109
+ editor.value = `Error: ${configData.error}`;
110
+ } else {
111
+ editor.value = JSON.stringify(configData, null, 2);
112
+ }
113
+ } catch (error) {
114
+ editor.value = 'Failed to fetch model configuration.';
115
+ console.error('Error fetching model config:', error);
116
+ }
117
+
118
+ // Trigger validation and UI updates after loading new config
119
+ validateParallelismLive();
120
+ toggleEpBasedOnConfig();
121
+
122
+ // Show Kimi-K2-Instruct warning if needed
123
+ if (selectedConfig.includes('Kimi-K2-Instruct') && configData && configData.model_type !== 'deepseek_v3') {
124
+ messageDiv.textContent = 'Notice: For Kimi-K2-Instruct the config field "model_type" must be set to "deepseek_v3" before memory estimation.';
125
+ messageDiv.style.display = 'block';
126
+ } else if (messageDiv) {
127
+ messageDiv.style.display = 'none';
128
+ }
129
+ }
130
+
131
+
132
+ function getFormValues(isSubmission = false) {
133
+ const form = document.getElementById('config-form');
134
+ const formData = new FormData(form);
135
+ const modelSelect = document.getElementById('model-select');
136
+
137
+ const hfPath = modelSelect.value;
138
+ if (!hfPath) {
139
+ // We will now handle this case in the submitForm function instead of an alert.
140
+ return null;
141
+ }
142
+
143
+ const editor = document.getElementById('config-editor');
144
+ let customConfig = null;
145
+ try {
146
+ // Only parse if the editor has content
147
+ if (editor.value) {
148
+ customConfig = JSON.parse(editor.value);
149
+ }
150
+ } catch (e) {
151
+ // Only alert on final submission, not on live validation
152
+ if (isSubmission) {
153
+ // alert('Model Config is not valid JSON.'); // Removing alert
154
+ }
155
+ return null; // Return null if JSON is invalid
156
+ }
157
+
158
+ const vppInput = formData.get('vpp');
159
+ const etpInput = formData.get('etp');
160
+
161
+ return {
162
+ hf_model_path: hfPath,
163
+ custom_hf_config: customConfig, // Renamed for clarity
164
+ num_gpus: parseInt(formData.get('num_gpus')),
165
+ mbs: parseInt(formData.get('mbs')),
166
+ seq_len: parseInt(formData.get('seq-len')),
167
+ use_distributed_optimizer: document.getElementById('use-distributed-optimizer').checked,
168
+ recompute_granularity: formData.get('recompute_granularity'),
169
+ recompute_method: formData.get('recompute_method'),
170
+ recompute_num_layers: parseInt(formData.get('recompute_num_layers')),
171
+ tp: parseInt(formData.get('tp')),
172
+ pp: parseInt(formData.get('pp')),
173
+ ep: parseInt(formData.get('ep')) || 1, // Default to 1 if disabled/null
174
+ cp: parseInt(formData.get('cp')),
175
+ vpp: vppInput ? parseInt(vppInput) : null,
176
+ etp: etpInput ? parseInt(etpInput) : null,
177
+ num_layers_in_first_pipeline_stage: formData.get('num_layers_in_first_pipeline_stage') ? parseInt(formData.get('num_layers_in_first_pipeline_stage')) : null,
178
+ num_layers_in_last_pipeline_stage: formData.get('num_layers_in_last_pipeline_stage') ? parseInt(formData.get('num_layers_in_last_pipeline_stage')) : null,
179
+ overhead: parseInt(formData.get('overhead')),
180
+ };
181
+ }
182
+
183
+ async function submitForm() {
184
+ const messageDiv = document.getElementById('validation-message');
185
+ messageDiv.textContent = '';
186
+ messageDiv.style.display = 'none';
187
+
188
+ // Get all form values first. We use getFormValues(false) to avoid any legacy alerts
189
+ // and handle all validation directly within this function for clarity.
190
+ const formValues = getFormValues(false);
191
+
192
+ // === START SUBMISSION VALIDATION ===
193
+
194
+ // 1. Check if form values could be retrieved. This catches both missing model selection
195
+ // and invalid JSON, as getFormValues returns null in those cases.
196
+ if (!formValues) {
197
+ if (!document.getElementById('model-select').value) {
198
+ messageDiv.textContent = 'Validation Error: Please select a model config.';
199
+ } else {
200
+ messageDiv.textContent = 'Validation Error: Model Config is not valid JSON.';
201
+ }
202
+ messageDiv.style.display = 'block';
203
+ return;
204
+ }
205
+
206
+ // Custom config must have valid JSON
207
+ if (document.getElementById('model-select').value === '__custom__' && !formValues.custom_hf_config) {
208
+ messageDiv.textContent = 'Validation Error: Please paste a valid model configuration JSON for the custom model.';
209
+ messageDiv.style.display = 'block';
210
+ return;
211
+ }
212
+
213
+ // 2. Perform all numeric and parallelism validation.
214
+ const { num_gpus, tp, pp, ep, cp, etp, custom_hf_config } = formValues;
215
+ const num_kv_heads = custom_hf_config?.num_key_value_heads || null;
216
+
217
+ let errors = [];
218
+ if (tp * pp * cp > num_gpus) {
219
+ errors.push(`TP*PP*CP (${tp * pp * cp}) > GPUs (${num_gpus}).`);
220
+ }
221
+ if (etp){
222
+ if (etp * pp * cp * ep > num_gpus) {
223
+ errors.push(`ETP*PP*CP*EP (${etp * pp * cp * ep}) > GPUs (${num_gpus}).`);
224
+ }
225
+ } else {
226
+ if (tp * pp * cp * ep > num_gpus) {
227
+ errors.push(`TP*PP*CP*EP (${tp * pp * cp * ep}) > GPUs (${num_gpus}) when ETP is not set.`);
228
+ }
229
+ }
230
+ if (num_kv_heads && tp > num_kv_heads) {
231
+ errors.push(`TP (${tp}) > Num KV Heads (${num_kv_heads}).`);
232
+ }
233
+
234
+ if (errors.length > 0) {
235
+ messageDiv.textContent = 'Validation Error: ' + errors.join(' ');
236
+ messageDiv.style.display = 'block';
237
+ return;
238
+ }
239
+ // === END SUBMISSION VALIDATION ===
240
+
241
+ const loading = document.getElementById('loading');
242
+ const submitBtn = document.querySelector('#config-form button[type="submit"]');
243
+ loading.style.display = 'block';
244
+ if (submitBtn) submitBtn.disabled = true;
245
+
246
+ try {
247
+ const response = await fetch('/estimate_with_mbridge', {
248
+ method: 'POST',
249
+ headers: { 'Content-Type': 'application/json' },
250
+ body: JSON.stringify(formValues) // Send the now fully-validated formValues
251
+ });
252
+
253
+ console.log('Response Status:', response.status);
254
+
255
+ if (response.ok) {
256
+ const data = await response.json();
257
+
258
+ // FIX: Ensure history wrapper is visible before updating and showing details
259
+ document.getElementById('history-wrapper').style.display = 'block';
260
+
261
+ saveToHistory(formValues, data);
262
+ updateHistoryView();
263
+ const newEntryRow = document.querySelector('#history-table tbody tr:first-child');
264
+ if (newEntryRow) {
265
+ const detailBtn = newEntryRow.querySelector('.detail-btn');
266
+ if (detailBtn) {
267
+ // We need to pass the event object structure to handleHistoryAction
268
+ handleHistoryAction({ target: detailBtn });
269
+ }
270
+ }
271
+ } else {
272
+ const error = await response.text();
273
+ console.error('Server error response:', error);
274
+ // Since we removed the main results display, show error in the validation div
275
+ messageDiv.textContent = `Server Error: ${error}`;
276
+ messageDiv.style.display = 'block';
277
+ }
278
+ } catch (error) {
279
+ console.error('Fetch API Error:', error);
280
+ messageDiv.textContent = `Client Error: ${error.message}`;
281
+ messageDiv.style.display = 'block';
282
+ } finally {
283
+ loading.style.display = 'none';
284
+ if (submitBtn) submitBtn.disabled = false;
285
+ }
286
+ }
287
+
288
+ function renderTable(details, rawFullReport) {
289
+ if (!details || details.length === 0) {
290
+ return '<p>No detailed memory breakdown available.</p>';
291
+ }
292
+
293
+ const headers = Object.keys(details[0]);
294
+ headers.push('Breakdown');
295
+
296
+ let table = '<table><thead><tr>';
297
+ headers.forEach(h => table += `<th>${h}</th>`);
298
+ table += '</tr></thead><tbody>';
299
+
300
+ details.forEach(row => {
301
+ const ppRank = row.pp_rank;
302
+ // FIX: Look in the full raw report array passed in.
303
+ const rawDataForRank = rawFullReport ? rawFullReport.find(r => r.pp_rank === ppRank) : null;
304
+
305
+ // FIX: Change to `let` to allow modification for highlighting.
306
+ let modelBreakdown = (rawDataForRank && rawDataForRank.model_breakdown)
307
+ ? rawDataForRank.model_breakdown
308
+ : 'No breakdown available.';
309
+
310
+ // Add syntax-like highlighting for params and activations
311
+ // Basic HTML escaping for safety before inserting spans
312
+ modelBreakdown = modelBreakdown.replace(/&/g, "&amp;").replace(/</g, "&lt;").replace(/>/g, "&gt;");
313
+ modelBreakdown = modelBreakdown
314
+ .replace(/(n_params=[0-9.]+[a-zA-Z]*)/g, '<span class="highlight-red">$1</span>')
315
+ .replace(/(n_act=[0-9.]+[a-zA-Z]*)/g, '<span class="highlight-red">$1</span>');
316
+
317
+ // Main row with data
318
+ table += `<tr data-pp-rank="${ppRank}">`;
319
+ headers.forEach(h => {
320
+ if (h !== 'Breakdown') {
321
+ table += `<td>${row[h]}</td>`;
322
+ }
323
+ });
324
+ table += `<td><button class="action-btn raw-per-rank-btn" data-pp-rank="${ppRank}">Raw</button></td>`;
325
+ table += '</tr>';
326
+
327
+ // Hidden row for the breakdown
328
+ table += `<tr class="raw-breakdown-row" data-pp-rank="${ppRank}" style="display: none;">
329
+ <td colspan="${headers.length}">
330
+ <pre>${modelBreakdown}</pre>
331
+ </td>
332
+ </tr>`;
333
+ });
334
+
335
+ table += '</tbody></table>';
336
+ return table;
337
+ }
338
+
339
+ function saveToHistory(params, resultData) {
340
+ let history = JSON.parse(localStorage.getItem('estimationHistory')) || [];
341
+ const historyEntry = {
342
+ params: params,
343
+ result: resultData, // Store the full result object { processed_report, raw_report }
344
+ id: new Date().getTime()
345
+ };
346
+ history.unshift(historyEntry); // Add to the beginning
347
+ if (history.length > 20) { // Keep history size manageable
348
+ history.pop();
349
+ }
350
+ localStorage.setItem('estimationHistory', JSON.stringify(history));
351
+ }
352
+
353
+ function updateHistoryView() {
354
+ const history = JSON.parse(localStorage.getItem('estimationHistory')) || [];
355
+ const historyTableBody = document.querySelector('#history-table tbody');
356
+ const historyWrapper = document.getElementById('history-wrapper');
357
+ historyTableBody.innerHTML = '';
358
+
359
+ if (history.length === 0) {
360
+ historyWrapper.style.display = 'none';
361
+ return;
362
+ }
363
+
364
+ historyWrapper.style.display = 'block';
365
+
366
+ history.forEach(item => {
367
+ const row = document.createElement('tr');
368
+
369
+ const params = item.params;
370
+ const resultData = item.result || {};
371
+
372
+ // FIX: Handle both old and new data structures for compatibility.
373
+ const details = (resultData.report && resultData.report.details) ? resultData.report.details : (resultData.processed_report || []);
374
+ const pp0Result = details.find(r => r.pp_rank === 0) || details[0] || {};
375
+
376
+ const modelName = params.hf_model_path.split('/').pop();
377
+
378
+ // Build parallelism string, e.g., "TP2 PP2 VPP2"
379
+ const parallelismParts = [];
380
+ ['tp', 'pp', 'ep', 'cp', 'vpp', 'etp'].forEach(p => {
381
+ const value = params[p];
382
+ if (value && value > 1) {
383
+ parallelismParts.push(`${p.toUpperCase()}${value}`);
384
+ }
385
+ });
386
+ const parallelismInfo = parallelismParts.join(' ') || 'No Parallelism';
387
+
388
+ const overheadGb = params.overhead ? parseInt(params.overhead) : 0;
389
+ const baseTotal = details.length > 0 ? Math.max(...details.map(r => r.total_gb || 0)) : null;
390
+ const totalGb = baseTotal !== null ? (baseTotal + overheadGb).toFixed(2) : 'N/A';
391
+
392
+ const seqLen = params.seq_len || 0;
393
+ const formattedSeqLen = seqLen >= 1024 ? `${seqLen / 1024}k` : seqLen;
394
+ const sequenceInfo = `${params.mbs || 'N/A'}*${formattedSeqLen}`;
395
+
396
+ row.innerHTML = `
397
+ <td>
398
+ <div>${modelName}</div>
399
+ <div class="model-meta-info">
400
+ <span>GPUs: ${params.num_gpus || 'N/A'}</span>
401
+ <span>${parallelismInfo}</span>
402
+ <span>Sequence: ${sequenceInfo}</span>
403
+ </div>
404
+ </td>
405
+ <td>${pp0Result.weight_optimizer_gb || 'N/A'}</td>
406
+ <td>${pp0Result.activation_gb || 'N/A'}</td>
407
+ <td>${totalGb}</td>
408
+ <td>
409
+ <button class="restore-btn" data-id="${item.id}">Restore</button>
410
+ <button class="detail-btn" data-id="${item.id}">Detail</button>
411
+ <button class="delete-btn" data-id="${item.id}">Delete</button>
412
+ </td>
413
+ `;
414
+ historyTableBody.appendChild(row);
415
+ });
416
+ }
417
+
418
+ async function handleHistoryAction(e) {
419
+ const button = e.target.closest('button');
420
+ if (!button) return;
421
+
422
+ // Handle breakdown toggle first
423
+ if (button.classList.contains('breakdown-btn')) {
424
+ const ppRank = button.dataset.ppRank;
425
+ const detailTable = button.closest('table');
426
+ if (!detailTable) return;
427
+
428
+ const breakdownRow = detailTable.querySelector(`tr.breakdown-row[data-pp-rank="${ppRank}"]`);
429
+ if (!breakdownRow) return;
430
+
431
+ const isVisible = breakdownRow.style.display !== 'none';
432
+ breakdownRow.style.display = isVisible ? 'none' : 'table-row';
433
+ button.textContent = isVisible ? 'Breakdown' : 'Hide';
434
+ return; // Do not continue to other handlers
435
+ }
436
+
437
+ if (!button.matches('.detail-btn, .restore-btn, .delete-btn')) return;
438
+
439
+ const id = parseInt(button.dataset.id, 10);
440
+ const history = JSON.parse(localStorage.getItem('estimationHistory')) || [];
441
+ const entry = history.find(item => item.id === id);
442
+
443
+ if (!entry) {
444
+ console.error('History entry not found for id:', id);
445
+ return;
446
+ }
447
+
448
+ const row = button.closest('tr');
449
+
450
+ if (button.classList.contains('detail-btn')) {
451
+ const isDetailsVisible = row.nextElementSibling && row.nextElementSibling.classList.contains('detail-row');
452
+
453
+ document.querySelectorAll('.detail-row').forEach(detailRow => {
454
+ const prevRow = detailRow.previousElementSibling;
455
+ const detailBtn = prevRow.querySelector('.detail-btn');
456
+ if (detailRow !== row.nextElementSibling) {
457
+ detailRow.remove();
458
+ if (detailBtn) detailBtn.textContent = 'Detail';
459
+ }
460
+ });
461
+
462
+ if (isDetailsVisible) {
463
+ row.nextElementSibling.remove();
464
+ button.textContent = 'Detail';
465
+ } else {
466
+ const detailRow = document.createElement('tr');
467
+ detailRow.classList.add('detail-row');
468
+ const detailCell = detailRow.insertCell();
469
+ detailCell.colSpan = row.cells.length;
470
+
471
+ // FIX: Handle both old and new data structures for compatibility.
472
+ const report = entry.result.report;
473
+ const details = (report && report.details) ? report.details : (entry.result.processed_report || []);
474
+ const modelBreakdown = (report && report.model_breakdown) ? report.model_breakdown : null;
475
+
476
+ if (details && details.length > 0) {
477
+ const newTable = document.createElement('table');
478
+ // Determine if breakdown information exists per-row or globally
479
+ let headers = Object.keys(details[0]);
480
+
481
+ // If old-format data, there is a 'model_breakdown' key on each detail row
482
+ const hasRowBreakdown = headers.includes('model_breakdown');
483
+
484
+ // Remove the raw model_breakdown column from headers to keep table compact
485
+ if (hasRowBreakdown) {
486
+ headers = headers.filter(h => h !== 'model_breakdown');
487
+ }
488
+
489
+ // Include global breakdown if provided, or row breakdowns if present
490
+ const includeBreakdown = hasRowBreakdown || (modelBreakdown && typeof modelBreakdown === 'string');
491
+
492
+ if (includeBreakdown) {
493
+ headers.push('Breakdown');
494
+ }
495
+
496
+ const headerRow = newTable.insertRow();
497
+ headers.forEach(h => {
498
+ const th = document.createElement('th');
499
+ th.textContent = h;
500
+ headerRow.appendChild(th);
501
+ });
502
+
503
+ details.forEach(detail => {
504
+ const newRow = newTable.insertRow();
505
+ headers.forEach(header => {
506
+ if (header === 'Breakdown') {
507
+ const cell = newRow.insertCell();
508
+ cell.innerHTML = `<button class="breakdown-btn" data-pp-rank="${detail.pp_rank}">Breakdown</button>`;
509
+ } else {
510
+ const cell = newRow.insertCell();
511
+ let value = detail[header];
512
+ if (typeof value === 'number' && !Number.isInteger(value)) {
513
+ value = value.toFixed(4);
514
+ }
515
+ cell.textContent = value;
516
+ }
517
+ });
518
+
519
+ // Hidden breakdown row
520
+ if (includeBreakdown) {
521
+ const breakdownRow = newTable.insertRow();
522
+ breakdownRow.classList.add('breakdown-row');
523
+ breakdownRow.dataset.ppRank = detail.pp_rank;
524
+ breakdownRow.style.display = 'none';
525
+ const breakdownCell = breakdownRow.insertCell();
526
+ breakdownCell.colSpan = headers.length;
527
+ const rowSpecificBreakdown = hasRowBreakdown ? (detail.model_breakdown || '') : modelBreakdown;
528
+ const htmlBreakdown = ansiToHtml(rowSpecificBreakdown);
529
+ breakdownCell.innerHTML = `<pre class="model-breakdown-view">${htmlBreakdown || 'No breakdown available.'}</pre>`;
530
+ }
531
+ });
532
+
533
+ detailCell.appendChild(newTable);
534
+ } else {
535
+ detailCell.innerHTML = 'No detailed per-rank results available.';
536
+ }
537
+
538
+ row.after(detailRow);
539
+ button.textContent = 'Hide';
540
+ }
541
+ } else if (button.classList.contains('restore-btn')) {
542
+ restoreForm(entry.params);
543
+ } else if (button.classList.contains('delete-btn')) {
544
+ deleteHistoryEntry(id);
545
+ }
546
+ }
547
+
548
+ function deleteHistoryEntry(id) {
549
+ let history = JSON.parse(localStorage.getItem('estimationHistory')) || [];
550
+ const updatedHistory = history.filter(item => item.id != id);
551
+ localStorage.setItem('estimationHistory', JSON.stringify(updatedHistory));
552
+ updateHistoryView();
553
+
554
+ // If history is now empty, hide the whole output container
555
+ if (updatedHistory.length === 0) {
556
+ // document.getElementById('output-container').style.display = 'none';
557
+ }
558
+ }
559
+
560
+ function clearHistory() {
561
+ localStorage.removeItem('estimationHistory');
562
+ updateHistoryView();
563
+ // document.getElementById('output-container').style.display = 'none';
564
+ }
565
+
566
+
567
+ function restoreForm(params) {
568
+ if (!params) return;
569
+
570
+ const setElementValue = (id, value, defaultValue = '') => {
571
+ const element = document.getElementById(id);
572
+ if (element) {
573
+ if (element.type === 'checkbox') {
574
+ element.checked = value ?? defaultValue;
575
+ } else {
576
+ element.value = value ?? defaultValue;
577
+ }
578
+ }
579
+ };
580
+
581
+ setElementValue('num-gpus', params.num_gpus, 8);
582
+ setElementValue('mbs', params.mbs, 1);
583
+ setElementValue('seq-len', params.seq_len, 4096);
584
+ setElementValue('use-distributed-optimizer', params.use_distributed_optimizer, true);
585
+ setElementValue('recompute_granularity', params.recompute_granularity, 'selective');
586
+ setElementValue('recompute_method', params.recompute_method, 'uniform');
587
+ setElementValue('recompute_num_layers', params.recompute_num_layers, 1);
588
+ setElementValue('tp', params.tp, 1);
589
+ setElementValue('pp', params.pp, 1);
590
+ setElementValue('ep', params.ep, 1);
591
+ setElementValue('cp', params.cp, 1);
592
+ setElementValue('vpp', params.vpp);
593
+ setElementValue('etp', params.etp);
594
+ setElementValue('num_layers_in_first_pipeline_stage', params.num_layers_in_first_pipeline_stage);
595
+ setElementValue('num_layers_in_last_pipeline_stage', params.num_layers_in_last_pipeline_stage);
596
+ setElementValue('overhead', params.overhead, 10);
597
+
598
+ const modelSelect = document.getElementById('model-select');
599
+ if (modelSelect && params.hf_model_path) {
600
+ modelSelect.value = params.hf_model_path;
601
+ }
602
+
603
+ // Manually trigger change event for UI updates
604
+ const recomputeSelect = document.getElementById('recompute_granularity');
605
+ if (recomputeSelect) {
606
+ recomputeSelect.dispatchEvent(new Event('change'));
607
+ }
608
+ }
609
+
610
+ function updateParallelismOptions() {
611
+ const numGpusInput = document.getElementById('num-gpus');
612
+ if (!numGpusInput) return;
613
+
614
+ const numGpus = parseInt(numGpusInput.value);
615
+ if (isNaN(numGpus) || numGpus <= 0) {
616
+ return; // Don't update if GPU count is invalid
617
+ }
618
+
619
+ const tpSelect = document.getElementById('tp');
620
+ const epSelect = document.getElementById('ep');
621
+ const cpSelect = document.getElementById('cp');
622
+
623
+ // PP is now a manual input, so we only handle TP, EP, CP here.
624
+ const selects = [tpSelect, epSelect, cpSelect];
625
+
626
+ const powersOfTwo = [1];
627
+ for (let i = 1; (1 << i) <= numGpus; i++) {
628
+ powersOfTwo.push(1 << i);
629
+ }
630
+
631
+ selects.forEach(select => {
632
+ if (!select) return;
633
+ const currentVal = select.value;
634
+ select.innerHTML = ''; // Clear existing options
635
+
636
+ powersOfTwo.forEach(val => {
637
+ const option = document.createElement('option');
638
+ option.value = val;
639
+ option.textContent = val;
640
+ select.appendChild(option);
641
+ });
642
+
643
+ // Try to restore the previous value, otherwise default to 1
644
+ if (powersOfTwo.includes(parseInt(currentVal))) {
645
+ select.value = currentVal;
646
+ } else {
647
+ select.value = 1;
648
+ }
649
+ });
650
+ }
651
+
652
+ function validateParallelismLive() {
653
+ const messageDiv = document.getElementById('validation-message');
654
+ // Pass isSubmission = false to getFormValues to prevent alerts during live validation
655
+ const formValues = getFormValues(false);
656
+
657
+ if (!formValues) {
658
+ messageDiv.textContent = '';
659
+ return true;
660
+ }
661
+
662
+ const { num_gpus, tp, pp, ep, cp, etp, custom_hf_config } = formValues;
663
+ // The key is the same in the HF config, so this logic remains valid.
664
+ const num_kv_heads = custom_hf_config?.num_key_value_heads || null;
665
+
666
+ let errors = [];
667
+ if (tp * pp * cp > num_gpus) {
668
+ errors.push(`TP*PP*CP (${tp*pp*cp}) > GPUs (${num_gpus}).`);
669
+ }
670
+ if (etp) {
671
+ if (etp * pp * cp * ep > num_gpus) {
672
+ errors.push(`ETP*PP*CP*EP (${etp*pp*cp*ep}) > GPUs (${num_gpus}).`);
673
+ }
674
+ } else {
675
+ if (tp * pp * cp * ep > num_gpus) {
676
+ errors.push(`TP*PP*CP*EP (${tp*pp*cp*ep}) > GPUs (${num_gpus}) when ETP is not set.`);
677
+ }
678
+ }
679
+ if (num_kv_heads && tp > num_kv_heads) {
680
+ errors.push(`TP (${tp}) > Num KV Heads (${num_kv_heads}).`);
681
+ }
682
+
683
+ if (errors.length > 0) {
684
+ messageDiv.textContent = 'Validation Error: ' + errors.join(' ');
685
+ messageDiv.style.display = 'block';
686
+ } else {
687
+ messageDiv.textContent = '';
688
+ messageDiv.style.display = 'none';
689
+ }
690
+ return errors.length === 0;
691
+ }
692
+
693
+ function toggleEpBasedOnConfig() {
694
+ const editor = document.getElementById('config-editor');
695
+ const epSelect = document.getElementById('ep');
696
+ if (!editor || !epSelect) return;
697
+
698
+ let config = null;
699
+ try {
700
+ if (editor.value) {
701
+ config = JSON.parse(editor.value);
702
+ }
703
+ } catch (e) {
704
+ // Invalid JSON, disable EP as a safety measure
705
+ epSelect.disabled = true;
706
+ return;
707
+ }
708
+
709
+ if (config && config.num_experts && config.num_experts > 0) {
710
+ epSelect.disabled = false;
711
+ } else {
712
+ epSelect.disabled = true;
713
+ epSelect.value = 1; // Reset to 1 if disabled
714
+ }
715
+ }
webui/style.css ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
3
+ line-height: 1.6;
4
+ background-color: #f4f4f4;
5
+ color: #333;
6
+ margin: 0;
7
+ padding: 1em;
8
+ }
9
+
10
+ .container {
11
+ max-width: 1600px;
12
+ margin: auto;
13
+ background: #fff;
14
+ padding: 2em;
15
+ border-radius: 8px;
16
+ box-shadow: 0 0 20px rgba(0, 0, 0, 0.05);
17
+ }
18
+
19
+ .main-layout {
20
+ display: flex;
21
+ flex-direction: column; /* Main axis is vertical */
22
+ gap: 2em;
23
+ }
24
+
25
+ .top-section {
26
+ display: flex;
27
+ flex-direction: row; /* Children are horizontal */
28
+ gap: 2em;
29
+ }
30
+
31
+ .config-column, .output-column {
32
+ flex: 1; /* Each column takes up half the space */
33
+ display: flex;
34
+ flex-direction: column;
35
+ }
36
+
37
+ /* The editor wrapper should grow to fill the space */
38
+ .config-editor-wrapper {
39
+ flex-grow: 1;
40
+ display: flex;
41
+ flex-direction: column;
42
+ }
43
+
44
+ #config-editor {
45
+ flex-grow: 1; /* The textarea itself should grow */
46
+ width: 100%;
47
+ box-sizing: border-box; /* Include padding and border in the element's total width and height */
48
+ resize: vertical; /* Allow vertical resizing */
49
+ }
50
+
51
+
52
+ .bottom-section {
53
+ width: 100%;
54
+ }
55
+
56
+ .form-row {
57
+ display: flex;
58
+ gap: 1em;
59
+ align-items: flex-end;
60
+ }
61
+
62
+ .form-row .form-group {
63
+ flex: 1; /* Allow groups to grow and fill space */
64
+ margin-bottom: 0.8em;
65
+ }
66
+
67
+ .form-group {
68
+ margin-bottom: 0.8em; /* Reduced from default */
69
+ }
70
+
71
+ .form-group label {
72
+ display: block;
73
+ margin-bottom: 0.25em; /* Reduced */
74
+ font-weight: 500;
75
+ }
76
+
77
+ .form-group label.inline-label {
78
+ display: inline-block;
79
+ margin-left: 0.5em;
80
+ font-weight: normal;
81
+ }
82
+
83
+ .form-group input[type="number"],
84
+ .form-group select {
85
+ width: 100%;
86
+ padding: 6px 10px; /* Reduced padding */
87
+ border-radius: 4px;
88
+ border: 1px solid #ccc;
89
+ box-sizing: border-box;
90
+ }
91
+
92
+ button {
93
+ background-color: #3498db;
94
+ color: white;
95
+ padding: 10px 15px;
96
+ border: none;
97
+ border-radius: 4px;
98
+ cursor: pointer;
99
+ font-size: 16px;
100
+ margin-top: 10px;
101
+ }
102
+
103
+ button:hover {
104
+ background-color: #2980b9;
105
+ }
106
+
107
+ #results {
108
+ background-color: #ecf0f1;
109
+ padding: 15px;
110
+ border-radius: 4px;
111
+ white-space: pre-wrap;
112
+ word-wrap: break-word;
113
+ min-height: 100px;
114
+ }
115
+
116
+ .results-container {
117
+ margin-top: 20px;
118
+ }
119
+
120
+ /* New styles for results table */
121
+ table {
122
+ width: 100%;
123
+ border-collapse: collapse;
124
+ margin-top: 20px;
125
+ }
126
+
127
+ th, td {
128
+ border: 1px solid #ddd;
129
+ padding: 12px;
130
+ text-align: left;
131
+ }
132
+
133
+ th {
134
+ background-color: #f2f2f2;
135
+ font-weight: bold;
136
+ }
137
+
138
+ tbody tr:nth-child(even) {
139
+ background-color: #f9f9f9;
140
+ }
141
+
142
+ tbody tr:hover {
143
+ background-color: #f1f1f1;
144
+ }
145
+
146
+ .error {
147
+ color: #e74c3c;
148
+ font-weight: bold;
149
+ }
150
+
151
+ .button-container {
152
+ grid-column: 1 / -1; /* Span across all columns */
153
+ text-align: center;
154
+ margin-top: 20px;
155
+ }
156
+
157
+ /* History Section */
158
+ .history-container {
159
+ margin-top: 40px;
160
+ border-top: 1px solid #e0e0e0;
161
+ padding-top: 20px;
162
+ }
163
+
164
+ .history-container h2 {
165
+ display: flex;
166
+ justify-content: space-between;
167
+ align-items: center;
168
+ }
169
+
170
+ #history-list table {
171
+ margin-top: 10px;
172
+ }
173
+
174
+ .small-button {
175
+ padding: 4px 8px;
176
+ font-size: 0.8em;
177
+ background-color: #e74c3c;
178
+ }
179
+
180
+ .small-button:hover {
181
+ background-color: #c0392b;
182
+ }
183
+
184
+ .history-item-actions {
185
+ display: flex;
186
+ gap: 10px;
187
+ }
188
+
189
+ #output-container {
190
+ margin-top: 2em;
191
+ padding: 1.5em;
192
+ background-color: #f9f9f9;
193
+ border: 1px solid #ddd;
194
+ border-radius: 8px;
195
+ }
196
+
197
+ #results-wrapper h3, #history-wrapper h3 {
198
+ margin-top: 0;
199
+ border-bottom: 2px solid #eee;
200
+ padding-bottom: 0.5em;
201
+ margin-bottom: 1em;
202
+ }
203
+
204
+ #results-display table {
205
+ width: 100%;
206
+ border-collapse: collapse;
207
+ }
208
+
209
+ #results-display th, #results-display td {
210
+ padding: 8px 12px;
211
+ border: 1px solid #ddd;
212
+ text-align: left;
213
+ }
214
+
215
+ #results-display th {
216
+ background-color: #f2f2f2;
217
+ }
218
+
219
+ #history-table {
220
+ width: 100%;
221
+ border-collapse: collapse;
222
+ }
223
+
224
+ #history-table th, #history-table td {
225
+ padding: 8px 12px;
226
+ border: 1px solid #ddd;
227
+ text-align: left;
228
+ }
229
+
230
+ #history-table th {
231
+ background-color: #f2f2f2;
232
+ }
233
+
234
+ #history-table td:last-child {
235
+ text-align: right;
236
+ }
237
+
238
+ #raw-json-output {
239
+ background-color: #2d2d2d;
240
+ color: #f1f1f1;
241
+ padding: 1em;
242
+ border-radius: 5px;
243
+ max-height: 500px;
244
+ overflow-y: auto;
245
+ }
246
+
247
+ #clear-history {
248
+ background-color: #dc3545;
249
+ }
250
+
251
+ #clear-history:hover {
252
+ background-color: #c82333;
253
+ }
254
+
255
+ .error-message {
256
+ color: #dc3545;
257
+ background-color: #f8d7da;
258
+ border: 1px solid #f5c6cb;
259
+ padding: 0.75rem 1.25rem;
260
+ margin-top: 1rem;
261
+ margin-bottom: 1rem;
262
+ border-radius: 0.25rem;
263
+ text-align: center;
264
+ }
265
+
266
+ /* Responsive Design for smaller screens */
267
+ @media (max-width: 992px) {
268
+ .top-section {
269
+ flex-direction: column;
270
+ }
271
+ }
272
+
273
+ .history-detail-row td {
274
+ background-color: #333;
275
+ padding: 15px;
276
+ border-top: 2px solid #555;
277
+ text-align: left; /* Align content to the left */
278
+ }
279
+
280
+ .history-detail-row pre {
281
+ background-color: #1e1e1e;
282
+ color: #d4d4d4;
283
+ padding: 10px;
284
+ border-radius: 4px;
285
+ white-space: pre-wrap;
286
+ word-break: break-all;
287
+ }
288
+
289
+ .history-detail-row table {
290
+ width: 100%;
291
+ border-collapse: collapse;
292
+ margin: 0;
293
+ }
294
+
295
+ .history-detail-row table th {
296
+ background-color: #e0e0e0;
297
+ color: #333;
298
+ padding: 8px 12px;
299
+ border: 1px solid #555;
300
+ }
301
+
302
+ .history-detail-row table td {
303
+ color: #d4d4d4;
304
+ padding: 8px 12px;
305
+ border: 1px solid #555;
306
+ background-color: #2a2a2a;
307
+ }
308
+
309
+ .model-breakdown-view {
310
+ max-height: 400px; /* Or any other suitable height */
311
+ overflow-y: auto;
312
+ overflow-x: auto;
313
+ background-color: #2d2d2d;
314
+ color: #f1f1f1;
315
+ padding: 1em;
316
+ border-radius: 5px;
317
+ white-space: pre-wrap; /* Ensures the pre content wraps */
318
+ margin: 0;
319
+ font-family: monospace;
320
+ font-size: 0.85em;
321
+ }
322
+
323
+ .model-meta-info {
324
+ font-size: 0.9em;
325
+ color: #666;
326
+ margin-top: 4px;
327
+ }
328
+
329
+ .model-meta-info span {
330
+ margin-right: 15px;
331
+ }
332
+
333
+ .action-btn.raw-btn {
334
+ background-color: #555;
335
+ color: white;
336
+ }
337
+
338
+ .highlight-red {
339
+ color: #ff6b6b;
340
+ }
341
+
342
+ .ansi-red { color: #e74c3c; }
343
+ .ansi-green { color: #2ecc71; }
344
+ .ansi-yellow { color: #f1c40f; }
345
+ .ansi-blue { color: #3498db; }
346
+ .ansi-magenta { color: #9b59b6; }
347
+ .ansi-cyan { color: #1abc9c; }
348
+
349
+ .breakdown-row td {
350
+ text-align: left !important;
351
+ }
352
+
353
+ .footer {
354
+ margin-top: 2em;
355
+ font-size: 0.85em;
356
+ color: #555;
357
+ text-align: center;
358
+ }
359
+
360
+ .footer a {
361
+ color: #2a77d4;
362
+ text-decoration: none;
363
+ }
364
+
365
+ .footer a:hover {
366
+ text-decoration: underline;
367
+ }
368
+
369
+ .disclaimer {
370
+ margin-top: 0.5em;
371
+ font-style: italic;
372
+ }
373
+
374
+ .disclaimer-banner {
375
+ background-color: #fff3cd;
376
+ color: #856404;
377
+ border: 1px solid #ffeeba;
378
+ padding: 10px 15px;
379
+ border-radius: 4px;
380
+ margin: 15px 0;
381
+ font-weight: bold;
382
+ text-align: center;
383
+ }