#!/usr/bin/env python # Copyright 2021 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os from dataclasses import dataclass from enum import Enum from typing import List, Optional, Union import yaml from ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType from ...utils.constants import SAGEMAKER_PYTHON_VERSION, SAGEMAKER_PYTORCH_VERSION, SAGEMAKER_TRANSFORMERS_VERSION hf_cache_home = os.path.expanduser( os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) ) cache_dir = os.path.join(hf_cache_home, "accelerate") default_json_config_file = os.path.join(cache_dir, "default_config.yaml") default_yaml_config_file = os.path.join(cache_dir, "default_config.yaml") # For backward compatibility: the default config is the json one if it's the only existing file. if os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file): default_config_file = default_yaml_config_file else: default_config_file = default_json_config_file def load_config_from_file(config_file): if config_file is not None: if not os.path.isfile(config_file): raise FileNotFoundError( f"The passed configuration file `{config_file}` does not exist. " "Please pass an existing file to `accelerate launch`, or use the the default one " "created through `accelerate config` and run `accelerate launch` " "without the `--config_file` argument." ) else: config_file = default_config_file with open(config_file, "r", encoding="utf-8") as f: if config_file.endswith(".json"): if ( json.load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE) == ComputeEnvironment.LOCAL_MACHINE ): config_class = ClusterConfig else: config_class = SageMakerConfig return config_class.from_json_file(json_file=config_file) else: if ( yaml.safe_load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE) == ComputeEnvironment.LOCAL_MACHINE ): config_class = ClusterConfig else: config_class = SageMakerConfig return config_class.from_yaml_file(yaml_file=config_file) @dataclass class BaseConfig: compute_environment: ComputeEnvironment distributed_type: Union[DistributedType, SageMakerDistributedType] mixed_precision: str use_cpu: bool def to_dict(self): result = self.__dict__ # For serialization, it's best to convert Enums to strings (or their underlying value type). for key, value in result.items(): if isinstance(value, Enum): result[key] = value.value if isinstance(value, dict) and not bool(value): result[key] = None result = {k: v for k, v in result.items() if v is not None} return result @classmethod def from_json_file(cls, json_file=None): json_file = default_json_config_file if json_file is None else json_file with open(json_file, "r", encoding="utf-8") as f: config_dict = json.load(f) if "compute_environment" not in config_dict: config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE if "mixed_precision" not in config_dict: config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None if "fp16" in config_dict: # Convert the config to the new format. del config_dict["fp16"] if "dynamo_backend" in config_dict: # Convert the config to the new format. dynamo_backend = config_dict.pop("dynamo_backend") config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend} if "use_cpu" not in config_dict: config_dict["use_cpu"] = False return cls(**config_dict) def to_json_file(self, json_file): with open(json_file, "w", encoding="utf-8") as f: content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" f.write(content) @classmethod def from_yaml_file(cls, yaml_file=None): yaml_file = default_yaml_config_file if yaml_file is None else yaml_file with open(yaml_file, "r", encoding="utf-8") as f: config_dict = yaml.safe_load(f) if "compute_environment" not in config_dict: config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE if "mixed_precision" not in config_dict: config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None if "fp16" in config_dict: # Convert the config to the new format. del config_dict["fp16"] if "dynamo_backend" in config_dict: # Convert the config to the new format. dynamo_backend = config_dict.pop("dynamo_backend") config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend} if "use_cpu" not in config_dict: config_dict["use_cpu"] = False return cls(**config_dict) def to_yaml_file(self, yaml_file): with open(yaml_file, "w", encoding="utf-8") as f: yaml.safe_dump(self.to_dict(), f) def __post_init__(self): if isinstance(self.compute_environment, str): self.compute_environment = ComputeEnvironment(self.compute_environment) if isinstance(self.distributed_type, str): if self.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER: self.distributed_type = SageMakerDistributedType(self.distributed_type) else: self.distributed_type = DistributedType(self.distributed_type) if self.dynamo_config is None: self.dynamo_config = {} @dataclass class ClusterConfig(BaseConfig): num_processes: int machine_rank: int = 0 num_machines: int = 1 gpu_ids: Optional[str] = None main_process_ip: Optional[str] = None main_process_port: Optional[int] = None rdzv_backend: Optional[str] = "static" same_network: Optional[bool] = False main_training_function: str = "main" # args for deepspeed_plugin deepspeed_config: dict = None # args for fsdp fsdp_config: dict = None # args for megatron_lm megatron_lm_config: dict = None # args for TPU downcast_bf16: bool = False # args for TPU pods tpu_name: str = None tpu_zone: str = None tpu_use_cluster: bool = False tpu_use_sudo: bool = False command_file: str = None commands: List[str] = None tpu_vm: List[str] = None tpu_env: List[str] = None # args for dynamo dynamo_config: dict = None def __post_init__(self): if self.deepspeed_config is None: self.deepspeed_config = {} if self.fsdp_config is None: self.fsdp_config = {} if self.megatron_lm_config is None: self.megatron_lm_config = {} return super().__post_init__() @dataclass class SageMakerConfig(BaseConfig): ec2_instance_type: str iam_role_name: str image_uri: Optional[str] = None profile: Optional[str] = None region: str = "us-east-1" num_machines: int = 1 gpu_ids: str = "all" base_job_name: str = f"accelerate-sagemaker-{num_machines}" pytorch_version: str = SAGEMAKER_PYTORCH_VERSION transformers_version: str = SAGEMAKER_TRANSFORMERS_VERSION py_version: str = SAGEMAKER_PYTHON_VERSION sagemaker_inputs_file: str = None sagemaker_metrics_file: str = None additional_args: dict = None dynamo_config: dict = None