File size: 21,310 Bytes
357a956
9791f0c
 
 
 
 
9b4534c
 
9791f0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b4534c
9791f0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b4534c
 
 
 
 
9791f0c
 
 
 
 
 
 
 
 
9b4534c
9791f0c
9b4534c
9791f0c
 
9b4534c
9791f0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b4534c
9791f0c
 
 
9b4534c
 
9791f0c
 
 
 
 
 
9b4534c
9791f0c
 
357a956
 
 
9791f0c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
import gradio as gr
import pandas as pd

col=['Layer number', 'Hidden size', 'FFN Hidden size', 'Sequence length', 'Head number', 'Group number', 
        'dp', 'tp', 'pp', 'cp', 'GPU numbers', 'Batch size', 'FP8', 'Model parameters', 'Model_states', 'Activation', 'Total']

# # global data
# table_data = pd.DataFrame(columns=col)

def Get_GigaByte(memory):
    return memory / 1024**3

def Get_BillionParameter(parameter):
    return parameter / 1000**3

# model states:
def Compute_Parameters_input(hidden_size, vocab_size, tp):
    num_parameters_word_embedding = hidden_size * vocab_size / tp
    num_parameters_position_embedding = 0 #args.hidden_size * args.seq_length
    return num_parameters_word_embedding + num_parameters_position_embedding

def Compute_Parameters_output(hidden_size, vocab_size, tp):
    num_parameters_output_layernorm = 2 * hidden_size
    num_parameters_output_embedding = 0 # due to sharedWordEmbedding
    return num_parameters_output_layernorm + num_parameters_output_embedding

def Compute_Parameters_attention(hidden_size, kv_hidden_size, is_bias, tp):
    # attention: 
    # layernorm: 2h 
    num_parameters_attention = 2 * hidden_size
    # QKV weight: 3h*h/tp, bias: 3h/tp
    # output linear weight: h*h/tp, bias: h
    num_parameters_attention_Q_weight = hidden_size * hidden_size / tp
    num_parameters_attention_KV_weight = 2 * kv_hidden_size * hidden_size / tp
    num_parameters_attention_Linear_weight = hidden_size * hidden_size / tp

    num_parameters_attention += num_parameters_attention_Q_weight + num_parameters_attention_KV_weight + num_parameters_attention_Linear_weight
    if is_bias == "True":
        num_parameters_attention += (hidden_size + 2 * kv_hidden_size) / tp + hidden_size
    
    return num_parameters_attention

def Compute_Parameters_mlp(hidden_size, ffn_size, is_bias, act_func,  tp):
    # MLP: 
    # layernorm: 2h
    num_parameters_mlp = 2 * hidden_size
    # mlp1 weight: h*ffn/tp, bias: ffn/tp
    # mlp2 weight: ffn*h/tp, bias: h
    if act_func == "True":
        num_parameters_mlp += hidden_size * ffn_size * 3 / tp
        if is_bias == "True":
            num_parameters_mlp += ffn_size * 2 / tp + hidden_size
    else:
        num_parameters_mlp += hidden_size * ffn_size * 2 / tp
        if is_bias == "True":
            num_parameters_mlp += ffn_size / tp + hidden_size
    
    return num_parameters_mlp

def Compute_Parameters(vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, act_func, head_num, tp, pp):
    if is_group_query == "False":
        group_query_num = head_num
    kv_hidden_size = hidden_size / head_num * group_query_num
    
    # input part
    num_parameters_input = Compute_Parameters_input(hidden_size, vocab_size, tp)

    # middle layers part
    num_parameters_attention = Compute_Parameters_attention(hidden_size, kv_hidden_size, is_bias, tp)
    num_parameters_mlp = Compute_Parameters_mlp(hidden_size, ffn_size, is_bias, act_func, tp)
    num_parameters_in_single_layer = num_parameters_attention + num_parameters_mlp
    num_parameters_in_total_layers = num_parameters_in_single_layer * layer_num / pp
    
    # output part
    parameters_output = Compute_Parameters_output(hidden_size, vocab_size, tp)    

    if pp == 1:
        num_parameters_total = (
            num_parameters_input
            + num_parameters_in_total_layers
            + parameters_output # num_parameters_output_layernorm
        )    
    else:
        num_parameters_total = (
            num_parameters_input
            + num_parameters_in_total_layers
        )   
    
    return num_parameters_total

