Draken007's picture
Upload 7228 files
2a0bc63 verified
#!/usr/bin/env python
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from ...utils.dataclasses import (
ComputeEnvironment,
DistributedType,
DynamoBackend,
PrecisionType,
SageMakerDistributedType,
)
from ..menu import BulletMenu
DYNAMO_BACKENDS = [
"EAGER",
"AOT_EAGER",
"INDUCTOR",
"AOT_TS_NVFUSER",
"NVPRIMS_NVFUSER",
"CUDAGRAPHS",
"OFI",
"FX2TRT",
"ONNXRT",
"TENSORRT",
"IPEX",
"TVM",
]
def _ask_field(input_text, convert_value=None, default=None, error_message=None):
ask_again = True
while ask_again:
result = input(input_text)
try:
if default is not None and len(result) == 0:
return default
return convert_value(result) if convert_value is not None else result
except Exception:
if error_message is not None:
print(error_message)
def _ask_options(input_text, options=[], convert_value=None, default=0):
menu = BulletMenu(input_text, options)
result = menu.run(default_choice=default)
return convert_value(result) if convert_value is not None else result
def _convert_compute_environment(value):
value = int(value)
return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value])
def _convert_distributed_mode(value):
value = int(value)
return DistributedType(["NO", "MULTI_CPU", "MULTI_XPU", "MULTI_GPU", "MULTI_NPU", "MULTI_MLU", "XLA"][value])
def _convert_dynamo_backend(value):
value = int(value)
return DynamoBackend(DYNAMO_BACKENDS[value]).value
def _convert_mixed_precision(value):
value = int(value)
return PrecisionType(["no", "fp16", "bf16", "fp8"][value])
def _convert_sagemaker_distributed_mode(value):
value = int(value)
return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value])
def _convert_yes_no_to_bool(value):
return {"yes": True, "no": False}[value.lower()]
class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter):
"""
A custom formatter that will remove the usage line from the help message for subcommands.
"""
def _format_usage(self, usage, actions, groups, prefix):
usage = super()._format_usage(usage, actions, groups, prefix)
usage = usage.replace("<command> [<args>] ", "")
return usage