harmdevries commited on
Commit
4885a19
·
1 Parent(s): 1934207

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -28,7 +28,7 @@ mha_flop = round(mha_flop/1e9, 2)
28
  mha_bytes = round(mha_bytes/1e6, 2)
29
 
30
 
31
- st.subheader("Multi-query Attention")
32
  c1, c2 = st.columns([2, 3])
33
  c1.write("GFLOP:")
34
  c2.write(str(mha_flop))
@@ -42,6 +42,8 @@ c2.write(str(mha_time))
42
 
43
  mqa_flop = 2*bs*n*d*(1+2/h)*d
44
  mqa_bytes = 2*bs*n*d + 2*(2/h)*d*d + 2*bs*n*(2/h)*d
 
 
45
 
46
  mqa_flop = round(mqa_flop/1e9, 2)
47
  mqa_bytes = round(mqa_bytes/1e6, 2)
@@ -53,7 +55,9 @@ c2.write(str(mqa_flop))
53
  c1.write("MB:")
54
  c2.write(str(mqa_bytes))
55
  c1.write("Arithm. intensity:")
56
- c2.write(str(mqa_flop/mqa_bytes))
 
 
57
 
58
  st.header('Attention')
59
 
 
28
  mha_bytes = round(mha_bytes/1e6, 2)
29
 
30
 
31
+ st.subheader("Multi-head Attention")
32
  c1, c2 = st.columns([2, 3])
33
  c1.write("GFLOP:")
34
  c2.write(str(mha_flop))
 
42
 
43
  mqa_flop = 2*bs*n*d*(1+2/h)*d
44
  mqa_bytes = 2*bs*n*d + 2*(2/h)*d*d + 2*bs*n*(2/h)*d
45
+ mqa_intensity = mqa_flop/mqa_bytes
46
+ mqa_time = (mqa_flop/TFLOPS + mqa_bytes/GB_S)*1000
47
 
48
  mqa_flop = round(mqa_flop/1e9, 2)
49
  mqa_bytes = round(mqa_bytes/1e6, 2)
 
55
  c1.write("MB:")
56
  c2.write(str(mqa_bytes))
57
  c1.write("Arithm. intensity:")
58
+ c2.write(str(mqa_intensity))
59
+ c1.write("Time (ms):")
60
+ c2.write(str(mqa_time))
61
 
62
  st.header('Attention')
63