| import os | |
| from transformers import ( | |
| Seq2SeqTrainer, | |
| TrainerCallback, | |
| TrainingArguments, | |
| TrainerState, | |
| TrainerControl, | |
| ) | |
| from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR | |
| class SavePeftModelCallback(TrainerCallback): | |
| def on_save( | |
| self, | |
| args: TrainingArguments, | |
| state: TrainerState, | |
| control: TrainerControl, | |
| **kwargs, | |
| ): | |
| checkpoint_folder = os.path.join( | |
| args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" | |
| ) | |
| peft_model_path = os.path.join(checkpoint_folder, "adapter_model") | |
| kwargs["model"].save_pretrained(peft_model_path) | |
| return control | |