#!/usr/bin/env python # Copyright 2021 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from ...utils.dataclasses import ( ComputeEnvironment, DistributedType, DynamoBackend, PrecisionType, SageMakerDistributedType, ) from ..menu import BulletMenu DYNAMO_BACKENDS = [ "EAGER", "AOT_EAGER", "INDUCTOR", "AOT_TS_NVFUSER", "NVPRIMS_NVFUSER", "CUDAGRAPHS", "OFI", "FX2TRT", "ONNXRT", "TENSORRT", "AOT_TORCHXLA_TRACE_ONCE", "TORHCHXLA_TRACE_ONCE", "IPEX", "TVM", ] def _ask_field(input_text, convert_value=None, default=None, error_message=None): ask_again = True while ask_again: result = input(input_text) try: if default is not None and len(result) == 0: return default return convert_value(result) if convert_value is not None else result except Exception: if error_message is not None: print(error_message) def _ask_options(input_text, options=[], convert_value=None, default=0): menu = BulletMenu(input_text, options) result = menu.run(default_choice=default) return convert_value(result) if convert_value is not None else result def _convert_compute_environment(value): value = int(value) return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value]) def _convert_distributed_mode(value): value = int(value) return DistributedType(["NO", "MULTI_CPU", "MULTI_XPU", "MULTI_GPU", "MULTI_NPU", "MULTI_MLU", "XLA"][value]) def _convert_dynamo_backend(value): value = int(value) return DynamoBackend(DYNAMO_BACKENDS[value]).value def _convert_mixed_precision(value): value = int(value) return PrecisionType(["no", "fp16", "bf16", "fp8"][value]) def _convert_sagemaker_distributed_mode(value): value = int(value) return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value]) def _convert_yes_no_to_bool(value): return {"yes": True, "no": False}[value.lower()] class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter): """ A custom formatter that will remove the usage line from the help message for subcommands. """ def _format_usage(self, usage, actions, groups, prefix): usage = super()._format_usage(usage, actions, groups, prefix) usage = usage.replace(" [] ", "") return usage