def Compute_Weight(numParametersTotal, is_fp8, is_fp8_init):
    if is_fp8 == "False":
        weight_memory = 2 * numParametersTotal
    elif is_fp8_init == "False":
        weight_memory = 4 * numParametersTotal
    else:
        weight_memory = 2 * numParametersTotal 
    
    return weight_memory

def Compute_Gradient(numParametersTotal, g_ty):
    if g_ty == "FP32":
        gradient_memory = 4 * numParametersTotal 
    elif g_ty =="BF16":
        gradient_memory = 2 * numParametersTotal
    
    return gradient_memory

def Compute_Optimizer_states(numParametersTotal, o_ty, is_dist_opt, dp, cp):
    if o_ty == "FP32":
        optimizer_memory = 4 * 2 * numParametersTotal 
    elif o_ty =="BF16":
        optimizer_memory = 2 * 2 * numParametersTotal
    
    if is_dist_opt == "True":
        optimizer_memory = optimizer_memory / (dp * cp)

    return optimizer_memory

def Compute_Master_weight(numParametersTotal, is_dist_opt, dp, cp):
    master_weight_memory = 4 * numParametersTotal
    if is_dist_opt == "True":
        master_weight_memory = master_weight_memory / (dp * cp)
    
    return master_weight_memory

def Compute_Model_states(vocab_size, layer_num, hidden_size, ffn_size, head_num, is_group_query, group_query_num, is_bias, act_func,
        dp, tp, pp, cp, is_dist_opt, is_fp8, is_fp8_init, g_ty, o_ty):
    numParametersTotal = Compute_Parameters(vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, act_func, head_num, tp, pp)

    weight_memory = Compute_Weight(numParametersTotal, is_fp8, is_fp8_init)
    gradient_memory = Compute_Gradient(numParametersTotal, g_ty)
    optimizer_memory = Compute_Optimizer_states(numParametersTotal, o_ty, is_dist_opt, dp, cp)
    master_weight_memory = Compute_Master_weight(numParametersTotal, is_dist_opt, dp, cp)

    return numParametersTotal, weight_memory, gradient_memory, optimizer_memory, master_weight_memory, \
            weight_memory + gradient_memory + optimizer_memory + master_weight_memory

# activation memory:
def compute_activation_memory_attention(activation_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp):
    # LN 2bsq
    activation_mem_attn_ln = seq_length * b * hidden_size * 2
    if is_sp == "False":
        activation_mem_attn_ln *= tp
    # attention input X, qkv 2bsh/1bsh
    activation_mem_attn_qkv = seq_length * b * hidden_size * activation_dtype
    if is_sp == "False":
        activation_mem_attn_qkv *= tp
    # attention q 2bsh
    activation_mem_attn_q = seq_length * b * hidden_size * 2
    # attention k and v 4bsh
    activation_mem_attn_kv = seq_length * b * kv_hidden_size * 2 * 2
    # attention proj input 2bsh/1bsh
    activation_mem_attn_proj = seq_length * b * hidden_size * activation_dtype
    # dropout bsh
    activation_mem_attn_dropout = seq_length * b * hidden_size
    if is_sp == "False":
        activation_mem_attn_dropout *= tp
    # bf16: 2+2+2+4+2+1=13bsh
    # fp8: 2+1+2+4+1+1=11bsh
    activation_memory_attn = (
        activation_mem_attn_ln
        + activation_mem_attn_qkv
        + activation_mem_attn_q 
        + activation_mem_attn_kv 
        + activation_mem_attn_proj 
        + activation_mem_attn_dropout
    )
    return activation_memory_attn

def compute_activation_memory_mlp(activation_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp):
    # LN 2bsh
    activation_mem_mlp_ln = seq_length * b * hidden_size * 2
    if is_sp == "False":
        activation_mem_mlp_ln *= tp
    # FC1 2bsh/1bsh
    activation_mem_mlp_fc1 = seq_length * b * hidden_size * activation_dtype
    if is_sp == "False":
        activation_mem_mlp_fc1 *= tp
    # Act 8bsh
    if act_func == "Swiglu":
        activation_mem_mlp_act = seq_length * b * ffn_size * 2 * 2
    else:
        activation_mem_mlp_act = seq_length * b * ffn_size * 2
    # FC2 8bsh/4bsh
    activation_mem_mlp_fc2 = seq_length * b * ffn_size * activation_dtype
    # dropout bsh
    activation_mem_mlp_dropout = seq_length * b * hidden_size
    if is_sp == "False":
        activation_mem_mlp_dropout *= tp
    # bf16: 2+2+8+8+1=21
    # fp8: 2+1+8+4+1=16
    activation_memory_mlp = (
        activation_mem_mlp_ln
        + activation_mem_mlp_fc1
        + activation_mem_mlp_act
        + activation_mem_mlp_fc2
        + activation_mem_mlp_dropout
    )
    return activation_memory_mlp

