harmdevries commited on
Commit
28b4830
·
1 Parent(s): 8cbefab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -1,17 +1,17 @@
1
  import streamlit as st
2
 
3
- def number_field(label, **args):
4
  c1, c2 = st.columns([2, 4])
5
  c1.write(label)
6
 
7
- return c2.number_input('', **args)
8
 
9
  def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
10
  arith_int = comp_flop/mem_bytes
11
  exec_time = (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
12
 
13
- comp_flop = round(mha_flop/1e9, 2)
14
- mem_bytes = round(mha_bytes/1e6, 2)
15
 
16
  c1.write("GFLOP:")
17
  c2.write(str(comp_flop))
@@ -66,17 +66,17 @@ mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
66
  c1, c2 = st.columns([2, 3])
67
  att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
68
 
69
- st.header('Attention scores: ')
70
  st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
71
  st.subheader("Multi-Head Attention")
72
- mha_flop = 2*bs*h*(d/h)*n
73
- mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
74
  c1, c2 = st.columns([2, 3])
75
  att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
76
 
77
  st.subheader("Multi-Query Attention")
78
- mqa_flop = 2*bs*h*(d/h)*n
79
- mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
80
  c1, c2 = st.columns([2, 3])
81
  att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
82
 
 
1
  import streamlit as st
2
 
3
+ def number_field(label, **kwargs)
4
  c1, c2 = st.columns([2, 4])
5
  c1.write(label)
6
 
7
+ return c2.number_input('', **kwargs)
8
 
9
  def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
10
  arith_int = comp_flop/mem_bytes
11
  exec_time = (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
12
 
13
+ comp_flop = round(comp_flop/1e9, 2)
14
+ mem_bytes = round(comp_bytes/1e6, 2)
15
 
16
  c1.write("GFLOP:")
17
  c2.write(str(comp_flop))
 
66
  c1, c2 = st.columns([2, 3])
67
  att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
68
 
69
+ st.header('Attention scores: attention-value gemm')
70
  st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
71
  st.subheader("Multi-Head Attention")
72
+ mha_flop = 2*bs*h*n*(d/h)
73
+ mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
74
  c1, c2 = st.columns([2, 3])
75
  att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
76
 
77
  st.subheader("Multi-Query Attention")
78
+ mqa_flop = 2*bs*h*n*(d/h)
79
+ mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
80
  c1, c2 = st.columns([2, 3])
81
  att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
82