|
from __future__ import annotations |
|
|
|
import array |
|
import asyncio |
|
import concurrent.futures |
|
import math |
|
import socket |
|
import sys |
|
from asyncio.base_events import _run_until_complete_cb |
|
from collections import OrderedDict, deque |
|
from concurrent.futures import Future |
|
from contextvars import Context, copy_context |
|
from dataclasses import dataclass |
|
from functools import partial, wraps |
|
from inspect import ( |
|
CORO_RUNNING, |
|
CORO_SUSPENDED, |
|
GEN_RUNNING, |
|
GEN_SUSPENDED, |
|
getcoroutinestate, |
|
getgeneratorstate, |
|
) |
|
from io import IOBase |
|
from os import PathLike |
|
from queue import Queue |
|
from socket import AddressFamily, SocketKind |
|
from threading import Thread |
|
from types import TracebackType |
|
from typing import ( |
|
IO, |
|
Any, |
|
AsyncGenerator, |
|
Awaitable, |
|
Callable, |
|
Collection, |
|
Coroutine, |
|
Generator, |
|
Iterable, |
|
Mapping, |
|
Optional, |
|
Sequence, |
|
Tuple, |
|
TypeVar, |
|
Union, |
|
cast, |
|
) |
|
from weakref import WeakKeyDictionary |
|
|
|
import sniffio |
|
|
|
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc |
|
from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable |
|
from .._core._eventloop import claim_worker_thread, threadlocals |
|
from .._core._exceptions import ( |
|
BrokenResourceError, |
|
BusyResourceError, |
|
ClosedResourceError, |
|
EndOfStream, |
|
WouldBlock, |
|
) |
|
from .._core._exceptions import ExceptionGroup as BaseExceptionGroup |
|
from .._core._sockets import GetAddrInfoReturnType, convert_ipv6_sockaddr |
|
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter |
|
from .._core._synchronization import Event as BaseEvent |
|
from .._core._synchronization import ResourceGuard |
|
from .._core._tasks import CancelScope as BaseCancelScope |
|
from ..abc import IPSockAddrType, UDPPacketType |
|
from ..lowlevel import RunVar |
|
|
|
if sys.version_info >= (3, 8): |
|
|
|
def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: |
|
return task.get_coro() |
|
|
|
else: |
|
|
|
def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: |
|
return task._coro |
|
|
|
|
|
from asyncio import all_tasks, create_task, current_task, get_running_loop |
|
from asyncio import run as native_run |
|
|
|
|
|
def _get_task_callbacks(task: asyncio.Task) -> Iterable[Callable]: |
|
return [cb for cb, context in task._callbacks] |
|
|
|
|
|
T_Retval = TypeVar("T_Retval") |
|
T_contra = TypeVar("T_contra", contravariant=True) |
|
|
|
|
|
_native_task_names = hasattr(asyncio.Task, "get_name") |
|
|
|
|
|
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task") |
|
|
|
|
|
def find_root_task() -> asyncio.Task: |
|
root_task = _root_task.get(None) |
|
if root_task is not None and not root_task.done(): |
|
return root_task |
|
|
|
|
|
for task in all_tasks(): |
|
if task._callbacks and not task.done(): |
|
for cb in _get_task_callbacks(task): |
|
if ( |
|
cb is _run_until_complete_cb |
|
or getattr(cb, "__module__", None) == "uvloop.loop" |
|
): |
|
_root_task.set(task) |
|
return task |
|
|
|
|
|
task = cast(asyncio.Task, current_task()) |
|
state = _task_states.get(task) |
|
if state: |
|
cancel_scope = state.cancel_scope |
|
while cancel_scope and cancel_scope._parent_scope is not None: |
|
cancel_scope = cancel_scope._parent_scope |
|
|
|
if cancel_scope is not None: |
|
return cast(asyncio.Task, cancel_scope._host_task) |
|
|
|
return task |
|
|
|
|
|
def get_callable_name(func: Callable) -> str: |
|
module = getattr(func, "__module__", None) |
|
qualname = getattr(func, "__qualname__", None) |
|
return ".".join([x for x in (module, qualname) if x]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
_run_vars = ( |
|
WeakKeyDictionary() |
|
) |
|
|
|
current_token = get_running_loop |
|
|
|
|
|
def _task_started(task: asyncio.Task) -> bool: |
|
"""Return ``True`` if the task has been started and has not finished.""" |
|
coro = cast(Coroutine[Any, Any, Any], get_coro(task)) |
|
try: |
|
return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED) |
|
except AttributeError: |
|
try: |
|
return getgeneratorstate(cast(Generator, coro)) in ( |
|
GEN_RUNNING, |
|
GEN_SUSPENDED, |
|
) |
|
except AttributeError: |
|
|
|
raise Exception(f"Cannot determine if task {task} has started or not") |
|
|
|
|
|
def _maybe_set_event_loop_policy( |
|
policy: asyncio.AbstractEventLoopPolicy | None, use_uvloop: bool |
|
) -> None: |
|
|
|
|
|
if policy is None and use_uvloop and sys.implementation.name == "cpython": |
|
try: |
|
import uvloop |
|
except ImportError: |
|
pass |
|
else: |
|
|
|
if not hasattr( |
|
asyncio.AbstractEventLoop, "shutdown_default_executor" |
|
) or hasattr(uvloop.loop.Loop, "shutdown_default_executor"): |
|
policy = uvloop.EventLoopPolicy() |
|
|
|
if policy is not None: |
|
asyncio.set_event_loop_policy(policy) |
|
|
|
|
|
def run( |
|
func: Callable[..., Awaitable[T_Retval]], |
|
*args: object, |
|
debug: bool = False, |
|
use_uvloop: bool = False, |
|
policy: asyncio.AbstractEventLoopPolicy | None = None, |
|
) -> T_Retval: |
|
@wraps(func) |
|
async def wrapper() -> T_Retval: |
|
task = cast(asyncio.Task, current_task()) |
|
task_state = TaskState(None, get_callable_name(func), None) |
|
_task_states[task] = task_state |
|
if _native_task_names: |
|
task.set_name(task_state.name) |
|
|
|
try: |
|
return await func(*args) |
|
finally: |
|
del _task_states[task] |
|
|
|
_maybe_set_event_loop_policy(policy, use_uvloop) |
|
return native_run(wrapper(), debug=debug) |
|
|
|
|
|
|
|
|
|
|
|
|
|
sleep = asyncio.sleep |
|
|
|
|
|
|
|
|
|
|
|
|
|
CancelledError = asyncio.CancelledError |
|
|
|
|
|
class CancelScope(BaseCancelScope): |
|
def __new__( |
|
cls, *, deadline: float = math.inf, shield: bool = False |
|
) -> CancelScope: |
|
return object.__new__(cls) |
|
|
|
def __init__(self, deadline: float = math.inf, shield: bool = False): |
|
self._deadline = deadline |
|
self._shield = shield |
|
self._parent_scope: CancelScope | None = None |
|
self._cancel_called = False |
|
self._active = False |
|
self._timeout_handle: asyncio.TimerHandle | None = None |
|
self._cancel_handle: asyncio.Handle | None = None |
|
self._tasks: set[asyncio.Task] = set() |
|
self._host_task: asyncio.Task | None = None |
|
self._timeout_expired = False |
|
self._cancel_calls: int = 0 |
|
|
|
def __enter__(self) -> CancelScope: |
|
if self._active: |
|
raise RuntimeError( |
|
"Each CancelScope may only be used for a single 'with' block" |
|
) |
|
|
|
self._host_task = host_task = cast(asyncio.Task, current_task()) |
|
self._tasks.add(host_task) |
|
try: |
|
task_state = _task_states[host_task] |
|
except KeyError: |
|
task_name = host_task.get_name() if _native_task_names else None |
|
task_state = TaskState(None, task_name, self) |
|
_task_states[host_task] = task_state |
|
else: |
|
self._parent_scope = task_state.cancel_scope |
|
task_state.cancel_scope = self |
|
|
|
self._timeout() |
|
self._active = True |
|
|
|
|
|
if self._cancel_called: |
|
self._deliver_cancellation() |
|
|
|
return self |
|
|
|
def __exit__( |
|
self, |
|
exc_type: type[BaseException] | None, |
|
exc_val: BaseException | None, |
|
exc_tb: TracebackType | None, |
|
) -> bool | None: |
|
if not self._active: |
|
raise RuntimeError("This cancel scope is not active") |
|
if current_task() is not self._host_task: |
|
raise RuntimeError( |
|
"Attempted to exit cancel scope in a different task than it was " |
|
"entered in" |
|
) |
|
|
|
assert self._host_task is not None |
|
host_task_state = _task_states.get(self._host_task) |
|
if host_task_state is None or host_task_state.cancel_scope is not self: |
|
raise RuntimeError( |
|
"Attempted to exit a cancel scope that isn't the current tasks's " |
|
"current cancel scope" |
|
) |
|
|
|
self._active = False |
|
if self._timeout_handle: |
|
self._timeout_handle.cancel() |
|
self._timeout_handle = None |
|
|
|
self._tasks.remove(self._host_task) |
|
|
|
host_task_state.cancel_scope = self._parent_scope |
|
|
|
|
|
|
|
if self._shield: |
|
self._deliver_cancellation_to_parent() |
|
|
|
if exc_val is not None: |
|
exceptions = ( |
|
exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val] |
|
) |
|
if all(isinstance(exc, CancelledError) for exc in exceptions): |
|
if self._timeout_expired: |
|
return self._uncancel() |
|
elif not self._cancel_called: |
|
|
|
return None |
|
elif not self._parent_cancelled(): |
|
|
|
return self._uncancel() |
|
|
|
return None |
|
|
|
def _uncancel(self) -> bool: |
|
if sys.version_info < (3, 11) or self._host_task is None: |
|
self._cancel_calls = 0 |
|
return True |
|
|
|
|
|
for i in range(self._cancel_calls): |
|
self._host_task.uncancel() |
|
|
|
self._cancel_calls = 0 |
|
return not self._host_task.cancelling() |
|
|
|
def _timeout(self) -> None: |
|
if self._deadline != math.inf: |
|
loop = get_running_loop() |
|
if loop.time() >= self._deadline: |
|
self._timeout_expired = True |
|
self.cancel() |
|
else: |
|
self._timeout_handle = loop.call_at(self._deadline, self._timeout) |
|
|
|
def _deliver_cancellation(self) -> None: |
|
""" |
|
Deliver cancellation to directly contained tasks and nested cancel scopes. |
|
|
|
Schedule another run at the end if we still have tasks eligible for cancellation. |
|
""" |
|
should_retry = False |
|
current = current_task() |
|
for task in self._tasks: |
|
if task._must_cancel: |
|
continue |
|
|
|
|
|
|
|
cancel_scope = _task_states[task].cancel_scope |
|
while cancel_scope is not self: |
|
if cancel_scope is None or cancel_scope._shield: |
|
break |
|
else: |
|
cancel_scope = cancel_scope._parent_scope |
|
else: |
|
should_retry = True |
|
if task is not current and ( |
|
task is self._host_task or _task_started(task) |
|
): |
|
self._cancel_calls += 1 |
|
task.cancel() |
|
|
|
|
|
if should_retry: |
|
self._cancel_handle = get_running_loop().call_soon( |
|
self._deliver_cancellation |
|
) |
|
else: |
|
self._cancel_handle = None |
|
|
|
def _deliver_cancellation_to_parent(self) -> None: |
|
"""Start cancellation effort in the farthest directly cancelled parent scope""" |
|
scope = self._parent_scope |
|
scope_to_cancel: CancelScope | None = None |
|
while scope is not None: |
|
if scope._cancel_called and scope._cancel_handle is None: |
|
scope_to_cancel = scope |
|
|
|
|
|
if scope._shield: |
|
break |
|
|
|
scope = scope._parent_scope |
|
|
|
if scope_to_cancel is not None: |
|
scope_to_cancel._deliver_cancellation() |
|
|
|
def _parent_cancelled(self) -> bool: |
|
|
|
cancel_scope = self._parent_scope |
|
while cancel_scope is not None and not cancel_scope._shield: |
|
if cancel_scope._cancel_called: |
|
return True |
|
else: |
|
cancel_scope = cancel_scope._parent_scope |
|
|
|
return False |
|
|
|
def cancel(self) -> DeprecatedAwaitable: |
|
if not self._cancel_called: |
|
if self._timeout_handle: |
|
self._timeout_handle.cancel() |
|
self._timeout_handle = None |
|
|
|
self._cancel_called = True |
|
if self._host_task is not None: |
|
self._deliver_cancellation() |
|
|
|
return DeprecatedAwaitable(self.cancel) |
|
|
|
@property |
|
def deadline(self) -> float: |
|
return self._deadline |
|
|
|
@deadline.setter |
|
def deadline(self, value: float) -> None: |
|
self._deadline = float(value) |
|
if self._timeout_handle is not None: |
|
self._timeout_handle.cancel() |
|
self._timeout_handle = None |
|
|
|
if self._active and not self._cancel_called: |
|
self._timeout() |
|
|
|
@property |
|
def cancel_called(self) -> bool: |
|
return self._cancel_called |
|
|
|
@property |
|
def shield(self) -> bool: |
|
return self._shield |
|
|
|
@shield.setter |
|
def shield(self, value: bool) -> None: |
|
if self._shield != value: |
|
self._shield = value |
|
if not value: |
|
self._deliver_cancellation_to_parent() |
|
|
|
|
|
async def checkpoint() -> None: |
|
await sleep(0) |
|
|
|
|
|
async def checkpoint_if_cancelled() -> None: |
|
task = current_task() |
|
if task is None: |
|
return |
|
|
|
try: |
|
cancel_scope = _task_states[task].cancel_scope |
|
except KeyError: |
|
return |
|
|
|
while cancel_scope: |
|
if cancel_scope.cancel_called: |
|
await sleep(0) |
|
elif cancel_scope.shield: |
|
break |
|
else: |
|
cancel_scope = cancel_scope._parent_scope |
|
|
|
|
|
async def cancel_shielded_checkpoint() -> None: |
|
with CancelScope(shield=True): |
|
await sleep(0) |
|
|
|
|
|
def current_effective_deadline() -> float: |
|
try: |
|
cancel_scope = _task_states[current_task()].cancel_scope |
|
except KeyError: |
|
return math.inf |
|
|
|
deadline = math.inf |
|
while cancel_scope: |
|
deadline = min(deadline, cancel_scope.deadline) |
|
if cancel_scope._cancel_called: |
|
deadline = -math.inf |
|
break |
|
elif cancel_scope.shield: |
|
break |
|
else: |
|
cancel_scope = cancel_scope._parent_scope |
|
|
|
return deadline |
|
|
|
|
|
def current_time() -> float: |
|
return get_running_loop().time() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TaskState: |
|
""" |
|
Encapsulates auxiliary task information that cannot be added to the Task instance itself |
|
because there are no guarantees about its implementation. |
|
""" |
|
|
|
__slots__ = "parent_id", "name", "cancel_scope" |
|
|
|
def __init__( |
|
self, |
|
parent_id: int | None, |
|
name: str | None, |
|
cancel_scope: CancelScope | None, |
|
): |
|
self.parent_id = parent_id |
|
self.name = name |
|
self.cancel_scope = cancel_scope |
|
|
|
|
|
_task_states = WeakKeyDictionary() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExceptionGroup(BaseExceptionGroup): |
|
def __init__(self, exceptions: list[BaseException]): |
|
super().__init__() |
|
self.exceptions = exceptions |
|
|
|
|
|
class _AsyncioTaskStatus(abc.TaskStatus): |
|
def __init__(self, future: asyncio.Future, parent_id: int): |
|
self._future = future |
|
self._parent_id = parent_id |
|
|
|
def started(self, value: T_contra | None = None) -> None: |
|
try: |
|
self._future.set_result(value) |
|
except asyncio.InvalidStateError: |
|
raise RuntimeError( |
|
"called 'started' twice on the same task status" |
|
) from None |
|
|
|
task = cast(asyncio.Task, current_task()) |
|
_task_states[task].parent_id = self._parent_id |
|
|
|
|
|
class TaskGroup(abc.TaskGroup): |
|
def __init__(self) -> None: |
|
self.cancel_scope: CancelScope = CancelScope() |
|
self._active = False |
|
self._exceptions: list[BaseException] = [] |
|
|
|
async def __aenter__(self) -> TaskGroup: |
|
self.cancel_scope.__enter__() |
|
self._active = True |
|
return self |
|
|
|
async def __aexit__( |
|
self, |
|
exc_type: type[BaseException] | None, |
|
exc_val: BaseException | None, |
|
exc_tb: TracebackType | None, |
|
) -> bool | None: |
|
ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb) |
|
if exc_val is not None: |
|
self.cancel_scope.cancel() |
|
self._exceptions.append(exc_val) |
|
|
|
while self.cancel_scope._tasks: |
|
try: |
|
await asyncio.wait(self.cancel_scope._tasks) |
|
except asyncio.CancelledError: |
|
self.cancel_scope.cancel() |
|
|
|
self._active = False |
|
if not self.cancel_scope._parent_cancelled(): |
|
exceptions = self._filter_cancellation_errors(self._exceptions) |
|
else: |
|
exceptions = self._exceptions |
|
|
|
try: |
|
if len(exceptions) > 1: |
|
if all( |
|
isinstance(e, CancelledError) and not e.args for e in exceptions |
|
): |
|
|
|
raise CancelledError |
|
else: |
|
raise ExceptionGroup(exceptions) |
|
elif exceptions and exceptions[0] is not exc_val: |
|
raise exceptions[0] |
|
except BaseException as exc: |
|
|
|
|
|
exc.__context__ = None |
|
raise |
|
|
|
return ignore_exception |
|
|
|
@staticmethod |
|
def _filter_cancellation_errors( |
|
exceptions: Sequence[BaseException], |
|
) -> list[BaseException]: |
|
filtered_exceptions: list[BaseException] = [] |
|
for exc in exceptions: |
|
if isinstance(exc, ExceptionGroup): |
|
new_exceptions = TaskGroup._filter_cancellation_errors(exc.exceptions) |
|
if len(new_exceptions) > 1: |
|
filtered_exceptions.append(exc) |
|
elif len(new_exceptions) == 1: |
|
filtered_exceptions.append(new_exceptions[0]) |
|
elif new_exceptions: |
|
new_exc = ExceptionGroup(new_exceptions) |
|
new_exc.__cause__ = exc.__cause__ |
|
new_exc.__context__ = exc.__context__ |
|
new_exc.__traceback__ = exc.__traceback__ |
|
filtered_exceptions.append(new_exc) |
|
elif not isinstance(exc, CancelledError) or exc.args: |
|
filtered_exceptions.append(exc) |
|
|
|
return filtered_exceptions |
|
|
|
async def _run_wrapped_task( |
|
self, coro: Coroutine, task_status_future: asyncio.Future | None |
|
) -> None: |
|
|
|
|
|
__traceback_hide__ = __tracebackhide__ = True |
|
task = cast(asyncio.Task, current_task()) |
|
try: |
|
await coro |
|
except BaseException as exc: |
|
if task_status_future is None or task_status_future.done(): |
|
self._exceptions.append(exc) |
|
self.cancel_scope.cancel() |
|
else: |
|
task_status_future.set_exception(exc) |
|
else: |
|
if task_status_future is not None and not task_status_future.done(): |
|
task_status_future.set_exception( |
|
RuntimeError("Child exited without calling task_status.started()") |
|
) |
|
finally: |
|
if task in self.cancel_scope._tasks: |
|
self.cancel_scope._tasks.remove(task) |
|
del _task_states[task] |
|
|
|
def _spawn( |
|
self, |
|
func: Callable[..., Awaitable[Any]], |
|
args: tuple, |
|
name: object, |
|
task_status_future: asyncio.Future | None = None, |
|
) -> asyncio.Task: |
|
def task_done(_task: asyncio.Task) -> None: |
|
|
|
assert _task in self.cancel_scope._tasks |
|
self.cancel_scope._tasks.remove(_task) |
|
del _task_states[_task] |
|
|
|
try: |
|
exc = _task.exception() |
|
except CancelledError as e: |
|
while isinstance(e.__context__, CancelledError): |
|
e = e.__context__ |
|
|
|
exc = e |
|
|
|
if exc is not None: |
|
if task_status_future is None or task_status_future.done(): |
|
self._exceptions.append(exc) |
|
self.cancel_scope.cancel() |
|
else: |
|
task_status_future.set_exception(exc) |
|
elif task_status_future is not None and not task_status_future.done(): |
|
task_status_future.set_exception( |
|
RuntimeError("Child exited without calling task_status.started()") |
|
) |
|
|
|
if not self._active: |
|
raise RuntimeError( |
|
"This task group is not active; no new tasks can be started." |
|
) |
|
|
|
options: dict[str, Any] = {} |
|
name = get_callable_name(func) if name is None else str(name) |
|
if _native_task_names: |
|
options["name"] = name |
|
|
|
kwargs = {} |
|
if task_status_future: |
|
parent_id = id(current_task()) |
|
kwargs["task_status"] = _AsyncioTaskStatus( |
|
task_status_future, id(self.cancel_scope._host_task) |
|
) |
|
else: |
|
parent_id = id(self.cancel_scope._host_task) |
|
|
|
coro = func(*args, **kwargs) |
|
if not asyncio.iscoroutine(coro): |
|
raise TypeError( |
|
f"Expected an async function, but {func} appears to be synchronous" |
|
) |
|
|
|
foreign_coro = not hasattr(coro, "cr_frame") and not hasattr(coro, "gi_frame") |
|
if foreign_coro or sys.version_info < (3, 8): |
|
coro = self._run_wrapped_task(coro, task_status_future) |
|
|
|
task = create_task(coro, **options) |
|
if not foreign_coro and sys.version_info >= (3, 8): |
|
task.add_done_callback(task_done) |
|
|
|
|
|
_task_states[task] = TaskState( |
|
parent_id=parent_id, name=name, cancel_scope=self.cancel_scope |
|
) |
|
self.cancel_scope._tasks.add(task) |
|
return task |
|
|
|
def start_soon( |
|
self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None |
|
) -> None: |
|
self._spawn(func, args, name) |
|
|
|
async def start( |
|
self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None |
|
) -> None: |
|
future: asyncio.Future = asyncio.Future() |
|
task = self._spawn(func, args, name, future) |
|
|
|
|
|
|
|
|
|
with CancelScope(shield=True): |
|
try: |
|
return await future |
|
except CancelledError: |
|
task.cancel() |
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]] |
|
|
|
|
|
class WorkerThread(Thread): |
|
MAX_IDLE_TIME = 10 |
|
|
|
def __init__( |
|
self, |
|
root_task: asyncio.Task, |
|
workers: set[WorkerThread], |
|
idle_workers: deque[WorkerThread], |
|
): |
|
super().__init__(name="AnyIO worker thread") |
|
self.root_task = root_task |
|
self.workers = workers |
|
self.idle_workers = idle_workers |
|
self.loop = root_task._loop |
|
self.queue: Queue[ |
|
tuple[Context, Callable, tuple, asyncio.Future] | None |
|
] = Queue(2) |
|
self.idle_since = current_time() |
|
self.stopping = False |
|
|
|
def _report_result( |
|
self, future: asyncio.Future, result: Any, exc: BaseException | None |
|
) -> None: |
|
self.idle_since = current_time() |
|
if not self.stopping: |
|
self.idle_workers.append(self) |
|
|
|
if not future.cancelled(): |
|
if exc is not None: |
|
if isinstance(exc, StopIteration): |
|
new_exc = RuntimeError("coroutine raised StopIteration") |
|
new_exc.__cause__ = exc |
|
exc = new_exc |
|
|
|
future.set_exception(exc) |
|
else: |
|
future.set_result(result) |
|
|
|
def run(self) -> None: |
|
with claim_worker_thread("asyncio"): |
|
threadlocals.loop = self.loop |
|
while True: |
|
item = self.queue.get() |
|
if item is None: |
|
|
|
return |
|
|
|
context, func, args, future = item |
|
if not future.cancelled(): |
|
result = None |
|
exception: BaseException | None = None |
|
try: |
|
result = context.run(func, *args) |
|
except BaseException as exc: |
|
exception = exc |
|
|
|
if not self.loop.is_closed(): |
|
self.loop.call_soon_threadsafe( |
|
self._report_result, future, result, exception |
|
) |
|
|
|
self.queue.task_done() |
|
|
|
def stop(self, f: asyncio.Task | None = None) -> None: |
|
self.stopping = True |
|
self.queue.put_nowait(None) |
|
self.workers.discard(self) |
|
try: |
|
self.idle_workers.remove(self) |
|
except ValueError: |
|
pass |
|
|
|
|
|
_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar( |
|
"_threadpool_idle_workers" |
|
) |
|
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers") |
|
|
|
|
|
async def run_sync_in_worker_thread( |
|
func: Callable[..., T_Retval], |
|
*args: object, |
|
cancellable: bool = False, |
|
limiter: CapacityLimiter | None = None, |
|
) -> T_Retval: |
|
await checkpoint() |
|
|
|
|
|
try: |
|
idle_workers = _threadpool_idle_workers.get() |
|
workers = _threadpool_workers.get() |
|
except LookupError: |
|
idle_workers = deque() |
|
workers = set() |
|
_threadpool_idle_workers.set(idle_workers) |
|
_threadpool_workers.set(workers) |
|
|
|
async with (limiter or current_default_thread_limiter()): |
|
with CancelScope(shield=not cancellable): |
|
future: asyncio.Future = asyncio.Future() |
|
root_task = find_root_task() |
|
if not idle_workers: |
|
worker = WorkerThread(root_task, workers, idle_workers) |
|
worker.start() |
|
workers.add(worker) |
|
root_task.add_done_callback(worker.stop) |
|
else: |
|
worker = idle_workers.pop() |
|
|
|
|
|
now = current_time() |
|
while idle_workers: |
|
if now - idle_workers[0].idle_since < WorkerThread.MAX_IDLE_TIME: |
|
break |
|
|
|
expired_worker = idle_workers.popleft() |
|
expired_worker.root_task.remove_done_callback(expired_worker.stop) |
|
expired_worker.stop() |
|
|
|
context = copy_context() |
|
context.run(sniffio.current_async_library_cvar.set, None) |
|
worker.queue.put_nowait((context, func, args, future)) |
|
return await future |
|
|
|
|
|
def run_sync_from_thread( |
|
func: Callable[..., T_Retval], |
|
*args: object, |
|
loop: asyncio.AbstractEventLoop | None = None, |
|
) -> T_Retval: |
|
@wraps(func) |
|
def wrapper() -> None: |
|
try: |
|
f.set_result(func(*args)) |
|
except BaseException as exc: |
|
f.set_exception(exc) |
|
if not isinstance(exc, Exception): |
|
raise |
|
|
|
f: concurrent.futures.Future[T_Retval] = Future() |
|
loop = loop or threadlocals.loop |
|
loop.call_soon_threadsafe(wrapper) |
|
return f.result() |
|
|
|
|
|
def run_async_from_thread( |
|
func: Callable[..., Awaitable[T_Retval]], *args: object |
|
) -> T_Retval: |
|
f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe( |
|
func(*args), threadlocals.loop |
|
) |
|
return f.result() |
|
|
|
|
|
class BlockingPortal(abc.BlockingPortal): |
|
def __new__(cls) -> BlockingPortal: |
|
return object.__new__(cls) |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
self._loop = get_running_loop() |
|
|
|
def _spawn_task_from_thread( |
|
self, |
|
func: Callable, |
|
args: tuple, |
|
kwargs: dict[str, Any], |
|
name: object, |
|
future: Future, |
|
) -> None: |
|
run_sync_from_thread( |
|
partial(self._task_group.start_soon, name=name), |
|
self._call_func, |
|
func, |
|
args, |
|
kwargs, |
|
future, |
|
loop=self._loop, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(eq=False) |
|
class StreamReaderWrapper(abc.ByteReceiveStream): |
|
_stream: asyncio.StreamReader |
|
|
|
async def receive(self, max_bytes: int = 65536) -> bytes: |
|
data = await self._stream.read(max_bytes) |
|
if data: |
|
return data |
|
else: |
|
raise EndOfStream |
|
|
|
async def aclose(self) -> None: |
|
self._stream.feed_eof() |
|
|
|
|
|
@dataclass(eq=False) |
|
class StreamWriterWrapper(abc.ByteSendStream): |
|
_stream: asyncio.StreamWriter |
|
|
|
async def send(self, item: bytes) -> None: |
|
self._stream.write(item) |
|
await self._stream.drain() |
|
|
|
async def aclose(self) -> None: |
|
self._stream.close() |
|
|
|
|
|
@dataclass(eq=False) |
|
class Process(abc.Process): |
|
_process: asyncio.subprocess.Process |
|
_stdin: StreamWriterWrapper | None |
|
_stdout: StreamReaderWrapper | None |
|
_stderr: StreamReaderWrapper | None |
|
|
|
async def aclose(self) -> None: |
|
if self._stdin: |
|
await self._stdin.aclose() |
|
if self._stdout: |
|
await self._stdout.aclose() |
|
if self._stderr: |
|
await self._stderr.aclose() |
|
|
|
await self.wait() |
|
|
|
async def wait(self) -> int: |
|
return await self._process.wait() |
|
|
|
def terminate(self) -> None: |
|
self._process.terminate() |
|
|
|
def kill(self) -> None: |
|
self._process.kill() |
|
|
|
def send_signal(self, signal: int) -> None: |
|
self._process.send_signal(signal) |
|
|
|
@property |
|
def pid(self) -> int: |
|
return self._process.pid |
|
|
|
@property |
|
def returncode(self) -> int | None: |
|
return self._process.returncode |
|
|
|
@property |
|
def stdin(self) -> abc.ByteSendStream | None: |
|
return self._stdin |
|
|
|
@property |
|
def stdout(self) -> abc.ByteReceiveStream | None: |
|
return self._stdout |
|
|
|
@property |
|
def stderr(self) -> abc.ByteReceiveStream | None: |
|
return self._stderr |
|
|
|
|
|
async def open_process( |
|
command: str | bytes | Sequence[str | bytes], |
|
*, |
|
shell: bool, |
|
stdin: int | IO[Any] | None, |
|
stdout: int | IO[Any] | None, |
|
stderr: int | IO[Any] | None, |
|
cwd: str | bytes | PathLike | None = None, |
|
env: Mapping[str, str] | None = None, |
|
start_new_session: bool = False, |
|
) -> Process: |
|
await checkpoint() |
|
if shell: |
|
process = await asyncio.create_subprocess_shell( |
|
cast(Union[str, bytes], command), |
|
stdin=stdin, |
|
stdout=stdout, |
|
stderr=stderr, |
|
cwd=cwd, |
|
env=env, |
|
start_new_session=start_new_session, |
|
) |
|
else: |
|
process = await asyncio.create_subprocess_exec( |
|
*command, |
|
stdin=stdin, |
|
stdout=stdout, |
|
stderr=stderr, |
|
cwd=cwd, |
|
env=env, |
|
start_new_session=start_new_session, |
|
) |
|
|
|
stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None |
|
stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None |
|
stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None |
|
return Process(process, stdin_stream, stdout_stream, stderr_stream) |
|
|
|
|
|
def _forcibly_shutdown_process_pool_on_exit( |
|
workers: set[Process], _task: object |
|
) -> None: |
|
""" |
|
Forcibly shuts down worker processes belonging to this event loop.""" |
|
child_watcher: asyncio.AbstractChildWatcher | None |
|
try: |
|
child_watcher = asyncio.get_event_loop_policy().get_child_watcher() |
|
except NotImplementedError: |
|
child_watcher = None |
|
|
|
|
|
for process in workers: |
|
if process.returncode is None: |
|
continue |
|
|
|
process._stdin._stream._transport.close() |
|
process._stdout._stream._transport.close() |
|
process._stderr._stream._transport.close() |
|
process.kill() |
|
if child_watcher: |
|
child_watcher.remove_child_handler(process.pid) |
|
|
|
|
|
async def _shutdown_process_pool_on_exit(workers: set[Process]) -> None: |
|
""" |
|
Shuts down worker processes belonging to this event loop. |
|
|
|
NOTE: this only works when the event loop was started using asyncio.run() or anyio.run(). |
|
|
|
""" |
|
process: Process |
|
try: |
|
await sleep(math.inf) |
|
except asyncio.CancelledError: |
|
for process in workers: |
|
if process.returncode is None: |
|
process.kill() |
|
|
|
for process in workers: |
|
await process.aclose() |
|
|
|
|
|
def setup_process_pool_exit_at_shutdown(workers: set[Process]) -> None: |
|
kwargs: dict[str, Any] = ( |
|
{"name": "AnyIO process pool shutdown task"} if _native_task_names else {} |
|
) |
|
create_task(_shutdown_process_pool_on_exit(workers), **kwargs) |
|
find_root_task().add_done_callback( |
|
partial(_forcibly_shutdown_process_pool_on_exit, workers) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamProtocol(asyncio.Protocol): |
|
read_queue: deque[bytes] |
|
read_event: asyncio.Event |
|
write_event: asyncio.Event |
|
exception: Exception | None = None |
|
|
|
def connection_made(self, transport: asyncio.BaseTransport) -> None: |
|
self.read_queue = deque() |
|
self.read_event = asyncio.Event() |
|
self.write_event = asyncio.Event() |
|
self.write_event.set() |
|
cast(asyncio.Transport, transport).set_write_buffer_limits(0) |
|
|
|
def connection_lost(self, exc: Exception | None) -> None: |
|
if exc: |
|
self.exception = BrokenResourceError() |
|
self.exception.__cause__ = exc |
|
|
|
self.read_event.set() |
|
self.write_event.set() |
|
|
|
def data_received(self, data: bytes) -> None: |
|
self.read_queue.append(data) |
|
self.read_event.set() |
|
|
|
def eof_received(self) -> bool | None: |
|
self.read_event.set() |
|
return True |
|
|
|
def pause_writing(self) -> None: |
|
self.write_event = asyncio.Event() |
|
|
|
def resume_writing(self) -> None: |
|
self.write_event.set() |
|
|
|
|
|
class DatagramProtocol(asyncio.DatagramProtocol): |
|
read_queue: deque[tuple[bytes, IPSockAddrType]] |
|
read_event: asyncio.Event |
|
write_event: asyncio.Event |
|
exception: Exception | None = None |
|
|
|
def connection_made(self, transport: asyncio.BaseTransport) -> None: |
|
self.read_queue = deque(maxlen=100) |
|
self.read_event = asyncio.Event() |
|
self.write_event = asyncio.Event() |
|
self.write_event.set() |
|
|
|
def connection_lost(self, exc: Exception | None) -> None: |
|
self.read_event.set() |
|
self.write_event.set() |
|
|
|
def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None: |
|
addr = convert_ipv6_sockaddr(addr) |
|
self.read_queue.append((data, addr)) |
|
self.read_event.set() |
|
|
|
def error_received(self, exc: Exception) -> None: |
|
self.exception = exc |
|
|
|
def pause_writing(self) -> None: |
|
self.write_event.clear() |
|
|
|
def resume_writing(self) -> None: |
|
self.write_event.set() |
|
|
|
|
|
class SocketStream(abc.SocketStream): |
|
def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol): |
|
self._transport = transport |
|
self._protocol = protocol |
|
self._receive_guard = ResourceGuard("reading from") |
|
self._send_guard = ResourceGuard("writing to") |
|
self._closed = False |
|
|
|
@property |
|
def _raw_socket(self) -> socket.socket: |
|
return self._transport.get_extra_info("socket") |
|
|
|
async def receive(self, max_bytes: int = 65536) -> bytes: |
|
with self._receive_guard: |
|
await checkpoint() |
|
|
|
if ( |
|
not self._protocol.read_event.is_set() |
|
and not self._transport.is_closing() |
|
): |
|
self._transport.resume_reading() |
|
await self._protocol.read_event.wait() |
|
self._transport.pause_reading() |
|
|
|
try: |
|
chunk = self._protocol.read_queue.popleft() |
|
except IndexError: |
|
if self._closed: |
|
raise ClosedResourceError from None |
|
elif self._protocol.exception: |
|
raise self._protocol.exception |
|
else: |
|
raise EndOfStream from None |
|
|
|
if len(chunk) > max_bytes: |
|
|
|
chunk, leftover = chunk[:max_bytes], chunk[max_bytes:] |
|
self._protocol.read_queue.appendleft(leftover) |
|
|
|
|
|
|
|
if not self._protocol.read_queue: |
|
self._protocol.read_event.clear() |
|
|
|
return chunk |
|
|
|
async def send(self, item: bytes) -> None: |
|
with self._send_guard: |
|
await checkpoint() |
|
|
|
if self._closed: |
|
raise ClosedResourceError |
|
elif self._protocol.exception is not None: |
|
raise self._protocol.exception |
|
|
|
try: |
|
self._transport.write(item) |
|
except RuntimeError as exc: |
|
if self._transport.is_closing(): |
|
raise BrokenResourceError from exc |
|
else: |
|
raise |
|
|
|
await self._protocol.write_event.wait() |
|
|
|
async def send_eof(self) -> None: |
|
try: |
|
self._transport.write_eof() |
|
except OSError: |
|
pass |
|
|
|
async def aclose(self) -> None: |
|
if not self._transport.is_closing(): |
|
self._closed = True |
|
try: |
|
self._transport.write_eof() |
|
except OSError: |
|
pass |
|
|
|
self._transport.close() |
|
await sleep(0) |
|
self._transport.abort() |
|
|
|
|
|
class UNIXSocketStream(abc.SocketStream): |
|
_receive_future: asyncio.Future | None = None |
|
_send_future: asyncio.Future | None = None |
|
_closing = False |
|
|
|
def __init__(self, raw_socket: socket.socket): |
|
self.__raw_socket = raw_socket |
|
self._loop = get_running_loop() |
|
self._receive_guard = ResourceGuard("reading from") |
|
self._send_guard = ResourceGuard("writing to") |
|
|
|
@property |
|
def _raw_socket(self) -> socket.socket: |
|
return self.__raw_socket |
|
|
|
def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: |
|
def callback(f: object) -> None: |
|
del self._receive_future |
|
loop.remove_reader(self.__raw_socket) |
|
|
|
f = self._receive_future = asyncio.Future() |
|
self._loop.add_reader(self.__raw_socket, f.set_result, None) |
|
f.add_done_callback(callback) |
|
return f |
|
|
|
def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: |
|
def callback(f: object) -> None: |
|
del self._send_future |
|
loop.remove_writer(self.__raw_socket) |
|
|
|
f = self._send_future = asyncio.Future() |
|
self._loop.add_writer(self.__raw_socket, f.set_result, None) |
|
f.add_done_callback(callback) |
|
return f |
|
|
|
async def send_eof(self) -> None: |
|
with self._send_guard: |
|
self._raw_socket.shutdown(socket.SHUT_WR) |
|
|
|
async def receive(self, max_bytes: int = 65536) -> bytes: |
|
loop = get_running_loop() |
|
await checkpoint() |
|
with self._receive_guard: |
|
while True: |
|
try: |
|
data = self.__raw_socket.recv(max_bytes) |
|
except BlockingIOError: |
|
await self._wait_until_readable(loop) |
|
except OSError as exc: |
|
if self._closing: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from exc |
|
else: |
|
if not data: |
|
raise EndOfStream |
|
|
|
return data |
|
|
|
async def send(self, item: bytes) -> None: |
|
loop = get_running_loop() |
|
await checkpoint() |
|
with self._send_guard: |
|
view = memoryview(item) |
|
while view: |
|
try: |
|
bytes_sent = self.__raw_socket.send(view) |
|
except BlockingIOError: |
|
await self._wait_until_writable(loop) |
|
except OSError as exc: |
|
if self._closing: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from exc |
|
else: |
|
view = view[bytes_sent:] |
|
|
|
async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: |
|
if not isinstance(msglen, int) or msglen < 0: |
|
raise ValueError("msglen must be a non-negative integer") |
|
if not isinstance(maxfds, int) or maxfds < 1: |
|
raise ValueError("maxfds must be a positive integer") |
|
|
|
loop = get_running_loop() |
|
fds = array.array("i") |
|
await checkpoint() |
|
with self._receive_guard: |
|
while True: |
|
try: |
|
message, ancdata, flags, addr = self.__raw_socket.recvmsg( |
|
msglen, socket.CMSG_LEN(maxfds * fds.itemsize) |
|
) |
|
except BlockingIOError: |
|
await self._wait_until_readable(loop) |
|
except OSError as exc: |
|
if self._closing: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from exc |
|
else: |
|
if not message and not ancdata: |
|
raise EndOfStream |
|
|
|
break |
|
|
|
for cmsg_level, cmsg_type, cmsg_data in ancdata: |
|
if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS: |
|
raise RuntimeError( |
|
f"Received unexpected ancillary data; message = {message!r}, " |
|
f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}" |
|
) |
|
|
|
fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) |
|
|
|
return message, list(fds) |
|
|
|
async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: |
|
if not message: |
|
raise ValueError("message must not be empty") |
|
if not fds: |
|
raise ValueError("fds must not be empty") |
|
|
|
loop = get_running_loop() |
|
filenos: list[int] = [] |
|
for fd in fds: |
|
if isinstance(fd, int): |
|
filenos.append(fd) |
|
elif isinstance(fd, IOBase): |
|
filenos.append(fd.fileno()) |
|
|
|
fdarray = array.array("i", filenos) |
|
await checkpoint() |
|
with self._send_guard: |
|
while True: |
|
try: |
|
|
|
|
|
self.__raw_socket.sendmsg( |
|
[message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)] |
|
) |
|
break |
|
except BlockingIOError: |
|
await self._wait_until_writable(loop) |
|
except OSError as exc: |
|
if self._closing: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from exc |
|
|
|
async def aclose(self) -> None: |
|
if not self._closing: |
|
self._closing = True |
|
if self.__raw_socket.fileno() != -1: |
|
self.__raw_socket.close() |
|
|
|
if self._receive_future: |
|
self._receive_future.set_result(None) |
|
if self._send_future: |
|
self._send_future.set_result(None) |
|
|
|
|
|
class TCPSocketListener(abc.SocketListener): |
|
_accept_scope: CancelScope | None = None |
|
_closed = False |
|
|
|
def __init__(self, raw_socket: socket.socket): |
|
self.__raw_socket = raw_socket |
|
self._loop = cast(asyncio.BaseEventLoop, get_running_loop()) |
|
self._accept_guard = ResourceGuard("accepting connections from") |
|
|
|
@property |
|
def _raw_socket(self) -> socket.socket: |
|
return self.__raw_socket |
|
|
|
async def accept(self) -> abc.SocketStream: |
|
if self._closed: |
|
raise ClosedResourceError |
|
|
|
with self._accept_guard: |
|
await checkpoint() |
|
with CancelScope() as self._accept_scope: |
|
try: |
|
client_sock, _addr = await self._loop.sock_accept(self._raw_socket) |
|
except asyncio.CancelledError: |
|
|
|
try: |
|
self._loop.remove_reader(self._raw_socket) |
|
except (ValueError, NotImplementedError): |
|
pass |
|
|
|
if self._closed: |
|
raise ClosedResourceError from None |
|
|
|
raise |
|
finally: |
|
self._accept_scope = None |
|
|
|
client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
|
transport, protocol = await self._loop.connect_accepted_socket( |
|
StreamProtocol, client_sock |
|
) |
|
return SocketStream(transport, protocol) |
|
|
|
async def aclose(self) -> None: |
|
if self._closed: |
|
return |
|
|
|
self._closed = True |
|
if self._accept_scope: |
|
|
|
try: |
|
self._loop.remove_reader(self._raw_socket) |
|
except (ValueError, NotImplementedError): |
|
pass |
|
|
|
self._accept_scope.cancel() |
|
await sleep(0) |
|
|
|
self._raw_socket.close() |
|
|
|
|
|
class UNIXSocketListener(abc.SocketListener): |
|
def __init__(self, raw_socket: socket.socket): |
|
self.__raw_socket = raw_socket |
|
self._loop = get_running_loop() |
|
self._accept_guard = ResourceGuard("accepting connections from") |
|
self._closed = False |
|
|
|
async def accept(self) -> abc.SocketStream: |
|
await checkpoint() |
|
with self._accept_guard: |
|
while True: |
|
try: |
|
client_sock, _ = self.__raw_socket.accept() |
|
client_sock.setblocking(False) |
|
return UNIXSocketStream(client_sock) |
|
except BlockingIOError: |
|
f: asyncio.Future = asyncio.Future() |
|
self._loop.add_reader(self.__raw_socket, f.set_result, None) |
|
f.add_done_callback( |
|
lambda _: self._loop.remove_reader(self.__raw_socket) |
|
) |
|
await f |
|
except OSError as exc: |
|
if self._closed: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from exc |
|
|
|
async def aclose(self) -> None: |
|
self._closed = True |
|
self.__raw_socket.close() |
|
|
|
@property |
|
def _raw_socket(self) -> socket.socket: |
|
return self.__raw_socket |
|
|
|
|
|
class UDPSocket(abc.UDPSocket): |
|
def __init__( |
|
self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol |
|
): |
|
self._transport = transport |
|
self._protocol = protocol |
|
self._receive_guard = ResourceGuard("reading from") |
|
self._send_guard = ResourceGuard("writing to") |
|
self._closed = False |
|
|
|
@property |
|
def _raw_socket(self) -> socket.socket: |
|
return self._transport.get_extra_info("socket") |
|
|
|
async def aclose(self) -> None: |
|
if not self._transport.is_closing(): |
|
self._closed = True |
|
self._transport.close() |
|
|
|
async def receive(self) -> tuple[bytes, IPSockAddrType]: |
|
with self._receive_guard: |
|
await checkpoint() |
|
|
|
|
|
if not self._protocol.read_queue and not self._transport.is_closing(): |
|
self._protocol.read_event.clear() |
|
await self._protocol.read_event.wait() |
|
|
|
try: |
|
return self._protocol.read_queue.popleft() |
|
except IndexError: |
|
if self._closed: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from None |
|
|
|
async def send(self, item: UDPPacketType) -> None: |
|
with self._send_guard: |
|
await checkpoint() |
|
await self._protocol.write_event.wait() |
|
if self._closed: |
|
raise ClosedResourceError |
|
elif self._transport.is_closing(): |
|
raise BrokenResourceError |
|
else: |
|
self._transport.sendto(*item) |
|
|
|
|
|
class ConnectedUDPSocket(abc.ConnectedUDPSocket): |
|
def __init__( |
|
self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol |
|
): |
|
self._transport = transport |
|
self._protocol = protocol |
|
self._receive_guard = ResourceGuard("reading from") |
|
self._send_guard = ResourceGuard("writing to") |
|
self._closed = False |
|
|
|
@property |
|
def _raw_socket(self) -> socket.socket: |
|
return self._transport.get_extra_info("socket") |
|
|
|
async def aclose(self) -> None: |
|
if not self._transport.is_closing(): |
|
self._closed = True |
|
self._transport.close() |
|
|
|
async def receive(self) -> bytes: |
|
with self._receive_guard: |
|
await checkpoint() |
|
|
|
|
|
if not self._protocol.read_queue and not self._transport.is_closing(): |
|
self._protocol.read_event.clear() |
|
await self._protocol.read_event.wait() |
|
|
|
try: |
|
packet = self._protocol.read_queue.popleft() |
|
except IndexError: |
|
if self._closed: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from None |
|
|
|
return packet[0] |
|
|
|
async def send(self, item: bytes) -> None: |
|
with self._send_guard: |
|
await checkpoint() |
|
await self._protocol.write_event.wait() |
|
if self._closed: |
|
raise ClosedResourceError |
|
elif self._transport.is_closing(): |
|
raise BrokenResourceError |
|
else: |
|
self._transport.sendto(item) |
|
|
|
|
|
async def connect_tcp( |
|
host: str, port: int, local_addr: tuple[str, int] | None = None |
|
) -> SocketStream: |
|
transport, protocol = cast( |
|
Tuple[asyncio.Transport, StreamProtocol], |
|
await get_running_loop().create_connection( |
|
StreamProtocol, host, port, local_addr=local_addr |
|
), |
|
) |
|
transport.pause_reading() |
|
return SocketStream(transport, protocol) |
|
|
|
|
|
async def connect_unix(path: str) -> UNIXSocketStream: |
|
await checkpoint() |
|
loop = get_running_loop() |
|
raw_socket = socket.socket(socket.AF_UNIX) |
|
raw_socket.setblocking(False) |
|
while True: |
|
try: |
|
raw_socket.connect(path) |
|
except BlockingIOError: |
|
f: asyncio.Future = asyncio.Future() |
|
loop.add_writer(raw_socket, f.set_result, None) |
|
f.add_done_callback(lambda _: loop.remove_writer(raw_socket)) |
|
await f |
|
except BaseException: |
|
raw_socket.close() |
|
raise |
|
else: |
|
return UNIXSocketStream(raw_socket) |
|
|
|
|
|
async def create_udp_socket( |
|
family: socket.AddressFamily, |
|
local_address: IPSockAddrType | None, |
|
remote_address: IPSockAddrType | None, |
|
reuse_port: bool, |
|
) -> UDPSocket | ConnectedUDPSocket: |
|
result = await get_running_loop().create_datagram_endpoint( |
|
DatagramProtocol, |
|
local_addr=local_address, |
|
remote_addr=remote_address, |
|
family=family, |
|
reuse_port=reuse_port, |
|
) |
|
transport = result[0] |
|
protocol = result[1] |
|
if protocol.exception: |
|
transport.close() |
|
raise protocol.exception |
|
|
|
if not remote_address: |
|
return UDPSocket(transport, protocol) |
|
else: |
|
return ConnectedUDPSocket(transport, protocol) |
|
|
|
|
|
async def getaddrinfo( |
|
host: bytes | str, |
|
port: str | int | None, |
|
*, |
|
family: int | AddressFamily = 0, |
|
type: int | SocketKind = 0, |
|
proto: int = 0, |
|
flags: int = 0, |
|
) -> GetAddrInfoReturnType: |
|
|
|
result = await get_running_loop().getaddrinfo( |
|
host, port, family=family, type=type, proto=proto, flags=flags |
|
) |
|
return cast(GetAddrInfoReturnType, result) |
|
|
|
|
|
async def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> tuple[str, str]: |
|
return await get_running_loop().getnameinfo(sockaddr, flags) |
|
|
|
|
|
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events") |
|
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events") |
|
|
|
|
|
async def wait_socket_readable(sock: socket.socket) -> None: |
|
await checkpoint() |
|
try: |
|
read_events = _read_events.get() |
|
except LookupError: |
|
read_events = {} |
|
_read_events.set(read_events) |
|
|
|
if read_events.get(sock): |
|
raise BusyResourceError("reading from") from None |
|
|
|
loop = get_running_loop() |
|
event = read_events[sock] = asyncio.Event() |
|
loop.add_reader(sock, event.set) |
|
try: |
|
await event.wait() |
|
finally: |
|
if read_events.pop(sock, None) is not None: |
|
loop.remove_reader(sock) |
|
readable = True |
|
else: |
|
readable = False |
|
|
|
if not readable: |
|
raise ClosedResourceError |
|
|
|
|
|
async def wait_socket_writable(sock: socket.socket) -> None: |
|
await checkpoint() |
|
try: |
|
write_events = _write_events.get() |
|
except LookupError: |
|
write_events = {} |
|
_write_events.set(write_events) |
|
|
|
if write_events.get(sock): |
|
raise BusyResourceError("writing to") from None |
|
|
|
loop = get_running_loop() |
|
event = write_events[sock] = asyncio.Event() |
|
loop.add_writer(sock.fileno(), event.set) |
|
try: |
|
await event.wait() |
|
finally: |
|
if write_events.pop(sock, None) is not None: |
|
loop.remove_writer(sock) |
|
writable = True |
|
else: |
|
writable = False |
|
|
|
if not writable: |
|
raise ClosedResourceError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Event(BaseEvent): |
|
def __new__(cls) -> Event: |
|
return object.__new__(cls) |
|
|
|
def __init__(self) -> None: |
|
self._event = asyncio.Event() |
|
|
|
def set(self) -> DeprecatedAwaitable: |
|
self._event.set() |
|
return DeprecatedAwaitable(self.set) |
|
|
|
def is_set(self) -> bool: |
|
return self._event.is_set() |
|
|
|
async def wait(self) -> None: |
|
if await self._event.wait(): |
|
await checkpoint() |
|
|
|
def statistics(self) -> EventStatistics: |
|
return EventStatistics(len(self._event._waiters)) |
|
|
|
|
|
class CapacityLimiter(BaseCapacityLimiter): |
|
_total_tokens: float = 0 |
|
|
|
def __new__(cls, total_tokens: float) -> CapacityLimiter: |
|
return object.__new__(cls) |
|
|
|
def __init__(self, total_tokens: float): |
|
self._borrowers: set[Any] = set() |
|
self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict() |
|
self.total_tokens = total_tokens |
|
|
|
async def __aenter__(self) -> None: |
|
await self.acquire() |
|
|
|
async def __aexit__( |
|
self, |
|
exc_type: type[BaseException] | None, |
|
exc_val: BaseException | None, |
|
exc_tb: TracebackType | None, |
|
) -> None: |
|
self.release() |
|
|
|
@property |
|
def total_tokens(self) -> float: |
|
return self._total_tokens |
|
|
|
@total_tokens.setter |
|
def total_tokens(self, value: float) -> None: |
|
if not isinstance(value, int) and not math.isinf(value): |
|
raise TypeError("total_tokens must be an int or math.inf") |
|
if value < 1: |
|
raise ValueError("total_tokens must be >= 1") |
|
|
|
old_value = self._total_tokens |
|
self._total_tokens = value |
|
events = [] |
|
for event in self._wait_queue.values(): |
|
if value <= old_value: |
|
break |
|
|
|
if not event.is_set(): |
|
events.append(event) |
|
old_value += 1 |
|
|
|
for event in events: |
|
event.set() |
|
|
|
@property |
|
def borrowed_tokens(self) -> int: |
|
return len(self._borrowers) |
|
|
|
@property |
|
def available_tokens(self) -> float: |
|
return self._total_tokens - len(self._borrowers) |
|
|
|
def acquire_nowait(self) -> DeprecatedAwaitable: |
|
self.acquire_on_behalf_of_nowait(current_task()) |
|
return DeprecatedAwaitable(self.acquire_nowait) |
|
|
|
def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: |
|
if borrower in self._borrowers: |
|
raise RuntimeError( |
|
"this borrower is already holding one of this CapacityLimiter's " |
|
"tokens" |
|
) |
|
|
|
if self._wait_queue or len(self._borrowers) >= self._total_tokens: |
|
raise WouldBlock |
|
|
|
self._borrowers.add(borrower) |
|
return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait) |
|
|
|
async def acquire(self) -> None: |
|
return await self.acquire_on_behalf_of(current_task()) |
|
|
|
async def acquire_on_behalf_of(self, borrower: object) -> None: |
|
await checkpoint_if_cancelled() |
|
try: |
|
self.acquire_on_behalf_of_nowait(borrower) |
|
except WouldBlock: |
|
event = asyncio.Event() |
|
self._wait_queue[borrower] = event |
|
try: |
|
await event.wait() |
|
except BaseException: |
|
self._wait_queue.pop(borrower, None) |
|
raise |
|
|
|
self._borrowers.add(borrower) |
|
else: |
|
try: |
|
await cancel_shielded_checkpoint() |
|
except BaseException: |
|
self.release() |
|
raise |
|
|
|
def release(self) -> None: |
|
self.release_on_behalf_of(current_task()) |
|
|
|
def release_on_behalf_of(self, borrower: object) -> None: |
|
try: |
|
self._borrowers.remove(borrower) |
|
except KeyError: |
|
raise RuntimeError( |
|
"this borrower isn't holding any of this CapacityLimiter's " "tokens" |
|
) from None |
|
|
|
|
|
if self._wait_queue and len(self._borrowers) < self._total_tokens: |
|
event = self._wait_queue.popitem(last=False)[1] |
|
event.set() |
|
|
|
def statistics(self) -> CapacityLimiterStatistics: |
|
return CapacityLimiterStatistics( |
|
self.borrowed_tokens, |
|
self.total_tokens, |
|
tuple(self._borrowers), |
|
len(self._wait_queue), |
|
) |
|
|
|
|
|
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter") |
|
|
|
|
|
def current_default_thread_limiter() -> CapacityLimiter: |
|
try: |
|
return _default_thread_limiter.get() |
|
except LookupError: |
|
limiter = CapacityLimiter(40) |
|
_default_thread_limiter.set(limiter) |
|
return limiter |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]): |
|
def __init__(self, signals: tuple[int, ...]): |
|
self._signals = signals |
|
self._loop = get_running_loop() |
|
self._signal_queue: deque[int] = deque() |
|
self._future: asyncio.Future = asyncio.Future() |
|
self._handled_signals: set[int] = set() |
|
|
|
def _deliver(self, signum: int) -> None: |
|
self._signal_queue.append(signum) |
|
if not self._future.done(): |
|
self._future.set_result(None) |
|
|
|
def __enter__(self) -> _SignalReceiver: |
|
for sig in set(self._signals): |
|
self._loop.add_signal_handler(sig, self._deliver, sig) |
|
self._handled_signals.add(sig) |
|
|
|
return self |
|
|
|
def __exit__( |
|
self, |
|
exc_type: type[BaseException] | None, |
|
exc_val: BaseException | None, |
|
exc_tb: TracebackType | None, |
|
) -> bool | None: |
|
for sig in self._handled_signals: |
|
self._loop.remove_signal_handler(sig) |
|
return None |
|
|
|
def __aiter__(self) -> _SignalReceiver: |
|
return self |
|
|
|
async def __anext__(self) -> int: |
|
await checkpoint() |
|
if not self._signal_queue: |
|
self._future = asyncio.Future() |
|
await self._future |
|
|
|
return self._signal_queue.popleft() |
|
|
|
|
|
def open_signal_receiver(*signals: int) -> _SignalReceiver: |
|
return _SignalReceiver(signals) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_task_info(task: asyncio.Task) -> TaskInfo: |
|
task_state = _task_states.get(task) |
|
if task_state is None: |
|
name = task.get_name() if _native_task_names else None |
|
parent_id = None |
|
else: |
|
name = task_state.name |
|
parent_id = task_state.parent_id |
|
|
|
return TaskInfo(id(task), parent_id, name, get_coro(task)) |
|
|
|
|
|
def get_current_task() -> TaskInfo: |
|
return _create_task_info(current_task()) |
|
|
|
|
|
def get_running_tasks() -> list[TaskInfo]: |
|
return [_create_task_info(task) for task in all_tasks() if not task.done()] |
|
|
|
|
|
async def wait_all_tasks_blocked() -> None: |
|
await checkpoint() |
|
this_task = current_task() |
|
while True: |
|
for task in all_tasks(): |
|
if task is this_task: |
|
continue |
|
|
|
if task._fut_waiter is None or task._fut_waiter.done(): |
|
await sleep(0.1) |
|
break |
|
else: |
|
return |
|
|
|
|
|
class TestRunner(abc.TestRunner): |
|
def __init__( |
|
self, |
|
debug: bool = False, |
|
use_uvloop: bool = False, |
|
policy: asyncio.AbstractEventLoopPolicy | None = None, |
|
): |
|
self._exceptions: list[BaseException] = [] |
|
_maybe_set_event_loop_policy(policy, use_uvloop) |
|
self._loop = asyncio.new_event_loop() |
|
self._loop.set_debug(debug) |
|
self._loop.set_exception_handler(self._exception_handler) |
|
asyncio.set_event_loop(self._loop) |
|
|
|
def _cancel_all_tasks(self) -> None: |
|
to_cancel = all_tasks(self._loop) |
|
if not to_cancel: |
|
return |
|
|
|
for task in to_cancel: |
|
task.cancel() |
|
|
|
self._loop.run_until_complete( |
|
asyncio.gather(*to_cancel, return_exceptions=True) |
|
) |
|
|
|
for task in to_cancel: |
|
if task.cancelled(): |
|
continue |
|
if task.exception() is not None: |
|
raise cast(BaseException, task.exception()) |
|
|
|
def _exception_handler( |
|
self, loop: asyncio.AbstractEventLoop, context: dict[str, Any] |
|
) -> None: |
|
if isinstance(context.get("exception"), Exception): |
|
self._exceptions.append(context["exception"]) |
|
else: |
|
loop.default_exception_handler(context) |
|
|
|
def _raise_async_exceptions(self) -> None: |
|
|
|
if self._exceptions: |
|
exceptions, self._exceptions = self._exceptions, [] |
|
if len(exceptions) == 1: |
|
raise exceptions[0] |
|
elif exceptions: |
|
raise ExceptionGroup(exceptions) |
|
|
|
def close(self) -> None: |
|
try: |
|
self._cancel_all_tasks() |
|
self._loop.run_until_complete(self._loop.shutdown_asyncgens()) |
|
finally: |
|
asyncio.set_event_loop(None) |
|
self._loop.close() |
|
|
|
def run_asyncgen_fixture( |
|
self, |
|
fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]], |
|
kwargs: dict[str, Any], |
|
) -> Iterable[T_Retval]: |
|
async def fixture_runner() -> None: |
|
agen = fixture_func(**kwargs) |
|
try: |
|
retval = await agen.asend(None) |
|
self._raise_async_exceptions() |
|
except BaseException as exc: |
|
f.set_exception(exc) |
|
return |
|
else: |
|
f.set_result(retval) |
|
|
|
await event.wait() |
|
try: |
|
await agen.asend(None) |
|
except StopAsyncIteration: |
|
pass |
|
else: |
|
await agen.aclose() |
|
raise RuntimeError("Async generator fixture did not stop") |
|
|
|
f = self._loop.create_future() |
|
event = asyncio.Event() |
|
fixture_task = self._loop.create_task(fixture_runner()) |
|
self._loop.run_until_complete(f) |
|
yield f.result() |
|
event.set() |
|
self._loop.run_until_complete(fixture_task) |
|
self._raise_async_exceptions() |
|
|
|
def run_fixture( |
|
self, |
|
fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]], |
|
kwargs: dict[str, Any], |
|
) -> T_Retval: |
|
retval = self._loop.run_until_complete(fixture_func(**kwargs)) |
|
self._raise_async_exceptions() |
|
return retval |
|
|
|
def run_test( |
|
self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] |
|
) -> None: |
|
try: |
|
self._loop.run_until_complete(test_func(**kwargs)) |
|
except Exception as exc: |
|
self._exceptions.append(exc) |
|
|
|
self._raise_async_exceptions() |
|
|