|
import torch |
|
|
|
def get_precision_fac(precision: str): |
|
if precision == "mixed": |
|
return 2 |
|
elif precision == "single": |
|
return 4 |
|
else: |
|
raise ValueError("Precision must be either 'mixed' or 'single'") |
|
|
|
|
|
def get_params_fac(model_dtype: torch.dtype): |
|
if model_dtype == torch.float16: |
|
return 2 |
|
elif model_dtype == torch.float32: |
|
return 4 |
|
else: |
|
raise ValueError("Model dtype must be either torch.float16 or torch.float32") |
|
|
|
|
|
|
|
|
|
|
|
VARIANCE_FACTOR = 4 |
|
MOMENTUM_FACTOR = 4 |
|
OPTIMIZER_FACTOR = VARIANCE_FACTOR + MOMENTUM_FACTOR |
|
FP32_GRADS_FACTOR = 4 |
|
FP32_PARAM_FACTOR = 4 |
|
MASTER_PARAMS_FACTOR = FP32_PARAM_FACTOR |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def estimate_zero1_model_states_mem_needs(total_params, |
|
num_gpus_per_node=1, |
|
num_nodes=1, |
|
cpu_offload=True, |
|
additional_buffer_factor=1.5, |
|
precision="mixed", |
|
model_dtype = torch.float16, |
|
): |
|
|
|
total_gpus = num_nodes * num_gpus_per_node |
|
|
|
precision_fac = get_precision_fac(precision) |
|
params_fac = get_params_fac(model_dtype) |
|
|
|
if cpu_offload: |
|
gpu_mem = (precision_fac * total_params) |
|
cpu_mem = total_params * max(params_fac * total_gpus, (MASTER_PARAMS_FACTOR + OPTIMIZER_FACTOR + FP32_GRADS_FACTOR)) * additional_buffer_factor |
|
else: |
|
if precision == "mixed": |
|
gpu_mem = (precision_fac * total_params) + (FP32_GRADS_FACTOR * total_params) + int((OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus) |
|
else: |
|
gpu_mem = (precision_fac * total_params) + (FP32_GRADS_FACTOR * total_params) + int(OPTIMIZER_FACTOR * total_params / total_gpus) |
|
cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor |
|
|
|
return int(cpu_mem), int(gpu_mem) |
|
|
|
|
|
def estimate_zero2_model_states_mem_needs(total_params, |
|
num_gpus_per_node=1, |
|
num_nodes=1, |
|
cpu_offload=True, |
|
additional_buffer_factor=1.5, |
|
precision="mixed", |
|
model_dtype = torch.float16, |
|
): |
|
|
|
total_gpus = num_nodes * num_gpus_per_node |
|
|
|
precision_fac = get_precision_fac(precision) |
|
params_fac = get_params_fac(model_dtype) |
|
|
|
if cpu_offload: |
|
gpu_mem = precision_fac * total_params |
|
cpu_mem = total_params * max(params_fac * total_gpus, (MASTER_PARAMS_FACTOR + OPTIMIZER_FACTOR + FP32_GRADS_FACTOR)) * additional_buffer_factor |
|
else: |
|
if precision == "mixed": |
|
gpu_mem = precision_fac * total_params + int((FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus) |
|
else: |
|
gpu_mem = precision_fac * total_params + int((FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * total_params / total_gpus) |
|
cpu_mem = params_fac * total_params * num_gpus_per_node * additional_buffer_factor |
|
|
|
return int(cpu_mem), int(gpu_mem) |
|
|
|
|
|
def estimate_zero3_model_states_mem_needs(total_params, |
|
largest_layer_params, |
|
num_gpus_per_node=1, |
|
num_nodes=1, |
|
cpu_offload=True, |
|
cpu_offload_params=True, |
|
zero_init=True, |
|
additional_buffer_factor=1.5, |
|
precision="mixed", |
|
model_dtype = torch.float16, |
|
): |
|
|
|
total_gpus = num_nodes * num_gpus_per_node |
|
gpus_factor = 1 / num_nodes |
|
|
|
precision_fac = get_precision_fac(precision) |
|
params_fac = get_params_fac(model_dtype) |
|
grads_fac = precision_fac |
|
|
|
largest_layer_memory = (grads_fac + precision_fac) * largest_layer_params |
|
|
|
if cpu_offload: |
|
if cpu_offload_params: |
|
gpu_mem = largest_layer_memory |
|
if zero_init: |
|
cpu_mem = total_params * (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + params_fac) * gpus_factor * additional_buffer_factor |
|
else: |
|
cpu_mem = total_params * max(params_fac * num_gpus_per_node, (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + params_fac) * gpus_factor) * additional_buffer_factor |
|
|
|
else: |
|
gpu_mem = max( |
|
largest_layer_memory, |
|
int((precision_fac) * total_params / total_gpus) |
|
) |
|
|
|
if zero_init: |
|
cpu_mem = total_params * (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * gpus_factor * additional_buffer_factor |
|
else: |
|
cpu_mem = total_params * max(params_fac * num_gpus_per_node, (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * gpus_factor) * additional_buffer_factor |
|
else: |
|
if precision == "mixed": |
|
gpu_mem = max( |
|
int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * largest_layer_params), |
|
int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus) |
|
) |
|
else: |
|
gpu_mem = max( |
|
int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * largest_layer_params), |
|
int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * total_params / total_gpus) |
|
) |
|
|
|
if zero_init: |
|
cpu_mem = largest_layer_params * params_fac * num_gpus_per_node * additional_buffer_factor |
|
else: |
|
cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor |
|
|
|
return int(cpu_mem), int(gpu_mem), largest_layer_memory |
|
|
|
|
|
|