Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -177,21 +177,21 @@ def Compute_Model_states(seq_length, vocab_size, layer_num, hidden_size, ffn_siz 
     | 
|
| 177 | 
         
             
                        weight_memory + gradient_memory + optimizer_memory + master_weight_memory
         
     | 
| 178 | 
         | 
| 179 | 
         
             
            # activation memory:
         
     | 
| 180 | 
         
            -
            def compute_activation_memory_attention( 
     | 
| 181 | 
         
             
                # LN 2bsh
         
     | 
| 182 | 
         
            -
                activation_mem_attn_ln = seq_length * b * hidden_size *  
     | 
| 183 | 
         
             
                if is_sp == "False":
         
     | 
| 184 | 
         
             
                    activation_mem_attn_ln *= tp
         
     | 
| 185 | 
         
             
                # attention input X, qkv 2bsh/1bsh
         
     | 
| 186 | 
         
            -
                activation_mem_attn_qkv = seq_length * b * hidden_size *  
     | 
| 187 | 
         
             
                if is_sp == "False":
         
     | 
| 188 | 
         
             
                    activation_mem_attn_qkv *= tp
         
     | 
| 189 | 
         
             
                # attention q 2bsh
         
     | 
| 190 | 
         
            -
                activation_mem_attn_q = seq_length * b * hidden_size *  
     | 
| 191 | 
         
             
                # attention k and v 4bsh
         
     | 
| 192 | 
         
            -
                activation_mem_attn_kv = seq_length * b * kv_hidden_size *  
     | 
| 193 | 
         
             
                # attention proj input 2bsh/1bsh
         
     | 
| 194 | 
         
            -
                activation_mem_attn_proj = seq_length * b * hidden_size *  
     | 
| 195 | 
         
             
                # dropout bsh
         
     | 
| 196 | 
         
             
                activation_mem_attn_dropout = seq_length * b * hidden_size
         
     | 
| 197 | 
         
             
                if is_sp == "False":
         
     | 
| 
         @@ -208,22 +208,22 @@ def compute_activation_memory_attention(activation_dtype, seq_length, b, hidden_ 
     | 
|
| 208 | 
         
             
                )
         
     | 
| 209 | 
         
             
                return activation_memory_attn
         
     | 
| 210 | 
         | 
| 211 | 
         
            -
            def compute_activation_memory_mlp( 
     | 
| 212 | 
         
             
                # LN 2bsh
         
     | 
| 213 | 
         
            -
                activation_mem_mlp_ln = seq_length * b * hidden_size *  
     | 
| 214 | 
         
             
                if is_sp == "False":
         
     | 
| 215 | 
         
             
                    activation_mem_mlp_ln *= tp
         
     | 
| 216 | 
         
             
                # FC1 2bsh/1bsh
         
     | 
| 217 | 
         
            -
                activation_mem_mlp_fc1 = seq_length * b * hidden_size *  
     | 
| 218 | 
         
             
                if is_sp == "False":
         
     | 
| 219 | 
         
             
                    activation_mem_mlp_fc1 *= tp
         
     | 
| 220 | 
         
             
                # Act 8bsh
         
     | 
| 221 | 
         
             
                if act_func == "LLaMA":
         
     | 
| 222 | 
         
            -
                    activation_mem_mlp_act = seq_length * b * ffn_size *  
     | 
| 223 | 
         
             
                else:
         
     | 
| 224 | 
         
            -
                    activation_mem_mlp_act = seq_length * b * ffn_size *  
     | 
| 225 | 
         
             
                # FC2 8bsh/4bsh
         
     | 
| 226 | 
         
            -
                activation_mem_mlp_fc2 = seq_length * b * ffn_size *  
     | 
| 227 | 
         
             
                # dropout bsh
         
     | 
| 228 | 
         
             
                activation_mem_mlp_dropout = seq_length * b * hidden_size
         
     | 
| 229 | 
         
             
                if is_sp == "False":
         
     | 
| 
         @@ -261,25 +261,33 @@ def compute_activation_memory_pp(activation_memory, vp, pp, num_microbatches): 
     | 
|
| 261 | 
         | 
| 262 | 
         
             
                return activation_memory 
         
     | 
| 263 | 
         | 
| 264 | 
         
            -
            def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp):
         
     | 
| 265 | 
         
             
                # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
         
     | 
| 266 | 
         
             
                # We are trying to compute the maximum activation footprint, so all calculations in this function
         
     | 
| 267 | 
         
             
                # are for the first pipeline stage.
         
     | 
| 268 | 
         | 
| 269 | 
         
            -
                # activation dataType
         
     | 
| 270 | 
         
            -
                if  
     | 
| 271 | 
         
            -
                     
     | 
| 272 | 
         
            -
                else: 
     | 
| 273 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 274 | 
         | 
| 275 | 
         
             
                # kv_hidden_size
         
     | 
| 276 | 
         
             
                if is_group_query == "False":
         
     | 
| 277 | 
         
             
                    group_query_num = head_num
         
     | 
| 278 | 
         
             
                kv_hidden_size = hidden_size / head_num * group_query_num
         
     | 
| 279 | 
         | 
| 280 | 
         
            -
                activation_memory_attn = compute_activation_memory_attention( 
     | 
| 281 | 
         | 
| 282 | 
         
            -
                activation_memory_mlp = compute_activation_memory_mlp( 
     | 
| 283 | 
         | 
| 284 | 
         
             
                activation_memory = activation_memory_attn + activation_memory_mlp
         
     | 
| 285 | 
         | 
| 
         @@ -324,7 +332,7 @@ def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_l 
     | 
|
| 324 | 
         
             
                    ffn_size, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, dp, tp, pp, cp, is_dist_opt, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty)
         
     | 
| 325 | 
         | 
| 326 | 
         
             
                # get activation memory 
         
     | 
| 327 | 
         
            -
                activation_memory = compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp)
         
     | 
| 328 | 
         | 
| 329 | 
         
             
                # get model parameters
         
     | 
| 330 | 
         
             
                numParametersTotal = Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, head_num, 1, 1)
         
     | 
| 
         | 
|
| 177 | 
         
             
                        weight_memory + gradient_memory + optimizer_memory + master_weight_memory
         
     | 
| 178 | 
         | 
| 179 | 
         
             
            # activation memory:
         
     | 
| 180 | 
         
            +
            def compute_activation_memory_attention(training_dtype, gemm_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp):
         
     | 
| 181 | 
         
             
                # LN 2bsh
         
     | 
| 182 | 
         
            +
                activation_mem_attn_ln = seq_length * b * hidden_size * training_dtype
         
     | 
| 183 | 
         
             
                if is_sp == "False":
         
     | 
| 184 | 
         
             
                    activation_mem_attn_ln *= tp
         
     | 
| 185 | 
         
             
                # attention input X, qkv 2bsh/1bsh
         
     | 
| 186 | 
         
            +
                activation_mem_attn_qkv = seq_length * b * hidden_size * gemm_dtype
         
     | 
| 187 | 
         
             
                if is_sp == "False":
         
     | 
| 188 | 
         
             
                    activation_mem_attn_qkv *= tp
         
     | 
| 189 | 
         
             
                # attention q 2bsh
         
     | 
| 190 | 
         
            +
                activation_mem_attn_q = seq_length * b * hidden_size * training_dtype
         
     | 
| 191 | 
         
             
                # attention k and v 4bsh
         
     | 
| 192 | 
         
            +
                activation_mem_attn_kv = seq_length * b * kv_hidden_size * training_dtype * 2
         
     | 
| 193 | 
         
             
                # attention proj input 2bsh/1bsh
         
     | 
| 194 | 
         
            +
                activation_mem_attn_proj = seq_length * b * hidden_size * gemm_dtype
         
     | 
| 195 | 
         
             
                # dropout bsh
         
     | 
| 196 | 
         
             
                activation_mem_attn_dropout = seq_length * b * hidden_size
         
     | 
| 197 | 
         
             
                if is_sp == "False":
         
     | 
| 
         | 
|
| 208 | 
         
             
                )
         
     | 