def compute_activation_memory_input(seq_length, b, hidden_size, pp):
    # embedding + Dropout
    return 8 * seq_length * b * pp + seq_length * b * hidden_size * pp

def compute_activation_memory_output(seq_length, b, hidden_size, vocab_size):
    # Inputs to output layer and CE loss(bf16, fp32 * 2).
    return 2 * seq_length * b * hidden_size + (2 + 4 + 4) * seq_length * b * vocab_size

def compute_activation_memory_pp(activation_memory, is_ip, vp, pp, num_microbatches):
    # Multiply by interleaved PP memory factor.
    if is_ip == "True":
        interleaved_schedule_memory_penalty = 1 + (pp - 1) / (pp * vp)
        activation_memory *= interleaved_schedule_memory_penalty

    # If using non-interleaved schedule, number of microbatches in pipeline can be less than pp_size,
    # so discount accordingly.
    if is_ip == "False" and pp > 1:
        if num_microbatches > 1:
            activation_memory *= min(1, num_microbatches / pp)

    return activation_memory 

def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, is_ip, vp):
    # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
    # We are trying to compute the maximum activation footprint, so all calculations in this function
    # are for the first pipeline stage.

    # activation dataType
    if is_fp8 == "False":
        activation_dtype = 2
    else: 
        activation_dtype = 1

    # kv_hidden_size
    if is_group_query == "False":
        group_query_num = head_num
    kv_hidden_size = hidden_size / head_num * group_query_num

    activation_memory_attn = compute_activation_memory_attention(activation_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp)

    activation_memory_mlp = compute_activation_memory_mlp(activation_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp)

    activation_memory = activation_memory_attn + activation_memory_mlp

    activation_memory *= layer_num

    # Now add activation memory required for input embeddings, last LayerNorm and output layer.
    # Input to embedding (pp_size microbatches in flight).
    activation_memory_input = compute_activation_memory_input(seq_length, b, hidden_size, pp)
    activation_memory += activation_memory_input

    # get num_microbatches
    num_microbatches = b_global / b / dp / cp
    activation_memory = compute_activation_memory_pp(activation_memory, is_ip, vp, pp, num_microbatches)

    if pp == 1:
        # Inputs to output layer and CE loss(fp32).
        activation_memory_output = compute_activation_memory_output(seq_length, b, hidden_size, vocab_size)
        activation_memory += activation_memory_output
    elif pp > 1:
        # Sendrecv memory
        activation_memory += seq_length * b * hidden_size * 2

    # Activation memory is partitioned by TP size due to tensor and sequence model parallelism.
    return activation_memory / tp / cp

# compute_btn.click.function
def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_length, head_num, is_group_query, group_query_num, is_bias, act_func,
        dp, tp, pp, cp, is_sp, is_ip, vp, is_dist_opt, b, b_global, is_fp8, is_fp8_init, g_ty, o_ty, record_df, count):
    # get model states
    numParameters, weight_memory, gradient_memory, optimizer_memory, master_weight_memory, model_states_memory = Compute_Model_states(vocab_size, layer_num, hidden_size, 
        ffn_size, head_num, is_group_query, group_query_num, is_bias, act_func, dp, tp, pp, cp, is_dist_opt, is_fp8, is_fp8_init, g_ty, o_ty)

    # get activation memory 
    activation_memory = compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, is_ip, vp)

    # get model parameters
    numParametersTotal = Compute_Parameters(vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, act_func, head_num, 1, 1)
    # get gpu number
    gpu_num = dp * tp * pp * cp

    # get B/GB
    numParametersTotal = round(Get_BillionParameter(numParametersTotal), 3)
    numParameters = round(Get_BillionParameter(numParameters), 3)
    model_states_memory = round(Get_GigaByte(model_states_memory), 3)
    activation_memory = round(Get_GigaByte(activation_memory), 3)
    Total = round(model_states_memory + activation_memory, 3)

    # record
    new_row = pd.DataFrame([[layer_num, hidden_size, ffn_size, seq_length, head_num, group_query_num, dp, tp, pp, cp, gpu_num, b, is_fp8, 
                            numParametersTotal, model_states_memory, activation_memory, Total]], 
                            columns=col)
    if count == 1:
        record_df = new_row
    else:    
        record_df = record_df._append(new_row, ignore_index=True)
    count = count + 1

    # return str(gpu_num), str(model_states) + " GB", str(activation) + " GB", str(total) + " GB", table_data
    return f"""
                GPU numbers = {str(gpu_num)}, \n
                Total model parameters = {str(numParametersTotal)} B, \n
                Model parameters = {str(numParameters)} B, \n
                Model_states = {str(model_states_memory)} GB, \n
                Activation = {str(activation_memory)} GB, \n
                Total memory consumption = {str(Total)} GB \n
           """, record_df, count

