In [1]:
from utils import activation_memory, param_grads_opt

In [48]:
def activation_memory(
    a, # attention heads
    b, # micro batch size
    h, # hidden dimension size
    h_ff, # feedforward dimension size (often h_ff = 4h)
    L, # number of layers
    s, # sequence length
    mixed=True,
    recomputation="none",
    ff_activation="relu"
    ):
    
    # https://arxiv.org/pdf/2205.05198
    if mixed:
        bytes_per_value = 2 
    else:
        bytes_per_value = 4

    one_layer_attention = s * b * h * (bytes_per_value * 5 + 1) + ((2 * bytes_per_value + 1) * a * s * s * b) # eq (2)

    if ff_activation == "relu":
        one_layer_feedforward = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value)   # inputs of 1st/2nd linear layers
                + s * b * h)  # dropout
    elif ff_activation == "gelu":
        one_layer_feedforward = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value)   # inputs of 1st/2nd linear layers
                + s * b * h_ff * bytes_per_value # inputs of activation function (not really necessary for Relu)
                + s * b * h)  # dropout
    elif ff_activation == "swiglu":
        one_layer_feedforward = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value)   # inputs of input/output linear layers
         + s * b * h_ff * bytes_per_value * 3 # inputs of activation function
            + s * b * h)  # dropout (note that dropout is lower-precision - boolean)


    layer_norm = s * b * h * bytes_per_value

    if recomputation == "none":
        one_layer =  one_layer_attention + one_layer_feedforward + 2 * layer_norm # eq (2)
    elif recomputation =="selective":
        one_layer = s * b * h * 34 # eq (6)
    elif recomputation =="full":
        one_layer = s * b * h * 2
    else:
        raise ValueError()
    
    input_dropout = s * b * h # section 4.3

    total = L * one_layer + input_dropout
        
    return total




In [51]:
a = 16
b = 3
h = 1024
h_ff = 4 * h
L = 1
s = 7  # 128000
recomputation = "none"
mixed = True
ff_activation = "swiglu"


In [52]:
activation_memory(a=a, b=b, h=h, h_ff=h_ff, L=L, s=s, recomputation=recomputation, mixed=mixed, ff_activation=ff_activation)

1086960

In [18]:
from math import log

def format_bytes(bytes):
    sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB']
    if bytes == 0:
        return '0 Bytes'
    i = int(log(bytes, 1024))
    print(i)
    p = 1024 ** i
    s = round(bytes / p, 2)
    return f"{s} {sizes[i]}"



In [19]:
format_bytes(activation_memory(a, b, h, L, s, recomputation))

4


'22.13 TB'