| 209 | 
         
             
                return activation_memory_attn
         
     | 
| 210 | 
         | 
| 211 | 
         
            +
            def compute_activation_memory_mlp(training_dtype, gemm_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp):
         
     | 
| 212 | 
         
             
                # LN 2bsh
         
     | 
| 213 | 
         
            +
                activation_mem_mlp_ln = seq_length * b * hidden_size * training_dtype
         
     | 
| 214 | 
         
             
                if is_sp == "False":
         
     | 
| 215 | 
         
             
                    activation_mem_mlp_ln *= tp
         
     | 
| 216 | 
         
             
                # FC1 2bsh/1bsh
         
     | 
| 217 | 
         
            +
                activation_mem_mlp_fc1 = seq_length * b * hidden_size * gemm_dtype
         
     | 
| 218 | 
         
             
                if is_sp == "False":
         
     | 
| 219 | 
         
             
                    activation_mem_mlp_fc1 *= tp
         
     | 
| 220 | 
         
             
                # Act 8bsh
         
     | 
| 221 | 
         
             
                if act_func == "LLaMA":
         
     | 
| 222 | 
         
            +
                    activation_mem_mlp_act = seq_length * b * ffn_size * training_dtype * 2
         
     | 
| 223 | 
         
             
                else:
         
     | 
