xxyux commited on
Commit
deacdbd
·
verified ·
1 Parent(s): 9baaa6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +221 -144
app.py CHANGED
@@ -1,11 +1,22 @@
1
  import gradio as gr
2
  import pandas as pd
3
 
4
- col=['Layer number', 'Hidden size', 'FFN Hidden size', 'Sequence length', 'Head number', 'Group number',
5
- 'dp', 'tp', 'pp', 'cp', 'GPU numbers', 'Batch size', 'FP8', 'Model parameters', 'Model_states', 'Activation', 'Total']
6
 
7
- # # global data
8
- # table_data = pd.DataFrame(columns=col)
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def Get_GigaByte(memory):
11
  return memory / 1024**3
@@ -46,7 +57,7 @@ def Compute_Parameters_mlp(hidden_size, ffn_size, is_bias, act_func, tp):
46
  num_parameters_mlp = 2 * hidden_size
47
  # mlp1 weight: h*ffn/tp, bias: ffn/tp
48
  # mlp2 weight: ffn*h/tp, bias: h
49
- if act_func == "True":
50
  num_parameters_mlp += hidden_size * ffn_size * 3 / tp
51
  if is_bias == "True":
52
  num_parameters_mlp += ffn_size * 2 / tp + hidden_size
@@ -178,7 +189,7 @@ def compute_activation_memory_mlp(activation_dtype, seq_length, b, hidden_size,
178
  if is_sp == "False":
179
  activation_mem_mlp_fc1 *= tp
180
  # Act 8bsh
181
- if act_func == "Swiglu":
182
  activation_mem_mlp_act = seq_length * b * ffn_size * 2 * 2
183
  else:
184
  activation_mem_mlp_act = seq_length * b * ffn_size * 2
@@ -207,21 +218,21 @@ def compute_activation_memory_output(seq_length, b, hidden_size, vocab_size):
207
  # Inputs to output layer and CE loss(bf16, fp32 * 2).
208
  return 2 * seq_length * b * hidden_size + (2 + 4 + 4) * seq_length * b * vocab_size
209
 
210
- def compute_activation_memory_pp(activation_memory, is_ip, vp, pp, num_microbatches):
211
  # Multiply by interleaved PP memory factor.
212
- if is_ip == "True":
213
  interleaved_schedule_memory_penalty = 1 + (pp - 1) / (pp * vp)
214
  activation_memory *= interleaved_schedule_memory_penalty
215
 
216
  # If using non-interleaved schedule, number of microbatches in pipeline can be less than pp_size,
217
  # so discount accordingly.
218
- if is_ip == "False" and pp > 1:
219
  if num_microbatches > 1:
220
  activation_memory *= min(1, num_microbatches / pp)
221
 
222
  return activation_memory
223
 
224
- def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, is_ip, vp):
225
  # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
226
  # We are trying to compute the maximum activation footprint, so all calculations in this function
227
  # are for the first pipeline stage.
@@ -252,7 +263,7 @@ def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, he
252
 
253
  # get num_microbatches
254
  num_microbatches = b_global / b / dp / cp
255
- activation_memory = compute_activation_memory_pp(activation_memory, is_ip, vp, pp, num_microbatches)
256
 
257
  if pp == 1:
258
  # Inputs to output layer and CE loss(fp32).
@@ -267,13 +278,22 @@ def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, he
267
 
268
  # compute_btn.click.function
269
  def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_length, head_num, is_group_query, group_query_num, is_bias, act_func,
270
- dp, tp, pp, cp, is_sp, is_ip, vp, is_dist_opt, b, b_global, is_fp8, is_fp8_init, g_ty, o_ty, record_df, count):
 
 
 
 
 
 
 
 
 
271
  # get model states
272
  numParameters, weight_memory, gradient_memory, optimizer_memory, master_weight_memory, model_states_memory = Compute_Model_states(vocab_size, layer_num, hidden_size,
273
  ffn_size, head_num, is_group_query, group_query_num, is_bias, act_func, dp, tp, pp, cp, is_dist_opt, is_fp8, is_fp8_init, g_ty, o_ty)
274
 
275
  # get activation memory
276
- activation_memory = compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, is_ip, vp)
277
 
