fixes to save on fractional save_steps (#1643)
Browse files
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -43,7 +43,7 @@ from axolotl.utils.callbacks import (
|
|
| 43 |
LossWatchDogCallback,
|
| 44 |
SaveAxolotlConfigtoWandBCallback,
|
| 45 |
SaveBetterTransformerModelCallback,
|
| 46 |
-
|
| 47 |
bench_eval_callback_factory,
|
| 48 |
causal_lm_bench_eval_callback_factory,
|
| 49 |
log_prediction_callback_factory,
|
|
@@ -945,7 +945,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 945 |
if self.cfg.loss_watchdog_threshold is not None:
|
| 946 |
callbacks.append(LossWatchDogCallback(self.cfg))
|
| 947 |
|
| 948 |
-
callbacks.append(
|
| 949 |
|
| 950 |
return callbacks
|
| 951 |
|
|
@@ -1431,7 +1431,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
| 1431 |
|
| 1432 |
def get_callbacks(self):
|
| 1433 |
callbacks = super().get_callbacks()
|
| 1434 |
-
callbacks.append(
|
| 1435 |
|
| 1436 |
return callbacks
|
| 1437 |
|
|
|
|
| 43 |
LossWatchDogCallback,
|
| 44 |
SaveAxolotlConfigtoWandBCallback,
|
| 45 |
SaveBetterTransformerModelCallback,
|
| 46 |
+
SaveModelCallback,
|
| 47 |
bench_eval_callback_factory,
|
| 48 |
causal_lm_bench_eval_callback_factory,
|
| 49 |
log_prediction_callback_factory,
|
|
|
|
| 945 |
if self.cfg.loss_watchdog_threshold is not None:
|
| 946 |
callbacks.append(LossWatchDogCallback(self.cfg))
|
| 947 |
|
| 948 |
+
callbacks.append(SaveModelCallback())
|
| 949 |
|
| 950 |
return callbacks
|
| 951 |
|
|
|
|
| 1431 |
|
| 1432 |
def get_callbacks(self):
|
| 1433 |
callbacks = super().get_callbacks()
|
| 1434 |
+
callbacks.append(SaveModelCallback())
|
| 1435 |
|
| 1436 |
return callbacks
|
| 1437 |
|
src/axolotl/utils/callbacks/__init__.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import logging
|
|
|
|
| 6 |
import os
|
| 7 |
from shutil import copyfile
|
| 8 |
from tempfile import NamedTemporaryFile
|
|
@@ -775,7 +776,7 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|
| 775 |
return control
|
| 776 |
|
| 777 |
|
| 778 |
-
class
|
| 779 |
"""Callback to save model on train end"""
|
| 780 |
|
| 781 |
def on_step_end( # pylint: disable=unused-argument
|
|
@@ -788,6 +789,13 @@ class SaveModelOnTrainEndCallback(TrainerCallback):
|
|
| 788 |
# Save
|
| 789 |
if state.global_step >= state.max_steps:
|
| 790 |
control.should_save = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 791 |
|
| 792 |
def on_train_end( # pylint: disable=unused-argument
|
| 793 |
self, args, state, control, **kwargs
|
|
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import logging
|
| 6 |
+
import math
|
| 7 |
import os
|
| 8 |
from shutil import copyfile
|
| 9 |
from tempfile import NamedTemporaryFile
|
|
|
|
| 776 |
return control
|
| 777 |
|
| 778 |
|
| 779 |
+
class SaveModelCallback(TrainerCallback):
|
| 780 |
"""Callback to save model on train end"""
|
| 781 |
|
| 782 |
def on_step_end( # pylint: disable=unused-argument
|
|
|
|
| 789 |
# Save
|
| 790 |
if state.global_step >= state.max_steps:
|
| 791 |
control.should_save = True
|
| 792 |
+
elif (
|
| 793 |
+
args.save_strategy == IntervalStrategy.STEPS
|
| 794 |
+
and state.save_steps < 1.0
|
| 795 |
+
and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0
|
| 796 |
+
):
|
| 797 |
+
# workaround to save model on fractional save_steps
|
| 798 |
+
control.should_save = True
|
| 799 |
|
| 800 |
def on_train_end( # pylint: disable=unused-argument
|
| 801 |
self, args, state, control, **kwargs
|