match up gradient checkpointing when using lora w config
Browse files
src/axolotl/utils/models.py
CHANGED
|
@@ -305,7 +305,9 @@ def load_model(
|
|
| 305 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
| 306 |
):
|
| 307 |
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
| 308 |
-
model = prepare_model_for_kbit_training(
|
|
|
|
|
|
|
| 309 |
|
| 310 |
model, lora_config = load_adapter(model, cfg, adapter)
|
| 311 |
|
|
|
|
| 305 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
| 306 |
):
|
| 307 |
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
| 308 |
+
model = prepare_model_for_kbit_training(
|
| 309 |
+
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
| 310 |
+
)
|
| 311 |
|
| 312 |
model, lora_config = load_adapter(model, cfg, adapter)
|
| 313 |
|