278
  # get model parameters
279
  numParametersTotal = Compute_Parameters(vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, act_func, head_num, 1, 1)
@@ -289,7 +309,7 @@ def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_l
289
 
290
  # record
291
  new_row = pd.DataFrame([[layer_num, hidden_size, ffn_size, seq_length, head_num, group_query_num, dp, tp, pp, cp, gpu_num, b, is_fp8,
292
- numParametersTotal, model_states_memory, activation_memory, Total]],
293
  columns=col)
294
  if count == 1:
295
  record_df = new_row
@@ -300,8 +320,8 @@ def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_l
300
  # return str(gpu_num), str(model_states) + " GB", str(activation) + " GB", str(total) + " GB", table_data
301
  return f"""
302
  GPU numbers = {str(gpu_num)}, \n
303
- Total model parameters = {str(numParametersTotal)} B, \n
304
- Model parameters = {str(numParameters)} B, \n
305
  Model_states = {str(model_states_memory)} GB, \n
306
  Activation = {str(activation_memory)} GB, \n
307
  Total memory consumption = {str(Total)} GB \n
@@ -317,71 +337,36 @@ def generate_csv(record_df):
317
 
318
  # formula string
319
  formula = r"""
320
- > **Note**🔑: In this formula, we assume LLM training with FP32 Gradient and Optimizer state, and bias = False, Zero1 = False, SP = True.
321
-
322
- <!-- parameters: -->
 
 
 
 
 
 
323
  $$
324
- P_{input} = \frac{HV}{tp}, \quad
325
- P_{output} = 2H \\\\
326
- P_{attn} = 2H + \frac{2H^2 + 2H_{KV} \times H}{tp}, \quad
327
- P_{MLP} = 2H +
328
- \\begin{cases}
329
- \frac{3H \times FFN}{tp}, & \text{if }GLU\text{ is True} \\\\
330
- \frac{2H \times FFN}{tp}, & \text{if }GLU\text{ is False}
331
- \\end{cases} \\\\
332
- P_{middle} = \frac{(P_{attn} + P_{MLP}) \times L}{pp} \\\\
333
- P = P_{input} + P_{middle} +
334
- \\begin{cases}
335
- P_{output}, & \text{if }pp = 1 \\\\
336
- 0, & \text{if }pp > 1
337
- \\end{cases} \\\\
338
  {Total\ Model\ parameters} =
339
- \\begin{cases}
340
- P, & \text{set tp = 1, pp = 1} \\\\
341
- 2HV + 2H + (4H + 2H^2 + 2H_{KV} \times H + 3FFN \times H) \times L, & \text{general formula}
342
- \\end{cases} \\\\
343
- {Model\ states} = {Model\ weight} + {Gradient} + {Optimizer\ state} + {Master\ weight} =
344
- \\begin{cases}
345
- 18P, & \text{BF16 training} \\\\
346
- 18P, & \text{FP8 training with FP8 Init} \\\\
347
- 20P, & \text{FP8 training w/o FP8 Init}
348
- \\end{cases} \\\\
349
  $$
350
-
351
  ***
352
 
353
- <!-- activations: -->
 
 
 
 
 
 
 
 
 
354
  $$
