import torch def get_precision_fac(precision: str): if precision == "mixed": return 2 elif precision == "single": return 4 else: raise ValueError("Precision must be either 'mixed' or 'single'") def get_params_fac(model_dtype: str): if model_dtype == "float16": return 2 elif model_dtype == "float32": return 4 else: raise ValueError("Model dtype must be either torch.float16 or torch.float32") ####################### Zero Redundancy Optimizer (ZeRO) ####################### VARIANCE_FACTOR = 4 MOMENTUM_FACTOR = 4 OPTIMIZER_FACTOR = VARIANCE_FACTOR + MOMENTUM_FACTOR # Adam optimizer FP32_GRADS_FACTOR = 4 FP32_PARAM_FACTOR = 4 MASTER_PARAMS_FACTOR = FP32_PARAM_FACTOR def estimate_zero1_model_states_mem_needs(total_params, num_gpus_per_node=1, num_nodes=1, cpu_offload=True, additional_buffer_factor=1.5, precision="mixed", model_dtype = "float16", ): total_gpus = num_nodes * num_gpus_per_node precision_fac = get_precision_fac(precision) params_fac = get_params_fac(model_dtype) if cpu_offload: gpu_mem = (precision_fac * total_params) # + (grads_fac * total_params) cpu_mem = total_params * max(params_fac * total_gpus, (MASTER_PARAMS_FACTOR + OPTIMIZER_FACTOR + FP32_GRADS_FACTOR)) * additional_buffer_factor else: if precision == "mixed": gpu_mem = (precision_fac * total_params) + (FP32_GRADS_FACTOR * total_params) + int((OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus) else: gpu_mem = (precision_fac * total_params) + (FP32_GRADS_FACTOR * total_params) + int(OPTIMIZER_FACTOR * total_params / total_gpus) cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor return int(cpu_mem), int(gpu_mem) def estimate_zero2_model_states_mem_needs(total_params, num_gpus_per_node=1, num_nodes=1, cpu_offload=True, additional_buffer_factor=1.5, precision="mixed", model_dtype = "float16", ): total_gpus = num_nodes * num_gpus_per_node precision_fac = get_precision_fac(precision) params_fac = get_params_fac(model_dtype) if cpu_offload: gpu_mem = precision_fac * total_params # Negligible memory usage for partitioned gradients cpu_mem = total_params * max(params_fac * total_gpus, (MASTER_PARAMS_FACTOR + OPTIMIZER_FACTOR + FP32_GRADS_FACTOR)) * additional_buffer_factor else: if precision == "mixed": gpu_mem = precision_fac * total_params + int((FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus) else: gpu_mem = precision_fac * total_params + int((FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * total_params / total_gpus) cpu_mem = params_fac * total_params * num_gpus_per_node * additional_buffer_factor return int(cpu_mem), int(gpu_mem) def estimate_zero3_model_states_mem_needs(total_params, largest_layer_params, num_gpus_per_node=1, num_nodes=1, cpu_offload=True, cpu_offload_params=True, zero_init=True, additional_buffer_factor=1.5, precision="mixed", model_dtype = "float16", ): total_gpus = num_nodes * num_gpus_per_node gpus_factor = 1 / num_nodes precision_fac = get_precision_fac(precision) params_fac = get_params_fac(model_dtype) grads_fac = precision_fac largest_layer_memory = (grads_fac + precision_fac) * largest_layer_params if cpu_offload: if cpu_offload_params: gpu_mem = largest_layer_memory if zero_init: cpu_mem = total_params * (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + params_fac) * gpus_factor * additional_buffer_factor else: cpu_mem = total_params * max(params_fac * num_gpus_per_node, (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + params_fac) * gpus_factor) * additional_buffer_factor else: gpu_mem = max( largest_layer_memory, int((precision_fac) * total_params / total_gpus) # No need for gradients: ZeRO-Offload can transfer these gradients for each parameter individually or in small groups to the CPU memory immediately after they are computed ) if zero_init: cpu_mem = total_params * (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * gpus_factor * additional_buffer_factor else: cpu_mem = total_params * max(params_fac * num_gpus_per_node, (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * gpus_factor) * additional_buffer_factor else: if precision == "mixed": gpu_mem = max( int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * largest_layer_params), int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus) ) else: gpu_mem = max( int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * largest_layer_params), int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * total_params / total_gpus) ) if zero_init: cpu_mem = largest_layer_params * params_fac * num_gpus_per_node * additional_buffer_factor else: cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor return int(cpu_mem), int(gpu_mem), largest_layer_memory