|
|
|
def estimate_zero1_model_states_mem_needs(total_params, |
|
num_gpus_per_node=1, |
|
num_nodes=1, |
|
cpu_offload=True, |
|
additional_buffer_factor=1.5, |
|
precision_fac = 2, |
|
params_fac = 4 |
|
): |
|
|
|
|
|
|
|
|
|
|
|
total_gpus = num_nodes * num_gpus_per_node |
|
|
|
master_params_fac = 4 |
|
variance_fac = 4 |
|
momentum_fac = 4 |
|
grads_fac = 4 |
|
optimizer_fac = variance_fac + momentum_fac |
|
|
|
total_gpus = num_nodes * num_gpus_per_node |
|
|
|
if cpu_offload: |
|
gpu_mem = (precision_fac * total_params) + (precision_fac * total_params) |
|
cpu_mem = total_params * max(params_fac * total_gpus, (master_params_fac+optimizer_fac+grads_fac)) * additional_buffer_factor |
|
else: |
|
gpu_mem = (precision_fac * total_params) + (precision_fac * total_params) + int((precision_fac + optimizer_fac + master_params_fac + precision_fac) * total_params / total_gpus) |
|
cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor |
|
|
|
return int(cpu_mem), int(gpu_mem) |
|
|
|
def estimate_zero2_model_states_mem_needs(total_params, |
|
num_gpus_per_node=1, |
|
num_nodes=1, |
|
cpu_offload=True, |
|
additional_buffer_factor=1.5, |
|
precision_fac = 2, |
|
params_fac = 4 |
|
): |
|
|
|
|
|
|
|
|
|
|
|
total_gpus = num_nodes * num_gpus_per_node |
|
|
|
master_params_fac = 4 |
|
variance_fac = 4 |
|
momentum_fac = 4 |
|
grads_fac = 4 |
|
optimizer_fac = variance_fac + momentum_fac |
|
|
|
total_gpus = num_nodes * num_gpus_per_node |
|
|
|
if cpu_offload: |
|
gpu_mem = precision_fac * total_params |
|
cpu_mem = total_params * max(params_fac * total_gpus, (master_params_fac+optimizer_fac+grads_fac)) * additional_buffer_factor |
|
else: |
|
gpu_mem = precision_fac * total_params + int((precision_fac + grads_fac + optimizer_fac + master_params_fac + precision_fac) * total_params / total_gpus) |
|
cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor |
|
|
|
return int(cpu_mem), int(gpu_mem) |
|
|
|
|
|
def estimate_zero3_model_states_mem_needs(total_params, |
|
largest_layer_params, |
|
num_gpus_per_node=1, |
|
num_nodes=1, |
|
cpu_offload=True, |
|
cpu_offload_params=True, |
|
zero_init=True, |
|
additional_buffer_factor=1.5, |
|
precision_fac = 2, |
|
params_fac = 4 |
|
): |
|
|
|
|
|
|
|
|
|
|
|
total_gpus = num_nodes * num_gpus_per_node |
|
gpus_factor = 1 / num_nodes |
|
master_params_fac = 4 |
|
variance_fac = 4 |
|
momentum_fac = 4 |
|
grads_fac = 4 |
|
optimizer_fac = variance_fac + momentum_fac |
|
|
|
largest_layer_memory = (2 * precision_fac) * largest_layer_params |
|
|
|
if cpu_offload: |
|
if cpu_offload_params: |
|
gpu_mem = largest_layer_memory |
|
|
|
if zero_init: |
|
cpu_mem = total_params * (master_params_fac + grads_fac + optimizer_fac + params_fac) * gpus_factor * additional_buffer_factor |
|
else: |
|
|
|
cpu_mem = total_params * max(params_fac * num_gpus_per_node, (master_params_fac + grads_fac + optimizer_fac + params_fac) * gpus_factor) * additional_buffer_factor |
|
else: |
|
gpu_mem = largest_layer_memory + int(precision_fac * total_params / total_gpus) |
|
|
|
if zero_init: |
|
cpu_mem = total_params * (master_params_fac + grads_fac + optimizer_fac) * gpus_factor * additional_buffer_factor |
|
else: |
|
cpu_mem = total_params * max(params_fac * num_gpus_per_node, (master_params_fac + grads_fac + optimizer_fac) * gpus_factor) * additional_buffer_factor |
|
else: |
|
gpu_mem = largest_layer_memory + int((master_params_fac + grads_fac + optimizer_fac + precision_fac) * total_params / total_gpus) |
|
|
|
|
|
if zero_init: |
|
cpu_mem = largest_layer_params * params_fac * num_gpus_per_node * additional_buffer_factor |
|
else: |
|
cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor |
|
|
|
return int(cpu_mem), int(gpu_mem), largest_layer_memory |
|
|
|
|
|
|