355
- A_{input} = (8SB + SBH) \times pp, \quad
356
- A_{output} = 2SBH +
357
- \\begin{cases}
358
- 10SBV, & \text{if }pp\text{ = 1} \\\\
359
- 0, & \text{if }pp\text{ > 1}
360
- \\end{cases} \\\\
361
- A_{attn} = 5SBH + 4SB \times H_{KV} +
362
- \\begin{cases}
363
- 2SBH, & \text{if } FP8 \text{ is True} \\\\
364
- 4SBH, & \text{if } FP8 \text{ is False}
365
- \\end{cases} \\\\
366
- A_{MLP} = 3SBH +
367
- \\begin{cases}
368
- SBH + SB \times FFN + 4SB \times FFN, & \text{if }FP8 \text{ is True and }GLU \text{ is True} \\\\
369
- 2SBH + 2SB \times FFN + 4SB \times FFN, & \text{if }FP8 \text{ is False and }GLU \text{ is True} \\\\
370
- SBH + SB \times FFN + 2SB \times FFN, & \text{if }FP8 \text{ is True and }GLU \text{ is False} \\\\
371
- 2SBH + 2SB \times FFN + 2SB \times FFN, & \text{if }FP8 \text{ is False and }GLU \text{ is False}
372
- \\end{cases} \\\\
373
- A_{middle} = (A_{attn} + A_{MLP}) \times L \\\\
374
- A_{ip} = (A_{input} + A_{middle}) \times
375
- \\begin{cases}
376
- (1 + \frac{pp - 1}{pp \times vp}), & \text{if } Interleaved\ Pipeline \text{ is True} \\\\
377
- min(1, \frac{microbatch}{pp}), & \text{if } Interleaved\ Pipeline \text{ is False and pp > 1} \\\\
378
- 1, & \text{other}
379
- \\end{cases} \\\\
380
- Activation =
381
- \\begin{cases}
382
- \frac{A_{ip} + A_{output}}{tp \times cp}, & \text{if pp = 1} \\\\
383
- \frac{A_{ip} + 2BSH}{tp \times cp}, & \text{if pp > 1}
384
- \\end{cases}
385
  $$
386
 
387
  ***
@@ -394,6 +379,76 @@ formula = r"""
394
  $$
395
  """
396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  with gr.Blocks() as demo:
398
  with gr.Row():
399
  # Text
