File size: 6,527 Bytes
964360b 6ca6353 964360b 6ca6353 964360b 674c962 964360b 6ca6353 674c962 964360b 674c962 964360b 674c962 964360b 674c962 964360b 674c962 964360b 6ca6353 674c962 964360b 674c962 964360b 674c962 964360b 674c962 964360b 6ca6353 674c962 964360b 674c962 964360b 674c962 964360b 674c962 964360b 674c962 964360b 674c962 964360b 674c962 964360b 674c962 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import torch
def get_precision_fac(precision: str):
if precision == "mixed":
return 2
elif precision == "single":
return 4
else:
raise ValueError("Precision must be either 'mixed' or 'single'")
def get_params_fac(model_dtype: str):
if model_dtype == "float16":
return 2
elif model_dtype == "float32":
return 4
else:
raise ValueError("Model dtype must be either torch.float16 or torch.float32")
####################### Zero Redundancy Optimizer (ZeRO) #######################
VARIANCE_FACTOR = 4
MOMENTUM_FACTOR = 4
OPTIMIZER_FACTOR = VARIANCE_FACTOR + MOMENTUM_FACTOR # Adam optimizer
FP32_GRADS_FACTOR = 4
FP32_PARAM_FACTOR = 4
MASTER_PARAMS_FACTOR = FP32_PARAM_FACTOR
def estimate_zero1_model_states_mem_needs(total_params,
num_gpus_per_node=1,
num_nodes=1,
cpu_offload=True,
additional_buffer_factor=1.5,
precision="mixed",
model_dtype = "float16",
):
total_gpus = num_nodes * num_gpus_per_node
precision_fac = get_precision_fac(precision)
params_fac = get_params_fac(model_dtype)
if cpu_offload:
gpu_mem = (precision_fac * total_params) # + (grads_fac * total_params)
cpu_mem = total_params * max(params_fac * total_gpus, (MASTER_PARAMS_FACTOR + OPTIMIZER_FACTOR + FP32_GRADS_FACTOR)) * additional_buffer_factor
else:
if precision == "mixed":
gpu_mem = (precision_fac * total_params) + (FP32_GRADS_FACTOR * total_params) + int((OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus)
else:
gpu_mem = (precision_fac * total_params) + (FP32_GRADS_FACTOR * total_params) + int(OPTIMIZER_FACTOR * total_params / total_gpus)
cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor
return int(cpu_mem), int(gpu_mem)
def estimate_zero2_model_states_mem_needs(total_params,
num_gpus_per_node=1,
num_nodes=1,
cpu_offload=True,
additional_buffer_factor=1.5,
precision="mixed",
model_dtype = "float16",
):
total_gpus = num_nodes * num_gpus_per_node
precision_fac = get_precision_fac(precision)
params_fac = get_params_fac(model_dtype)
if cpu_offload:
gpu_mem = precision_fac * total_params # Negligible memory usage for partitioned gradients
cpu_mem = total_params * max(params_fac * total_gpus, (MASTER_PARAMS_FACTOR + OPTIMIZER_FACTOR + FP32_GRADS_FACTOR)) * additional_buffer_factor
else:
if precision == "mixed":
gpu_mem = precision_fac * total_params + int((FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus)
else:
gpu_mem = precision_fac * total_params + int((FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * total_params / total_gpus)
cpu_mem = params_fac * total_params * num_gpus_per_node * additional_buffer_factor
return int(cpu_mem), int(gpu_mem)
def estimate_zero3_model_states_mem_needs(total_params,
largest_layer_params,
num_gpus_per_node=1,
num_nodes=1,
cpu_offload=True,
cpu_offload_params=True,
zero_init=True,
additional_buffer_factor=1.5,
precision="mixed",
model_dtype = "float16",
):
total_gpus = num_nodes * num_gpus_per_node
gpus_factor = 1 / num_nodes
precision_fac = get_precision_fac(precision)
params_fac = get_params_fac(model_dtype)
grads_fac = precision_fac
largest_layer_memory = (grads_fac + precision_fac) * largest_layer_params
if cpu_offload:
if cpu_offload_params:
gpu_mem = largest_layer_memory
if zero_init:
cpu_mem = total_params * (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + params_fac) * gpus_factor * additional_buffer_factor
else:
cpu_mem = total_params * max(params_fac * num_gpus_per_node, (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + params_fac) * gpus_factor) * additional_buffer_factor
else:
gpu_mem = max(
largest_layer_memory,
int((precision_fac) * total_params / total_gpus) # No need for gradients: ZeRO-Offload can transfer these gradients for each parameter individually or in small groups to the CPU memory immediately after they are computed
)
if zero_init:
cpu_mem = total_params * (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * gpus_factor * additional_buffer_factor
else:
cpu_mem = total_params * max(params_fac * num_gpus_per_node, (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * gpus_factor) * additional_buffer_factor
else:
if precision == "mixed":
gpu_mem = max(
int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * largest_layer_params),
int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus)
)
else:
gpu_mem = max(
int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * largest_layer_params),
int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * total_params / total_gpus)
)
if zero_init:
cpu_mem = largest_layer_params * params_fac * num_gpus_per_node * additional_buffer_factor
else:
cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor
return int(cpu_mem), int(gpu_mem), largest_layer_memory
|