Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						28b4830
	
1
								Parent(s):
							
							8cbefab
								
Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,17 +1,17 @@ | |
| 1 | 
             
            import streamlit as st
         | 
| 2 |  | 
| 3 | 
            -
            def number_field(label, ** | 
| 4 | 
             
              c1, c2 = st.columns([2, 4])
         | 
| 5 | 
             
              c1.write(label)
         | 
| 6 |  | 
| 7 | 
            -
              return c2.number_input('', ** | 
| 8 |  | 
| 9 | 
             
            def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
         | 
| 10 | 
             
              arith_int = comp_flop/mem_bytes
         | 
| 11 | 
             
              exec_time = (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
         | 
| 12 |  | 
| 13 | 
            -
              comp_flop = round( | 
| 14 | 
            -
              mem_bytes = round( | 
| 15 |  | 
| 16 | 
             
              c1.write("GFLOP:")
         | 
| 17 | 
             
              c2.write(str(comp_flop))
         | 
| @@ -66,17 +66,17 @@ mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n | |
| 66 | 
             
            c1, c2 = st.columns([2, 3])
         | 
| 67 | 
             
            att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
         | 
| 68 |  | 
| 69 | 
            -
            st.header('Attention scores: ')
         | 
| 70 | 
             
            st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
         | 
| 71 | 
             
            st.subheader("Multi-Head Attention")
         | 
| 72 | 
            -
            mha_flop = 2*bs*h*(d/h) | 
| 73 | 
            -
            mha_bytes = 2*bs*h* | 
| 74 | 
             
            c1, c2 = st.columns([2, 3])
         | 
| 75 | 
             
            att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
         | 
| 76 |  | 
| 77 | 
             
            st.subheader("Multi-Query Attention")
         | 
| 78 | 
            -
            mqa_flop = 2*bs*h*(d/h) | 
| 79 | 
            -
            mqa_bytes = 2*bs* | 
| 80 | 
             
            c1, c2 = st.columns([2, 3])
         | 
| 81 | 
             
            att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
         | 
| 82 |  | 
|  | |
| 1 | 
             
            import streamlit as st
         | 
| 2 |  | 
| 3 | 
            +
            def number_field(label, **kwargs)
         | 
| 4 | 
             
              c1, c2 = st.columns([2, 4])
         | 
| 5 | 
             
              c1.write(label)
         | 
| 6 |  | 
| 7 | 
            +
              return c2.number_input('', **kwargs)
         | 
| 8 |  | 
| 9 | 
             
            def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
         | 
| 10 | 
             
              arith_int = comp_flop/mem_bytes
         | 
| 11 | 
             
              exec_time = (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
         | 
| 12 |  | 
| 13 | 
            +
              comp_flop = round(comp_flop/1e9, 2)
         | 
| 14 | 
            +
              mem_bytes = round(comp_bytes/1e6, 2)
         | 
| 15 |  | 
| 16 | 
             
              c1.write("GFLOP:")
         | 
| 17 | 
             
              c2.write(str(comp_flop))
         | 
|  | |
| 66 | 
             
            c1, c2 = st.columns([2, 3])
         | 
| 67 | 
             
            att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
         | 
| 68 |  | 
| 69 | 
            +
            st.header('Attention scores: attention-value gemm')
         | 
| 70 | 
             
            st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
         | 
| 71 | 
             
            st.subheader("Multi-Head Attention")
         | 
| 72 | 
            +
            mha_flop = 2*bs*h*n*(d/h)
         | 
| 73 | 
            +
            mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
         | 
| 74 | 
             
            c1, c2 = st.columns([2, 3])
         | 
| 75 | 
             
            att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
         | 
| 76 |  | 
| 77 | 
             
            st.subheader("Multi-Query Attention")
         | 
| 78 | 
            +
            mqa_flop = 2*bs*h*n*(d/h)
         | 
| 79 | 
            +
            mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
         | 
| 80 | 
             
            c1, c2 = st.columns([2, 3])
         | 
| 81 | 
             
            att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
         | 
| 82 |  | 