@@ -406,64 +461,92 @@ with gr.Blocks() as demo:
406
  """
407
  )
408
 
409
- with gr.Column():
410
- # Input 1.[Model Parameters]
411
- gr.Markdown(
412
- """
413
- <h1>Model Parameters:</h1>
414
- """
415
- )
416
- with gr.Accordion("Model Parameters"):
417
- act_func = gr.Radio(["True", "False"], value="True", label="Model type", info="Action Function in MLP, whether to use GLU (Gated Linear Unit). [e.g \"True\" for LlaMA, \"False\" for GPT.]")
418
- vocab_size = gr.Number(label="Vocab size", value=32000)
419
- layer_num = gr.Number(label="Layer number", value=32)
420
- hidden_size = gr.Number(label="Hidden size", value=4096)
421
- ffn_size = gr.Number(label="FFN Hidden size", value=11008)
422
- sequence_len = gr.Number(label="Sequence length", value=1024)
423
- head_num = gr.Number(label="Number of Attention Heads", value=32)
424
- with gr.Row():
425
- is_group_query = gr.Radio(["True", "False"], value="True", label="Use Group Query Attention")
426
- group_query_num = gr.Number(label="Number of Query Groups", value=96)
427
- is_bias = gr.Radio(["True", "False"], value="False", label="Use Bias")
428
-
429
- # Input 2.[Parallelism]
430
- gr.Markdown(
431
- """
432
- <h1>Parallelism config:</h1>
433
- """
434
- )
435
- with gr.Accordion("Parallelism config"):
436
- dp = gr.Number(label="Data parallelism", value=1)
437
- tp = gr.Number(label="Tensor parallelism", value=2)
438
- pp = gr.Number(label="Pipeline parallelism", value=2)
439
- cp = gr.Number(label="Context parallelism", value=2)
440
- is_sp = gr.Radio(["True", "False"], value="True", label="Sequence parallelism")
441
- with gr.Row():
442
- is_ip = gr.Radio(["True", "False"], value="False", label="Use Interleaved Pipeline")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  vp = gr.Number(label="Virtual Pipeline Size")
444
- is_dist_opt = gr.Radio(["True", "False"], value="True", label="Use Distributed Optimizer(Zero1)")
445
-
446
- # Input 3.[Training Settings]
447
- gr.Markdown(
448
- """
449
- <h1>Training Config:</h1>
450
- """
451
- )
452
- with gr.Accordion("Training Config"):
453
- b = gr.Number(label="Micro Batch size", value=4)
454
- b_global = gr.Number(label="Global Batch size", value=64)
455
- gr.Checkbox(label="True", value=True, info="BF16 Training")
456
- is_fp8 = gr.Radio(["True", "False"], value="True", label="FP8 Training")
457
- is_fp8_init = gr.Radio(["True", "False"], value="True", label="FP8 Initialization(will reduce memory)")
458
- g_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Gradients Dtype")
459
- o_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Optimizer State Dtype")
460
-
461
- with gr.Column():
462
- gr.Markdown(
463
- """
464
- <h1>Output Data:</h1>
465
- """
466
- )
 
 
 
 
 
 
 
 
 
 
 
 
467
  formula = formula
468
 
469
  gr.Markdown(
@@ -471,25 +554,19 @@ with gr.Blocks() as demo:
471
  , latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }]
472
  )
473
 
474
- output_text = gr.Textbox(
475
- label="Compute result",
476
- interactive=False,
477
- )
478
 
479
- # Button
480
- with gr.Row():
481
- compute_btn = gr.Button("Compute")
482
- download_btn = gr.Button("Download")
483
-
484
  record_df = gr.Dataframe(
485
  label="Record Table",
486
- headers=col
 
487
  )
 
488
  count = gr.Number(label="Row count", value=1, visible=False)
489
  compute_btn.click(
490
  fn=Compute_ALL_Model_memory,
491
  inputs=[vocab_size, layer_num, hidden_size, ffn_size, sequence_len, head_num, is_group_query, group_query_num, is_bias, act_func,
492
- dp, tp, pp, cp, is_sp, is_ip, vp, is_dist_opt, b, b_global, is_fp8, is_fp8_init, g_ty, o_ty, record_df, count],
493
  outputs=[output_text, record_df, count]
494
  )
495
 
@@ -503,4 +580,4 @@ with gr.Blocks() as demo:
503
 
504
 
505
  if __name__ == "__main__":
506
- demo.launch()
 
1
  import gradio as gr
2
  import pandas as pd
3
 
4
+ # col=['Layer number', 'Hidden size', 'FFN Hidden size', 'Sequence length', 'Head number', 'Group number',
5
+ # 'dp', 'tp', 'pp', 'cp', 'GPU numbers', 'Batch size', 'FP8', 'Model parameters', 'Model_states', 'Activation', 'Total']
6
 
7
+ col=['L', 'H', 'FFN', 'S', 'A', 'G',
8
+ 'dp', 'tp', 'pp', 'cp', 'GPU number', 'Batch size', 'FP8', 'Model parameters', 'Model states', 'Activation', 'Total']
9
+
10
+ abbr = """
11
+ <div align="center">
12
+
13
+ > **Abbreviations of symbols:**
14
+ |Abbr|Full name|Abbr|Full name|Abbr|Full name|Abbr|Full name|Abbr|Full name|Abbr|Full name|
15
+ |---|---|---|---|---|---|---|---|---|---|---|---|
16
+ |L|Layer number|H|Hidden size|FFN|FFN Hidden size|S|Sequence length|A|Head number|G|Group number|
17
+
18
+ </div>
19
+ """
20
 
21
  def Get_GigaByte(memory):
22
  return memory / 1024**3
 
57
  num_parameters_mlp = 2 * hidden_size
58
  # mlp1 weight: h*ffn/tp, bias: ffn/tp
59
  # mlp2 weight: ffn*h/tp, bias: h
60
+ if act_func == "LLaMA":
61
  num_parameters_mlp += hidden_size * ffn_size * 3 / tp
62
  if is_bias == "True":
63
  num_parameters_mlp += ffn_size * 2 / tp + hidden_size
 
189
  if is_sp == "False":
190
  activation_mem_mlp_fc1 *= tp
191
  # Act 8bsh
192
+ if act_func == "LLaMA":
193
  activation_mem_mlp_act = seq_length * b * ffn_size * 2 * 2
194
  else:
195
  activation_mem_mlp_act = seq_length * b * ffn_size * 2
 
218
  # Inputs to output layer and CE loss(bf16, fp32 * 2).
219
  return 2 * seq_length * b * hidden_size + (2 + 4 + 4) * seq_length * b * vocab_size
220
 
221
+ def compute_activation_memory_pp(activation_memory, vp, pp, num_microbatches):
222
  # Multiply by interleaved PP memory factor.
223
+ if vp > 0:
224
  interleaved_schedule_memory_penalty = 1 + (pp - 1) / (pp * vp)
225
  activation_memory *= interleaved_schedule_memory_penalty
226
 
227
  # If using non-interleaved schedule, number of microbatches in pipeline can be less than pp_size,
228
  # so discount accordingly.
229
+ if vp == 0 and pp > 1:
230
  if num_microbatches > 1:
231
  activation_memory *= min(1, num_microbatches / pp)
232
 
233
  return activation_memory
234
 
235
+ def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp):
236
  # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
237
  # We are trying to compute the maximum activation footprint, so all calculations in this function
238
  # are for the first pipeline stage.
 
263
 
264
  # get num_microbatches
265
  num_microbatches = b_global / b / dp / cp
266
+ activation_memory = compute_activation_memory_pp(activation_memory, vp, pp, num_microbatches)
267
 
268
  if pp == 1:
269
  # Inputs to output layer and CE loss(fp32).
 
278
 
279
  # compute_btn.click.function
280
  def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_length, head_num, is_group_query, group_query_num, is_bias, act_func,
281
+ dp, tp, pp, cp, is_sp, vp, is_dist_opt, b, b_global, is_fp8, is_fp8_init, g_ty, o_ty, record_df, count):
282
+ # data type trans
283
+ if is_group_query == "True":
284
+ group_query_num = int(group_query_num)
285
+
286
+ # check input
287
+ [result, Error_message] = check_input(dp, tp, pp, cp, hidden_size, head_num, layer_num, seq_length, vp, b, b_global)
288
+ if result == False:
289
+ return Error_message, record_df, count
290
+
291
  # get model states
292
  numParameters, weight_memory, gradient_memory, optimizer_memory, master_weight_memory, model_states_memory = Compute_Model_states(vocab_size, layer_num, hidden_size,
293
  ffn_size, head_num, is_group_query, group_query_num, is_bias, act_func, dp, tp, pp, cp, is_dist_opt, is_fp8, is_fp8_init, g_ty, o_ty)
294
 
295
  # get activation memory
296
+ activation_memory = compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp)
297
 
298
  # get model parameters
299
  numParametersTotal = Compute_Parameters(vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, act_func, head_num, 1, 1)
 
309
 
310
  # record
311
  new_row = pd.DataFrame([[layer_num, hidden_size, ffn_size, seq_length, head_num, group_query_num, dp, tp, pp, cp, gpu_num, b, is_fp8,
312
+ numParameters, model_states_memory, activation_memory, Total]],
313
  columns=col)
314
  if count == 1:
315
  record_df = new_row
 
320
  # return str(gpu_num), str(model_states) + " GB", str(activation) + " GB", str(total) + " GB", table_data
321
  return f"""
