|
def activation_memory( |
|
a, |
|
b, |
|
h, |
|
h_ff, |
|
L, |
|
s, |
|
mixed=True, |
|
recomputation="none" |
|
): |
|
|
|
|
|
if mixed: |
|
bytes_per_value = 2 |
|
else: |
|
bytes_per_value = 4 |
|
|
|
one_layer_attention = s * b * h * (bytes_per_value * 5 + 1) + ((2 * bytes_per_value + 1) * a * s * s * b) |
|
one_layer_feedforward_mlp = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value) |
|
+ s * b * h_ff * bytes_per_value |
|
+ s * b * h) |
|
one_layer_feedforward_swiglu = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value) |
|
+ s * b * h_ff * bytes_per_value * 3 |
|
+ s * b * h) |
|
|
|
|
|
if recomputation == "none": |
|
one_layer = one_layer_attention |
|
elif recomputation =="selective": |
|
one_layer = s * b * h * 34 |
|
elif recomputation =="full": |
|
one_layer = s * b * h * 2 |
|
else: |
|
raise ValueError() |
|
|
|
input_dropout = 0 |
|
|
|
total = L * one_layer + input_dropout |
|
|
|
return total |
|
|
|
|
|
def param_grads_opt( |
|
h, |
|
L, |
|
s, |
|
v, |
|
k=8, |
|
mixed=True |
|
): |
|
|
|
|
|
|
|
|
|
emb = h*(v+s) |
|
one_layer = 12 * h**2 + 13*h |
|
other = 2*h |
|
|
|
n = emb + L * one_layer + other |
|
|
|
|
|
|
|
if mixed: |
|
k += 4 |
|
bytes_per_paramter = 2 |
|
else: |
|
bytes_per_paramter = 4 |
|
|
|
return bytes_per_paramter*n, bytes_per_paramter*n, k*n |
|
|