Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ import pandas as pd
|
|
| 5 |
# 'dp', 'tp', 'pp', 'cp', 'GPU numbers', 'Batch size', 'FP8', 'Model parameters', 'Model_states', 'Activation', 'Total']
|
| 6 |
|
| 7 |
col=['L', 'H', 'FFN', 'S', 'A', 'G',
|
| 8 |
-
'
|
| 9 |
|
| 10 |
abbr = """
|
| 11 |
<div align="center">
|
|
@@ -31,7 +31,7 @@ def Compute_Parameters_input(seq_length, hidden_size, vocab_size, act_func, tp):
|
|
| 31 |
if act_func == "LLaMA":
|
| 32 |
num_parameters_position_embedding = 0
|
| 33 |
else:
|
| 34 |
-
num_parameters_position_embedding = seq_length * hidden_size
|
| 35 |
|
| 36 |
return num_parameters_word_embedding + num_parameters_position_embedding
|
| 37 |
|
|
@@ -119,13 +119,15 @@ def Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size,
|
|
| 119 |
|
| 120 |
return num_parameters_total
|
| 121 |
|
| 122 |
-
def Compute_Weight(numParametersTotal, is_fp8, is_fp8_init):
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
elif is_fp8_init == "False":
|
| 126 |
weight_memory = 4 * numParametersTotal
|
| 127 |
else:
|
| 128 |
-
weight_memory = 2 * numParametersTotal
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
return weight_memory
|
| 131 |
|
|
@@ -137,7 +139,7 @@ def Compute_Gradient(numParametersTotal, g_ty):
|
|
| 137 |
|
| 138 |
return gradient_memory
|
| 139 |
|
| 140 |
-
def Compute_Optimizer_states(numParametersTotal, o_ty, is_dist_opt, dp, cp):
|
| 141 |
if o_ty == "FP32":
|
| 142 |
optimizer_memory = 4 * 2 * numParametersTotal
|
| 143 |
elif o_ty =="BF16":
|
|
@@ -146,23 +148,30 @@ def Compute_Optimizer_states(numParametersTotal, o_ty, is_dist_opt, dp, cp):
|
|
| 146 |
if is_dist_opt == "True":
|
| 147 |
optimizer_memory = optimizer_memory / (dp * cp)
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
return optimizer_memory
|
| 150 |
|
| 151 |
-
def Compute_Master_weight(numParametersTotal, is_dist_opt, dp, cp):
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
| 153 |
if is_dist_opt == "True":
|
| 154 |
master_weight_memory = master_weight_memory / (dp * cp)
|
| 155 |
|
| 156 |
return master_weight_memory
|
| 157 |
|
| 158 |
def Compute_Model_states(seq_length, vocab_size, layer_num, hidden_size, ffn_size, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func,
|
| 159 |
-
dp, tp, pp, cp, is_dist_opt, is_fp8, is_fp8_init, g_ty, o_ty):
|
| 160 |
numParametersTotal = Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, head_num, tp, pp)
|
| 161 |
|
| 162 |
-
weight_memory = Compute_Weight(numParametersTotal, is_fp8, is_fp8_init)
|
| 163 |
gradient_memory = Compute_Gradient(numParametersTotal, g_ty)
|
| 164 |
-
optimizer_memory = Compute_Optimizer_states(numParametersTotal, o_ty, is_dist_opt, dp, cp)
|
| 165 |
-
master_weight_memory = Compute_Master_weight(numParametersTotal, is_dist_opt, dp, cp)
|
| 166 |
|
| 167 |
return numParametersTotal, weight_memory, gradient_memory, optimizer_memory, master_weight_memory, \
|
| 168 |
weight_memory + gradient_memory + optimizer_memory + master_weight_memory
|
|
@@ -298,7 +307,7 @@ def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, he
|
|
| 298 |
|
| 299 |
# compute_btn.click.function
|
| 300 |
def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_length, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func,
|
| 301 |
-
dp, tp, pp, cp, is_sp, vp, is_dist_opt, b, b_global, is_fp8, is_fp8_init, g_ty, o_ty, record_df, count):
|
| 302 |
# data type trans
|
| 303 |
if is_group_query == "True":
|
| 304 |
group_query_num = int(group_query_num)
|
|
@@ -312,7 +321,7 @@ def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_l
|
|
| 312 |
|
| 313 |
# get model states
|
| 314 |
numParameters, weight_memory, gradient_memory, optimizer_memory, master_weight_memory, model_states_memory = Compute_Model_states(seq_length, vocab_size, layer_num, hidden_size,
|
| 315 |
-
ffn_size, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, dp, tp, pp, cp, is_dist_opt, is_fp8, is_fp8_init, g_ty, o_ty)
|
| 316 |
|
| 317 |
# get activation memory
|
| 318 |
activation_memory = compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp)
|
|
@@ -344,7 +353,7 @@ def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_l
|
|
| 344 |
GPU numbers = {str(gpu_num)}, \n
|
| 345 |
Model parameters = {str(numParametersTotal)} B, \n
|
| 346 |
Model parameters on each device = {str(numParameters)} B, \n
|
| 347 |
-
Model_states = {str(model_states_memory)} GB, \n
|
| 348 |
Activation = {str(activation_memory)} GB, \n
|
| 349 |
Total memory consumption = {str(Total)} GB \n
|
| 350 |
""", record_df, count
|
|
@@ -389,7 +398,7 @@ formula = r"""
|
|
| 389 |
$$
|
| 390 |
{Activation} =
|
| 391 |
(1 + \frac{pp-1}{pp \times vp}) \times
|
| 392 |
-
\frac{(8BS + BSH) \times pp + 15BSH + 5BS \times FFN}{tp \times cp}
|
| 393 |
$$
|
| 394 |
|
| 395 |
***
|
|
@@ -494,7 +503,7 @@ with gr.Blocks() as demo:
|
|
| 494 |
)
|
| 495 |
with gr.Accordion("Model Parameters"):
|
| 496 |
# with gr.Row():
|
| 497 |
-
act_func = gr.Radio(["LLaMA", "GPT"], value="LLaMA", label="Model type") #, info="Action Function in MLP, whether to use GLU (Gated Linear Unit). [e.g \"True\" for LlaMA, \"False\" for GPT.]")
|
| 498 |
with gr.Row():
|
| 499 |
vocab_size = gr.Number(label="Vocab size (V)", value=32000)
|
| 500 |
layer_num = gr.Number(label="Layer number (L)", value=32)
|
|
@@ -549,13 +558,14 @@ with gr.Blocks() as demo:
|
|
| 549 |
# with gr.Row():
|
| 550 |
b = gr.Number(label="Micro Batch size (B)", value=4)
|
| 551 |
b_global = gr.Number(label="Global Batch size", value=64)
|
| 552 |
-
|
| 553 |
-
gr.
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
# with gr.Row():
|
| 557 |
g_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Gradients Dtype")
|
| 558 |
-
|
|
|
|
|
|
|
| 559 |
|
| 560 |
compute_btn = gr.Button("Compute")
|
| 561 |
with gr.Tab("Output"):
|
|
@@ -590,7 +600,7 @@ with gr.Blocks() as demo:
|
|
| 590 |
compute_btn.click(
|
| 591 |
fn=Compute_ALL_Model_memory,
|
| 592 |
inputs=[vocab_size, layer_num, hidden_size, ffn_size, sequence_len, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func,
|
| 593 |
-
dp, tp, pp, cp, is_sp, vp, is_dist_opt, b, b_global, is_fp8, is_fp8_init, g_ty, o_ty, record_df, count],
|
| 594 |
outputs=[output_text, record_df, count]
|
| 595 |
)
|
| 596 |
|
|
|
|
| 5 |
# 'dp', 'tp', 'pp', 'cp', 'GPU numbers', 'Batch size', 'FP8', 'Model parameters', 'Model_states', 'Activation', 'Total']
|
| 6 |
|
| 7 |
col=['L', 'H', 'FFN', 'S', 'A', 'G',
|
| 8 |
+
'DP', 'TP', 'PP', 'CP', 'GPUs', 'B', 'FP8', 'Model parameters (B)', 'Model states (GB)', 'Activation (GB)', 'Total (GB)']
|
| 9 |
|
| 10 |
abbr = """
|
| 11 |
<div align="center">
|
|
|
|
| 31 |
if act_func == "LLaMA":
|
| 32 |
num_parameters_position_embedding = 0
|
| 33 |
else:
|
| 34 |
+
num_parameters_position_embedding = seq_length * hidden_size / tp
|
| 35 |
|
| 36 |
return num_parameters_word_embedding + num_parameters_position_embedding
|
| 37 |
|
|
|
|
| 119 |
|
| 120 |
return num_parameters_total
|
| 121 |
|
| 122 |
+
def Compute_Weight(numParametersTotal, precision, is_fp8, is_fp8_init):
|
| 123 |
+
weight_memory = 0
|
| 124 |
+
if precision == "FP32":
|
|
|
|
| 125 |
weight_memory = 4 * numParametersTotal
|
| 126 |
else:
|
| 127 |
+
weight_memory = 2 * numParametersTotal
|
| 128 |
+
|
| 129 |
+
if is_fp8 == "True" and is_fp8_init == "False":
|
| 130 |
+
weight_memory += 2 * numParametersTotal
|
| 131 |
|
| 132 |
return weight_memory
|
| 133 |
|
|
|
|
| 139 |
|
| 140 |
return gradient_memory
|
| 141 |
|
| 142 |
+
def Compute_Optimizer_states(numParametersTotal, opt_func, o_ty, is_dist_opt, dp, cp):
|
| 143 |
if o_ty == "FP32":
|
| 144 |
optimizer_memory = 4 * 2 * numParametersTotal
|
| 145 |
elif o_ty =="BF16":
|
|
|
|
| 148 |
if is_dist_opt == "True":
|
| 149 |
optimizer_memory = optimizer_memory / (dp * cp)
|
| 150 |
|
| 151 |
+
# for SGD, we have no optimizer states
|
| 152 |
+
if opt_func == "SGD":
|
| 153 |
+
optimizer_memory = 0
|
| 154 |
+
|
| 155 |
return optimizer_memory
|
| 156 |
|
| 157 |
+
def Compute_Master_weight(numParametersTotal, precision, is_dist_opt, dp, cp):
|
| 158 |
+
if precision == "BF16":
|
| 159 |
+
master_weight_memory = 4 * numParametersTotal
|
| 160 |
+
else:
|
| 161 |
+
master_weight_memory = 0
|
| 162 |
if is_dist_opt == "True":
|
| 163 |
master_weight_memory = master_weight_memory / (dp * cp)
|
| 164 |
|
| 165 |
return master_weight_memory
|
| 166 |
|
| 167 |
def Compute_Model_states(seq_length, vocab_size, layer_num, hidden_size, ffn_size, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func,
|
| 168 |
+
dp, tp, pp, cp, is_dist_opt, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty):
|
| 169 |
numParametersTotal = Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, head_num, tp, pp)
|
| 170 |
|
| 171 |
+
weight_memory = Compute_Weight(numParametersTotal, precision, is_fp8, is_fp8_init)
|
| 172 |
gradient_memory = Compute_Gradient(numParametersTotal, g_ty)
|
| 173 |
+
optimizer_memory = Compute_Optimizer_states(numParametersTotal, opt_func, o_ty, is_dist_opt, dp, cp)
|
| 174 |
+
master_weight_memory = Compute_Master_weight(numParametersTotal, precision, is_dist_opt, dp, cp)
|
| 175 |
|
| 176 |
return numParametersTotal, weight_memory, gradient_memory, optimizer_memory, master_weight_memory, \
|
| 177 |
weight_memory + gradient_memory + optimizer_memory + master_weight_memory
|
|
|
|
| 307 |
|
| 308 |
# compute_btn.click.function
|
| 309 |
def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_length, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func,
|
| 310 |
+
dp, tp, pp, cp, is_sp, vp, is_dist_opt, b, b_global, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty, record_df, count):
|
| 311 |
# data type trans
|
| 312 |
if is_group_query == "True":
|
| 313 |
group_query_num = int(group_query_num)
|
|
|
|
| 321 |
|
| 322 |
# get model states
|
| 323 |
numParameters, weight_memory, gradient_memory, optimizer_memory, master_weight_memory, model_states_memory = Compute_Model_states(seq_length, vocab_size, layer_num, hidden_size,
|
| 324 |
+
ffn_size, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, dp, tp, pp, cp, is_dist_opt, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty)
|
| 325 |
|
| 326 |
# get activation memory
|
| 327 |
activation_memory = compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp)
|
|
|
|
| 353 |
GPU numbers = {str(gpu_num)}, \n
|
| 354 |
Model parameters = {str(numParametersTotal)} B, \n
|
| 355 |
Model parameters on each device = {str(numParameters)} B, \n
|
| 356 |
+
Model_states = Weight + Gradient + Optimizer = {str(model_states_memory)} GB, \n
|
| 357 |
Activation = {str(activation_memory)} GB, \n
|
| 358 |
Total memory consumption = {str(Total)} GB \n
|
| 359 |
""", record_df, count
|
|
|
|
| 398 |
$$
|
| 399 |
{Activation} =
|
| 400 |
(1 + \frac{pp-1}{pp \times vp}) \times
|
| 401 |
+
\frac{(8BS + BSH) \times pp + (15BSH + 5BS \times FFN) \times L}{tp \times cp}
|
| 402 |
$$
|
| 403 |
|
| 404 |
***
|
|
|
|
| 503 |
)
|
| 504 |
with gr.Accordion("Model Parameters"):
|
| 505 |
# with gr.Row():
|
| 506 |
+
act_func = gr.Radio(["LLaMA", "GPT"], value="LLaMA", label="Model type", info="eg. LLaMa: SwiGLU, RoPE, RMSNorm") #, info="Action Function in MLP, whether to use GLU (Gated Linear Unit). [e.g \"True\" for LlaMA, \"False\" for GPT.]")
|
| 507 |
with gr.Row():
|
| 508 |
vocab_size = gr.Number(label="Vocab size (V)", value=32000)
|
| 509 |
layer_num = gr.Number(label="Layer number (L)", value=32)
|
|
|
|
| 558 |
# with gr.Row():
|
| 559 |
b = gr.Number(label="Micro Batch size (B)", value=4)
|
| 560 |
b_global = gr.Number(label="Global Batch size", value=64)
|
| 561 |
+
precision = gr.Dropdown(["FP32", "BF16"], value="BF16", label="Training precision")
|
| 562 |
+
with gr.Row():
|
| 563 |
+
is_fp8 = gr.Radio(["True", "False"], value="True", label="FP8 Training")
|
| 564 |
+
is_fp8_init = gr.Radio(["True", "False"], value="True", label="FP8 Initialization(will reduce memory)")
|
|
|
|
| 565 |
g_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Gradients Dtype")
|
| 566 |
+
with gr.Row():
|
| 567 |
+
opt_func = gr.Radio(["Adam", "SGD"], value="Adam", label="Optimizer function")
|
| 568 |
+
o_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Optimizer State Dtype")
|
| 569 |
|
| 570 |
compute_btn = gr.Button("Compute")
|
| 571 |
with gr.Tab("Output"):
|
|
|
|
| 600 |
compute_btn.click(
|
| 601 |
fn=Compute_ALL_Model_memory,
|
| 602 |
inputs=[vocab_size, layer_num, hidden_size, ffn_size, sequence_len, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func,
|
| 603 |
+
dp, tp, pp, cp, is_sp, vp, is_dist_opt, b, b_global, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty, record_df, count],
|
| 604 |
outputs=[output_text, record_df, count]
|
| 605 |
)
|
| 606 |
|