Spaces:
Runtime error
Runtime error
Commit
·
729a063
1
Parent(s):
3a743e7
Update app.py
Browse files
app.py
CHANGED
@@ -76,13 +76,13 @@ st.caption("Multi-Head Attention")
|
|
76 |
mha_flop = 2*bs*h*n*(d/h)
|
77 |
mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
|
78 |
c1, c2 = st.columns([2, 3])
|
79 |
-
|
80 |
|
81 |
st.caption("Multi-Query Attention")
|
82 |
mqa_flop = 2*bs*h*n*(d/h)
|
83 |
mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
|
84 |
c1, c2 = st.columns([2, 3])
|
85 |
-
|
86 |
|
87 |
st.subheader('Output projection')
|
88 |
out_flop = 2*bs*1*d*d
|
@@ -91,15 +91,17 @@ c1, c2 = st.columns([2, 3])
|
|
91 |
out_time = print_kernel_execution(c1, c2, out_flop, out_bytes)
|
92 |
|
93 |
st.subheader('Element-wise ops')
|
94 |
-
st.write("We also need to take into the softmax layer and
|
95 |
|
96 |
st.caption("Softmax")
|
97 |
softmax_bytes = 2*bs*h*n + 2*bs*h*n
|
98 |
c1, c2 = st.columns([2, 3])
|
99 |
softmax_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
|
100 |
|
101 |
-
st.caption("Layer norm")
|
102 |
-
|
|
|
|
|
103 |
|
104 |
st.header('MLP')
|
105 |
st.subheader('First Linear')
|
@@ -113,3 +115,17 @@ mlp2_flop = 2*bs*1*d*4*d
|
|
113 |
mlp2_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
|
114 |
c1, c2 = st.columns([2, 3])
|
115 |
mlp2_time = print_kernel_execution(c1, c2, mlp2_flop, mlp2_bytes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
mha_flop = 2*bs*h*n*(d/h)
|
77 |
mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
|
78 |
c1, c2 = st.columns([2, 3])
|
79 |
+
att2_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
80 |
|
81 |
st.caption("Multi-Query Attention")
|
82 |
mqa_flop = 2*bs*h*n*(d/h)
|
83 |
mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
|
84 |
c1, c2 = st.columns([2, 3])
|
85 |
+
att2_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
86 |
|
87 |
st.subheader('Output projection')
|
88 |
out_flop = 2*bs*1*d*d
|
|
|
91 |
out_time = print_kernel_execution(c1, c2, out_flop, out_bytes)
|
92 |
|
93 |
st.subheader('Element-wise ops')
|
94 |
+
st.write("We also need to take into the softmax layer, layer norm, and residual connection. We assume that these operations are memory bound. ")
|
95 |
|
96 |
st.caption("Softmax")
|
97 |
softmax_bytes = 2*bs*h*n + 2*bs*h*n
|
98 |
c1, c2 = st.columns([2, 3])
|
99 |
softmax_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
|
100 |
|
101 |
+
st.caption("Layer norm/residual connection")
|
102 |
+
ln_bytes = 2*bs*1*d
|
103 |
+
ln_flop = 0
|
104 |
+
ln_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
|
105 |
|
106 |
st.header('MLP')
|
107 |
st.subheader('First Linear')
|
|
|
115 |
mlp2_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
|
116 |
c1, c2 = st.columns([2, 3])
|
117 |
mlp2_time = print_kernel_execution(c1, c2, mlp2_flop, mlp2_bytes)
|
118 |
+
|
119 |
+
st.subheader('Element-wise ops')
|
120 |
+
st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
|
121 |
+
ln_bytes = 2*bs*1*d
|
122 |
+
ln_flop = 0
|
123 |
+
ln_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
|
124 |
+
|
125 |
+
st.header("Adding it all up")
|
126 |
+
|
127 |
+
shared_time = out_time + softmax_time + 2*ln_time + mlp1_time + mlp2_time + 3*ln_time
|
128 |
+
mha_total_time = qkv_mha_time + att1_mha_time + att2_mha_time + shared_time
|
129 |
+
mqa_total_time = qkv_mqa_time + att1_mqa_time + att2_mqa_time + shared_time
|
130 |
+
st.write("MHA exec time (ms): " + str(mha_total_time))
|
131 |
+
st.write("MQA exec time (ms): " + str(mqa_total_time))
|