File size: 1,463 Bytes
0af560f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def activation_memory(
    a, # attention heads
    b, # micro batch size
    h, # hidden dimension size
    L, # number of layers
    s, # sequence length
    mixed=True,
    recomputation=None
    ):
    
    # https://arxiv.org/pdf/2205.05198
    if recomputation is None:
        one_layer = s * b * h * (34 + (5 * a * s / h)) # eq (2)
    elif recomputation =="selective":
        one_layer = s * b * h * 34 # eq (6)
    elif recomputation =="full":
        one_layer = s * b * h * 2
    else:
        raise ValueError()
    
    input_dropout = s * b * h # section 4.3
    
    if mixed:
        bytes_per_value = 2 
    else:
        bytes_per_value = 4
    
    return bytes_per_value * L * one_layer + input_dropout


def param_grads_opt(
    h, # hidden dimension size
    L, # number of layers
    s, # sequence length
    v, # vocab size
    k=8, # parameters for optimizer (Adam: 8 = 4 bytes moments + 4 bytes variance)
    mixed=True # mixed precision training
    ):
    
    # https://michaelwornow.net/2024/01/18/counting-params-in-transformer
    # note: this is without GQA or MQA
    
    emb = h*(v+s)
    one_layer = 12 * h**2 + 13*h
    other = 2*h

    n = emb + L * one_layer + other
    
    # 3.1 https://arxiv.org/pdf/1910.02054
    
    if mixed:
        k += 4 # additional full precision weights
        bytes_per_paramter = 2
    else:
        bytes_per_paramter = 4
    
    return bytes_per_paramter*n, bytes_per_paramter*n, k*n