# Copyright 2024 The etils Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Dataclass utils.""" from __future__ import annotations import dataclasses import functools import typing from typing import Any, Callable, TypeVar from etils import epy from etils.edc import cast_utils from etils.edc import context from etils.edc import frozen_utils from etils.edc import helpers _Cls = Any _ClsT = TypeVar('_ClsT') _T = TypeVar('_T') @typing.overload def dataclass( cls: None = ..., *, kw_only: bool = ..., replace: bool = ..., # pylint: disable=redefined-outer-name repr: bool = ..., # pylint: disable=redefined-builtin auto_cast: bool = ..., contextvars: bool = ..., allow_unfrozen: bool = ..., ) -> Callable[[_ClsT], _ClsT]: ... @typing.overload def dataclass( cls: _ClsT, *, kw_only: bool = ..., replace: bool = ..., # pylint: disable=redefined-outer-name repr: bool = ..., # pylint: disable=redefined-builtin auto_cast: bool = ..., contextvars: bool = ..., allow_unfrozen: bool = ..., ) -> _ClsT: ... def dataclass( cls=None, *, kw_only=False, replace=True, # pylint: disable=redefined-outer-name repr=True, # pylint: disable=redefined-builtin auto_cast=True, contextvars=True, allow_unfrozen=False, ): """Augment a dataclass with additional features. `auto_cast`: Auto-convert init assignements to the annotated class. ```python @edc.dataclass class A: path: edc.AutoCast[epath.Path] some_enum: edc.AutoCast[MyEnum] x: edc.AutoCast[str] a = A( path='/some/path', some_enum='A', x=123 ) # Fields annotated with `AutoCast` are automatically casted to their type assert a.path == epath.Path('/some/path') assert a.some_enum is MyEnum.A assert a.x == '123' ``` `allow_unfrozen`: allow nested dataclass to be updated. This add two methods: * `.unfrozen()`: Create a lazy deep-copy of the current dataclass. Updates to nested attributes will be propagated to the top-level dataclass. * `.frozen()`: Returns the frozen dataclass, after it was mutated. Example: ```python old_x = X(y=Y(z=123)) x = old_x.unfrozen() x.y.z = 456 x = x.frozen() assert x == X(y=Y(z=123)) # Only new x is mutated assert old_x == X(y=Y(z=456)) # Old x is not mutated ``` Note: * Only the last `.frozen()` call resolve the dataclass by calling `.replace` recursivelly. * Dataclass returned by `.unfrozen()` and nested attributes are not the original dataclass but proxy objects which track the mutations. As such, those object are not compatible with `isinstance()`, `jax.tree.map`,... * Only the top-level dataclass need to be `allow_unfrozen=True` * Avoid using `unfrozen` if 2 attributes of the dataclass point to the same nested dataclass. Updates on one attribute might not be reflected on the other. ```python y = Y(y=123) x = X(x0=y, x1=y) # Same instance assigned twice in `x0` and `x1` x = x.unfrozen() x.x0.y = 456 # Changes in `x0` not reflected in `x1` x = x.frozen() assert x == X(x0=Y(y=456), x1=Y(y=123)) ``` This is because only attributes which are accessed are tracked, so `etils` do not know the object exist somewhere else in the attribute tree. * After `.frozen()` has been called, any of the temporary sub-attribute become invalid: ```python a = a.unfrozen() y = a.y a = a.frozen() y.x # Raise error (created between the unfrozen/frozen call) a.y.x # Work ``` `contextvars`: Fields annotated as `edc.ContextVar` are wrapped in a `contextvars.ContextVar`. Afterward each thread / asyncio coroutine will have its own version of the fields (similarly to `threading.local`). The contextvars are lazily initialized at first usage. Example: ```python @edc.dataclass @dataclasses.dataclass class Context: thread_id: edc.ContextVar[int] = dataclasses.field( default_factory=threading.get_native_id ) stack: edc.ContextVar[list[str]] = dataclasses.field(default_factory=list) # Global context object context = Context(thread_id=0) def worker(): # Inside each thread, the worker use its own context assert context.thread_id != 0 context.stack.append(1) with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: for _ in range(10): executor.submit(worker) ``` Args: cls: The dataclass to decorate kw_only: If True, make the dataclass `__init__` keyword-only. replace: If `True`, add a `.replace(` alias of `dataclasses.replace`. repr: If `True`, the class `__repr__` will return a pretty-printed `str` (one attribute per line) auto_cast: If `True`, fields annotated as `x: edc.AutoCast[Cls]` will be converted to `x: Cls = edc.field(validator=Cls)`. contextvars: It `True`, fields annotated as `x: edc.AutoCast[T]` are converted to `contextvars`. This allow to have a `threading.local`-like API for contextvars. allow_unfrozen: If `True`, add `.frozen`, `.unfrozen` methods. Returns: Decorated class """ # Return decorator if cls is None: return functools.partial( dataclass, kw_only=kw_only, replace=replace, repr=repr, auto_cast=auto_cast, allow_unfrozen=allow_unfrozen, ) if kw_only: cls = _make_kw_only(cls) if repr: cls = add_repr(cls) if replace: cls = _add_replace(cls) if allow_unfrozen: cls = frozen_utils.add_unfrozen(cls) descriptor_fns = [] if auto_cast: descriptor_fns.append( helpers.DescriptorInfo( annotation=cast_utils.AutoCast, descriptor_fn=cast_utils.make_auto_cast_descriptor, ) ) if contextvars: descriptor_fns.append( helpers.DescriptorInfo( annotation=context.ContextVar, descriptor_fn=context.make_contextvar_descriptor, ) ) cls = helpers.wrap_new(cls, descriptor_fns) return cls def _make_kw_only(cls: _ClsT) -> _ClsT: """Replace the `__init__` by a keyword-only version.""" # Use `cls.__dict__` and not `hasattr` to ignore parent classes if '__init__' not in cls.__dict__: return cls # Do not mutate the class if __init__ isn't present old_init = cls.__init__ # Despite `@functools.wraps`, the function has to be called `__init__` ( # see: https://stackoverflow.com/q/29919804/4172685) @functools.wraps(old_init) def __init__(self, *args, **kwargs): # pylint: disable=invalid-name if args: raise TypeError( f'{self.__class__.__name__} contructor is keyword-only. ' f'Got {len(args)} positional arguments.' ) return old_init(self, **kwargs) cls.__init__ = __init__ return cls def _add_replace(cls: _ClsT) -> _ClsT: """Add a `.replace` method to the class, if not already present.""" # Only add replace if not present if not hasattr(cls, 'replace'): cls.replace = replace return cls def replace(self: _T, **kwargs: Any) -> _T: """Similar to `dataclasses.replace`.""" return dataclasses.replace(self, **kwargs) def add_repr(cls: _ClsT) -> _ClsT: """Add a `.__repr__` method to the class, if not already present.""" # Use `cls.__dict__` and not `hasattr` to ignore parent classes if '__repr__' not in cls.__dict__: return cls if epy.text_utils.has_default_repr(cls): cls.__repr__ = __repr__ return cls __repr__ = epy.pretty_repr