322
  GPU numbers = {str(gpu_num)}, \n
323
+ Model parameters = {str(numParametersTotal)} B, \n
324
+ Model parameters on each device = {str(numParameters)} B, \n
325
  Model_states = {str(model_states_memory)} GB, \n
326
  Activation = {str(activation_memory)} GB, \n
327
  Total memory consumption = {str(Total)} GB \n
 
337
 
338
  # formula string
339
  formula = r"""
340
+ > **Note**🔑: In this formula, we assume LLM training with FP8 training.
341
+ > 1. Interleaved pipeline.
342
+ > 2. bias = False.
343
+ > 3. SP = True.
344
+
345
+ <div align="center">
346
+ <img src=file/T1.jpg width=50%/>
347
+ </div>
348
+
349
  $$
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  {Total\ Model\ parameters} =
351
+ HV + HS + (4H^2 + 3H \times FFN + 2H) \times L + 2H + HV
 
 
 
 
 
 
 
 
 
352
  $$
353
+
354
  ***
355
 
356
+ <div align="center">
357
+ <img src=file/ms.png width=40%/>
358
+ </div>
359
+
360
+ $$
361
+ {Model\ states} =
362
+ (6 + \frac{12}{dp}) \times
363
+ (\frac{(\frac{4h^2 + 3H \times FFN}{tp} + 2H) \times L}{pp} + \frac{HV}{tp} + HS)
364
+ $$
365
+
366
  $$
367
+ {Activation} =
368
+ (1 + \frac{pp-1}{pp \times vp}) \times
369
+ \frac{(8BS + BSH) \times pp + 15BSH + 5BS \times FFN}{tp \times cp}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  $$
371
 
372
  ***
 
379
  $$
380
  """
