|
from __future__ import annotations |
|
|
|
import socket |
|
from abc import abstractmethod |
|
from contextlib import AsyncExitStack |
|
from io import IOBase |
|
from ipaddress import IPv4Address, IPv6Address |
|
from socket import AddressFamily |
|
from typing import ( |
|
Any, |
|
Callable, |
|
Collection, |
|
Mapping, |
|
Tuple, |
|
TypeVar, |
|
Union, |
|
) |
|
|
|
from .._core._tasks import create_task_group |
|
from .._core._typedattr import ( |
|
TypedAttributeProvider, |
|
TypedAttributeSet, |
|
typed_attribute, |
|
) |
|
from ._streams import ByteStream, Listener, UnreliableObjectStream |
|
from ._tasks import TaskGroup |
|
|
|
IPAddressType = Union[str, IPv4Address, IPv6Address] |
|
IPSockAddrType = Tuple[str, int] |
|
SockAddrType = Union[IPSockAddrType, str] |
|
UDPPacketType = Tuple[bytes, IPSockAddrType] |
|
T_Retval = TypeVar("T_Retval") |
|
|
|
|
|
class SocketAttribute(TypedAttributeSet): |
|
|
|
family: AddressFamily = typed_attribute() |
|
|
|
local_address: SockAddrType = typed_attribute() |
|
|
|
local_port: int = typed_attribute() |
|
|
|
raw_socket: socket.socket = typed_attribute() |
|
|
|
remote_address: SockAddrType = typed_attribute() |
|
|
|
remote_port: int = typed_attribute() |
|
|
|
|
|
class _SocketProvider(TypedAttributeProvider): |
|
@property |
|
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: |
|
from .._core._sockets import convert_ipv6_sockaddr as convert |
|
|
|
attributes: dict[Any, Callable[[], Any]] = { |
|
SocketAttribute.family: lambda: self._raw_socket.family, |
|
SocketAttribute.local_address: lambda: convert( |
|
self._raw_socket.getsockname() |
|
), |
|
SocketAttribute.raw_socket: lambda: self._raw_socket, |
|
} |
|
try: |
|
peername: tuple[str, int] | None = convert(self._raw_socket.getpeername()) |
|
except OSError: |
|
peername = None |
|
|
|
|
|
if peername is not None: |
|
attributes[SocketAttribute.remote_address] = lambda: peername |
|
|
|
|
|
if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6): |
|
attributes[ |
|
SocketAttribute.local_port |
|
] = lambda: self._raw_socket.getsockname()[1] |
|
if peername is not None: |
|
remote_port = peername[1] |
|
attributes[SocketAttribute.remote_port] = lambda: remote_port |
|
|
|
return attributes |
|
|
|
@property |
|
@abstractmethod |
|
def _raw_socket(self) -> socket.socket: |
|
pass |
|
|
|
|
|
class SocketStream(ByteStream, _SocketProvider): |
|
""" |
|
Transports bytes over a socket. |
|
|
|
Supports all relevant extra attributes from :class:`~SocketAttribute`. |
|
""" |
|
|
|
|
|
class UNIXSocketStream(SocketStream): |
|
@abstractmethod |
|
async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: |
|
""" |
|
Send file descriptors along with a message to the peer. |
|
|
|
:param message: a non-empty bytestring |
|
:param fds: a collection of files (either numeric file descriptors or open file or socket |
|
objects) |
|
""" |
|
|
|
@abstractmethod |
|
async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: |
|
""" |
|
Receive file descriptors along with a message from the peer. |
|
|
|
:param msglen: length of the message to expect from the peer |
|
:param maxfds: maximum number of file descriptors to expect from the peer |
|
:return: a tuple of (message, file descriptors) |
|
""" |
|
|
|
|
|
class SocketListener(Listener[SocketStream], _SocketProvider): |
|
""" |
|
Listens to incoming socket connections. |
|
|
|
Supports all relevant extra attributes from :class:`~SocketAttribute`. |
|
""" |
|
|
|
@abstractmethod |
|
async def accept(self) -> SocketStream: |
|
"""Accept an incoming connection.""" |
|
|
|
async def serve( |
|
self, |
|
handler: Callable[[SocketStream], Any], |
|
task_group: TaskGroup | None = None, |
|
) -> None: |
|
async with AsyncExitStack() as exit_stack: |
|
if task_group is None: |
|
task_group = await exit_stack.enter_async_context(create_task_group()) |
|
|
|
while True: |
|
stream = await self.accept() |
|
task_group.start_soon(handler, stream) |
|
|
|
|
|
class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider): |
|
""" |
|
Represents an unconnected UDP socket. |
|
|
|
Supports all relevant extra attributes from :class:`~SocketAttribute`. |
|
""" |
|
|
|
async def sendto(self, data: bytes, host: str, port: int) -> None: |
|
"""Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))).""" |
|
return await self.send((data, (host, port))) |
|
|
|
|
|
class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider): |
|
""" |
|
Represents an connected UDP socket. |
|
|
|
Supports all relevant extra attributes from :class:`~SocketAttribute`. |
|
""" |
|
|