import streamlit as st st.header("Transformer parameters") bs = st.number_input('Batch size: ', value=10) h = st.number_input('Num heads: ', value=16) d = st.number_input('Dimension: ', value=768) n = st.number_input('Seq length: ', value=1024) st.header('Query, Key, Value projection') mha_flop = 2*bs*n*d*3*d mha_bytes = 2*bs*n*d + 2*3*d*d + 2*bs*n*3*d st.write()