Spaces:
Sleeping
Sleeping
# Copyright 2020 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import dataclasses | |
import json | |
import re | |
import sys | |
from argparse import ArgumentParser, ArgumentTypeError | |
from enum import Enum | |
from pathlib import Path | |
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union | |
DataClass = NewType("DataClass", Any) | |
DataClassType = NewType("DataClassType", Any) | |
# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse | |
def string_to_bool(v): | |
if isinstance(v, bool): | |
return v | |
if v.lower() in ("yes", "true", "t", "y", "1"): | |
return True | |
elif v.lower() in ("no", "false", "f", "n", "0"): | |
return False | |
else: | |
raise ArgumentTypeError( | |
f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." | |
) | |
class HfArgumentParser(ArgumentParser): | |
""" | |
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments. | |
The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed) | |
arguments to the parser after initialization and you'll get the output back after parsing as an additional | |
namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass. | |
""" | |
dataclass_types: Iterable[DataClassType] | |
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs): | |
""" | |
Args: | |
dataclass_types: | |
Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args. | |
kwargs: | |
(Optional) Passed to `argparse.ArgumentParser()` in the regular way. | |
""" | |
super().__init__(**kwargs) | |
if dataclasses.is_dataclass(dataclass_types): | |
dataclass_types = [dataclass_types] | |
self.dataclass_types = dataclass_types | |
for dtype in self.dataclass_types: | |
self._add_dataclass_arguments(dtype) | |
def _add_dataclass_arguments(self, dtype: DataClassType): | |
if hasattr(dtype, "_argument_group_name"): | |
parser = self.add_argument_group(dtype._argument_group_name) | |
else: | |
parser = self | |
for field in dataclasses.fields(dtype): | |
if not field.init: | |
continue | |
field_name = f"--{field.name}" | |
kwargs = field.metadata.copy() | |
# field.metadata is not used at all by Data Classes, | |
# it is provided as a third-party extension mechanism. | |
if isinstance(field.type, str): | |
raise ImportError( | |
"This implementation is not compatible with Postponed Evaluation of Annotations (PEP 563)," | |
"which can be opted in from Python 3.7 with `from __future__ import annotations`." | |
"We will add compatibility when Python 3.9 is released." | |
) | |
typestring = str(field.type) | |
for prim_type in (int, float, str): | |
for collection in (List,): | |
if ( | |
typestring == f"typing.Union[{collection[prim_type]}, NoneType]" | |
or typestring == f"typing.Optional[{collection[prim_type]}]" | |
): | |
field.type = collection[prim_type] | |
if ( | |
typestring == f"typing.Union[{prim_type.__name__}, NoneType]" | |
or typestring == f"typing.Optional[{prim_type.__name__}]" | |
): | |
field.type = prim_type | |
if isinstance(field.type, type) and issubclass(field.type, Enum): | |
kwargs["choices"] = [x.value for x in field.type] | |
kwargs["type"] = type(kwargs["choices"][0]) | |
if field.default is not dataclasses.MISSING: | |
kwargs["default"] = field.default | |
else: | |
kwargs["required"] = True | |
elif field.type is bool or field.type == Optional[bool]: | |
if field.default is True: | |
parser.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **kwargs) | |
# Hack because type=bool in argparse does not behave as we want. | |
kwargs["type"] = string_to_bool | |
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): | |
# Default value is False if we have no default when of type bool. | |
default = False if field.default is dataclasses.MISSING else field.default | |
# This is the value that will get picked if we don't include --field_name in any way | |
kwargs["default"] = default | |
# This tells argparse we accept 0 or 1 value after --field_name | |
kwargs["nargs"] = "?" | |
# This is the value that will get picked if we do --field_name (without value) | |
kwargs["const"] = True | |
elif ( | |
hasattr(field.type, "__origin__") and re.search(r"^typing\.List\[(.*)\]$", str(field.type)) is not None | |
): | |
kwargs["nargs"] = "+" | |
kwargs["type"] = field.type.__args__[0] | |
assert all( | |
x == kwargs["type"] for x in field.type.__args__ | |
), f"{field.name} cannot be a List of mixed types" | |
if field.default_factory is not dataclasses.MISSING: | |
kwargs["default"] = field.default_factory() | |
elif field.default is dataclasses.MISSING: | |
kwargs["required"] = True | |
else: | |
kwargs["type"] = field.type | |
if field.default is not dataclasses.MISSING: | |
kwargs["default"] = field.default | |
elif field.default_factory is not dataclasses.MISSING: | |
kwargs["default"] = field.default_factory() | |
else: | |
kwargs["required"] = True | |
parser.add_argument(field_name, **kwargs) | |
def parse_args_into_dataclasses( | |
self, args=None, return_remaining_strings=False, look_for_args_file=True, args_filename=None | |
) -> Tuple[DataClass, ...]: | |
""" | |
Parse command-line args into instances of the specified dataclass types. | |
This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at: | |
docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args | |
Args: | |
args: | |
List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser) | |
return_remaining_strings: | |
If true, also return a list of remaining argument strings. | |
look_for_args_file: | |
If true, will look for a ".args" file with the same base name as the entry point script for this | |
process, and will append its potential content to the command line args. | |
args_filename: | |
If not None, will uses this file instead of the ".args" file specified in the previous argument. | |
Returns: | |
Tuple consisting of: | |
- the dataclass instances in the same order as they were passed to the initializer.abspath | |
- if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser | |
after initialization. | |
- The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args) | |
""" | |
if args_filename or (look_for_args_file and len(sys.argv)): | |
if args_filename: | |
args_file = Path(args_filename) | |
else: | |
args_file = Path(sys.argv[0]).with_suffix(".args") | |
if args_file.exists(): | |
fargs = args_file.read_text().split() | |
args = fargs + args if args is not None else fargs + sys.argv[1:] | |
# in case of duplicate arguments the first one has precedence | |
# so we append rather than prepend. | |
namespace, remaining_args = self.parse_known_args(args=args) | |
outputs = [] | |
for dtype in self.dataclass_types: | |
keys = {f.name for f in dataclasses.fields(dtype) if f.init} | |
inputs = {k: v for k, v in vars(namespace).items() if k in keys} | |
for k in keys: | |
delattr(namespace, k) | |
obj = dtype(**inputs) | |
outputs.append(obj) | |
if len(namespace.__dict__) > 0: | |
# additional namespace. | |
outputs.append(namespace) | |
if return_remaining_strings: | |
return (*outputs, remaining_args) | |
else: | |
if remaining_args: | |
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") | |
return (*outputs,) | |
def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: | |
""" | |
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the | |
dataclass types. | |
""" | |
data = json.loads(Path(json_file).read_text()) | |
outputs = [] | |
for dtype in self.dataclass_types: | |
keys = {f.name for f in dataclasses.fields(dtype) if f.init} | |
inputs = {k: v for k, v in data.items() if k in keys} | |
obj = dtype(**inputs) | |
outputs.append(obj) | |
return (*outputs,) | |
def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: | |
""" | |
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass | |
types. | |
""" | |
outputs = [] | |
for dtype in self.dataclass_types: | |
keys = {f.name for f in dataclasses.fields(dtype) if f.init} | |
inputs = {k: v for k, v in args.items() if k in keys} | |
obj = dtype(**inputs) | |
outputs.append(obj) | |
return (*outputs,) | |