def generate_csv(record_df):
    # 将 DataFrame 保存为 CSV 文件
    csv_filename = "data.csv"
    record_df.to_csv(csv_filename, index=False)
    
    # 返回 CSV 文件路径
    return csv_filename

# formula string
formula = r"""
        > **Note**🔑: In this formula, we assume LLM training with FP32 Gradient and Optimizer state, and bias = False, Zero1 = False, SP = True. 
        
        <!-- parameters: -->
        $$
        P_{input} = \frac{HV}{tp}, \quad
        P_{output} = 2H \\\\
        P_{attn} = 2H + \frac{2H^2 + 2H_{KV} \times H}{tp}, \quad
        P_{MLP} = 2H + 
        \\begin{cases} 
        \frac{3H \times FFN}{tp},  & \text{if }GLU\text{ is True} \\\\
        \frac{2H \times FFN}{tp}, & \text{if }GLU\text{ is False}
        \\end{cases} \\\\
        P_{middle} = \frac{(P_{attn} + P_{MLP}) \times L}{pp} \\\\
        P = P_{input} + P_{middle} + 
        \\begin{cases} 
        P_{output},  & \text{if }pp = 1 \\\\
        0, & \text{if }pp > 1
        \\end{cases} \\\\
        {Total\ Model\ parameters} = 
        \\begin{cases}
        P,  & \text{set tp = 1, pp = 1} \\\\
        2HV + 2H + (4H + 2H^2 + 2H_{KV} \times H + 3FFN \times H) \times L, & \text{general formula}
        \\end{cases} \\\\
        {Model\ states} = {Model\ weight} + {Gradient} + {Optimizer\ state} + {Master\ weight} = 
        \\begin{cases}
        18P,  & \text{BF16 training} \\\\
        18P, & \text{FP8 training with FP8 Init} \\\\
        20P, & \text{FP8 training w/o FP8 Init}
        \\end{cases} \\\\
        $$
        
        ***

        <!-- activations: -->
        $$
        A_{input} = (8SB + SBH) \times pp, \quad
        A_{output} = 2SBH + 
        \\begin{cases} 
        10SBV,  & \text{if }pp\text{ = 1} \\\\
        0, & \text{if }pp\text{ > 1}
        \\end{cases} \\\\
        A_{attn} = 5SBH + 4SB \times H_{KV} +
        \\begin{cases}
        2SBH, & \text{if } FP8  \text{ is True} \\\\
        4SBH, & \text{if } FP8  \text{ is False}
        \\end{cases} \\\\
        A_{MLP} = 3SBH + 
        \\begin{cases}
        SBH + SB \times FFN + 4SB \times FFN, & \text{if }FP8 \text{ is True and }GLU \text{ is True} \\\\
        2SBH + 2SB \times FFN + 4SB \times FFN, & \text{if }FP8 \text{ is False and }GLU \text{ is True} \\\\
        SBH + SB \times FFN + 2SB \times FFN, & \text{if }FP8 \text{ is True and }GLU \text{ is False} \\\\
        2SBH + 2SB \times FFN + 2SB \times FFN, & \text{if }FP8 \text{ is False and }GLU \text{ is False}
        \\end{cases} \\\\
        A_{middle} = (A_{attn} + A_{MLP}) \times L \\\\
        A_{ip} = (A_{input} + A_{middle}) \times 
        \\begin{cases}
        (1 + \frac{pp - 1}{pp \times vp}), & \text{if } Interleaved\ Pipeline  \text{ is True} \\\\
        min(1, \frac{microbatch}{pp}), & \text{if } Interleaved\ Pipeline \text{ is False and pp > 1} \\\\
        1, & \text{other}
        \\end{cases} \\\\
        Activation = 
        \\begin{cases}
        \frac{A_{ip} + A_{output}}{tp \times cp}, & \text{if pp = 1} \\\\
        \frac{A_{ip} + 2BSH}{tp \times cp}, & \text{if pp > 1}
        \\end{cases}
        $$

        ***

        $$
        \\begin{gather}
        {GPU\ numbers} = tp \times pp \times dp \times cp\\\\
        {Total\ memory\ consumption} = {Model\ states} + Activation
        \\end{gather}
        $$
        """

