harmdevries commited on
Commit
729a063
·
1 Parent(s): 3a743e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -5
app.py CHANGED
@@ -76,13 +76,13 @@ st.caption("Multi-Head Attention")
76
  mha_flop = 2*bs*h*n*(d/h)
77
  mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
78
  c1, c2 = st.columns([2, 3])
79
- att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
80
 
81
  st.caption("Multi-Query Attention")
82
  mqa_flop = 2*bs*h*n*(d/h)
83
  mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
84
  c1, c2 = st.columns([2, 3])
85
- att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
86
 
87
  st.subheader('Output projection')
88
  out_flop = 2*bs*1*d*d
@@ -91,15 +91,17 @@ c1, c2 = st.columns([2, 3])
91
  out_time = print_kernel_execution(c1, c2, out_flop, out_bytes)
92
 
93
  st.subheader('Element-wise ops')
94
- st.write("We also need to take into the softmax layer and layer norm")
95
 
96
  st.caption("Softmax")
97
  softmax_bytes = 2*bs*h*n + 2*bs*h*n
98
  c1, c2 = st.columns([2, 3])
99
  softmax_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
100
 
101
- st.caption("Layer norm")
102
-
 
 
103
 
104
  st.header('MLP')
105
  st.subheader('First Linear')
@@ -113,3 +115,17 @@ mlp2_flop = 2*bs*1*d*4*d
113
  mlp2_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
114
  c1, c2 = st.columns([2, 3])
115
  mlp2_time = print_kernel_execution(c1, c2, mlp2_flop, mlp2_bytes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  mha_flop = 2*bs*h*n*(d/h)
77
  mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
78
  c1, c2 = st.columns([2, 3])
79
+ att2_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
80
 
81
  st.caption("Multi-Query Attention")
82
  mqa_flop = 2*bs*h*n*(d/h)
83
  mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
84
  c1, c2 = st.columns([2, 3])
85
+ att2_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
86
 
87
  st.subheader('Output projection')
88
  out_flop = 2*bs*1*d*d
 
91
  out_time = print_kernel_execution(c1, c2, out_flop, out_bytes)
92
 
93
  st.subheader('Element-wise ops')
94
+ st.write("We also need to take into the softmax layer, layer norm, and residual connection. We assume that these operations are memory bound. ")
95
 
96
  st.caption("Softmax")
97
  softmax_bytes = 2*bs*h*n + 2*bs*h*n
98
  c1, c2 = st.columns([2, 3])
99
  softmax_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
100
 
101
+ st.caption("Layer norm/residual connection")
102
+ ln_bytes = 2*bs*1*d
103
+ ln_flop = 0
104
+ ln_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
105
 
106
  st.header('MLP')
107
  st.subheader('First Linear')
 
115
  mlp2_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
116
  c1, c2 = st.columns([2, 3])
117
  mlp2_time = print_kernel_execution(c1, c2, mlp2_flop, mlp2_bytes)
118
+
119
+ st.subheader('Element-wise ops')
120
+ st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
121
+ ln_bytes = 2*bs*1*d
122
+ ln_flop = 0
123
+ ln_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
124
+
125
+ st.header("Adding it all up")
126
+
127
+ shared_time = out_time + softmax_time + 2*ln_time + mlp1_time + mlp2_time + 3*ln_time
128
+ mha_total_time = qkv_mha_time + att1_mha_time + att2_mha_time + shared_time
129
+ mqa_total_time = qkv_mqa_time + att1_mqa_time + att2_mqa_time + shared_time
130
+ st.write("MHA exec time (ms): " + str(mha_total_time))
131
+ st.write("MQA exec time (ms): " + str(mqa_total_time))