Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -177,21 +177,21 @@ def Compute_Model_states(seq_length, vocab_size, layer_num, hidden_size, ffn_siz
|
|
| 177 |
weight_memory + gradient_memory + optimizer_memory + master_weight_memory
|
| 178 |
|
| 179 |
# activation memory:
|
| 180 |
-
def compute_activation_memory_attention(
|
| 181 |
# LN 2bsh
|
| 182 |
-
activation_mem_attn_ln = seq_length * b * hidden_size *
|
| 183 |
if is_sp == "False":
|
| 184 |
activation_mem_attn_ln *= tp
|
| 185 |
# attention input X, qkv 2bsh/1bsh
|
| 186 |
-
activation_mem_attn_qkv = seq_length * b * hidden_size *
|
| 187 |
if is_sp == "False":
|
| 188 |
activation_mem_attn_qkv *= tp
|
| 189 |
# attention q 2bsh
|
| 190 |
-
activation_mem_attn_q = seq_length * b * hidden_size *
|
| 191 |
# attention k and v 4bsh
|
| 192 |
-
activation_mem_attn_kv = seq_length * b * kv_hidden_size *
|
| 193 |
# attention proj input 2bsh/1bsh
|
| 194 |
-
activation_mem_attn_proj = seq_length * b * hidden_size *
|
| 195 |
# dropout bsh
|
| 196 |
activation_mem_attn_dropout = seq_length * b * hidden_size
|
| 197 |
if is_sp == "False":
|
|
@@ -208,22 +208,22 @@ def compute_activation_memory_attention(activation_dtype, seq_length, b, hidden_
|
|
| 208 |
)
|
| 209 |
return activation_memory_attn
|
| 210 |
|
| 211 |
-
def compute_activation_memory_mlp(
|
| 212 |
# LN 2bsh
|
| 213 |
-
activation_mem_mlp_ln = seq_length * b * hidden_size *
|
| 214 |
if is_sp == "False":
|
| 215 |
activation_mem_mlp_ln *= tp
|
| 216 |
# FC1 2bsh/1bsh
|
| 217 |
-
activation_mem_mlp_fc1 = seq_length * b * hidden_size *
|
| 218 |
if is_sp == "False":
|
| 219 |
activation_mem_mlp_fc1 *= tp
|
| 220 |
# Act 8bsh
|
| 221 |
if act_func == "LLaMA":
|
| 222 |
-
activation_mem_mlp_act = seq_length * b * ffn_size *
|
| 223 |
else:
|
| 224 |
-
activation_mem_mlp_act = seq_length * b * ffn_size *
|
| 225 |
# FC2 8bsh/4bsh
|
| 226 |
-
activation_mem_mlp_fc2 = seq_length * b * ffn_size *
|
| 227 |
# dropout bsh
|
| 228 |
activation_mem_mlp_dropout = seq_length * b * hidden_size
|
| 229 |
if is_sp == "False":
|
|
@@ -261,25 +261,33 @@ def compute_activation_memory_pp(activation_memory, vp, pp, num_microbatches):
|
|
| 261 |
|
| 262 |
return activation_memory
|
| 263 |
|
| 264 |
-
def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp):
|
| 265 |
# Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
|
| 266 |
# We are trying to compute the maximum activation footprint, so all calculations in this function
|
| 267 |
# are for the first pipeline stage.
|
| 268 |
|
| 269 |
-
# activation dataType
|
| 270 |
-
if
|
| 271 |
-
|
| 272 |
-
else:
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
# kv_hidden_size
|
| 276 |
if is_group_query == "False":
|
| 277 |
group_query_num = head_num
|
| 278 |
kv_hidden_size = hidden_size / head_num * group_query_num
|
| 279 |
|
| 280 |
-
activation_memory_attn = compute_activation_memory_attention(
|
| 281 |
|
| 282 |
-
activation_memory_mlp = compute_activation_memory_mlp(
|
| 283 |
|
| 284 |
activation_memory = activation_memory_attn + activation_memory_mlp
|
| 285 |
|
|
@@ -324,7 +332,7 @@ def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_l
|
|
| 324 |
ffn_size, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, dp, tp, pp, cp, is_dist_opt, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty)
|
| 325 |
|
| 326 |
# get activation memory
|
| 327 |
-
activation_memory = compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp)
|
| 328 |
|
| 329 |
# get model parameters
|
| 330 |
numParametersTotal = Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, head_num, 1, 1)
|
|
|
|
| 177 |
weight_memory + gradient_memory + optimizer_memory + master_weight_memory
|
| 178 |
|
| 179 |
# activation memory:
|
| 180 |
+
def compute_activation_memory_attention(training_dtype, gemm_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp):
|
| 181 |
# LN 2bsh
|
| 182 |
+
activation_mem_attn_ln = seq_length * b * hidden_size * training_dtype
|
| 183 |
if is_sp == "False":
|
| 184 |
activation_mem_attn_ln *= tp
|
| 185 |
# attention input X, qkv 2bsh/1bsh
|
| 186 |
+
activation_mem_attn_qkv = seq_length * b * hidden_size * gemm_dtype
|
| 187 |
if is_sp == "False":
|
| 188 |
activation_mem_attn_qkv *= tp
|
| 189 |
# attention q 2bsh
|
| 190 |
+
activation_mem_attn_q = seq_length * b * hidden_size * training_dtype
|
| 191 |
# attention k and v 4bsh
|
| 192 |
+
activation_mem_attn_kv = seq_length * b * kv_hidden_size * training_dtype * 2
|
| 193 |
# attention proj input 2bsh/1bsh
|
| 194 |
+
activation_mem_attn_proj = seq_length * b * hidden_size * gemm_dtype
|
| 195 |
# dropout bsh
|
| 196 |
activation_mem_attn_dropout = seq_length * b * hidden_size
|
| 197 |
if is_sp == "False":
|
|
|
|
| 208 |
)
|
| 209 |
return activation_memory_attn
|
| 210 |
|
| 211 |
+
def compute_activation_memory_mlp(training_dtype, gemm_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp):
|
| 212 |
# LN 2bsh
|
| 213 |
+
activation_mem_mlp_ln = seq_length * b * hidden_size * training_dtype
|
| 214 |
if is_sp == "False":
|
| 215 |
activation_mem_mlp_ln *= tp
|
| 216 |
# FC1 2bsh/1bsh
|
| 217 |
+
activation_mem_mlp_fc1 = seq_length * b * hidden_size * gemm_dtype
|
| 218 |
if is_sp == "False":
|
| 219 |
activation_mem_mlp_fc1 *= tp
|
| 220 |
# Act 8bsh
|
| 221 |
if act_func == "LLaMA":
|
| 222 |
+
activation_mem_mlp_act = seq_length * b * ffn_size * training_dtype * 2
|
| 223 |
else:
|
| 224 |
+
activation_mem_mlp_act = seq_length * b * ffn_size * training_dtype
|
| 225 |
# FC2 8bsh/4bsh
|
| 226 |
+
activation_mem_mlp_fc2 = seq_length * b * ffn_size * gemm_dtype
|
| 227 |
# dropout bsh
|
| 228 |
activation_mem_mlp_dropout = seq_length * b * hidden_size
|
| 229 |
if is_sp == "False":
|
|
|
|
| 261 |
|
| 262 |
return activation_memory
|
| 263 |
|
| 264 |
+
def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, precision, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp):
|
| 265 |
# Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
|
| 266 |
# We are trying to compute the maximum activation footprint, so all calculations in this function
|
| 267 |
# are for the first pipeline stage.
|
| 268 |
|
| 269 |
+
# activation dataType for Training
|
| 270 |
+
if precision == "FP32":
|
| 271 |
+
training_dtype = 4
|
| 272 |
+
else:
|
| 273 |
+
training_dtype = 2
|
| 274 |
+
|
| 275 |
+
# activation dataType for GEMM
|
| 276 |
+
if precision == "FP32":
|
| 277 |
+
gemm_dtype = 4
|
| 278 |
+
elif is_fp8 == "False":
|
| 279 |
+
gemm_dtype = 2
|
| 280 |
+
else:
|
| 281 |
+
gemm_dtype = 1
|
| 282 |
|
| 283 |
# kv_hidden_size
|
| 284 |
if is_group_query == "False":
|
| 285 |
group_query_num = head_num
|
| 286 |
kv_hidden_size = hidden_size / head_num * group_query_num
|
| 287 |
|
| 288 |
+
activation_memory_attn = compute_activation_memory_attention(training_dtype, gemm_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp)
|
| 289 |
|
| 290 |
+
activation_memory_mlp = compute_activation_memory_mlp(training_dtype, gemm_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp)
|
| 291 |
|
| 292 |
activation_memory = activation_memory_attn + activation_memory_mlp
|
| 293 |
|
|
|
|
| 332 |
ffn_size, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, dp, tp, pp, cp, is_dist_opt, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty)
|
| 333 |
|
| 334 |
# get activation memory
|
| 335 |
+
activation_memory = compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, precision, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp)
|
| 336 |
|
| 337 |
# get model parameters
|
| 338 |
numParametersTotal = Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, head_num, 1, 1)
|