Jan Philipp Harries
Jan Philipp Harries
commited on
Added advanced DDP args (#515)
Browse files* add ddp_config
* add advanced ddp config
* add ddp_config
* add advanced ddp config
---------
Co-authored-by: Jan Philipp Harries <[email protected]>
- README.md +5 -0
- src/axolotl/utils/trainer.py +9 -0
README.md
CHANGED
|
@@ -623,6 +623,11 @@ fsdp_config:
|
|
| 623 |
# Deepspeed config path
|
| 624 |
deepspeed:
|
| 625 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
# Path to torch distx for optim 'adamw_anyprecision'
|
| 627 |
torchdistx_path:
|
| 628 |
|
|
|
|
| 623 |
# Deepspeed config path
|
| 624 |
deepspeed:
|
| 625 |
|
| 626 |
+
# Advanced DDP Arguments
|
| 627 |
+
ddp_timeout:
|
| 628 |
+
ddp_bucket_cap_mb:
|
| 629 |
+
ddp_broadcast_buffers:
|
| 630 |
+
|
| 631 |
# Path to torch distx for optim 'adamw_anyprecision'
|
| 632 |
torchdistx_path:
|
| 633 |
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -579,6 +579,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 579 |
if cfg.bench_dataset:
|
| 580 |
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
| 581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
| 583 |
max_steps=total_num_steps if cfg.max_steps else -1,
|
| 584 |
max_seq_length=cfg.sequence_len,
|
|
|
|
| 579 |
if cfg.bench_dataset:
|
| 580 |
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
| 581 |
|
| 582 |
+
# DDP Config
|
| 583 |
+
if cfg.ddp_timeout:
|
| 584 |
+
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
|
| 585 |
+
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
|
| 586 |
+
if cfg.ddp_bucket_cap_mb:
|
| 587 |
+
training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb
|
| 588 |
+
if cfg.ddp_broadcast_buffers is not None:
|
| 589 |
+
training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers
|
| 590 |
+
|
| 591 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
| 592 |
max_steps=total_num_steps if cfg.max_steps else -1,
|
| 593 |
max_seq_length=cfg.sequence_len,
|