File size: 1,463 Bytes
0af560f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
def activation_memory(
a, # attention heads
b, # micro batch size
h, # hidden dimension size
L, # number of layers
s, # sequence length
mixed=True,
recomputation=None
):
# https://arxiv.org/pdf/2205.05198
if recomputation is None:
one_layer = s * b * h * (34 + (5 * a * s / h)) # eq (2)
elif recomputation =="selective":
one_layer = s * b * h * 34 # eq (6)
elif recomputation =="full":
one_layer = s * b * h * 2
else:
raise ValueError()
input_dropout = s * b * h # section 4.3
if mixed:
bytes_per_value = 2
else:
bytes_per_value = 4
return bytes_per_value * L * one_layer + input_dropout
def param_grads_opt(
h, # hidden dimension size
L, # number of layers
s, # sequence length
v, # vocab size
k=8, # parameters for optimizer (Adam: 8 = 4 bytes moments + 4 bytes variance)
mixed=True # mixed precision training
):
# https://michaelwornow.net/2024/01/18/counting-params-in-transformer
# note: this is without GQA or MQA
emb = h*(v+s)
one_layer = 12 * h**2 + 13*h
other = 2*h
n = emb + L * one_layer + other
# 3.1 https://arxiv.org/pdf/1910.02054
if mixed:
k += 4 # additional full precision weights
bytes_per_paramter = 2
else:
bytes_per_paramter = 4
return bytes_per_paramter*n, bytes_per_paramter*n, k*n
|