xxyux commited on
Commit
5c53556
·
verified ·
1 Parent(s): affd6ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -21
app.py CHANGED
@@ -177,21 +177,21 @@ def Compute_Model_states(seq_length, vocab_size, layer_num, hidden_size, ffn_siz
177
  weight_memory + gradient_memory + optimizer_memory + master_weight_memory
178
 
179
  # activation memory:
180
- def compute_activation_memory_attention(activation_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp):
181
  # LN 2bsh
182
- activation_mem_attn_ln = seq_length * b * hidden_size * 2
183
  if is_sp == "False":
184
  activation_mem_attn_ln *= tp
185
  # attention input X, qkv 2bsh/1bsh
186
- activation_mem_attn_qkv = seq_length * b * hidden_size * activation_dtype
187
  if is_sp == "False":
188
  activation_mem_attn_qkv *= tp
189
  # attention q 2bsh
190
- activation_mem_attn_q = seq_length * b * hidden_size * 2
191
  # attention k and v 4bsh
192
- activation_mem_attn_kv = seq_length * b * kv_hidden_size * 2 * 2
193
  # attention proj input 2bsh/1bsh
194
- activation_mem_attn_proj = seq_length * b * hidden_size * activation_dtype
195
  # dropout bsh
196
  activation_mem_attn_dropout = seq_length * b * hidden_size
197
  if is_sp == "False":
@@ -208,22 +208,22 @@ def compute_activation_memory_attention(activation_dtype, seq_length, b, hidden_
208
  )
209
  return activation_memory_attn
210
 
211
- def compute_activation_memory_mlp(activation_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp):
212
  # LN 2bsh
213
- activation_mem_mlp_ln = seq_length * b * hidden_size * 2
214
  if is_sp == "False":
215
  activation_mem_mlp_ln *= tp
216
  # FC1 2bsh/1bsh
217
- activation_mem_mlp_fc1 = seq_length * b * hidden_size * activation_dtype
218
  if is_sp == "False":
219
  activation_mem_mlp_fc1 *= tp
220
  # Act 8bsh
221
  if act_func == "LLaMA":
222
- activation_mem_mlp_act = seq_length * b * ffn_size * 2 * 2
223
  else:
224
- activation_mem_mlp_act = seq_length * b * ffn_size * 2
225
  # FC2 8bsh/4bsh
226
- activation_mem_mlp_fc2 = seq_length * b * ffn_size * activation_dtype
227
  # dropout bsh
228
  activation_mem_mlp_dropout = seq_length * b * hidden_size
229
  if is_sp == "False":
@@ -261,25 +261,33 @@ def compute_activation_memory_pp(activation_memory, vp, pp, num_microbatches):
261
 
262
  return activation_memory
263
 
264
- def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp):
265
  # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
266
  # We are trying to compute the maximum activation footprint, so all calculations in this function
267
  # are for the first pipeline stage.
268
 
269
- # activation dataType
270
- if is_fp8 == "False":
271
- activation_dtype = 2
272
- else:
273
- activation_dtype = 1
 
 
 
 
 
 
 
 
274
 
275
  # kv_hidden_size
276
  if is_group_query == "False":
277
  group_query_num = head_num
278
  kv_hidden_size = hidden_size / head_num * group_query_num
279
 
280
- activation_memory_attn = compute_activation_memory_attention(activation_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp)
281
 
282
- activation_memory_mlp = compute_activation_memory_mlp(activation_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp)
283
 
284
  activation_memory = activation_memory_attn + activation_memory_mlp
285
 
@@ -324,7 +332,7 @@ def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_l
324
  ffn_size, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, dp, tp, pp, cp, is_dist_opt, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty)
325
 
326
  # get activation memory
327
- activation_memory = compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp)
328
 
329
  # get model parameters
330
  numParametersTotal = Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, head_num, 1, 1)
 
177
  weight_memory + gradient_memory + optimizer_memory + master_weight_memory
178
 
179
  # activation memory:
180
+ def compute_activation_memory_attention(training_dtype, gemm_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp):
181
  # LN 2bsh
182
+ activation_mem_attn_ln = seq_length * b * hidden_size * training_dtype
183
  if is_sp == "False":
184
  activation_mem_attn_ln *= tp
185
  # attention input X, qkv 2bsh/1bsh
186
+ activation_mem_attn_qkv = seq_length * b * hidden_size * gemm_dtype
187
  if is_sp == "False":
188
  activation_mem_attn_qkv *= tp
189
  # attention q 2bsh
190
+ activation_mem_attn_q = seq_length * b * hidden_size * training_dtype
191
  # attention k and v 4bsh
192
+ activation_mem_attn_kv = seq_length * b * kv_hidden_size * training_dtype * 2
193
  # attention proj input 2bsh/1bsh
194
+ activation_mem_attn_proj = seq_length * b * hidden_size * gemm_dtype
195
  # dropout bsh
196
  activation_mem_attn_dropout = seq_length * b * hidden_size
197
  if is_sp == "False":
 
208
  )
209
  return activation_memory_attn
210
 
211
+ def compute_activation_memory_mlp(training_dtype, gemm_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp):
212
  # LN 2bsh
213
+ activation_mem_mlp_ln = seq_length * b * hidden_size * training_dtype
214
  if is_sp == "False":
215
  activation_mem_mlp_ln *= tp
216
  # FC1 2bsh/1bsh
217
+ activation_mem_mlp_fc1 = seq_length * b * hidden_size * gemm_dtype
218
  if is_sp == "False":
219
  activation_mem_mlp_fc1 *= tp
220
  # Act 8bsh
221
  if act_func == "LLaMA":
222
+ activation_mem_mlp_act = seq_length * b * ffn_size * training_dtype * 2
223
  else:
224
+ activation_mem_mlp_act = seq_length * b * ffn_size * training_dtype
225
  # FC2 8bsh/4bsh
226
+ activation_mem_mlp_fc2 = seq_length * b * ffn_size * gemm_dtype
227
  # dropout bsh
228
  activation_mem_mlp_dropout = seq_length * b * hidden_size
229
  if is_sp == "False":
 
261
 
262
  return activation_memory
263
 
264
+ def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, precision, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp):
265
  # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
266
  # We are trying to compute the maximum activation footprint, so all calculations in this function
267
  # are for the first pipeline stage.
268
 
269
+ # activation dataType for Training
270
+ if precision == "FP32":
271
+ training_dtype = 4
272
+ else:
273
+ training_dtype = 2
274
+
275
+ # activation dataType for GEMM
276
+ if precision == "FP32":
277
+ gemm_dtype = 4
278
+ elif is_fp8 == "False":
279
+ gemm_dtype = 2
280
+ else:
281
+ gemm_dtype = 1
282
 
283
  # kv_hidden_size
284
  if is_group_query == "False":
285
  group_query_num = head_num
286
  kv_hidden_size = hidden_size / head_num * group_query_num
287
 
288
+ activation_memory_attn = compute_activation_memory_attention(training_dtype, gemm_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp)
289
 
290
+ activation_memory_mlp = compute_activation_memory_mlp(training_dtype, gemm_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp)
291
 
292
  activation_memory = activation_memory_attn + activation_memory_mlp
293
 
 
332
  ffn_size, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, dp, tp, pp, cp, is_dist_opt, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty)
333
 
334
  # get activation memory
335
+ activation_memory = compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, precision, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp)
336
 
337
  # get model parameters
338
  numParametersTotal = Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, head_num, 1, 1)