Spaces:
Runtime error
Runtime error
| import copy | |
| import json | |
| import sys | |
| import warnings | |
| from collections import defaultdict, namedtuple | |
| from dataclasses import (MISSING, | |
| fields, | |
| is_dataclass # type: ignore | |
| ) | |
| from datetime import datetime, timezone | |
| from decimal import Decimal | |
| from enum import Enum | |
| from typing import (Any, Collection, Mapping, Union, get_type_hints, | |
| Tuple, TypeVar, Type) | |
| from uuid import UUID | |
| from typing_inspect import is_union_type # type: ignore | |
| from dataclasses_json import cfg | |
| from dataclasses_json.utils import (_get_type_cons, _get_type_origin, | |
| _handle_undefined_parameters_safe, | |
| _is_collection, _is_mapping, _is_new_type, | |
| _is_optional, _isinstance_safe, | |
| _get_type_arg_param, | |
| _get_type_args, _is_counter, | |
| _NO_ARGS, | |
| _issubclass_safe, _is_tuple) | |
| Json = Union[dict, list, str, int, float, bool, None] | |
| confs = ['encoder', 'decoder', 'mm_field', 'letter_case', 'exclude'] | |
| FieldOverride = namedtuple('FieldOverride', confs) # type: ignore | |
| class _ExtendedEncoder(json.JSONEncoder): | |
| def default(self, o) -> Json: | |
| result: Json | |
| if _isinstance_safe(o, Collection): | |
| if _isinstance_safe(o, Mapping): | |
| result = dict(o) | |
| else: | |
| result = list(o) | |
| elif _isinstance_safe(o, datetime): | |
| result = o.timestamp() | |
| elif _isinstance_safe(o, UUID): | |
| result = str(o) | |
| elif _isinstance_safe(o, Enum): | |
| result = o.value | |
| elif _isinstance_safe(o, Decimal): | |
| result = str(o) | |
| else: | |
| result = json.JSONEncoder.default(self, o) | |
| return result | |
| def _user_overrides_or_exts(cls): | |
| global_metadata = defaultdict(dict) | |
| encoders = cfg.global_config.encoders | |
| decoders = cfg.global_config.decoders | |
| mm_fields = cfg.global_config.mm_fields | |
| for field in fields(cls): | |
| if field.type in encoders: | |
| global_metadata[field.name]['encoder'] = encoders[field.type] | |
| if field.type in decoders: | |
| global_metadata[field.name]['decoder'] = decoders[field.type] | |
| if field.type in mm_fields: | |
| global_metadata[field.name]['mm_field'] = mm_fields[field.type] | |
| try: | |
| cls_config = (cls.dataclass_json_config | |
| if cls.dataclass_json_config is not None else {}) | |
| except AttributeError: | |
| cls_config = {} | |
| overrides = {} | |
| for field in fields(cls): | |
| field_config = {} | |
| # first apply global overrides or extensions | |
| field_metadata = global_metadata[field.name] | |
| if 'encoder' in field_metadata: | |
| field_config['encoder'] = field_metadata['encoder'] | |
| if 'decoder' in field_metadata: | |
| field_config['decoder'] = field_metadata['decoder'] | |
| if 'mm_field' in field_metadata: | |
| field_config['mm_field'] = field_metadata['mm_field'] | |
| # then apply class-level overrides or extensions | |
| field_config.update(cls_config) | |
| # last apply field-level overrides or extensions | |
| field_config.update(field.metadata.get('dataclasses_json', {})) | |
| overrides[field.name] = FieldOverride(*map(field_config.get, confs)) | |
| return overrides | |
| def _encode_json_type(value, default=_ExtendedEncoder().default): | |
| if isinstance(value, Json.__args__): # type: ignore | |
| if isinstance(value, list): | |
| return [_encode_json_type(i) for i in value] | |
| elif isinstance(value, dict): | |
| return {k: _encode_json_type(v) for k, v in value.items()} | |
| else: | |
| return value | |
| return default(value) | |
| def _encode_overrides(kvs, overrides, encode_json=False): | |
| override_kvs = {} | |
| for k, v in kvs.items(): | |
| if k in overrides: | |
| exclude = overrides[k].exclude | |
| # If the exclude predicate returns true, the key should be | |
| # excluded from encoding, so skip the rest of the loop | |
| if exclude and exclude(v): | |
| continue | |
| letter_case = overrides[k].letter_case | |
| original_key = k | |
| k = letter_case(k) if letter_case is not None else k | |
| if k in override_kvs: | |
| raise ValueError( | |
| f"Multiple fields map to the same JSON " | |
| f"key after letter case encoding: {k}" | |
| ) | |
| encoder = overrides[original_key].encoder | |
| v = encoder(v) if encoder is not None else v | |
| if encode_json: | |
| v = _encode_json_type(v) | |
| override_kvs[k] = v | |
| return override_kvs | |
| def _decode_letter_case_overrides(field_names, overrides): | |
| """Override letter case of field names for encode/decode""" | |
| names = {} | |
| for field_name in field_names: | |
| field_override = overrides.get(field_name) | |
| if field_override is not None: | |
| letter_case = field_override.letter_case | |
| if letter_case is not None: | |
| names[letter_case(field_name)] = field_name | |
| return names | |
| def _decode_dataclass(cls, kvs, infer_missing): | |
| if _isinstance_safe(kvs, cls): | |
| return kvs | |
| overrides = _user_overrides_or_exts(cls) | |
| kvs = {} if kvs is None and infer_missing else kvs | |
| field_names = [field.name for field in fields(cls)] | |
| decode_names = _decode_letter_case_overrides(field_names, overrides) | |
| kvs = {decode_names.get(k, k): v for k, v in kvs.items()} | |
| missing_fields = {field for field in fields(cls) if field.name not in kvs} | |
| for field in missing_fields: | |
| if field.default is not MISSING: | |
| kvs[field.name] = field.default | |
| elif field.default_factory is not MISSING: | |
| kvs[field.name] = field.default_factory() | |
| elif infer_missing: | |
| kvs[field.name] = None | |
| # Perform undefined parameter action | |
| kvs = _handle_undefined_parameters_safe(cls, kvs, usage="from") | |
| init_kwargs = {} | |
| types = get_type_hints(cls) | |
| for field in fields(cls): | |
| # The field should be skipped from being added | |
| # to init_kwargs as it's not intended as a constructor argument. | |
| if not field.init: | |
| continue | |
| field_value = kvs[field.name] | |
| field_type = types[field.name] | |
| if field_value is None: | |
| if not _is_optional(field_type): | |
| warning = ( | |
| f"value of non-optional type {field.name} detected " | |
| f"when decoding {cls.__name__}" | |
| ) | |
| if infer_missing: | |
| warnings.warn( | |
| f"Missing {warning} and was defaulted to None by " | |
| f"infer_missing=True. " | |
| f"Set infer_missing=False (the default) to prevent " | |
| f"this behavior.", RuntimeWarning | |
| ) | |
| else: | |
| warnings.warn( | |
| f"'NoneType' object {warning}.", RuntimeWarning | |
| ) | |
| init_kwargs[field.name] = field_value | |
| continue | |
| while True: | |
| if not _is_new_type(field_type): | |
| break | |
| field_type = field_type.__supertype__ | |
| if (field.name in overrides | |
| and overrides[field.name].decoder is not None): | |
| # FIXME hack | |
| if field_type is type(field_value): | |
| init_kwargs[field.name] = field_value | |
| else: | |
| init_kwargs[field.name] = overrides[field.name].decoder( | |
| field_value) | |
| elif is_dataclass(field_type): | |
| # FIXME this is a band-aid to deal with the value already being | |
| # serialized when handling nested marshmallow schema | |
| # proper fix is to investigate the marshmallow schema generation | |
| # code | |
| if is_dataclass(field_value): | |
| value = field_value | |
| else: | |
| value = _decode_dataclass(field_type, field_value, | |
| infer_missing) | |
| init_kwargs[field.name] = value | |
| elif _is_supported_generic(field_type) and field_type != str: | |
| init_kwargs[field.name] = _decode_generic(field_type, | |
| field_value, | |
| infer_missing) | |
| else: | |
| init_kwargs[field.name] = _support_extended_types(field_type, | |
| field_value) | |
| return cls(**init_kwargs) | |
| def _support_extended_types(field_type, field_value): | |
| if _issubclass_safe(field_type, datetime): | |
| # FIXME this is a hack to deal with mm already decoding | |
| # the issue is we want to leverage mm fields' missing argument | |
| # but need this for the object creation hook | |
| if isinstance(field_value, datetime): | |
| res = field_value | |
| else: | |
| tz = datetime.now(timezone.utc).astimezone().tzinfo | |
| res = datetime.fromtimestamp(field_value, tz=tz) | |
| elif _issubclass_safe(field_type, Decimal): | |
| res = (field_value | |
| if isinstance(field_value, Decimal) | |
| else Decimal(field_value)) | |
| elif _issubclass_safe(field_type, UUID): | |
| res = (field_value | |
| if isinstance(field_value, UUID) | |
| else UUID(field_value)) | |
| elif _issubclass_safe(field_type, (int, float, str, bool)): | |
| res = (field_value | |
| if isinstance(field_value, field_type) | |
| else field_type(field_value)) | |
| else: | |
| res = field_value | |
| return res | |
| def _is_supported_generic(type_): | |
| if type_ is _NO_ARGS: | |
| return False | |
| not_str = not _issubclass_safe(type_, str) | |
| is_enum = _issubclass_safe(type_, Enum) | |
| return (not_str and _is_collection(type_)) or _is_optional( | |
| type_) or is_union_type(type_) or is_enum | |
| def _decode_generic(type_, value, infer_missing): | |
| if value is None: | |
| res = value | |
| elif _issubclass_safe(type_, Enum): | |
| # Convert to an Enum using the type as a constructor. | |
| # Assumes a direct match is found. | |
| res = type_(value) | |
| # FIXME this is a hack to fix a deeper underlying issue. A refactor is due. | |
| elif _is_collection(type_): | |
| if _is_mapping(type_) and not _is_counter(type_): | |
| k_type, v_type = _get_type_args(type_, (Any, Any)) | |
| # a mapping type has `.keys()` and `.values()` | |
| # (see collections.abc) | |
| ks = _decode_dict_keys(k_type, value.keys(), infer_missing) | |
| vs = _decode_items(v_type, value.values(), infer_missing) | |
| xs = zip(ks, vs) | |
| elif _is_tuple(type_): | |
| types = _get_type_args(type_) | |
| if Ellipsis in types: | |
| xs = _decode_items(types[0], value, infer_missing) | |
| else: | |
| xs = _decode_items(_get_type_args(type_) or _NO_ARGS, value, infer_missing) | |
| elif _is_counter(type_): | |
| xs = dict(zip(_decode_items(_get_type_arg_param(type_, 0), value.keys(), infer_missing), value.values())) | |
| else: | |
| xs = _decode_items(_get_type_arg_param(type_, 0), value, infer_missing) | |
| # get the constructor if using corresponding generic type in `typing` | |
| # otherwise fallback on constructing using type_ itself | |
| materialize_type = type_ | |
| try: | |
| materialize_type = _get_type_cons(type_) | |
| except (TypeError, AttributeError): | |
| pass | |
| res = materialize_type(xs) | |
| else: # Optional or Union | |
| _args = _get_type_args(type_) | |
| if _args is _NO_ARGS: | |
| # Any, just accept | |
| res = value | |
| elif _is_optional(type_) and len(_args) == 2: # Optional | |
| type_arg = _get_type_arg_param(type_, 0) | |
| if is_dataclass(type_arg) or is_dataclass(value): | |
| res = _decode_dataclass(type_arg, value, infer_missing) | |
| elif _is_supported_generic(type_arg): | |
| res = _decode_generic(type_arg, value, infer_missing) | |
| else: | |
| res = _support_extended_types(type_arg, value) | |
| else: # Union (already decoded or try to decode a dataclass) | |
| type_options = _get_type_args(type_) | |
| res = value # assume already decoded | |
| if type(value) is dict and dict not in type_options: | |
| for type_option in type_options: | |
| if is_dataclass(type_option): | |
| try: | |
| res = _decode_dataclass(type_option, value, infer_missing) | |
| break | |
| except (KeyError, ValueError, AttributeError): | |
| continue | |
| if res == value: | |
| warnings.warn( | |
| f"Failed to decode {value} Union dataclasses." | |
| f"Expected Union to include a matching dataclass and it didn't." | |
| ) | |
| return res | |
| def _decode_dict_keys(key_type, xs, infer_missing): | |
| """ | |
| Because JSON object keys must be strs, we need the extra step of decoding | |
| them back into the user's chosen python type | |
| """ | |
| decode_function = key_type | |
| # handle NoneType keys... it's weird to type a Dict as NoneType keys | |
| # but it's valid... | |
| # Issue #341 and PR #346: | |
| # This is a special case for Python 3.7 and Python 3.8. | |
| # By some reason, "unbound" dicts are counted | |
| # as having key type parameter to be TypeVar('KT') | |
| if key_type is None or key_type == Any or isinstance(key_type, TypeVar): | |
| decode_function = key_type = (lambda x: x) | |
| # handle a nested python dict that has tuples for keys. E.g. for | |
| # Dict[Tuple[int], int], key_type will be typing.Tuple[int], but | |
| # decode_function should be tuple, so map() doesn't break. | |
| # | |
| # Note: _get_type_origin() will return typing.Tuple for python | |
| # 3.6 and tuple for 3.7 and higher. | |
| elif _get_type_origin(key_type) in {tuple, Tuple}: | |
| decode_function = tuple | |
| key_type = key_type | |
| return map(decode_function, _decode_items(key_type, xs, infer_missing)) | |
| def _decode_items(type_args, xs, infer_missing): | |
| """ | |
| This is a tricky situation where we need to check both the annotated | |
| type info (which is usually a type from `typing`) and check the | |
| value's type directly using `type()`. | |
| If the type_arg is a generic we can use the annotated type, but if the | |
| type_arg is a typevar we need to extract the reified type information | |
| hence the check of `is_dataclass(vs)` | |
| """ | |
| def _decode_item(type_arg, x): | |
| if is_dataclass(type_arg) or is_dataclass(xs): | |
| return _decode_dataclass(type_arg, x, infer_missing) | |
| if _is_supported_generic(type_arg): | |
| return _decode_generic(type_arg, x, infer_missing) | |
| return x | |
| def handle_pep0673(pre_0673_hint: str) -> Union[Type, str]: | |
| for module in sys.modules: | |
| maybe_resolved = getattr(sys.modules[module], type_args, None) | |
| if maybe_resolved: | |
| return maybe_resolved | |
| warnings.warn(f"Could not resolve self-reference for type {pre_0673_hint}, " | |
| f"decoded type might be incorrect or decode might fail altogether.") | |
| return pre_0673_hint | |
| # Before https://peps.python.org/pep-0673 (3.11+) self-type hints are simply strings | |
| if sys.version_info.minor < 11 and type_args is not type and type(type_args) is str: | |
| type_args = handle_pep0673(type_args) | |
| if _isinstance_safe(type_args, Collection) and not _issubclass_safe(type_args, Enum): | |
| if len(type_args) == len(xs): | |
| return list(_decode_item(type_arg, x) for type_arg, x in zip(type_args, xs)) | |
| else: | |
| raise TypeError(f"Number of types specified in the collection type {str(type_args)} " | |
| f"does not match number of elements in the collection. In case you are working with tuples" | |
| f"take a look at this document " | |
| f"docs.python.org/3/library/typing.html#annotating-tuples.") | |
| return list(_decode_item(type_args, x) for x in xs) | |
| def _asdict(obj, encode_json=False): | |
| """ | |
| A re-implementation of `asdict` (based on the original in the `dataclasses` | |
| source) to support arbitrary Collection and Mapping types. | |
| """ | |
| if is_dataclass(obj): | |
| result = [] | |
| overrides = _user_overrides_or_exts(obj) | |
| for field in fields(obj): | |
| if overrides[field.name].encoder: | |
| value = getattr(obj, field.name) | |
| else: | |
| value = _asdict( | |
| getattr(obj, field.name), | |
| encode_json=encode_json | |
| ) | |
| result.append((field.name, value)) | |
| result = _handle_undefined_parameters_safe(cls=obj, kvs=dict(result), | |
| usage="to") | |
| return _encode_overrides(dict(result), _user_overrides_or_exts(obj), | |
| encode_json=encode_json) | |
| elif isinstance(obj, Mapping): | |
| return dict((_asdict(k, encode_json=encode_json), | |
| _asdict(v, encode_json=encode_json)) for k, v in | |
| obj.items()) | |
| # enum.IntFlag and enum.Flag are regarded as collections in Python 3.11, thus a check against Enum is needed | |
| elif isinstance(obj, Collection) and not isinstance(obj, (str, bytes, Enum)): | |
| return list(_asdict(v, encode_json=encode_json) for v in obj) | |
| else: | |
| return copy.deepcopy(obj) | |