hotfix for capabilities loading (#1331)
Browse files
src/axolotl/cli/__init__.py
CHANGED
|
@@ -30,7 +30,6 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
|
| 30 |
from axolotl.logging_config import configure_logging
|
| 31 |
from axolotl.train import TrainDatasetMeta
|
| 32 |
from axolotl.utils.config import (
|
| 33 |
-
GPUCapabilities,
|
| 34 |
normalize_cfg_datasets,
|
| 35 |
normalize_config,
|
| 36 |
validate_config,
|
|
@@ -350,14 +349,15 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|
| 350 |
except: # pylint: disable=bare-except # noqa: E722
|
| 351 |
gpu_version = None
|
| 352 |
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
|
|
|
|
|
|
|
|
|
| 357 |
)
|
| 358 |
|
| 359 |
-
cfg = validate_config(cfg, capabilities=capabilities)
|
| 360 |
-
|
| 361 |
prepare_optim_env(cfg)
|
| 362 |
|
| 363 |
normalize_config(cfg)
|
|
|
|
| 30 |
from axolotl.logging_config import configure_logging
|
| 31 |
from axolotl.train import TrainDatasetMeta
|
| 32 |
from axolotl.utils.config import (
|
|
|
|
| 33 |
normalize_cfg_datasets,
|
| 34 |
normalize_config,
|
| 35 |
validate_config,
|
|
|
|
| 349 |
except: # pylint: disable=bare-except # noqa: E722
|
| 350 |
gpu_version = None
|
| 351 |
|
| 352 |
+
cfg = validate_config(
|
| 353 |
+
cfg,
|
| 354 |
+
capabilities={
|
| 355 |
+
"bf16": is_torch_bf16_gpu_available(),
|
| 356 |
+
"n_gpu": os.environ.get("WORLD_SIZE", 1),
|
| 357 |
+
"compute_capability": gpu_version,
|
| 358 |
+
},
|
| 359 |
)
|
| 360 |
|
|
|
|
|
|
|
| 361 |
prepare_optim_env(cfg)
|
| 362 |
|
| 363 |
normalize_config(cfg)
|
src/axolotl/utils/config/__init__.py
CHANGED
|
@@ -13,7 +13,6 @@ from axolotl.utils.config.models.input.v0_4_1 import (
|
|
| 13 |
AxolotlConfigWCapabilities,
|
| 14 |
AxolotlInputConfig,
|
| 15 |
)
|
| 16 |
-
from axolotl.utils.config.models.internals import GPUCapabilities
|
| 17 |
from axolotl.utils.dict import DictDefault
|
| 18 |
from axolotl.utils.models import load_model_config
|
| 19 |
|
|
@@ -197,7 +196,7 @@ def normalize_cfg_datasets(cfg):
|
|
| 197 |
cfg.datasets[idx].conversation = "chatml"
|
| 198 |
|
| 199 |
|
| 200 |
-
def validate_config(cfg: DictDefault, capabilities: Optional[
|
| 201 |
if capabilities:
|
| 202 |
return DictDefault(
|
| 203 |
dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))
|
|
|
|
| 13 |
AxolotlConfigWCapabilities,
|
| 14 |
AxolotlInputConfig,
|
| 15 |
)
|
|
|
|
| 16 |
from axolotl.utils.dict import DictDefault
|
| 17 |
from axolotl.utils.models import load_model_config
|
| 18 |
|
|
|
|
| 196 |
cfg.datasets[idx].conversation = "chatml"
|
| 197 |
|
| 198 |
|
| 199 |
+
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
| 200 |
if capabilities:
|
| 201 |
return DictDefault(
|
| 202 |
dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))
|