Spaces:
Sleeping
Sleeping
# Extra utilities for working with context managers that should have been | |
# in the standard library but are not | |
import functools | |
import inspect | |
import warnings | |
import sys | |
from typing import Any, Callable, TypeVar, cast | |
# Used for annotating the decorator usage of _DecoratorContextManager (e.g., | |
# 'no_grad' and 'enable_grad'). | |
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators | |
FuncType = Callable[..., Any] | |
F = TypeVar('F', bound=FuncType) | |
def _wrap_generator(ctx_factory, func): | |
""" | |
Wrap each generator invocation with the context manager factory. | |
The input should be a function that returns a context manager, | |
not a context manager itself, to handle one-shot context managers. | |
""" | |
def generator_context(*args, **kwargs): | |
gen = func(*args, **kwargs) | |
# Generators are suspended and unsuspended at `yield`, hence we | |
# make sure the grad mode is properly set every time the execution | |
# flow returns into the wrapped generator and restored when it | |
# returns through our `yield` to our caller (see PR #49017). | |
try: | |
# Issuing `None` to a generator fires it up | |
with ctx_factory(): | |
response = gen.send(None) | |
while True: | |
try: | |
# Forward the response to our caller and get its next request | |
request = yield response | |
except GeneratorExit: | |
# Inform the still active generator about its imminent closure | |
with ctx_factory(): | |
gen.close() | |
raise | |
except BaseException: | |
# Propagate the exception thrown at us by the caller | |
with ctx_factory(): | |
response = gen.throw(*sys.exc_info()) | |
else: | |
# Pass the last request to the generator and get its response | |
with ctx_factory(): | |
response = gen.send(request) | |
# We let the exceptions raised above by the generator's `.throw` or | |
# `.send` methods bubble up to our caller, except for StopIteration | |
except StopIteration as e: | |
# The generator informed us that it is done: take whatever its | |
# returned value (if any) was and indicate that we're done too | |
# by returning it (see docs for python's return-statement). | |
return e.value | |
return generator_context | |
def context_decorator(ctx, func): | |
""" | |
Like contextlib.ContextDecorator. | |
But with the following differences: | |
1. Is done by wrapping, rather than inheritance, so it works with context | |
managers that are implemented from C and thus cannot easily inherit from | |
Python classes | |
2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743) | |
3. Errors out if you try to wrap a class, because it is ambiguous whether | |
or not you intended to wrap only the constructor | |
The input argument can either be a context manager (in which case it must | |
be a multi-shot context manager that can be directly invoked multiple times) | |
or a callable that produces a context manager. | |
""" | |
assert not (callable(ctx) and hasattr(ctx, '__enter__')), ( | |
f"Passed in {ctx} is both callable and also a valid context manager " | |
"(has __enter__), making it ambiguous which interface to use. If you " | |
"intended to pass a context manager factory, rewrite your call as " | |
"context_decorator(lambda: ctx()); if you intended to pass a context " | |
"manager directly, rewrite your call as context_decorator(lambda: ctx)" | |
) | |
if not callable(ctx): | |
def ctx_factory(): | |
return ctx | |
else: | |
ctx_factory = ctx | |
if inspect.isclass(func): | |
raise RuntimeError( | |
"Cannot decorate classes; it is ambiguous whether or not only the " | |
"constructor or all methods should have the context manager applied; " | |
"additionally, decorating a class at definition-site will prevent " | |
"use of the identifier as a conventional type. " | |
"To specify which methods to decorate, decorate each of them " | |
"individually." | |
) | |
if inspect.isgeneratorfunction(func): | |
return _wrap_generator(ctx_factory, func) | |
def decorate_context(*args, **kwargs): | |
with ctx_factory(): | |
return func(*args, **kwargs) | |
return decorate_context | |
class _DecoratorContextManager: | |
"""Allow a context manager to be used as a decorator.""" | |
def __call__(self, orig_func: F) -> F: | |
if inspect.isclass(orig_func): | |
warnings.warn("Decorating classes is deprecated and will be disabled in " | |
"future versions. You should only decorate functions or methods. " | |
"To preserve the current behavior of class decoration, you can " | |
"directly decorate the `__init__` method and nothing else.") | |
func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs)) | |
else: | |
func = orig_func | |
return cast(F, context_decorator(self.clone, func)) | |
def __enter__(self) -> None: | |
raise NotImplementedError | |
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: | |
raise NotImplementedError | |
def clone(self): | |
# override this method if your children class takes __init__ parameters | |
return self.__class__() | |
class _NoParamDecoratorContextManager(_DecoratorContextManager): | |
"""Allow a context manager to be used as a decorator without parentheses.""" | |
def __new__(cls, orig_func=None): | |
if orig_func is None: | |
return super().__new__(cls) | |
return cls()(orig_func) | |