File size: 6,527 Bytes
964360b
 
 
 
 
 
 
 
 
 
 
6ca6353
 
964360b
6ca6353
964360b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674c962
 
 
 
 
964360b
6ca6353
674c962
 
 
964360b
 
 
674c962
 
964360b
 
674c962
964360b
 
 
 
674c962
 
 
 
964360b
674c962
 
 
 
 
964360b
6ca6353
674c962
 
 
964360b
 
 
674c962
 
964360b
 
674c962
964360b
 
 
 
 
674c962
 
 
 
 
 
 
 
 
 
 
 
964360b
6ca6353
674c962
 
 
 
 
964360b
 
 
 
 
674c962
 
 
 
 
964360b
674c962
964360b
 
674c962
964360b
 
 
 
674c962
 
964360b
674c962
964360b
674c962
964360b
 
 
 
 
 
 
 
 
 
674c962
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch

def get_precision_fac(precision: str):
    if precision == "mixed":
        return 2
    elif precision == "single":
        return 4
    else:
        raise ValueError("Precision must be either 'mixed' or 'single'")


def get_params_fac(model_dtype: str):
    if model_dtype == "float16":
        return 2
    elif model_dtype == "float32":
        return 4
    else:
        raise ValueError("Model dtype must be either torch.float16 or torch.float32")



####################### Zero Redundancy Optimizer (ZeRO) #######################

VARIANCE_FACTOR = 4
MOMENTUM_FACTOR = 4
OPTIMIZER_FACTOR = VARIANCE_FACTOR + MOMENTUM_FACTOR # Adam optimizer
FP32_GRADS_FACTOR = 4
FP32_PARAM_FACTOR = 4
MASTER_PARAMS_FACTOR = FP32_PARAM_FACTOR


def estimate_zero1_model_states_mem_needs(total_params,
                                          num_gpus_per_node=1,
                                          num_nodes=1,
                                          cpu_offload=True,
                                          additional_buffer_factor=1.5,
                                          precision="mixed",
                                          model_dtype = "float16",
                                          ):
   
    total_gpus = num_nodes * num_gpus_per_node
    
    precision_fac = get_precision_fac(precision)
    params_fac = get_params_fac(model_dtype)

    if cpu_offload:
        gpu_mem = (precision_fac * total_params) # + (grads_fac * total_params)
        cpu_mem = total_params * max(params_fac * total_gpus, (MASTER_PARAMS_FACTOR + OPTIMIZER_FACTOR + FP32_GRADS_FACTOR)) * additional_buffer_factor
    else:
        if precision == "mixed":
            gpu_mem = (precision_fac * total_params) + (FP32_GRADS_FACTOR * total_params) + int((OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus)
        else:
            gpu_mem = (precision_fac * total_params) + (FP32_GRADS_FACTOR * total_params) + int(OPTIMIZER_FACTOR * total_params / total_gpus)
        cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor

    return int(cpu_mem), int(gpu_mem)


def estimate_zero2_model_states_mem_needs(total_params,
                                          num_gpus_per_node=1,
                                          num_nodes=1,
                                          cpu_offload=True,
                                          additional_buffer_factor=1.5,
                                          precision="mixed",
                                          model_dtype = "float16",
                                          ):
    
    total_gpus = num_nodes * num_gpus_per_node
    
    precision_fac = get_precision_fac(precision)
    params_fac = get_params_fac(model_dtype)

    if cpu_offload:
        gpu_mem = precision_fac * total_params # Negligible memory usage for partitioned gradients
        cpu_mem = total_params * max(params_fac * total_gpus, (MASTER_PARAMS_FACTOR + OPTIMIZER_FACTOR + FP32_GRADS_FACTOR)) * additional_buffer_factor
    else:
        if precision == "mixed":
            gpu_mem = precision_fac * total_params + int((FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus)
        else:
            gpu_mem = precision_fac * total_params + int((FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * total_params / total_gpus)
        cpu_mem =  params_fac * total_params * num_gpus_per_node * additional_buffer_factor

    return int(cpu_mem), int(gpu_mem)


def estimate_zero3_model_states_mem_needs(total_params,
                                          largest_layer_params,
                                          num_gpus_per_node=1,
                                          num_nodes=1,
                                          cpu_offload=True,
                                          cpu_offload_params=True,
                                          zero_init=True,
                                          additional_buffer_factor=1.5,
                                          precision="mixed",
                                          model_dtype = "float16",
                                          ):

    total_gpus = num_nodes * num_gpus_per_node
    gpus_factor = 1 / num_nodes

    precision_fac = get_precision_fac(precision)
    params_fac = get_params_fac(model_dtype)
    grads_fac = precision_fac

    largest_layer_memory = (grads_fac + precision_fac) * largest_layer_params

    if cpu_offload:
        if cpu_offload_params:
            gpu_mem = largest_layer_memory
            if zero_init:
                cpu_mem = total_params * (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + params_fac) * gpus_factor * additional_buffer_factor
            else:
                cpu_mem = total_params * max(params_fac * num_gpus_per_node, (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + params_fac) * gpus_factor) * additional_buffer_factor
        
        else:
            gpu_mem = max(
                largest_layer_memory,
                int((precision_fac) * total_params / total_gpus) # No need for gradients: ZeRO-Offload can transfer these gradients for each parameter individually or in small groups to the CPU memory immediately after they are computed
            )

            if zero_init:
                cpu_mem = total_params * (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * gpus_factor * additional_buffer_factor 
            else:
                cpu_mem = total_params * max(params_fac * num_gpus_per_node, (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * gpus_factor) * additional_buffer_factor
    else:
        if precision == "mixed":
            gpu_mem = max(
                int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * largest_layer_params),
                int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus)
            )
        else:
            gpu_mem = max(
                int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * largest_layer_params),
                int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * total_params / total_gpus)
            )

        if zero_init:
            cpu_mem = largest_layer_params * params_fac * num_gpus_per_node * additional_buffer_factor
        else:
            cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor

    return int(cpu_mem), int(gpu_mem), largest_layer_memory