381
 
382
+ def check_tp(tp, head_num):
383
+ if head_num % tp == 0:
384
+ return True
385
+ else:
386
+ return False
387
+
388
+ def check_pp(pp, layer_num):
389
+ if layer_num % pp == 0:
390
+ return True
391
+ else:
392
+ return False
393
+
394
+ def check_cp(cp, seq_length):
395
+ if seq_length % cp == 0:
396
+ return True
397
+ else:
398
+ return False
399
+
400
+ def check_hidden(hidden_size, head_num):
401
+ if hidden_size % head_num == 0:
402
+ return True
403
+ else:
404
+ return False
405
+
406
+ def check_b_global(b_global, b, dp, cp):
407
+ if b_global % (b * dp * cp) == 0:
408
+ return True
409
+ else:
410
+ return False
411
+
412
+ def check_num_microbatch(layer_num, vp, pp, num_microbatches):
413
+ if vp > 0:
414
+ if layer_num % (pp * vp) == 0:
415
+ return True
416
+ else:
417
+ return False
418
+
419
+ if vp == 0 and pp > 1:
420
+ if num_microbatches > 1:
421
+ if num_microbatches % pp == 0:
422
+ return True
423
+ else:
424
+ return False
425
+ return True
426
+
427
+
428
+ def check_input(dp, tp, pp, cp, hidden_size, head_num, layer_num, seq_length, vp, b, b_global):
429
+ result = True
430
+ Error_message = ""
431
+ if check_tp(tp, head_num) == False:
432
+ result = False
433
+ Error_message += "Error message: Please reset Tensor parallelism or head_num, make head_num % tp = 0. \n"
434
+ if check_pp(pp, layer_num) == False:
435
+ result = False
436
+ Error_message += "Error message: Please reset Pipeline parallelism or layer_num, make layer_num % pp = 0. \n"
437
+ if check_cp(cp, seq_length) == False:
438
+ result = False
439
+ Error_message += "Error message: Please reset Context parallelism or seq_length, make seq_length % cp = 0. \n"
440
+ if check_hidden(hidden_size, head_num) == False:
441
+ result = False
442
+ Error_message += "Error message: Please reset hidden_size or head_num, make hidden_size % head_num = 0. \n"
443
+ if check_b_global(b_global, b, dp, cp) == False:
444
+ result = False
445
+ Error_message += "Error message: Please reset b_global or batch_size, make b_global % (batch_size * dp * cp) = 0. \n"
446
+ if check_num_microbatch(layer_num, vp, pp, b_global / b / dp / cp) == False:
447
+ result = False
448
+ Error_message += "Error message: Please reset b_global or batch_size or layer_num or Virtual Pipeline Size, make layer_num % (pp * vp) = 0, num_microbatches % pp = 0. \n"
449
+
450
+ return result, Error_message
451
+
452
  with gr.Blocks() as demo:
453
  with gr.Row():
454
  # Text
 
