Spaces:
Running
Running
#!/usr/bin/env python | |
# Copyright 2021 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import os | |
from ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES | |
from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType | |
from ...utils.imports import is_boto3_available | |
from .config_args import SageMakerConfig | |
from .config_utils import ( | |
DYNAMO_BACKENDS, | |
_ask_field, | |
_ask_options, | |
_convert_dynamo_backend, | |
_convert_mixed_precision, | |
_convert_sagemaker_distributed_mode, | |
_convert_yes_no_to_bool, | |
) | |
if is_boto3_available(): | |
import boto3 # noqa: F401 | |
def _create_iam_role_for_sagemaker(role_name): | |
iam_client = boto3.client("iam") | |
sagemaker_trust_policy = { | |
"Version": "2012-10-17", | |
"Statement": [ | |
{"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"} | |
], | |
} | |
try: | |
# create the role, associated with the chosen trust policy | |
iam_client.create_role( | |
RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2) | |
) | |
policy_document = { | |
"Version": "2012-10-17", | |
"Statement": [ | |
{ | |
"Effect": "Allow", | |
"Action": [ | |
"sagemaker:*", | |
"ecr:GetDownloadUrlForLayer", | |
"ecr:BatchGetImage", | |
"ecr:BatchCheckLayerAvailability", | |
"ecr:GetAuthorizationToken", | |
"cloudwatch:PutMetricData", | |
"cloudwatch:GetMetricData", | |
"cloudwatch:GetMetricStatistics", | |
"cloudwatch:ListMetrics", | |
"logs:CreateLogGroup", | |
"logs:CreateLogStream", | |
"logs:DescribeLogStreams", | |
"logs:PutLogEvents", | |
"logs:GetLogEvents", | |
"s3:CreateBucket", | |
"s3:ListBucket", | |
"s3:GetBucketLocation", | |
"s3:GetObject", | |
"s3:PutObject", | |
], | |
"Resource": "*", | |
} | |
], | |
} | |
# attach policy to role | |
iam_client.put_role_policy( | |
RoleName=role_name, | |
PolicyName=f"{role_name}_policy_permission", | |
PolicyDocument=json.dumps(policy_document, indent=2), | |
) | |
except iam_client.exceptions.EntityAlreadyExistsException: | |
print(f"role {role_name} already exists. Using existing one") | |
def _get_iam_role_arn(role_name): | |
iam_client = boto3.client("iam") | |
return iam_client.get_role(RoleName=role_name)["Role"]["Arn"] | |
def get_sagemaker_input(): | |
credentials_configuration = _ask_options( | |
"How do you want to authorize?", | |
["AWS Profile", "Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) "], | |
int, | |
) | |
aws_profile = None | |
if credentials_configuration == 0: | |
aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default") | |
os.environ["AWS_PROFILE"] = aws_profile | |
else: | |
print( | |
"Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with," | |
"`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`" | |
) | |
aws_access_key_id = _ask_field("AWS Access Key ID: ") | |
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id | |
aws_secret_access_key = _ask_field("AWS Secret Access Key: ") | |
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key | |
aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1") | |
os.environ["AWS_DEFAULT_REGION"] = aws_region | |
role_management = _ask_options( | |
"Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?", | |
["Provide IAM Role name", "Create new IAM role using credentials"], | |
int, | |
) | |
if role_management == 0: | |
iam_role_name = _ask_field("Enter your IAM role name: ") | |
else: | |
iam_role_name = "accelerate_sagemaker_execution_role" | |
print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials') | |
_create_iam_role_for_sagemaker(iam_role_name) | |
is_custom_docker_image = _ask_field( | |
"Do you want to use custom Docker image? [yes/NO]: ", | |
_convert_yes_no_to_bool, | |
default=False, | |
error_message="Please enter yes or no.", | |
) | |
docker_image = None | |
if is_custom_docker_image: | |
docker_image = _ask_field("Enter your Docker image: ", lambda x: str(x).lower()) | |
is_sagemaker_inputs_enabled = _ask_field( | |
"Do you want to provide SageMaker input channels with data locations? [yes/NO]: ", | |
_convert_yes_no_to_bool, | |
default=False, | |
error_message="Please enter yes or no.", | |
) | |
sagemaker_inputs_file = None | |
if is_sagemaker_inputs_enabled: | |
sagemaker_inputs_file = _ask_field( | |
"Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): ", | |
lambda x: str(x).lower(), | |
) | |
is_sagemaker_metrics_enabled = _ask_field( | |
"Do you want to enable SageMaker metrics? [yes/NO]: ", | |
_convert_yes_no_to_bool, | |
default=False, | |
error_message="Please enter yes or no.", | |
) | |
sagemaker_metrics_file = None | |
if is_sagemaker_metrics_enabled: | |
sagemaker_metrics_file = _ask_field( | |
"Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): ", | |
lambda x: str(x).lower(), | |
) | |
distributed_type = _ask_options( | |
"What is the distributed mode?", | |
["No distributed training", "Data parallelism"], | |
_convert_sagemaker_distributed_mode, | |
) | |
dynamo_config = {} | |
use_dynamo = _ask_field( | |
"Do you wish to optimize your script with torch dynamo?[yes/NO]:", | |
_convert_yes_no_to_bool, | |
default=False, | |
error_message="Please enter yes or no.", | |
) | |
if use_dynamo: | |
prefix = "dynamo_" | |
dynamo_config[prefix + "backend"] = _ask_options( | |
"Which dynamo backend would you like to use?", | |
[x.lower() for x in DYNAMO_BACKENDS], | |
_convert_dynamo_backend, | |
default=2, | |
) | |
use_custom_options = _ask_field( | |
"Do you want to customize the defaults sent to torch.compile? [yes/NO]: ", | |
_convert_yes_no_to_bool, | |
default=False, | |
error_message="Please enter yes or no.", | |
) | |
if use_custom_options: | |
dynamo_config[prefix + "mode"] = _ask_options( | |
"Which mode do you want to use?", | |
TORCH_DYNAMO_MODES, | |
lambda x: TORCH_DYNAMO_MODES[int(x)], | |
default="default", | |
) | |
dynamo_config[prefix + "use_fullgraph"] = _ask_field( | |
"Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ", | |
_convert_yes_no_to_bool, | |
default=False, | |
error_message="Please enter yes or no.", | |
) | |
dynamo_config[prefix + "use_dynamic"] = _ask_field( | |
"Do you want to enable dynamic shape tracing? [yes/NO]: ", | |
_convert_yes_no_to_bool, | |
default=False, | |
error_message="Please enter yes or no.", | |
) | |
ec2_instance_query = "Which EC2 instance type you want to use for your training?" | |
if distributed_type != SageMakerDistributedType.NO: | |
ec2_instance_type = _ask_options( | |
ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)] | |
) | |
else: | |
ec2_instance_query += "? [ml.p3.2xlarge]:" | |
ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default="ml.p3.2xlarge") | |
debug = False | |
if distributed_type != SageMakerDistributedType.NO: | |
debug = _ask_field( | |
"Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ", | |
_convert_yes_no_to_bool, | |
default=False, | |
error_message="Please enter yes or no.", | |
) | |
num_machines = 1 | |
if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL): | |
num_machines = _ask_field( | |
"How many machines do you want use? [1]: ", | |
int, | |
default=1, | |
) | |
mixed_precision = _ask_options( | |
"Do you wish to use FP16 or BF16 (mixed precision)?", | |
["no", "fp16", "bf16", "fp8"], | |
_convert_mixed_precision, | |
) | |
if use_dynamo and mixed_precision == "no": | |
print( | |
"Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts." | |
) | |
return SageMakerConfig( | |
image_uri=docker_image, | |
compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER, | |
distributed_type=distributed_type, | |
use_cpu=False, | |
dynamo_config=dynamo_config, | |
ec2_instance_type=ec2_instance_type, | |
profile=aws_profile, | |
region=aws_region, | |
iam_role_name=iam_role_name, | |
mixed_precision=mixed_precision, | |
num_machines=num_machines, | |
sagemaker_inputs_file=sagemaker_inputs_file, | |
sagemaker_metrics_file=sagemaker_metrics_file, | |
debug=debug, | |
) | |