xxyux commited on
Commit
9791f0c
·
verified ·
1 Parent(s): 4a329eb

Init commit

Browse files
Files changed (1) hide show
  1. app.py +507 -59
app.py CHANGED
@@ -1,63 +1,511 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
  import gradio as gr
2
+ import pandas as pd
3
+
4
+ col=['Layer number', 'Hidden size', 'FFN Hidden size', 'Sequence length', 'Head number', 'Group number',
5
+ 'dp', 'tp', 'pp', 'cp', 'GPU numbers', 'Batch size', 'FP8', 'Model parameters', 'Model_states', 'Activation', 'Total']
6
+
7
+ # global data
8
+ table_data = pd.DataFrame(columns=col)
9
+
10
+ def Get_GigaByte(memory):
11
+ return memory / 1024**3
12
+
13
+ def Get_BillionParameter(parameter):
14
+ return parameter / 1000**3
15
+
16
+ # model states:
17
+ def Compute_Parameters_input(hidden_size, vocab_size, tp):
18
+ num_parameters_word_embedding = hidden_size * vocab_size / tp
19
+ num_parameters_position_embedding = 0 #args.hidden_size * args.seq_length
20
+ return num_parameters_word_embedding + num_parameters_position_embedding
21
+
22
+ def Compute_Parameters_output(hidden_size, vocab_size, tp):
23
+ num_parameters_output_layernorm = 2 * hidden_size
24
+ num_parameters_output_embedding = 0 # due to sharedWordEmbedding
25
+ return num_parameters_output_layernorm + num_parameters_output_embedding
26
+
27
+ def Compute_Parameters_attention(hidden_size, kv_hidden_size, is_bias, tp):
28
+ # attention:
29
+ # layernorm: 2h
30
+ num_parameters_attention = 2 * hidden_size
31
+ # QKV weight: 3h*h/tp, bias: 3h/tp
32
+ # output linear weight: h*h/tp, bias: h
33
+ num_parameters_attention_Q_weight = hidden_size * hidden_size / tp
34
+ num_parameters_attention_KV_weight = 2 * kv_hidden_size * hidden_size / tp
35
+ num_parameters_attention_Linear_weight = hidden_size * hidden_size / tp
36
+
37
+ num_parameters_attention += num_parameters_attention_Q_weight + num_parameters_attention_KV_weight + num_parameters_attention_Linear_weight
38
+ if is_bias == "True":
39
+ num_parameters_attention += (hidden_size + 2 * kv_hidden_size) / tp + hidden_size
40
+
41
+ return num_parameters_attention
42
+
43
+ def Compute_Parameters_mlp(hidden_size, ffn_size, is_bias, act_func, tp):
44
+ # MLP:
45
+ # layernorm: 2h
46
+ num_parameters_mlp = 2 * hidden_size
47
+ # mlp1 weight: h*ffn/tp, bias: ffn/tp
48
+ # mlp2 weight: ffn*h/tp, bias: h
49
+ if act_func == "True":
50
+ num_parameters_mlp += hidden_size * ffn_size * 3 / tp
51
+ if is_bias == "True":
52
+ num_parameters_mlp += ffn_size * 2 / tp + hidden_size
53
+ else:
54
+ num_parameters_mlp += hidden_size * ffn_size * 2 / tp
55
+ if is_bias == "True":
56
+ num_parameters_mlp += ffn_size / tp + hidden_size
57
+
58
+ return num_parameters_mlp
59
+
60
+ def Compute_Parameters(vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, act_func, head_num, tp, pp):
61
+ if is_group_query == "False":
62
+ group_query_num = head_num
63
+ kv_hidden_size = hidden_size / head_num * group_query_num
64
+
65
+ # input part
66
+ num_parameters_input = Compute_Parameters_input(hidden_size, vocab_size, tp)
67
+
68
+ # middle layers part
69
+ num_parameters_attention = Compute_Parameters_attention(hidden_size, kv_hidden_size, is_bias, tp)
70
+ num_parameters_mlp = Compute_Parameters_mlp(hidden_size, ffn_size, is_bias, act_func, tp)
71
+ num_parameters_in_single_layer = num_parameters_attention + num_parameters_mlp
72
+ num_parameters_in_total_layers = num_parameters_in_single_layer * layer_num / pp
73
+
74
+ # output part
75
+ parameters_output = Compute_Parameters_output(hidden_size, vocab_size, tp)
76
+
77
+ if pp == 1:
78
+ num_parameters_total = (
79
+ num_parameters_input
80
+ + num_parameters_in_total_layers
81
+ + parameters_output # num_parameters_output_layernorm
82
+ )
83
+ else:
84
+ num_parameters_total = (
85
+ num_parameters_input
86
+ + num_parameters_in_total_layers
87
+ )
88
+
89
+ return num_parameters_total
90
+
91
+ def Compute_Weight(numParametersTotal, is_fp8, is_fp8_init):
92
+ if is_fp8 == "False":
93
+ weight_memory = 2 * numParametersTotal
94
+ elif is_fp8_init == "False":
95
+ weight_memory = 4 * numParametersTotal
96
+ else:
97
+ weight_memory = 2 * numParametersTotal
98
+
99
+ return weight_memory
100
+
101
+ def Compute_Gradient(numParametersTotal, g_ty):
102
+ if g_ty == "FP32":
103
+ gradient_memory = 4 * numParametersTotal
104
+ elif g_ty =="BF16":
105
+ gradient_memory = 2 * numParametersTotal
106
+
107
+ return gradient_memory
108
+
109
+ def Compute_Optimizer_states(numParametersTotal, o_ty, is_dist_opt, dp, cp):
110
+ if o_ty == "FP32":
111
+ optimizer_memory = 4 * 2 * numParametersTotal
112
+ elif o_ty =="BF16":
113
+ optimizer_memory = 2 * 2 * numParametersTotal
114
+
115
+ if is_dist_opt == "True":
116
+ optimizer_memory = optimizer_memory / (dp * cp)
117
+
118
+ return optimizer_memory
119
+
120
+ def Compute_Master_weight(numParametersTotal, is_dist_opt, dp, cp):
121
+ master_weight_memory = 4 * numParametersTotal
122
+ if is_dist_opt == "True":
123
+ master_weight_memory = master_weight_memory / (dp * cp)
124
+
125
+ return master_weight_memory
126
+
127
+ def Compute_Model_states(vocab_size, layer_num, hidden_size, ffn_size, head_num, is_group_query, group_query_num, is_bias, act_func,
128
+ dp, tp, pp, cp, is_dist_opt, is_fp8, is_fp8_init, g_ty, o_ty):
129
+ numParametersTotal = Compute_Parameters(vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, act_func, head_num, tp, pp)
130
+
131
+ weight_memory = Compute_Weight(numParametersTotal, is_fp8, is_fp8_init)
132
+ gradient_memory = Compute_Gradient(numParametersTotal, g_ty)
133
+ optimizer_memory = Compute_Optimizer_states(numParametersTotal, o_ty, is_dist_opt, dp, cp)
134
+ master_weight_memory = Compute_Master_weight(numParametersTotal, is_dist_opt, dp, cp)
135
+
136
+ return numParametersTotal, weight_memory, gradient_memory, optimizer_memory, master_weight_memory, \
137
+ weight_memory + gradient_memory + optimizer_memory + master_weight_memory
138
+
139
+ # activation memory:
140
+ def compute_activation_memory_attention(activation_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp):
141
+ # LN 2bsq
142
+ activation_mem_attn_ln = seq_length * b * hidden_size * 2
143
+ if is_sp == "False":
144
+ activation_mem_attn_ln *= tp
145
+ # attention input X, qkv 2bsh/1bsh
146
+ activation_mem_attn_qkv = seq_length * b * hidden_size * activation_dtype
147
+ if is_sp == "False":
148
+ activation_mem_attn_qkv *= tp
149
+ # attention q 2bsh
150
+ activation_mem_attn_q = seq_length * b * hidden_size * 2
151
+ # attention k and v 4bsh
152
+ activation_mem_attn_kv = seq_length * b * kv_hidden_size * 2 * 2
153
+ # attention proj input 2bsh/1bsh
154
+ activation_mem_attn_proj = seq_length * b * hidden_size * activation_dtype
155
+ # dropout bsh
156
+ activation_mem_attn_dropout = seq_length * b * hidden_size
157
+ if is_sp == "False":
158
+ activation_mem_attn_dropout *= tp
159
+ # bf16: 2+2+2+4+2+1=13bsh
160
+ # fp8: 2+1+2+4+1+1=11bsh
161
+ activation_memory_attn = (
162
+ activation_mem_attn_ln
163
+ + activation_mem_attn_qkv
164
+ + activation_mem_attn_q
165
+ + activation_mem_attn_kv
166
+ + activation_mem_attn_proj
167
+ + activation_mem_attn_dropout
168
+ )
169
+ return activation_memory_attn
170
+
171
+ def compute_activation_memory_mlp(activation_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp):
172
+ # LN 2bsh
173
+ activation_mem_mlp_ln = seq_length * b * hidden_size * 2
174
+ if is_sp == "False":
175
+ activation_mem_mlp_ln *= tp
176
+ # FC1 2bsh/1bsh
177
+ activation_mem_mlp_fc1 = seq_length * b * hidden_size * activation_dtype
178
+ if is_sp == "False":
179
+ activation_mem_mlp_fc1 *= tp
180
+ # Act 8bsh
181
+ if act_func == "Swiglu":
182
+ activation_mem_mlp_act = seq_length * b * ffn_size * 2 * 2
183
+ else:
184
+ activation_mem_mlp_act = seq_length * b * ffn_size * 2
185
+ # FC2 8bsh/4bsh
186
+ activation_mem_mlp_fc2 = seq_length * b * ffn_size * activation_dtype
187
+ # dropout bsh
188
+ activation_mem_mlp_dropout = seq_length * b * hidden_size
189
+ if is_sp == "False":
190
+ activation_mem_mlp_dropout *= tp
191
+ # bf16: 2+2+8+8+1=21
192
+ # fp8: 2+1+8+4+1=16
193
+ activation_memory_mlp = (
194
+ activation_mem_mlp_ln
195
+ + activation_mem_mlp_fc1
196
+ + activation_mem_mlp_act
197
+ + activation_mem_mlp_fc2
198
+ + activation_mem_mlp_dropout
199
+ )
200
+ return activation_memory_mlp
201
+
202
+ def compute_activation_memory_input(seq_length, b, hidden_size, pp):
203
+ # embedding + Dropout
204
+ return 8 * seq_length * b * pp + seq_length * b * hidden_size * pp
205
+
206
+ def compute_activation_memory_output(seq_length, b, hidden_size, vocab_size):
207
+ # Inputs to output layer and CE loss(bf16, fp32 * 2).
208
+ return 2 * seq_length * b * hidden_size + (2 + 4 + 4) * seq_length * b * vocab_size
209
+
210
+ def compute_activation_memory_pp(activation_memory, is_ip, vp, pp, num_microbatches):
211
+ # Multiply by interleaved PP memory factor.
212
+ if is_ip == "True":
213
+ interleaved_schedule_memory_penalty = 1 + (pp - 1) / (pp * vp)
214
+ activation_memory *= interleaved_schedule_memory_penalty
215
+
216
+ # If using non-interleaved schedule, number of microbatches in pipeline can be less than pp_size,
217
+ # so discount accordingly.
218
+ if is_ip == "False" and pp > 1:
219
+ if num_microbatches > 1:
220
+ activation_memory *= min(1, num_microbatches / pp)
221
+
222
+ return activation_memory
223
+
224
+ def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, is_ip, vp):
225
+ # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
226
+ # We are trying to compute the maximum activation footprint, so all calculations in this function
227
+ # are for the first pipeline stage.
228
+
229
+ # activation dataType
230
+ if is_fp8 == "False":
231
+ activation_dtype = 2
232
+ else:
233
+ activation_dtype = 1
234
+
235
+ # kv_hidden_size
236
+ if is_group_query == "False":
237
+ group_query_num = head_num
238
+ kv_hidden_size = hidden_size / head_num * group_query_num
239
+
240
+ activation_memory_attn = compute_activation_memory_attention(activation_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp)
241
+
242
+ activation_memory_mlp = compute_activation_memory_mlp(activation_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp)
243
+
244
+ activation_memory = activation_memory_attn + activation_memory_mlp
245
+
246
+ activation_memory *= layer_num
247
+
248
+ # Now add activation memory required for input embeddings, last LayerNorm and output layer.
249
+ # Input to embedding (pp_size microbatches in flight).
250
+ activation_memory_input = compute_activation_memory_input(seq_length, b, hidden_size, pp)
251
+ activation_memory += activation_memory_input
252
+
253
+ # get num_microbatches
254
+ num_microbatches = b_global / b / dp / cp
255
+ activation_memory = compute_activation_memory_pp(activation_memory, is_ip, vp, pp, num_microbatches)
256
+
257
+ if pp == 1:
258
+ # Inputs to output layer and CE loss(fp32).
259
+ activation_memory_output = compute_activation_memory_output(seq_length, b, hidden_size, vocab_size)
260
+ activation_memory += activation_memory_output
261
+ elif pp > 1:
262
+ # Sendrecv memory
263
+ activation_memory += seq_length * b * hidden_size * 2
264
+
265
+ # Activation memory is partitioned by TP size due to tensor and sequence model parallelism.
266
+ return activation_memory / tp / cp
267
+
268
+ # compute_btn.click.function
269
+ def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_length, head_num, is_group_query, group_query_num, is_bias, act_func,
270
+ dp, tp, pp, cp, is_sp, is_ip, vp, is_dist_opt, b, b_global, is_fp8, is_fp8_init, g_ty, o_ty):
271
+ # get model states
272
+ numParameters, weight_memory, gradient_memory, optimizer_memory, master_weight_memory, model_states_memory = Compute_Model_states(vocab_size, layer_num, hidden_size,
273
+ ffn_size, head_num, is_group_query, group_query_num, is_bias, act_func, dp, tp, pp, cp, is_dist_opt, is_fp8, is_fp8_init, g_ty, o_ty)
274
+
275
+ # get activation memory
276
+ activation_memory = compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, is_ip, vp)
277
+
278
+ # get model parameters
279
+ numParametersTotal = Compute_Parameters(vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, act_func, head_num, 1, 1)
280
+ # get gpu number
281
+ gpu_num = dp * tp * pp * cp
282
+
283
+ # get B/GB
284
+ numParametersTotal = round(Get_BillionParameter(numParametersTotal), 3)
285
+ numParameters = round(Get_BillionParameter(numParameters), 3)
286
+ model_states_memory = round(Get_GigaByte(model_states_memory), 3)
287
+ activation_memory = round(Get_GigaByte(activation_memory), 3)
288
+ Total = round(model_states_memory + activation_memory, 3)
289
+
290
+ # record
291
+ global table_data
292
+ new_row = pd.DataFrame([[layer_num, hidden_size, ffn_size, seq_length, head_num, group_query_num, dp, tp, pp, cp, gpu_num, b, is_fp8,
293
+ numParametersTotal, model_states_memory, activation_memory, Total]],
294
+ columns=col)
295
+ table_data = table_data._append(new_row, ignore_index=True)
296
+
297
+ # return str(gpu_num), str(model_states) + " GB", str(activation) + " GB", str(total) + " GB", table_data
298
+ return f"""
299
+ GPU numbers = {str(gpu_num)}, \n
300
+ Total model parameters = {str(numParametersTotal)} B, \n
301
+ Model parameters = {str(numParameters)} B, \n
302
+ Model_states = {str(model_states_memory)} GB, \n
303
+ Activation = {str(activation_memory)} GB, \n
304
+ Total memory consumption = {str(Total)} GB \n
305
+ """, table_data
306
+
307
+ def generate_csv():
308
+ # 创建示例 DataFrame
309
+ df = table_data
310
+
311
+ # 将 DataFrame 保存为 CSV 文件
312
+ csv_filename = "data.csv"
313
+ df.to_csv(csv_filename, index=False)
314
+
315
+ # 返回 CSV 文件路径
316
+ return csv_filename
317
+
318
+ # P = P_{input} + P_{middle} +
319
+ # \begin{cases}
320
+ # P_{output}, & \text{if }pp = 1 \\\\
321
+ # 0, & \text{if }pp > 1
322
+ # \end{cases} \\\\
323
+
324
+ # formula string
325
+ formula = r"""
326
+ > **Note**🔑: In this formula, we assume LLM training with FP32 Gradient and Optimizer state, and bias = False, Zero1 = False, SP = True.
327
+
328
+ <!-- parameters: -->
329
+ $$
330
+ P_{input} = \frac{HV}{tp}, \quad
331
+ P_{output} = 2H \\\\
332
+ P_{attn} = 2H + \frac{2H^2 + 2H_{KV} \times H}{tp}, \quad
333
+ P_{MLP} = 2H +
334
+ \\begin{cases}
335
+ \frac{3H \times FFN}{tp}, & \text{if }GLU\text{ is True} \\\\
336
+ \frac{2H \times FFN}{tp}, & \text{if }GLU\text{ is False}
337
+ \\end{cases} \\\\
338
+ P_{middle} = \frac{(P_{attn} + P_{MLP}) \times L}{pp} \\\\
339
+ P = P_{input} + P_{middle} +
340
+ \\begin{cases}
341
+ P_{output}, & \text{if }pp = 1 \\\\
342
+ 0, & \text{if }pp > 1
343
+ \\end{cases} \\\\
344
+ {Total\ Model\ parameters} =
345
+ \\begin{cases}
346
+ P, & \text{set tp = 1, pp = 1} \\\\
347
+ 2HV + 2H + (4H + 2H^2 + 2H_{KV} \times H + 3FFN \times H) \times L, & \text{general formula}
348
+ \\end{cases} \\\\
349
+ {Model\ states} = {Model\ weight} + {Gradient} + {Optimizer\ state} + {Master\ weight} =
350
+ \\begin{cases}
351
+ 18P, & \text{BF16 training} \\\\
352
+ 18P, & \text{FP8 training with FP8 Init} \\\\
353
+ 20P, & \text{FP8 training w/o FP8 Init}
354
+ \\end{cases} \\\\
355
+ $$
356
+
357
+ ***
358
+
359
+ <!-- activations: -->
360
+ $$
361
+ A_{input} = (8SB + SBH) \times pp, \quad
362
+ A_{output} = 2SBH +
363
+ \\begin{cases}
364
+ 10SBV, & \text{if }pp\text{ = 1} \\\\
365
+ 0, & \text{if }pp\text{ > 1}
366
+ \\end{cases} \\\\
367
+ A_{attn} = 5SBH + 4SB \times H_{KV} +
368
+ \\begin{cases}
369
+ 2SBH, & \text{if } FP8 \text{ is True} \\\\
370
+ 4SBH, & \text{if } FP8 \text{ is False}
371
+ \\end{cases} \\\\
372
+ A_{MLP} = 3SBH +
373
+ \\begin{cases}
374
+ SBH + SB \times FFN + 4SB \times FFN, & \text{if }FP8 \text{ is True and }GLU \text{ is True} \\\\
375
+ 2SBH + 2SB \times FFN + 4SB \times FFN, & \text{if }FP8 \text{ is False and }GLU \text{ is True} \\\\
376
+ SBH + SB \times FFN + 2SB \times FFN, & \text{if }FP8 \text{ is True and }GLU \text{ is False} \\\\
377
+ 2SBH + 2SB \times FFN + 2SB \times FFN, & \text{if }FP8 \text{ is False and }GLU \text{ is False}
378
+ \\end{cases} \\\\
379
+ A_{middle} = (A_{attn} + A_{MLP}) \times L \\\\
380
+ A_{ip} = (A_{input} + A_{middle}) \times
381
+ \\begin{cases}
382
+ (1 + \frac{pp - 1}{pp \times vp}), & \text{if } Interleaved\ Pipeline \text{ is True} \\\\
383
+ min(1, \frac{microbatch}{pp}), & \text{if } Interleaved\ Pipeline \text{ is False and pp > 1} \\\\
384
+ 1, & \text{other}
385
+ \\end{cases} \\\\
386
+ Activation =
387
+ \\begin{cases}
388
+ \frac{A_{ip} + A_{output}}{tp \times cp}, & \text{if pp = 1} \\\\
389
+ \frac{A_{ip} + 2BSH}{tp \times cp}, & \text{if pp > 1}
390
+ \\end{cases}
391
+ $$
392
+
393
+ ***
394
+
395
+ $$
396
+ \\begin{gather}
397
+ {GPU\ numbers} = tp \times pp \times dp \times cp\\\\
398
+ {Total\ memory\ consumption} = {Model\ states} + Activation
399
+ \\end{gather}
400
+ $$
401
+ """
402
+
403
+ with gr.Blocks() as demo:
404
+ with gr.Row():
405
+ # Text
406
+ gr.Markdown(
407
+ """
408
+ <div style="text-align: center;">
409
+ <h1>GPU memory calculator 🌀</h1>
410
+ <p style="font-size:16px;">Here's a GPU memory calculator, it helps you to compute memory comsumption in LLM training. </p>
411
+ </div>
412
+ """
413
+ )
414
+
415
+ with gr.Column():
416
+ # Input 1.[Model Parameters]
417
+ gr.Markdown(
418
+ """
419
+ <h1>Model Parameters:</h1>
420
+ """
421
+ )
422
+ with gr.Accordion("Model Parameters"):
423
+ act_func = gr.Radio(["True", "False"], value="True", label="Model type", info="Action Function in MLP, whether to use GLU (Gated Linear Unit). [e.g \"True\" for LlaMA, \"False\" for GPT.]")
424
+ vocab_size = gr.Number(label="Vocab size", value=32000)
425
+ layer_num = gr.Number(label="Layer number", value=32)
426
+ hidden_size = gr.Number(label="Hidden size", value=4096)
427
+ ffn_size = gr.Number(label="FFN Hidden size", value=11008)
428
+ sequence_len = gr.Number(label="Sequence length", value=1024)
429
+ head_num = gr.Number(label="Number of Attention Heads", value=32)
430
+ with gr.Row():
431
+ is_group_query = gr.Radio(["True", "False"], value="True", label="Use Group Query Attention")
432
+ group_query_num = gr.Number(label="Number of Query Groups", value=96)
433
+ is_bias = gr.Radio(["True", "False"], value="False", label="Use Bias")
434
+
435
+ # Input 2.[Parallelism]
436
+ gr.Markdown(
437
+ """
438
+ <h1>Parallelism config:</h1>
439
+ """
440
+ )
441
+ with gr.Accordion("Parallelism config"):
442
+ dp = gr.Number(label="Data parallelism", value=1)
443
+ tp = gr.Number(label="Tensor parallelism", value=2)
444
+ pp = gr.Number(label="Pipeline parallelism", value=2)
445
+ cp = gr.Number(label="Context parallelism", value=2)
446
+ is_sp = gr.Radio(["True", "False"], value="True", label="Sequence parallelism")
447
+ with gr.Row():
448
+ is_ip = gr.Radio(["True", "False"], value="False", label="Use Interleaved Pipeline")
449
+ vp = gr.Number(label="Virtual Pipeline Size")
450
+ is_dist_opt = gr.Radio(["True", "False"], value="True", label="Use Distributed Optimizer(Zero1)")
451
+
452
+ # Input 3.[Training Settings]
453
+ gr.Markdown(
454
+ """
455
+ <h1>Training Config:</h1>
456
+ """
457
+ )
458
+ with gr.Accordion("Training Config"):
459
+ b = gr.Number(label="Micro Batch size", value=4)
460
+ b_global = gr.Number(label="Global Batch size", value=64)
461
+ gr.Checkbox(label="True", value=True, info="BF16 Training")
462
+ is_fp8 = gr.Radio(["True", "False"], value="True", label="FP8 Training")
463
+ is_fp8_init = gr.Radio(["True", "False"], value="True", label="FP8 Initialization(will reduce memory)")
464
+ g_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Gradients Dtype")
465
+ o_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Optimizer State Dtype")
466
+
467
+ with gr.Column():
468
+ gr.Markdown(
469
+ """
470
+ <h1>Output Data:</h1>
471
+ """
472
+ )
473
+ formula = formula
474
+
475
+ gr.Markdown(
476
+ formula
477
+ , latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }]
478
+ )
479
+
480
+ output_text = gr.Textbox(
481
+ label="Compute result",
482
+ interactive=False,
483
+ )
484
+
485
+ # Button
486
+ with gr.Row():
487
+ compute_btn = gr.Button("Compute")
488
+ download_btn = gr.Button("Download")
489
+
490
+ record_df = gr.Dataframe(
491
+ label="Record Table",
492
+ headers=col
493
+ )
494
+
495
+ compute_btn.click(
496
+ fn=Compute_ALL_Model_memory,
497
+ inputs=[vocab_size, layer_num, hidden_size, ffn_size, sequence_len, head_num, is_group_query, group_query_num, is_bias, act_func,
498
+ dp, tp, pp, cp, is_sp, is_ip, vp, is_dist_opt, b, b_global, is_fp8, is_fp8_init, g_ty, o_ty],
499
+ outputs=[output_text, record_df]
500
+ )
501
+
502
+ output_file=gr.File(label="When you click the download button, the downloaded form will be displayed here.")
503
+ # download func
504
+ download_btn.click(
505
+ fn=generate_csv,
506
+ outputs=output_file
507
+ )
508
 
509
 
510
  if __name__ == "__main__":
511
+ demo.launch()