with gr.Blocks() as demo:
    with gr.Row():
        # Text
        gr.Markdown(
            """
            <div style="text-align: center;">
                <h1>GPU memory calculator 🌀</h1>
                <p style="font-size:16px;">Here's a GPU memory calculator, it helps you to compute memory comsumption in LLM training. </p>
            </div>
            """
        )

    with gr.Column(): 
        # Input 1.[Model Parameters]
        gr.Markdown(
            """
            <h1>Model Parameters:</h1>
            """
        )
        with gr.Accordion("Model Parameters"):
            act_func = gr.Radio(["True", "False"], value="True", label="Model type", info="Action Function in MLP, whether to use GLU (Gated Linear Unit). [e.g \"True\" for LlaMA, \"False\" for GPT.]")
            vocab_size = gr.Number(label="Vocab size", value=32000)
            layer_num = gr.Number(label="Layer number", value=32)
            hidden_size = gr.Number(label="Hidden size", value=4096)
            ffn_size = gr.Number(label="FFN Hidden size", value=11008)
            sequence_len = gr.Number(label="Sequence length", value=1024)
            head_num = gr.Number(label="Number of Attention Heads", value=32)
            with gr.Row():
                is_group_query = gr.Radio(["True", "False"], value="True", label="Use Group Query Attention")
                group_query_num = gr.Number(label="Number of Query Groups", value=96)
            is_bias = gr.Radio(["True", "False"], value="False", label="Use Bias")
        
        # Input 2.[Parallelism]
        gr.Markdown(
            """
            <h1>Parallelism config:</h1>
            """
        )
        with gr.Accordion("Parallelism config"):
            dp = gr.Number(label="Data parallelism", value=1)
            tp = gr.Number(label="Tensor parallelism", value=2)
            pp = gr.Number(label="Pipeline parallelism", value=2)
            cp = gr.Number(label="Context parallelism", value=2)
            is_sp = gr.Radio(["True", "False"], value="True", label="Sequence parallelism")
            with gr.Row():
                is_ip = gr.Radio(["True", "False"], value="False", label="Use Interleaved Pipeline")
                vp = gr.Number(label="Virtual Pipeline Size")
            is_dist_opt = gr.Radio(["True", "False"], value="True", label="Use Distributed Optimizer(Zero1)")

        # Input 3.[Training Settings]
        gr.Markdown(
            """
            <h1>Training Config:</h1>
            """
        )
        with gr.Accordion("Training Config"):
            b = gr.Number(label="Micro Batch size", value=4)
            b_global = gr.Number(label="Global Batch size", value=64)
            gr.Checkbox(label="True", value=True, info="BF16 Training")
            is_fp8 = gr.Radio(["True", "False"], value="True", label="FP8 Training")
            is_fp8_init = gr.Radio(["True", "False"], value="True", label="FP8 Initialization(will reduce memory)")
            g_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Gradients Dtype")
            o_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Optimizer State Dtype")

    with gr.Column():
        gr.Markdown(
            """
            <h1>Output Data:</h1>
            """
        )
        formula = formula

        gr.Markdown(
            formula
            , latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }]
        )

        output_text = gr.Textbox(
            label="Compute result", 
            interactive=False, 
        )

    # Button
    with gr.Row():
        compute_btn = gr.Button("Compute")
        download_btn = gr.Button("Download")
    
    record_df = gr.Dataframe(
        label="Record Table",
        headers=col
    )
    count = gr.Number(label="Row count", value=1, visible=False)
    compute_btn.click(
        fn=Compute_ALL_Model_memory, 
        inputs=[vocab_size, layer_num, hidden_size, ffn_size, sequence_len, head_num, is_group_query, group_query_num, is_bias, act_func,
                dp, tp, pp, cp, is_sp, is_ip, vp, is_dist_opt, b, b_global, is_fp8, is_fp8_init, g_ty, o_ty, record_df, count],
        outputs=[output_text, record_df, count]
    )

    output_file=gr.File(label="When you click the download button, the downloaded form will be displayed here.")
    # download func
    download_btn.click(
        fn=generate_csv,
        inputs=record_df,
        outputs=output_file
    )


if __name__ == "__main__":
    demo.launch()