461
  """
462
  )
463
 
464
+ with gr.Row():
465
+ with gr.Column():
466
+ # Input 1.[Model Parameters]
467
+ gr.Markdown(
468
+ """
469
+ <h1>Model Parameters:</h1>
470
+ """
471
+ )
472
+ with gr.Accordion("Model Parameters"):
473
+ # with gr.Row():
474
+ act_func = gr.Radio(["LLaMA", "GPT"], value="LLaMA", label="Model type") #, info="Action Function in MLP, whether to use GLU (Gated Linear Unit). [e.g \"True\" for LlaMA, \"False\" for GPT.]")
475
+ with gr.Row():
476
+ vocab_size = gr.Number(label="Vocab size", value=32000)
477
+ layer_num = gr.Number(label="Layer number", value=32)
478
+ with gr.Row():
479
+ hidden_size = gr.Number(label="Hidden size", value=4096)
480
+ ffn_size = gr.Number(label="FFN Hidden size", value=11008)
481
+ with gr.Row():
482
+ sequence_len = gr.Number(label="Sequence length", value=2048)
483
+ head_num = gr.Number(label="Number of Attention Heads", value=32)
484
+ with gr.Row():
485
+ is_group_query = gr.Radio(["True", "False"], value="False", label="Use Group Query Attention")
486
+ group_query_num = gr.Textbox(label="Number of Query Groups", max_lines=1, value=None, interactive=False)
487
+ is_bias = gr.Radio(["True", "False"], value="False", label="Use Bias")
488
+
489
+ # change editable function
490
+ def toggle_textbox_editable(radio_value):
491
+ # 根据 radio_value 的值来决定 textbox 是否可编辑
492
+ if radio_value == "True":
493
+ return gr.update(interactive=True, value="96")
494
+ else:
495
+ return gr.update(interactive=False, value="")
496
+ # 将 radio 组件的变化连接到函数
497
+ is_group_query.change(toggle_textbox_editable, inputs=is_group_query, outputs=group_query_num)
498
+
499
+ with gr.Column():
500
+ # Input 2.[Parallelism]
501
+ gr.Markdown(
502
+ """
503
+ <h1>Parallelism config:</h1>
504
+ """
505
+ )
506
+ with gr.Accordion("Parallelism config"):
507
+ # with gr.Row():
508
+ dp = gr.Number(label="Data parallelism", value=1)
509
+ tp = gr.Number(label="Tensor parallelism", value=2)
510
+ pp = gr.Number(label="Pipeline parallelism", value=2)
511
+ cp = gr.Number(label="Context parallelism", value=2)
512
+ # with gr.Row():
513
+ is_sp = gr.Radio(["True", "False"], value="True", label="Sequence parallelism")
514
  vp = gr.Number(label="Virtual Pipeline Size")
515
+ is_dist_opt = gr.Radio(["True", "False"], value="True", label="Use Distributed Optimizer(Zero1)")
516
+
517
+ with gr.Column():
518
+ # Input 3.[Training Settings]
519
+ gr.Markdown(
520
+ """
521
+ <h1>Training Config:</h1>
522
+ """
523
+ )
524
+ with gr.Accordion("Training Config"):
525
+ # with gr.Row():
526
+ b = gr.Number(label="Micro Batch size", value=4)
527
+ b_global = gr.Number(label="Global Batch size", value=64)
528
+ # with gr.Row():
529
+ gr.Checkbox(label="True", value=True, info="BF16 Training")
530
+ is_fp8 = gr.Radio(["True", "False"], value="True", label="FP8 Training")
531
+ is_fp8_init = gr.Radio(["True", "False"], value="True", label="FP8 Initialization(will reduce memory)")
532
+ # with gr.Row():
533
+ g_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Gradients Dtype")
534
+ o_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Optimizer State Dtype")
535
+
536
+ compute_btn = gr.Button("Compute")
537
+ with gr.Tab("Output"):
538
+ with gr.Column():
539
+ gr.Markdown(
540
+ """
541
+ <h1>Output Data:</h1>
542
+ """
543
+ )
544
+ output_text = gr.Textbox(
545
+ label="Compute result",
546
+ interactive=False,
547
+ )
548
+
549
+ with gr.Tab("Formula"):
550
  formula = formula
551
 
552
  gr.Markdown(
 
554
  , latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }]
555
  )
556
 
557
+ gr.Markdown(abbr)
 
 
 
558
 
 
 
 
 
 
559
  record_df = gr.Dataframe(
560
  label="Record Table",
561
+ headers=col,
562
+ interactive=False
563
  )
564
+ download_btn = gr.Button("Download")
565
  count = gr.Number(label="Row count", value=1, visible=False)
566
  compute_btn.click(
567
  fn=Compute_ALL_Model_memory,
568
  inputs=[vocab_size, layer_num, hidden_size, ffn_size, sequence_len, head_num, is_group_query, group_query_num, is_bias, act_func,
569
+ dp, tp, pp, cp, is_sp, vp, is_dist_opt, b, b_global, is_fp8, is_fp8_init, g_ty, o_ty, record_df, count],
570
  outputs=[output_text, record_df, count]
571
  )
572
 
 
580
 
581
 
582
  if __name__ == "__main__":
583
+ demo.launch(allowed_paths=["/"])