|
def activation_memory( |
|
a, |
|
b, |
|
h, |
|
L, |
|
s, |
|
mixed=True, |
|
recomputation=None |
|
): |
|
|
|
|
|
if recomputation is None: |
|
one_layer = s * b * h * (34 + (5 * a * s / h)) |
|
elif recomputation =="selective": |
|
one_layer = s * b * h * 34 |
|
elif recomputation =="full": |
|
one_layer = s * b * h * 2 |
|
else: |
|
raise ValueError() |
|
|
|
input_dropout = s * b * h |
|
|
|
if mixed: |
|
bytes_per_value = 2 |
|
else: |
|
bytes_per_value = 4 |
|
|
|
return bytes_per_value * L * one_layer + input_dropout |
|
|
|
|
|
def param_grads_opt( |
|
h, |
|
L, |
|
s, |
|
v, |
|
k=8, |
|
mixed=True |
|
): |
|
|
|
|
|
|
|
|
|
emb = h*(v+s) |
|
one_layer = 12 * h**2 + 13*h |
|
other = 2*h |
|
|
|
n = emb + L * one_layer + other |
|
|
|
|
|
|
|
if mixed: |
|
k += 4 |
|
bytes_per_paramter = 2 |
|
else: |
|
bytes_per_paramter = 4 |
|
|
|
return bytes_per_paramter*n, bytes_per_paramter*n, k*n |
|
|