|
def create_deepspeed_config(args): |
|
ds_config = { |
|
"steps_per_print": 1000, |
|
"train_batch_size": args.global_batch_size, |
|
"gradient_accumulation_steps": args.gradient_accumulation_steps, |
|
|
|
"optimizer": { |
|
"type": "Adam", |
|
"adam_w_mode": True, |
|
"params": { |
|
"lr": args.lr, |
|
"weight_decay": args.weight_decay, |
|
"bias_correction": True, |
|
"betas": [ |
|
args.beta1, |
|
args.beta2 |
|
], |
|
} |
|
}, |
|
"fp16": { |
|
"enabled": args.mixed_precision == 'fp16', |
|
"loss_scale": 0, |
|
"initial_scale_power": 16, |
|
"loss_scale_window": 1000, |
|
"hysteresis": 2, |
|
"min_loss_scale": 1 |
|
}, |
|
"bf16": { |
|
"enabled": args.mixed_precision == 'bf16', |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"zero_allow_untested_optimizer": True |
|
} |
|
|
|
if args.clip_grad is not None: |
|
ds_config.update({'gradient_clipping': args.clip_grad}) |
|
|
|
if args.zero_stage == 0: |
|
ds_config.update({"zero_optimization": |
|
{ |
|
"stage": args.zero_stage, |
|
"contiguous_gradients": True, |
|
"overlap_comm": True, |
|
} |
|
}) |
|
elif args.zero_stage == 1: |
|
ds_config.update({"zero_optimization": |
|
{ |
|
"stage": args.zero_stage, |
|
"contiguous_gradients": True, |
|
"overlap_comm": True, |
|
"reduce_bucket_size": 5e8, |
|
} |
|
}) |
|
elif args.zero_stage == 2: |
|
ds_config.update({"zero_optimization": |
|
{ |
|
"stage": args.zero_stage, |
|
"contiguous_gradients": True, |
|
"overlap_comm": True, |
|
"reduce_scatter": True, |
|
"reduce_bucket_size": 5e8, |
|
"allgather_bucket_size": 5e8, |
|
} |
|
}) |
|
elif args.zero_stage == 3: |
|
ds_config.update({"zero_optimization": |
|
{ |
|
"stage": args.zero_stage, |
|
"contiguous_gradients": True, |
|
"overlap_comm": True, |
|
"reduce_bucket_size": 5e8, |
|
"stage3_prefetch_bucket_size": 5e8, |
|
"stage3_param_persistence_threshold": 1e6, |
|
"stage3_max_live_parameters": 1e9, |
|
"stage3_max_reuse_distance": 1e9, |
|
"stage3_gather_16bit_weights_on_model_save": True |
|
} |
|
}) |
|
|
|
return ds_config |
|
|