| 224 | 
         
            +
                    activation_mem_mlp_act = seq_length * b * ffn_size * training_dtype
         
     | 
| 225 | 
         
             
                # FC2 8bsh/4bsh
         
     | 
| 226 | 
         
            +
                activation_mem_mlp_fc2 = seq_length * b * ffn_size * gemm_dtype
         
     | 
| 227 | 
         
             
                # dropout bsh
         
     | 
| 228 | 
         
             
                activation_mem_mlp_dropout = seq_length * b * hidden_size
         
     | 
| 229 | 
         
             
                if is_sp == "False":
         
     | 
| 
         | 
|
| 261 | 
         | 
| 262 | 
         
             
                return activation_memory 
         
     | 
| 263 | 
         | 
| 264 | 
         
            +
            def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, precision, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp):
         
     | 
| 265 | 
         
             
                # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
         
     | 
| 266 | 
         
             
                # We are trying to compute the maximum activation footprint, so all calculations in this function
         
     | 
| 267 | 
         
             
                # are for the first pipeline stage.
         
     | 
| 268 | 
         | 
| 269 | 
         
            +
                # activation dataType for Training
         
     | 
| 270 | 
         
            +
                if precision == "FP32":
         
     | 
| 271 | 
         
            +
                    training_dtype = 4
         
     | 
| 272 | 
         
            +
                else:
         
     | 
| 273 | 
         
            +
                    training_dtype = 2
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                # activation dataType for GEMM
         
     | 
| 276 | 
         
            +
                if precision == "FP32":
         
     | 
| 277 | 
         
            +
                    gemm_dtype = 4
         
     | 
| 278 | 
         
            +
                elif is_fp8 == "False":
         
     | 
| 279 | 
         
            +
                    gemm_dtype = 2
         
     | 
| 280 | 
         
            +
                else:
         
     | 
| 281 | 
         
            +
                    gemm_dtype = 1
         
     | 
| 282 | 
         | 
| 283 | 
         
             
                # kv_hidden_size
         
     | 
| 284 | 
         
             
                if is_group_query == "False":
         
     | 
| 285 | 
         
             
                    group_query_num = head_num
         
     | 
| 286 | 
         
             
                kv_hidden_size = hidden_size / head_num * group_query_num
         
     | 
| 287 | 
         | 
| 288 | 
         
            +
                activation_memory_attn = compute_activation_memory_attention(training_dtype, gemm_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp)
         
     | 
| 289 | 
         | 
| 290 | 
         
            +
                activation_memory_mlp = compute_activation_memory_mlp(training_dtype, gemm_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp)
         
     | 
| 291 | 
         | 
| 292 | 
         
             
                activation_memory = activation_memory_attn + activation_memory_mlp
         
     | 
| 293 | 
         | 
| 
         | 
|
| 332 | 
         
             
                    ffn_size, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, dp, tp, pp, cp, is_dist_opt, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty)
         
     | 
| 333 | 
         | 
| 334 | 
         
             
                # get activation memory 
         
     | 
| 335 | 
         
            +
                activation_memory = compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, precision, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp)
         
     | 
| 336 | 
         | 
| 337 | 
         
             
                # get model parameters
         
     | 
| 338 | 
         
             
                numParametersTotal = Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, head_num, 1, 1)
         
     |